aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-11 19:59:11 +0800
commit9b3a93edf5a1f259bfe5230cc3b6c076573d4ec9 (patch)
treecbb0548282ba1584ed91a1be8f89b03ec882f287 /tensorflow
parent90cf7fb7786c8a9c135ef73482856b082e80f61a (diff)
parente18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD71
-rw-r--r--tensorflow/__init__.py3
-rw-r--r--tensorflow/api_template.__init__.py22
-rw-r--r--tensorflow/c/BUILD11
-rw-r--r--tensorflow/c/c_api.cc7
-rw-r--r--tensorflow/c/c_api_experimental.cc210
-rw-r--r--tensorflow/c/c_api_experimental.h43
-rw-r--r--tensorflow/c/c_api_test.cc16
-rw-r--r--tensorflow/c/checkpoint_reader.cc6
-rw-r--r--tensorflow/c/checkpoint_reader.h6
-rwxr-xr-x[-rw-r--r--]tensorflow/c/eager/c_api.cc36
-rwxr-xr-x[-rw-r--r--]tensorflow/c/eager/c_api.h10
-rw-r--r--tensorflow/c/eager/c_api_internal.h11
-rw-r--r--tensorflow/c/eager/c_api_test.cc82
-rw-r--r--tensorflow/c/eager/tape.h17
-rw-r--r--tensorflow/c/tf_status_helper.h6
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc11
-rw-r--r--tensorflow/cc/framework/ops.h2
-rw-r--r--tensorflow/cc/framework/scope.cc2
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc8
-rw-r--r--tensorflow/cc/saved_model/loader.cc6
-rw-r--r--tensorflow/compiler/aot/BUILD10
-rw-r--r--tensorflow/compiler/aot/codegen.cc69
-rw-r--r--tensorflow/compiler/aot/codegen.h4
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc8
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc41
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h6
-rw-r--r--tensorflow/compiler/aot/tests/BUILD11
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc4
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl59
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc14
-rw-r--r--tensorflow/compiler/jit/BUILD75
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc15
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op_test.cc9
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc106
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h2
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc1
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc79
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h6
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc275
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc360
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h61
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc346
-rw-r--r--tensorflow/compiler/jit/graphcycles/BUILD1
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles.cc4
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc7
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD13
-rw-r--r--tensorflow/compiler/jit/kernels/parallel_check_op.cc144
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc48
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h11
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc68
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h52
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc334
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc267
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc11
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h5
-rw-r--r--tensorflow/compiler/jit/ops/BUILD7
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc193
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.h31
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc134
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc336
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.h73
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc540
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc39
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h14
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util_test.cc1
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc25
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h6
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_device.cc35
-rw-r--r--tensorflow/compiler/jit/xla_device.h7
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc30
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h4
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h18
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc13
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer_test.cc27
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc50
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h2
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h7
-rw-r--r--tensorflow/compiler/tests/BUILD35
-rw-r--r--tensorflow/compiler/tests/adadelta_test.py2
-rw-r--r--tensorflow/compiler/tests/adagrad_da_test.py8
-rw-r--r--tensorflow/compiler/tests/adagrad_test.py6
-rw-r--r--tensorflow/compiler/tests/adam_test.py13
-rw-r--r--tensorflow/compiler/tests/adamax_test.py4
-rw-r--r--tensorflow/compiler/tests/addsign_test.py2
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py2
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py80
-rw-r--r--tensorflow/compiler/tests/bucketize_op_test.py10
-rw-r--r--tensorflow/compiler/tests/categorical_op_test.py6
-rw-r--r--tensorflow/compiler/tests/cholesky_op_test.py2
-rw-r--r--tensorflow/compiler/tests/clustering_test.py8
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py30
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py6
-rw-r--r--tensorflow/compiler/tests/conv3d_test.py10
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py11
-rw-r--r--tensorflow/compiler/tests/depthwise_conv_op_test.py8
-rw-r--r--tensorflow/compiler/tests/dynamic_slice_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/dynamic_stitch_test.py2
-rw-r--r--tensorflow/compiler/tests/eager_test.py101
-rw-r--r--tensorflow/compiler/tests/extract_image_patches_op_test.py2
-rw-r--r--tensorflow/compiler/tests/fake_quant_ops_test.py8
-rw-r--r--tensorflow/compiler/tests/fft_test.py4
-rw-r--r--tensorflow/compiler/tests/fifo_queue_test.py24
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py69
-rw-r--r--tensorflow/compiler/tests/function_test.py10
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py8
-rw-r--r--tensorflow/compiler/tests/gather_nd_op_test.py4
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py60
-rw-r--r--tensorflow/compiler/tests/jit_test.py5
-rw-r--r--tensorflow/compiler/tests/listdiff_op_test.py2
-rw-r--r--tensorflow/compiler/tests/lrn_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/lstm_test.py4
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py2
-rw-r--r--tensorflow/compiler/tests/matrix_triangular_solve_op_test.py2
-rw-r--r--tensorflow/compiler/tests/momentum_test.py6
-rw-r--r--tensorflow/compiler/tests/nary_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/oom_test.py2
-rw-r--r--tensorflow/compiler/tests/placeholder_test.py4
-rw-r--r--tensorflow/compiler/tests/pooling_ops_3d_test.py4
-rw-r--r--tensorflow/compiler/tests/pooling_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/powersign_test.py2
-rw-r--r--tensorflow/compiler/tests/proximal_adagrad_test.py12
-rw-r--r--tensorflow/compiler/tests/proximal_gradient_descent_test.py12
-rw-r--r--tensorflow/compiler/tests/qr_op_test.py6
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py10
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc69
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py58
-rw-r--r--tensorflow/compiler/tests/reduce_window_test.py2
-rw-r--r--tensorflow/compiler/tests/reshape_op_test.py50
-rw-r--r--tensorflow/compiler/tests/reverse_ops_test.py27
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py2
-rw-r--r--tensorflow/compiler/tests/rmsprop_test.py2
-rw-r--r--tensorflow/compiler/tests/scan_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/scatter_nd_op_test.py2
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/slice_ops_test.py33
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py4
-rw-r--r--tensorflow/compiler/tests/sparse_to_dense_op_test.py22
-rw-r--r--tensorflow/compiler/tests/stack_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py72
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py22
-rw-r--r--tensorflow/compiler/tests/while_test.py8
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py6
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py301
-rw-r--r--tensorflow/compiler/tf2xla/BUILD174
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD4
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc35
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.h20
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis_test.cc19
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.cc8
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc1354
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h248
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond_test.cc106
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc1522
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h6
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc93
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc72
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_util.h56
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc668
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.h32
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc8
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc78
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/identity_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc49
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc165
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc146
-rw-r--r--tensorflow/compiler/tf2xla/kernels/qr_op.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc92
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc115
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc65
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc105
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc102
-rw-r--r--tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc147
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD12
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc12
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h10
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc26
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc70
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h6
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc10
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc58
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h9
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc60
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h24
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc12
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.h14
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc5
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h7
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc27
-rw-r--r--tensorflow/compiler/tf2xla/ops/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc192
-rw-r--r--tensorflow/compiler/tf2xla/python/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py336
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc130
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.h71
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table_test.cc66
-rw-r--r--tensorflow/compiler/tf2xla/sharding_util.cc38
-rw-r--r--tensorflow/compiler/tf2xla/sharding_util.h21
-rw-r--r--tensorflow/compiler/tf2xla/sharding_util_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.cc67
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.h47
-rw-r--r--tensorflow/compiler/tf2xla/str_util.cc44
-rw-r--r--tensorflow/compiler/tf2xla/str_util.h42
-rw-r--r--tensorflow/compiler/tf2xla/str_util_test.cc60
-rw-r--r--tensorflow/compiler/tf2xla/test_util.cc8
-rw-r--r--tensorflow/compiler/tf2xla/test_util.h16
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc16
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc16
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h2
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util_test.cc12
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc203
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h53
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc421
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc27
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h6
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc72
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h28
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc47
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h23
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc4
-rw-r--r--tensorflow/compiler/xla/BUILD58
-rw-r--r--tensorflow/compiler/xla/array.h40
-rw-r--r--tensorflow/compiler/xla/array2d.h7
-rw-r--r--tensorflow/compiler/xla/array4d.h6
-rw-r--r--tensorflow/compiler/xla/array4d_test.cc24
-rw-r--r--tensorflow/compiler/xla/array_test.cc2
-rw-r--r--tensorflow/compiler/xla/client/BUILD20
-rw-r--r--tensorflow/compiler/xla/client/client.cc36
-rw-r--r--tensorflow/compiler/xla/client/client.h18
-rw-r--r--tensorflow/compiler/xla/client/client_library.cc10
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc6
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h4
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.cc30
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.h35
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD21
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc4
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.h6
-rw-r--r--tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc96
-rw-r--r--tensorflow/compiler/xla/client/lib/conv_grad_size_util.h44
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc9
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h3
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc58
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc10
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling.cc178
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling.h29
-rw-r--r--tensorflow/compiler/xla/client/lib/pooling_test.cc113
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc9
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc61
-rw-r--r--tensorflow/compiler/xla/client/local_client.h25
-rw-r--r--tensorflow/compiler/xla/client/padding.cc17
-rw-r--r--tensorflow/compiler/xla/client/padding.h15
-rw-r--r--tensorflow/compiler/xla/client/sharding_builder.h2
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc732
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h807
-rw-r--r--tensorflow/compiler/xla/client/xla_builder_test.cc9
-rw-r--r--tensorflow/compiler/xla/client/xla_computation.cc4
-rw-r--r--tensorflow/compiler/xla/device_util.h6
-rw-r--r--tensorflow/compiler/xla/index_util.cc19
-rw-r--r--tensorflow/compiler/xla/index_util.h14
-rw-r--r--tensorflow/compiler/xla/index_util_test.cc8
-rw-r--r--tensorflow/compiler/xla/iterator_util.h6
-rw-r--r--tensorflow/compiler/xla/iterator_util_test.cc6
-rw-r--r--tensorflow/compiler/xla/layout_util.cc46
-rw-r--r--tensorflow/compiler/xla/layout_util.h15
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD5
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc11
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h30
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc1
-rw-r--r--tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc6
-rw-r--r--tensorflow/compiler/xla/literal.cc377
-rw-r--r--tensorflow/compiler/xla/literal.h220
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc399
-rw-r--r--tensorflow/compiler/xla/literal_test.cc999
-rw-r--r--tensorflow/compiler/xla/literal_util.cc297
-rw-r--r--tensorflow/compiler/xla/literal_util.h251
-rw-r--r--tensorflow/compiler/xla/map_util.h2
-rw-r--r--tensorflow/compiler/xla/metric_table_report.cc28
-rw-r--r--tensorflow/compiler/xla/metric_table_report.h5
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc22
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h3
-rw-r--r--tensorflow/compiler/xla/python/BUILD5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc127
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h85
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i61
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc49
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h4
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py9
-rw-r--r--tensorflow/compiler/xla/reference_util.cc173
-rw-r--r--tensorflow/compiler/xla/reference_util.h110
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc62
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD2
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc14
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc4
-rw-r--r--tensorflow/compiler/xla/service/BUILD327
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc354
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc511
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc22
-rw-r--r--tensorflow/compiler/xla/service/backend.cc17
-rw-r--r--tensorflow/compiler/xla/service/backend.h8
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc14
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc18
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc37
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc245
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc249
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc350
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc24
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc136
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.cc3
-rw-r--r--tensorflow/compiler/xla/service/buffer_value.h2
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc40
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h6
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.cc2
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h8
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc16
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc10
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h4
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc16
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h6
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc9
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc77
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h26
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD41
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc162
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc58
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h40
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc31
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h13
-rw-r--r--tensorflow/compiler/xla/service/cpu/disassembler.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc103
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc508
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h87
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc43
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h17
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc36
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc27
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc4
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc11
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.h4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.cc4
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h6
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h14
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc5
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc889
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h119
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/executable.cc20
-rw-r--r--tensorflow/compiler/xla/service/executable.h57
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc10
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc188
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc9
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD106
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h68
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc135
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc160
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc164
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc88
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h44
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc152
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h64
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc37
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc19
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible.h49
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc332
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule.cc)26
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule.h)20
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc (renamed from tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc)86
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc83
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc74
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h26
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc136
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc655
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h78
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc24
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc96
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc160
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc67
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc27
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc62
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk_schedule.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h5
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc16
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc86
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h72
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto41
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc165
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h49
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc135
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc99
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc82
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc73
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h21
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc230
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc584
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h70
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc719
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h647
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc253
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc670
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h246
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc146
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc808
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h478
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling.cc)76
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h (renamed from tensorflow/compiler/xla/service/hlo_scheduling.h)55
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling_test.cc)133
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h28
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc65
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc124
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h26
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc102
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc644
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc230
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_fix.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc187
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h92
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h36
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc343
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h158
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc341
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc438
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h42
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc27
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc325
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h63
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc87
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc41
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.h18
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc244
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h35
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc306
-rw-r--r--tensorflow/compiler/xla/service/inliner.cc2
-rw-r--r--tensorflow/compiler/xla/service/inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc34
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h6
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc10
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc21
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/interpreter/platform.cc11
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc254
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h7
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc229
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD28
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc19
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h14
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc41
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h42
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h400
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h50
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc44
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h36
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h42
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.cc49
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/sort_util.h8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h10
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc28
-rw-r--r--tensorflow/compiler/xla/service/local_service.h4
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.cc11
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.h2
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc3
-rw-r--r--tensorflow/compiler/xla/service/maybe_owning_device_memory.cc41
-rw-r--r--tensorflow/compiler/xla/service/maybe_owning_device_memory.h70
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h18
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc17
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h4
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h9
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc63
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc4
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc8
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/service.cc156
-rw-r--r--tensorflow/compiler/xla/service/service.h26
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc1072
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h77
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc296
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc15
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h10
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc8
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.h40
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc33
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h35
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc5
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc60
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h10
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/tuple_util.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.cc24
-rw-r--r--tensorflow/compiler/xla/service/while_loop_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc13
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc20
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc14
-rw-r--r--tensorflow/compiler/xla/service/while_util.h2
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc8
-rw-r--r--tensorflow/compiler/xla/shape_tree.h40
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc3
-rw-r--r--tensorflow/compiler/xla/shape_util.cc168
-rw-r--r--tensorflow/compiler/xla/shape_util.h94
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc24
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.cc19
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.h22
-rw-r--r--tensorflow/compiler/xla/sparse_index_array_test.cc2
-rw-r--r--tensorflow/compiler/xla/status_macros.cc24
-rw-r--r--tensorflow/compiler/xla/test.h6
-rw-r--r--tensorflow/compiler/xla/test_helpers.h2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD129
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc428
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc132
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc28
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc109
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc55
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc140
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h221
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc36
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc100
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc25
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc165
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc70
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc45
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc86
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc131
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc146
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc382
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc97
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h73
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc13
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h19
-rw-r--r--tensorflow/compiler/xla/tests/iota_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc7
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h44
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc8
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc253
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc12
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h13
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc150
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc34
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc155
-rw-r--r--tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc76
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc149
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc66
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc42
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc158
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc365
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc354
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc43
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/sample_text_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc52
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc177
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc48
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.cc13
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc264
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h43
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc55
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc153
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc66
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc14
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc95
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h7
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc17
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc22
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h5
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc2
-rw-r--r--tensorflow/compiler/xla/tools/BUILD10
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc8
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc24
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_text.cc8
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc8
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc9
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc31
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc4
-rw-r--r--tensorflow/compiler/xla/tools/show_signature.cc8
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc16
-rw-r--r--tensorflow/compiler/xla/util.cc152
-rw-r--r--tensorflow/compiler/xla/util.h319
-rw-r--r--tensorflow/compiler/xla/util_test.cc43
-rw-r--r--tensorflow/compiler/xla/window_util.cc17
-rw-r--r--tensorflow/compiler/xla/window_util.h6
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/compiler/xla/xla_data.proto53
-rw-r--r--tensorflow/compiler/xrt/BUILD84
-rw-r--r--tensorflow/compiler/xrt/cc/BUILD20
-rw-r--r--tensorflow/compiler/xrt/kernels/BUILD69
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc239
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_execute_op.cc254
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.cc110
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.h424
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_compile_ops.cc48
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc44
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_state_ops.cc122
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD65
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc421
-rw-r--r--tensorflow/compiler/xrt/xrt.proto78
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.cc263
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.h238
-rw-r--r--tensorflow/compiler/xrt/xrt_device.cc46
-rw-r--r--tensorflow/compiler/xrt/xrt_device.h66
-rw-r--r--tensorflow/compiler/xrt/xrt_state.cc458
-rw-r--r--tensorflow/compiler/xrt/xrt_state.h208
-rw-r--r--tensorflow/contrib/BUILD10
-rw-r--r--tensorflow/contrib/__init__.py12
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.cc4
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py41
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py15
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees_test.py6
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py23
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py50
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py6
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions.py4
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions_test.py13
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards_test.py12
-rw-r--r--tensorflow/contrib/autograph/converters/slices_test.py2
-rw-r--r--tensorflow/contrib/autograph/docs/pyfunc_dtypes.md2
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD1
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/errors_test.py4
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py2
-rw-r--r--tensorflow/contrib/autograph/impl/api.py4
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py3
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD11
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py6
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py9
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures_test.py12
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins.py225
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins_test.py131
-rw-r--r--tensorflow/contrib/autograph/operators/slices.py9
-rw-r--r--tensorflow/contrib/autograph/operators/slices_test.py19
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD4
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/__init__.py0
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py7
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py6
-rw-r--r--tensorflow/contrib/autograph/pyct/templates_test.py36
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/BUILD2
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD23
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py3
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py143
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py145
-rw-r--r--tensorflow/contrib/autograph/utils/misc_test.py4
-rw-r--r--tensorflow/contrib/autograph/utils/py_func_test.py8
-rw-r--r--tensorflow/contrib/autograph/utils/tensor_list_test.py8
-rw-r--r--tensorflow/contrib/autograph/utils/tensors.py41
-rw-r--r--tensorflow/contrib/autograph/utils/tensors_test.py57
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py2
-rw-r--r--tensorflow/contrib/bigtable/README.md10
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc6
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.h4
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py61
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc226
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/training_ops.cc210
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py20
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py121
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py138
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h8
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc51
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/random.h6
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/ops/training_ops.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/proto/split_info.proto8
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto27
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py16
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py229
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py56
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py31
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py24
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py908
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py39
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py157
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD28
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py166
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py101
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc2
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h6
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h6
-rw-r--r--tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py4
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py4
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/cmake/external/nsync.cmake8
-rw-r--r--tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt325
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt6
-rw-r--r--tensorflow/contrib/coder/BUILD1
-rw-r--r--tensorflow/contrib/compiler/BUILD34
-rw-r--r--tensorflow/contrib/compiler/jit_test.py17
-rw-r--r--tensorflow/contrib/compiler/xla.py208
-rw-r--r--tensorflow/contrib/compiler/xla_test.py180
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py10
-rw-r--r--tensorflow/contrib/crf/__init__.py2
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py20
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py2
-rw-r--r--tensorflow/contrib/data/BUILD22
-rw-r--r--tensorflow/contrib/data/__init__.py12
-rw-r--r--tensorflow/contrib/data/kernels/BUILD38
-rw-r--r--tensorflow/contrib/data/kernels/assert_next_dataset_op.cc2
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/identity_indexed_dataset.cc155
-rw-r--r--tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.cc373
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.h119
-rw-r--r--tensorflow/contrib/data/kernels/lmdb_dataset_op.cc217
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc19
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc2
-rw-r--r--tensorflow/contrib/data/kernels/unique_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc9
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD94
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py224
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py32
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py78
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py62
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py185
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD88
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py)179
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py219
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py108
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py850
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py89
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py52
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py70
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py166
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD41
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py27
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py173
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py74
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py2
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py150
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py90
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py23
-rw-r--r--tensorflow/contrib/distribute/BUILD3
-rw-r--r--tensorflow/contrib/distribute/README.md305
-rw-r--r--tensorflow/contrib/distribute/__init__.py6
-rw-r--r--tensorflow/contrib/distribute/python/BUILD158
-rw-r--r--tensorflow/contrib/distribute/python/checkpoint_utils_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py227
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py116
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py32
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py131
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py107
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py659
-rw-r--r--tensorflow/contrib/distribute/python/examples/BUILD15
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py125
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py (renamed from tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py)10
-rw-r--r--tensorflow/contrib/distribute/python/input_ops.py13
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py210
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py222
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py153
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py19
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py141
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy_test.py62
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py107
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py13
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py340
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py256
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py6
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py8
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py141
-rw-r--r--tensorflow/contrib/distribute/python/values.py217
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py25
-rw-r--r--tensorflow/contrib/distribute/python/warm_starting_util_test.py4
-rw-r--r--tensorflow/contrib/distributions/BUILD2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py34
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py96
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py60
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py24
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py26
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py44
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py58
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py18
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py43
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py32
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/independent_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py36
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py18
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py18
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py40
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py16
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py34
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py24
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py32
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py32
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py24
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py18
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py40
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/shape_test.py16
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py30
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py20
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_exponential_diag_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py16
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py28
-rw-r--r--tensorflow/contrib/eager/python/BUILD15
-rw-r--r--tensorflow/contrib/eager/python/evaluator_test.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/README.md14
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb298
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb389
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb467
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb485
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py11
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/BUILD25
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py54
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py22
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py71
-rw-r--r--tensorflow/contrib/eager/python/remote.py73
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py13
-rw-r--r--tensorflow/contrib/eager/python/saver_test.py51
-rw-r--r--tensorflow/contrib/eager/python/tfe.py7
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py29
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders_test.py129
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py58
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py26
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py32
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py22
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py92
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py16
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py4
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py16
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops_test.py6
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py2
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py70
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py18
-rw-r--r--tensorflow/contrib/ffmpeg/__init__.py2
-rw-r--r--tensorflow/contrib/ffmpeg/decode_audio_op_test.py12
-rw-r--r--tensorflow/contrib/ffmpeg/decode_video_op_test.py2
-rw-r--r--tensorflow/contrib/ffmpeg/encode_audio_op_test.py10
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py4
-rw-r--r--tensorflow/contrib/framework/__init__.py4
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py32
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py33
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py34
-rw-r--r--tensorflow/contrib/framework/python/ops/arg_scope_test.py18
-rw-r--r--tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_ops.py5
-rw-r--r--tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/script_ops.py2
-rw-r--r--tensorflow/contrib/framework/python/ops/sort_ops_test.py12
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py140
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h6
-rw-r--r--tensorflow/contrib/gan/BUILD52
-rw-r--r--tensorflow/contrib/gan/python/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py2
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py (renamed from tensorflow/contrib/kfac/python/ops/optimizer_lib.py)16
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py363
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py306
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py10
-rw-r--r--tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py6
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py52
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py8
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc12
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.h6
-rw-r--r--tensorflow/contrib/gdr/gdr_rendezvous_mgr.h6
-rw-r--r--tensorflow/contrib/gdr/gdr_server_lib.h6
-rw-r--r--tensorflow/contrib/gdr/gdr_worker.h6
-rw-r--r--tensorflow/contrib/graph_editor/__init__.py4
-rw-r--r--tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc60
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc44
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py12
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py2
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py32
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py84
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/segmentation_test.py18
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py6
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py10
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py7
-rw-r--r--tensorflow/contrib/image/python/ops/interpolate_spline.py35
-rw-r--r--tensorflow/contrib/integrate/__init__.py4
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py34
-rw-r--r--tensorflow/contrib/kfac/BUILD26
-rw-r--r--tensorflow/contrib/kfac/README.md93
-rw-r--r--tensorflow/contrib/kfac/__init__.py46
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD80
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py667
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py39
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py354
-rw-r--r--tensorflow/contrib/kfac/examples/mlp_mnist_main.py64
-rw-r--r--tensorflow/contrib/kfac/examples/mnist.py69
-rw-r--r--tensorflow/contrib/kfac/examples/tests/BUILD52
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py166
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mlp_test.py63
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mnist_test.py72
-rw-r--r--tensorflow/contrib/kfac/g3doc/autoencoder.pngbin54204 -> 0 bytes
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD160
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py310
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py955
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py597
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py190
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py219
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py410
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD263
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py183
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py516
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py1752
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py45
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py1830
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py1269
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py46
-rw-r--r--tensorflow/contrib/kfac/python/ops/linear_operator.py95
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py754
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions_lib.py39
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue.py69
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py727
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py114
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py709
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py50
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops_test.py2
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/test_util.py2
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/layers/__init__.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py54
-rw-r--r--tensorflow/contrib/layers/python/layers/encoders_test.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py17
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py206
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py75
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers_test.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py340
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization.py25
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py108
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py42
-rw-r--r--tensorflow/contrib/layers/python/layers/regularizers_test.py14
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py145
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/summaries_test.py12
-rw-r--r--tensorflow/contrib/layers/python/layers/utils_test.py26
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops_test.py46
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/__init__.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/stability_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py42
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py28
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py44
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py6
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py2
-rw-r--r--tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py4
-rw-r--r--tensorflow/contrib/linalg/__init__.py3
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md40
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py51
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py14
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py4
-rw-r--r--tensorflow/contrib/lite/BUILD54
-rw-r--r--tensorflow/contrib/lite/RELEASE.md8
-rw-r--r--tensorflow/contrib/lite/allocation.cc4
-rw-r--r--tensorflow/contrib/lite/allocation.h4
-rw-r--r--tensorflow/contrib/lite/arena_planner.h6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl24
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h277
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h4
-rw-r--r--tensorflow/contrib/lite/c/BUILD39
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h298
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc83
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c (renamed from tensorflow/contrib/lite/context.c)6
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h491
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal_test.cc (renamed from tensorflow/contrib/lite/context_test.cc)10
-rw-r--r--tensorflow/contrib/lite/context.h475
-rw-r--r--tensorflow/contrib/lite/context_util.h2
-rw-r--r--tensorflow/contrib/lite/core/api/BUILD57
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.cc (renamed from tensorflow/compiler/xla/ptr_util.h)45
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.h45
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter_test.cc49
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc622
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h48
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc104
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.cc60
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.h47
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver_test.cc197
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD24
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc20
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc15
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc45
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.h28
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc42
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h19
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc44
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc102
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc83
-rw-r--r--tensorflow/contrib/lite/error_reporter.h38
-rw-r--r--tensorflow/contrib/lite/examples/android/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/examples/android/build.gradle1
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h6
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h2
-rw-r--r--tensorflow/contrib/lite/examples/label_image/get_top_n.h6
-rw-r--r--tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h6
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h6
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD1
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h13
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h6
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs31
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/BUILD3
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h18
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc15
-rw-r--r--tensorflow/contrib/lite/experimental/writer/BUILD66
-rw-r--r--tensorflow/contrib/lite/experimental/writer/enum_mapping.h116
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc370
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer.cc41
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc281
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h126
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc62
-rw-r--r--tensorflow/contrib/lite/g3doc/README.md4
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml1
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/index.md10
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md45
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_android.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_ios.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md8
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md100
-rw-r--r--tensorflow/contrib/lite/g3doc/ops_versioning.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/rpi.md36
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md27
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md2
-rw-r--r--tensorflow/contrib/lite/graph_info.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc8
-rw-r--r--tensorflow/contrib/lite/interpreter.h16
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc2
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md6
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD3
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java44
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java70
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java29
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java51
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/exception_jni.h6
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h14
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h8
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h6
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java15
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java19
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java22
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD126
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h2
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc15
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc33
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc996
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc430
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc467
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc284
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h4
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc146
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD50
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc683
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h91
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h343
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc31
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h21
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h907
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h18
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc210
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc133
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h450
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h2008
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h92
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h111
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h135
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h17
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc206
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h57
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h10
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc1316
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc664
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc80
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc79
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h6
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc51
-rw-r--r--tensorflow/contrib/lite/kernels/pack_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc300
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc288
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc17
-rw-r--r--tensorflow/contrib/lite/kernels/register.h3
-rw-r--r--tensorflow/contrib/lite/kernels/relu1.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc79
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc84
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc29
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc35
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc130
-rw-r--r--tensorflow/contrib/lite/kernels/unpack_test.cc225
-rwxr-xr-xtensorflow/contrib/lite/lib_package/create_ios_frameworks.sh7
-rw-r--r--tensorflow/contrib/lite/memory_planner.h2
-rw-r--r--tensorflow/contrib/lite/mmap_allocation.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc628
-rw-r--r--tensorflow/contrib/lite/model.h5
-rw-r--r--tensorflow/contrib/lite/model_test.cc2
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc14
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc (renamed from tensorflow/contrib/lite/op_resolver.cc)3
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h79
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc (renamed from tensorflow/contrib/lite/op_resolver_test.cc)2
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h6
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc80
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate_disabled.cc8
-rw-r--r--tensorflow/contrib/lite/op_resolver.h78
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h6
-rw-r--r--tensorflow/contrib/lite/python/BUILD6
-rw-r--r--tensorflow/contrib/lite/python/convert.py103
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py186
-rw-r--r--tensorflow/contrib/lite/python/lite.py209
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py321
-rw-r--r--tensorflow/contrib/lite/python/op_hint.py898
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py58
-rw-r--r--tensorflow/contrib/lite/schema/BUILD16
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs18
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h274
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h2
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.cc (renamed from tensorflow/contrib/lite/error_reporter.cc)22
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.h34
-rw-r--r--tensorflow/contrib/lite/string.h6
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/string_util.h2
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py98
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc15
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.h6
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc22
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.h6
-rw-r--r--tensorflow/contrib/lite/testing/util.h2
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc16
-rw-r--r--tensorflow/contrib/lite/toco/args.h7
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc47
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md22
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md24
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md1
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc29
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc151
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc1
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc51
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h5
-rw-r--r--tensorflow/contrib/lite/toco/model.h20
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py2
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h6
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h6
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h6
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h6
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc80
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h54
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc63
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc107
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h8
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc38
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc43
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto21
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc28
-rw-r--r--tensorflow/contrib/lite/toco/toco_types.h6
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc11
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h8
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/BUILD328
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/README.md38
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h49
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc27
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/csv_writer.h79
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc39
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h87
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc100
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h99
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc229
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc133
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc (renamed from tensorflow/contrib/lite/delegates/eager/constants.h)24
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h37
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc110
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD182
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md146
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt1762
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py105
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc165
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc351
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h124
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc114
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h83
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc151
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc80
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h75
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc123
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpgbin0 -> 73746 bytes
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc158
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc200
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc45
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h53
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/stage.h56
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.cc102
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.h46
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils_test.cc76
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD41
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md22
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.h4
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc13
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags.h4
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/README.md4
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/logging.h4
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile108
-rw-r--r--tensorflow/contrib/lite/tools/optimize/BUILD25
-rw-r--r--tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md70
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc432
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.h57
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc226
-rw-r--r--tensorflow/contrib/lite/tutorials/BUILD20
-rw-r--r--tensorflow/contrib/lite/tutorials/dataset.py122
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
-rw-r--r--tensorflow/contrib/lite/util.cc7
-rw-r--r--tensorflow/contrib/lite/util.h12
-rw-r--r--tensorflow/contrib/lite/util_test.cc12
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py50
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py330
-rw-r--r--tensorflow/contrib/losses/__init__.py2
-rw-r--r--tensorflow/contrib/losses/python/losses/__init__.py2
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py214
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/__init__.py4
-rw-r--r--tensorflow/contrib/makefile/Makefile29
-rwxr-xr-xtensorflow/contrib/makefile/compile_nsync.sh1
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt113
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt74
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt73
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt524
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt56
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt76
-rw-r--r--tensorflow/contrib/metrics/__init__.py4
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py51
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py520
-rw-r--r--tensorflow/contrib/model_pruning/BUILD3
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py4
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py4
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py16
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py79
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py68
-rw-r--r--tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h6
-rw-r--r--tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py2
-rw-r--r--tensorflow/contrib/nn/python/ops/alpha_dropout_test.py2
-rw-r--r--tensorflow/contrib/nn/python/ops/fwd_gradients_test.py4
-rw-r--r--tensorflow/contrib/nn/python/ops/sampling_ops_test.py4
-rw-r--r--tensorflow/contrib/opt/BUILD18
-rw-r--r--tensorflow/contrib/opt/__init__.py2
-rw-r--r--tensorflow/contrib/opt/python/training/adamax_test.py12
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py168
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py113
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer_test.py18
-rw-r--r--tensorflow/contrib/opt/python/training/ggt_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py34
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py247
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions.py155
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions_test.py63
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py8
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py40
-rw-r--r--tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/nadam_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/powersign.py2
-rw-r--r--tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py22
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py98
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py214
-rw-r--r--tensorflow/contrib/opt/python/training/sign_decay_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py77
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py2
-rw-r--r--tensorflow/contrib/optimizer_v2/adadelta_test.py4
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad_test.py18
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py9
-rw-r--r--tensorflow/contrib/optimizer_v2/adam_test.py12
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py6
-rw-r--r--tensorflow/contrib/optimizer_v2/gradient_descent_test.py16
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum_test.py14
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py6
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2_test.py10
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop_test.py4
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h6
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py14
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py2
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py10
-rw-r--r--tensorflow/contrib/quantize/BUILD5
-rw-r--r--tensorflow/contrib/quantize/python/common.py26
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py25
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py25
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py5
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py53
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py15
-rw-r--r--tensorflow/contrib/rate/BUILD48
-rw-r--r--tensorflow/contrib/rate/rate.py151
-rw-r--r--tensorflow/contrib/rate/rate_test.py97
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py2
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py4
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py8
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h6
-rw-r--r--tensorflow/contrib/rnn/BUILD8
-rw-r--r--tensorflow/contrib/rnn/__init__.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py77
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py28
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py60
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py20
-rw-r--r--tensorflow/contrib/saved_model/BUILD17
-rw-r--r--tensorflow/contrib/saved_model/__init__.py7
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD2
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc81
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h3
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc123
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py260
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py303
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/reader_test.py12
-rw-r--r--tensorflow/contrib/seq2seq/__init__.py4
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py10
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py18
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py2
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.cc11
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.py2
-rw-r--r--tensorflow/contrib/signal/__init__.py4
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py11
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py316
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py6
-rw-r--r--tensorflow/contrib/slim/python/slim/learning_test.py8
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/alexnet_test.py14
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py10
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py10
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py10
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/overfeat_test.py14
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py18
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py18
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/vgg_test.py36
-rw-r--r--tensorflow/contrib/slim/python/slim/summaries_test.py2
-rw-r--r--tensorflow/contrib/specs/python/specs_test.py22
-rw-r--r--tensorflow/contrib/specs/python/summaries_test.py8
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD9
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py22
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h6
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/k_feature_routing_function_op_test.py2
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/routing_function_op_test.py2
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/data_spec.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc19
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc34
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h3
-rw-r--r--tensorflow/contrib/tensorrt/BUILD45
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc73
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h6
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph_test.cc140
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc266
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h5
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc23
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h8
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resource_manager.h2
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc78
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py154
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py42
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py58
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py47
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/manual_test.py114
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/rank_two_test.py89
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py304
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py21
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py8
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly_test.py4
-rw-r--r--tensorflow/contrib/timeseries/examples/predict.py16
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py25
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py4
-rw-r--r--tensorflow/contrib/tpu/BUILD2
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc103
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto2
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto5
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto8
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py79
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py992
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py287
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py170
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py6
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py6
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py95
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py53
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py31
-rw-r--r--tensorflow/contrib/training/__init__.py4
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py10
-rw-r--r--tensorflow/contrib/training/python/training/bucket_ops_test.py10
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py9
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py33
-rw-r--r--tensorflow/contrib/training/python/training/resample_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_threading_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/training_test.py14
-rw-r--r--tensorflow/contrib/util/__init__.py2
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.h6
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h6
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc2
-rw-r--r--tensorflow/contrib/verbs/verbs_util.h6
-rw-r--r--tensorflow/core/BUILD117
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Fill.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt69
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt112
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt7
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.cc150
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.h20
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc2
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc300
-rw-r--r--tensorflow/core/common_runtime/buf_rendezvous.h6
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc251
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h23
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc204
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.h6
-rw-r--r--tensorflow/core/common_runtime/collective_util.cc83
-rw-r--r--tensorflow/core/common_runtime/collective_util.h38
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc7
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h6
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.h6
-rw-r--r--tensorflow/core/common_runtime/device.h6
-rw-r--r--tensorflow/core/common_runtime/device_factory.h6
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/device_resolver_local.h6
-rw-r--r--tensorflow/core/common_runtime/device_set.h6
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc38
-rw-r--r--tensorflow/core/common_runtime/direct_session.h18
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc117
-rw-r--r--tensorflow/core/common_runtime/dma_helper.h6
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD5
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc8
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h13
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc34
-rw-r--r--tensorflow/core/common_runtime/eager/context.h27
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc19
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h8
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device_test.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc1
-rw-r--r--tensorflow/core/common_runtime/executor.cc136
-rw-r--r--tensorflow/core/common_runtime/executor.h6
-rw-r--r--tensorflow/core/common_runtime/function.cc63
-rw-r--r--tensorflow/core/common_runtime/function.h6
-rw-r--r--tensorflow/core/common_runtime/function_test.cc22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc17
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc9
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h9
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc49
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.h7
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc4
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc440
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h (renamed from tensorflow/core/common_runtime/broadcaster.h)58
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc (renamed from tensorflow/core/common_runtime/broadcaster_test.cc)239
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h6
-rw-r--r--tensorflow/core/common_runtime/local_device.cc2
-rw-r--r--tensorflow/core/common_runtime/local_device.h9
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.cc427
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.h (renamed from tensorflow/core/util/status_util.h)29
-rw-r--r--tensorflow/core/common_runtime/lower_while_op_test.cc249
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h201
-rw-r--r--tensorflow/core/common_runtime/optimization_registry.h11
-rw-r--r--tensorflow/core/common_runtime/placer.cc54
-rw-r--r--tensorflow/core/common_runtime/placer.h8
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc49
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc1
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc70
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h5
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc37
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.h3
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc352
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.h55
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc121
-rw-r--r--tensorflow/core/common_runtime/session_factory.h6
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc92
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h82
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc11
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.h1
-rw-r--r--tensorflow/core/common_runtime/tracing_device.h60
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h6
-rw-r--r--tensorflow/core/debug/debug_callback_registry.h6
-rw-r--r--tensorflow/core/debug/debug_graph_utils.cc4
-rw-r--r--tensorflow/core/debug/debug_graph_utils.h6
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.h6
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc61
-rw-r--r--tensorflow/core/debug/debug_io_utils.h35
-rw-r--r--tensorflow/core/debug/debug_io_utils_test.cc46
-rw-r--r--tensorflow/core/debug/debug_node_key.h6
-rw-r--r--tensorflow/core/debug/debugger_state_impl.cc3
-rw-r--r--tensorflow/core/debug/debugger_state_impl.h6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h2
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc65
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.h3
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h14
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h2
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_wrapper.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc5
-rw-r--r--tensorflow/core/example/example_parser_configuration.h2
-rw-r--r--tensorflow/core/example/feature_util.h6
-rw-r--r--tensorflow/core/framework/attr_value_util.h6
-rw-r--r--tensorflow/core/framework/bfloat16.h6
-rw-r--r--tensorflow/core/framework/cancellation.h6
-rw-r--r--tensorflow/core/framework/collective.cc102
-rw-r--r--tensorflow/core/framework/collective.h119
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc6
-rw-r--r--tensorflow/core/framework/common_shape_fns.h6
-rw-r--r--tensorflow/core/framework/control_flow.h6
-rw-r--r--tensorflow/core/framework/dataset.cc30
-rw-r--r--tensorflow/core/framework/dataset.h81
-rw-r--r--tensorflow/core/framework/dataset_stateful_op_whitelist.h33
-rw-r--r--tensorflow/core/framework/device_base.h9
-rw-r--r--tensorflow/core/framework/fake_input.h6
-rw-r--r--tensorflow/core/framework/function.cc2
-rw-r--r--tensorflow/core/framework/function.h18
-rw-r--r--tensorflow/core/framework/function_testlib.cc76
-rw-r--r--tensorflow/core/framework/function_testlib.h12
-rw-r--r--tensorflow/core/framework/graph_def_util.h6
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h6
-rw-r--r--tensorflow/core/framework/log_memory.h6
-rw-r--r--tensorflow/core/framework/lookup_interface.h6
-rw-r--r--tensorflow/core/framework/memory_types.h6
-rw-r--r--tensorflow/core/framework/node_def_builder.cc17
-rw-r--r--tensorflow/core/framework/node_def_builder.h6
-rw-r--r--tensorflow/core/framework/node_def_util.cc6
-rw-r--r--tensorflow/core/framework/node_def_util.h6
-rw-r--r--tensorflow/core/framework/numeric_op.h6
-rw-r--r--tensorflow/core/framework/numeric_types.h6
-rw-r--r--tensorflow/core/framework/op.h6
-rw-r--r--tensorflow/core/framework/op_def_builder.cc4
-rw-r--r--tensorflow/core/framework/op_def_builder.h6
-rw-r--r--tensorflow/core/framework/op_def_util.cc9
-rw-r--r--tensorflow/core/framework/op_def_util.h11
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc2
-rw-r--r--tensorflow/core/framework/op_gen_lib.h6
-rw-r--r--tensorflow/core/framework/op_kernel.cc2
-rw-r--r--tensorflow/core/framework/queue_interface.h6
-rw-r--r--tensorflow/core/framework/reader_base.h6
-rw-r--r--tensorflow/core/framework/reader_interface.h6
-rw-r--r--tensorflow/core/framework/reader_op_kernel.h6
-rw-r--r--tensorflow/core/framework/register_types.h11
-rw-r--r--tensorflow/core/framework/register_types_traits.h6
-rw-r--r--tensorflow/core/framework/resource_mgr.h10
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h6
-rw-r--r--tensorflow/core/framework/selective_registration.h6
-rw-r--r--tensorflow/core/framework/session_state.h6
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h2
-rw-r--r--tensorflow/core/framework/stats_aggregator.h3
-rw-r--r--tensorflow/core/framework/tensor.cc4
-rw-r--r--tensorflow/core/framework/tensor_slice.h6
-rw-r--r--tensorflow/core/framework/tensor_types.h6
-rw-r--r--tensorflow/core/framework/tensor_util.h6
-rw-r--r--tensorflow/core/framework/tracking_allocator.h6
-rw-r--r--tensorflow/core/framework/type_index.h6
-rw-r--r--tensorflow/core/framework/type_traits.h6
-rw-r--r--tensorflow/core/framework/types.h6
-rw-r--r--tensorflow/core/framework/variant.h6
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h6
-rw-r--r--tensorflow/core/framework/variant_op_registry.h6
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h6
-rw-r--r--tensorflow/core/graph/algorithm.h6
-rw-r--r--tensorflow/core/graph/colors.h6
-rw-r--r--tensorflow/core/graph/control_flow.h6
-rw-r--r--tensorflow/core/graph/costmodel.h6
-rw-r--r--tensorflow/core/graph/default_device.h6
-rw-r--r--tensorflow/core/graph/graph.cc24
-rw-r--r--tensorflow/core/graph/graph.h4
-rw-r--r--tensorflow/core/graph/graph_constructor.cc10
-rw-r--r--tensorflow/core/graph/graph_constructor.h6
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc5
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc4
-rw-r--r--tensorflow/core/graph/graph_def_builder.h8
-rw-r--r--tensorflow/core/graph/graph_partition.cc2
-rw-r--r--tensorflow/core/graph/graph_partition.h6
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc70
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.h6
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc12
-rw-r--r--tensorflow/core/graph/node_builder.cc2
-rw-r--r--tensorflow/core/graph/node_builder.h6
-rw-r--r--tensorflow/core/graph/optimizer_cse.h6
-rw-r--r--tensorflow/core/graph/quantize_training.h6
-rw-r--r--tensorflow/core/graph/subgraph.h6
-rw-r--r--tensorflow/core/graph/tensor_id.cc2
-rw-r--r--tensorflow/core/graph/testlib.cc27
-rw-r--r--tensorflow/core/graph/testlib.h15
-rw-r--r--tensorflow/core/graph/types.h6
-rw-r--r--tensorflow/core/graph/while_context.cc2
-rw-r--r--tensorflow/core/graph/while_context.h6
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc2
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc2
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc56
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc93
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc1
-rw-r--r--tensorflow/core/grappler/costs/utils.cc4
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc11
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc1
-rw-r--r--tensorflow/core/grappler/graph_analyzer/BUILD139
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.cc148
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.h167
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node_test.cc491
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc341
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.h154
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc569
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc98
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h (renamed from tensorflow/compiler/jit/ops/parallel_check_op.cc)27
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools.h47
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc (renamed from tensorflow/core/util/status_util_test.cc)36
-rw-r--r--tensorflow/core/grappler/graph_analyzer/map_tools.h46
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.cc453
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.h304
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node_test.cc1235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.cc235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.h189
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph_test.cc348
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.cc296
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.h120
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc1
-rw-r--r--tensorflow/core/grappler/op_types.cc37
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD71
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc35
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc66
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h4
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc157
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD63
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc141
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc91
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc166
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.h57
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc40
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h11
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc58
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc258
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.h (renamed from tensorflow/core/grappler/optimizers/data/function_rename.h)14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc201
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.h1
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc93
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.h115
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc139
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.cc167
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.h80
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info_test.cc160
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc35
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc46
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc1
-rw-r--r--tensorflow/core/grappler/utils.h2
-rw-r--r--tensorflow/core/grappler/utils/functions.cc38
-rw-r--r--tensorflow/core/grappler/utils/functions.h15
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc31
-rw-r--r--tensorflow/core/kernels/BUILD134
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.h6
-rw-r--r--tensorflow/core/kernels/adjust_hue_op.h6
-rw-r--r--tensorflow/core/kernels/adjust_saturation_op.h6
-rw-r--r--tensorflow/core/kernels/aggregate_ops.h6
-rw-r--r--tensorflow/core/kernels/aggregate_ops_cpu.h6
-rw-r--r--tensorflow/core/kernels/argmax_op.h6
-rw-r--r--tensorflow/core/kernels/assign_op.h6
-rw-r--r--tensorflow/core/kernels/avgpooling_op.h6
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h5
-rw-r--r--tensorflow/core/kernels/batch_norm_op.h6
-rw-r--r--tensorflow/core/kernels/betainc_op.h6
-rw-r--r--tensorflow/core/kernels/bias_op.h6
-rw-r--r--tensorflow/core/kernels/bincount_op.h6
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD63
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h132
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc99
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h330
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc276
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h344
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc223
-rw-r--r--tensorflow/core/kernels/bounds_check.h6
-rw-r--r--tensorflow/core/kernels/broadcast_to_op.h6
-rw-r--r--tensorflow/core/kernels/bucketize_op.h6
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc6
-rw-r--r--tensorflow/core/kernels/cast_op.cc8
-rw-r--r--tensorflow/core/kernels/cast_op.h6
-rw-r--r--tensorflow/core/kernels/collective_ops.cc46
-rw-r--r--tensorflow/core/kernels/colorspace_op.h6
-rw-r--r--tensorflow/core/kernels/concat_lib.h6
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.h5
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h12
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.cc13
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h9
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base_op.h9
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_op.cc3
-rw-r--r--tensorflow/core/kernels/constant_op.cc5
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h6
-rw-r--r--tensorflow/core/kernels/conv_2d.h6
-rw-r--r--tensorflow/core/kernels/conv_3d.h6
-rw-r--r--tensorflow/core/kernels/conv_ops.h6
-rw-r--r--tensorflow/core/kernels/cross_op.h6
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h5
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.h6
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_div.cu.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_zeta.cc5
-rw-r--r--tensorflow/core/kernels/cwise_ops.h14
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h6
-rw-r--r--tensorflow/core/kernels/data/BUILD54
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc87
-rw-r--r--tensorflow/core/kernels/data/captured_function.h25
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc6
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h6
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc24
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.h3
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc20
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc17
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc104
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.h9
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc104
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/optional_ops.h2
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc616
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc53
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc30
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h12
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc372
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.h2
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner_test.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc31
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.h2
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc42
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.cc380
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.h62
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor_test.cc332
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.h4
-rw-r--r--tensorflow/core/kernels/data/sql/query_connection.h3
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.h4
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc5
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc198
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc2
-rw-r--r--tensorflow/core/kernels/data/window_dataset.h2
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc3
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data_format_ops.h6
-rw-r--r--tensorflow/core/kernels/debug_ops.h19
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h6
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc4
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h505
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h9
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h304
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc422
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc165
-rw-r--r--tensorflow/core/kernels/extract_image_patches_op.h6
-rw-r--r--tensorflow/core/kernels/fake_quant_ops_functor.h6
-rw-r--r--tensorflow/core/kernels/fill_functor.h6
-rw-r--r--tensorflow/core/kernels/fractional_pool_common.h6
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.h6
-rw-r--r--tensorflow/core/kernels/gather_functor.h6
-rw-r--r--tensorflow/core/kernels/gather_nd_op.h6
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h21
-rw-r--r--tensorflow/core/kernels/gemm_functors.h5
-rw-r--r--tensorflow/core/kernels/gpu_utils.h3
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.h6
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h2
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h6
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h2
-rw-r--r--tensorflow/core/kernels/hexagon/soc_interface.h6
-rw-r--r--tensorflow/core/kernels/hinge-loss.h6
-rw-r--r--tensorflow/core/kernels/histogram_op.h6
-rw-r--r--tensorflow/core/kernels/i_remote_fused_graph_executor.h6
-rw-r--r--tensorflow/core/kernels/identity_n_op.h6
-rw-r--r--tensorflow/core/kernels/identity_op.h6
-rw-r--r--tensorflow/core/kernels/image_resizer_state.h6
-rw-r--r--tensorflow/core/kernels/immutable_constant_op.h6
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.cc8
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h63
-rw-r--r--tensorflow/core/kernels/inplace_ops_functor.h6
-rw-r--r--tensorflow/core/kernels/l2loss_op.h6
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.h6
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc15
-rw-r--r--tensorflow/core/kernels/list_kernels.h149
-rw-r--r--tensorflow/core/kernels/logistic-loss.h8
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc4
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.h6
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc26
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h15
-rw-r--r--tensorflow/core/kernels/lookup_util.h51
-rw-r--r--tensorflow/core/kernels/loss.h6
-rw-r--r--tensorflow/core/kernels/loss_test.cc174
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc10
-rw-r--r--tensorflow/core/kernels/matmul_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_diag_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_exponential_op.cc1
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op_impl.h5
-rw-r--r--tensorflow/core/kernels/maxpooling_op.h6
-rw-r--r--tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc4
-rw-r--r--tensorflow/core/kernels/mirror_pad_op.h6
-rw-r--r--tensorflow/core/kernels/mirror_pad_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc20
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc51
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc193
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc174
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc186
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h414
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc17
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc59
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc157
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h132
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc607
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc38
-rw-r--r--tensorflow/core/kernels/multinomial_op.h6
-rw-r--r--tensorflow/core/kernels/neon/depthwiseconv_float.h6
-rw-r--r--tensorflow/core/kernels/no_op.h6
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc130
-rw-r--r--tensorflow/core/kernels/nth_element_op.h6
-rw-r--r--tensorflow/core/kernels/one_hot_op.h6
-rw-r--r--tensorflow/core/kernels/ops_testutil.h6
-rw-r--r--tensorflow/core/kernels/ops_util.h6
-rw-r--r--tensorflow/core/kernels/pad_op.h6
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.h6
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.cc31
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.h6
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc73
-rw-r--r--tensorflow/core/kernels/poisson-loss.h109
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d.h6
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d_gpu.h6
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h6
-rw-r--r--tensorflow/core/kernels/priority_queue.h6
-rw-r--r--tensorflow/core/kernels/qr_op_complex128.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_complex64.cc6
-rw-r--r--tensorflow/core/kernels/qr_op_double.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_float.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_impl.h7
-rw-r--r--tensorflow/core/kernels/random_op.h6
-rw-r--r--tensorflow/core/kernels/random_poisson_op.h6
-rw-r--r--tensorflow/core/kernels/range_sampler.h6
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc22
-rw-r--r--tensorflow/core/kernels/record_yielder.h6
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h7
-rw-r--r--tensorflow/core/kernels/reduction_ops.h6
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h6
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc33
-rw-r--r--tensorflow/core/kernels/regex_replace_op.cc80
-rw-r--r--tensorflow/core/kernels/regex_replace_op_test.cc137
-rw-r--r--tensorflow/core/kernels/relu_op.cc27
-rw-r--r--tensorflow/core/kernels/relu_op.h6
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h6
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc35
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc26
-rw-r--r--tensorflow/core/kernels/reshape_op.h6
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc34
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc75
-rw-r--r--tensorflow/core/kernels/reverse_op.h6
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.h6
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc9
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.h6
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc4
-rw-r--r--tensorflow/core/kernels/scan_ops.h6
-rw-r--r--tensorflow/core/kernels/scatter_functor.h6
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.h6
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc2
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc3
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h5
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.h6
-rw-r--r--tensorflow/core/kernels/set_kernels.cc14
-rw-r--r--tensorflow/core/kernels/shape_ops.cc93
-rw-r--r--tensorflow/core/kernels/shape_ops.h6
-rw-r--r--tensorflow/core/kernels/slice_op.h6
-rw-r--r--tensorflow/core/kernels/smooth-hinge-loss.h6
-rw-r--r--tensorflow/core/kernels/snapshot_op.h6
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h6
-rw-r--r--tensorflow/core/kernels/softplus_op.cc11
-rw-r--r--tensorflow/core/kernels/softplus_op.h6
-rw-r--r--tensorflow/core/kernels/softsign_op.cc11
-rw-r--r--tensorflow/core/kernels/softsign_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator.h10
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_softmax_op.cc2
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_add_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.h6
-rw-r--r--tensorflow/core/kernels/split_lib.h6
-rw-r--r--tensorflow/core/kernels/squared-loss.h6
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc38
-rw-r--r--tensorflow/core/kernels/strided_slice_op.h6
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h6
-rw-r--r--tensorflow/core/kernels/string_split_op.cc111
-rw-r--r--tensorflow/core/kernels/string_split_op_test.cc129
-rw-r--r--tensorflow/core/kernels/string_strip_op.cc2
-rw-r--r--tensorflow/core/kernels/svd_op_impl.h5
-rw-r--r--tensorflow/core/kernels/tensor_array.h6
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc14
-rw-r--r--tensorflow/core/kernels/tile_functor.h6
-rw-r--r--tensorflow/core/kernels/tile_ops_impl.h6
-rw-r--r--tensorflow/core/kernels/topk_op.h6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h6
-rw-r--r--tensorflow/core/kernels/training_ops.cc52
-rw-r--r--tensorflow/core/kernels/training_ops.h6
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h11
-rw-r--r--tensorflow/core/kernels/variable_ops.h6
-rw-r--r--tensorflow/core/kernels/warn_about_ints.cc33
-rw-r--r--tensorflow/core/kernels/where_op.h6
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h5
-rw-r--r--tensorflow/core/kernels/whole_file_read_ops.cc2
-rw-r--r--tensorflow/core/kernels/xent_op.h6
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h4
-rw-r--r--tensorflow/core/lib/core/arena.h6
-rw-r--r--tensorflow/core/lib/core/bits.h6
-rw-r--r--tensorflow/core/lib/core/casts.h6
-rw-r--r--tensorflow/core/lib/core/coding.h6
-rw-r--r--tensorflow/core/lib/core/errors.h24
-rw-r--r--tensorflow/core/lib/core/notification.h6
-rw-r--r--tensorflow/core/lib/core/raw_coding.h6
-rw-r--r--tensorflow/core/lib/core/status.cc2
-rw-r--r--tensorflow/core/lib/core/status_test_util.h6
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc54
-rw-r--r--tensorflow/core/lib/core/stringpiece.h119
-rw-r--r--tensorflow/core/lib/core/stringpiece_test.cc4
-rw-r--r--tensorflow/core/lib/core/threadpool.h6
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h294
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h269
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc666
-rw-r--r--tensorflow/core/lib/gtl/cleanup.h6
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h671
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc898
-rw-r--r--tensorflow/core/lib/gtl/optional.h859
-rw-r--r--tensorflow/core/lib/gtl/optional_test.cc1098
-rw-r--r--tensorflow/core/lib/gtl/priority_queue_util.h6
-rw-r--r--tensorflow/core/lib/hash/crc32c.h6
-rw-r--r--tensorflow/core/lib/hash/hash.h6
-rw-r--r--tensorflow/core/lib/histogram/histogram.h6
-rw-r--r--tensorflow/core/lib/io/buffered_inputstream.h6
-rw-r--r--tensorflow/core/lib/io/inputstream_interface.h4
-rw-r--r--tensorflow/core/lib/io/path.cc6
-rw-r--r--tensorflow/core/lib/io/path.h6
-rw-r--r--tensorflow/core/lib/io/path_test.cc6
-rw-r--r--tensorflow/core/lib/io/proto_encode_helper.h6
-rw-r--r--tensorflow/core/lib/io/random_inputstream.h6
-rw-r--r--tensorflow/core/lib/io/record_reader.cc3
-rw-r--r--tensorflow/core/lib/io/record_reader.h14
-rw-r--r--tensorflow/core/lib/io/record_writer.cc15
-rw-r--r--tensorflow/core/lib/io/record_writer.h40
-rw-r--r--tensorflow/core/lib/io/table.h6
-rw-r--r--tensorflow/core/lib/io/table_builder.h6
-rw-r--r--tensorflow/core/lib/io/table_options.h6
-rw-r--r--tensorflow/core/lib/io/table_test.cc6
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.h6
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.h6
-rw-r--r--tensorflow/core/lib/math/math_util.h6
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.cc8
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h4
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h4
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.h6
-rw-r--r--tensorflow/core/lib/random/philox_random.h6
-rw-r--r--tensorflow/core/lib/random/random_distributions.h6
-rw-r--r--tensorflow/core/lib/random/simple_philox.h6
-rw-r--r--tensorflow/core/lib/strings/numbers.h10
-rw-r--r--tensorflow/core/lib/strings/str_util.cc5
-rw-r--r--tensorflow/core/lib/strings/str_util.h8
-rw-r--r--tensorflow/core/lib/strings/strcat.h43
-rw-r--r--tensorflow/core/lib/strings/strcat_test.cc8
-rw-r--r--tensorflow/core/lib/strings/stringprintf.h6
-rw-r--r--tensorflow/core/ops/array_grad.cc21
-rw-r--r--tensorflow/core/ops/array_ops.cc24
-rw-r--r--tensorflow/core/ops/array_ops_test.cc18
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt1036
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc41
-rw-r--r--tensorflow/core/ops/image_ops.cc15
-rw-r--r--tensorflow/core/ops/linalg_ops.cc2
-rw-r--r--tensorflow/core/ops/list_ops.cc51
-rw-r--r--tensorflow/core/ops/logging_ops.cc7
-rw-r--r--tensorflow/core/ops/lookup_ops.cc4
-rw-r--r--tensorflow/core/ops/math_grad.cc8
-rw-r--r--tensorflow/core/ops/math_grad_test.cc6
-rw-r--r--tensorflow/core/ops/math_ops.cc7
-rw-r--r--tensorflow/core/ops/math_ops_test.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc187
-rw-r--r--tensorflow/core/ops/ops.pbtxt609
-rw-r--r--tensorflow/core/ops/parsing_ops.cc93
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc82
-rw-r--r--tensorflow/core/ops/sdca_ops.cc2
-rw-r--r--tensorflow/core/ops/string_ops.cc14
-rw-r--r--tensorflow/core/platform/abi.h6
-rw-r--r--tensorflow/core/platform/cloud/auth_provider.h6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider.cc2
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_dns_cache.h6
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc14
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider.h6
-rw-r--r--tensorflow/core/platform/cloud/http_request.h6
-rw-r--r--tensorflow/core/platform/cloud/http_request_fake.h6
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/cloud/oauth_client_test.cc6
-rw-r--r--tensorflow/core/platform/context.h6
-rw-r--r--tensorflow/core/platform/cpu_feature_guard.h6
-rw-r--r--tensorflow/core/platform/cpu_info.h6
-rw-r--r--tensorflow/core/platform/default/build_config.bzl1194
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc5
-rw-r--r--tensorflow/core/platform/default/integral_types.h6
-rw-r--r--tensorflow/core/platform/default/logging.h6
-rw-r--r--tensorflow/core/platform/default/mutex.h6
-rw-r--r--tensorflow/core/platform/default/thread_annotations.h6
-rw-r--r--tensorflow/core/platform/default/tracing_impl.h6
-rw-r--r--tensorflow/core/platform/denormal.h6
-rw-r--r--tensorflow/core/platform/dynamic_annotations.h6
-rw-r--r--tensorflow/core/platform/env.cc4
-rw-r--r--tensorflow/core/platform/file_system.cc2
-rw-r--r--tensorflow/core/platform/file_system_helper.cc2
-rw-r--r--tensorflow/core/platform/file_system_test.cc2
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc6
-rw-r--r--tensorflow/core/platform/host_info.h6
-rw-r--r--tensorflow/core/platform/init_main.h6
-rw-r--r--tensorflow/core/platform/load_library.h6
-rw-r--r--tensorflow/core/platform/logging.h6
-rw-r--r--tensorflow/core/platform/macros.h6
-rw-r--r--tensorflow/core/platform/mem.h6
-rw-r--r--tensorflow/core/platform/mutex.h6
-rw-r--r--tensorflow/core/platform/net.h6
-rw-r--r--tensorflow/core/platform/png.h6
-rw-r--r--tensorflow/core/platform/posix/error.h2
-rw-r--r--tensorflow/core/platform/posix/port.cc6
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.h2
-rw-r--r--tensorflow/core/platform/posix/subprocess.h6
-rw-r--r--tensorflow/core/platform/prefetch.h6
-rw-r--r--tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h6
-rw-r--r--tensorflow/core/platform/profile_utils/clock_cycle_profiler.h6
-rw-r--r--tensorflow/core/platform/profile_utils/cpu_utils.h6
-rw-r--r--tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h6
-rw-r--r--tensorflow/core/platform/protobuf.h6
-rw-r--r--tensorflow/core/platform/protobuf_internal.h6
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc4
-rw-r--r--tensorflow/core/platform/setround.h6
-rw-r--r--tensorflow/core/platform/snappy.h6
-rw-r--r--tensorflow/core/platform/stacktrace_handler.h6
-rw-r--r--tensorflow/core/platform/subprocess.h6
-rw-r--r--tensorflow/core/platform/test.h6
-rw-r--r--tensorflow/core/platform/test_benchmark.h6
-rw-r--r--tensorflow/core/platform/thread_annotations.h6
-rw-r--r--tensorflow/core/platform/tracing.h10
-rw-r--r--tensorflow/core/platform/types.h6
-rw-r--r--tensorflow/core/platform/windows/cpu_info.h6
-rw-r--r--tensorflow/core/platform/windows/integral_types.h6
-rw-r--r--tensorflow/core/platform/windows/subprocess.h6
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.h2
-rw-r--r--tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h2
-rw-r--r--tensorflow/core/profiler/internal/advisor/tfprof_advisor.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_code.cc4
-rw-r--r--tensorflow/core/profiler/tfprof_options.h6
-rw-r--r--tensorflow/core/protobuf/config.proto9
-rw-r--r--tensorflow/core/protobuf/debug.proto6
-rw-r--r--tensorflow/core/public/session.h6
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/activation_mode.h6
-rw-r--r--tensorflow/core/util/bcast.h6
-rw-r--r--tensorflow/core/util/command_line_flags.cc2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_entry.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_scorer.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h19
-rw-r--r--tensorflow/core/util/ctc/ctc_decoder.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_loss_util.h2
-rw-r--r--tensorflow/core/util/device_name_utils.h6
-rw-r--r--tensorflow/core/util/env_var.cc8
-rw-r--r--tensorflow/core/util/env_var.h5
-rw-r--r--tensorflow/core/util/events_writer.h6
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc230
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h3
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc2
-rw-r--r--tensorflow/core/util/example_proto_helper.cc53
-rw-r--r--tensorflow/core/util/example_proto_helper.h61
-rw-r--r--tensorflow/core/util/guarded_philox_random.h6
-rw-r--r--tensorflow/core/util/mirror_pad_mode.h6
-rw-r--r--tensorflow/core/util/mkl_util.h164
-rw-r--r--tensorflow/core/util/padding.h6
-rw-r--r--tensorflow/core/util/port.h6
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h6
-rw-r--r--tensorflow/core/util/strided_slice_op.cc2
-rw-r--r--tensorflow/core/util/tensor_bundle/naming.h6
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc16
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.h6
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc2
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h6
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.h6
-rw-r--r--tensorflow/core/util/tensor_slice_writer.h6
-rw-r--r--tensorflow/core/util/util.h6
-rw-r--r--tensorflow/core/util/work_sharder.h6
-rw-r--r--tensorflow/docs_src/README.md3
-rw-r--r--tensorflow/docs_src/about/attribution.md9
-rw-r--r--tensorflow/docs_src/about/bib.md131
-rw-r--r--tensorflow/docs_src/about/index.md11
-rw-r--r--tensorflow/docs_src/about/leftnav_files4
-rw-r--r--tensorflow/docs_src/about/uses.md68
-rw-r--r--tensorflow/docs_src/api_guides/cc/guide.md301
-rw-r--r--tensorflow/docs_src/api_guides/python/array_ops.md87
-rw-r--r--tensorflow/docs_src/api_guides/python/check_ops.md19
-rw-r--r--tensorflow/docs_src/api_guides/python/client.md36
-rw-r--r--tensorflow/docs_src/api_guides/python/constant_op.md87
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.crf.md11
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.ffmpeg.md23
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.framework.md64
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.graph_editor.md177
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.integrate.md41
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.layers.md109
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.learn.md63
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.linalg.md30
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.losses.md125
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.metrics.md133
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.rnn.md61
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.seq2seq.md138
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.signal.md172
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.staging.md6
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.training.md50
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.util.md12
-rw-r--r--tensorflow/docs_src/api_guides/python/control_flow_ops.md57
-rw-r--r--tensorflow/docs_src/api_guides/python/framework.md51
-rw-r--r--tensorflow/docs_src/api_guides/python/functional_ops.md18
-rw-r--r--tensorflow/docs_src/api_guides/python/image.md144
-rw-r--r--tensorflow/docs_src/api_guides/python/index.md52
-rw-r--r--tensorflow/docs_src/api_guides/python/input_dataset.md85
-rw-r--r--tensorflow/docs_src/api_guides/python/io_ops.md130
-rw-r--r--tensorflow/docs_src/api_guides/python/math_ops.md199
-rw-r--r--tensorflow/docs_src/api_guides/python/meta_graph.md277
-rw-r--r--tensorflow/docs_src/api_guides/python/nn.md418
-rw-r--r--tensorflow/docs_src/api_guides/python/python_io.md29
-rw-r--r--tensorflow/docs_src/api_guides/python/reading_data.md522
-rw-r--r--tensorflow/docs_src/api_guides/python/regression_examples.md232
-rw-r--r--tensorflow/docs_src/api_guides/python/session_ops.md15
-rw-r--r--tensorflow/docs_src/api_guides/python/sparse_ops.md45
-rw-r--r--tensorflow/docs_src/api_guides/python/spectral_ops.md26
-rw-r--r--tensorflow/docs_src/api_guides/python/state_ops.md110
-rw-r--r--tensorflow/docs_src/api_guides/python/string_ops.md39
-rw-r--r--tensorflow/docs_src/api_guides/python/summary.md23
-rw-r--r--tensorflow/docs_src/api_guides/python/test.md47
-rw-r--r--tensorflow/docs_src/api_guides/python/tfdbg.md50
-rw-r--r--tensorflow/docs_src/api_guides/python/threading_and_queues.md270
-rw-r--r--tensorflow/docs_src/api_guides/python/train.md139
-rw-r--r--tensorflow/docs_src/community/benchmarks.md108
-rw-r--r--tensorflow/docs_src/community/contributing.md49
-rw-r--r--tensorflow/docs_src/community/documentation.md673
-rw-r--r--tensorflow/docs_src/community/groups.md38
-rw-r--r--tensorflow/docs_src/community/index.md85
-rw-r--r--tensorflow/docs_src/community/leftnav_files8
-rw-r--r--tensorflow/docs_src/community/lists.md53
-rw-r--r--tensorflow/docs_src/community/roadmap.md121
-rw-r--r--tensorflow/docs_src/community/style_guide.md136
-rw-r--r--tensorflow/docs_src/deploy/deploy_to_js.md4
-rw-r--r--tensorflow/docs_src/deploy/distributed.md354
-rw-r--r--tensorflow/docs_src/deploy/hadoop.md65
-rw-r--r--tensorflow/docs_src/deploy/index.md21
-rw-r--r--tensorflow/docs_src/deploy/leftnav_files5
-rw-r--r--tensorflow/docs_src/deploy/s3.md93
-rw-r--r--tensorflow/docs_src/extend/add_filesys.md260
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md1460
-rw-r--r--tensorflow/docs_src/extend/architecture.md217
-rw-r--r--tensorflow/docs_src/extend/index.md34
-rw-r--r--tensorflow/docs_src/extend/language_bindings.md231
-rw-r--r--tensorflow/docs_src/extend/leftnav_files7
-rw-r--r--tensorflow/docs_src/extend/new_data_formats.md305
-rw-r--r--tensorflow/docs_src/extend/tool_developers/index.md186
-rw-r--r--tensorflow/docs_src/extras/README.txt3
-rw-r--r--tensorflow/docs_src/guide/autograph.md3
-rw-r--r--tensorflow/docs_src/guide/checkpoints.md238
-rw-r--r--tensorflow/docs_src/guide/custom_estimators.md602
-rw-r--r--tensorflow/docs_src/guide/datasets.md823
-rw-r--r--tensorflow/docs_src/guide/datasets_for_estimators.md387
-rw-r--r--tensorflow/docs_src/guide/debugger.md814
-rw-r--r--tensorflow/docs_src/guide/eager.md854
-rw-r--r--tensorflow/docs_src/guide/embedding.md262
-rw-r--r--tensorflow/docs_src/guide/estimators.md196
-rw-r--r--tensorflow/docs_src/guide/faq.md296
-rw-r--r--tensorflow/docs_src/guide/feature_columns.md572
-rw-r--r--tensorflow/docs_src/guide/graph_viz.md317
-rw-r--r--tensorflow/docs_src/guide/graphs.md558
-rw-r--r--tensorflow/docs_src/guide/index.md82
-rw-r--r--tensorflow/docs_src/guide/keras.md623
-rw-r--r--tensorflow/docs_src/guide/leftnav_files41
-rw-r--r--tensorflow/docs_src/guide/low_level_intro.md604
-rw-r--r--tensorflow/docs_src/guide/premade_estimators.md430
-rw-r--r--tensorflow/docs_src/guide/saved_model.md999
-rw-r--r--tensorflow/docs_src/guide/summaries_and_tensorboard.md225
-rw-r--r--tensorflow/docs_src/guide/tensorboard_histograms.md245
-rw-r--r--tensorflow/docs_src/guide/tensors.md330
-rw-r--r--tensorflow/docs_src/guide/using_gpu.md215
-rw-r--r--tensorflow/docs_src/guide/using_tpu.md395
-rw-r--r--tensorflow/docs_src/guide/variables.md319
-rw-r--r--tensorflow/docs_src/guide/version_compat.md322
-rw-r--r--tensorflow/docs_src/install/index.md39
-rw-r--r--tensorflow/docs_src/install/install_c.md118
-rw-r--r--tensorflow/docs_src/install/install_go.md142
-rw-r--r--tensorflow/docs_src/install/install_java.md268
-rw-r--r--tensorflow/docs_src/install/install_linux.md714
-rw-r--r--tensorflow/docs_src/install/install_mac.md529
-rw-r--r--tensorflow/docs_src/install/install_raspbian.md313
-rw-r--r--tensorflow/docs_src/install/install_sources.md577
-rw-r--r--tensorflow/docs_src/install/install_sources_windows.md320
-rw-r--r--tensorflow/docs_src/install/install_windows.md227
-rw-r--r--tensorflow/docs_src/install/leftnav_files18
-rw-r--r--tensorflow/docs_src/install/migration.md336
-rw-r--r--tensorflow/docs_src/mobile/README.md3
-rw-r--r--tensorflow/docs_src/performance/benchmarks.md412
-rw-r--r--tensorflow/docs_src/performance/datasets_performance.md331
-rw-r--r--tensorflow/docs_src/performance/index.md52
-rw-r--r--tensorflow/docs_src/performance/leftnav_files14
-rw-r--r--tensorflow/docs_src/performance/performance_guide.md733
-rw-r--r--tensorflow/docs_src/performance/performance_models.md422
-rw-r--r--tensorflow/docs_src/performance/quantization.md253
-rw-r--r--tensorflow/docs_src/performance/xla/broadcasting.md204
-rw-r--r--tensorflow/docs_src/performance/xla/developing_new_backend.md77
-rw-r--r--tensorflow/docs_src/performance/xla/index.md98
-rw-r--r--tensorflow/docs_src/performance/xla/jit.md169
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md2433
-rw-r--r--tensorflow/docs_src/performance/xla/shapes.md150
-rw-r--r--tensorflow/docs_src/performance/xla/tfcompile.md281
-rw-r--r--tensorflow/docs_src/tutorials/_index.yaml202
-rw-r--r--tensorflow/docs_src/tutorials/_toc.yaml124
-rw-r--r--tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md3
-rw-r--r--tensorflow/docs_src/tutorials/eager/index.md13
-rw-r--r--tensorflow/docs_src/tutorials/estimators/cnn.md694
-rw-r--r--tensorflow/docs_src/tutorials/estimators/linear.md3
-rw-r--r--tensorflow/docs_src/tutorials/images/deep_cnn.md446
-rw-r--r--tensorflow/docs_src/tutorials/images/image_recognition.md455
-rw-r--r--tensorflow/docs_src/tutorials/keras/basic_classification.md3
-rw-r--r--tensorflow/docs_src/tutorials/keras/basic_regression.md3
-rw-r--r--tensorflow/docs_src/tutorials/keras/basic_text_classification.md3
-rw-r--r--tensorflow/docs_src/tutorials/keras/index.md22
-rw-r--r--tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md3
-rw-r--r--tensorflow/docs_src/tutorials/keras/save_and_restore_models.md3
-rw-r--r--tensorflow/docs_src/tutorials/next_steps.md36
-rw-r--r--tensorflow/docs_src/tutorials/non-ml/mandelbrot.md116
-rw-r--r--tensorflow/docs_src/tutorials/non-ml/pdes.md140
-rw-r--r--tensorflow/docs_src/tutorials/representation/kernel_methods.md303
-rw-r--r--tensorflow/docs_src/tutorials/representation/linear.md239
-rw-r--r--tensorflow/docs_src/tutorials/representation/word2vec.md405
-rw-r--r--tensorflow/docs_src/tutorials/sequences/audio_recognition.md631
-rw-r--r--tensorflow/docs_src/tutorials/sequences/recurrent.md230
-rw-r--r--tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md411
-rw-r--r--tensorflow/examples/adding_an_op/cuda_op_test.py2
-rw-r--r--tensorflow/examples/adding_an_op/fact_test.py2
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_1_test.py2
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_2_test.py8
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_3_test.py8
-rw-r--r--tensorflow/examples/android/README.md8
-rw-r--r--tensorflow/examples/android/jni/object_tracking/jni_utils.h2
-rw-r--r--tensorflow/examples/android/jni/object_tracking/logging.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_model.h6
-rwxr-xr-xtensorflow/examples/android/jni/rgb2yuv.h6
-rw-r--r--tensorflow/examples/android/jni/yuv2rgb.h6
-rw-r--r--tensorflow/examples/ios/benchmark/ios_image_load.h6
-rw-r--r--tensorflow/examples/ios/camera/ios_image_load.h6
-rw-r--r--tensorflow/examples/label_image/main.cc2
-rw-r--r--tensorflow/examples/speech_commands/models.py2
-rw-r--r--tensorflow/go/op/wrappers.go1736
-rw-r--r--tensorflow/java/maven/pom.xml4
-rw-r--r--tensorflow/java/maven/run_inside_container.sh6
-rw-r--r--tensorflow/java/maven/spark-tensorflow-connector/pom.xml (renamed from tensorflow/java/maven/spark-connector/pom.xml)6
-rw-r--r--tensorflow/java/maven/tensorflow-hadoop/pom.xml (renamed from tensorflow/java/maven/hadoop/pom.xml)4
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h30
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc74
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc42
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h14
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc1
-rw-r--r--tensorflow/java/src/main/native/exception_jni.h6
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h6
-rw-r--r--tensorflow/java/src/main/native/operation_builder_jni.h6
-rw-r--r--tensorflow/java/src/main/native/operation_jni.h6
-rw-r--r--tensorflow/java/src/main/native/saved_model_bundle_jni.h6
-rw-r--r--tensorflow/java/src/main/native/session_jni.h6
-rw-r--r--tensorflow/java/src/main/native/tensor_jni.h6
-rw-r--r--tensorflow/java/src/main/native/tensorflow_jni.h6
-rw-r--r--tensorflow/java/src/main/native/utils_jni.h6
-rw-r--r--tensorflow/js/BUILD52
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc290
-rw-r--r--tensorflow/js/ops/ts_op_gen.h (renamed from tensorflow/core/kernels/warn_about_ints.h)18
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc246
-rw-r--r--tensorflow/python/BUILD79
-rw-r--r--tensorflow/python/__init__.py7
-rw-r--r--tensorflow/python/client/client_lib.py2
-rw-r--r--tensorflow/python/client/session.py8
-rw-r--r--tensorflow/python/client/session_test.py6
-rw-r--r--tensorflow/python/compat/compat.py12
-rw-r--r--tensorflow/python/data/__init__.py2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py22
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py24
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py16
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py28
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py167
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py91
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py18
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py216
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py48
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py50
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py14
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py12
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py4
-rw-r--r--tensorflow/python/data/ops/BUILD1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py81
-rw-r--r--tensorflow/python/data/util/BUILD35
-rw-r--r--tensorflow/python/data/util/convert_test.py16
-rw-r--r--tensorflow/python/data/util/nest.py37
-rw-r--r--tensorflow/python/data/util/nest_test.py27
-rw-r--r--tensorflow/python/data/util/sparse_test.py2
-rw-r--r--tensorflow/python/data/util/structure.py315
-rw-r--r--tensorflow/python/data/util/structure_test.py327
-rw-r--r--tensorflow/python/debug/BUILD19
-rw-r--r--tensorflow/python/debug/__init__.py2
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py9
-rw-r--r--tensorflow/python/debug/lib/debug_utils.py12
-rw-r--r--tensorflow/python/debug/wrappers/disk_usage_test.py109
-rw-r--r--tensorflow/python/debug/wrappers/framework.py25
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py5
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py5
-rw-r--r--tensorflow/python/distribute/BUILD83
-rw-r--r--tensorflow/python/distribute/distribute_config.py45
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py481
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_context.py (renamed from tensorflow/contrib/kfac/python/ops/op_queue_lib.py)21
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py452
-rw-r--r--tensorflow/python/distribute/estimator_training.py264
-rw-r--r--tensorflow/python/distribute/multi_worker_util.py80
-rw-r--r--tensorflow/python/distribute/multi_worker_util_test.py107
-rw-r--r--tensorflow/python/eager/BUILD39
-rw-r--r--tensorflow/python/eager/backprop.py157
-rw-r--r--tensorflow/python/eager/backprop_test.py101
-rw-r--r--tensorflow/python/eager/benchmarks_test.py55
-rw-r--r--tensorflow/python/eager/context.py41
-rw-r--r--tensorflow/python/eager/core_test.py24
-rw-r--r--tensorflow/python/eager/execution_callbacks.py8
-rw-r--r--tensorflow/python/eager/function.py875
-rw-r--r--tensorflow/python/eager/function_test.py290
-rw-r--r--tensorflow/python/eager/graph_callable.py435
-rw-r--r--tensorflow/python/eager/graph_callable_test.py249
-rw-r--r--tensorflow/python/eager/graph_only_ops_test.py4
-rw-r--r--tensorflow/python/eager/imperative_grad.py10
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc90
-rwxr-xr-x[-rw-r--r--]tensorflow/python/eager/pywrap_tfe.h29
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc375
-rw-r--r--tensorflow/python/eager/tape.py26
-rw-r--r--tensorflow/python/eager/tape_test.py4
-rw-r--r--tensorflow/python/eager/tensor_test.py15
-rw-r--r--tensorflow/python/estimator/BUILD4
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py263
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py36
-rw-r--r--tensorflow/python/estimator/canned/dnn.py14
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py13
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py6
-rw-r--r--tensorflow/python/estimator/canned/head.py4
-rw-r--r--tensorflow/python/estimator/canned/head_test.py208
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py10
-rw-r--r--tensorflow/python/estimator/canned/prediction_keys.py1
-rw-r--r--tensorflow/python/estimator/estimator.py391
-rw-r--r--tensorflow/python/estimator/estimator_test.py131
-rw-r--r--tensorflow/python/estimator/export/export.py45
-rw-r--r--tensorflow/python/estimator/export/export_output.py24
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py89
-rw-r--r--tensorflow/python/estimator/export/export_test.py30
-rw-r--r--tensorflow/python/estimator/exporter_test.py37
-rw-r--r--tensorflow/python/estimator/gc.py8
-rw-r--r--tensorflow/python/estimator/gc_test.py11
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py34
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io_test.py24
-rw-r--r--tensorflow/python/estimator/keras.py374
-rw-r--r--tensorflow/python/estimator/keras_test.py112
-rw-r--r--tensorflow/python/estimator/model_fn.py111
-rw-r--r--tensorflow/python/estimator/model_fn_test.py155
-rw-r--r--tensorflow/python/estimator/run_config.py36
-rw-r--r--tensorflow/python/estimator/training.py37
-rw-r--r--tensorflow/python/estimator/training_test.py33
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/estimator/util_test.py4
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py122
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py564
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py827
-rw-r--r--tensorflow/python/framework/constant_op.py16
-rw-r--r--tensorflow/python/framework/device.py38
-rw-r--r--tensorflow/python/framework/error_interpolation.py85
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py116
-rw-r--r--tensorflow/python/framework/errors_impl.py9
-rw-r--r--tensorflow/python/framework/errors_test.py29
-rw-r--r--tensorflow/python/framework/file_system_test.py2
-rw-r--r--tensorflow/python/framework/function.py23
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py20
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py49
-rw-r--r--tensorflow/python/framework/function_test.py36
-rw-r--r--tensorflow/python/framework/importer_test.py22
-rw-r--r--tensorflow/python/framework/meta_graph_test.py21
-rw-r--r--tensorflow/python/framework/ops.py122
-rw-r--r--tensorflow/python/framework/ops_enable_eager_test.py (renamed from tensorflow/contrib/eager/python/examples/scan/scan_test.py)42
-rw-r--r--tensorflow/python/framework/ops_test.py83
-rw-r--r--tensorflow/python/framework/python_op_gen.cc9
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.cc22
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc2
-rw-r--r--tensorflow/python/framework/smart_cond.py6
-rw-r--r--tensorflow/python/framework/sparse_tensor.py23
-rw-r--r--tensorflow/python/framework/sparse_tensor_test.py20
-rw-r--r--tensorflow/python/framework/subscribe.py7
-rw-r--r--tensorflow/python/framework/subscribe_test.py14
-rw-r--r--tensorflow/python/framework/tensor_shape.py7
-rw-r--r--tensorflow/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/tensor_util_test.py2
-rw-r--r--tensorflow/python/framework/test_util.py421
-rw-r--r--tensorflow/python/framework/test_util_test.py70
-rw-r--r--tensorflow/python/grappler/cost_analyzer.h6
-rw-r--r--tensorflow/python/grappler/graph_analyzer.i (renamed from tensorflow/core/lib/gtl/optional.cc)19
-rw-r--r--tensorflow/python/grappler/graph_analyzer.py46
-rw-r--r--tensorflow/python/grappler/model_analyzer.h6
-rwxr-xr-xtensorflow/python/keras/BUILD20
-rw-r--r--tensorflow/python/keras/activations_test.py20
-rw-r--r--tensorflow/python/keras/applications/__init__.py51
-rw-r--r--tensorflow/python/keras/applications/applications_test.py8
-rw-r--r--tensorflow/python/keras/applications/densenet.py47
-rw-r--r--tensorflow/python/keras/applications/imagenet_utils.py33
-rw-r--r--tensorflow/python/keras/applications/inception_resnet_v2.py26
-rw-r--r--tensorflow/python/keras/applications/inception_v3.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet_v2.py24
-rw-r--r--tensorflow/python/keras/applications/nasnet.py35
-rw-r--r--tensorflow/python/keras/applications/resnet50.py24
-rw-r--r--tensorflow/python/keras/applications/vgg16.py24
-rw-r--r--tensorflow/python/keras/applications/vgg19.py24
-rw-r--r--tensorflow/python/keras/applications/xception.py25
-rw-r--r--tensorflow/python/keras/backend.py75
-rw-r--r--tensorflow/python/keras/backend_test.py111
-rw-r--r--tensorflow/python/keras/callbacks_test.py107
-rw-r--r--tensorflow/python/keras/constraints_test.py8
-rw-r--r--tensorflow/python/keras/engine/base_layer.py50
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py99
-rw-r--r--tensorflow/python/keras/engine/feature_columns_integration_test.py237
-rw-r--r--tensorflow/python/keras/engine/network.py23
-rw-r--r--tensorflow/python/keras/engine/saving.py2
-rw-r--r--tensorflow/python/keras/engine/saving_test.py42
-rw-r--r--tensorflow/python/keras/engine/sequential.py5
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py43
-rw-r--r--tensorflow/python/keras/engine/topology_test.py36
-rw-r--r--tensorflow/python/keras/engine/training.py345
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py5
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py625
-rw-r--r--tensorflow/python/keras/engine/training_eager.py23
-rw-r--r--tensorflow/python/keras/engine/training_test.py1551
-rw-r--r--tensorflow/python/keras/engine/training_utils.py141
-rw-r--r--tensorflow/python/keras/engine/training_utils_test.py89
-rw-r--r--tensorflow/python/keras/initializers.py88
-rw-r--r--tensorflow/python/keras/initializers_test.py43
-rw-r--r--tensorflow/python/keras/integration_test.py22
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py18
-rw-r--r--tensorflow/python/keras/layers/convolutional_recurrent_test.py12
-rw-r--r--tensorflow/python/keras/layers/core_test.py20
-rw-r--r--tensorflow/python/keras/layers/embeddings_test.py2
-rw-r--r--tensorflow/python/keras/layers/gru_test.py8
-rw-r--r--tensorflow/python/keras/layers/local_test.py8
-rw-r--r--tensorflow/python/keras/layers/lstm_test.py22
-rw-r--r--tensorflow/python/keras/layers/merge_test.py6
-rw-r--r--tensorflow/python/keras/layers/noise_test.py4
-rw-r--r--tensorflow/python/keras/layers/normalization.py14
-rw-r--r--tensorflow/python/keras/layers/normalization_test.py18
-rw-r--r--tensorflow/python/keras/layers/recurrent.py163
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py75
-rw-r--r--tensorflow/python/keras/layers/simplernn_test.py8
-rw-r--r--tensorflow/python/keras/layers/wrappers.py26
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py38
-rw-r--r--tensorflow/python/keras/losses_test.py8
-rw-r--r--tensorflow/python/keras/metrics.py72
-rw-r--r--tensorflow/python/keras/metrics_test.py57
-rw-r--r--tensorflow/python/keras/models.py229
-rw-r--r--tensorflow/python/keras/models_test.py224
-rw-r--r--tensorflow/python/keras/optimizers.py16
-rw-r--r--tensorflow/python/keras/optimizers_test.py19
-rw-r--r--tensorflow/python/keras/preprocessing/__init__.py2
-rw-r--r--tensorflow/python/keras/preprocessing/image.py492
-rw-r--r--tensorflow/python/keras/preprocessing/sequence.py63
-rw-r--r--tensorflow/python/keras/regularizers_test.py4
-rw-r--r--tensorflow/python/keras/testing_utils.py19
-rw-r--r--tensorflow/python/keras/utils/data_utils.py1
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py1
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py17
-rw-r--r--tensorflow/python/keras/utils/tf_utils.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD31
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/batch_scatter_ops_test.py129
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py44
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py239
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py31
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py88
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py2
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py119
-rw-r--r--tensorflow/python/kernel_tests/ctc_decoder_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py196
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py462
-rw-r--r--tensorflow/python/kernel_tests/distributions/bijector_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py42
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py56
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_test.py262
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py187
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py529
-rw-r--r--tensorflow/python/kernel_tests/distributions/identity_bijector_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/laplace_test.py439
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py46
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py613
-rw-r--r--tensorflow/python/kernel_tests/distributions/special_math_test.py35
-rw-r--r--tensorflow/python/kernel_tests/distributions/student_t_test.py505
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py354
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py230
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py21
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_grad_test.py20
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py470
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py16
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py14
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py46
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_test.py10
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py44
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py12
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py108
-rw-r--r--tensorflow/python/kernel_tests/matrix_logarithm_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py1158
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py111
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py99
-rw-r--r--tensorflow/python/kernel_tests/random/random_crop_test.py6
-rw-r--r--tensorflow/python/kernel_tests/random/random_gamma_test.py2
-rw-r--r--tensorflow/python/kernel_tests/random/random_grad_test.py4
-rw-r--r--tensorflow/python/kernel_tests/random/random_poisson_test.py4
-rw-r--r--tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py114
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py60
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py76
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py80
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py277
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py19
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py83
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py17
-rw-r--r--tensorflow/python/kernel_tests/stack_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/template_test.py18
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py22
-rw-r--r--tensorflow/python/layers/base.py4
-rw-r--r--tensorflow/python/layers/convolutional_test.py8
-rw-r--r--tensorflow/python/layers/core_test.py23
-rw-r--r--tensorflow/python/layers/normalization_test.py28
-rw-r--r--tensorflow/python/lib/core/py_func.cc2
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc25
-rw-r--r--tensorflow/python/lib/core/py_util.h6
-rw-r--r--tensorflow/python/lib/io/file_io.i2
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc2
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc19
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h7
-rw-r--r--tensorflow/python/lib/io/py_record_writer.i22
-rw-r--r--tensorflow/python/lib/io/python_io.py2
-rw-r--r--tensorflow/python/lib/io/tf_record.py112
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py113
-rw-r--r--tensorflow/python/ops/array_grad.py100
-rw-r--r--tensorflow/python/ops/array_ops.py62
-rw-r--r--tensorflow/python/ops/check_ops.py58
-rw-r--r--tensorflow/python/ops/clip_ops.py6
-rw-r--r--tensorflow/python/ops/clip_ops_test.py2
-rw-r--r--tensorflow/python/ops/collective_ops_test.py15
-rw-r--r--tensorflow/python/ops/cond_v2.py2
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py193
-rw-r--r--tensorflow/python/ops/control_flow_ops.py13
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py36
-rw-r--r--tensorflow/python/ops/custom_gradient.py10
-rw-r--r--tensorflow/python/ops/data_flow_ops.py20
-rw-r--r--tensorflow/python/ops/dequantize_op_test.py2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py18
-rw-r--r--tensorflow/python/ops/embedding_ops.py7
-rw-r--r--tensorflow/python/ops/functional_ops.py5
-rw-r--r--tensorflow/python/ops/gradient_checker_test.py16
-rw-r--r--tensorflow/python/ops/gradients.py2
-rw-r--r--tensorflow/python/ops/gradients_impl.py45
-rw-r--r--tensorflow/python/ops/gradients_test.py69
-rw-r--r--tensorflow/python/ops/histogram_ops.py2
-rw-r--r--tensorflow/python/ops/histogram_ops_test.py8
-rw-r--r--tensorflow/python/ops/image_grad_test.py16
-rw-r--r--tensorflow/python/ops/image_ops.py2
-rw-r--r--tensorflow/python/ops/image_ops_impl.py2
-rw-r--r--tensorflow/python/ops/image_ops_test.py69
-rw-r--r--tensorflow/python/ops/init_ops.py109
-rw-r--r--tensorflow/python/ops/init_ops_test.py66
-rw-r--r--tensorflow/python/ops/io_ops.py40
-rw-r--r--tensorflow/python/ops/list_ops.py15
-rw-r--r--tensorflow/python/ops/lookup_ops.py11
-rw-r--r--tensorflow/python/ops/math_grad.py10
-rw-r--r--tensorflow/python/ops/math_grad_test.py45
-rw-r--r--tensorflow/python/ops/math_ops.py139
-rw-r--r--tensorflow/python/ops/math_ops_test.py39
-rw-r--r--tensorflow/python/ops/metrics_impl.py202
-rw-r--r--tensorflow/python/ops/nn.py2
-rw-r--r--tensorflow/python/ops/nn_batchnorm_test.py10
-rw-r--r--tensorflow/python/ops/nn_grad.py26
-rw-r--r--tensorflow/python/ops/nn_grad_test.py2
-rw-r--r--tensorflow/python/ops/nn_impl.py4
-rw-r--r--tensorflow/python/ops/nn_ops.py18
-rw-r--r--tensorflow/python/ops/nn_test.py58
-rw-r--r--tensorflow/python/ops/nn_xent_test.py10
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops.py13
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops_test.py2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py9
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py13
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py6
-rw-r--r--tensorflow/python/ops/parsing_ops.py524
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py286
-rw-r--r--tensorflow/python/ops/rnn.py50
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py81
-rw-r--r--tensorflow/python/ops/script_ops.py11
-rw-r--r--tensorflow/python/ops/session_ops.py6
-rw-r--r--tensorflow/python/ops/sparse_ops.py119
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py81
-rw-r--r--tensorflow/python/ops/state_ops.py188
-rw-r--r--tensorflow/python/ops/string_ops.py71
-rw-r--r--tensorflow/python/ops/variable_scope.py89
-rw-r--r--tensorflow/python/ops/variables.py512
-rw-r--r--tensorflow/python/platform/test.py2
-rwxr-xr-x[-rw-r--r--]tensorflow/python/pywrap_tfe.i21
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/loader_test.py18
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py170
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py79
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py38
-rw-r--r--tensorflow/python/saved_model/simple_save_test.py4
-rw-r--r--tensorflow/python/saved_model/utils_impl.py2
-rw-r--r--tensorflow/python/summary/summary.py2
-rw-r--r--tensorflow/python/summary/summary_test.py20
-rw-r--r--tensorflow/python/summary/text_summary_test.py2
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/tools/BUILD7
-rw-r--r--tensorflow/python/tools/api/generator/BUILD8
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl34
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl2
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/doc_srcs.py2
-rw-r--r--tensorflow/python/tools/component_api_helper.py86
-rw-r--r--tensorflow/python/tools/freeze_graph.py68
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py8
-rw-r--r--tensorflow/python/tools/print_selective_registration_header_test.py12
-rw-r--r--tensorflow/python/tools/saved_model_cli.py25
-rw-r--r--tensorflow/python/tools/selective_registration_header_lib.py17
-rw-r--r--tensorflow/python/training/adadelta_test.py4
-rw-r--r--tensorflow/python/training/adagrad.py26
-rw-r--r--tensorflow/python/training/adagrad_da_test.py10
-rw-r--r--tensorflow/python/training/adagrad_test.py49
-rw-r--r--tensorflow/python/training/adam.py8
-rw-r--r--tensorflow/python/training/adam_test.py12
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py6
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py47
-rw-r--r--tensorflow/python/training/checkpoint_management.py18
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py59
-rw-r--r--tensorflow/python/training/checkpoint_ops.py3
-rw-r--r--tensorflow/python/training/checkpoint_ops_test.py18
-rw-r--r--tensorflow/python/training/checkpoint_utils_test.py42
-rw-r--r--tensorflow/python/training/checkpointable/BUILD13
-rw-r--r--tensorflow/python/training/checkpointable/base.py192
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py56
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py105
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py9
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py2
-rw-r--r--tensorflow/python/training/checkpointable/util.py328
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py95
-rw-r--r--tensorflow/python/training/distribute.py52
-rw-r--r--tensorflow/python/training/ftrl_test.py129
-rw-r--r--tensorflow/python/training/gradient_descent_test.py18
-rw-r--r--tensorflow/python/training/input.py83
-rw-r--r--tensorflow/python/training/input_test.py94
-rw-r--r--tensorflow/python/training/learning_rate_decay.py432
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py2
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2.py898
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2_test.py497
-rw-r--r--tensorflow/python/training/momentum_test.py14
-rw-r--r--tensorflow/python/training/monitored_session.py122
-rw-r--r--tensorflow/python/training/monitored_session_test.py176
-rw-r--r--tensorflow/python/training/moving_averages.py55
-rw-r--r--tensorflow/python/training/moving_averages_test.py51
-rw-r--r--tensorflow/python/training/optimizer.py7
-rw-r--r--tensorflow/python/training/optimizer_test.py8
-rw-r--r--tensorflow/python/training/proximal_adagrad_test.py18
-rw-r--r--tensorflow/python/training/proximal_gradient_descent_test.py16
-rw-r--r--tensorflow/python/training/queue_runner_impl.py22
-rw-r--r--tensorflow/python/training/queue_runner_test.py28
-rw-r--r--tensorflow/python/training/rmsprop_test.py4
-rw-r--r--tensorflow/python/training/saver.py18
-rw-r--r--tensorflow/python/training/saver_test.py144
-rw-r--r--tensorflow/python/training/session_manager_test.py28
-rw-r--r--tensorflow/python/training/slot_creator_test.py14
-rw-r--r--tensorflow/python/training/supervisor_test.py6
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py6
-rw-r--r--tensorflow/python/training/training.py2
-rw-r--r--tensorflow/python/training/training_util.py2
-rw-r--r--tensorflow/python/training/warm_starting_util.py102
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py216
-rw-r--r--tensorflow/python/util/nest.py76
-rw-r--r--tensorflow/python/util/nest_test.py2
-rw-r--r--tensorflow/python/util/tf_export.py13
-rw-r--r--tensorflow/python/util/tf_should_use_test.py5
-rw-r--r--tensorflow/python/util/util.cc44
-rw-r--r--tensorflow/python/util/util.h25
-rw-r--r--tensorflow/python/util/util.i55
-rw-r--r--tensorflow/stream_executor/blas.h1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc22
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc4
-rw-r--r--tensorflow/stream_executor/dso_loader.cc8
-rw-r--r--tensorflow/stream_executor/kernel.cc2
-rw-r--r--tensorflow/stream_executor/kernel_spec.cc6
-rw-r--r--tensorflow/stream_executor/lib/env.h2
-rw-r--r--tensorflow/stream_executor/lib/path.cc2
-rw-r--r--tensorflow/stream_executor/lib/statusor_internals.h1
-rw-r--r--tensorflow/stream_executor/lib/str_util.h2
-rw-r--r--tensorflow/stream_executor/stream.cc18
-rw-r--r--tensorflow/tensorflow.bzl45
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt26
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt42
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt10
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt45
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt3
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt26
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt102
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt49
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt68
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt49
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt15
-rw-r--r--tensorflow/tools/api/tests/BUILD5
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py14
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cmake4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.0483
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.gpu6
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh28
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh1
-rwxr-xr-xtensorflow/tools/ci_build/install/install_deb_packages.sh6
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh10
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh5
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh5
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_docker.sh1
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py45
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh2
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/common_env.sh2
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh15
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh17
-rw-r--r--tensorflow/tools/common/public_api.py5
-rw-r--r--tensorflow/tools/compatibility/renames_v2.py1
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v0_11.py16
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v1_10.py2
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2.py24
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2_test.py13
-rw-r--r--tensorflow/tools/docker/Dockerfile6
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel6
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu17
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7117
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl-horovod4
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu13
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl-horovod4
-rw-r--r--tensorflow/tools/docker/README.md7
-rw-r--r--tensorflow/tools/dockerfiles/README.md67
-rw-r--r--tensorflow/tools/dockerfiles/assembler.Dockerfile (renamed from tensorflow/contrib/kfac/python/ops/estimator_lib.py)29
-rw-r--r--tensorflow/tools/dockerfiles/assembler.py554
-rw-r--r--tensorflow/tools/dockerfiles/bashrc50
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile100
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile89
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile69
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile58
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile126
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile115
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile95
-rw-r--r--tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile84
-rw-r--r--tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile13
-rw-r--r--tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile8
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile49
-rw-r--r--tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile28
-rw-r--r--tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile12
-rw-r--r--tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile24
-rw-r--r--tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile2
-rw-r--r--tensorflow/tools/dockerfiles/spec.yml195
-rw-r--r--tensorflow/tools/docs/doc_controls_test.py39
-rw-r--r--tensorflow/tools/docs/generate.py5
-rw-r--r--tensorflow/tools/docs/generate_lib.py99
-rw-r--r--tensorflow/tools/docs/parser.py122
-rw-r--r--tensorflow/tools/docs/parser_test.py132
-rw-r--r--tensorflow/tools/docs/pretty_docs.py71
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc2
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.h6
-rw-r--r--tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc2
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc4
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph.cc15
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc2
-rw-r--r--tensorflow/tools/pip_package/BUILD19
-rw-r--r--tensorflow/tools/pip_package/MANIFEST.in1
-rw-r--r--tensorflow/tools/pip_package/setup.py26
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib.h6
-rwxr-xr-x[-rw-r--r--]tensorflow/workspace.bzl176
3798 files changed, 133841 insertions, 107916 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 94e059b914..386e0096ff 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -12,6 +12,7 @@ exports_files([
# The leakr files are used by //third_party/cloud_tpu.
"leakr_badwords.dic",
"leakr_badfiles.dic",
+ "leakr_file_type_recipe.ftrcp",
])
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
@@ -23,6 +24,24 @@ load(
"//tensorflow/python/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files")
+load(
+ "//tensorflow/python/tools/api/generator:api_init_files.bzl",
+ "TENSORFLOW_API_INIT_FILES", # @unused
+)
+load(
+ "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
+ "TENSORFLOW_API_INIT_FILES_V1", # @unused
+)
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
+
+# @unused
+TENSORFLOW_API_INIT_FILES_V2 = (
+ TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
# Config setting used when building for products
# which requires restricted licenses to be avoided.
@@ -411,12 +430,28 @@ config_setting(
visibility = ["//visibility:public"],
)
+# This flag is set from the configure step when the user selects with nGraph option.
+# By default it should be false
+config_setting(
+ name = "with_ngraph_support",
+ values = {"define": "with_ngraph_support=true"},
+ visibility = ["//visibility:public"],
+)
+
+# This flag specifies whether TensorFlow 2.0 API should be built instead
+# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
+config_setting(
+ name = "api_version_2",
+ define_values = {"tf_api_version": "2"},
+)
+
package_group(
name = "internal",
packages = [
"-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
+ "//tensorflow_estimator/...",
"//tensorflow_fold/llgtm/...",
"//third_party/py/tensor2tensor/...",
],
@@ -563,7 +598,7 @@ tf_cc_shared_object(
"//tensorflow/cc:scope",
"//tensorflow/cc/profiler",
"//tensorflow/core:tensorflow",
- ],
+ ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]),
)
exports_files(
@@ -574,12 +609,39 @@ exports_files(
)
gen_api_init_files(
- name = "tensorflow_python_api_gen",
+ name = "tf_python_api_gen_v1",
srcs = ["api_template.__init__.py"],
api_version = 1,
+ output_dir = "_api/v1/",
+ output_files = TENSORFLOW_API_INIT_FILES_V1,
+ output_package = "tensorflow._api.v1",
+ root_init_template = "api_template.__init__.py",
+)
+
+gen_api_init_files(
+ name = "tf_python_api_gen_v2",
+ srcs = ["api_template.__init__.py"],
+ api_version = 2,
+ compat_api_versions = [1],
+ output_dir = "_api/v2/",
+ output_files = TENSORFLOW_API_INIT_FILES_V2,
+ output_package = "tensorflow._api.v2",
root_init_template = "api_template.__init__.py",
)
+genrule(
+ name = "root_init_gen",
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }),
+ outs = ["__init__.py"],
+ cmd = select({
+ "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ }),
+)
+
py_library(
name = "tensorflow_py",
srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
@@ -594,7 +656,10 @@ py_library(
py_library(
name = "tensorflow_py_no_contrib",
- srcs = [":tensorflow_python_api_gen"],
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }) + [":root_init_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"],
diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py
index 440e9f8dbd..21677512b6 100644
--- a/tensorflow/__init__.py
+++ b/tensorflow/__init__.py
@@ -28,7 +28,8 @@ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-app.flags = flags # pylint: disable=undefined-variable
+from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
+app.flags = flags
del absolute_import
del division
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 779f65d5b1..53a72b8443 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os as _os
+
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
try:
- import os # pylint: disable=g-import-not-at-top
# Add `estimator` attribute to allow access to estimator APIs via
# "tf.estimator..."
from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
@@ -30,9 +31,8 @@ try:
# Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
# style imports.
from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
- __path__ += [os.path.dirname(estimator_api.__file__)]
+ __path__ += [_os.path.dirname(estimator_api.__file__)]
del estimator_api
- del os
except (ImportError, AttributeError):
print('tf.estimator package not installed.')
@@ -45,6 +45,12 @@ del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
+# Make sure directory containing top level submodules is in
+# the __path__ so that "from tensorflow.foo import bar" works.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+if _tf_api_dir not in __path__:
+ __path__.append(_tf_api_dir)
+
del absolute_import
del division
del print_function
@@ -54,6 +60,12 @@ del print_function
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
-del python
-del core
+try:
+ del python
+ del core
+except NameError:
+ # Don't fail if these modules are not available.
+ # For e.g. we are using this file for compat.v1 module as well and
+ # 'python', 'core' directories are not under compat/v1.
+ pass
# pylint: enable=undefined-variable
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 8a9301d584..43c279bd80 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -117,6 +117,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
@@ -127,6 +128,15 @@ tf_cuda_library(
],
)
+cc_library(
+ name = "c_api_headers",
+ hdrs = [
+ "c_api.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
exports_files(
[
"version_script.lds",
@@ -194,6 +204,7 @@ tf_cuda_cc_test(
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 19ccb6e71d..173bbea596 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -202,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf->len_ = len;
if (dtype != TF_STRING && dtype != TF_RESOURCE &&
tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
- reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
+ reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
+ 0) {
// TF_STRING and TF_RESOURCE tensors have a different representation in
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
// (any alignment requirements will be taken care of by TF_TensorToTensor
@@ -1239,7 +1240,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
const char* value, size_t length) {
tensorflow::NameAttrList func_name;
- func_name.set_name(std::string(value, value + length));
+ func_name.set_name(string(value, value + length));
desc->node_builder.Attr(attr_name, func_name);
}
@@ -2064,7 +2065,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
- tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
+ tf_results->missing_unused_key_names_data.emplace_back(id.first);
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 69b3ffe2a1..c046bd66cd 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
auto* gpu_options = config.mutable_gpu_options();
gpu_options->set_allow_growth(gpu_memory_allow_growth);
+ // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
+ // threadpool, so that we avoid the possibility of running the runner_ in the
+ // threadpool of GPU event mgr, as that can trigger more callbacks to be
+ // scheduled on that same threadpool, causing a deadlock in cases where the
+ // caller of event_mgr->ThenExecute() blocks on the completion of the callback
+ // (as in the case of ConstOp kernel creation on GPU, which involves copying a
+ // CPU tensor to GPU).
+ // Setting a larger thread pool does not help with the Swift caller, as we use
+ // a different TFE context for each thread of execution (for running graph
+ // functions, and their send/recvs corountines).
+ config.set_inter_op_parallelism_threads(1);
+
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(config, ret));
return ret;
@@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}
+
+TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
+ TF_Status* status) {
+ auto* opts = TFE_NewContextOptions();
+
+ // Reduce GPU memory allocation, and set appropriate config options for TFE
+ // context.
+ auto* config =
+ TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
+ TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
+ if (!status->status.ok()) {
+ CHECK(!config);
+ TFE_DeleteContextOptions(opts);
+ return nullptr;
+ }
+
+ auto* ctx = TFE_NewContextFromSession(opts, session, status);
+ TF_DeleteBuffer(config);
+ TFE_DeleteContextOptions(opts);
+ return ctx;
+}
+
+// TODO: retrieve the device string via TFE_ContextListDevices()
+static const char DEFAULT_CPU_DEVICE[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+
+static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
+ int tensor_id, TF_Status* status) {
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
+ TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
+ TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+ // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
+ TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
+ TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
+ auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
+ TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
+ shared_name.size());
+ TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
+
+ // TODO: consider making this an unknown shape.
+ const int64_t* dims_ptr = nullptr;
+ int num_dims = 0;
+ TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
+ /*num_values*/ 0, status);
+ if (!status->status.ok()) return nullptr;
+
+ int num_retvals = 1;
+ TFE_TensorHandle* queue = nullptr;
+ TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+
+ return queue;
+}
+
+static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
+ TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, tensor, status);
+ if (!status->status.ok()) return;
+ TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+
+ int num_retvals = 0;
+ TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
+ if (!status->status.ok()) return;
+ CHECK_EQ(num_retvals, 0);
+}
+
+static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
+ TF_DataType inputType,
+ TFE_TensorHandle* queue,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return nullptr;
+ TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+ TFE_TensorHandle* ret;
+ int num_retvals = 1;
+ TFE_Execute(op, &ret, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+
+ return ret;
+}
+
+void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
+}
+
+TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
+ TF_Status* status) {
+ VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ return createTFEDequeue(ctx, TF_VARIANT, queue, status);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 6617c5a572..522c91f67e 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
@@ -131,6 +132,48 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
+// TODO: remove this API in favor of the next one.
+TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
+ const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
+
+// Creates from `session` a new eager context to run a graph function or
+// sends/recvs, so that these concurrent TFE executions can share (via
+// `session` and its associated device mgr) the same set of fifo queue resource
+// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
+// graph function execution can access the same fifo queue resource handles
+// (associated with devices managed by the device manager, which can be obtained
+// from `session`).
+//
+// TODO: Remove this function once we migrate away from using session.
+TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
+ TF_Session* session, TF_Status* status);
+
+// TODO: Retire this API in favor of the next one.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
+ TF_Session* session, int tensor_id, TF_DataType inputType,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
+ TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
+ TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+// TODO: consider folding the 2 APIs below into the ones above.
+TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
+ int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
+ TF_Session* session, int tensor_id, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index aa2a537f03..03516c39dc 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) {
TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
nullptr, 0, run_metadata, s);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(std::string("Session was not created with a graph before Run()!"),
- std::string(TF_Message(s)));
+ EXPECT_EQ("Session was not created with a graph before Run()!",
+ string(TF_Message(s)));
TF_DeleteBuffer(run_metadata);
TF_DeleteBuffer(run_options);
@@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test {
TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
if (expected.empty()) {
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
- EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."),
- std::string(TF_Message(s_)));
+ EXPECT_EQ("Operation 'add' has no attr named '_class'.",
+ string(TF_Message(s_)));
return;
}
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
@@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) {
input.flat<string>()(i) = example.SerializeAsString();
}
- const tensorflow::string input_op_name =
- std::string(tensorflow::ParseTensorName(input_name).first);
+ const tensorflow::string input_op_name(
+ tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- const tensorflow::string output_op_name =
- std::string(tensorflow::ParseTensorName(output_name).first);
+ const tensorflow::string output_op_name(
+ tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index 74bc25a491..d3311f0cd0 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
- std::string(v2_reader_->key()) /* full var's name */,
+ string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
@@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
- if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
+ if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
- string key = std::string(v2_reader_->key());
+ string key(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h
index 4de1300a7f..91654c8d4f 100644
--- a/tensorflow/c/checkpoint_reader.h
+++ b/tensorflow/c/checkpoint_reader.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
-#define TENSORFLOW_C_CHECKPOINT_READER_H
+#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_
+#define TENSORFLOW_C_CHECKPOINT_READER_H_
#include <memory>
#include <string>
@@ -79,4 +79,4 @@ class CheckpointReader {
} // namespace checkpoint
} // namespace tensorflow
-#endif // TENSORFLOW_C_CHECKPOINT_READER_H
+#endif // TENSORFLOW_C_CHECKPOINT_READER_H_
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index dfb1c9a376..349d9bcd7c 100644..100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -244,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
}
void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
- unsigned char async) {
- options->async = async;
+ unsigned char enable) {
+ options->async = enable;
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
@@ -253,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
}
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
- unsigned char async,
+ unsigned char enable,
TF_Status* status) {
- status->status = ctx->context.SetAsyncForThread(async);
+ status->status = ctx->context.SetAsyncForThread(enable);
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
@@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, std::move(device_mgr), r);
+ opts->async, device_mgr.release(),
+ /*device_mgr_owned*/ true, r);
+}
+
+TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
+ TF_Session* sess, TF_Status* status) {
+ const tensorflow::DeviceMgr* device_mgr = nullptr;
+ status->status = sess->session->LocalDeviceManager(&device_mgr);
+ if (!status->status.ok()) return nullptr;
+ tensorflow::Rendezvous* r =
+ new tensorflow::IntraProcessRendezvous(device_mgr);
+ return new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, /*device_mgr_owned*/ false,
+ r);
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
@@ -386,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
: d->name().c_str();
}
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+
+ h->handle->Ref();
+
+ return new TFE_TensorHandle(h->handle);
+}
+
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index a0ebc6fa0a..337447eec9 100644..100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetAsyncForThread.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
- unsigned char async);
+ unsigned char enable);
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
@@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
// Overrides the execution mode (sync/async) for the current thread.
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
- unsigned char async,
+ unsigned char enable,
TF_Status* status);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
@@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
+// with `h`. On success, `status` is set to OK. On failure, `status` reflects
+// the error and a nullptr is returned.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status);
+
// This function will block till the operation that produces `h` has
// completed. The memory returned might alias the internal memory used by
// TensorFlow. Hence, callers should not mutate this memory (for example by
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index a5c0681e2e..104d52430c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -62,15 +62,14 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
- explicit TFE_Context(const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy,
- bool async,
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
- tensorflow::Rendezvous* rendezvous)
+ TFE_Context(const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
+ tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
- async, std::move(device_mgr), rendezvous) {}
+ async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 71d5f3613c..55331022b9 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1471,4 +1471,86 @@ void BM_ReadVariable(int iters) {
}
BENCHMARK(BM_ReadVariable);
+TEST(CAPI, StringAttributes) {
+ // Test that TFE_OpSetAttrString doesn't hold on to the value after it
+ // returns.
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ std::vector<int64_t> dims(4, 1);
+ TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* tensor =
+ TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
+ float tensor_data[] = {1};
+ memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, tensor_handle, status);
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(tensor_handle);
+
+ std::vector<int64_t> values(4, 1);
+ TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
+ TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
+
+ const int BUFFER_SIZE = 10;
+ char buffer[BUFFER_SIZE];
+ std::strncpy(buffer, "VALID", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
+ // Overwriting value in "buffer", should be fine since TFE_Op
+ // shouldn't be holding on to it.
+ std::strncpy(buffer, "NHWC", BUFFER_SIZE);
+ TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
+
+ TFE_OpSetAttrType(op, "T", TF_FLOAT);
+
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(op, &retvals[0], &num_retvals, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ ASSERT_EQ(1, num_retvals);
+
+ tensor = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ EXPECT_EQ(4, TF_TensorByteSize(tensor));
+ TF_DeleteTensor(tensor);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(op);
+
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
+TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle();
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+
+ TFE_TensorHandle* h_shares_tensor =
+ TFE_TensorHandleCopySharingTensor(h, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
+ ASSERT_EQ(16, TF_TensorByteSize(t));
+ float data[4] = {0};
+ memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+ EXPECT_EQ(1.0, data[0]);
+ EXPECT_EQ(2.0, data[1]);
+ EXPECT_EQ(3.0, data[2]);
+ EXPECT_EQ(4.0, data[3]);
+ TF_DeleteTensor(t);
+
+ TFE_DeleteTensorHandle(h);
+ TFE_DeleteTensorHandle(h_shares_tensor);
+}
} // namespace
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 1adb0458c3..ce038a4b57 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -440,6 +440,15 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
return Status::OK();
}
+gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
+ static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
+ return m;
+}
+
} // namespace
// If over kMinAggregateCount gradients are accumulated and the total
@@ -485,10 +494,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
VLOG(1) << " " << t;
}
}
- gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
- {"SoftmaxCrossEntropyWithLogits", {1}},
- {"FusedBatchNorm", {1, 2, 3, 4}},
- });
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;
@@ -509,8 +514,8 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
- functions_accept_none_for_indices.find(trace.op_type);
- if (func_name_it != functions_accept_none_for_indices.end() &&
+ FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
+ if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h
index 86e687df20..7661a01de4 100644
--- a/tensorflow/c/tf_status_helper.h
+++ b/tensorflow/c/tf_status_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H
-#define TENSORFLOW_C_TF_STATUS_HELPER_H
+#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_
+#define TENSORFLOW_C_TF_STATUS_HELPER_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/status.h"
@@ -29,4 +29,4 @@ Status StatusFromTF_Status(const TF_Status* tf_status);
} // namespace tensorflow
-#endif // TENSORFLOW_C_TF_STATUS_HELPER_H
+#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index dfdef88945..a32d1b1eb5 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
- return std::string(name);
+ return string(name);
}
void InferArgAttributes(const OpDef::ArgDef& arg,
@@ -508,15 +508,6 @@ bool HasOptionalAttrs(
return false;
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
struct OpInfo {
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h
index a085e1d6e2..0717e7dd4b 100644
--- a/tensorflow/cc/framework/ops.h
+++ b/tensorflow/cc/framework/ops.h
@@ -150,7 +150,7 @@ class Input {
Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), shape);
- if (t.NumElements() != v.size()) {
+ if (t.NumElements() != static_cast<int64>(v.size())) {
status = errors::InvalidArgument(
"Cannot construct a tensor with ", t.NumElements(),
" from an initializer list with ", v.size(), " elements");
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 8c886f3171..7f6ac4cae7 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -225,7 +225,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
- current_constraints.insert(std::string(s));
+ current_constraints.emplace(s);
}
}
} else {
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 5dcf00857d..1329b568ab 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const Scope& scope, const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
+Status DivNoNanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
- auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
- auto gx_2 =
- Mul(scope, grad_inputs[0],
- UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
-REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 88aef1fab4..c16938322c 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,6 +33,7 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
+using ops::DivNoNan;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@@ -48,7 +49,6 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
-using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@@ -854,13 +854,13 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
-TEST_F(NaryGradTest, UnsafeDiv) {
+TEST_F(NaryGradTest, DivNoNan) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
- const auto y = UnsafeDiv(
+ const auto y = DivNoNan(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
@@ -868,7 +868,7 @@ TEST_F(NaryGradTest, UnsafeDiv) {
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
- const auto y = UnsafeDiv(scope_, x, zero);
+ const auto y = DivNoNan(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index 3830416159..c6abe2f41b 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir,
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
- return RunOnce(run_options, inputs, {}, {main_op_name.ToString()},
+ return RunOnce(run_options, inputs, {}, {string(main_op_name)},
nullptr /* outputs */, &run_metadata, session);
}
return Status::OK();
@@ -182,12 +182,12 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
variables_path_tensor.scalar<string>()() = variables_path;
std::vector<std::pair<string, Tensor>> inputs = {
- {variable_filename_const_op_name.ToString(), variables_path_tensor}};
+ {string(variable_filename_const_op_name), variables_path_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata;
- return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()},
+ return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
nullptr /* outputs */, &run_metadata, session);
}
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 1899a32e4d..6c29f09cde 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -32,7 +32,6 @@ cc_library(
deps = [
":embedded_protocol_buffers",
"//tensorflow/compiler/tf2xla",
- "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
@@ -55,6 +54,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -71,6 +73,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
"@llvm//:support", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
],
@@ -99,6 +102,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -188,11 +192,13 @@ cc_library(
srcs = ["embedded_protocol_buffers.cc"],
hdrs = ["embedded_protocol_buffers.h"],
deps = [
- "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 89fefdad54..b17bc658fa 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -19,18 +19,19 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
@@ -134,14 +135,14 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
indices = "[0]";
} else {
for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
- dim_vars.push_back(strings::StrCat("size_t dim", dim));
- dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
- indices += strings::StrCat("[dim", dim, "]");
+ dim_vars.push_back(absl::StrCat("size_t dim", dim));
+ dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
+ indices += absl::StrCat("[dim", dim, "]");
}
}
- rewrites->push_back({"{{I}}", strings::StrCat(i)});
+ rewrites->push_back({"{{I}}", absl::StrCat(i)});
rewrites->push_back({"{{TYPE}}", type});
- rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
+ rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
rewrites->push_back({"{{INDICES}}", indices});
return Status::OK();
@@ -157,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape,
// text-templating mechanism.
string RewriteWithName(const string& name, string code,
const std::vector<std::pair<string, string>>& rewrites) {
- str_util::ReplaceAllPairs(&code, rewrites);
- return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true);
+ absl::StrReplaceAll(rewrites, &code);
+ absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
+ return code;
}
// Generate methods for args (inputs).
@@ -192,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
arg_data({{I}}))){{INDICES}};
}
)";
- *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+ *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.feed(i).name().empty()) {
*methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
}
@@ -233,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config,
result_data({{I}}))){{INDICES}};
}
)";
- *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+ *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
if (!config.fetch(i).name().empty()) {
*methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
}
@@ -302,8 +304,8 @@ std::vector<string> BufferInfosToCppExpression(
string encoded_second_as_str =
encoded.second == ~0ULL
? "~0ULL"
- : strings::StrCat(encoded.second, "ULL");
- return strings::StrCat(
+ : absl::StrCat(encoded.second, "ULL");
+ return absl::StrCat(
"::tensorflow::cpu_function_runtime::BufferInfo({",
encoded.first, "ULL, ", encoded_second_as_str, "})");
});
@@ -350,13 +352,13 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// Create rewrite strings for namespace start and end.
string ns_start;
for (const string& n : opts.namespaces) {
- ns_start += strings::StrCat("namespace ", n, " {\n");
+ ns_start += absl::StrCat("namespace ", n, " {\n");
}
ns_start += "\n";
string ns_end("\n");
for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
const string& n = opts.namespaces[i];
- ns_end += strings::StrCat("} // end namespace ", n, "\n");
+ ns_end += absl::StrCat("} // end namespace ", n, "\n");
}
// Generate metadata.
@@ -566,15 +568,15 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
)";
// The replacement strategy is naive, but good enough for our purposes.
const std::vector<std::pair<string, string>> rewrites = {
- {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
- {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
+ {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
+ {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
- {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
- {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
+ {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
+ {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
+ absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
@@ -588,25 +590,25 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
- {"{{RESULT_INDEX}}", strings::StrCat(result_index)},
+ {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
- {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
- {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
- {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
+ {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
+ {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
+ {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
{"{{BUFFER_INFOS_AS_STRING}}",
- str_util::Join(buffer_infos_as_strings, ",\n")}};
- str_util::ReplaceAllPairs(header, rewrites);
+ absl::StrJoin(buffer_infos_as_strings, ",\n")}};
+ absl::StrReplaceAll(rewrites, header);
return Status::OK();
}
static string CreateUniqueIdentifier(const CodegenOpts& opts,
- StringPiece suffix) {
+ absl::string_view suffix) {
string result = "__tfcompile";
for (const string& n : opts.namespaces) {
- strings::StrAppend(&result, "_", n);
+ absl::StrAppend(&result, "_", n);
}
- strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
+ absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
return result;
}
@@ -617,7 +619,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
if (opts.gen_program_shape) {
program_shape =
- tensorflow::MakeUnique<xla::ProgramShape>(compile_result.program_shape);
+ absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
+
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
// space.
@@ -675,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
return Status::OK();
}
-Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
+Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
if (ident.empty()) {
return errors::InvalidArgument("empty identifier: ", msg);
}
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 83f2d3ee11..90410c46a8 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
@@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
// appended to error messages.
-Status ValidateCppIdent(StringPiece ident, StringPiece msg);
+Status ValidateCppIdent(absl::string_view ident, absl::string_view msg);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 60d59ae996..bb288d2300 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
#include "llvm/Support/TargetSelect.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -34,9 +34,9 @@ namespace {
using ::tensorflow::cpu_function_runtime::BufferInfo;
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 4e27aafec7..3c32d533f6 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_replace.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/LLVMContext.h"
@@ -26,8 +28,6 @@ limitations under the License.
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "tensorflow/compiler/tf2xla/str_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef;
static void AddEmbeddedProtocolBufferToLlvmModule(
llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
- StringPiece unique_identifier, string* protobuf_array_symbol_name,
+ absl::string_view unique_identifier, string* protobuf_array_symbol_name,
int64* protobuf_array_size) {
string protobuf_array_contents = proto.SerializeAsString();
*protobuf_array_symbol_name =
- strings::StrCat(unique_identifier, "_protobuf_array_contents");
+ absl::StrCat(unique_identifier, "_protobuf_array_contents");
*protobuf_array_size = protobuf_array_contents.size();
llvm::Constant* protobuf_array_initializer =
@@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule(
protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
}
-static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
- StringPiece protobuf_array_symbol_name,
- int64 protobuf_array_size) {
+static string CreateCPPShimExpression(
+ absl::string_view qualified_cpp_protobuf_name,
+ absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) {
string code =
"[]() {\n"
" {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
@@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
" return proto;\n"
" }()";
- str_util::ReplaceAllPairs(
- &code,
+ return absl::StrReplaceAll(
+ code,
{
- {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)},
- {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)},
- {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)},
+ {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
+ {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
+ {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
});
- return code;
}
static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
@@ -94,10 +93,10 @@ static StatusOr<string> CodegenModule(llvm::TargetMachine* target_machine,
}
static StatusOr<std::unique_ptr<llvm::TargetMachine>>
-GetTargetMachineFromTriple(StringPiece target_triple) {
+GetTargetMachineFromTriple(absl::string_view target_triple) {
std::string error;
std::string normalized_triple =
- llvm::Triple::normalize(AsStringRef(target_triple));
+ llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(normalized_triple, error);
if (target == nullptr) {
@@ -105,20 +104,20 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
error.c_str());
}
- return WrapUnique(target->createTargetMachine(
+ return absl::WrapUnique(target->createTargetMachine(
normalized_triple, /*CPU=*/"",
/*Features=*/"", llvm::TargetOptions(), llvm::None));
}
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
- StringPiece target_triple,
- gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
+ absl::string_view target_triple,
+ absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
GetTargetMachineFromTriple(target_triple));
llvm::LLVMContext llvm_context;
std::unique_ptr<llvm::Module> module_with_serialized_proto =
- MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
+ absl::make_unique<llvm::Module>("embedded_data_module", llvm_context);
EmbeddedProtocolBuffers result;
@@ -136,8 +135,8 @@ StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
protobuf_to_embed.qualified_cpp_protobuf_name,
protobuf_array_symbol_name, protobuf_array_size);
- cpp_variable_decl = strings::StrCat("extern \"C\" char ",
- protobuf_array_symbol_name, "[];");
+ cpp_variable_decl =
+ absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
} else {
cpp_shim = "nullptr";
}
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index 4e194a6aba..cf5c04ac4b 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -20,8 +20,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -83,8 +83,8 @@ struct ProtobufToEmbed {
// is stored in the object_file_data field in the returned
// EmbeddedProtocolBuffers instance.
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
- StringPiece target_triple,
- gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
+ absl::string_view target_triple,
+ absl::Span<const ProtobufToEmbed> protobufs_to_embed);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 0ecc3feeb6..8d94f5495c 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -67,7 +67,12 @@ genrule(
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
],
- cmd = "$(location :make_test_graphs) --out_dir $(@D)",
+ # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
+ # GPUs which might be present. This is important because builds may run
+ # concurrently with tests, and tests need to be able to assume that they
+ # have control of the full GPU.
+ cmd = "CUDA_VISIBLE_DEVICES='' " +
+ "$(location :make_test_graphs) --out_dir $(@D)",
tags = ["manual"],
tools = [":make_test_graphs"],
)
@@ -187,6 +192,9 @@ tf_library(
cpp_class = "MatMulAndAddCompWithProfiling",
enable_xla_hlo_profiling = True,
graph = "test_graph_tfmatmulandadd.pb",
+ tags = [
+ "manual",
+ ],
)
tf_library(
@@ -226,5 +234,6 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 0c0c676ece..dd2b151098 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
+#include "absl/strings/str_split.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) {
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
- tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+ absl::StrSplit(hlo_profile_as_string, '\n');
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 326f73b975..792b7fe14a 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -105,12 +105,18 @@ def tf_library(
freeze_file = freeze_name + ".pb"
# First run tfcompile to generate the list of out_nodes.
+ #
+ # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we
+ # launch from using any GPUs which might be present. This is important
+ # because builds may run concurrently with tests, and tests need to be
+ # able to assume that they have control of the full GPU.
out_nodes_file = "out_nodes_" + freeze_name
native.genrule(
name = ("gen_" + out_nodes_file),
srcs = [config],
outs = [out_nodes_file],
- cmd = ("$(location " + tfcompile_tool + ")" +
+ cmd = ("CUDA_VISIBLE_DEVICES='' " +
+ "$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools = [tfcompile_tool],
@@ -142,9 +148,12 @@ def tf_library(
out_nodes_file,
] + freeze_saver_srcs,
outs = [freeze_file],
- cmd = ("$(location " +
- "//tensorflow/python/tools:freeze_graph)" +
- freeze_args),
+ cmd = (
+ "CUDA_VISIBLE_DEVICES='' " +
+ "$(location " +
+ "//tensorflow/python/tools:freeze_graph)" +
+ freeze_args
+ ),
tools = ["//tensorflow/python/tools:freeze_graph"],
tags = tags,
)
@@ -177,16 +186,19 @@ def tf_library(
metadata_object_file,
function_object_file,
],
- cmd = ("$(location " + tfcompile_tool + ")" +
- " --graph=$(location " + tfcompile_graph + ")" +
- " --config=$(location " + config + ")" +
- " --entry_point=" + ep +
- " --cpp_class=" + cpp_class +
- " --target_triple=" + target_llvm_triple() +
- " --out_header=$(@D)/" + header_file +
- " --out_metadata_object=$(@D)/" + metadata_object_file +
- " --out_function_object=$(@D)/" + function_object_file +
- " " + flags + " " + profiling_flag),
+ cmd = (
+ "CUDA_VISIBLE_DEVICES='' " +
+ "$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_header=$(@D)/" + header_file +
+ " --out_metadata_object=$(@D)/" + metadata_object_file +
+ " --out_function_object=$(@D)/" + function_object_file +
+ " " + flags + " " + profiling_flag
+ ),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
@@ -216,14 +228,17 @@ def tf_library(
outs = [
session_module_pb,
],
- cmd = ("$(location " + tfcompile_tool + ")" +
- " --graph=$(location " + tfcompile_graph + ")" +
- " --config=$(location " + config + ")" +
- " --entry_point=" + ep +
- " --cpp_class=" + cpp_class +
- " --target_triple=" + target_llvm_triple() +
- " --out_session_module=$(@D)/" + session_module_pb +
- " " + flags),
+ cmd = (
+ "CUDA_VISIBLE_DEVICES='' " +
+ "$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_session_module=$(@D)/" + session_module_pb +
+ " " + flags
+ ),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 839e1588b7..b95b063348 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -18,6 +18,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@@ -32,9 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,7 +56,7 @@ const char kUsageHeader[] =
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
- if (str_util::EndsWith(fname, ".pbtxt")) {
+ if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
@@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) {
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
- std::cout << str_util::Join(nodes, ",");
+ std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
@@ -91,8 +92,9 @@ Status Main(const MainFlags& flags) {
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
- TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
- StringPiece(obj.data(), obj.size())));
+ TF_RETURN_IF_ERROR(
+ WriteStringToFile(env, flags.out_function_object,
+ absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index e059f77563..352f63bc98 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -128,11 +128,11 @@ cc_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:shaped_buffer",
- "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -191,6 +191,7 @@ cc_library(
"//tensorflow/core/kernels/data:generator_dataset_op",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
+ "@com_google_absl//absl/memory",
],
)
@@ -235,6 +236,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
)
@@ -283,6 +285,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
@@ -303,6 +306,52 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_library(
+ name = "resource_operation_safety_analysis",
+ srcs = ["resource_operation_safety_analysis.cc"],
+ hdrs = ["resource_operation_safety_analysis.h"],
+ deps = [
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_safety_analysis_test",
+ srcs = ["resource_operation_safety_analysis_test.cc"],
+ deps = [
+ ":common",
+ ":resource_operation_safety_analysis",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -313,6 +362,7 @@ cc_library(
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
+ "encapsulate_xla_computations_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -321,6 +371,7 @@ cc_library(
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
+ "encapsulate_xla_computations_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@@ -331,11 +382,10 @@ cc_library(
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
- "//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
- "//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -347,6 +397,9 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -355,12 +408,14 @@ cc_library(
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
+ ":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
- "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -422,6 +477,7 @@ tf_cc_test(
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
+ "encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@@ -429,13 +485,17 @@ tf_cc_test(
":common",
":compilation_passes",
":xla_cluster_util",
+ ":xla_gpu_device",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -444,6 +504,9 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -514,6 +577,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "@com_google_absl//absl/strings",
],
)
@@ -524,6 +588,9 @@ tf_cuda_cc_test(
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index a2e6285339..56b034a30b 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -125,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr));
for (int i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
@@ -207,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
- // in device memory
+ // in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+ for (int i = 0; i < fbody->ret_types.size(); ++i) {
+ if (fbody->ret_types[i] == DT_RESOURCE) {
+ output_memory_types[i] = HOST_MEMORY;
+ }
+ }
// Create the kernel.
NameAttrList function;
@@ -223,8 +230,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
&fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
- *kernel = MakeUnique<XlaLocalLaunchBase>(&construction, constant_arg_indices,
- resource_arg_indices, function);
+ *kernel = absl::make_unique<XlaLocalLaunchBase>(
+ &construction, constant_arg_indices, resource_arg_indices, function);
return s;
}
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
index b75ab486b8..7386660762 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/create_xla_launch_op.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
for (const auto& fdef : flib) {
*(proto.add_function()) = fdef;
}
- lib_def_ =
- MakeUnique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
+ lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
+ OpRegistry::Global(), proto);
OptimizerOptions opts;
- device_mgr_ = MakeUnique<DeviceMgr>(devices_);
- pflr_ = MakeUnique<ProcessFunctionLibraryRuntime>(
+ device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 309aeffc18..9128b48da3 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
@@ -107,7 +108,7 @@ class Predicate {
virtual string ToString() const = 0;
int64 hash() const { return hash_; }
- virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
+ virtual absl::Span<Predicate* const> GetOperands() const = 0;
virtual Kind kind() const = 0;
virtual ~Predicate() {}
@@ -128,7 +129,7 @@ class Predicate {
};
int64 HashPredicateSequence(Predicate::Kind kind,
- gtl::ArraySlice<Predicate*> preds) {
+ absl::Span<Predicate* const> preds) {
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
for (Predicate* pred : preds) {
hash = Hash64Combine(hash, pred->hash());
@@ -153,13 +154,15 @@ class AndPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
+ return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
}
Kind kind() const override { return Kind::kAnd; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
- gtl::ArraySlice<Predicate*> operands() const { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
+ absl::Span<Predicate* const> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -182,12 +185,14 @@ class OrPredicate : public Predicate {
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
- return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
+ return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
}
Kind kind() const override { return Kind::kOr; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
- gtl::ArraySlice<Predicate*> operands() const { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
+ absl::Span<Predicate* const> operands() const { return operands_; }
private:
std::vector<Predicate*> operands_;
@@ -201,12 +206,14 @@ class NotPredicate : public Predicate {
operands_({operand}) {}
string ToString() const override {
- return strings::StrCat("~", operand()->ToString());
+ return absl::StrCat("~", operand()->ToString());
}
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operands_[0]; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
private:
std::array<Predicate*, 1> operands_;
@@ -233,13 +240,15 @@ class AndRecurrencePredicate : public Predicate {
Predicate* step() const { return operands_[1]; }
string ToString() const override {
- return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
- "}");
+ return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
+ "}");
}
Kind kind() const override { return Kind::kAndRecurrence; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
+ absl::Span<Predicate* const> GetOperands() const override {
+ return operands_;
+ }
private:
std::array<Predicate*, 2> operands_;
@@ -258,12 +267,12 @@ class SymbolPredicate : public Predicate {
must_be_true_(must_be_true) {}
string ToString() const override {
- return must_be_true() ? strings::StrCat("*", tensor_id_.ToString())
+ return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
: tensor_id_.ToString();
}
Kind kind() const override { return Kind::kSymbol; }
- gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
+ absl::Span<Predicate* const> GetOperands() const override { return {}; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
// "tensor_id() is live and evaluates to true".
@@ -312,11 +321,11 @@ template <typename FunctionTy>
// them.
class PredicateFactory {
public:
- Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
+ Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
return MakeAndOrImpl(operands, /*is_and=*/true);
}
- Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
+ Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
return MakeAndOrImpl(operands, /*is_and=*/false);
}
@@ -373,7 +382,7 @@ class PredicateFactory {
new PredicateT(std::forward<Args>(args)...));
}
- Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
+ Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -386,7 +395,7 @@ class PredicateFactory {
// for the owning pointers to predicate instances.
using SignatureForAndOr =
- std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
+ std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
using SignatureForNot = Predicate*;
using SignatureForAndRec = std::pair<Predicate*, Predicate*>;
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
@@ -421,8 +430,8 @@ class PredicateFactory {
};
// Common code to create AndPredicate or OrPredicate instances.
-Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
- bool is_and) {
+Predicate* PredicateFactory::MakeAndOrImpl(
+ absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
gtl::FlatSet<Predicate*> simplified_ops_set;
@@ -473,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
// NB! Because we'll use a non-owning reference to simplified_ops in the
// key for interned_and_or_instances_ we need to be careful to std::move()
// it all the way through.
- gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
std::unique_ptr<Predicate> new_pred =
is_and ? Make<AndPredicate>(std::move(simplified_ops))
: Make<OrPredicate>(std::move(simplified_ops));
@@ -495,7 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
- Status PopulateWithReversePostOrder(gtl::ArraySlice<Node*> rpo);
+ Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
@@ -508,8 +517,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
// Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
// bit of `should_revisit` if `pred` is different from the current predicate
// for the `output_idx` output of `n`.
- void SetPred(Node* n, int output_idx, Predicate* pred,
- std::vector<bool>* should_revisit) {
+ void SetPredicate(Node* n, int output_idx, Predicate* pred,
+ std::vector<bool>* should_revisit) {
auto insert_result =
predicate_map_.insert({TensorId(n->name(), output_idx), pred});
if (!insert_result.second && insert_result.first->second != pred) {
@@ -526,10 +535,10 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
}
}
- void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
- std::vector<bool>* should_revisit) {
+ void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
+ std::vector<bool>* should_revisit) {
for (int output_idx : output_idxs) {
- SetPred(n, output_idx, pred, should_revisit);
+ SetPredicate(n, output_idx, pred, should_revisit);
}
}
@@ -580,19 +589,20 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
- SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
- should_revisit);
+ SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
- SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
- should_revisit);
+ SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
input_preds.pop_back();
// Control is alive iff all inputs are alive.
- SetPred(n, Graph::kControlSlot,
- predicate_factory_.MakeAndPredicate(input_preds), should_revisit);
+ SetPredicate(n, Graph::kControlSlot,
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
@@ -623,7 +633,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
}
std::vector<Predicate*> and_ops;
- gtl::ArraySlice<Predicate*> recurrent_pred_ops =
+ absl::Span<Predicate* const> recurrent_pred_ops =
backedge_predicate->GetOperands();
bool found_sym = false;
@@ -682,14 +692,16 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// backedge.
Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false);
- SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
return Status::OK();
}
// We're visiting this merge for the first time and it is a acyclic merge.
Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
GetIncomingPreds(n, EdgeKind::kDataOnly));
- SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
+ should_revisit);
return Status::OK();
}
@@ -717,7 +729,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
Predicate* and_rec =
predicate_factory_.MakeAndRecurrencePredicate(start, step);
- SetPred(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
+ SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
return Status::OK();
}
}
@@ -733,8 +745,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
GetIncomingPreds(n, EdgeKind::kDataAndControl);
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
- SetPred(n, {0, Graph::kControlSlot},
- predicate_factory_.MakeAndPredicate(input_preds), should_revisit);
+ SetPredicate(n, {0, Graph::kControlSlot},
+ predicate_factory_.MakeAndPredicate(input_preds),
+ should_revisit);
return Status::OK();
}
@@ -744,9 +757,9 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
Predicate* pred = predicate_factory_.MakeAndPredicate(
GetIncomingPreds(n, EdgeKind::kDataAndControl));
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
- SetPred(n, output_idx, pred, should_revisit);
+ SetPredicate(n, output_idx, pred, should_revisit);
}
- SetPred(n, Graph::kControlSlot, pred, should_revisit);
+ SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
return Status::OK();
}
@@ -757,7 +770,8 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
} else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
} else if (n->IsControlTrigger()) {
- SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), nullptr);
+ SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
+ nullptr);
} else if (n->IsRecv() || n->IsHostRecv()) {
TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
} else if (n->IsNextIteration()) {
@@ -770,7 +784,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
- GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
+ GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
@@ -778,7 +792,7 @@ Status DeadnessAnalysisImpl::Populate() {
}
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
- gtl::ArraySlice<Node*> rpo) {
+ absl::Span<Node* const> rpo) {
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
//
@@ -918,7 +932,7 @@ Status ComputePredicates(const Graph& graph,
}
Status ComputePredicates(const Graph& graph,
- gtl::ArraySlice<Node*> reverse_post_order,
+ absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index 401d6e406a..3df2679c62 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -32,7 +32,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
// specified in `reverse_post_order` which must be a valid RPO for the graph
// minus NextIteration->Merge edges.
Status ComputePredicates(const Graph& graph,
- gtl::ArraySlice<Node*> reverse_post_order,
+ absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index cc9f102398..28a56044d5 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index f150bf1819..e0632ff7e4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
@@ -36,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
@@ -44,8 +47,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+void SortControlInputs(GraphDef* gdef) {
+ int64 num_nodes = gdef->node_size();
+ for (int64 i = 0; i < num_nodes; ++i) {
+ NodeDef* node = gdef->mutable_node(i);
+ // Stable sort control inputs and leave the order of data inputs unchanged.
+ std::stable_sort(node->mutable_input()->begin(),
+ node->mutable_input()->end(),
+ [](const string& a, const string& b) {
+ bool a_is_control = absl::StartsWith(a, "^");
+ bool b_is_control = absl::StartsWith(b, "^");
+ return (!a_is_control && b_is_control) ||
+ (a_is_control && b_is_control && a < b);
+ });
+ }
+}
+
namespace {
bool AreAllParentsGuaranteedConst(
@@ -755,7 +772,7 @@ Status Encapsulator::Subgraph::RecordArg(
if (inserted) {
NodeDef arg_def;
NodeDefBuilder builder(
- strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
DataType dtype = edge->dst()->input_type(edge->dst_input());
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
@@ -790,7 +807,7 @@ Status Encapsulator::Subgraph::RecordResult(
if (inserted) {
NodeDef ret_def;
NodeDefBuilder builder(
- strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
+ absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
DataType dtype = src_node->output_type(src_slot);
builder.Attr("T", dtype);
builder.Attr("index", ret_index);
@@ -950,16 +967,15 @@ Status Encapsulator::Subgraph::AddHostComputes(
}
NodeDef host_compute_def;
- NodeDefBuilder builder(strings::StrCat("outside_compilation_",
- oc_subgraph_name, "_host_compute"),
+ NodeDefBuilder builder(absl::StrCat("outside_compilation_",
+ oc_subgraph_name, "_host_compute"),
kHostComputeOp);
builder.Input(inputs);
builder.Attr("Tinputs", input_dtypes);
builder.Attr("Toutputs", output_dtypes);
builder.Attr("ancestors", host_compute_ancestors);
- builder.Attr("key",
- strings::StrCat("host_compute_channel_", subgraph_name, "_",
- oc_subgraph_name));
+ builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name,
+ "_", oc_subgraph_name));
builder.Attr("_outside_compilation_subgraph", oc_subgraph_name);
Status s = builder.Finalize(&host_compute_def);
if (!s.ok()) return s;
@@ -1017,8 +1033,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
Graph* graph_out) {
if (sequencer_ == nullptr) {
NodeDef seq_def;
- NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
- "NoOp");
+ NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
builder.Device(device_);
Status s = builder.Finalize(&seq_def);
@@ -1091,10 +1106,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
if (VLOG_IS_ON(1)) {
VLOG(2) << "Build function def " << name;
- dump_graph::DumpGraphToFile(
- strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
- dump_graph::DumpFunctionDefToFile(
- strings::StrCat("encapsulate_fdef_", name), fdef);
+ dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name),
+ *graph_, library);
+ dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name),
+ fdef);
}
if (!reuse_existing_functions || library->Find(name) == nullptr) {
@@ -1130,8 +1145,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
host_compute->AddAttr("shapes", shapes);
} else {
string inference_graph_name =
- strings::StrCat("_outside_compilation_shape_inference_", subgraph_name,
- "_", outside_compilation_subgraph_name);
+ absl::StrCat("_outside_compilation_shape_inference_", subgraph_name,
+ "_", outside_compilation_subgraph_name);
FunctionDef fdef;
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
@@ -1155,10 +1170,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
if (VLOG_IS_ON(1)) {
VLOG(2) << "Replace function def " << name;
dump_graph::DumpGraphToFile(
- strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
+ absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
library);
dump_graph::DumpFunctionDefToFile(
- strings::StrCat("replace_encapsulate_fdef_", name), fdef);
+ absl::StrCat("replace_encapsulate_fdef_", name), fdef);
}
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
@@ -1186,8 +1201,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder(
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
NodeDef key_def;
NodeDefBuilder builder(
- strings::StrCat(call_node_def_.name(), "_key_placeholder"),
- "Placeholder");
+ absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder");
builder.Attr("dtype", DT_STRING);
builder.Attr("shape", shape_proto);
builder.Attr("_host_compute_call_node", call_node_def_.name());
@@ -1221,16 +1235,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
}
NodeDef recv_def;
- NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
- "_", oc_subgraph_name, "_recv"),
+ NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
+ "_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
builder.Device(device_);
builder.Attr("Toutputs", dtypes);
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
builder.Attr("device_ordinal", 0);
- builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
- "_", oc_subgraph_name));
+ builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
+ oc_subgraph_name));
builder.Attr(group_attribute, subgraph_name);
builder.Attr(outside_compilation_attribute, oc_subgraph_name);
builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
@@ -1276,13 +1290,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
}
NodeDef send_def;
- NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
- "_", oc_subgraph_name, "_send"),
+ NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
+ "_", oc_subgraph_name, "_send"),
kSendFromHostOp);
builder.Device(device_);
builder.Attr("Tinputs", dtypes);
- builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
- "_", oc_subgraph_name));
+ builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
+ oc_subgraph_name));
// The correct device_ordinal will be inserted during replication in a
// subsequent rewrite.
builder.Attr("device_ordinal", 0);
@@ -1516,7 +1530,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
// Dump subgraphs.
for (auto& entry : subgraphs_) {
dump_graph::DumpGraphToFile(
- strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
+ absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
*entry.second.GetGraph(), library);
}
}
@@ -2052,7 +2066,7 @@ struct PathDetails {
struct SubgraphAndClusterHash {
inline std::size_t operator()(const SubgraphAndCluster& v) const {
return hash<string>()(
- strings::StrCat(v.subgraph, v.outside_compilation_cluster));
+ absl::StrCat(v.subgraph, v.outside_compilation_cluster));
}
};
@@ -2504,7 +2518,8 @@ Status EncapsulateSubgraphsPass::Run(
const int num_args = input_permutation->size();
std::vector<bool> const_args(num_args);
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr));
DataTypeVector arg_types(num_args);
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 926589546f..90354a801a 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
+// Sorts each node's control inputs by their names. This guarantees that for two
+// structually equivalent GraphDefs, we get the same traversal ordering on
+// node's control input fields.
+// TODO(hpucha): Move the utilities to a more appropriate place.
+void SortControlInputs(GraphDef* gdef);
+
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index c0543a0079..49958093b8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
@@ -25,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
FunctionDef* fdef = library->add_function();
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*graph,
- strings::StrCat("_outside_compilation_shape_inference_", name_suffix),
+ absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
fdef));
return Status::OK();
}
@@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
const auto iter = b.find(elt_a.first);
if (iter == b.end()) {
if (diff) {
- *diff = strings::StrCat(
- map_name, " expected: contains element with key '",
- key_to_string(elt_a.first), "' got: map has no such element");
+ *diff = absl::StrCat(map_name, " expected: contains element with key '",
+ key_to_string(elt_a.first),
+ "' got: map has no such element");
}
return false;
}
if (!compare(elt_a.first, elt_a.second, iter->second)) {
if (diff) {
- *diff = strings::StrCat(map_name, " expected: element with key '",
- key_to_string(elt_a.first), "' has value '",
- value_to_string(elt_a.second), "' got: '",
- value_to_string(iter->second), "'");
+ *diff = absl::StrCat(map_name, " expected: element with key '",
+ key_to_string(elt_a.first), "' has value '",
+ value_to_string(elt_a.second), "' got: '",
+ value_to_string(iter->second), "'");
}
return false;
}
@@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
const auto iter = a.find(elt_b.first);
if (iter == a.end()) {
if (diff) {
- *diff = strings::StrCat(map_name, " got: contains element with key '",
- key_to_string(elt_b.first),
- "' expected: map has no such element");
+ *diff = absl::StrCat(map_name, " got: contains element with key '",
+ key_to_string(elt_b.first),
+ "' expected: map has no such element");
}
return false;
}
@@ -99,38 +100,38 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
const string& diff_preamble, string* diff) {
if (a.op() != b.op()) {
if (diff) {
- *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
- ", expected op '", a.op(), "' got '", b.op());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected op '", a.op(), "' got '", b.op());
}
return false;
}
if (a.device() != b.device()) {
if (diff) {
- *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
- ", expected device '", a.device(), "' got '",
- b.device());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected device '", a.device(), "' got '",
+ b.device());
}
return false;
}
if (a.input_size() != b.input_size()) {
if (diff) {
- *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
- ", expected ", a.input_size(), " inputs got ",
- b.input_size(), " expected:\n", a.DebugString(),
- "\ngot:\n", b.DebugString());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ ", expected ", a.input_size(), " inputs got ",
+ b.input_size(), " expected:\n", a.DebugString(),
+ "\ngot:\n", b.DebugString());
}
return false;
}
std::unordered_set<string> control_input_a;
std::unordered_set<string> control_input_b;
for (int i = 0; i < a.input_size(); ++i) {
- if (str_util::StartsWith(a.input(i), "^")) {
- if (!str_util::StartsWith(b.input(i), "^")) {
+ if (absl::StartsWith(a.input(i), "^")) {
+ if (!absl::StartsWith(b.input(i), "^")) {
if (diff) {
- *diff = strings::StrCat(
- diff_preamble, " mismatch for node ", a.name(), " input ", i,
- ", expected control input ", a.input(i), " got ", b.input(i),
- " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ " input ", i, ", expected control input ",
+ a.input(i), " got ", b.input(i), " expected:\n",
+ a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
@@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
control_input_b.insert(b.input(i));
} else if (a.input(i) != b.input(i)) {
if (diff) {
- *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
- " input ", i, ", expected ", a.input(i),
- " got ", b.input(i), " expected:\n",
- a.DebugString(), "\ngot:\n", b.DebugString());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ " input ", i, ", expected ", a.input(i), " got ",
+ b.input(i), " expected:\n", a.DebugString(),
+ "\ngot:\n", b.DebugString());
}
return false;
}
}
if (control_input_a != control_input_b) {
if (diff) {
- *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
- " control inputs differ expected:\n",
- a.DebugString(), "\ngot:\n", b.DebugString());
+ *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+ " control inputs differ expected:\n",
+ a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
@@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
return av.DebugString() == bv.DebugString();
}
},
- strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
- diff);
+ absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
}
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
if (a.signature().DebugString() != b.signature().DebugString()) {
if (diff) {
- *diff = strings::StrCat("Signature mismatch for function ",
- a.signature().name(), ", expected:\n",
- a.signature().DebugString(), "\ngot:\n",
- b.signature().DebugString());
+ *diff =
+ absl::StrCat("Signature mismatch for function ", a.signature().name(),
+ ", expected:\n", a.signature().DebugString(), "\ngot:\n",
+ b.signature().DebugString());
}
return false;
}
@@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
[](const string& key, const AttrValue& av, const AttrValue& bv) {
return av.DebugString() == bv.DebugString();
},
- strings::StrCat("attr mismatch for function ", a.signature().name()),
+ absl::StrCat("attr mismatch for function ", a.signature().name()),
diff)) {
return false;
}
@@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
[](const string& key, const string& av, const string& bv) {
return av == bv;
},
- strings::StrCat("ret mismatch for function ", a.signature().name()),
+ absl::StrCat("ret mismatch for function ", a.signature().name()),
diff)) {
return false;
}
@@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
if (a.node_def(i).name() == b.node_def(j).name()) {
if (!EqualFunctionNodeDef(
a.node_def(i), b.node_def(j),
- strings::StrCat("Function ", a.signature().name()), diff)) {
+ absl::StrCat("Function ", a.signature().name()), diff)) {
return false;
}
found = true;
@@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
}
if (!found) {
if (diff) {
- *diff = strings::StrCat("Function ", a.signature().name(),
- ", expected: has node '", a.node_def(i).name(),
- "' got: no node of that name");
+ *diff = absl::StrCat("Function ", a.signature().name(),
+ ", expected: has node '", a.node_def(i).name(),
+ "' got: no node of that name");
}
return false;
}
@@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
}
if (!found) {
if (diff) {
- *diff = strings::StrCat("Function ", a.signature().name(),
- ", got: has node '", b.node_def(i).name(),
- "' expected: no node of that name");
+ *diff = absl::StrCat("Function ", a.signature().name(),
+ ", got: has node '", b.node_def(i).name(),
+ "' expected: no node of that name");
}
return false;
}
@@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
auto it = actual_index.find(expected_function.signature().name());
if (it == actual_index.end()) {
if (diff) {
- *diff = strings::StrCat("Did not find expected function '",
- expected_function.signature().name(), "'");
+ *diff = absl::StrCat("Did not find expected function '",
+ expected_function.signature().name(), "'");
}
return false;
}
@@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
if (!actual_index.empty()) {
if (diff != nullptr) {
- *diff = strings::StrCat("Found unexpected function '",
- actual_index.begin()->second->signature().name(),
- "'");
+ *diff =
+ absl::StrCat("Found unexpected function '",
+ actual_index.begin()->second->signature().name(), "'");
}
return false;
}
@@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTestShaped", opts);
}
-Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
+Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
@@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
.FinalizeBuilder(&node_builder);
}
-Node* KnownShape(const gtl::ArraySlice<int>& shape,
+Node* KnownShape(absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
return KnownShapeBase(DT_FLOAT, shape, opts);
}
@@ -417,14 +417,12 @@ Node* KeyPlaceholder(const string& call_node,
}
Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
- const string& oc_cluster,
- const gtl::ArraySlice<DataType>& dtypes,
+ const string& oc_cluster, absl::Span<const DataType> dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
- string key =
- strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
- string name = strings::StrCat("outside_compilation_", cluster, "_",
- oc_cluster, "_recv");
+ string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+ string name =
+ absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv");
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
node_builder.Input(std::move(key_input));
@@ -441,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
- string key =
- strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
- string name = strings::StrCat("outside_compilation_", cluster, "_",
- oc_cluster, "_send");
+ string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+ string name =
+ absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send");
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
@@ -683,8 +680,8 @@ std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
for (const Edge* edge : graph.edges()) {
if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
edges.emplace_back(
- strings::StrCat(edge->src()->name(), ":", edge->src_output()),
- strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
+ absl::StrCat(edge->src()->name(), ":", edge->src_output()),
+ absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
}
std::sort(edges.begin(), edges.end());
return edges;
@@ -768,7 +765,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
@@ -813,7 +810,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
- str_util::StartsWith(n->name(), "const")) {
+ absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
@@ -892,13 +889,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "c:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
@@ -1038,26 +1035,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1190,13 +1187,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1213,13 +1210,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
@@ -1364,13 +1361,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1386,13 +1383,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F2_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"i_0_retval", "I:o:0"}});
@@ -1495,13 +1492,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{},
- {{"Tinputs", gtl::ArraySlice<DataType>({})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1579,13 +1576,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{},
- {{"Tinputs", gtl::ArraySlice<DataType>({})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
- gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
+ absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"}},
{"D"}},
},
@@ -1661,12 +1658,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1742,12 +1739,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"f_0_retval", "F:o:0"}});
@@ -1846,13 +1843,13 @@ TEST(EncapsulateSubgraphsTest,
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O2"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}}},
},
{{"h_0_retval", "H:o:0"}});
@@ -1955,13 +1952,13 @@ TEST(EncapsulateSubgraphsTest,
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
},
{{"h_0_retval", "H:o:0"}});
@@ -2066,37 +2063,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute"})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({})},
{"ancestors",
- gtl::ArraySlice<string>({"outside_compilation_O1_host_compute",
- "outside_compilation_O2_host_compute"})},
+ absl::Span<const string>({"outside_compilation_O1_host_compute",
+ "outside_compilation_O2_host_compute"})},
{"key", "host_compute_channel_F1_O3"},
{"shape_inference_graph", ""},
- {"shapes", gtl::ArraySlice<TensorShapeProto>({})},
+ {"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
@@ -2272,13 +2269,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"c:o:0"},
- {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
- {"ancestors", gtl::ArraySlice<string>({})},
+ {{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
+ {"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph",
"_outside_compilation_shape_inference_F1_O1"},
- {"shapes", gtl::ArraySlice<DataType>({})},
+ {"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"}},
{"c"}},
},
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
new file mode 100644
index 0000000000..97ef8cd3cb
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -0,0 +1,360 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+
+namespace tensorflow {
+
+const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
+ "_xla_compile_id";
+
+namespace {
+
+const char* const kXlaClusterOutput = "XlaClusterOutput";
+
+// Checks if a graph node is marked to be a guaranteed constant.
+bool is_guaranteed_constant(const Node& n) {
+ bool guaranteed_constant = false;
+ if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
+ .ok()) {
+ return false;
+ }
+ return guaranteed_constant;
+}
+
+// Finds the `index` of an _Arg or _Retval node.
+Status GetIndexAttr(const Node& n, int num_args, int* index) {
+ TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
+ if (*index < 0 || *index >= num_args) {
+ return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
+ *index);
+ }
+ return Status::OK();
+}
+
+// Returns the data type of the destination of an edge.
+DataType EdgeType(const Edge* edge) {
+ return edge->dst()->input_type(edge->dst_input());
+}
+
+// Adds the control inputs of `node` to `*deps`.
+void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.in_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->src());
+ }
+ }
+}
+
+// Adds the control outputs of `node` to `*deps`.
+void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.out_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->dst());
+ }
+ }
+}
+
+// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
+// the arguments into the order expected by XlaLaunch computations:
+// 1) arguments
+// 2) resource variable arguments
+// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
+// of the arguments.
+//
+// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
+Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation,
+ NodeDef* call_def) {
+ Graph* graph = graph_ptr->get();
+ const int num_args = input_permutation->size();
+ const int num_retvals = output_permutation->size();
+
+ std::vector<Node*> args;
+ std::vector<Node*> retvals;
+ args.reserve(num_args);
+ retvals.reserve(num_retvals);
+ for (Node* n : graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ // Check if this is a guaranteed constant.
+ if (is_guaranteed_constant(*n)) {
+ return errors::InvalidArgument(
+ "Guaranteed constants are not supported (", n->name(), ")");
+ }
+ args.push_back(n);
+ } else if (n->type_string() == "_Retval") {
+ retvals.push_back(n);
+ }
+ }
+
+ if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
+ return errors::InvalidArgument("Missing or non-consecutive arguments");
+ }
+
+ // Reorders the arguments.
+ std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
+ // Non-resources appear before resources
+ bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
+ bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
+ // Uses the name as a tiebreaker so the output is deterministic.
+ StringPiece a_name(a->name());
+ StringPiece b_name(b->name());
+ return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
+ });
+
+ // Sorts the retvals by name so the order is deterministic.
+ std::sort(retvals.begin(), retvals.end(),
+ [](Node* a, Node* b) { return a->name() < b->name(); });
+
+ // Computes the permutation to produce the correct argument order, and update
+ // the argument indices.
+ int variable_start_index = num_args;
+ for (int i = 0; i < num_args; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
+ if (args[i]->output_type(0) == DT_RESOURCE &&
+ variable_start_index == num_args) {
+ variable_start_index = i;
+ }
+ (*input_permutation)[index] = i;
+ args[i]->AddAttr("index", i);
+ }
+ VLOG(4) << "variable_start_index: " << variable_start_index;
+
+ // Computes the permutation to produce the correct retval order, and update
+ // the argument indices.
+ for (int i = 0; i < num_retvals; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
+ (*output_permutation)[index] = i;
+ retvals[i]->AddAttr("index", i);
+ }
+
+ AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
+ call_def);
+ AddNodeAttr("_variable_start_index", variable_start_index, call_def);
+
+ // Uniquify the function name.
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+
+ // Before serialization, sort each node's control inputs to achieve
+ // determinism. Sorting control inputs could help (but not necessarily) create
+ // a deterministic serialization and fingerprint. Other sources of
+ // nondeterminism include unstable node ordering.
+ SortControlInputs(&gdef);
+ // Fingerprint the function.
+ // Nondeterminism in serialization would not lead to incorrect results, but
+ // may cause spurious cache misses. DeterministicSerialization is a
+ // best-effort deterministic serialization.
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
+ call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Check for undeclared outputs before Encapsulation, so we can give a better
+ // error message.
+ // TODO(phawkins): merge this with the encapsulation code to avoid the extra
+ // O(n) pass over the edges.
+ for (const Edge* e : (*graph)->edges()) {
+ if (!e->IsControlEdge() &&
+ e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
+ e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+ e->dst()->type_string() != kXlaClusterOutput) {
+ return errors::InvalidArgument(
+ "Undeclared output of XLA computation. A common cause of this error "
+ "is variable initializers that depend on the XLA computation. Edge: ",
+ e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
+ e->dst_input());
+ }
+ }
+
+ auto output = absl::make_unique<Graph>((*graph)->op_registry());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, "", **graph, RewriteSubgraph,
+ /*reuse_existing_functions=*/true, &output, flib_def),
+ "EncapsulateXlaComputationsPass failed");
+ graph->swap(output);
+ return Status::OK();
+}
+
+/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
+ Graph* graph) {
+ // Finds all of the XlaLaunch function calls, to avoid mutating the graph
+ // while iterating.
+ std::vector<Node*> launch_nodes;
+ for (Node* n : graph->nodes()) {
+ string name;
+ if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+ launch_nodes.push_back(n);
+ }
+ }
+
+ // Replaces each launch function call together with its neighboring
+ // XlaClusterOutput nodes with a XlaLaunch node.
+ for (Node* launch : launch_nodes) {
+ int variable_start_index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
+ &variable_start_index));
+
+ std::vector<const Edge*> in_edges;
+ TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
+
+ const int num_inputs = in_edges.size();
+ const int num_variables = num_inputs - variable_start_index;
+ const int num_args = variable_start_index;
+
+ VLOG(4) << "Launch node '" << launch->name() << "'"
+ << " input edges: " << in_edges.size() << " num_args: " << num_args
+ << " num_variables: " << num_variables;
+
+ std::vector<Node*> nodes_to_remove = {launch};
+
+ // Data and control inputs to the new XlaLaunch node.
+ std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
+ gtl::FlatSet<Node*> control_inputs;
+ DataTypeVector arg_types(num_args);
+
+ AddControlInputs(*launch, &control_inputs);
+
+ for (int i = 0; i < num_args; ++i) {
+ const Edge* edge = in_edges[i];
+ data_inputs[i] = {edge->src(), edge->src_output()};
+ arg_types[i] = EdgeType(edge);
+ }
+
+ // Appends the variable inputs.
+ for (int i = 0; i < num_variables; ++i) {
+ int pos = variable_start_index + i;
+ const Edge* edge = in_edges[pos];
+ data_inputs[pos] = {edge->src(), edge->src_output()};
+ }
+
+ // Outputs.
+ const int num_outputs = launch->output_types().size();
+ gtl::FlatSet<Node*> control_outputs;
+ std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
+ DataTypeVector output_types(num_outputs);
+
+ for (const Edge* le : launch->out_edges()) {
+ if (le->IsControlEdge()) {
+ control_outputs.insert(le->dst());
+ } else {
+ TF_RET_CHECK(le->src_output() < num_outputs);
+ Node* output_node = le->dst();
+
+ TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
+ << le->DebugString();
+ nodes_to_remove.push_back(output_node);
+
+ for (const Edge* oe : output_node->out_edges()) {
+ TF_RET_CHECK(!oe->IsControlEdge());
+ data_outputs[le->src_output()].push_back(
+ {oe->dst(), oe->dst_input()});
+ }
+ output_types[le->src_output()] = output_node->input_type(0);
+
+ AddControlOutputs(*output_node, &control_outputs);
+ }
+ }
+
+ NodeDef def;
+ def.set_name(launch->name());
+
+ // Target the XLA CPU/GPU backends.
+ VLOG(2) << "Replacing with XlaLaunch";
+ def.set_op("XlaLaunch");
+ AddNodeAttr("Tconstants", DataTypeVector{}, &def);
+ AddNodeAttr("Targs", arg_types, &def);
+ AddNodeAttr("Nresources", num_variables, &def);
+ AddNodeAttr("Tresults", output_types, &def);
+ NameAttrList function;
+ function.set_name(launch->type_string());
+ AddNodeAttr("function", function, &def);
+
+ for (Node* node : nodes_to_remove) {
+ VLOG(2) << "Deleting node " << node->DebugString();
+ // Ensure that we do not attempt to add control edges to nodes that are
+ // deleted.
+ control_inputs.erase(node);
+ control_outputs.erase(node);
+ graph->RemoveNode(node);
+ }
+
+ Status status;
+ Node* xla_launch = graph->AddNode(def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ for (int i = 0; i < data_inputs.size(); ++i) {
+ graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
+ i);
+ }
+ for (Node* n : control_inputs) {
+ graph->AddControlEdge(n, xla_launch);
+ }
+ for (int i = 0; i < data_outputs.size(); ++i) {
+ for (const auto& successor : data_outputs[i]) {
+ graph->AddEdge(xla_launch, i, successor.first, successor.second);
+ }
+ }
+ for (Node* n : control_outputs) {
+ graph->AddControlEdge(xla_launch, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status EncapsulateXlaComputationsPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "EncapsulateXlaComputations(): "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
+ VLOG(1) << "EncapsulateXlaComputations() half-way: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
+ VLOG(1) << "EncapsulateXlaComputations() finished: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
+ **options.graph, options.flib_def);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
new file mode 100644
index 0000000000..c8bb4dc114
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Rewrites computations generated by the xla.compile() Python code into
+// XlaLaunch nodes.
+//
+// xla.compile() does two main things:
+// a) marks operators that make up a XLA computation with the attribute
+// _xla_compile_id=XYZ, where XYZ is a unique key.
+// b) adds XlaClusterOutput nodes to represent outputs of the computation.
+// These nodes are not marked with the _xla_compile_id attribute.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+// Encapsulates nodes marked with the _xla_compile_id attribute into
+// XlaLaunch operators.
+class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
+ public:
+ static const char* const kXlaClusterAttr; // _xla_compile_id
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ // The following methods are public only for unit tests.
+
+ // This pass has two stages:
+ // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
+ // marked with the same _xla_compile_id attribute into functions. These
+ // functions contain the computations to be passed to XlaLaunch. During
+ // encapsulation, we sort the arguments into the order expected by
+ // XlaLaunch.
+ static Status Encapsulate(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // b) we rewrite the function calls generated in phase (a) into XlaLaunch
+ // operators. We also convert the XlaClusterOutput output nodes of the
+ // function call into the outputs of the XlaLaunch operator.
+ static Status BuildXlaLaunchOps(Graph* graph);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
new file mode 100644
index 0000000000..f643fb0cfe
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/test_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<Graph> MakeOuterGraph(
+ const FunctionLibraryDefinition& flib_def, const string& function) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NodeDef def;
+ TF_CHECK_OK(
+ NodeDefBuilder("launch0", function, &flib_def)
+ .Input(a.node()->name(), 0, DT_INT32)
+ .Input(b.node()->name(), 0, DT_FLOAT)
+ .Input(c.node()->name(), 0, DT_INT32)
+ .Input(d.node()->name(), 0, DT_FLOAT)
+ .Input(u.node()->name(), 0, DT_RESOURCE)
+ .Input(v.node()->name(), 0, DT_RESOURCE)
+ .Input(w.node()->name(), 0, DT_RESOURCE)
+ .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
+ .Attr("_variable_start_index", 4)
+ .Finalize(&def));
+
+ Status status;
+ Node* launch = scope.graph()->AddNode(def, &status);
+ TF_CHECK_OK(status);
+ TF_CHECK_OK(scope.DoShapeInference(launch));
+ scope.graph()->AddEdge(a.node(), 0, launch, 0);
+ scope.graph()->AddEdge(b.node(), 0, launch, 1);
+ scope.graph()->AddEdge(c.node(), 0, launch, 2);
+ scope.graph()->AddEdge(d.node(), 0, launch, 3);
+ scope.graph()->AddEdge(u.node(), 0, launch, 4);
+ scope.graph()->AddEdge(v.node(), 0, launch, 5);
+ scope.graph()->AddEdge(w.node(), 0, launch, 6);
+
+ auto out0 =
+ ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
+ auto out1 =
+ ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
+ auto out2 =
+ ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
+ auto out3 =
+ ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+// Makes an encapsulate body graph for use in tests.
+static std::unique_ptr<Graph> MakeBodyGraph() {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
+ auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
+ auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
+
+ auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
+ auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
+ auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, arg3);
+ add_attrs(g.node());
+
+ auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
+ b_identity, 0);
+ auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
+ auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
+ auto out3 =
+ ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
+ // Test that control edge insertion order doesn't affect the cache key
+ // (cluster name) generated by TPU encapsulate pass.
+ auto get_serialized_graph = [](bool control_input_reversed,
+ bool operand_reversed) -> string {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
+ auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
+
+ ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
+ : ops::Add(scope.WithOpName("E"), a1, a0);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
+ "launch0");
+ };
+ add_attrs(e.node());
+
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ auto get_node_in_graph = [&graph](Node* node) {
+ return graph->FindNodeId(node->id());
+ };
+ // Insert control edge in different order. The order should not affect
+ // the encapsulated or serialized graph.
+ if (!control_input_reversed) {
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ } else {
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ }
+ }
+ TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+ // Before serialization, sort control inputs first to remove
+ // nondeterminism.
+ SortControlInputs(&gdef);
+ string serialized;
+ SerializeToStringDeterministic(gdef, &serialized);
+ return serialized;
+ };
+
+ // Changing the order of control input shouldn't affect the graph generated.
+ EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
+ /*operand_reversed=*/false),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+
+ // Changing the order of data input should affect the graph generated.
+ EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/true),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+}
+
+TEST(EncapsulateXlaComputations, Encapsulate) {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
+ add_attrs(b_identity.node());
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), a, c);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, d);
+ add_attrs(g.node());
+
+ auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
+ auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
+ auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
+ auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+
+ std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
+ CopyGraph(*graph, graph_copy.get());
+
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+
+ std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
+ string function = index.at("launch0")->type_string();
+
+ // Tests the outer graph is as expected.
+ {
+ std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
+ GraphDef expected_def;
+ outer->ToGraphDef(&expected_def);
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
+ }
+
+ // Tests the encapsulated body graph is as expected.
+ {
+ std::unique_ptr<Graph> body = MakeBodyGraph();
+ GraphDef expected_body_def;
+ body->ToGraphDef(&expected_body_def);
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
+
+ EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
+ DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
+ result.arg_types);
+ EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
+ result.ret_types);
+ TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
+ }
+
+ // Encapsulates the same computation again, verifies we reuse the same
+ // function. Encapsulation should be deterministic to avoid recompilation.
+ TF_ASSERT_OK(
+ EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
+ std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
+ string function_copy = index_copy.at("launch0")->type_string();
+ EXPECT_EQ(function, function_copy);
+}
+
+TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
+ std::unique_ptr<Graph> body_graph = MakeBodyGraph();
+ FunctionDefLibrary flib;
+ TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+ std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
+
+ Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NameAttrList function;
+ function.set_name("launch0");
+ auto launch = ops::XlaLaunch(
+ scope.WithOpName("launch0"), std::initializer_list<Input>{},
+ std::initializer_list<Input>{a, b, c, d},
+ std::initializer_list<Input>{u, v, w},
+ DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
+
+ auto consumer0_a =
+ ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
+ auto consumer0_b =
+ ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
+ auto consumer0_c =
+ ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
+ auto consumer1 =
+ ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
+ auto consumer2 =
+ ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
+ auto consumer3 =
+ ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
+
+ GraphDef expected_def;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD
index 676f71a75a..8212956adf 100644
--- a/tensorflow/compiler/jit/graphcycles/BUILD
+++ b/tensorflow/compiler/jit/graphcycles/BUILD
@@ -14,6 +14,7 @@ cc_library(
hdrs = ["graphcycles.h"],
deps = [
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
index 805bbc62c1..756377bd95 100644
--- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
@@ -34,7 +34,7 @@ limitations under the License.
#include <algorithm>
#include <unordered_set>
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -44,7 +44,7 @@ namespace {
typedef std::unordered_set<int32> NodeSet;
template <typename T>
struct VecStruct {
- typedef gtl::InlinedVector<T, 4> type;
+ typedef absl::InlinedVector<T, 4> type;
};
template <typename T>
using Vec = typename VecStruct<T>::type;
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index c37b6112cc..315fcb2fa7 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -15,12 +15,19 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
+// EncapsulateXlaComputationsPass rewrites computations generated by the
+// xla.compile() Python code into XlaLaunch nodes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
+ EncapsulateXlaComputationsPass);
+
+// The following POST_REWRITE passes support auto-clustering to enable XLA.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 8f78c110cb..253a5d2547 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -29,16 +29,3 @@ cc_library(
],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- visibility = ["//tensorflow/compiler/jit:friends"],
- deps = [
- "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- ],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc
deleted file mode 100644
index bd4eefbc0b..0000000000
--- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace tensorflow {
-namespace {
-
-// Inputs 2*N tensors, outputs the first N inputs.
-// Logs errors if input tensor i and i + N are not (near) identical
-// in any position.
-class ParallelCheckOp : public OpKernel {
- public:
- explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- template <typename T>
- int CompareTensors(DataType dtype, const char* v0, const char* v1,
- int64 num_elts, int input_idx) {
- int failed = 0;
- const T* p0 = reinterpret_cast<const T*>(v0);
- const T* p1 = reinterpret_cast<const T*>(v1);
- double rtol;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
- &rtol)) {
- LOG(ERROR) << "can't convert parallel_check_rtol "
- << flags->parallel_check_rtol << " to double";
- }
- double atol;
- if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
- &atol)) {
- LOG(ERROR) << "can't convert parallel_check_atol "
- << flags->parallel_check_atol << " to double";
- }
- for (int i = 0; i < num_elts; ++i) {
- bool ok = (p0[i] == p1[i]);
- VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
- if (!ok) {
- if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
- float tolerance =
- std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
- T diff = p0[i] - p1[i];
- if (diff < 0) diff = 0 - diff;
- ok = (diff <= tolerance);
- }
- if (ok) continue;
- LOG(ERROR) << "Op " << name() << " fails equality at output "
- << input_idx << " type " << DataTypeString(dtype)
- << " element " << i << ": std_val=" << p0[i]
- << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
- if (++failed > 10) break;
- }
- }
- return failed;
- }
-
- void Compute(OpKernelContext* ctx) override {
- VLOG(1) << "Compute " << name();
- const int num_pairs = ctx->num_inputs() / 2;
- for (int i = 0; i < num_pairs; ++i) {
- CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
- Tensor t0 = ctx->input(i);
- Tensor t1 = ctx->input(i + num_pairs);
- int64 num_elts = t0.NumElements();
- CHECK_EQ(num_elts, t1.NumElements());
-
- // Compare inputs elementwise for near-exact equality.
- const char* v0 = t0.tensor_data().data();
- const char* v1 = t1.tensor_data().data();
- int failed = 0;
- switch (ctx->input_dtype(i)) {
- case DT_INT32:
- failed =
- CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_INT64:
- failed =
- CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_FLOAT:
- failed =
- CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_DOUBLE:
- failed =
- CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- case DT_BOOL:
- failed =
- CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
- break;
- default:
- LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
- }
- if (failed > 0) {
- LOG(ERROR) << "check failed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- legacy_flags::ParallelCheckOpFlags* flags =
- legacy_flags::GetParallelCheckOpFlags();
- if (flags->parallel_check_failfast) {
- LOG(QFATAL) << "failfast on first parallel-check failure";
- }
- } else {
- VLOG(1) << "check passed for " << name() << " output " << i
- << " num_elts: " << num_elts;
- }
-
- // Propagate the std value.
- if (IsRefType(ctx->input_dtype(i))) {
- ctx->forward_ref_input_to_ref_output(i, i);
- } else {
- ctx->set_output(i, ctx->input(i));
- }
- }
- }
-
- TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
-};
-
-REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
- ParallelCheckOp);
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 7f4370b5b0..b6f2f632f7 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -57,18 +56,17 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
->stream->parent()
->platform()
->id();
- } else {
- platform_id_ = nullptr;
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
+ use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
+ platform_id_ = xla_device_metadata_->platform()->id();
}
}
Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
- const XlaDevice::Metadata* metadata;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- if (s.ok()) {
- *cache = new XlaCompilationCache(metadata->client(),
- metadata->jit_device_type());
+ if (xla_device_metadata_) {
+ *cache = new XlaCompilationCache(xla_device_metadata_->client(),
+ xla_device_metadata_->jit_device_type());
return Status::OK();
}
@@ -117,18 +115,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
- const XlaDevice::Metadata* metadata = nullptr;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- bool allocate_xla_tensors = s.ok();
- bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
-
- // Get the platform_id_ for XLA_* devices.
- if (platform_id_ == nullptr) {
- if (s.ok()) {
- platform_id_ = metadata->platform()->id();
- }
- }
-
std::map<int, OptionalTensor> variables =
SnapshotResourceVariables(ctx, resources_);
@@ -146,7 +132,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// (which local_xla_allocator above uses) as on an XlaDevice, this is a
// dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
// real allocator to allocate real buffers.
- if (allocate_xla_tensors) {
+ if (xla_device_metadata_) {
xla_allocator = client->backend().memory_allocator();
} else {
xla_allocator = &local_xla_allocator;
@@ -163,8 +149,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
- if (metadata) {
- options.shape_representation_fn = metadata->shape_representation_fn();
+ if (xla_device_metadata_) {
+ options.shape_representation_fn =
+ xla_device_metadata_->shape_representation_fn();
}
const XlaCompiler::CompilationResult* kernel;
@@ -176,22 +163,25 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
- // Optimization: don't resolve constants. If we resolve constants we never
- // emit them on the device, meaning that if they are needed by a following
- // computation the host has to transfer them.
- compile_options.resolve_compile_time_constants = false;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
// Optimization: where possible, have the computation return a naked array
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, &compile_options));
+ &kernel, &executable, compile_options));
VLOG(1) << "Executing XLA Computation...";
XlaComputationLaunchContext launch_context(
- client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
+ client, xla_allocator,
+ /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
+ use_multiple_streams_);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 8dfc4b382d..e0f10e9817 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -58,7 +59,9 @@ class XlaLocalLaunchBase : public OpKernel {
DeviceType device_type_;
NameAttrList function_;
- se::Platform::Id platform_id_;
+ se::Platform::Id platform_id_ = nullptr;
+ bool use_multiple_streams_ = false;
+ const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
@@ -81,4 +84,4 @@ class XlaLocalLaunchOp : public XlaLocalLaunchBase {
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 5b6692f523..07c5b23188 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -29,18 +29,6 @@ cc_library(
)
cc_library(
- name = "parallel_check_op_flags",
- srcs = ["parallel_check_op_flags.cc"],
- hdrs = ["parallel_check_op_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "xla_device_flags",
srcs = ["xla_device_flags.cc"],
hdrs = ["xla_device_flags.h"],
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
deleted file mode 100644
index a61694b494..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <mutex>
-#include <vector>
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static ParallelCheckOpFlags* flags;
-static std::vector<Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new ParallelCheckOpFlags;
- flags->parallel_check_failfast = true;
- flags->parallel_check_atol = "1e-5";
- flags->parallel_check_rtol = "1e-5";
- flag_list = new std::vector<Flag>({
- Flag("parallel_check_failfast", &flags->parallel_check_failfast,
- "Fail immediately on first parallel-check comparison error."),
- Flag("parallel_check_atol", &flags->parallel_check_atol,
- "Absolute error tolerance for parallel-check comparison."),
- Flag("parallel_check_rtol", &flags->parallel_check_rtol,
- "Relative error tolerance for parallel-check comparison."),
- });
- xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
deleted file mode 100644
index 156a2a2a71..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with the XLA bridge's
-// parallel_check_op module.
-typedef struct {
- bool parallel_check_failfast; // Fail immediately on first parallel-check
- // comparison error.
- string parallel_check_atol; // Absolute error tolerance for parallel-check
- // comparison.
- string parallel_check_rtol; // Relative error tolerance for parallel-check
- // comparison.
-} ParallelCheckOpFlags;
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 90d5d56998..e6cc6e52ae 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -27,7 +27,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -39,7 +41,9 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -72,18 +76,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
+bool HasResourceOutput(const Node& node) {
+ return std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+bool HasResourceInput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end();
+}
+
+// Returns true if `node` is a resource operation recognized by tf2xla that
+// operates on something other than resource variables.
+bool IsNonResourceVarResourceOp(const Node& node) {
+ // TODO(b/112837194): We can't cluster these because we only support
+ // snapshotting resource variables (and we can't e.g. snapshot stacks). This
+ // limitation may be fixable with some work.
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
+ return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+}
+
// Make sure we don't recurse infinitely on recursive functions.
const int kMaxRecursionDepth = 10;
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const Node& while_node,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
const NameAttrList* name_attr;
NodeDef call;
@@ -98,7 +124,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_cond");
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop condition: " << cond_func;
return false;
@@ -113,7 +140,8 @@ bool IsCompilableWhile(const Node& while_node,
call.set_name("while_body");
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
- if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1,
+ lib_runtime)) {
VLOG(2) << "Rejecting While " << while_node.name()
<< ": can't compile loop body: " << body_func;
return false;
@@ -125,7 +153,8 @@ bool IsCompilableWhile(const Node& while_node,
// Every operator in the function must be compilable for a function to be
// compilable.
bool IsCompilableCall(const NodeDef& call_def,
- const DeviceType& jit_device_type, int depth,
+ const DeviceType& jit_device_type,
+ bool allow_resource_ops, int depth,
FunctionLibraryRuntime* lib_runtime) {
if (depth > kMaxRecursionDepth) {
VLOG(2) << "Rejecting " << call_def.op()
@@ -141,6 +170,10 @@ bool IsCompilableCall(const NodeDef& call_def,
<< ": could not instantiate: " << status;
return false;
}
+
+ auto release_handle_on_return = gtl::MakeCleanup(
+ [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
+
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
const FunctionDef& fdef = fbody->fdef;
@@ -161,12 +194,17 @@ bool IsCompilableCall(const NodeDef& call_def,
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
continue;
if (node->type_string() == "While") {
- // Handle functional While loop (not in open source build).
- return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
+ // Handle functional While loop.
+ return IsCompilableWhile(*node, jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime);
+ }
+ if (!allow_resource_ops &&
+ (HasResourceInput(*node) || HasResourceOutput(*node))) {
+ return false;
}
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, depth + 1,
- lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops,
+ depth + 1, lib_runtime)) {
VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
<< node->name() << ": " << node->def().ShortDebugString();
return false;
@@ -337,6 +375,10 @@ Status FindCompilationCandidates(
flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+ std::vector<bool> compile_time_const_nodes(graph.num_node_ids(), false);
+ TF_RETURN_IF_ERROR(
+ BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
+ &compile_time_const_nodes));
int64& fuel =
legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
@@ -380,19 +422,46 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
- !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
+ !IsCompilableCall(node->def(), jit_device_type,
+ registration->compile_resource_ops, 0, lib_runtime)) {
VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
<< node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
- HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
+ // We don't have a way of returning values of type DT_RESOURCE from XLA
+ // computations so we avoid auto-clustering nodes producing DT_RESOURCE.
+ // XlaLaunchOp also cannot snapshot resources that are not resource
+ // variables so we avoid clustering resource operations that operate on
+ // non-resource variables.
+ VLOG(2) << "Rejecting: " << node->name() << ": resource output "
<< node->type_string();
continue;
}
+ if (compile_time_const_nodes[node->id()] &&
+ !registration->requires_compilation) {
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(
+ graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
+ if (op_def->is_stateful()) {
+ // We need to be able to constant fold the nodes in
+ // compile_time_const_nodes given constant inputs (required by XLA) and
+ // therefore can't auto-cluster stateful ops since these can never be
+ // constant folded.
+ VLOG(2) << "Rejecting " << node->name()
+ << ": must-be-constant stateful op";
+ continue;
+ }
+ }
+ // We don't auto-cluster functional control flow nodes containing resource
+ // operations because safety checks are trickier in this case.
+ // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not
+ // for CPU/GPU.
if (node->type_string() == "While" &&
- !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
+ !IsCompilableWhile(*node, jit_device_type,
+ registration->compile_resource_ops, 0,
+ lib_runtime)) {
continue;
}
// _Arg nodes in a top-level function represent feeds.
@@ -412,6 +481,31 @@ Status FindCompilationCandidates(
return Status::OK();
}
+// Determine the global jit level which is ON if either the
+// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag
+// is true.
+OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
+ const GraphOptimizationPassOptions& options) {
+ OptimizerOptions::GlobalJitLevel global_jit_level =
+ options.session_options->config.graph_options()
+ .optimizer_options()
+ .global_jit_level();
+ if (global_jit_level == OptimizerOptions::DEFAULT) {
+ // To set compilation to be on by default, change the following line.
+ global_jit_level = OptimizerOptions::OFF;
+ }
+ legacy_flags::MarkForCompilationPassFlags* flags =
+ legacy_flags::GetMarkForCompilationPassFlags();
+ if (flags->tf_xla_auto_jit == -1 ||
+ (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
+ // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
+ // the setting in ConfigProto.
+ global_jit_level =
+ static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
+ }
+ return global_jit_level;
+}
+
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
@@ -426,7 +520,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
DeviceType jit_device_type(registration->compilation_device_name);
- return IsCompilableCall(ndef, jit_device_type, 0, flr);
+
+ // We can always *compile* resource operations, even if we are sometimes
+ // unable to auto-cluster them.
+ const bool compile_resource_ops = true;
+ return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr);
}
Status MarkForCompilationPass::Run(
@@ -434,22 +532,9 @@ Status MarkForCompilationPass::Run(
// TODO(phawkins): precompute the "GetCompilationDevice" properties of each
// device ahead of time.
OptimizerOptions::GlobalJitLevel global_jit_level =
- options.session_options->config.graph_options()
- .optimizer_options()
- .global_jit_level();
- if (global_jit_level == OptimizerOptions::DEFAULT) {
- // To set compilation to be on by default, change the following line.
- global_jit_level = OptimizerOptions::OFF;
- }
+ GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
- if (flags->tf_xla_auto_jit == -1 ||
- (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
- // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
- // the setting in ConfigProto.
- global_jit_level =
- static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
- }
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
bool fusion_only = flags->tf_xla_fusion_only;
@@ -517,9 +602,9 @@ Status MarkForCompilationPass::Run(
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
bool should_compile =
(ignore_registration || registration->enable_jit_by_default) &&
- global_jit_level > 0;
+ global_jit_level != OptimizerOptions::OFF;
if (!should_compile) {
- if (global_jit_level <= 0) {
+ if (global_jit_level == OptimizerOptions::OFF) {
VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
} else {
VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
@@ -530,6 +615,139 @@ Status MarkForCompilationPass::Run(
return RunImpl(options, is_compilable);
}
+static string RatioToString(int numerator, int denominator) {
+ return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
+ (100.0 * numerator) / denominator);
+}
+
+static void VLogClusteringSummary(const Graph& g) {
+ if (!VLOG_IS_ON(2)) {
+ return;
+ }
+
+ std::map<absl::string_view, int> cluster_name_to_size;
+ std::map<absl::string_view, std::map<absl::string_view, int>>
+ cluster_name_to_op_histogram;
+ std::map<absl::string_view, int> unclustered_op_histogram;
+ int clustered_node_count = 0;
+
+ for (Node* n : g.nodes()) {
+ absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
+ if (cluster_name) {
+ clustered_node_count++;
+ cluster_name_to_size[*cluster_name]++;
+ cluster_name_to_op_histogram[*cluster_name][n->type_string()]++;
+ } else {
+ unclustered_op_histogram[n->type_string()]++;
+ }
+ }
+
+ int unclustered_node_count = g.num_nodes() - clustered_node_count;
+
+ VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes();
+ VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size "
+ << RatioToString(clustered_node_count, g.num_nodes());
+
+ for (const auto& cluster_name_size_pair : cluster_name_to_size) {
+ absl::string_view cluster_name = cluster_name_size_pair.first;
+ int size = cluster_name_size_pair.second;
+ VLOG(2) << " " << cluster_name << " "
+ << RatioToString(size, g.num_nodes());
+ for (const auto& op_count_pair :
+ cluster_name_to_op_histogram[cluster_name]) {
+ VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second
+ << " instances";
+ }
+ }
+
+ if (!unclustered_op_histogram.empty()) {
+ VLOG(2) << " Unclustered nodes: "
+ << RatioToString(unclustered_node_count, g.num_nodes());
+ for (const auto& pair : unclustered_op_histogram) {
+ VLOG(3) << " " << pair.first << ": " << pair.second << " instances";
+ }
+ }
+
+ struct EdgeInfo {
+ absl::string_view node_name;
+ absl::optional<absl::string_view> cluster_name;
+
+ absl::string_view GetClusterName() const {
+ return cluster_name ? *cluster_name : "[none]";
+ }
+
+ std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
+ const {
+ return {node_name, cluster_name};
+ }
+
+ bool operator<(const EdgeInfo& other) const {
+ return AsPair() < other.AsPair();
+ }
+ };
+
+ using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
+
+ EdgeInfoMap incoming_edge_infos;
+ EdgeInfoMap outgoing_edge_infos;
+
+ std::set<absl::string_view> cluster_names_to_print;
+
+ for (const Edge* e : g.edges()) {
+ const Node* from = e->src();
+ absl::optional<absl::string_view> from_cluster_name =
+ GetXlaClusterForNode(*from);
+
+ const Node* to = e->dst();
+ absl::optional<absl::string_view> to_cluster_name =
+ GetXlaClusterForNode(*to);
+
+ if (to_cluster_name == from_cluster_name) {
+ continue;
+ }
+
+ if (to_cluster_name) {
+ incoming_edge_infos[*to_cluster_name]
+ [EdgeInfo{from->name(), from_cluster_name}]++;
+ cluster_names_to_print.insert(*to_cluster_name);
+ }
+
+ if (from_cluster_name) {
+ outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++;
+ cluster_names_to_print.insert(*from_cluster_name);
+ }
+ }
+
+ VLOG(2) << "*** Inter-Cluster edges:";
+ if (cluster_names_to_print.empty()) {
+ VLOG(2) << " [none]";
+ }
+
+ auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
+ const EdgeInfoMap& edge_info_map,
+ absl::string_view desc) {
+ auto it = edge_info_map.find(cluster_name);
+ if (it != edge_info_map.end()) {
+ VLOG(2) << " " << it->second.size() << " " << desc << " edges";
+ for (const auto& edge_info_count_pair : it->second) {
+ VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " "
+ << edge_info_count_pair.first.node_name << " # "
+ << edge_info_count_pair.second;
+ }
+ } else {
+ VLOG(2) << " No " << desc << " edges.";
+ }
+ };
+
+ for (absl::string_view cluster_name : cluster_names_to_print) {
+ VLOG(2) << " ** Cluster " << cluster_name;
+ print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
+ "incoming");
+ print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos,
+ "outgoing");
+ }
+}
+
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
@@ -537,6 +755,43 @@ static bool IsShapeConsumerOp(const Node& node) {
node.type_string() == "Size";
}
+static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) {
+ // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then
+ // ignore it during resource operation safety analysis. We need this hack
+ // because of two reasons:
+ //
+ // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled.
+ // 2. We don't support live-out values of type DT_RESOURCE and live-in values
+ // of type DT_RESOURCE that are not resource variables.
+ //
+ // Together these imply we cannot let resource variable safety analysis
+ // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different
+ // clusters: both of them will have to be clustered because of (1) and we
+ // won't be able to keep the edge between the two as neither the input to the
+ // second XLA cluster nor the output from the first XLA cluster are supported
+ // because of (2).
+ //
+ // TODO(b/113100872): This can be fixed if the TensorFlow representation for
+ // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then
+ // (2) would no longer hold.
+
+ if (n.assigned_device_name().empty()) {
+ *ignore = false;
+ return Status::OK();
+ }
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n.assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *ignore = true;
+ } else {
+ *ignore = registration->compile_resource_ops;
+ }
+ return Status::OK();
+}
+
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
@@ -565,6 +820,8 @@ Status MarkForCompilationPass::RunImpl(
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@@ -577,6 +834,8 @@ Status MarkForCompilationPass::RunImpl(
worklist.push_back(&clusters[node->id()]);
}
+ OptimizerOptions::GlobalJitLevel global_jit_level =
+ GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags =
legacy_flags::GetMarkForCompilationPassFlags();
@@ -601,7 +860,7 @@ Status MarkForCompilationPass::RunImpl(
string to_scope;
for (int to : cycles.Successors(from)) {
if (to >= graph->num_node_ids()) {
- // Node is a "frame" node that is present only in the cycle detection
+ // Node is a fictitious node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
@@ -616,13 +875,15 @@ Status MarkForCompilationPass::RunImpl(
}
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
- // edge. If even one of the nodes lacks an _XlaScope attribute,
+ // edge. This restriction is overridden if the global_jit_level is ON. If
+ // even one of the nodes lacks an _XlaScope attribute,
// then it is treated as a "bridge" and a cluster may be created
// along it. We may want to restrict this behavior to require
// all nodes marked with _XlaCompile=true to also have a
// _XlaScope property set (and raise an error otherwise); but
// for now we don't do this.
- if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
+ if (global_jit_level == OptimizerOptions::OFF &&
+ GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
from_scope != to_scope) {
continue;
@@ -707,7 +968,7 @@ Status MarkForCompilationPass::RunImpl(
string& name = cluster_names[cluster];
if (name.empty()) {
- name = strings::StrCat("cluster_", cluster_sequence_num++);
+ name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
@@ -718,6 +979,9 @@ Status MarkForCompilationPass::RunImpl(
dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
options.flib_def);
}
+
+ VLogClusteringSummary(*graph);
+
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index a780d4a936..c59770a4c8 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,10 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -26,11 +29,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -48,9 +51,35 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
ids[node->name()] = cluster;
}
}
+
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Clusters:";
+ for (const auto& p : ids) {
+ VLOG(2) << " " << p.first << " -> " << p.second;
+ }
+ }
return ids;
}
+gtl::FlatMap<string, std::vector<string>> GetClusterSets(
+ const Graph& g, std::vector<string>* cluster_names = nullptr) {
+ CHECK(cluster_names == nullptr || cluster_names->empty());
+ gtl::FlatMap<string, std::vector<string>> cluster_sets;
+ for (const auto& p : GetClusters(g)) {
+ cluster_sets[p.second].push_back(p.first);
+ }
+ for (auto& p : cluster_sets) {
+ if (cluster_names != nullptr) {
+ cluster_names->push_back(p.first);
+ }
+ std::sort(p.second.begin(), p.second.end());
+ }
+ if (cluster_names != nullptr) {
+ std::sort(cluster_names->begin(), cluster_names->end());
+ }
+ return cluster_sets;
+}
+
TEST(XlaCompilationTest, Chains) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
@@ -199,7 +228,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
FunctionDef noinline = compilable;
noinline.mutable_signature()->set_name("NoInlineFn");
- AddAttr("_noinline", bool(true), noinline.mutable_attr());
+ AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
FunctionDefLibrary flib;
*flib.add_function() = compilable;
@@ -372,6 +401,44 @@ TEST(XlaCompilationTest, Loops) {
EXPECT_EQ(0, clusters.size());
}
+TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor())
+ .WithAttr(kXlaScopeAttr, "ScopeA"));
+ Node* b = ops::UnaryOp(
+ "Relu", a,
+ builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
+ ops::BinaryOp(
+ "MatMul", a, b,
+ builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
+ TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
+ SessionOptions session_options;
+ session_options.config.mutable_graph_options()
+ ->mutable_optimizer_options()
+ ->set_global_jit_level(OptimizerOptions::ON_2);
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
+ &graph, &flib_def, &session_options));
+ auto clusters = GetClusters(*graph);
+
+ // The computation is: C = A + relu(A)
+ // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
+ // In this case, the GlobalJitLevel overrides the scopes to cluster while
+ // ignoring scopes.
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
+}
+
TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
@@ -463,45 +530,111 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
EXPECT_EQ(clusters["B"], clusters["C"]);
}
-REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float");
-REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource");
-
namespace {
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
-class DummyOp : public XlaOpKernel {
- using XlaOpKernel::XlaOpKernel;
- void Compile(XlaOpKernelContext* ctx) override {}
-};
-
-REGISTER_XLA_OP(Name("ResourceInput"), DummyOp);
-REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp);
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
+ var_handle, value_to_write);
+ return assign_op.operation.node();
+}
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
} // namespace
-TEST(XlaCompilationTest, Resources) {
+TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ FixupSourceAndSinkEdges(root.graph());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- GraphDef graphdef;
- {
- GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
- Node* a =
- ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
- Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
- // We should not form clusters with resource ops by default.
- Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C"));
- Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D"));
- ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
- TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
- }
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- auto clusters = GetClusters(*graph);
- EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph);
+ ASSERT_EQ(cluster_sets.size(), 1);
+ std::vector<string> expected_clustered_nodes = {"AssignmentW",
+ "ValueToAssignW"};
+ ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
+}
+
+TEST(XlaCompilationTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ FixupSourceAndSinkEdges(root.graph());
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_EXPECT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::vector<string> cluster_names;
+ gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ GetClusterSets(*graph, &cluster_names);
+
+ ASSERT_EQ(cluster_sets.size(), 2);
+
+ std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0",
+ "ValueToAssignW0"};
+ ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
+
+ std::vector<string> expected_clustered_nodes_b = {
+ "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
+ ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b);
}
TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
{
- auto BuildNoopNode = [](StringPiece name, Graph* graph) {
+ auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
NodeDefBuilder builder(name, "NoOp");
NodeDef def;
TF_CHECK_OK(builder.Finalize(&def));
@@ -524,11 +657,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.ToString(),
- "Edge from c to a would create a cycle.\n"
- "+-> a\n"
- "| b\n"
- "+-- c\n"));
+ EXPECT_TRUE(absl::StrContains(status.ToString(),
+ "Edge from c to a would create a cycle.\n"
+ "+-> a\n"
+ "| b\n"
+ "+-- c\n"));
}
TEST(XlaCompilationTest, Retval) {
@@ -693,5 +826,73 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
EXPECT_EQ(clusters, expected_clusters);
}
+TEST(XlaCompilationTest, RandomShape) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
+ ops::Const(root.WithOpName("minval"), 1),
+ ops::Const(root.WithOpName("maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["shape"], "");
+}
+
+TEST(XlaCompilationTest, RandomShapeWithFunc) {
+ Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
+
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/"Stateful_func", /*in_def=*/{},
+ /*out_def=*/{"out: int32"},
+ /*attr_def*/
+ {}, /*node_def=*/
+ {FunctionDefHelper::Const("shape_shape", 2),
+ FunctionDefHelper::Const("minval", 1),
+ FunctionDefHelper::Const("maxval", 20),
+ {{"shape"},
+ "RandomUniformInt",
+ {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
+ {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
+ /*ret_def=*/{{"out", "shape:output:0"}});
+
+ func.mutable_signature()->set_is_stateful(true);
+ *flib_def.add_function() = std::move(func);
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ NodeDef call_node;
+ call_node.set_name("fn_call");
+ call_node.set_op("Stateful_func");
+ Status status;
+ Node* call = root.graph()->AddNode(call_node, &status);
+ TF_ASSERT_OK(status);
+
+ Output shape = Output(call, 0);
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
+ flib_def);
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["fn_call"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index a84b82e479..65669877f7 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
- std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
+ SessionOptions* session_options) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
@@ -26,12 +28,19 @@ namespace tensorflow {
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
+ opt_options.session_options = session_options;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunImpl(opt_options);
}
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ SessionOptions session_options;
+ return MarkForCompilation(graph, flib_def, &session_options);
+}
+
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
index b9a0531cb0..216baaf933 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
@@ -25,6 +25,11 @@ class MarkForCompilationPassTestHelper {
// `graph` to the CPU device. To make testing easier, ignores device
// registration, _XlaCompile attributes, input deadness and global jit level.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def,
+ SessionOptions* session_options);
+
+ // Like `MarkForCompilation` but creates a default SessionOptions.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def);
// Like `MarkForCompilation` but creates `flib_def` from the op registry.
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index c9e46bc147..13804c6a05 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -10,10 +10,3 @@ cc_library(
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
-
-cc_library(
- name = "parallel_check_op",
- srcs = ["parallel_check_op.cc"],
- deps = ["//tensorflow/core:framework"],
- alwayslink = 1,
-)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d98ff..1a29c3caab 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
+REGISTER_OP("XlaClusterOutput")
+ .Input("input: T")
+ // Note: when replication is supported, this op will have N outputs.
+ .Output("outputs: T")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(
+ "Operator that connects the output of an XLA computation to other "
+ "consumer graph nodes.");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 68ead39424..10fc9e85d9 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -14,7 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -22,7 +26,7 @@ limitations under the License.
namespace tensorflow {
namespace {
Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
- gtl::ArraySlice<Node*> post_order) {
+ absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
// avoid the device-host copy we'd otherwise need.
@@ -30,7 +34,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
MemoryTypeVector input_mtypes, output_mtypes;
for (Node* n : post_order) {
- gtl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
+ absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
if (!from_cluster) {
continue;
}
@@ -79,8 +83,8 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
// Check if `dst` is in a different cluster, unclustered, or about to be
// partially declustered (here we rely on the post-order traversal order).
// If yes, decluster `n` to avoid the device-to-host memcpy.
- gtl::optional<StringPiece> dst_cluster =
- result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst);
+ absl::optional<absl::string_view> dst_cluster =
+ result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst);
if (from_cluster != dst_cluster) {
CHECK(result->insert(n).second);
break;
@@ -91,15 +95,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
}
Status PartiallyDeclusterNode(Graph* graph, Node* n) {
- StringPiece cluster_name = *GetXlaClusterForNode(*n);
- gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
+ absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+ absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
for (const Edge* out_edge : n->out_edges()) {
if (out_edge->IsControlEdge()) {
continue;
}
Node* dst = out_edge->dst();
- gtl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
+ absl::optional<absl::string_view> dst_cluster_name =
+ GetXlaClusterForNode(*dst);
if (dst_cluster_name != cluster_name) {
out_edges_to_clone.push_back(out_edge);
}
@@ -108,7 +113,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
CHECK(!out_edges_to_clone.empty()) << n->DebugString();
NodeDef ndef = n->def();
- ndef.set_name(strings::StrCat(n->name(), "/declustered"));
+ ndef.set_name(absl::StrCat(n->name(), "/declustered"));
RemoveFromXlaCluster(&ndef);
Status s;
Node* cloned_node = graph->AddNode(ndef, &s);
@@ -128,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
-} // namespace
-Status PartiallyDeclusterPass::Run(
- const GraphOptimizationPassOptions& options) {
- // NB! In this pass we assume the only XLA-auto-clusterable operations that
- // may have side effects are resource variable operations so we don't cluster
- // those. The pass will have to be updated if this assumption becomes
- // invalid.
-
- Graph* graph = options.graph->get();
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
+// Clones nodes to outside their cluster to avoid device-to-host copies. For
+// instance, converts this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
- /*edge_filter=*/[](const Edge& edge) {
- return !edge.src()->IsNextIteration();
- });
+ /*edge_filter=*/NotBackedge);
gtl::FlatSet<Node*> nodes_to_partially_decluster;
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
@@ -168,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
}
nodes_to_partially_decluster.clear();
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
+
+bool IsIntraClusterEdge(const Edge& edge) {
+ absl::optional<absl::string_view> src_cluster_name =
+ GetXlaClusterForNode(*edge.src());
+ absl::optional<absl::string_view> dst_cluster_name =
+ GetXlaClusterForNode(*edge.dst());
+ return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
+}
+
+Status MustCompileNode(const Node* n, bool* result) {
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *result = false;
+ } else {
+ *result = registration->requires_compilation;
+ }
+
+ return Status::OK();
+}
+
+// Declusters nodes to reduce the number of times we think we need to recompile
+// a TensorFlow graph.
+//
+// Abstractly, if we have a cluster of this form:
+//
+// x0 = arg0
+// x1 = arg1
+// ...
+// shape = f(x0, x1, ...)
+// result = Reshape(input=<something>, new_shape=shape)
+//
+// then pulling `f` out of the cluster may reduce the number of compilations and
+// will never increase the number of compilations.
+//
+// We may reduce the number of compilations if f is many to one. For instance
+// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
+// compilations if f is in the cluster but only one compilation if f is outside
+// the cluster.
+//
+// Declustering f will increase the number of compilations only if f is a
+// one-to-many "function" i.e. isn't a function at all. RNG is one possible
+// example, depending on how we look at it. But we never create clusters where
+// such f's would be marked as must-be-constant.
+//
+// We assume here that the extra repeated (repeated compared to a clustered f
+// where it will always be constant folded) host-side computation of f does not
+// regress performance in any significant manner. We will have to revisit this
+// algorith with a more complex cost model if this assumption turns out to be
+// incorrect.
+Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+ std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/NotBackedge);
+ for (Node* n : rpo) {
+ if (!compile_time_const_nodes[n->id()]) {
+ continue;
+ }
+
+ absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+ bool node_on_cluster_edge =
+ absl::c_all_of(n->in_edges(), [&](const Edge* e) {
+ absl::optional<absl::string_view> incoming_cluster =
+ GetXlaClusterForNode(*e->src());
+ return !incoming_cluster || *incoming_cluster != cluster_name;
+ });
+
+ // We don't want to decluster F in a graph like
+ //
+ // Input -> OP -> Shape -> F -> Reshape
+ //
+ // Doing so will break up the cluster. Even if we were okay with breaking
+ // up the cluster we will at least have to relabel the two clusters to have
+ // different cluster names.
+ //
+ // We may want to revisit this in the future: we may have cases where OP is
+ // a small computation that does not benefit from XLA while XLA can optimize
+ // everything that follows the Reshape. In these cases it may be wise to
+ // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
+ // function.
+ //
+ // Note that we do do the right thing for graphs like:
+ //
+ // Input -> F0 -> F1 -> Reshape
+ //
+ // Since we iterate in RPO, we'll first encounter F0, decluster it, then
+ // encounter F1, decluster it and so on.
+ if (node_on_cluster_edge) {
+ bool must_compile_node;
+ TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
+ if (!must_compile_node) {
+ VLOG(3) << "Declustering must-be-constant node " << n->name();
+ RemoveFromXlaCluster(n);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
+ TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+
+ return Status::OK();
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
index 6949b5028e..cfc4ddb563 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.h
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -20,34 +20,11 @@ limitations under the License.
namespace tensorflow {
-// Clones nodes from within a cluster to outside the cluster if profitable.
+// Clones or moves nodes from within a cluster to outside the cluster if
+// profitable. There are two reasons why we do this:
//
-// Today this only clones to avoid device-to-host copies, but in the future we
-// may consider other reasons to clone. For instance, we convert this:
-//
-// .....
-// |
-// v
-// A_Clustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// to:
-//
-// .....
-// | |
-// | +-------------+
-// | |
-// v v
-// A_Clustered A_Unclustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// where the ===> arrow has a hostmem source and destination and would entail a
-// device to host copy if the source and destination were not in the same XLA
-// cluster.
+// - Reducing device-to-host copies.
+// - Reducing the number of XLA recompilations.
class PartiallyDeclusterPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 08a956e4c6..35872daa65 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@@ -31,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -83,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
GraphOptimizationPassOptions opt_options;
@@ -92,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
return pass.Run(opt_options);
}
-const Node* FindNodeByName(const Graph& graph, const string& name) {
- for (const Node* node : graph.nodes()) {
+Node* FindNodeByName(const Graph& graph, const string& name) {
+ for (Node* node : graph.nodes()) {
if (node->name() == name) {
return node;
}
@@ -280,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
"ClusteredProducer0/declustered");
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
}
+
+void AddToCluster(absl::Span<Node* const> nodes,
+ absl::string_view cluster_name) {
+ for (Node* n : nodes) {
+ n->AddAttr(kXlaClusterAttr, string(cluster_name));
+ }
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
+ ops::Placeholder::Attrs{});
+ Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
+ Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
+
+ Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
+ "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+}
+
+TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({reshape.node()}, "cluster_0");
+ AddToCluster({shape.node()}, "cluster_1");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+
+ // This is needed to register the XLA_GPU device.
+ std::vector<Device*> devices;
+ TF_ASSERT_OK(DeviceFactory::AddDevices(
+ SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
+
+ // Scope::ToGraph loses the assigned device name since it goes through
+ // GraphDef/NodeDef which does not have a field for the assigned device name.
+ Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+ n->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+
+ for (Device* d : devices) {
+ delete d;
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
new file mode 100644
index 0000000000..56e35c0059
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -0,0 +1,336 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// ALGORITHM OVERVIEW
+// ==================
+//
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// computes the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+// dependencies.
+//
+// Specifically the result computed by this analysis contains the edge {W, R}
+// iff all of these hold true:
+//
+// - In the graph (g - {edges from NextIteration to Merge}) there is a path
+// from W to R.
+// - IsEdgeSafe(W, R) == False [defined below]
+// - W != R (note: some resource operations both read from and write to
+// resource variables).
+//
+// The result is incorrect around loops because we ignore edges from
+// NextIteration to Merge, but that should be fine because we don't cluster
+// these edges. For instance, in:
+//
+// Init -----> Merge <-------+
+// | |
+// v |
+// Read |
+// | |
+// v |
+// Write |
+// | |
+// v |
+// NextIteration --+
+//
+// we won't put (Read, Write) in the returned set. This is fine if
+// auto-clustering can only cluster the Read->Write edge, but it is a problem if
+// it clusters the Write->NextIteration->Merge->Read edges instead. The same
+// problem is present for the functional version of the loop above. We rely on
+// auto-clustering to not cluster control flow edges like NextIteration->Merge.
+// This is enough to avoid the explicit-control-flow problem shown above. One
+// way to think about this is that we only care about cases where two nodes, A
+// and B, would normally have been put in the same cluster but cannot legally be
+// in the same cluster because of resourcevar-dependencies. If A and B would
+// normally have been put in the same cluster then all paths between A and B
+// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo
+// there could not have been a NextIteration->Merge edge between A and B since
+// we don't cluster these edges.
+//
+// We also rely on auto-clustering to not cluster functional control flow nodes
+// that contain resource operations.
+//
+// IMPLEMENTATION
+// --------------
+//
+// We traverse the graph minus backedges in reverse post order, mapping each
+// node to the set of resource operation reaching that node. Since we visit
+// producers before consumers, we can construct the set of reaching operations
+// by taking the union of the operations reaching the input nodes. These
+// "reaching resource operations" can then be used to create the pairs of
+// incompatible nodes using `IsEdgeSafe`.
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+// Returns true if `n` may call a function.
+Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
+ bool* out_result) {
+ if (flib_def->Contains(n.type_string())) {
+ *out_result = true;
+ } else {
+ *out_result =
+ std::any_of(n.def().attr().begin(), n.def().attr().end(),
+ [](const std::pair<string, AttrValue>& name_attr_pair) {
+ return name_attr_pair.second.has_func();
+ });
+ }
+
+ return Status::OK();
+}
+
+// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
+// not a resource operation recognized by XLA then sets `out_resource_op_kind`
+// to nullopt.
+Status XlaResourceOpKindForNode(
+ const Node& n, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ absl::optional<XlaResourceOpKind>* out_resource_op_kind) {
+ bool should_ignore = false;
+ if (resource_ops_to_ignore) {
+ TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
+ }
+ if (should_ignore) {
+ *out_resource_op_kind = absl::nullopt;
+ return Status::OK();
+ }
+
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
+ if (op_info) {
+ *out_resource_op_kind = op_info->kind();
+ return Status::OK();
+ }
+
+ // We conservatively assume that functions will both read and write resource
+ // variables. In the future we may consider doing some form of
+ // inter-procedural analysis.
+ bool may_call_function;
+ TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
+ if (may_call_function) {
+ *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
+ } else {
+ *out_resource_op_kind = absl::nullopt;
+ }
+
+ return Status::OK();
+}
+
+// Returns true if a control or data dependence from a TensorFlow operation of
+// resource op kind `from` to a TensorFlow operation of resource op kind `to`
+// can be represented by an XLA cluster and needs no special handling around
+// auto-jit.
+bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
+ // XLA clusters forces all reads to happen before all writes, which means the
+ // kinds of edges it can faithfully represent are: Read->Write, Read->Modify,
+ // Modify->Write, Read->Read, Write->Write.
+ //
+ // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
+ // dependencies.
+ return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite;
+}
+
+using ResourceOp = std::pair<int, XlaResourceOpKind>;
+
+string ResourceOpToString(const ResourceOp& resource_op) {
+ return absl::StrCat(
+ resource_op.first, ": ",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
+}
+
+// A copy-on-write set used to store the set of ResourceOps reaching a node in a
+// TensorFlow graph.
+//
+// TODO(sanjoy): It may be useful to pull this out into its own header at some
+// point.
+class ResourceOpSet {
+ private:
+ using Impl = gtl::FlatSet<ResourceOp>;
+
+ public:
+ ResourceOpSet() = default;
+
+ // Adds all ResourceOp s in `other` to this set.
+ void Add(const ResourceOpSet& other) {
+ CHECK(!frozen_);
+ if (other.impl_ == impl_) {
+ other.frozen_ = true;
+ return;
+ }
+
+ if (!impl_) {
+ other.frozen_ = true;
+ impl_ = other.impl_;
+ return;
+ }
+
+ for (ResourceOp resource_op : other) {
+ Add(resource_op);
+ }
+ }
+
+ void Add(const ResourceOp& resource_op) {
+ CHECK(!frozen_);
+ if (!IsCopy() && Contains(resource_op)) {
+ // We can avoid the copy if the item we want to insert already exists.
+ return;
+ }
+
+ EnsureIsCopied();
+ impl_->insert(resource_op);
+ }
+
+ Impl::const_iterator begin() const {
+ return impl_ ? impl_->begin() : GetEmptyImpl()->begin();
+ }
+
+ Impl::const_iterator end() const {
+ return impl_ ? impl_->end() : GetEmptyImpl()->end();
+ }
+
+ bool Contains(const ResourceOp& resource_op) const {
+ return impl_ != nullptr && impl_->count(resource_op);
+ }
+
+ private:
+ bool IsCopy() const { return storage_ != nullptr; }
+
+ void EnsureIsCopied() {
+ if (storage_ == nullptr) {
+ storage_ = absl::make_unique<Impl>();
+ for (ResourceOp op : *this) {
+ storage_->insert(op);
+ }
+ impl_ = storage_.get();
+ }
+ }
+
+ static Impl* GetEmptyImpl() {
+ static Impl* empty_impl = new Impl;
+ return empty_impl;
+ }
+
+ Impl* impl_ = nullptr;
+ std::unique_ptr<Impl> storage_;
+
+ // frozen_ is true if there is another set pointing to this set's impl_. We
+ // can no longer add elements to this set in that case since the sets pointing
+ // to this set expect the contents of this set to be stable.
+ mutable bool frozen_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet);
+};
+
+string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
+ std::vector<string> elements_debug_string;
+ std::transform(resource_op_set.begin(), resource_op_set.end(),
+ std::back_inserter(elements_debug_string), ResourceOpToString);
+ return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
+}
+
+string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
+ return absl::StrCat(
+ "[", n.name(), ": ", n.type_string(), "(",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
+}
+} // namespace
+
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ std::vector<std::pair<int, int>>* result) {
+ CHECK(result->empty());
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ auto resource_op_set_for_node =
+ absl::make_unique<ResourceOpSet[]>(g.num_node_ids());
+
+ const bool vlog = VLOG_IS_ON(2);
+
+ for (Node* n : rpo) {
+ absl::optional<XlaResourceOpKind> op_kind;
+ TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
+ *n, flib_def, resource_ops_to_ignore, &op_kind));
+
+ ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
+
+ // Merge the reaching resource operations for all the incoming edges to
+ // create the set of all possible resource ops reaching `n`.
+ for (const Edge* e : n->in_edges()) {
+ if (n->IsMerge() && e->src()->IsNextIteration()) {
+ // Ignore back-edges (see file comment).
+ continue;
+ }
+
+ const ResourceOpSet& incoming_op_set =
+ resource_op_set_for_node[e->src()->id()];
+ resource_op_set->Add(incoming_op_set);
+ }
+
+ // Add to the "incompatible resource ops" set if necessary.
+ if (op_kind) {
+ for (ResourceOp incoming_op : *resource_op_set) {
+ if (IsEdgeSafe(incoming_op.second, *op_kind)) {
+ continue;
+ }
+
+ if (vlog) {
+ VLOG(2) << "Unsafe edge: "
+ << NodeToString(*g.FindNodeId(incoming_op.first),
+ incoming_op.second)
+ << " -> " << NodeToString(*n, *op_kind);
+ }
+ result->push_back({incoming_op.first, n->id()});
+ }
+
+ resource_op_set->Add({n->id(), *op_kind});
+ }
+
+ if (vlog) {
+ VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set);
+ }
+ }
+
+ std::sort(result->begin(), result->end());
+ CHECK(std::unique(result->begin(), result->end()) == result->end());
+
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
new file mode 100644
index 0000000000..ae8cfeecad
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+// An XLA cluster hoists all resource reads to be beginning of the cluster
+// execution and all the resource writes to the end. This means it cannot
+// enforce arbitrary ordering dependencies (via control or data edges) between
+// resource operations. Since all resource reads happen before all resource
+// writes, edges constraining resource reads to happen before resource writes
+// are fine, but all other kinds of edges are problematic. This analysis
+// returns the set of pairs of resource operations that cannot be put in the
+// same cluster because XLA cannot respect the dependencies between them in the
+// TensorFlow program.
+//
+// The restrictions are not transitive: it is fine to put A and C in the same
+// cluster even if the returned set contains (A,B) and (B,C).
+//
+// In other words, if these pairs are seen as edges in an undirected graph of
+// the nodes in `g` then auto-clustering is at least as constrained as the graph
+// coloring problem on this graph.
+//
+//
+// For instance if we auto-cluster all operations in this TensorFlow graph:
+//
+// ReadVariablepOp0 -> ReadVariableOp1
+// |
+// v
+// AssignVariableOp0 -> AssignVariableOp1
+//
+// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the
+// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for
+// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads
+// all the resource variables when the cluster starts executing without any
+// particular ordering between them; same holds for the AssignVariableOp0 ->
+// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will
+// be respected by XlaLaunchOp though because all reads happen before all
+// writes.
+//
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// back-edges (i.e. the edges from NextIteration to Merge).
+//
+// NB! The result computed by this analysis assumes that we don't auto-cluster
+// functional control flow nodes containing resource operations.
+//
+// If `resource_ops_to_ignore` is set then nodes for which it returns true are
+// ignored (we pretend these nodes are not resource operations).
+Status ComputeIncompatibleResourceOperationPairs(
+ const Graph& g, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ std::vector<std::pair<int, int>>* result);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
new file mode 100644
index 0000000000..e54b547abc
--- /dev/null
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
@@ -0,0 +1,540 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Node* MakeRead(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output read =
+ ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
+ return read.node();
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+Node* MakeModify(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f);
+ ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id),
+ var_handle, value_to_write);
+ return assign_add_op.operation.node();
+}
+
+Node* MakeNeutral(const Scope& scope, const string& id) {
+ return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
+}
+
+Status ComputeIncompatiblePairs(Graph* g,
+ std::vector<std::pair<int, int>>* result) {
+ FixupSourceAndSinkEdges(g);
+ return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {},
+ result);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ MakeRead(root, "R");
+ MakeWrite(root, "W");
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 0);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> modify_write_pair = {modify->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(read, modify);
+ root.graph()->AddControlEdge(modify, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ EXPECT_EQ(incompatible_pairs.size(), 2);
+ std::pair<int, int> modify_write_pair = {modify->id(), write->id()};
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], modify_write_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, modify);
+ root.graph()->AddControlEdge(modify, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ std::pair<int, int> modify_read_pair = {modify->id(), read->id()};
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], modify_read_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* read = MakeRead(root, "R");
+ Node* modify = MakeModify(root, "M");
+ Node* write = MakeWrite(root, "W");
+
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, modify);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 3);
+
+ std::pair<int, int> write_modify_pair = {write->id(), modify->id()};
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ std::pair<int, int> read_modify_pair = {read->id(), modify->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_modify_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_read_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_modify_pair);
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name,
+ Status* status) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ return graph->AddNode(call_node, status);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> call_read_edge = {call->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ReadCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(read, call);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> read_call_edge = {read->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], read_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, CallWrite) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(call, write);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> call_write_edge = {call->id(), write->id()};
+ EXPECT_EQ(incompatible_pairs[0], call_write_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteCall) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ Status status;
+ Node* call = MakeCall(root.graph(), "Const_func", "C", &status);
+ TF_ASSERT_OK(status);
+
+ root.graph()->AddControlEdge(write, call);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_call_edge = {write->id(), call->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_call_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* read = MakeRead(root, "R");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(symbolic_gradient, read);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> symbolic_gradient_read_edge = {symbolic_gradient->id(),
+ read->id()};
+ EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("Const_func");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+
+ Node* write = MakeWrite(root, "W");
+ NameAttrList fn;
+ fn.set_name("Const_func");
+ Node* symbolic_gradient =
+ ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)},
+ /*Tout=*/{DT_FLOAT}, fn)
+ .output[0]
+ .node();
+
+ root.graph()->AddControlEdge(write, symbolic_gradient);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+ std::pair<int, int> write_symbolic_gradient_edge = {write->id(),
+ symbolic_gradient->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* neutral_0 = MakeNeutral(root, "N0");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral_1 = MakeNeutral(root, "N1");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral_0);
+ root.graph()->AddControlEdge(neutral_0, read_0);
+ root.graph()->AddControlEdge(read_0, write_1);
+ root.graph()->AddControlEdge(write_1, neutral_1);
+ root.graph()->AddControlEdge(neutral_1, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 5);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+ std::pair<int, int> write_0_write_1_pair = {write_0->id(), write_1->id()};
+ std::pair<int, int> read_0_read_1_pair = {read_0->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Node* write_0 = MakeWrite(root, "W0");
+ Node* write_1 = MakeWrite(root, "W1");
+ Node* neutral = MakeNeutral(root, "N");
+ Node* read_0 = MakeRead(root, "R0");
+ Node* read_1 = MakeRead(root, "R1");
+
+ root.graph()->AddControlEdge(write_0, neutral);
+ root.graph()->AddControlEdge(write_1, neutral);
+ root.graph()->AddControlEdge(neutral, read_0);
+ root.graph()->AddControlEdge(neutral, read_1);
+ root.graph()->AddControlEdge(write_1, read_1);
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 4);
+ std::pair<int, int> write_0_read_0_pair = {write_0->id(), read_0->id()};
+ std::pair<int, int> write_0_read_1_pair = {write_0->id(), read_1->id()};
+ std::pair<int, int> write_1_read_0_pair = {write_1->id(), read_0->id()};
+ std::pair<int, int> write_1_read_1_pair = {write_1->id(), read_1->id()};
+
+ EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair);
+ EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair);
+ EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair);
+}
+
+TEST(ResourceOperationSafetyAnalysisTest, Loop) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT);
+ Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL);
+ Output enter_value =
+ ops::internal::Enter(root.WithOpName("enter"), init_value, "fr");
+ ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value});
+ ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName("exit"), iv.output);
+ Output next_iteration =
+ ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true);
+ TF_ASSERT_OK(
+ root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1));
+
+ Node* write = MakeWrite(root, "W");
+ Node* read = MakeRead(root, "R");
+
+ root.graph()->AddControlEdge(iv.output.node(), write);
+ root.graph()->AddControlEdge(write, read);
+ root.graph()->AddControlEdge(read, next_iteration.node());
+
+ std::vector<std::pair<int, int>> incompatible_pairs;
+ TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs));
+
+ ASSERT_EQ(incompatible_pairs.size(), 1);
+
+ std::pair<int, int> write_read_pair = {write->id(), read->id()};
+ EXPECT_EQ(incompatible_pairs[0], write_read_pair);
+}
+
+bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
+ return arg_def.type() == DT_RESOURCE;
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 0a025a1fc0..f85121ca27 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -51,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
};
string description;
- strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
- node_name(dst), " would create a cycle.\n");
+ absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
+ node_name(dst), " would create a cycle.\n");
path.resize(path_size);
for (int32 node_id : path) {
string ascii_art;
@@ -63,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
} else {
ascii_art = "+-- ";
}
- strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
+ absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
}
return description;
}
@@ -185,14 +187,14 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
return Status::OK();
}
-gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
+absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
if (attr_value == nullptr) {
- return gtl::nullopt;
+ return absl::nullopt;
}
Status s = AttrValueHasType(*attr_value, "string");
if (!s.ok()) {
- return gtl::nullopt;
+ return absl::nullopt;
}
return attr_value->s();
}
@@ -207,4 +209,29 @@ bool HasResourceInputOrOutput(const Node& node) {
void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
+
+void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
+
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ GraphCycles* cycles) {
+ std::vector<std::pair<int, int>> unsafe_deps;
+ TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
+ *graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
+
+ // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
+ // operations that interact with resource variables, must not be put in the
+ // same cluster. We enforce this constraint by creating a phantom node, X,
+ // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
+ // and Q together since that would create a cycle with X.
+
+ for (std::pair<int, int> unsafe_dep : unsafe_deps) {
+ int phantom_node_id = cycles->NewNode();
+ CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
+ CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index bff76da6f9..ba218f3315 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -18,9 +18,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace tensorflow {
@@ -47,14 +47,24 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
// otherwise returns nullopt.
-gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
+absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(NodeDef* node_def);
+// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(Node* node);
+
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
+// Adds edges to `cycles` to prevent clustering resource operations that cannot
+// be legally clustered.
+Status AdjustCycleDetectionGraphForResourceOps(
+ const Graph* graph, const FunctionLibraryDefinition* flib_def,
+ const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
+ GraphCycles* cycles);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc
index 2cb351e1ec..65bbf3efe8 100644
--- a/tensorflow/compiler/jit/xla_cluster_util_test.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 7140d47a94..3aa9e9c7ed 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() {
string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
string result = sig.name;
for (const auto& a : sig.arg_types) {
- strings::StrAppend(&result, ",", DataTypeString(a.first),
- a.second.DebugString());
+ absl::StrAppend(&result, ",", DataTypeString(a.first),
+ a.second.DebugString());
}
for (const auto& v : sig.arg_values) {
- strings::StrAppend(&result, "; ", v.DebugString());
+ absl::StrAppend(&result, "; ", v.DebugString());
}
return result;
}
@@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
return CompileImpl(options, function, constant_args, variable_args, ctx,
compilation_result, executable, compile_options, false);
}
@@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ const XlaCompiler::CompileOptions& compile_options) {
const NodeDef& def = ctx->op_kernel().def();
NameAttrList name;
name.set_name(def.op());
@@ -256,10 +256,10 @@ Status XlaCompilationCache::CompileImpl(
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op) {
CHECK_NE(executable, nullptr);
- VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
+ VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
VLOG(2) << "num_inputs=" << ctx->num_inputs()
@@ -310,7 +310,7 @@ Status XlaCompilationCache::CompileImpl(
// cache eviction.
mutex_lock entry_lock(entry->mu);
if (!entry->compiled) {
- VLOG(1) << "Compilation cache miss for signature: "
+ VLOG(2) << "Compilation cache miss for signature: "
<< SignatureDebugString(signature);
tensorflow::Env* env = tensorflow::Env::Default();
const uint64 compile_start_us = env->NowMicros();
@@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl(
entry->compiled = true;
if (compile_single_op) {
- entry->compilation_status = compiler.CompileSingleOp(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- signature.name, ctx, args, &entry->compilation_result);
+ entry->compilation_status =
+ compiler.CompileSingleOp(compile_options, signature.name, ctx, args,
+ &entry->compilation_result);
} else {
entry->compilation_status = compiler.CompileFunction(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- function, args, &entry->compilation_result);
+ compile_options, function, args, &entry->compilation_result);
}
TF_RETURN_IF_ERROR(entry->compilation_status);
CHECK_EQ(entry->executable.get(), nullptr);
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index fc5f008f4f..10ad87e38c 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
@@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase {
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ const XlaCompiler::CompileOptions& compile_options);
xla::LocalClient* client() const { return client_; }
const DeviceType& device_type() const { return device_type_; }
@@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options,
+ const XlaCompiler::CompileOptions& compile_options,
bool compile_single_op);
// Takes `result` which has been compiled from a Tensorflow subgraph to a
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index dd84fb34c1..3ba48e8c31 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile(
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
- result, executable, &compile_options);
+ result, executable, compile_options);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 2a2691a6a4..51797def04 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <stdlib.h>
#include <unordered_set>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@@ -101,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
}
std::unique_ptr<XlaDeviceAllocator> alloc =
- xla::MakeUnique<XlaDeviceAllocator>();
+ absl::make_unique<XlaDeviceAllocator>();
XlaDeviceAllocator* alloc_ptr = alloc.get();
state.allocators_[{backend, device_ordinal}] = std::move(alloc);
return alloc_ptr;
@@ -147,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
}
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
- strings::StrCat(name_prefix, "/device:", device_name, ":",
- device_ordinal),
+ absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
- strings::StrCat("device: ", device_name, " device"));
+ absl::StrCat("device: ", device_name, " device"));
device->reset(
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
@@ -184,14 +184,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return device_type_;
}
-/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
- const Metadata** metadata) {
+/*static*/ Status XlaDevice::GetMetadataFromDevice(
+ DeviceBase* device, const XlaDevice::Metadata** metadata) {
*metadata = nullptr;
- XlaDevice* xla_device =
- dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
+ XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
if (xla_device == nullptr) {
return errors::Internal(
- "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
+ "Cannot get XLA metadata from non-XLA device \"", device->name(),
"\". GetMetadata must only be called on an XLA device. Either an "
"internal bug has been triggered, or an XLA-specific op has been "
"placed on the wrong device.");
@@ -200,6 +199,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
+/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
+/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
@@ -327,7 +336,7 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
// to those methods; see the bug for details. Our only saving grace at the
// moment is that this race doesn't seem to occur in practice.
if (use_gpu_device_info_) {
- auto gpu_device_info = MakeUnique<GpuDeviceInfo>();
+ auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
gpu_device_info->stream = stream_.get();
gpu_device_info->default_context = device_context_;
set_tensorflow_gpu_device_info(gpu_device_info.get());
@@ -364,11 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
- // When Xprof profiling is off (which is the default), constructing the
- // activity is simple enough that its overhead is negligible.
- tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
- op_kernel->IsExpensive());
- op_kernel->Compute(context);
+ TracingDevice::Compute(op_kernel, context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index dbf35f349f..92891ffa8c 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice {
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
+ // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
+ static Status GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata);
+
// Factory function. 'platform_name' is the name of the XLA platform.
// 'device_name' is the name of the Tensorflow device to create.
// 'jit_device_name' is the name of the corresponding JIT device.
@@ -158,6 +162,9 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ static Status GetMetadataFromDevice(DeviceBase* device,
+ const XlaDevice::Metadata** metadata);
+
mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0a0c089241..af83c792e5 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -91,7 +91,8 @@ Status XlaTransferManager::TransferLiteralToDevice(
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
- if (UseMultipleStreams()) {
+ if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow(
+ stream_->parent(), shaped_buffer)) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
@@ -123,11 +124,11 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_.get(), shaped_buffer, literal,
- [=, &shaped_buffer, &literal](xla::Status status) {
+ [=, &shaped_buffer](xla::Status status) {
ref.Unref();
done([&]() -> Status {
- VLOG(1) << "Transfer from device as literal: " << literal.ToString()
- << " " << shaped_buffer.ToString();
+ VLOG(1) << "Transfer from device as literal: "
+ << shaped_buffer.ToString();
return status;
}());
});
@@ -183,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- if (status.ok()) {
- xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback([this, done]() {
- // We must not call the done closure directly from DoHostCallback
- // to avoid a deadlock. If done() is the callback that ends an
- // Executor's run, the Executor may call XlaDevice::Sync() inside the
- // callback. This deadlocks, because XlaDevice::Sync() waits for all
- // stream activity to complete.
- thread_pool_->Schedule([done]() { done(Status::OK()); });
- });
- return;
- }
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
@@ -207,13 +196,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
- xla_tensor->set_host_tensor(*cpu_tensor);
-
+ if (status.ok()) {
+ xla_tensor->set_host_tensor(*cpu_tensor);
+ }
done(status);
}
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name,
+ absl::string_view tensor_name,
Device* device,
Tensor* cpu_tensor,
StatusCallback done) {
@@ -349,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name,
+ absl::string_view tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 2e7445340c..df82421294 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -57,7 +57,7 @@ class XlaTransferManager {
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name, Device* device,
+ absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
@@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext {
Tensor* device_tensor,
StatusCallback done) const override;
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
- StringPiece tensor_name, Device* device,
+ absl::string_view tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index da3e329247..49c8582682 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -198,31 +198,33 @@ class XlaAssignVariableOp : public AsyncOpKernel {
\
REGISTER_KERNEL_BUILDER( \
Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \
- GeneratorDatasetOp); \
+ data::GeneratorDatasetOp); \
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \
.Device(DEVICE) \
.HostMemory("buffer_size") \
.HostMemory("input_dataset") \
.HostMemory("handle"), \
- PrefetchDatasetOp); \
+ data::PrefetchDatasetOp); \
\
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \
- IteratorHandleOp); \
+ data::IteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \
- MakeIteratorOp); \
+ data::MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
- AnonymousIteratorHandleOp); \
+ data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
- IteratorGetNextOp); \
+ data::IteratorGetNextOp); \
+ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
+ data::IteratorGetNextSyncOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
- IteratorToStringHandleOp); \
+ data::IteratorToStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \
.Device(DEVICE) \
.HostMemory("string_handle"), \
- IteratorFromStringHandleOp); \
+ data::IteratorFromStringHandleOp); \
REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \
.Device(DEVICE) \
.HostMemory("output") \
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 4b499b1613..bc0db558d8 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -41,8 +42,8 @@ static bool IsShapeConsumerOp(const Node& node) {
}
// Returns true if the op can be decomposed into XLA ops for which
-// there are fusable elemental implementations.
-bool IsXlaFusable(const NodeDef& node) {
+// there are fusible elemental implementations.
+static bool IsXlaFusible(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
{// tf2xla/kernels/aggregate_ops.cc
@@ -176,9 +177,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
- // Assume all fusable ops are registered.
+ // Assume all fusible ops are registered.
// TODO(hpucha): Check for registration if possible.
- if (!IsXlaFusable(node->def())) {
+ if (!IsXlaFusible(node->def())) {
continue;
}
@@ -208,6 +209,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
+ TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
+ &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
@@ -324,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
string& name = cluster_names[cluster];
if (name.empty()) {
- name = strings::StrCat("cluster_", cluster_sequence_num++);
+ name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
index 5736760a87..68e19c8a13 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@@ -71,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) {
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
-TEST_F(XlaFusionOptimizerTest, FusableOps) {
+TEST_F(XlaFusionOptimizerTest, FusibleOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
@@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
+TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output var_handle =
+ ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
+ Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
+ Output begin = ops::Const(root.WithOpName("begin"), 0);
+ Output end = ops::Const(root.WithOpName("end"), 1);
+ Output strides = ops::Const(root.WithOpName("strides"), 1);
+ ops::ResourceStridedSliceAssign assign_1(
+ root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
+ ops::ResourceStridedSliceAssign assign_2(
+ root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
+ root.graph()->AddControlEdge(assign_1.operation.node(),
+ assign_2.operation.node());
+ grappler::GrapplerItem item;
+ root.graph()->ToGraphDef(&item.graph);
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 4efbb2d5d7..affeab4a8c 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -175,7 +176,7 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
- arg_buffers_[i] = xla::MakeUnique<ShapedBuffer>(
+ arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
@@ -270,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
const TensorShape& shape = kernel->outputs[i].shape;
- VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
-
- se::DeviceMemoryBase buffer = output.buffer({output_num});
- if (allocate_xla_tensors_) {
- Tensor* output_tensor;
- TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
- XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- if (xla_tensor) {
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
- if (use_multiple_streams_) {
- xla_tensor->SetDefinedOn(stream, definition_event);
+ const DataType& type = kernel->outputs[i].type;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
+ << DataTypeString(type);
+ if (type == DT_RESOURCE) {
+ ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
+ } else {
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
+ if (allocate_xla_tensors_) {
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ xla_tensor->SetDefinedOn(stream, definition_event);
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
}
} else {
- // xla_tensor wasn't valid, which must mean this is a zero-element
- // tensor.
- CHECK_EQ(output_tensor->TotalBytes(), 0);
+ Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+ ctx->expected_output_dtype(i), shape, buffer, allocator);
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
+ ctx->set_output(i, output_tensor);
}
- } else {
- Tensor output_tensor = XlaTensorBuffer::MakeTensor(
- ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(xla::OwningDeviceMemory(), {output_num});
- ctx->set_output(i, output_tensor);
+ ++output_num;
}
- ++output_num;
}
if (VLOG_IS_ON(3)) {
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 4232f514b3..7ac275fab8 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -167,4 +167,4 @@ xla::ScopedShapedBuffer ExtractSubShapedBuffer(
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 8d36d0fa0a..d95da63405 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -70,7 +71,7 @@ class XlaTensor {
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
- xla::MakeUnique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
+ absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
}
// Some tensors on the device may have known values on the host. We use these
@@ -121,10 +122,10 @@ class XlaTensor {
std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
- gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
+ absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
mutex mu_;
};
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index ae98b3f0f9..050d827a09 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -72,7 +72,7 @@ py_test(
tf_xla_py_test(
name = "adadelta_test",
- size = "medium",
+ size = "large",
srcs = ["adadelta_test.py"],
deps = [
":xla_test",
@@ -251,6 +251,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "matrix_triangular_solve_op_test",
size = "small",
+ timeout = "moderate",
srcs = ["matrix_triangular_solve_op_test.py"],
tags = ["optonly"],
deps = [
@@ -388,6 +389,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "reshape_op_test",
+ size = "small",
+ srcs = ["reshape_op_test.py"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
srcs = ["dynamic_stitch_test.py"],
@@ -559,6 +573,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "matrix_band_part_test",
size = "medium",
+ timeout = "long",
srcs = ["matrix_band_part_test.py"],
tags = ["optonly"],
deps = [
@@ -715,6 +730,7 @@ tf_xla_py_test(
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -1087,6 +1103,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/strings",
],
)
@@ -1177,3 +1194,19 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "xla_ops_test",
+ size = "small",
+ srcs = ["xla_ops_test.py"],
+ disabled_backends = ["cpu_ondemand"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py
index 3e3c09c66e..b7b7fda293 100644
--- a/tensorflow/compiler/tests/adadelta_test.py
+++ b/tensorflow/compiler/tests/adadelta_test.py
@@ -33,7 +33,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
num_updates = 4 # number of ADADELTA steps to perform
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
var0_init = [1.0, 2.0]
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
index dc1625793a..69fb3ec296 100644
--- a/tensorflow/compiler/tests/adagrad_da_test.py
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithoutRegularizationBasic1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
@@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
def testAdagradDAWithL1_L2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index d775850a80..ab69319c59 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -57,7 +57,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -83,7 +83,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 0d2e4d0296..058576b3d4 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -53,9 +54,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -95,9 +96,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
@@ -137,9 +138,9 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
index c4fdbc5974..3ed1d41b71 100644
--- a/tensorflow/compiler/tests/adamax_test.py
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for i, dtype in enumerate(self.float_types):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@@ -100,7 +100,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
index 9ec5a964cb..1bc07ace23 100644
--- a/tensorflow/compiler/tests/addsign_test.py
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase):
alpha=1.0,
beta=0.9):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 9d3a889b1f..4155342787 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase):
op_input: numpy input array to use as input to 'op'.
expected: numpy array representing the expected output of 'op'.
"""
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(op_input.dtype), op_input.shape, name="a")
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 069e83d083..c478ff4eea 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -36,7 +36,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
"""Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
@@ -1018,7 +1018,38 @@ class BinaryOpsTest(xla_test.XLATestCase):
[7, 7, 7, 7, 7, 7]],
dtype=dtype))
- def testMirrorPad(self):
+ def testSymmetricMirrorPad(self):
+ mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
+ for dtype in self.numeric_types:
+ self._testBinary(
+ mirror_pad,
+ np.array(
+ [
+ [1, 2, 3], #
+ [4, 5, 6], #
+ ],
+ dtype=dtype),
+ np.array([[
+ 2,
+ 2,
+ ], [3, 3]], dtype=np.int32),
+ expected=np.array(
+ [
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [6, 5, 4, 4, 5, 6, 6, 5, 4], #
+ [3, 2, 1, 1, 2, 3, 3, 2, 1], #
+ ],
+ dtype=dtype))
+ self._testBinary(
+ mirror_pad,
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[0, 0], [0, 0]], dtype=np.int32),
+ expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
+
+ def testReflectMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types:
self._testBinary(
@@ -1175,6 +1206,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
for dtype in self.numeric_types:
self._testBinary(
array_ops.tile,
+ np.array([[6], [3], [4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([6, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[6, 3, 4]], dtype=dtype),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.empty([2, 0], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
np.array([[6]], dtype=dtype),
np.array([1, 2], dtype=np.int32),
expected=np.array([[6, 6]], dtype=dtype))
@@ -1370,5 +1411,40 @@ class BinaryOpsTest(xla_test.XLATestCase):
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
+ def testBroadcastTo(self):
+ for dtype in self.all_types:
+ x = np.random.randint(0, high=100, size=[2, 3])
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([2, 3], dtype=np.int32),
+ expected=x)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([6, 6], dtype=np.int32),
+ expected=np.tile(x, [3, 2]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 4, 3], dtype=np.int32),
+ expected=np.tile(x, [7, 2, 1]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 0, 3], dtype=np.int32),
+ expected=np.zeros([7, 0, 3], dtype=dtype))
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array([7, 1, 2, 9], dtype=np.int32),
+ expected=np.tile(x, [7, 1, 1, 3]))
+ self._testBinary(
+ array_ops.broadcast_to,
+ np.zeros([2, 0], dtype=dtype),
+ np.array([4, 0], dtype=np.int32),
+ expected=np.zeros([4, 0], dtype=dtype))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py
index ef4d5f6322..5c24db539b 100644
--- a/tensorflow/compiler/tests/bucketize_op_test.py
+++ b/tensorflow/compiler/tests/bucketize_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class BucketizationOpTest(xla_test.XLATestCase):
def testInt(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]}))
def testFloat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.])
@@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]}))
def test2DInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
@@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
def testInvalidBoundariesOrder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11])
@@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
sess.run(op, {p: [-5, 0]})
def testBoundariesNotList(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Expected list.*"):
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index a4e7f75081..a57d1dc81e 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -56,7 +56,7 @@ class CategoricalTest(xla_test.XLATestCase):
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
@@ -79,7 +79,7 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = rng(dtype, output_dtype)
@@ -107,7 +107,7 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index ed532db0ee..d1896a50f7 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase):
def _verifyCholesky(self, x, atol=1e-6):
# Verify that LL^T == x.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(
dtypes.as_dtype(x.dtype), shape=x.shape)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
index e42ebf8f9e..88bd58b2da 100644
--- a/tensorflow/compiler/tests/clustering_test.py
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
val2 = np.array([5, 6, 7, 8], dtype=np.float32)
expected = val1 + val2
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
@@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase):
val1 = np.array([4, 3, 2, 1]).astype(np.float32)
val2 = np.array([5, 6, 7, 8]).astype(np.float32)
expected = val1 + val2
- with self.test_session():
+ with self.cached_session():
with ops.device(CPU_DEVICE):
input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2")
@@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase):
# where x and z are placed on the CPU and y and w are placed on the XLA
# device. If y and w are clustered for compilation, then the graph will
# deadlock since the clustered graph will contain a self-loop.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device(CPU_DEVICE):
x = array_ops.placeholder(dtypes.float32, [2])
with self.test_scope():
@@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase):
self.assertAllClose(result, [12., 2.], rtol=1e-3)
def testHostMemory(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.int32)
with self.test_scope():
y = x + 1
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index d9ad428147..37e5318bb5 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
class ConcatTest(xla_test.XLATestCase):
def testHStack(self):
- with self.test_session():
+ with self.cached_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
@@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[4:, :], params[p2])
def testVStack(self):
- with self.test_session():
+ with self.cached_session():
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
with self.test_scope():
@@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(result[:, 4:], params[p2])
def testInt32(self):
- with self.test_session():
+ with self.cached_session():
p1 = np.random.rand(2, 3).astype("i")
p2 = np.random.rand(2, 3).astype("i")
x1 = constant_op.constant(p1)
@@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase):
dtype_feed = dtypes.float32
else:
dtype_feed = dtype
- with self.test_session():
+ with self.cached_session():
p = []
for i in np.arange(num_tensors):
input_shape = shape
@@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testRandom(dtypes.int32)
def _testGradientsSimple(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsSimple()
def _testGradientsFirstDim(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase):
self._testGradientsFirstDim()
def _testGradientsLastDim(self):
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase):
# Random dim to concat on
concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
- with self.test_session():
+ with self.cached_session():
inp = []
inp_tensors = []
with self.test_scope():
@@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
@@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase):
def testConcatTuple(self):
c1 = np.random.rand(4, 4).astype(np.float32)
c2 = np.random.rand(4, 4).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
@@ -295,7 +295,7 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
@@ -309,7 +309,7 @@ class ConcatOffsetTest(xla_test.XLATestCase):
class PackTest(xla_test.XLATestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -319,7 +319,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
@@ -329,7 +329,7 @@ class PackTest(xla_test.XLATestCase):
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index f9db103f6d..af00ff287d 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
with self.test_scope():
@@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope():
@@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
dilations = test_utils.PermuteDimsBetweenDataFormats(
dilations, data_format_src, data_format_dst)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py
index 31ee41f04f..33fd983b54 100644
--- a/tensorflow/compiler/tests/conv3d_test.py
+++ b/tensorflow/compiler/tests/conv3d_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest
class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
def testGradient(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
@@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
class Conv3DTransposeTest(xla_test.XLATestCase):
def testConv3DTransposeSingleStride(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 1, 1, 1, 1]
# Input, output: [batch, depth, height, width, channel]
@@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeSame(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeValid(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 865f60ccab..0af74c2d8f 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase):
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1)
@@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
@@ -86,7 +87,7 @@ class DenseLayerTest(test.TestCase):
XlaLaunch op by XLA.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
@@ -113,7 +114,7 @@ class DenseLayerTest(test.TestCase):
cluster, causing dense layer to be split into TWO XlaLaunch ops.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
with jit_scope():
y = layers.dense(x, 3)
@@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 98dc73e189..6ef8a68ca5 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=data_type).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=data_type).reshape(filter_in_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if data_type == np.float32:
tolerance = 1e-4
else:
@@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
dtype=np.float32).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=np.float32).reshape(filter_in_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32)
with self.test_scope():
@@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
- with self.test_session():
+ with self.cached_session():
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
@@ -356,7 +356,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
- with self.test_session():
+ with self.cached_session():
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
index 154e36b10e..5f01e128f0 100644
--- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py
+++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index edd78153b5..50b04daa6b 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest
class DynamicStitchTest(xla_test.XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
index_placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
]
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index ff097f80f1..63cee550fd 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -101,7 +101,7 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product)
# Run some ops graphly
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
@@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f(v)
self.assertEqual(2.0, var.numpy())
+ def testReturnResourceHandle(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
+
+ def f(v):
+ return v.handle
+
+ f = function.defun(f)
+ handle = f(v)
+ self.assertAllEqual(v.numpy(),
+ resource_variable_ops.read_variable_op(
+ handle, dtypes.float32).numpy())
+
+ def testReturnMultipleResourceHandles(self):
+ with self.test_scope():
+ v1 = resource_variable_ops.ResourceVariable(1.25)
+ v2 = resource_variable_ops.ResourceVariable(2.0)
+
+ def f(v):
+ return v.handle, 3.0 * v, v2.handle, v + v2
+
+ f = function.defun(f)
+ v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
+ self.assertAllEqual(v1.numpy(),
+ resource_variable_ops.read_variable_op(
+ v1_handle, dtypes.float32).numpy())
+ self.assertEqual(3.75, v1_times_3.numpy())
+ self.assertAllEqual(v2.numpy(),
+ resource_variable_ops.read_variable_op(
+ v2_handle, dtypes.float32).numpy())
+ self.assertEqual(3.25, variable_sum.numpy())
+
def testAllArgumentKinds(self):
"""Test a complex function that takes different argument kinds.
@@ -443,7 +475,6 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
def testNestedDefun(self):
- self.skipTest('Nested defuns do not work on TPU at the moment')
with self.test_scope():
@function.defun
@@ -458,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase):
y = two_x_plus_1(x)
self.assertAllEqual([5, 7, 9], y.numpy())
+ def testNestedDefunWithVariable(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ y = f(x)
+
+ self.assertEqual(75, y.numpy())
+
+ def testNestedDefunInGradientTape(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def g(x):
+ x = v0 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ y = f(x)
+ dy = tape.gradient(y, v0)
+
+ self.assertEqual(75, y.numpy())
+ self.assertEqual(30, dy.numpy())
+
+ def testNestedDefunInGradientTapeDifferentVars(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+ v1 = resource_variable_ops.ResourceVariable(3.0)
+
+ @function.defun
+ def g(x):
+ x = v1 * x
+ return x
+
+ @function.defun
+ def f(x):
+ x = g(v0 * x)
+ return x
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape(persistent=True) as tape:
+ y = f(x)
+ dy_v0 = tape.gradient(y, v0)
+ dy_v1 = tape.gradient(y, v1)
+
+ self.assertEqual(45, y.numpy())
+ self.assertEqual(9, dy_v0.numpy())
+ self.assertEqual(15, dy_v1.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py
index 5529fdbb09..37061e91d1 100644
--- a/tensorflow/compiler/tests/extract_image_patches_op_test.py
+++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py
@@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase):
strides = [1] + strides + [1]
rates = [1] + rates + [1]
- with self.test_session():
+ with self.cached_session():
image_placeholder = array_ops.placeholder(dtypes.float32)
with self.test_scope():
out_tensor = array_ops.extract_image_patches(
diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py
index c48ab178bf..2178c44556 100644
--- a/tensorflow/compiler/tests/fake_quant_ops_test.py
+++ b/tensorflow/compiler/tests/fake_quant_ops_test.py
@@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
@@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
@@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
],
dtype=np.float32)
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
@@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
expected_backprops_wrt_min = 1.0 + 2.0
expected_backprops_wrt_max = 10.0 + 11.0
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index c64ea249ec..b3e13fbaa6 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -71,7 +71,7 @@ class FFTTest(xla_test.XLATestCase):
data = np.reshape(data.astype(np.float32).view(np.complex64), shape)
data = to_32bit(complex_to_input(data))
expected = to_32bit(input_to_expected(data))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
@@ -93,7 +93,7 @@ class FFTTest(xla_test.XLATestCase):
data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
expected = np.swapaxes(expected, -1, -2)
expected *= window.sum() # scipy divides by window sum
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
index 0f64cc87cd..8c7edfd277 100644
--- a/tensorflow/compiler/tests/fifo_queue_test.py
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -31,13 +31,13 @@ from tensorflow.python.platform import test
class FIFOQueueTest(xla_test.XLATestCase):
def testEnqueue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
@@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual(1, q.size().eval())
def testMultipleDequeues(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue([1]))
self.evaluate(q.enqueue([2]))
@@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
def testQueuesDontShare(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue(1))
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
@@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})
def testParallelEnqueue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 1da97fd512..f1b87a5ffb 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -29,7 +29,6 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
-
class FtrlOptimizerTest(xla_test.XLATestCase):
def initVariableAndGradient(self, dtype):
@@ -112,7 +111,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -146,7 +145,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlwithoutRegularization2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -174,7 +173,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testFtrlWithL1(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -196,13 +195,17 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4)
+ np.array([-7.66718769, -10.91273689]),
+ var0.eval(),
+ rtol=1e-4,
+ bfloat16_rtol=1e-1,
+ bfloat16_atol=1e-1)
self.assertAllCloseAccordingToType(
np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4)
def testFtrlWithL1_L2(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -236,7 +239,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
weights will tend to have smaller magnitudes with this parameter set.
"""
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -259,9 +262,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params
self.assertAllCloseAccordingToType(
- np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4)
+ np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
self.assertAllCloseAccordingToType(
- np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4)
+ np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
+ accum0 = list(opt0._slots["accum"].values())[0].eval()
+ accum1 = list(opt1._slots["accum"].values())[0].eval()
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
# When variables are initialized with Zero, FTRL-Proximal has two properties:
# 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
@@ -273,9 +316,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivAdagradwithoutRegularization(self):
steps = 5
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype)
self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2)
@@ -284,9 +327,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
def testEquivGradientDescentwithoutRegularization(self):
steps = 5
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.equivGradientDescentTest_GradientDescentPart(
steps, dtype)
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 04fba44446..b1891b918c 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase):
def testCompileTimeConstantsInDefun(self):
"""Tests that XLA handles compile-time constants in defuns."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.float32, dtypes.int32, dtypes.int32)
def Foo(a, c, d):
@@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = aval + bval * 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtypes.float32, name="a")
b = array_ops.placeholder(dtypes.float32, name="b")
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 132e42ac7a..8c018cccb8 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -83,7 +83,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
# To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format)
@@ -126,7 +126,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
# To avoid constant folding
x_val_converted = test_utils.ConvertBetweenDataFormats(
x_val, data_format_src, data_format)
@@ -210,7 +210,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats(
@@ -260,7 +260,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
var_val = np.random.random_sample(scale_shape).astype(np.float32)
data_format_src = "NHWC"
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
grad_val_converted = test_utils.ConvertBetweenDataFormats(
grad_val, data_format_src, data_format)
x_val_converted = test_utils.ConvertBetweenDataFormats(
diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py
index 23b0aed34f..7161f4ab33 100644
--- a/tensorflow/compiler/tests/gather_nd_op_test.py
+++ b/tensorflow/compiler/tests/gather_nd_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class GatherNdTest(xla_test.XLATestCase):
def _runGather(self, params, indices):
- with self.test_session():
+ with self.cached_session():
paramsp = array_ops.placeholder(params.dtype)
indicesp = array_ops.placeholder(indices.dtype)
with self.test_scope():
@@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase):
np.array([[4], [4], [0]], np.int32)))
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
- with self.test_session():
+ with self.cached_session():
params = np.ones((3, 3), dtype=np.float32)
indices_empty = np.empty((0, 2), dtype=np.int32)
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index e9c8ef7c91..089d95daab 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase):
return data
def testScalar1D(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in self.all_tf_types:
for indices in 4, [4], [1, 2, 2, 4, 5]:
@@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in self.all_tf_types:
@@ -69,7 +69,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in self.all_tf_types:
@@ -87,7 +87,7 @@ class GatherTest(xla_test.XLATestCase):
if np.int64 not in self.int_types:
return
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
# The indices must be in bounds for any axis.
@@ -114,7 +114,7 @@ class GatherTest(xla_test.XLATestCase):
for axis in 0, 1, 2, 3, -1, -2:
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
tf_params = array_ops.placeholder(dtype=dtype)
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
@@ -123,7 +123,7 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
- with self.test_session():
+ with self.cached_session():
for dtype in self.numeric_tf_types:
params = array_ops.placeholder(dtype=dtype)
indices = array_ops.placeholder(dtype=np.int32)
@@ -137,7 +137,7 @@ class GatherTest(xla_test.XLATestCase):
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
def testGatherPrecision(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
[0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
indices = np.array([1, 2, 3, 1])
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index bf986ade06..6fe5a66e0e 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -54,7 +54,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
inp = GenerateNumpyRandomRGB(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch0 = array_ops.placeholder(nptype, shape=shape)
with self.test_scope():
batch1 = image_ops.rgb_to_hsv(batch0)
@@ -78,7 +78,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
for nptype in self.float_types:
rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255.
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(nptype)
with self.test_scope():
hsv = image_ops.rgb_to_hsv(placeholder)
@@ -97,7 +97,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
for r, g, b in rgb_flat
])
hsv_np = hsv_np.reshape(4, 4, 4, 3)
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(nptype)
with self.test_scope():
hsv_op = image_ops.rgb_to_hsv(placeholder)
@@ -108,7 +108,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
class AdjustContrastTest(xla_test.XLATestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_np.shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -146,7 +146,7 @@ class AdjustContrastTest(xla_test.XLATestCase):
return y_np
def _adjustContrastTf(self, x_np, contrast_factor):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(np.float32)
with self.test_scope():
y = image_ops.adjust_contrast(x, contrast_factor)
@@ -180,7 +180,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -198,7 +198,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -216,7 +216,7 @@ class AdjustHueTest(xla_test.XLATestCase):
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
flt_x = image_ops.convert_image_dtype(x, dtypes.float32)
with self.test_scope():
@@ -244,7 +244,7 @@ class AdjustHueTest(xla_test.XLATestCase):
return y_v.reshape(x_np.shape)
def _adjustHueTf(self, x_np, delta_h):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32)
with self.test_scope():
y = gen_image_ops.adjust_hue(x, delta_h)
@@ -324,7 +324,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np})
@@ -339,7 +339,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(x_np.dtype, shape=x_shape)
y = self._adjust_saturation(x, saturation_factor)
y_tf = y.eval({x: x_np})
@@ -378,7 +378,7 @@ class AdjustSaturationTest(xla_test.XLATestCase):
"gb_same",
"rgb_same",
]
- with self.test_session():
+ with self.cached_session():
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
@@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase):
image_np,
target_shape,
expected=None,
- large_tolerance=False):
+ large_tolerance=False,
+ align_corners=True):
if expected is None:
self.fail("expected must be specified")
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
image = array_ops.placeholder(image_np.dtype)
resized = gen_image_ops.resize_bilinear(
- image, target_shape, align_corners=True)
+ image, target_shape, align_corners=align_corners)
out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]})
if large_tolerance:
self.assertAllClose(
@@ -433,7 +434,7 @@ class ResizeBilinearTest(xla_test.XLATestCase):
self.fail("input_shape must be specified")
if expected is None:
self.fail("expected must be specified")
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
dtype = dtype or np.float32
grads = array_ops.placeholder(np.float32)
resized = gen_image_ops.resize_bilinear_grad(
@@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase):
dtype=np.float32)),
large_tolerance=True)
+ def testNonAlignCorners3x2To6x4(self):
+ input_data = [[64, 32], [32, 64], [50, 100]]
+ expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0],
+ [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0],
+ [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [6, 4],
+ expected=np.array(expected_data, dtype=np.float32),
+ align_corners=False)
+
+ def testNonAlignCorners6x4To3x2(self):
+ input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127],
+ [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]]
+ expected_data = [[127, 64], [64, 127], [50, 100]]
+ for dtype in self.float_types:
+ self._assertForwardOpMatchesExpected(
+ np.array(input_data, dtype=dtype), [3, 2],
+ expected=np.array(expected_data, dtype=dtype),
+ align_corners=False)
+
class NonMaxSuppressionTest(xla_test.XLATestCase):
@@ -596,7 +618,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@@ -639,7 +661,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.0, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
@@ -686,7 +708,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
iou_threshold_np = np.array(0.5, dtype=np.float32)
score_threshold_np = np.array(0.4, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6e0db54b7a..0839fb123e 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase):
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
- "--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_fusion_only=true "
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
index 45a04f0cf5..58622114e4 100644
--- a/tensorflow/compiler/tests/listdiff_op_test.py
+++ b/tensorflow/compiler/tests/listdiff_op_test.py
@@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
index 253b45902f..c6ad67993e 100644
--- a/tensorflow/compiler/tests/lrn_ops_test.py
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase):
return output
def _RunAndVerify(self, dtype):
- with self.test_session():
+ with self.cached_session():
# random shape
shape = np.random.randint(1, 16, size=4)
# Make depth at least 2 to make it meaningful
@@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase):
alpha = 1.0 * np.random.rand()
beta = 1.0 * np.random.rand()
- with self.test_session():
+ with self.cached_session():
in_image = constant_op.constant(in_image_vals, shape=shape)
out_image = constant_op.constant(out_image_vals, shape=shape)
out_grads = constant_op.constant(out_grads_vals, shape=shape)
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
index 31093c6571..265c0b6d14 100644
--- a/tensorflow/compiler/tests/lstm_test.py
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -73,7 +73,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
pad_scalar):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_inputs = 1
num_nodes = 1
@@ -156,7 +156,7 @@ class LSTMTest(test.TestCase):
def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
pad_scalar):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_inputs = 1
num_nodes = 1
seq_length = 3
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 0d9f99f8a6..9222db4b7e 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class MatrixBandPartTest(xla_test.XLATestCase):
def _testMatrixBandPart(self, dtype, shape):
- with self.test_session():
+ with self.cached_session():
batch_shape = shape[:-2]
mat = np.ones(shape).astype(dtype)
batch_mat = np.tile(mat, batch_shape + [1, 1])
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 2bb8a97bda..94cd3eeb31 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
clean_a = np.tril(a) if lower else np.triu(a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder_a = MakePlaceholder(a)
placeholder_ca = MakePlaceholder(clean_a)
placeholder_b = MakePlaceholder(b)
diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py
index c2592c54cf..f77521a7c4 100644
--- a/tensorflow/compiler/tests/momentum_test.py
+++ b/tensorflow/compiler/tests/momentum_test.py
@@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -95,7 +95,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testNesterovMomentum(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype)
var0_np = np.array([0.1, 0.2], dtype=dtype)
@@ -120,7 +120,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index da08225e9f..a1c07fce73 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase):
[[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
def testOneHot(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32))
op = array_ops.one_hot(indices,
np.int32(4),
@@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase):
self.assertAllEqual(output, expected)
def testSplitV(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
output = session.run(
array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index 2f9122645d..f985c5d2d9 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest
class NullaryOpsTest(xla_test.XLATestCase):
def _testNullary(self, op, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
output = op()
result = session.run(output)
self.assertAllClose(result, expected, rtol=1e-3)
def testNoOp(self):
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
output = control_flow_ops.no_op()
# This should not crash.
diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py
index d68d32057a..7635f89249 100644
--- a/tensorflow/compiler/tests/oom_test.py
+++ b/tensorflow/compiler/tests/oom_test.py
@@ -46,7 +46,7 @@ class OutOfMemoryTest(xla_test.XLATestCase):
def test_loop():
size = int(2e8)
while True:
- with self.test_session():
+ with self.cached_session():
# Force the compiled code to not be constant by feeding in a
# parameter.
p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
index a75d99189b..77bb839409 100644
--- a/tensorflow/compiler/tests/placeholder_test.py
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest
class PlaceholderTest(xla_test.XLATestCase):
def test_placeholder_with_default_default(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
@@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase):
self.assertEqual(8.0, sess.run(out))
def test_placeholder_with_default_fed(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(4.0)
ph = array_ops.placeholder_with_default(v, shape=[])
out = ph * 2
diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py
index 17f860db61..b6cdd38345 100644
--- a/tensorflow/compiler/tests/pooling_ops_3d_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py
@@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase):
# numbers from 1.
x = np.arange(1.0, total_size + 1, dtype=np.float32)
x = x.reshape(input_sizes)
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
inputs = array_ops.placeholder(dtypes.float32)
t = pool_func(
inputs,
@@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase):
strides = [1] + strides + [1]
total_size = np.prod(input_sizes)
x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
with ops.device("CPU"):
diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
index 9fc94752ea..d03bd4fdbb 100644
--- a/tensorflow/compiler/tests/pooling_ops_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase):
# numbers from 1.
x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
x = x.reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
inputs = array_ops.placeholder(dtypes.float32)
t = inputs
@@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase):
# TODO(b/74222344): Fix nan handling for max pool grad.
# x[np.random.choice(total_size)] = np.nan
x = x.reshape(input_sizes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Use the forward pool function to compute some corresponding outputs
# (needed for the CPU device, and we need the shape in both cases).
with ops.device(self.CPU_DEVICE):
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
index 5fa7706d72..86536da7fe 100644
--- a/tensorflow/compiler/tests/powersign_test.py
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase):
base=math.e,
beta=0.9):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
m0, m1 = 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py
index cde87db63d..c41b4171e2 100644
--- a/tensorflow/compiler/tests/proximal_adagrad_test.py
+++ b/tensorflow/compiler/tests/proximal_adagrad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad
class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
def testResourceProximalAdagradwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -60,7 +60,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertEqual(2, len(opt_vars))
def testProximalAdagradwithoutRegularization2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -84,7 +84,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval())
def testProximalAdagradWithL1(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -108,7 +108,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval())
def testProximalAdagradWithL1_L2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -151,7 +151,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
return var0.eval(), var1.eval()
def testEquivAdagradwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -159,7 +159,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1))
diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
index 11eb768711..3d808e6b8a 100644
--- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py
+++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent
class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
def testResourceProximalGradientDescentwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([-0.09, -0.18]), var1.eval())
def testProximalGradientDescentwithoutRegularization2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.91, 2.82]), var1.eval())
def testProximalGradientDescentWithL1(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
self.assertAllClose(np.array([3.67, 2.37]), var1.eval())
def testProximalGradientDescentWithL1_L2(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
return var0.eval(), var1.eval()
def testEquivGradientDescentwithoutRegularization(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0))
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
index 1b969ee2b3..236b1b881d 100644
--- a/tensorflow/compiler/tests/qr_op_test.py
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tf = array_ops.placeholder(dtype)
with self.test_scope():
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
@@ -101,8 +101,8 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
@parameterized.parameters(*PARAMS)
def testQR(self, rows, cols, dtype):
- # TODO(b/111317468): implement full_matrices=False, test other types.
- for full_matrices in [True]:
+ # TODO(b/111317468): Test other types.
+ for full_matrices in [True, False]:
# Only tests the (3, 2) case for small numbers of rows/columns.
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
self._test(dtype, batch_dims + (rows, cols), full_matrices)
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 8c4e16e4e0..6e18344117 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -39,7 +39,7 @@ class RandomOpsTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = rng(dtype)
@@ -79,7 +79,7 @@ class RandomOpsTest(xla_test.XLATestCase):
if (self.device in ["XLA_GPU", "XLA_CPU"
]) and (dtype in [dtypes.bfloat16, dtypes.half]):
continue
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
@@ -99,7 +99,7 @@ class RandomOpsTest(xla_test.XLATestCase):
count = 10000000
# TODO(b/34339814): implement inverse erf support for non-F32 types.
for dtype in [dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
y = sess.run(x)
@@ -147,7 +147,7 @@ class RandomOpsTest(xla_test.XLATestCase):
# TODO(b/26783907): this test requires the CPU backend to implement sort.
if self.device in ["XLA_CPU"]:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
shuffle = random_ops.random_shuffle(x)
@@ -158,7 +158,7 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllEqual(set(result), set(expected))
def testShuffle2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index c0ea242044..bddda6f302 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,8 @@ limitations under the License.
#include <random>
#include <unordered_map>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -61,7 +63,6 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main()
bool tf_xla_test_use_jit = true;
string LocalDeviceToFullDeviceName(const string& device) {
- return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
+ return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
}
constexpr std::array<DataType, 5> kAllXlaTypes = {
@@ -107,11 +108,12 @@ class OpTestBuilder {
// Sets an attribute.
template <class T>
- OpTestBuilder& Attr(StringPiece attr_name, T&& value);
+ OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
// Overload needed to allow {...} expressions for value.
template <class T>
- OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
+ OpTestBuilder& Attr(absl::string_view attr_name,
+ std::initializer_list<T> value);
// Adds nodes that executes the operator under test on 'device' to 'graphdef'.
// If 'use_jit' is true, marks the operator under test to be compiled by XLA.
@@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
}
template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
return *this;
}
template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name,
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
std::initializer_list<T> value) {
Attr<std::initializer_list<T>>(attr_name, std::move(value));
return *this;
@@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
NodeDef* test_def = graphdef->add_node();
*test_def = node_def_;
- test_def->set_name(strings::StrCat(name_prefix, "_op_under_test"));
+ test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
test_def->set_device(device);
AddDefaultsToNodeDef(*op_def, test_def);
if (use_jit) {
@@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
// Build feed and fetch nodes.
for (int i = 0; i < input_types.size(); ++i) {
NodeDef* def = graphdef->add_node();
- string name = strings::StrCat(name_prefix, "_input_", i);
+ string name = absl::StrCat(name_prefix, "_input_", i);
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
.Device(device)
.Attr("dtype", input_types[i])
@@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
for (int i = 0; i < output_types.size(); ++i) {
NodeDef* def = graphdef->add_node();
- string name = strings::StrCat(name_prefix, "_output_", i);
+ string name = absl::StrCat(name_prefix, "_output_", i);
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
.Device(device)
.Attr("T", output_types[i])
@@ -275,13 +277,13 @@ class OpTest : public ::testing::Test {
// Select a random element from 'candidates'.
template <typename T>
- T Choose(gtl::ArraySlice<T> candidates);
+ T Choose(absl::Span<const T> candidates);
static constexpr int kDefaultMaxRank = 5;
static constexpr int64 kDefaultMaxDimensionSize = 256LL;
// Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
- bool TensorSizeIsOk(gtl::ArraySlice<int64> dims);
+ bool TensorSizeIsOk(absl::Span<const int64> dims);
// Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
@@ -307,11 +309,11 @@ class OpTest : public ::testing::Test {
// of the type's range. If the shape is omitted, a random shape is used.
// TODO(phawkins): generalize this code to a caller-supplied distribution.
Tensor RandomTensor(DataType dtype, bool needs_unique_values,
- gtl::ArraySlice<int64> shape);
+ absl::Span<const int64> shape);
Tensor RandomTensor(DataType dtype);
// Like RandomTensor, but uses values >= 0.
- Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice<int64> shape);
+ Tensor RandomNonNegativeTensor(DataType dtype, absl::Span<const int64> shape);
Tensor RandomNonNegativeTensor(DataType dtype);
// Returns a random subset of the integers in the range [0, rank), suitable
@@ -415,7 +417,7 @@ void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
}
template <typename T>
-T OpTest::Choose(gtl::ArraySlice<T> candidates) {
+T OpTest::Choose(absl::Span<const T> candidates) {
std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
return candidates[d(generator())];
}
@@ -425,7 +427,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) {
return size_distribution(generator());
}
-bool OpTest::TensorSizeIsOk(gtl::ArraySlice<int64> dims) {
+bool OpTest::TensorSizeIsOk(absl::Span<const int64> dims) {
int64 size = 1LL;
for (int64 dim : dims) {
size *= dim;
@@ -451,7 +453,7 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
}
Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
- gtl::ArraySlice<int64> shape) {
+ absl::Span<const int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
@@ -548,7 +550,7 @@ Tensor OpTest::RandomTensor(DataType dtype) {
}
Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
- gtl::ArraySlice<int64> shape) {
+ absl::Span<const int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
@@ -726,11 +728,11 @@ bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
template <typename T>
string Str(T x) {
- return strings::StrCat(x);
+ return absl::StrCat(x);
}
template <>
string Str<complex64>(complex64 x) {
- return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
+ return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
}
template <typename T>
@@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
auto Ty = y.flat<T>();
for (int i = 0; i < Tx.size(); ++i) {
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
- return errors::InvalidArgument(strings::StrCat(
- i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
- Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
- "atol = ", atol, " rtol = ", rtol,
- " tol = ", atol + rtol * Abs(Tx(i))));
+ return errors::InvalidArgument(
+ absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
+ " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
+ "y = ", y.DebugString(), "atol = ", atol,
+ " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
}
}
return Status::OK();
@@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
auto Ty = y.flat<T>();
for (int i = 0; i < Tx.size(); ++i) {
if (Tx(i) != Ty(i)) {
- return errors::InvalidArgument(strings::StrCat(
+ return errors::InvalidArgument(absl::StrCat(
i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i),
". x = ", x.DebugString(), "y = ", y.DebugString()));
}
@@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
double rtol) {
if (a.dtype() != b.dtype()) {
- return errors::InvalidArgument(strings::StrCat(
+ return errors::InvalidArgument(absl::StrCat(
"Tensors have different types: ", DataTypeString(a.dtype()), " and ",
DataTypeString(b.dtype())));
}
if (!a.IsSameSize(b)) {
- return errors::InvalidArgument(strings::StrCat(
- "Tensors have different shapes: ", a.shape().DebugString(), " and ",
- b.shape().DebugString()));
+ return errors::InvalidArgument(
+ absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
+ " and ", b.shape().DebugString()));
}
switch (a.dtype()) {
@@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
}
string cpu_device =
- LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
+ LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0"));
string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
DeviceNameUtils::ParsedName parsed_name;
@@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
std::vector<string> expected_inputs, test_inputs;
std::vector<string> expected_fetches, test_fetches;
Status status = builder.BuildGraph(
- strings::StrCat("test", num_tests_, "_expected"), cpu_device,
+ absl::StrCat("test", num_tests_, "_expected"), cpu_device,
/* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
&expected_inputs, &expected_fetches);
if (!status.ok()) {
@@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
}
NodeDef* node_def;
- status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
+ status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
test_device, tf_xla_test_use_jit, &graph,
&node_def, &test_inputs, &test_fetches);
if (!status.ok()) {
@@ -1884,7 +1886,8 @@ TEST_F(OpTest, DynamicStitch) {
for (int i = 0; i < n; ++i) {
TensorShape shape(index_dims[i]);
Tensor t = test::AsTensor<int32>(
- gtl::ArraySlice<int32>(indices, pos, shape.num_elements()), shape);
+ absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
+ shape);
builder.Input(t);
pos += t.NumElements();
}
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index cea2ec816f..132c59c32c 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import itertools
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -30,22 +31,24 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ReduceOpsTest(xla_test.XLATestCase):
-
+@parameterized.named_parameters(('32_bit_index', dtypes.int32),
+ ('64_bit_index', dtypes.int64))
+class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _testReduction(self,
tf_reduce_fn,
np_reduce_fn,
dtype,
test_inputs,
+ index_dtype,
rtol=1e-4,
atol=1e-4):
"""Tests that the output of 'tf_reduce_fn' matches numpy's output."""
for test_input in test_inputs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
- index = array_ops.placeholder(dtypes.int32)
+ index = array_ops.placeholder(index_dtype)
out = tf_reduce_fn(a, index)
result = sess.run(out, {a: test_input, index: [0]})
self.assertAllClose(
@@ -89,22 +92,23 @@ class ReduceOpsTest(xla_test.XLATestCase):
np.array([[False, True, False], [True, True, False]]),
]
- def testReduceSumF32(self):
- self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA)
+ def testReduceSumF32(self, index_dtype):
+ self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
+ index_dtype)
- def testReduceSumC64(self):
+ def testReduceSumC64(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
- self.COMPLEX_DATA)
+ self.COMPLEX_DATA, index_dtype)
- def testReduceProdF32(self):
+ def testReduceProdF32(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceProdC64(self):
+ def testReduceProdC64(self, index_dtype):
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
- self.COMPLEX_DATA)
+ self.COMPLEX_DATA, index_dtype)
- def testReduceMin(self):
+ def testReduceMin(self, index_dtype):
def reference_min(dtype, inp, axis):
"""Wrapper around np.amin that returns +infinity for an empty input."""
@@ -119,9 +123,9 @@ class ReduceOpsTest(xla_test.XLATestCase):
[np.float32, np.int32, np.int64]):
self._testReduction(math_ops.reduce_min,
functools.partial(reference_min, dtype), dtype,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceMax(self):
+ def testReduceMax(self, index_dtype):
def reference_max(dtype, inp, axis):
"""Wrapper around np.amax that returns -infinity for an empty input."""
@@ -137,23 +141,25 @@ class ReduceOpsTest(xla_test.XLATestCase):
[np.float32, np.int32, np.int64]):
self._testReduction(math_ops.reduce_max,
functools.partial(reference_max, dtype), dtype,
- self.REAL_DATA)
+ self.REAL_DATA, index_dtype)
- def testReduceMeanF32(self):
+ def testReduceMeanF32(self, index_dtype):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
- self.NONEMPTY_REAL_DATA)
+ self.NONEMPTY_REAL_DATA, index_dtype)
- def testReduceMeanC64(self):
+ def testReduceMeanC64(self, index_dtype):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
- self.NONEMPTY_COMPLEX_DATA)
+ self.NONEMPTY_COMPLEX_DATA, index_dtype)
- def testReduceAll(self):
- self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
+ def testReduceAll(self, index_dtype):
+ self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA,
+ index_dtype)
- def testReduceAny(self):
- self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+ def testReduceAny(self, index_dtype):
+ self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
+ index_dtype)
class ReduceOpPrecisionTest(xla_test.XLATestCase):
@@ -178,7 +184,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase):
"""
for test_input in test_inputs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
a = array_ops.placeholder(dtype)
index = array_ops.placeholder(dtypes.int32)
@@ -213,7 +219,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase):
bf16_max = np.float32(dtypes.bfloat16.max)
f32_max = dtypes.float32.max
- value = min(bf16_max, f32_max - bf16_max)
+ value = min(bf16_max, f32_max - bf16_max) / 2
self._testReduceSum(
dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py
index c69b6837b0..ff20ea3f42 100644
--- a/tensorflow/compiler/tests/reduce_window_test.py
+++ b/tensorflow/compiler/tests/reduce_window_test.py
@@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase):
"""Test cases for xla.reduce_window."""
def _reduce_window(self, operand, init, reducer, **kwargs):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(operand.dtype)
with self.test_scope():
output = xla.reduce_window(placeholder, init, reducer, **kwargs)
diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py
new file mode 100644
index 0000000000..96e0b07475
--- /dev/null
+++ b/tensorflow/compiler/tests/reshape_op_test.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slicing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(('32_bit_index', dtypes.int32),
+ ('64_bit_index', dtypes.int64))
+ def testBasic(self, index_dtype):
+ for dtype in self.numeric_types:
+ with self.cached_session():
+ i = array_ops.placeholder(dtype, shape=[2, 3])
+ with self.test_scope():
+ shape = constant_op.constant([3, 2], dtype=index_dtype)
+ o = array_ops.reshape(i, shape)
+ params = {
+ i: [[1, 2, 3], [4, 5, 6]],
+ }
+ result = o.eval(feed_dict=params)
+
+ self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py
index d01c676e7c..392290fd92 100644
--- a/tensorflow/compiler/tests/reverse_ops_test.py
+++ b/tensorflow/compiler/tests/reverse_ops_test.py
@@ -32,33 +32,40 @@ class ReverseOpsTest(xla_test.XLATestCase):
def testReverseOneDim(self):
shape = (7, 5, 9, 11)
- for revdim in range(len(shape)):
+ for revdim in range(-len(shape), len(shape)):
self._AssertReverseEqual([revdim], shape)
def testReverseMoreThanOneDim(self):
shape = (7, 5, 9, 11)
+ # The offset is used to test various (but not all) combinations of negative
+ # and positive axis indices that are guaranteed to not collide at the same
+ # index.
for revdims in itertools.chain.from_iterable(
- itertools.combinations(range(len(shape)), k)
- for k in range(2, len(shape)+1)):
+ itertools.combinations(range(-offset,
+ len(shape) - offset), k)
+ for k in range(2,
+ len(shape) + 1)
+ for offset in range(0, len(shape))):
self._AssertReverseEqual(revdims, shape)
def _AssertReverseEqual(self, revdims, shape):
np.random.seed(120)
pval = np.random.randint(0, 100, size=shape).astype(float)
- with self.test_session():
+ with self.cached_session():
with self.test_scope():
p = array_ops.placeholder(dtypes.int32, shape=shape)
axis = constant_op.constant(
np.array(revdims, dtype=np.int32),
- shape=(len(revdims),), dtype=dtypes.int32)
+ shape=(len(revdims),),
+ dtype=dtypes.int32)
rval = array_ops.reverse(p, axis).eval({p: pval})
slices = [
- slice(-1, None, -1) if d in revdims else slice(None)
- for d in range(len(shape))]
- self.assertEqual(
- pval[slices].flatten().tolist(),
- rval.flatten().tolist())
+ slice(-1, None, -1)
+ if d in revdims or d - len(shape) in revdims else slice(None)
+ for d in range(len(shape))
+ ]
+ self.assertEqual(pval[slices].flatten().tolist(), rval.flatten().tolist())
if __name__ == '__main__':
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index ccfa630016..60c2337743 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
seq_lengths,
truth,
expected_err_re=None):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
with self.test_scope():
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index ff8bbac911..8840a1329a 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
for centered in [False, True]:
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 4292352e76..897db384b7 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval(
feed_dict={p: x})
@@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes():
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype)
math_ops.cumsum(p, axis).eval(feed_dict={p: x})
@@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase):
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
@@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase):
def _compare(self, x, axis, exclusive, reverse):
np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
prod = math_ops.cumprod(p, axis, exclusive, reverse)
tf_out = prod.eval(feed_dict={p: x})
@@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase):
for dtype in self.valid_dtypes:
x = np.arange(1, 6).reshape([5]).astype(dtype)
for axis_dtype in self.axis_dtypes():
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
p = array_ops.placeholder(x.dtype)
axis = constant_op.constant(0, axis_dtype)
math_ops.cumprod(x, axis).eval(feed_dict={p: x})
@@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase):
def testInvalidAxis(self):
x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
input_tensor = ops.convert_to_tensor(x)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py
index f606f88545..693f8513bc 100644
--- a/tensorflow/compiler/tests/scatter_nd_op_test.py
+++ b/tensorflow/compiler/tests/scatter_nd_op_test.py
@@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
def _runScatterNd(self, indices, updates, shape):
- with self.test_session():
+ with self.cached_session():
updates_placeholder = array_ops.placeholder(updates.dtype)
indices_placeholder = array_ops.placeholder(indices.dtype)
with self.test_scope():
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
index 772c20fd42..287bb0d84e 100644
--- a/tensorflow/compiler/tests/segment_reduction_ops_test.py
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase):
"""Test cases for segment reduction ops."""
def _segmentReduction(self, op, data, indices, num_segments):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
d = array_ops.placeholder(data.dtype, shape=data.shape)
if isinstance(indices, int):
i = array_ops.placeholder(np.int32, shape=[])
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py
index 6c4890565d..2c611a959e 100644
--- a/tensorflow/compiler/tests/slice_ops_test.py
+++ b/tensorflow/compiler/tests/slice_ops_test.py
@@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.slice(i, [2], [4])
@@ -40,9 +40,22 @@ class SliceTest(xla_test.XLATestCase):
self.assertAllEqual([2, 3, 4, 5], result)
+ def testZeroSlice(self):
+ for dtype in self.numeric_types:
+ with self.cached_session():
+ i = array_ops.placeholder(dtype, shape=[2])
+ with self.test_scope():
+ o = array_ops.slice(i, [0], [0])
+ params = {
+ i: [0, 1],
+ }
+ result = o.eval(feed_dict=params)
+
+ self.assertAllEqual([], result)
+
def test3D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope():
o = array_ops.slice(i, [1, 2, 2], [1, 1, 4])
@@ -64,7 +77,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBegin(self):
"""Tests a slice where the start offset is not known at compile time."""
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope():
@@ -88,7 +101,7 @@ class SliceTest(xla_test.XLATestCase):
def test3DWithDynamicBeginAndNegativeSize(self):
"""Tests a slice where `begin` is fed dynamically and `size` contains -1."""
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
begin = array_ops.placeholder(dtypes.int32, shape=[3])
with self.test_scope():
@@ -114,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.strided_slice(i, [2], [6], [2])
@@ -127,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test1DNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[10])
with self.test_scope():
o = array_ops.strided_slice(i, [6], [2], [-2])
@@ -140,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerate(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
o = array_ops.strided_slice(i, [-1, 0], [0, 3])
@@ -154,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test2DDegenerateNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[2, 3])
with self.test_scope():
o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1])
@@ -168,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3D(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 3, 10])
with self.test_scope():
o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2])
@@ -189,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase):
def test3DNegativeStride(self):
for dtype in self.numeric_types:
- with self.test_session():
+ with self.cached_session():
i = array_ops.placeholder(dtype, shape=[3, 4, 10])
with self.test_scope():
o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2])
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 7ff01be3cb..51c04b5c47 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class XlaSortOpTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -131,7 +131,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=4)
@@ -153,7 +153,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
if bfloat16 not in self.numeric_types:
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtypes.bfloat16)
with self.test_scope():
topk = nn_ops.top_k(p, k=6)
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index c685bc548f..33b84cec71 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self.float_types:
# outputs = space_to_batch(inputs)
placeholder = array_ops.placeholder(dtype)
@@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase):
def _testPad(self, inputs, block_shape, paddings, outputs):
block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self.float_types:
# TODO(b/68813416): Skip bfloat16's as the input type for direct is
# float32 and results in a mismatch, while making testDirect provide the
diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
index 3db8101c4b..07afd1ab3f 100644
--- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py
+++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
@@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices,
class SparseToDenseTest(xla_test.XLATestCase):
def testInt(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, 0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def testFloat(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0)
np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
self.assertAllClose(np_ans, tf_ans)
def testSetValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1)
np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def testSetSingleValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([1, 3], [5], 1, -1)
np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
self.assertAllClose(np_ans, tf_ans)
def test2d(self):
# pylint: disable=bad-whitespace
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1)
np_ans = np.array([[-1, -1, -1, -1],
[-1, -1, -1, 1],
@@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self):
- with self.test_session():
+ with self.cached_session():
x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0])
def test3d(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1)
np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
np_ans[1, 3, 0] = 1
@@ -91,25 +91,25 @@ class SparseToDenseTest(xla_test.XLATestCase):
self.assertAllClose(np_ans, tf_ans)
def testBadShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1)
def testBadValue(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[2,1\], "
r"should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [[5], [3]], -1)
def testBadNumValues(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [1, 2, 3], -1)
def testBadDefault(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
with self.assertRaisesOpError("default_value should be a scalar"):
_SparseToDense([1, 3], [5], [1, 2], [0])
diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py
index b7dd787fef..720595a159 100644
--- a/tensorflow/compiler/tests/stack_ops_test.py
+++ b/tensorflow/compiler/tests/stack_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import test
class StackOpTest(xla_test.XLATestCase):
def testStackPushPop(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
size = array_ops.placeholder(dtypes.int32)
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo")
@@ -41,7 +41,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]}))
def testStackPushPopSwap(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
a = np.arange(2000)
x = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
@@ -51,7 +51,7 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(a, c1.eval({x: a}))
def testMultiStack(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_push_v2(h1, v)
@@ -66,7 +66,7 @@ class StackOpTest(xla_test.XLATestCase):
def testSameNameStacks(self):
"""Different stacks with the same name do not interfere."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v1 = array_ops.placeholder(dtypes.float32)
v2 = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
@@ -84,14 +84,14 @@ class StackOpTest(xla_test.XLATestCase):
self.assertAllClose(out2, 5.0)
def testCloseStack(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
size = array_ops.placeholder(dtypes.int32)
h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_close_v2(h)
sess.run(c1, {size: 5})
def testPushCloseStack(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops.stack_push_v2(h, v)
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index d162675ef8..1bea7d9355 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -38,7 +38,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for stateless_op in [
@@ -55,7 +55,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testRandomUniformIsInRange(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_uniform(
@@ -74,7 +74,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomUniform(self):
"""Use Pearson's Chi-squared test to test for uniformity."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000
@@ -88,7 +88,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._chi_squared(y, 10) < 16.92)
def testRandomNormalIsFinite(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless.stateless_random_uniform(
@@ -111,7 +111,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testDistributionOfStatelessRandomNormal(self):
"""Use Anderson-Darling test to test distribution appears normal."""
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 1000
@@ -126,7 +126,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
def testTruncatedNormalIsInRange(self):
# TODO(b/34339814): implement inverse erf support for non-F32 types.
for dtype in [dtypes.float32]:
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
x = stateless.stateless_truncated_normal(
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index f332aa2e9b..78244d0b36 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -44,7 +44,7 @@ def _make_converter(dtype):
class TensorArrayTest(xla_test.XLATestCase):
def testTensorArrayWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -66,7 +66,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -86,7 +86,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWritePack(dtype)
def testEmptyTensorArrayPack(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -100,7 +100,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([3, 0, 1], c0.eval().shape)
def _testTensorArrayWriteConcat(self, tf_dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -121,7 +121,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteConcat(dtype)
def _testTensorArrayUnpackRead(self, tf_dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -176,7 +176,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayUnpackReadMaybeLegacy()
def _testTensorArraySplitRead(self, tf_dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
@@ -228,7 +228,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArraySplitRead(dtype)
def testTensorGradArrayWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -261,7 +261,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[-2.0]], g_d2)
def testTensorGradArrayDynamicWriteRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -300,7 +300,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(3, g_vs)
def testTensorGradAccessTwiceReceiveSameObject(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3,
element_shape=[1, 2])
@@ -317,7 +317,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
@@ -331,7 +331,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# the first type, but try to read the other type.
if len(self.float_types) > 1:
dtype1, dtype2 = list(self.float_types)[:2]
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtype1, tensor_array_name="foo", size=3)
@@ -347,7 +347,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w0.read(1)
def testTensorArraySplitIncompatibleShapesFails(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -379,7 +379,7 @@ class TensorArrayTest(xla_test.XLATestCase):
ta.split([1.0], [1]).flow.eval()
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
@@ -410,7 +410,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
def testMultiTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
h1 = tensor_array_ops.TensorArray(
size=1, dtype=dtypes.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
@@ -425,7 +425,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllClose(9.0, r.eval())
def _testTensorArrayGradientWriteReadType(self, dtype):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.as_dtype(dtype),
tensor_array_name="foo",
@@ -478,7 +478,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):
- with self.test_session() as sess, self.test_scope():
+ with self.cached_session() as sess, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -513,7 +513,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientWritePackConcatAndRead()
def testTensorArrayReadTwice(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
ta_readtwice = tensor_array_ops.TensorArray(
@@ -529,7 +529,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
def _testTensorArrayGradientUnpackRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -557,7 +557,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayGradientUnpackRead()
def testTensorArrayGradientSplitConcat(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2)
@@ -581,21 +581,21 @@ class TensorArrayTest(xla_test.XLATestCase):
grad_vals[0])
def testCloseTensorArray(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c1 = ta.close()
session.run(c1)
def testSizeTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
s = ta.size()
self.assertAllEqual(3, s.eval())
def testWriteCloseTensorArray(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -608,7 +608,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# TODO(phawkins): implement while loops.
# def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
# np_dtype = dtype.as_numpy_dtype
- # with self.test_session() as session, self.test_scope():
+ # with self.cached_session() as session, self.test_scope():
# v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
# var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
# state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
@@ -692,7 +692,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# dynamic_size=True, dtype=dtypes.float32)
# def testGradSerialTwoLoops(self):
- # with self.test_session(), self.test_scope():
+ # with self.cached_session(), self.test_scope():
# num_steps = 100
# acc = tensor_array_ops.TensorArray(
# dtype=dtypes.float32,
@@ -725,7 +725,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# self.assertAllClose(31.0, grad.eval())
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
a = array_ops.identity(
np.arange(
3 * 5, dtype=np.float32).reshape(3, 5) + 1)
@@ -757,7 +757,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(joint_grad_b_t, g0)
def testWriteShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
c0 = constant_op.constant([4.0, 5.0])
@@ -781,7 +781,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w0.write(0, c2)
def testPartlyUnknownShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=6)
@@ -821,7 +821,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
def _testUnpackShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -846,7 +846,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testUnpackShape()
def testSplitShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -867,7 +867,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def testWriteUnknownShape(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -879,7 +879,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def _testGradientWhenNotAllComponentsRead(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
x = constant_op.constant([2.0, 3.0])
w = ta.unstack(x)
@@ -893,7 +893,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testGradientWhenNotAllComponentsRead()
def _testTensorArrayEvalEmpty(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, infer_shape=False)
with self.assertRaisesOpError(
@@ -906,7 +906,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmpty()
def _testTensorArrayEvalEmptyWithDefault(self):
- with self.test_session(), self.test_scope():
+ with self.cached_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=0, infer_shape=True)
self.assertEqual(0, ta.size().eval())
@@ -921,7 +921,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self._testTensorArrayEvalEmptyWithDefault()
def testTensorArrayScatterReadAndGradients(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -946,7 +946,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
def testTensorArrayWriteGatherAndGradients(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
@@ -974,7 +974,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(expected_grad, grad_vals[0])
def testTensorArrayIdentity(self):
- with self.test_session() as session, self.test_scope():
+ with self.cached_session() as session, self.test_scope():
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
infer_shape=False)
ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index effa5a59fe..55a992195f 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest
class TernaryOpsTest(xla_test.XLATestCase):
def _testTernary(self, op, a, b, c, expected):
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 91f876fa23..dd29ef34ce 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
- with self.test_session() as session:
+ with self.cached_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name="a")
@@ -202,7 +202,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
# Disable float16 testing for now
if dtype != np.float16:
x = np.arange(-10, 10, 1).astype(dtype)
- with self.test_session() as session:
+ with self.cached_session() as session:
erf_x = session.run(math_ops.erf(x))
erfc_x = session.run(math_ops.erfc(x))
@@ -403,6 +403,11 @@ class UnaryOpsTest(xla_test.XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.lgamma,
+ np.array(0.5, dtype=dtype),
+ expected=np.array(np.log(np.pi) / 2, dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ math_ops.lgamma,
np.array(
[[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2],
[-3 / 2, -7 / 2, -11 / 2]],
@@ -425,6 +430,19 @@ class UnaryOpsTest(xla_test.XLATestCase):
],
dtype=dtype))
+ # The actual result is complex. Take the real part.
+ self._assertOpOutputMatchesExpected(
+ math_ops.lgamma,
+ np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype),
+ expected=np.array(
+ [
+ np.log(np.pi) / 2 + np.log(2),
+ np.log(np.pi) / 2 - np.log(15) + np.log(8),
+ np.log(np.pi) / 2 - np.log(945) + np.log(32),
+ ],
+ dtype=dtype),
+ atol=1e-4)
+
self._assertOpOutputMatchesExpected(
math_ops.digamma,
np.array(
diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py
index b637cf31cf..4ee144beb7 100644
--- a/tensorflow/compiler/tests/while_test.py
+++ b/tensorflow/compiler/tests/while_test.py
@@ -43,7 +43,7 @@ class WhileTest(xla_test.XLATestCase):
def loop_cond(step):
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope():
loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)
@@ -65,7 +65,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.float32, [])
with self.test_scope():
@@ -91,7 +91,7 @@ class WhileTest(xla_test.XLATestCase):
del rsum
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
init_sum = array_ops.placeholder(dtypes.complex64, [])
with self.test_scope():
@@ -117,7 +117,7 @@ class WhileTest(xla_test.XLATestCase):
del x
return step < 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_index = array_ops.placeholder(dtypes.int32, [])
with self.test_scope():
loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 85084bb124..28d61fb07d 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types:
for shape in shapes:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape)
with self.test_scope():
@@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
])
shape = (10, 10)
for unsupported_dtype in test_types - self.all_types:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(unsupported_dtype, shape)
with self.test_scope():
@@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase):
pass
def testControlTrigger(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.test_scope():
x = gen_control_flow_ops.control_trigger()
sess.run(x)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
new file mode 100644
index 0000000000..0f3843dc1e
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA op wrappers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected,
+ equality_fn=None):
+ with self.cached_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ result = session.run(output, feeds)
+ if not equality_fn:
+ equality_fn = self.assertAllClose
+ equality_fn(result, expected, rtol=1e-3)
+
+ def testAdd(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.add,
+ args=(np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype)),
+ expected=np.array([5, 7, 9], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(0,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 9], [14, 15]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
+ lambda x, y: xla.add(x, y, broadcast_dims=(1,)),
+ args=(np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([7, 11], dtype=dtype)),
+ expected=np.array([[8, 13], [10, 15]], dtype=dtype))
+
+ def testBroadcast(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.broadcast(x, (7, 42)),
+ args=(v,),
+ expected=np.tile(v, (7, 42, 1, 1)))
+
+ def testShiftRightLogical(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_logical,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
+
+ def testShiftRightArithmetic(self):
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([-1, 16], dtype=np.int32), np.int32(4)),
+ expected=np.array([-1, 1], dtype=np.int32))
+
+ self._assertOpOutputMatchesExpected(
+ xla.shift_right_arithmetic,
+ args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
+ expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
+
+ PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT,
+ xla_data_pb2.PrecisionConfig.HIGH,
+ xla_data_pb2.PrecisionConfig.HIGHEST)
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testConv(self, precision):
+ for dtype in set(self.float_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ def conv_1d_fn(lhs, rhs):
+ dnums = xla_data_pb2.ConvolutionDimensionNumbers()
+ num_spatial_dims = 1
+ dnums.input_batch_dimension = 0
+ dnums.input_feature_dimension = 1
+ dnums.output_batch_dimension = 0
+ dnums.output_feature_dimension = 1
+ dnums.kernel_output_feature_dimension = 0
+ dnums.kernel_input_feature_dimension = 1
+ dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfig()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.conv(
+ lhs,
+ rhs,
+ window_strides=(1,),
+ padding=((2, 1),),
+ lhs_dilation=(1,),
+ rhs_dilation=(2,),
+ dimension_numbers=dnums)
+
+ self._assertOpOutputMatchesExpected(
+ conv_1d_fn,
+ args=(
+ np.array([[[3, 4, 5, 6]]], dtype=dtype),
+ np.array([[[-2, -3]]], dtype=dtype),
+ ),
+ expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
+
+ @parameterized.parameters(*PRECISION_VALUES)
+ def testDotGeneral(self, precision):
+ for dtype in self.float_types:
+
+ def dot_fn(lhs, rhs):
+ dnums = xla_data_pb2.DotDimensionNumbers()
+ dnums.lhs_contracting_dimensions.append(2)
+ dnums.rhs_contracting_dimensions.append(1)
+ dnums.lhs_batch_dimensions.append(0)
+ dnums.rhs_batch_dimensions.append(0)
+ precision_config = None
+ if precision:
+ precision_config = xla_data_pb2.PrecisionConfig()
+ precision_config.operand_precision.extend([precision, precision])
+ return xla.dot_general(
+ lhs,
+ rhs,
+ dimension_numbers=dnums,
+ precision_config=precision_config)
+
+ lhs = np.array(
+ [
+ [[1, 2], [3, 4]],
+ [[5, 6], [7, 8]],
+ ], dtype=dtype)
+ rhs = np.array(
+ [
+ [[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]],
+ ], dtype=dtype)
+ self._assertOpOutputMatchesExpected(
+ dot_fn,
+ args=(lhs, rhs),
+ expected=np.array(
+ [
+ [[9, 12, 15], [19, 26, 33]],
+ [[95, 106, 117], [129, 144, 159]],
+ ],
+ dtype=dtype))
+
+ def testNeg(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.neg,
+ args=(np.array([1, 2, 3], dtype=dtype),),
+ expected=np.array([-1, -2, -3], dtype=dtype))
+
+ def testPad(self):
+ for dtype in self.numeric_types:
+
+ def pad_fn(x):
+ return xla.pad(
+ x,
+ padding_value=7,
+ padding_low=[2, 1],
+ padding_high=[1, 2],
+ padding_interior=[1, 0])
+
+ self._assertOpOutputMatchesExpected(
+ pad_fn,
+ args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),),
+ expected=np.array(
+ [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7],
+ [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
+ dtype=dtype))
+
+ def testReduce(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def sum_reducer(x, y):
+ return x + y
+
+ def sum_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([12, 15, 18, 21], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([6, 22, 38], dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ sum_reduction(dims=[0, 1]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=dtype(66))
+
+ @function.Defun(dtype, dtype)
+ def mul_reducer(x, y):
+ return x * y
+
+ def mul_reduction(dims):
+
+ def fn(x):
+ return xla.reduce(
+ x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer)
+
+ return fn
+
+ self._assertOpOutputMatchesExpected(
+ mul_reduction(dims=[0]),
+ args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
+ expected=np.array([0, 45, 120, 231], dtype=dtype))
+
+ def testSelectAndScatter(self):
+ for dtype in set(self.numeric_types).intersection(
+ set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
+
+ @function.Defun(dtype, dtype)
+ def add_scatter(x, y):
+ return x + y
+
+ @function.Defun(dtype, dtype)
+ def ge_select(x, y):
+ return x >= y
+
+ def test_fn(operand, source):
+ return xla.select_and_scatter(
+ operand,
+ window_dimensions=[2, 3, 1, 1],
+ window_strides=[2, 2, 1, 1],
+ padding=[[0, 0]] * 4,
+ source=source,
+ init_value=0,
+ select=ge_select,
+ scatter=add_scatter)
+
+ self._assertOpOutputMatchesExpected(
+ test_fn,
+ args=(np.array(
+ [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6],
+ [0, 6, 2, 10, 2]],
+ dtype=dtype).reshape((4, 5, 1, 1)),
+ np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))),
+ expected=np.array(
+ [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0],
+ [0, 0, 0, 1, 0]],
+ dtype=dtype).reshape((4, 5, 1, 1)))
+
+ def testTranspose(self):
+ for dtype in self.numeric_types:
+ v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
+ self._assertOpOutputMatchesExpected(
+ lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index fda32c8a1c..74b131e07e 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -39,6 +39,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -88,6 +89,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -189,6 +191,7 @@ cc_library(
":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
+ ":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
@@ -211,6 +214,8 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
@@ -220,13 +225,11 @@ cc_library(
srcs = [
"literal_util.cc",
"shape_util.cc",
- "str_util.cc",
"type_util.cc",
],
hdrs = [
"literal_util.h",
"shape_util.h",
- "str_util.h",
"type_util.h",
],
visibility = [":friends"],
@@ -238,6 +241,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/types:span",
],
)
@@ -255,6 +259,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -287,6 +292,8 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -305,6 +312,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -352,6 +360,7 @@ tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
+ ":side_effect_util",
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
@@ -363,6 +372,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:core_cpu_internal",
@@ -372,19 +382,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- ],
-)
-
-tf_cc_test(
- name = "str_util_test",
- srcs = [
- "str_util_test.cc",
- ],
- deps = [
- ":common",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -439,25 +437,101 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "functionalize_control_flow_util",
+ srcs = [
+ "functionalize_control_flow_util.cc",
+ ],
+ hdrs = [
+ "functionalize_control_flow_util.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "functionalize_cond",
+ srcs = [
+ "functionalize_cond.cc",
+ ],
+ hdrs = [
+ "functionalize_cond.h",
+ ],
+ deps = [
+ ":functionalize_control_flow_util",
+ ":tf2xla_util",
+ "//tensorflow/compiler/jit:union_find",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "functionalize_control_flow",
- srcs = ["functionalize_control_flow.cc"],
- hdrs = ["functionalize_control_flow.h"],
+ srcs = [
+ "functionalize_control_flow.cc",
+ ],
+ hdrs = [
+ "functionalize_control_flow.h",
+ ],
deps = [
+ ":functionalize_cond",
+ ":functionalize_control_flow_util",
+ ":functionalize_while",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "functionalize_while",
+ srcs = [
+ "functionalize_while.cc",
+ ],
+ hdrs = [
+ "functionalize_while.h",
+ ],
+ deps = [
+ ":functionalize_control_flow_util",
+ ":tf2xla_util",
+ "//tensorflow/compiler/jit:union_find",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla/ops:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -485,6 +559,32 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "functionalize_cond_test",
+ srcs = ["functionalize_cond_test.cc"],
+ deps = [
+ ":functionalize_cond",
+ ":functionalize_control_flow",
+ ":test_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/compiler/tf2xla/cc:xla_ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:resource_variable_ops_op_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
cc_library(
name = "test_util",
testonly = 1,
@@ -494,6 +594,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
@@ -508,3 +609,38 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "resource_operation_table",
+ srcs = ["resource_operation_table.cc"],
+ hdrs = ["resource_operation_table.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "resource_operation_table_test",
+ srcs = ["resource_operation_table_test.cc"],
+ deps = [
+ ":resource_operation_table",
+ ":xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "side_effect_util",
+ srcs = ["side_effect_util.cc"],
+ hdrs = ["side_effect_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..8ac5eb5df9 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -31,7 +31,9 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
out_ops_file = "ops/xla_jit_op",
- deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ deps = [
+ "//tensorflow/compiler/jit/ops:xla_ops",
+ ],
)
cc_library(
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index de1008803d..922ae7c79a 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -23,11 +23,12 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
-
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
- std::vector<bool>* compile_time_const_args) {
+ std::vector<bool>* compile_time_const_arg_indices,
+ std::vector<bool>* compile_time_const_nodes,
+ std::function<bool(const Edge&)> edge_filter) {
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
"Rank",
@@ -36,10 +37,16 @@ Status BackwardsConstAnalysis(const Graph& g,
"Size",
};
+ std::vector<bool> compile_time_const_nodes_impl;
+ if (compile_time_const_nodes) {
+ CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
+ } else {
+ compile_time_const_nodes_impl.resize(g.num_node_ids());
+ compile_time_const_nodes = &compile_time_const_nodes_impl;
+ }
+
Status status;
- std::unordered_set<const Node*> must_be_const;
- auto visit = [&status, &metadata_ops, &must_be_const,
- compile_time_const_args](Node* node) {
+ auto visit = [&](Node* node) {
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
@@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g,
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
- if (must_be_const.find(node) != must_be_const.end()) {
+ if ((*compile_time_const_nodes)[node->id()]) {
if (node->type_string() == "_Arg") {
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
if (!status.ok()) return;
- compile_time_const_args->at(index) = true;
+ if (compile_time_const_arg_indices) {
+ (*compile_time_const_arg_indices)[index] = true;
+ }
return;
}
for (const Edge* pred : node->in_edges()) {
- if (!pred->IsControlEdge()) {
- must_be_const.insert(pred->src());
+ if (!pred->IsControlEdge() && edge_filter(*pred)) {
+ (*compile_time_const_nodes)[pred->src()->id()] = true;
}
}
return;
@@ -79,8 +88,9 @@ Status BackwardsConstAnalysis(const Graph& g,
for (Edge const* edge : node->in_edges()) {
if (edge->dst_input() >= name_range->second.first &&
- edge->dst_input() < name_range->second.second) {
- must_be_const.insert(edge->src());
+ edge->dst_input() < name_range->second.second &&
+ edge_filter(*edge)) {
+ (*compile_time_const_nodes)[edge->src()->id()] = true;
}
}
}
@@ -88,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g,
// Post-order traversal visits nodes in reverse topological order for an
// acyclic graph.
- DFS(g, {}, visit);
+ DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
+ [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
return status;
}
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
index 634b97d7e3..49b3c6d413 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.h
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -23,10 +23,22 @@ limitations under the License.
namespace tensorflow {
-// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that
-// must be compile-time constants.
-Status BackwardsConstAnalysis(const Graph& graph,
- std::vector<bool>* compile_time_const_args);
+// Backwards dataflow analysis that finds nodes in a graph that must be
+// compile-time constants for us to be able to lower the graph to XLA.
+//
+// The indices of the arguments to `graph` that must be constant are returned in
+// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not
+// null.
+//
+// The ids of the nodes in `graph` that must be constant are returned in
+// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
+//
+// Only propagate const-ness along edges for which `edge_filter` returns true.
+Status BackwardsConstAnalysis(const Graph& g,
+ std::vector<bool>* compile_time_const_arg_indices,
+ std::vector<bool>* compile_time_const_nodes,
+ std::function<bool(const Edge&)> edge_filter =
+ [](const Edge& e) { return true; });
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc
index 992b12c06d..56065be894 100644
--- a/tensorflow/compiler/tf2xla/const_analysis_test.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) {
auto c = ops::Reshape(root, arg2, b);
auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
- Graph graph(OpRegistry::Global());
- TF_ASSERT_OK(root.ToGraph(&graph));
+ FixupSourceAndSinkEdges(root.graph());
std::vector<bool> const_args(4, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ std::vector<bool> const_nodes(root.graph()->num_node_ids(), false);
+ TF_ASSERT_OK(
+ BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes));
// Arg 0 doesn't need to be constant since the graph only uses its shape.
// Arg 1 must be constant because it flows to the shape argument of a Reshape.
// Arg 2 is used only as the value input to a Reshape and need not be const.
// Arg 3 is used as the reduction-indices argument to Sum and must be const.
EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
+
+ EXPECT_FALSE(const_nodes[arg0.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg1.node()->id()]);
+ EXPECT_FALSE(const_nodes[arg2.node()->id()]);
+ EXPECT_TRUE(const_nodes[arg3.node()->id()]);
}
// Regression test for a case where the backward const analysis did
@@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(3, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
}
@@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) {
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(2, false);
- TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
+ /*compile_time_const_nodes=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 24616c01c7..380c6a7e23 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) {
string filename = name;
if (count > 0) {
- strings::StrAppend(&filename, "_", count);
+ absl::StrAppend(&filename, "_", count);
}
- strings::StrAppend(&filename, ".pbtxt");
+ absl::StrAppend(&filename, ".pbtxt");
return filename;
}
@@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile(
<< proto_type << ": " << status;
return "(unavailable)";
}
- string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name));
+ string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name));
status = WriteTextProto(Env::Default(), filepath, proto);
if (!status.ok()) {
LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
new file mode 100644
index 0000000000..0911550f1f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -0,0 +1,1354 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+
+#include <algorithm>
+#include <deque>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+using xla::StatusOr;
+
+namespace tensorflow {
+namespace functionalize_cond {
+
+// TODO(jpienaar): Move to OutputTensor.
+string DebugString(const OutputTensor& tensor) {
+ return absl::StrCat(tensor.node->name(), ":", tensor.index);
+}
+
+string Branch_Name(BranchType b) {
+ switch (b) {
+ case BranchType::kElseBranch:
+ return "else";
+ case BranchType::kThenBranch:
+ return "then";
+ case BranchType::kBoth:
+ return "both";
+ case BranchType::kNeither:
+ return "neither";
+ }
+}
+
+string DebugString(StateMap::CondId cond_state) {
+ if (cond_state == nullptr || cond_state->empty()) return "{}";
+ using value_type = StateMap::CondState::value_type;
+ return absl::StrCat(
+ "{",
+ absl::StrJoin(*cond_state, ", ",
+ [](string* output, const value_type& pred_branch) {
+ const OutputTensor& pred = pred_branch.first;
+ const BranchType& branch = pred_branch.second;
+ if (branch == BranchType::kNeither)
+ absl::StrAppend(output, "d");
+ else
+ absl::StrAppend(output, "s(", DebugString(pred), ",",
+ Branch_Name(branch), ")");
+ }),
+ "}");
+}
+
+// Returns the predicate of a switch.
+Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
+ const Edge* pred_edge;
+ TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
+ // The predicate can be preceded by a identity node. Look through
+ // identity nodes to predicate.
+ while (pred_edge->src()->IsIdentity()) {
+ TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
+ }
+ *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
+ return Status::OK();
+}
+
+Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
+ const Edge* val_edge;
+ TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
+ *val = OutputTensor(val_edge->src(), val_edge->src_output());
+ return Status::OK();
+}
+
+bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
+ const OutputTensor& rhs) const {
+ return (lhs.node->id() < rhs.node->id()) ||
+ (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
+}
+
+struct CondStateLess {
+ bool operator()(const StateMap::CondState::value_type& lhs,
+ const StateMap::CondState::value_type& rhs) const {
+ if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
+ return true;
+ if (lhs.first.node->id() == rhs.first.node->id() &&
+ lhs.first.index == rhs.first.index)
+ return lhs.second < rhs.second;
+ return false;
+ }
+};
+
+StateMap::StateMap(Graph* graph) {
+ node_to_condid_map_.resize(graph->num_node_ids());
+ node_to_ancestorid_map_.resize(graph->num_node_ids());
+ // Initialize the dead state (empty state is designated with a nullptr).
+ dead_id_ = GetCondId(
+ {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
+}
+
+bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
+
+bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
+
+size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
+ if (map.empty()) return 0;
+ // Compute hash of the front element.
+ auto it = map.begin();
+ size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
+ hash<BranchType>()(it->second));
+ for (++it; it != map.end(); ++it) {
+ // Combine the has with the different elements in the map.
+ h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
+ hash<BranchType>()(it->second)));
+ }
+ return h;
+}
+
+size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
+ if (map.empty()) return 0;
+ // Compute hash of the front element.
+ auto it = map.begin();
+ size_t h = hash<Node*>()(*it);
+ for (++it; it != map.end(); ++it) {
+ // Combine the has with the different elements in the map.
+ h = Hash64Combine(h, hash<Node*>()(*it));
+ }
+ return h;
+}
+
+// CondArgNode represents a input to the conditional and its corresponding
+// switch nodes.
+struct CondArgNode {
+ explicit CondArgNode(Node* src, int src_output)
+ : src(src), src_output(src_output) {}
+
+ string ToString() const {
+ return absl::StrCat("src=", src->name(), ":", src_output,
+ " switches=", NodesToString(switches));
+ }
+
+ Node* src;
+ int src_output;
+ std::array<Node*, 2> branch_copy;
+ std::vector<Node*> switches;
+};
+using CondArgNodes = std::vector<CondArgNode>;
+
+string DebugString(const CondArgNodes& nodes) {
+ return absl::StrCat(
+ "[",
+ absl::StrJoin(nodes, ", ",
+ [](string* output, const CondArgNode& node) {
+ absl::StrAppend(output, node.ToString());
+ }),
+ "]");
+}
+
+StateMap::CondId StateMap::LookupCondId(const Node* node) const {
+ if (node->id() < node_to_condid_map_.size())
+ return node_to_condid_map_[node->id()];
+ return added_node_condid_mapping_.at(node->id());
+}
+
+StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
+ if (state.empty()) return nullptr;
+ return &*condstate_set_.insert(state).first;
+}
+
+void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
+ if (node->id() < node_to_condid_map_.size())
+ node_to_condid_map_[node->id()] = id;
+ else
+ added_node_condid_mapping_[node->id()] = id;
+}
+
+StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
+ if (node->id() < node_to_ancestorid_map_.size())
+ return node_to_ancestorid_map_[node->id()];
+ return added_node_ancestorid_mapping_.at(node->id());
+}
+
+StateMap::AncestorId StateMap::GetAncestorId(
+ const StateMap::AncestorState& state) {
+ if (state.empty()) return nullptr;
+ return &*ancestorstate_set_.insert(state).first;
+}
+
+void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
+ if (node->id() < node_to_ancestorid_map_.size())
+ node_to_ancestorid_map_[node->id()] = id;
+ else
+ added_node_ancestorid_mapping_[node->id()] = id;
+}
+
+const StateMap::CondState& StateMap::LookupState(const Node* node) const {
+ return *LookupCondId(node);
+}
+
+void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
+
+string StateMap::CondStateToString(const Node* node) const {
+ return CondStateToString(LookupCondId(node));
+}
+
+string StateMap::CondStateToString(StateMap::CondId id) const {
+ return DebugString(id);
+}
+
+string StateMap::AncestorStateToString(const Node* node) const {
+ if (auto id = LookupAncestorId(node)) return NodesToString(*id);
+ return "{}";
+}
+
+FunctionalizeCond::FunctionalizeCond(Graph* graph,
+ FunctionLibraryDefinition* library)
+ : state_map_(graph), library_(library), graph_(graph) {}
+
+// Class representing the merge/switch nodes that will become a conditional.
+class Conditional {
+ public:
+ Conditional(OutputTensor predicate, FunctionalizeCond* parent,
+ StateMap* cond_state_map);
+
+ // Adds merge node that is part of this conditional.
+ Status AddMerge(Node* m);
+
+ // Constructs an If node from the merge nodes.
+ Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library);
+
+ private:
+ // Extracts the then/else bodies: creates new graphs with the nodes
+ // corresponding to the nodes in the then/else branches as of this conditional
+ // as function bodies.
+ Status ExtractBodies(Graph* graph);
+
+ // Builds the arguments that are the input to the If.
+ Status BuildArgumentNodes();
+
+ // Builds the If node for the extracted bodies with the given predicate.
+ Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Adds input edges to If node.
+ Status AddInputEdges(Graph* graph);
+
+ // Adds output edges from If node.
+ Status AddOutputEdges(Graph* graph);
+
+ // Adds switch node that is part of this conditional.
+ Status AddSwitch(Node* s);
+
+ // Adds a switch node along the edge and rewire the edge to go via the switch.
+ Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+ Graph* graph);
+
+ // Internal name of conditional. The name is based on the first merge node
+ // added.
+ string name() const;
+
+ // The FunctionalizeCond instance that created this.
+ FunctionalizeCond* parent_;
+
+ // Mapping between nodes and their cond state.
+ StateMap* state_map_;
+
+ // The predicate of the conditional.
+ OutputTensor predicate_;
+
+ // The predicate of the switches of the conditional. This may be different
+ // than predicate (which is initialized from the original graph) as the
+ // predicate could be the output of a newly created If node.
+ OutputTensor switch_predicate_;
+
+ // Switch nodes in graph that are part of this conditional.
+ std::set<Node*, NodeCmpByNameResourcesLast> switches_;
+
+ // Merge nodes in graph that are part of this conditional.
+ std::set<Node*, NodeCmpByNameResourcesLast> merges_;
+
+ // Vector of control inputs from outside the conditional to a node inside.
+ std::vector<Node*> external_control_inputs_;
+ std::vector<Node*> external_control_outputs_;
+
+ // Graphs corresponding to the then and else branch.
+ std::array<std::unique_ptr<Graph>, 2> bodies_;
+
+ // Maps from graph_ to the branch body's graph.
+ std::array<std::vector<Node*>, 2> node_maps_;
+
+ // The argument nodes created for the switches.
+ CondArgNodes cond_arg_nodes_;
+
+ // The constructed If node.
+ Node* if_node_ = nullptr;
+
+ // Whether the merge nodes of this conditional have been replaced.
+ bool replaced_ = false;
+};
+
+Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
+ StateMap* cond_state_map)
+ : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {}
+
+Status Conditional::AddMerge(Node* m) {
+ merges_.insert(m);
+ return Status::OK();
+}
+
+Status Conditional::AddSwitch(Node* s) {
+ VLOG(5) << "Adding switch " << s->DebugString();
+ OutputTensor predicate;
+ TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
+ if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
+ if (!(switch_predicate_ == predicate)) {
+ return errors::InvalidArgument(
+ "Merge nodes ", NodesToString(merges_),
+ " directly dominated by switch nodes with different predicates (",
+ DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
+ }
+ switches_.insert(s);
+ return Status::OK();
+}
+
+Status Conditional::BuildArgumentNodes() {
+ VLOG(1) << "Build function arguments";
+ struct Hash {
+ size_t operator()(const std::pair<Node*, int>& item) const {
+ return Hash64Combine(hash<Node*>()(item.first),
+ std::hash<int>()(item.second));
+ }
+ };
+
+ std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
+ for (Node* switch_node : switches_) {
+ const Edge* e;
+ TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
+ std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
+ if (input_index.find(key) == input_index.end()) {
+ input_index[key] = cond_arg_nodes_.size();
+ cond_arg_nodes_.emplace_back(key.first, key.second);
+ }
+ cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
+ }
+ VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
+
+ int arg_count = 0;
+ for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
+ DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(absl::StrCat("_Arg", arg_count),
+ FunctionLibraryDefinition::kArgOp)
+ .Attr("T", dtype)
+ .Attr("index", arg_count)
+ .Finalize(bodies_[branch_index].get(),
+ &cond_arg_node.branch_copy[branch_index]));
+ }
+ for (Node* node : cond_arg_node.switches) {
+ for (const Edge* e : node->out_edges()) {
+ if (e->IsControlEdge()) continue;
+ int branch_index = e->src_output();
+ Node* src_copy = cond_arg_node.branch_copy[branch_index];
+ Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
+
+ // The graph may contain dead switch nodes,
+ if (dst_copy == nullptr) continue;
+
+ TF_RET_CHECK(dst_copy != nullptr)
+ << "Unable to find copied node for " << e->dst()->DebugString()
+ << " on branch " << Branch_Name(BranchType(branch_index));
+ // If the input goes directly to a merge then the merge has
+ // been replaced by a retval so the dst input is 0 instead of
+ // dst_input.
+ int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
+ bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
+ }
+ }
+ ++arg_count;
+ }
+
+ // Verify that all retvals have an input.
+ // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
+ // input.
+ for (Node* m : merges_) {
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ bool has_input = false;
+ for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
+ if (!e->IsControlEdge()) {
+ has_input = true;
+ break;
+ }
+ }
+ if (!has_input) {
+ return errors::Internal(
+ "Failed to functionalize control flow with merge ",
+ FormatNodeForError(*m), " that doesn't have input on ",
+ Branch_Name(branch), " branch.");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+ Graph* graph) {
+ // Previously we had edge:
+ // src:src_output ---- edge ----> dst:dst_input
+ // post this we have (in graph)
+ // src:src_output --> switch<pred> --- new_edge --> dst:dst_input
+
+ // TODO(jpienaar): One could keep a map caching the extra switch nodes added
+ // to avoid adding another switch to feed a value for which a switch was
+ // already added.
+ Node* switch_node;
+ Node* src = edge->src();
+ int src_output = edge->src_output();
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
+ "Switch")
+ .Input(src, src_output)
+ .Input(const_cast<Node*>(predicate_.node), predicate_.index)
+ .Finalize(graph, &switch_node));
+ state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
+ state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
+
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ graph->RemoveEdge(edge);
+ graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
+ return AddSwitch(switch_node);
+}
+
+Status Conditional::ExtractBodies(Graph* graph) {
+ VLOG(2) << "Extracting bodies for " << name();
+ for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ bodies_[static_cast<int>(b)] =
+ absl::make_unique<Graph>(graph->op_registry());
+ }
+
+ auto find_branch = [&](const Edge* e) {
+ const auto& id = state_map_->LookupCondId(e->src());
+ return IsSwitch(e->src()) ? BranchType(e->src_output())
+ : state_map_->FindBranchOf(id, predicate_);
+ };
+
+ std::array<std::vector<Node*>, 2> stacks;
+ VLOG(5) << "Merges: " << NodesToString(merges_);
+ for (Node* m : merges_) {
+ VLOG(5) << "For merge: " << m->DebugString() << " "
+ << state_map_->CondStateToString(m);
+ for (auto e : m->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BranchType branch = find_branch(e);
+ TF_RET_CHECK(branch == BranchType::kThenBranch ||
+ branch == BranchType::kElseBranch)
+ << "Error: " << e->src()->name()
+ << " is not on either then or else branch (" << Branch_Name(branch)
+ << ") for predicate " << DebugString(predicate_) << " ["
+ << DebugString(state_map_->LookupCondId(e->src())) << "].";
+ Node* src = e->src();
+ if (IsSwitch(src)) {
+ // Switch node outputs and dependencies are handled separately.
+ TF_RETURN_IF_ERROR(AddSwitch(src));
+ } else {
+ stacks[static_cast<int>(branch)].push_back(src);
+ }
+ }
+ }
+
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto output = bodies_[branch_index].get();
+ auto& stack = stacks[branch_index];
+ VLOG(5) << "In branch: " << Branch_Name(branch) << " "
+ << NodesToString(stack);
+ std::vector<bool> visited(graph->num_node_ids(), false);
+ node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
+ auto& node_map = node_maps_[branch_index];
+
+ while (!stack.empty()) {
+ Node* n = stack.back();
+ stack.pop_back();
+
+ if (visited.at(n->id())) continue;
+ visited[n->id()] = true;
+
+ // Verify output edges and record control edges exitting scope.
+ for (const Edge* e : n->out_edges()) {
+ Node* dst = e->dst();
+ if (IsMerge(dst)) continue;
+ Node* src = e->src();
+
+ auto dst_id = state_map_->LookupCondId(dst);
+ auto src_id = state_map_->LookupCondId(src);
+ if (dst_id != src_id) {
+ if (e->IsControlEdge()) {
+ external_control_outputs_.push_back(e->src());
+ } else {
+ // Constants are treated specially to workaround the case of
+ // non-dominated constant nodes.
+ if (!IsConstant(src)) {
+ // TODO(b/78882471): A node that feeds into two different
+ // CondState is not necessarily an error so log a warning for now
+ // but revisit to improve the testing to enable making this an
+ // error.
+ LOG(WARNING) << errors::InvalidArgument(
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
+ " but these nodes are in different control contexts (",
+ DebugString(src_id), " vs ", DebugString(dst_id),
+ " (detected during out edge testing)");
+ }
+ }
+ }
+ }
+
+ // Copying incomming edges to dst node. Iterate over a copy of the edges
+ // as they could be mutated during iteration.
+ std::vector<const Edge*> in_edges(n->in_edges().begin(),
+ n->in_edges().end());
+ for (const Edge* e : in_edges) {
+ Node* src = e->src();
+ // Skip src/dst node.
+ if (!src->IsOp()) continue;
+
+ Node* dst = e->dst();
+ if (IsSwitch(src)) {
+ // Switch node outputs and dependencies are handled separately.
+ TF_RETURN_IF_ERROR(AddSwitch(src));
+ continue;
+ }
+
+ // Verify input is from the same context.
+ auto src_id = state_map_->LookupCondId(src);
+ auto dst_id = state_map_->LookupCondId(dst);
+ if (IsMerge(dst) || src_id == dst_id) {
+ // TODO(jpienaar): The merge case can be more strict.
+ if (node_map.at(src->id()) == nullptr) {
+ node_map.at(src->id()) = output->CopyNode(src);
+ stack.push_back(src);
+ }
+ } else if (e->IsControlEdge()) {
+ external_control_inputs_.push_back(src);
+ } else {
+ // This shouldn't happen, this means we have an external data input
+ // not entering via a switch node. Work around this by for
+ // * constant nodes copy them;
+ // * non-constant nodes, insert a switch along the edge;
+ if (IsConstant(src)) {
+ node_map.at(src->id()) = output->CopyNode(src);
+ } else {
+ StateMap::CondState state = *dst_id;
+ state.erase(predicate_);
+ if (state_map_->GetCondId(state) == src_id) {
+ TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
+ continue;
+ } else {
+ return errors::InvalidArgument(
+ "Graph contains node ", FormatNodeForError(*src),
+ " that feeds into node ", FormatNodeForError(*dst),
+ " but these nodes are in different control contexts (",
+ DebugString(src_id), " vs ", DebugString(dst_id),
+ " (detected during in edge testing)");
+ }
+ }
+ }
+
+ Node* src_copy = node_map.at(e->src()->id());
+ int src_output = e->src_output();
+ if (node_map.at(dst->id()) == nullptr) {
+ node_map.at(dst->id()) = output->CopyNode(dst);
+ }
+ Node* dst_copy = node_map.at(e->dst()->id());
+ if (e->IsControlEdge()) {
+ // Skip control inputs from external context.
+ if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
+ } else {
+ output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
+ }
+ }
+ }
+ }
+
+ // Build return values from the merge nodes.
+ int index = 0;
+ for (Node* m : merges_) {
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto& node_map = node_maps_[branch_index];
+ auto output = bodies_[branch_index].get();
+ TF_ASSIGN_OR_RETURN(node_map[m->id()],
+ BuildRetvalNode(output, m->output_type(0), index));
+ }
+ ++index;
+
+ // Connect the input to the merge_ with the retval, except if it is a
+ // Swich node, which is handled separately.
+ for (auto e : m->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ int branch_index = static_cast<int>(find_branch(e));
+ auto& node_map = node_maps_[branch_index];
+ auto output = bodies_[branch_index].get();
+ Node* in = e->src();
+ if (!IsSwitch(in)) {
+ if (node_map.at(in->id()) == nullptr) {
+ node_map[in->id()] = output->CopyNode(in);
+ }
+ output->AddEdge(node_map[in->id()], e->src_output(),
+ node_map.at(m->id()), 0);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Conditional::BuildIfNode(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(2) << "Build cond function for " << name();
+ NodeDefBuilder builder(name(), "If");
+ const string branch_name[] = {"else_branch", "then_branch"};
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ static std::atomic<int64> sequence_num(0LL);
+ int64 id = ++sequence_num;
+
+ NameAttrList body_name;
+ body_name.set_name(
+ absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id));
+
+ VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
+ << "): "
+ << dump_graph::DumpGraphToFile(
+ "functionalize_cond_body_" + branch_name[branch_index],
+ *bodies_[branch_index], nullptr);
+
+ FunctionDef body_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
+ body_name.name(), &body_fdef));
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
+ builder.Attr(branch_name[branch_index], body_name);
+ }
+
+ VLOG(3) << "Build input type";
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ DataTypeVector in_arg_types;
+ for (auto& kv : cond_arg_nodes_) {
+ bool inserted = false;
+ for (const Node* arg : kv.switches) {
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ builder.ControlInput(in_edge->src()->name());
+ } else {
+ if (!inserted) {
+ DataType dtype = arg->input_type(0);
+ inputs.emplace_back(NodeDefBuilder::NodeOut(
+ in_edge->src()->name(), in_edge->src_output(), dtype));
+ in_arg_types.push_back(dtype);
+ inserted = true;
+ }
+ }
+ }
+ }
+ builder.Attr("Tin", in_arg_types);
+
+ DataTypeVector out_type;
+ for (const Node* merge : merges_) {
+ DataType dtype = merge->output_type(0);
+ out_type.push_back(dtype);
+ }
+ builder.Attr("Tout", out_type);
+ VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
+
+ builder.Attr("Tcond", DT_BOOL);
+ builder.Device(predicate_.node->assigned_device_name());
+ // Conditional should be the first input ...
+ builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(),
+ predicate_.index,
+ predicate_.node->output_type(0)));
+ // ... followed by the other inputs.
+ builder.Input(inputs);
+
+ VLOG(3) << "Build If node";
+ NodeDef if_def;
+ TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
+ TF_ASSIGN_OR_RETURN(if_node_,
+ parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
+
+ return Status::OK();
+}
+
+Status Conditional::AddInputEdges(Graph* graph) {
+ VLOG(2) << "AddInputEdges for " << if_node_->name();
+ int index = 0;
+ // Add predicate input.
+ graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index, if_node_,
+ index++);
+ // Add function body inputs.
+ for (auto& arg : cond_arg_nodes_) {
+ if (arg.src_output == Graph::kControlSlot) {
+ graph->AddControlEdge(arg.src, if_node_);
+ } else {
+ graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
+ }
+ }
+ for (Node* n : external_control_inputs_) {
+ graph->AddControlEdge(n, if_node_);
+ }
+ return Status::OK();
+}
+
+Status Conditional::AddOutputEdges(Graph* graph) {
+ VLOG(2) << "AddOutputEdges for " << if_node_->name();
+ int i = 0;
+ for (Node* node : merges_) {
+ TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
+ std::vector<const Edge*> edges(node->out_edges().begin(),
+ node->out_edges().end());
+ for (const Edge* edge : edges) {
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ if (edge->src_output() > 0) {
+ return errors::Unimplemented("Output of index (", edge->src_output(),
+ ") of merge node ",
+ FormatNodeForError(*node));
+ }
+
+ bool control_edge = edge->IsControlEdge();
+ graph->RemoveEdge(edge);
+ if (control_edge) {
+ graph->AddControlEdge(if_node_, dst);
+ } else {
+ graph->AddEdge(if_node_, i, dst, dst_input);
+ }
+ }
+ ++i;
+ }
+ for (Node* n : external_control_outputs_) {
+ graph->AddControlEdge(if_node_, n);
+ }
+
+ return Status::OK();
+}
+
+Status Conditional::BuildAndReplace(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(1) << "Build If and replace merge nodes "
+ << NodesToString(this->merges_);
+ if (replaced_) return Status::OK();
+
+ TF_RETURN_IF_ERROR(ExtractBodies(graph));
+ TF_RETURN_IF_ERROR(BuildArgumentNodes());
+
+ if (VLOG_IS_ON(3)) {
+ LOG(INFO) << "Extracted bodies:";
+ for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
+ int branch_index = static_cast<int>(branch);
+ auto output = bodies_[branch_index].get();
+ LOG(INFO) << Branch_Name(branch) << ": "
+ << DebugString(output->ToGraphDefDebug());
+ }
+ }
+
+ TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
+ TF_RETURN_IF_ERROR(AddInputEdges(graph));
+ TF_RETURN_IF_ERROR(AddOutputEdges(graph));
+ TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
+ for (Node* m : merges_) state_map_->MarkDead(m);
+
+ // Check that the if_node doesn't feed into itself.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
+ "Converting to If failed.");
+
+ replaced_ = true;
+ return Status::OK();
+}
+
+string Conditional::name() const {
+ CHECK(!merges_.empty());
+ return absl::StrCat((*merges_.begin())->name(), "_if");
+}
+
+Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
+ int port) {
+ Node* id;
+ TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
+ .Input(if_node, port)
+ .Finalize(graph_, &id));
+ state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
+ state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
+ return Status::OK();
+}
+
+StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
+ const Node* replacee,
+ const OutputTensor& predicate) {
+ Status status;
+ Node* ret = graph_->AddNode(def, &status);
+ TF_RETURN_IF_ERROR(status);
+ VLOG(1) << "Adding If for " << replacee->name();
+ StateMap::CondId id = state_map_.LookupCondId(replacee);
+ if (id) {
+ StateMap::CondState state = *id;
+ state.erase(predicate);
+ state_map_.ResetCondId(ret, state_map_.GetCondId(state));
+ } else {
+ state_map_.ResetCondId(ret, nullptr);
+ }
+
+ state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
+
+ return ret;
+}
+
+Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
+ VLOG(2) << "Propagating update state for " << replacee->name() << " "
+ << state_map_.CondStateToString(replacee);
+ // Redo topological sort as the order could have changed.
+ // TODO(jpienaar): The original topological order could also be updated
+ // dynamically if needed.
+ std::vector<Node*> rev_topo_order;
+ GetPostOrder(*graph_, &rev_topo_order);
+
+ // All the outputs of the new node could potentially be updated.
+ std::unordered_set<Node*> changed;
+ for (auto n : replacee->out_nodes())
+ if (n->IsOp()) changed.insert(n);
+
+ // Iterate through the changed/possible changed nodes in topological order.
+ for (auto it = rev_topo_order.rbegin();
+ it != rev_topo_order.rend() && !changed.empty(); ++it) {
+ if (changed.find(*it) != changed.end()) {
+ // Update the node state.
+ Node* n = *it;
+ StateMap::CondId old_state = state_map_.LookupCondId(n);
+ state_map_.ResetCondId(n, nullptr);
+ TF_RETURN_IF_ERROR(DetermineCondState(n));
+ if (state_map_.LookupCondId(n) != old_state) {
+ for (auto out : n->out_nodes())
+ if (out->IsOp()) changed.insert(out);
+ }
+ changed.erase(n);
+ }
+ }
+ return Status::OK();
+}
+
+// Returns the most restrictive branch of two branches or neither. This is the
+// meet operator of the BranchType lattice.
+BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
+ if (lhs == rhs) return lhs;
+ if (lhs == BranchType::kNeither) return rhs;
+ if (rhs == BranchType::kNeither) return lhs;
+ if (lhs == BranchType::kBoth) return rhs;
+ if (rhs == BranchType::kBoth) return lhs;
+ return BranchType::kNeither;
+}
+
+BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
+ if (IsEmpty(id)) return BranchType::kNeither;
+ const CondState& nodes = *id;
+ auto it = nodes.find(predicate);
+ if (it == nodes.end()) return BranchType::kNeither;
+ return it->second;
+}
+
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
+ StateMap::CondId src, StateMap::CondId dst) {
+ VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
+ << "] and dst=" << DebugString(dst) << " [" << dst << "]";
+
+ if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
+ if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
+
+ // Nothing to do if the CondState is the same.
+ if (src == dst) return src;
+
+ StateMap::CondState both = *src;
+ for (const auto& kv : *dst) {
+ auto it = both.find(kv.first);
+ if (it == both.end()) {
+ both.insert(kv);
+ } else {
+ if (it->second != kv.second) {
+ return errors::InvalidArgument(
+ "Graph contains node with inputs predicated on incompatible "
+ "predicates: ",
+ DebugString(src), " and ", DebugString(dst));
+ }
+ }
+ }
+ return state_map_.GetCondId(both);
+}
+
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
+ Node* merge, StateMap::CondId src, StateMap::CondId dst) {
+ // Determine the flow state when joining two states for a merge
+ // node. Combining the two states for a merge node is effectively performing a
+ // disjunction of the states along the different input edges. For a merge that
+ // can be transformed into a If the two inputs paths have to have a predicate
+ // on which they differ (e.g., along one edge predicate `p` has to hold while
+ // on another it should not). This function first determines this predicate
+ // and then the resultant state is the common path between the two inputs
+ // followed by s(p, both).
+ VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
+ << DebugString(dst);
+ if (state_map_.IsEmpty(dst)) return src;
+
+ if (state_map_.IsDead(src)) return src;
+ if (state_map_.IsDead(dst)) return dst;
+
+ std::vector<StateMap::CondState::value_type> diff;
+ StateMap::CondState merged;
+ std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
+ dst->end(), std::back_inserter(diff),
+ CondStateLess());
+ std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
+ std::inserter(merged, merged.begin()), CondStateLess());
+
+ // Update mapping from merge node to predicate.
+ if (diff.size() == 2) {
+ auto pred = diff[0].first;
+ bool different_branches = (diff[0].second != diff[1].second) &&
+ (diff[0].second == BranchType::kThenBranch ||
+ diff[0].second == BranchType::kElseBranch) &&
+ (diff[1].second == BranchType::kThenBranch ||
+ diff[1].second == BranchType::kElseBranch);
+ if (!(pred == diff[1].first) || !different_branches)
+ return errors::InvalidArgument(
+ "Unable to determine predicate for merge node");
+ merge_to_predicate_[merge] = pred;
+ } else {
+ return errors::InvalidArgument(
+ "Merge of two inputs that differ on more than one predicate ",
+ DebugString(src), " and ", DebugString(dst));
+ }
+
+ return state_map_.GetCondId(merged);
+}
+
+StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
+ Node* src = e->src();
+ StateMap::CondId id = state_map_.LookupCondId(e->src());
+
+ // Dead nodes only propagate dead state.
+ if (state_map_.IsDead(id)) return id;
+
+ if (IsSwitch(src)) {
+ StateMap::CondState state;
+ if (id != nullptr) state = *id;
+ OutputTensor predicate;
+ TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
+ if (!e->IsControlEdge()) {
+ state[predicate] = BranchType(e->src_output());
+ }
+ return state_map_.GetCondId(state);
+ }
+ return id;
+}
+
+Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
+ // Only Merge nodes with two inputs are supported, but if this is a redundant
+ // merge, then the dead edge may already have been removed (if due to a
+ // switch) and so the input count would be incorrect.
+ if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
+
+ int data_inputs = 0;
+ for (auto e : dst->in_edges()) {
+ Node* src = e->src();
+ VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
+ << state_map_.CondStateToString(src);
+ if (!src->IsOp()) continue;
+ if (!e->IsControlEdge()) ++data_inputs;
+
+ StateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
+ state_map_.ResetCondId(dst, id_or.ValueOrDie());
+ }
+
+ // Incomplete Merge nodes are not supported.
+ if (data_inputs != 2) {
+ return errors::Unimplemented(
+ dst->name(), " only has ", data_inputs,
+ " inputs, while only merge nodes with two inputs supported.");
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
+ // Handle non-merge join.
+ for (auto e : dst->in_edges()) {
+ VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
+ << state_map_.CondStateToString(dst);
+ Node* src = e->src();
+ if (!src->IsOp()) continue;
+
+ // Joining the state between the current and propagated state.
+ StateMap::CondId prop = StateAlongEdge(e);
+ auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst));
+ state_map_.ResetCondId(dst, id_or.ValueOrDie());
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
+ // Handle redundant merge nodes. A merge node is considered redundant if
+ // one input edge is dead while the other has a value.
+ if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
+
+ const Edge* non_dead_edge = nullptr;
+ for (auto e : node->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ Node* src = e->src();
+
+ // Handle merge with dead state.
+ const auto& src_id = state_map_.LookupCondId(src);
+ if (!state_map_.IsDead(src_id)) {
+ non_dead_edge = e;
+ break;
+ }
+ }
+
+ if (non_dead_edge == nullptr) {
+ return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
+ " has no non-dead inputs.");
+ }
+ state_map_.MarkDead(node);
+ delete_nodes_.push_back(node->id());
+ VLOG(5) << "removing redundant merge: " << node->name();
+ while (!node->out_edges().empty()) {
+ const Edge* oe = *node->out_edges().begin();
+ Node* dst_node = oe->dst();
+ int dst_port = oe->dst_input();
+ graph_->RemoveEdge(oe);
+ graph_->AddEdge(non_dead_edge->src(),
+ dst_port == Graph::kControlSlot
+ ? Graph::kControlSlot
+ : non_dead_edge->src_output(),
+ dst_node, dst_port);
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
+ // Handle redundant switch nodes. A switch node is considered redundant if
+ // the predicate of the switch already holds on the current branch. E.g., if
+ // p is the predicate of the switch but p is already known to hold on this
+ // branch, then the switch can be removed and the dead state propagated
+ // along one. The checking of predicate is based on the exact predicate
+ // (rather than boolean equivalence) and aimed at redundant switches as
+ // currently generated by gradient code.
+ StateMap::CondId dst_id = state_map_.LookupCondId(node);
+ if (state_map_.IsDead(dst_id)) return Status::OK();
+
+ BranchType b;
+ OutputTensor pred;
+ TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
+
+ // Determine if we are already on a branch where the switch predicate is
+ // true/false. Consider both the data and predicate to determine if the
+ // node is redundant (skipping over identity node).
+ b = state_map_.FindBranchOf(dst_id, pred);
+ if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
+ OutputTensor val;
+ const Edge* e;
+ TF_RETURN_IF_ERROR(node->input_edge(0, &e));
+ val = OutputTensor(e->src(), e->src_output());
+ while (IsIdentity(val.node)) {
+ TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
+ val = OutputTensor(e->src(), e->src_output());
+ }
+ b = state_map_.FindBranchOf(dst_id, val);
+ if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
+ return Status::OK();
+ }
+
+ VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
+ << DebugString(dst_id);
+ const Edge* value_edge;
+ TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
+ Node* val_node = value_edge->src();
+ int val_port = value_edge->src_output();
+ while (!node->out_edges().empty()) {
+ auto e = *node->out_edges().begin();
+ Node* dst_node = e->dst();
+ int dst_input = e->dst_input();
+ int switch_branch = e->src_output();
+ graph_->RemoveEdge(e);
+ if (switch_branch == Graph::kControlSlot) {
+ if (IsMerge(dst_node)) {
+ auto id_or = JoinCondStatesMerge(dst_node, dst_id,
+ state_map_.LookupCondId(dst_node));
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+ FormatNodeForError(*dst_node));
+ state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
+ } else {
+ auto id_or =
+ JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
+ TF_RETURN_IF_ERROR(id_or.status());
+ state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
+ }
+ } else if (BranchType(switch_branch) != b) {
+ state_map_.MarkDead(dst_node);
+ delete_nodes_.push_back(dst_node->id());
+ continue;
+ }
+ graph_->AddEdge(
+ val_node,
+ switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
+ dst_node, dst_input);
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
+ // The state that is propagated along the given edge.
+ for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
+ Node* dst = *it;
+ TF_RETURN_IF_ERROR(DetermineCondState(dst));
+ TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
+ if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
+ if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
+
+ VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
+ << " @ " << state_map_.AncestorStateToString(dst);
+ if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it");
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
+ StateMap::AncestorId id = nullptr;
+ StateMap::AncestorState state;
+
+ auto insert = [&](StateMap::AncestorId id, Node* src) {
+ auto other_id = state_map_.LookupAncestorId(src);
+ if (other_id != id && other_id != nullptr) {
+ state.insert(other_id->begin(), other_id->end());
+ }
+ if (IsSwitch(src) || IsMerge(src)) {
+ state.insert(src);
+ }
+ return state_map_.GetAncestorId(state);
+ };
+
+ // Compute the union of all the switch/merge nodes that affects the input of
+ // dst.
+ for (auto e : dst->in_edges()) {
+ Node* src = e->src();
+ id = insert(id, src);
+ }
+ state_map_.ResetAncestorId(dst, id);
+ return Status::OK();
+}
+
+void FunctionalizeCond::DeleteReachableNodes() {
+ // Delete all nodes that have been extracted or are reachable from
+ // deleted/dead nodes. The input and outgoing edges should have already been
+ // removed.
+ std::vector<bool> deleted(graph_->num_node_ids(), false);
+ // Don't try to delete source or sink nodes.
+ deleted[graph_->kSourceId] = true;
+ deleted[graph_->kSinkId] = true;
+ while (!delete_nodes_.empty()) {
+ int d_id = delete_nodes_.front();
+ delete_nodes_.pop_front();
+ if (deleted[d_id]) continue;
+ Node* d = graph_->FindNodeId(d_id);
+ // Switch and Merge nodes could have been deleted already.
+ if (d == nullptr) continue;
+ for (const Edge* e : d->out_edges()) {
+ delete_nodes_.push_back(e->dst()->id());
+ }
+ deleted[d_id] = true;
+ graph_->RemoveNode(d);
+ }
+}
+
+void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
+ // Sort merge nodes by nesting depth.
+ using sort_pair = std::pair<int, Node*>;
+ std::vector<sort_pair> inner_to_outer_merge_order;
+ inner_to_outer_merge_order.reserve(merge_order->size());
+ for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
+ Node* merge = *it;
+ StateMap::CondId id = state_map_.LookupCondId(merge);
+ int depth = id != nullptr ? id->size() : 0;
+ inner_to_outer_merge_order.emplace_back(depth, merge);
+ }
+ std::stable_sort(
+ inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
+ [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
+ merge_order->clear();
+ for (sort_pair t : inner_to_outer_merge_order) {
+ merge_order->push_back(t.second);
+ }
+}
+
+Status FunctionalizeCond::FunctionalizeInternal() {
+ // The general approach for converting a tf.cond (as lowered via switch/merge
+ // nodes) to a functional if is as follows:
+ // 1. Determine the topological order and collect all the switch and merge
+ // nodes in the graph;
+ // 2. Compute the predicates and dominance structure for all the nodes in the
+ // graph - this includes which predicate must be true for a op to execute
+ // (predicate values are considered directly rather than attempting to
+ // determine deeper equivalence). We shall refer to this structure as the
+ // CondState;
+ // 3. Sort the merge nodes by nesting depth;
+ // 4. Extract merge nodes together that have the same CondState and
+ // AncestorState from the innermost to the outermost into IfOps;
+ // Note: In the above only nodes that feed into a merge node will be
+ // considered for functionalization.
+
+ // Perform a DFS over the graph and
+ // * Determine the reverse topological order of the nodes (there should be no
+ // cycles at this point so the post-order numbering corresponds to the
+ // reverse topological sorting);
+ // * Record reverse topological for merge and switch nodes;
+ std::vector<Node*> rev_topo_order;
+ std::vector<int> switch_ids;
+ std::vector<Node*> merge_order;
+ DFS(*graph_, nullptr, [&](Node* n) {
+ if (IsSwitch(n)) {
+ switch_ids.push_back(n->id());
+ }
+ if (IsMerge(n)) {
+ merge_order.push_back(n);
+ }
+ if (n->IsOp()) {
+ rev_topo_order.push_back(n);
+ }
+ });
+
+ // No merges to functionalize.
+ if (merge_order.empty()) {
+ // No merges mean no switch values consumed (as only considering values
+ // fetchable as output of merge);
+ for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) {
+ graph_->RemoveNode(graph_->FindNodeId(*it));
+ }
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+
+ // Sort the merge nodes from innermost outwards.
+ SortMergeNodes(&merge_order);
+
+ // Cluster merge nodes by CondId and AncestorId in order of nesting.
+ using ClusterPair = std::pair<StateMap::CondId, StateMap::AncestorId>;
+ std::deque<std::vector<Node*>> merge_clusters;
+ std::map<ClusterPair, int> merge_cluster_index;
+ for (Node* merge : merge_order) {
+ auto cond_id = state_map_.LookupCondId(merge);
+ if (state_map_.IsDead(cond_id)) continue;
+
+ ClusterPair key =
+ std::make_pair(cond_id, state_map_.LookupAncestorId(merge));
+ auto idx = merge_cluster_index.find(key);
+ if (idx == merge_cluster_index.end()) {
+ merge_cluster_index[key] = merge_clusters.size();
+ merge_clusters.push_back({merge});
+ } else {
+ merge_clusters[idx->second].emplace_back(merge);
+ }
+ }
+
+ // Extract the conditionals from inner most to outer most. Extracting from
+ // innermost to outermost enables the extraction pass to stop once it
+ // encounters a Switch node instead of having to keep track of Switch/Merge
+ // nodes seen.
+ for (const auto& cluster : merge_clusters) {
+ // Construct a Conditional with the predicate of the merge.
+ Conditional cond(merge_to_predicate_.at(cluster.front()), this,
+ &state_map_);
+ for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
+ TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_));
+
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
+ }
+
+ // All remaining Switch nodes are not reachable from a Merge node and
+ // removed. This is to account for dead Switch nodes.
+ for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
+ for (Node* m : merge_order) delete_nodes_.push_back(m->id());
+ DeleteReachableNodes();
+
+ return Status::OK();
+}
+
+void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
+ const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
+
+ for (Node* n : graph_->nodes()) {
+ n->ClearAttr(kCondGroupDebugAttr);
+ n->AddAttr(kCondGroupDebugAttr,
+ absl::StrCat(state_map_.CondStateToString(n), "_",
+ state_map_.AncestorStateToString(n)));
+ }
+ LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
+ << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name),
+ *graph_, library_);
+}
+
+Status FunctionalizeCond::Functionalize(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ VLOG(1) << "FunctionalizeCond::Functionalize";
+ FunctionalizeCond fc(graph, library);
+ return fc.FunctionalizeInternal();
+}
+
+} // namespace functionalize_cond
+
+Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
+ // FunctionalizeControlFlow is invoked for every function, so the loops's
+ // bodies and conditionals that were extracted into functions will be handled
+ // in successive invocations.
+ return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
new file mode 100644
index 0000000000..28301150ea
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
+
+#include <deque>
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Functionalize all the switch-merge nodes of a loop-free graph into If
+// nodes. That is, attempt to transform every remaining switch and merge nodes
+// in the graph into If nodes.
+// Precondition: All while loops have been removed from graph.
+Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
+
+// Internal functions/classes exposed for testing purposes.
+namespace functionalize_cond {
+
+// All nodes are assumed to be either in no branch, then branch, else branch,
+// or both branches (such as merge nodes).
+// The code below relies on Else and Then being 0 and 1 (corresponding to the
+// switch outputs). Both and Neither are arbitrary.
+enum class BranchType {
+ kElseBranch = 0,
+ kThenBranch = 1,
+ kBoth = 2,
+ kNeither = 3,
+};
+
+// StateMap is responsible for mapping from each graph Node to
+// * a CondState, where each CondState is a map from predicate to branch (i,e.,
+// what predicates have to hold or not hold).
+// * a AncestorState, where each AncestorState is a set of switch/merge nodes
+// that are an ancestor of the node in the graph;
+// For efficiency, this class interns the CondState (AncestorState), so that
+// CondState (AncestorState) equality comparisons are simply pointer
+// comparisons.
+class StateMap {
+ public:
+ explicit StateMap(Graph* graph);
+
+ // Compare two OutputTensors by (node id, index).
+ struct OutputTensorLess {
+ bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const;
+ };
+
+ // A node in the graph is executed when multiple conditions hold. Keep track
+ // of the predicates that must hold for a node to execute.
+ using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>;
+
+ // Every unique ID is mapped to a CondState.
+ using CondId = const CondState*;
+
+ // Keep track of which switch/merge node's feed into a node's values.
+ using AncestorState = std::set<Node*>;
+
+ // Every unique ID is mapped to a AncestorState.
+ using AncestorId = const AncestorState*;
+
+ // Returns the CondId for a given node.
+ CondId LookupCondId(const Node* node) const;
+
+ // Returns the unique CondId for CondState.
+ CondId GetCondId(const CondState& state);
+
+ // Resets the CondId for a given node.
+ void ResetCondId(const Node* node, CondId id);
+
+ // Returns the AncestorId for a given node.
+ AncestorId LookupAncestorId(const Node* node) const;
+
+ // Returns the unique AncestorId for CondState.
+ AncestorId GetAncestorId(const AncestorState& state);
+
+ // Resets the AncestorId for a given node.
+ void ResetAncestorId(const Node* node, AncestorId id);
+
+ // Returns the CondState for a Node.
+ // REQUIRES: node has a non-empty CondState.
+ const CondState& LookupState(const Node* node) const;
+
+ // Marks `node` as dead.
+ void MarkDead(const Node* node);
+
+ // Determine branch execution of CondState.
+ BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
+
+ // Returns textual representation of node's CondState.
+ string CondStateToString(const Node* node) const;
+ string CondStateToString(CondId id) const;
+
+ // Returns textual representation of node's AncestorState.
+ string AncestorStateToString(const Node* node) const;
+
+ // Returns whether the cond state is the dead state.
+ bool IsDead(CondId id) const;
+
+ // Returns whether the cond state is the empty state.
+ bool IsEmpty(CondId id) const;
+
+ private:
+ // Hash for CondState and AncestorState.
+ struct Hash {
+ size_t operator()(const CondState& map) const;
+ size_t operator()(const AncestorState& map) const;
+ };
+
+ // Set to keep track of unique CondStates.
+ // Pointers to the entries in the unordered set are used as identifiers:
+ // unordered_set guarantees that the pointers remain the same.
+ std::unordered_set<CondState, Hash> condstate_set_;
+
+ // Mapping from Node id to CondId.
+ std::vector<CondId> node_to_condid_map_;
+
+ // Track the CondId for newly inserted nodes. We use a vector to quickly map
+ // from Node id in the original graph to the CondId, but there will be nodes
+ // added to the original graph (such as If nodes) whose CondState needs to be
+ // tracked too.
+ std::unordered_map<int, CondId> added_node_condid_mapping_;
+
+ // AncestorId variants of the CondId members.
+ std::unordered_set<AncestorState, Hash> ancestorstate_set_;
+ std::vector<AncestorId> node_to_ancestorid_map_;
+ std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_;
+
+ // Identifier of the dead flow state. The empty flow state is represented with
+ // a nullptr.
+ CondId dead_id_;
+};
+
+// FunctionalizeCond groups all the state used by functionalizing conditionals
+// of the given graph together.
+class FunctionalizeCond {
+ public:
+ // Functionalize all the switch-merge nodes of a loop-free graph into If
+ // nodes. That is, attempt to transform every remaining switch and merge nodes
+ // in the graph into If nodes.
+ // Precondition: All while loops have been removed from graph.
+ static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Build identity node with the same name as the merge that will be replaced
+ // in case the output is fetched/colocated.
+ Status AddIdentityNode(const Node* replacee, Node* if_node, int port);
+
+ // Add a If node to the graph defined by def that will, amongst other, replace
+ // replacee in the graph.
+ xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee,
+ const OutputTensor& predicate);
+
+ // Propagates the state of a newly inserted node.
+ Status PropagateUpdatedState(const Node* replacee);
+
+ // Dump graph with the CondState annotated.
+ void DumpGraphWithCondState(const string& name);
+
+ private:
+ FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
+
+ // Performs the actual cond functionalization. Iterate over groups of merge
+ // nodes (linked by common predicates & ancestor IDs), from innermost to
+ // outermost, and extract into If nodes.
+ Status FunctionalizeInternal();
+
+ // Returns the forward flow state propagated along edge `e`.
+ // This may modify state_map_.
+ StateMap::CondId StateAlongEdge(const Edge* e);
+
+ // Determines the CondState and AncestorState of all the nodes in the given
+ // vector where the input is expected in reverse topological order.
+ // This populates the state_map_.
+ Status DetermineStates(std::vector<Node*> rev_topo_order);
+
+ // Determine the CondState for a given node using the incomming edges
+ // to the node. Note: it is expected that this node's CondState is only
+ // determined once its input's CondState is.
+ Status DetermineCondState(Node* dst) {
+ if (IsMerge(dst)) return DetermineCondStateMerge(dst);
+ return DetermineCondStateNonMerge(dst);
+ }
+
+ // Helper functions for DetermineCondState.
+ Status DetermineCondStateNonMerge(Node* dst);
+ Status DetermineCondStateMerge(Node* dst);
+
+ // Determines the dst node's CondState by joining the src and dst's CondState
+ // where either the dst node is a merge or not.
+ // These may modify state_map_.
+ xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge,
+ StateMap::CondId src,
+ StateMap::CondId dst);
+ xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+ StateMap::CondId dst);
+
+ // Determines which switch/merge nodes are ancestors of this node.
+ Status DetermineAncestorState(Node* dst);
+
+ // Checks if a merge node is redundant and if so removes it from the graph.
+ Status RemoveRedundantMerge(Node* node);
+
+ // Checks if a switch node is redundant and if so removes it from the graph.
+ Status RemoveRedundantSwitch(Node* node);
+
+ // Sorts merge nodes (in reverse topological order) in order of increasing
+ // nesting depth.
+ void SortMergeNodes(std::vector<Node*>* merge_order);
+
+ // Deletes all nodes in/consumers of `delete_nodes_`.
+ void DeleteReachableNodes();
+
+ // Member used to unique the CondState to a unique CondId (AncestorState to a
+ // unique AncestorId) and keep track of CondState/CondId
+ // (AncestorState/AncestorId) per Node.
+ StateMap state_map_;
+
+ // Mapping from merge nodes to predicate.
+ std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
+
+ // Nodes to be deleted.
+ std::deque<int> delete_nodes_;
+
+ FunctionLibraryDefinition* library_;
+ Graph* graph_;
+
+ friend class FunctionalizeCondTest;
+};
+
+} // namespace functionalize_cond
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
new file mode 100644
index 0000000000..b0aabd63bb
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests for the backward const analysis.
+
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace functionalize_cond {
+
+class FunctionalizeCondTest : public ::testing::Test {
+ protected:
+ FunctionalizeCondTest() {
+ graph_.reset(new Graph(OpRegistry::Global()));
+ flib_def_.reset(
+ new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_));
+ fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(),
+ flib_def_.get()));
+ }
+
+ StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) {
+ return fc_->state_map_.GetCondId(state);
+ }
+
+ string GetString(const StateMap::StateMap::CondId id) {
+ return fc_->state_map_.CondStateToString(id);
+ }
+
+ xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+ StateMap::CondId dst) {
+ return fc_->JoinCondStatesNonMerge(src, dst);
+ }
+
+ xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* n,
+ StateMap::CondId src,
+ StateMap::CondId dst) {
+ return fc_->JoinCondStatesMerge(n, src, dst);
+ }
+
+ FunctionDefLibrary fdef_lib_;
+ std::unique_ptr<functionalize_cond::FunctionalizeCond> fc_;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ std::unique_ptr<Graph> graph_;
+};
+
+namespace {
+
+TEST_F(FunctionalizeCondTest, JoinCondStates) {
+ Tensor pred_tensor(DT_BOOL, TensorShape());
+ pred_tensor.flat<bool>().setZero();
+ Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred");
+ Tensor val_tensor(DT_INT32, TensorShape());
+ val_tensor.flat<int>().setZero();
+ Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
+ Node* m = test::graph::Merge(graph_.get(), val, val);
+
+ StateMap::CondId then_branch;
+ {
+ StateMap::CondState ss;
+ ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch));
+ then_branch = GetUniqueId(ss);
+ }
+ StateMap::CondId else_branch;
+ {
+ StateMap::CondState ss;
+ ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch));
+ else_branch = GetUniqueId(ss);
+ }
+
+ // An non-merge op with inputs from then and else branch.
+ Status status = JoinCondStatesNonMerge(then_branch, else_branch).status();
+ EXPECT_TRUE(errors::IsInvalidArgument(status));
+
+ // Merge between then and else branch.
+ auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch);
+ TF_EXPECT_OK(joined_or.status());
+ StateMap::CondId joined = joined_or.ValueOrDie();
+
+ // Merge between then branch and both branch.
+ auto t = JoinCondStatesNonMerge(then_branch, joined);
+ // Note: this is OK in terms of constraint predication, but
+ TF_EXPECT_OK(t.status());
+}
+
+} // namespace
+} // namespace functionalize_cond
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 0904778f97..5932be4e52 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -21,1440 +21,24 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/functionalize_while.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
-#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
-namespace {
-
-using xla::StatusOr;
-
-const char* const kArgOp = "_Arg";
-const char* const kRetValOp = "_Retval";
-
-// Information about a loop argument.
-struct Arg {
- // Every loop argument has an Enter node.
- Node* enter;
-
- // Is the loop argument a loop-invariant value? Taken from the `is_constant`
- // attribute on the Enter node.
- bool is_loop_invariant;
-
- // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
- // arguments must have all of the following nodes:
- Node* merge = nullptr;
- Node* switch_node = nullptr;
- Node* next_iteration = nullptr;
- Node* exit = nullptr;
-};
-
-// Information about a loop frame.
-struct Frame {
- string name;
-
- // Pointer to the parent frame. The root frame has a pointer to itself.
- Frame* parent = nullptr;
- int num_children = 0;
-
- // Arguments to this loop.
- std::vector<Arg> args;
-
- // The loop condition of the loop. There should be exactly one loop condition
- // in every loop.
- Node* loop_cond = nullptr;
-
- // Set of nodes that belong to the loop frame.
- std::unordered_set<Node*> nodes;
-};
-
-// Comparison function used for sorting nodes consistently.
-// a) resource variables are last, and
-// b) sort lexicographically by name (for deterministic output).
-struct NodeCmp {
- bool operator()(const Node* lhs, const Node* rhs) const {
- bool lhs_is_resource =
- lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
- bool rhs_is_resource =
- rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
- return std::tie(lhs_is_resource, lhs->name()) <
- std::tie(rhs_is_resource, rhs->name());
- }
-};
-
-// Returns a textual representation of the names of the nodes in the input.
-template <typename T>
-string NodesToString(const T& nodes) {
- return strings::StrCat("{",
- str_util::Join(nodes, ",",
- [](string* output, const Node* node) {
- strings::StrAppend(output,
- node->name());
- }),
- "}");
-}
-
-// Copies a subgraph from `graph` to `output` by performing a reverse DFS
-// starting at nodes in vector `stack`.
-// `node_map` is a vector indexed by source node ID to dest nodes.
-// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
-// before the traversal clients can cut the graph. If a frame is provided (frame
-// != nullptr), then this functions will return an error if the
-// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
-// cut the graph and prevent the traversal from escaping.
-//
-// `squash_src_outputs` contains a bool for each source node ID. If true, then
-// the source output on that node will be replaced by zero when copied. This is
-// used when replacing a Switch node with an _Arg node. The output we are
-// taking from the Switch node was not necessarily the first output, but _Arg
-// nodes only have one output. By adding the Switch node to `squash_src_outputs`
-// we rewrite the src_output of the corresponding edge to be 0.
-Status CopySubgraph(const Graph& graph, const Frame* frame,
- std::vector<Node*> stack,
- const std::vector<bool>& squash_src_outputs,
- std::vector<Node*>* node_map, Graph* output) {
- VLOG(3) << "Stack: " << NodesToString(stack);
- std::vector<bool> visited(graph.num_node_ids(), false);
- while (!stack.empty()) {
- Node* n = stack.back();
- stack.pop_back();
-
- VLOG(5) << "Copying node " << n->name();
-
- if (visited[n->id()]) continue;
- visited[n->id()] = true;
-
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
- // We traversed out of the loop frame, without encountering a cut node.
- return errors::Internal("Graph traversal of loop frame ", frame->name,
- " escaped frame at ", src->name(),
- " without encountering an argument node.");
- }
- if ((*node_map)[src->id()] == nullptr) {
- (*node_map)[src->id()] = output->CopyNode(src);
- stack.push_back(src);
- }
- Node* src_copy = (*node_map)[e->src()->id()];
- int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
- ? 0
- : e->src_output();
- Node* dst_copy = (*node_map)[e->dst()->id()];
- output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
- Status status;
- Node* inserted_node = graph->AddNode(node_def, &status);
- if (!status.ok()) {
- return status;
- }
- return inserted_node;
-}
-
-// Check that the graph has no cycle containing the given node.
-Status CheckNoCycleContains(const Node* node, const int num_nodes) {
- std::vector<const Node*> ready;
- ready.push_back(node);
- std::vector<bool> visited(num_nodes);
- while (!ready.empty()) {
- const Node* current_node = ready.back();
- ready.pop_back();
- visited[current_node->id()] = true;
- for (const Edge* out : current_node->out_edges()) {
- if (out->dst() == node) {
- return errors::Internal("Detected a cycle: ", FormatNodeForError(*node),
- "(", node->def().op(), ") feeds into itself.");
- } else if (!visited[out->dst()->id()]) {
- ready.push_back(out->dst());
- }
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
- NodeDef arg_def;
- NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
- builder.Attr("T", type);
- builder.Attr("index", index);
- TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
- return AddNode(arg_def, graph);
-}
-
-StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
- NodeDef ret_def;
- ret_def.set_op(kRetValOp);
- ret_def.set_name(strings::StrCat(kRetValOp, index));
- AddNodeAttr("T", type, &ret_def);
- AddNodeAttr("index", index, &ret_def);
- return AddNode(ret_def, graph);
-}
-
-// Builds a graph for the loop condition.
-Status BuildLoopCondition(const Graph& graph, Frame* frame,
- std::unique_ptr<Graph>* cond_output) {
- VLOG(2) << "Building loop condition for " << frame->name;
- *cond_output = xla::MakeUnique<Graph>(graph.op_registry());
- Graph* output = cond_output->get();
-
- // Map from nodes in the original graph to the condition graph.
- std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
- std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
-
- // Build one _Arg node for each Enter node.
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
-
- TF_ASSIGN_OR_RETURN(Node * arg_node,
- BuildArgNode(output, arg.enter->input_type(0), i));
- if (arg.is_loop_invariant) {
- node_map[arg.enter->id()] = arg_node;
- } else {
- node_map[arg.merge->id()] = arg_node;
- }
- }
-
- // Build a Retval node for the loop condition. The LoopCond nodes are always
- // boolean because of the type constraints on the LoopCond op.
- TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
- BuildRetvalNode(output, DT_BOOL, 0));
-
- // Performs a reverse DFS, copying nodes and edges to the output graph.
- // The _Arg and _Retval nodes were added unconditionally above, so we are
- // guaranteed to get the correct function signature.
- return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
- &node_map, output);
-}
-
-// Builds a graph for the loop body.
-Status BuildLoopBody(const Graph& graph, Frame* frame,
- DataTypeVector* arg_types,
- std::unique_ptr<Graph>* body_output) {
- VLOG(2) << "Building loop body for " << frame->name;
- *body_output = xla::MakeUnique<Graph>(graph.op_registry());
- Graph* output = body_output->get();
-
- // Map from nodes in the original graph to the condition graph.
- std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
- std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
-
- // Build one _Arg node for each Enter node.
- std::vector<Node*> next_iterations;
- next_iterations.reserve(frame->args.size());
- arg_types->reserve(frame->args.size());
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
-
- DataType dtype = arg.enter->input_type(0);
- arg_types->push_back(dtype);
-
- TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
-
- if (dtype == DT_RESOURCE) {
- // The convention of the XLA bridge is that resource variable arguments
- // are only inputs to the loop body and have no corresponding output.
- // TODO(b/37741920): change the convention so that DT_RESOURCE variables
- // are both inputs and outputs, and then remove this case.
- TF_RET_CHECK(arg.is_loop_invariant);
- node_map[arg.enter->id()] = arg_node;
- } else {
- TF_ASSIGN_OR_RETURN(Node * retval_node,
- BuildRetvalNode(output, dtype, i));
-
- if (arg.is_loop_invariant) {
- // Argument is loop-invariant. Forward it from the Arg to the Retval.
- node_map[arg.enter->id()] = arg_node;
- output->AddEdge(arg_node, 0, retval_node, 0);
- } else {
- // Argument is loop-varying.
- node_map[arg.switch_node->id()] = arg_node;
- // The Switch node has two outputs, but _Arg only has one. This tells
- // the CopySubgraph function to rewrite the output number of edges from
- // the _Arg node to be 0 rather than copying the output number from the
- // Switch node.
- squash_src_outputs[arg.switch_node->id()] = true;
- node_map[arg.next_iteration->id()] = retval_node;
- next_iterations.push_back(arg.next_iteration);
- }
- }
- }
-
- // Performs a reverse DFS, copying nodes and edges to the output graph.
- // The _Arg and _Retval nodes were added unconditionally above, so we are
- // guaranteed to get the correct function signature.
- TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
- squash_src_outputs, &node_map, output));
-
- return Status::OK();
-}
-
-// Copy the FunctionDef of given function from lookup_library to library, if
-// it can be found in lookup_library but is missing from library.
-Status AddMissingFunctionByName(const string& function_name,
- const FunctionLibraryDefinition* lookup_library,
- FunctionLibraryDefinition* library) {
- if (!library->Find(function_name) && lookup_library->Find(function_name)) {
- return library->AddFunctionDef(*lookup_library->Find(function_name));
- }
- return Status::OK();
-}
-
-// Iterate over all functions that the given fdef refers to. Copy the missing
-// FunctionDefs from lookup_library to library.
-Status AddMissingFunctionDef(const FunctionDef& fdef,
- const FunctionLibraryDefinition* lookup_library,
- FunctionLibraryDefinition* library) {
- TF_RET_CHECK(lookup_library);
- for (const NodeDef& node : fdef.node_def()) {
- if (library->Find(node.op())) {
- continue;
- }
- // The function referred by 'SymbolicGradient' node is specified in its
- // attribute 'f'.
- if (node.op() == FunctionLibraryDefinition::kGradientOp) {
- const AttrValue* attr =
- AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
- if (!attr) {
- return errors::InvalidArgument("SymbolicGradient is missing attr: f");
- }
- const string& func_name = attr->func().name();
- TF_RETURN_IF_ERROR(
- AddMissingFunctionByName(func_name, lookup_library, library));
- // Copy the user-defined gradient function if it exists.
- const string grad_name = lookup_library->FindGradient(func_name);
- if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
- TF_RETURN_IF_ERROR(
- AddMissingFunctionByName(grad_name, lookup_library, library));
- GradientDef grad_def;
- grad_def.set_function_name(func_name);
- grad_def.set_gradient_func(grad_name);
- TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
- }
- } else if (lookup_library->Find(node.op())) {
- TF_RETURN_IF_ERROR(
- library->AddFunctionDef(*lookup_library->Find(node.op())));
- }
- }
- return Status::OK();
-}
-
-Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
- Graph* graph, Frame* frame,
- FunctionLibraryDefinition* library) {
- VLOG(2) << "Frame " << frame->name << " before: "
- << dump_graph::DumpGraphToFile("functionalize_before", *graph,
- library);
-
- // Split loop-varying Enter nodes with multiple successors. If the same
- // Tensor is fed as input to multiple loop arguments, we may end up with a
- // shared Enter node. We clone Enter nodes with multiple successors to
- // maintain the invariant of a unique Enter node per argument of the final
- // loop.
- std::vector<Arg> args;
- for (const Arg& arg : frame->args) {
- if (arg.is_loop_invariant) {
- args.push_back(arg);
- } else {
- std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
- arg.enter->out_edges().end());
- for (int i = 0; i < edges.size(); ++i) {
- if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
- continue;
- }
- TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
- Arg new_arg;
- new_arg.is_loop_invariant = false;
- if (i == 0) {
- new_arg.enter = arg.enter;
- } else {
- new_arg.enter = graph->CopyNode(arg.enter);
- frame->nodes.insert(new_arg.enter);
- for (Edge const* e : arg.enter->in_edges()) {
- graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
- e->IsControlEdge() ? Graph::kControlSlot : 0);
- }
- Node* dst = edges[i]->dst();
- int dst_input = edges[i]->dst_input();
- graph->RemoveEdge(edges[i]);
- graph->AddEdge(new_arg.enter, 0, dst, dst_input);
- }
- args.push_back(new_arg);
- }
- }
- }
- frame->args = std::move(args);
-
- std::sort(
- frame->args.begin(), frame->args.end(),
- [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
-
- if (frame->loop_cond == nullptr) {
- return errors::InvalidArgument("Loop ", frame->name,
- " has no LoopCond node");
- }
-
- // Find the set of Switch nodes that are successors of the LoopCond.
- std::unordered_set<Node*> switches;
- for (const Edge* edge : frame->loop_cond->out_edges()) {
- if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
- edge->dst_input() == 1) {
- switches.insert(edge->dst());
- }
- }
-
- // For each non-constant argument, looks for the following pattern of nodes:
- // Enter ----> Merge --------> Switch --> Exit
- // ^ ^
- // | |
- // NextIteration LoopCond
- // ^ ^
- // | |
- // ... ...
- for (Arg& arg : frame->args) {
- if (!arg.is_loop_invariant) {
- // Follow the edge from the Enter to Merge.
- const Edge* enter_merge = nullptr;
- for (const Edge* e : arg.enter->out_edges()) {
- // Ignore control-edges to the sink node. These are allowed by the
- // graph invariants, although probably they should have been stripped
- // off earlier.
- if (e->IsControlEdge() && e->dst()->IsSink()) {
- continue;
- }
- if (enter_merge != nullptr) {
- return errors::Internal("Enter node for loop-varying argument ",
- FormatNodeForError(*arg.enter),
- " has multiple successors: ",
- FormatNodeForError(*enter_merge->dst()),
- " and ", FormatNodeForError(*e->dst()));
- }
- enter_merge = e;
- }
- if (enter_merge == nullptr) {
- return errors::Internal("Enter node for loop-varying argument ",
- FormatNodeForError(*arg.enter),
- " has zero successors");
- }
- arg.merge = enter_merge->dst();
- if (!IsMerge(arg.merge)) {
- return errors::InvalidArgument(
- "Successor of Enter node for loop-varying argument ",
- FormatNodeForError(*arg.merge),
- " is not a Merge node; got: ", arg.merge->type_string());
- }
-
- // Find the NextIteration from the merge. There should be two inputs to
- // the Merge and the NextIteration should be the other input.
- if (arg.merge->input_types().size() != 2) {
- return errors::InvalidArgument(
- "Unexpected number of inputs to Merge node for loop-varying "
- "argument ",
- FormatNodeForError(*arg.merge), "; expected 2, got ",
- arg.merge->input_types().size());
- }
- TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
- &arg.next_iteration));
- if (!IsNextIteration(arg.next_iteration)) {
- return errors::InvalidArgument(
- "Expected NextIteration node as input to Merge node; got node ",
- FormatNodeForError(*arg.next_iteration), " with kind ",
- arg.next_iteration->type_string());
- }
-
- // Find the Switch successor of the Merge. There should be exactly one
- // Switch node that is a successor of both the Merge and the LoopCond.
- for (const Edge* edge : arg.merge->out_edges()) {
- if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
- switches.find(edge->dst()) != switches.end()) {
- if (arg.switch_node != nullptr) {
- return errors::InvalidArgument("Duplicate Switch successors to ",
- FormatNodeForError(*arg.merge));
- }
- arg.switch_node = edge->dst();
- }
- }
- if (arg.switch_node == nullptr) {
- return errors::InvalidArgument("Missing Switch successor to ",
- FormatNodeForError(*arg.merge));
- }
-
- // Update the device on the Identity outputs of the switch to match their
- // target. These Identity outputs do not
-
- // Loop over the switch node's output to:
- // - Find the Exit successor.
- // - Set the sharding on all Identity outputs of the switch. These
- // identity nodes are values used by the loop body or condition.
- // The Identity node may have the wrong device so copy the device from
- // one of its outputs instead.
- std::deque<const Edge*> possible_exit;
- for (const Edge* edge : arg.switch_node->out_edges()) {
- if (edge->src_output() == 0) {
- possible_exit.push_back(edge);
- }
- if (IsIdentity(edge->dst())) {
- TF_RETURN_IF_ERROR(
- SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
- }
- }
- // TODO(b/67425339): Allow general graph between switch and exit.
- while (!possible_exit.empty()) {
- const Edge* edge = possible_exit.front();
- possible_exit.pop_front();
- if (IsExit(edge->dst())) {
- if (arg.exit != nullptr) {
- return errors::InvalidArgument(
- "Duplicate Exit successors to ",
- FormatNodeForError(*arg.switch_node));
- }
- arg.exit = edge->dst();
- } else {
- if (!IsIdentity(edge->dst())) {
- return errors::Unimplemented("General graph between switch (",
- FormatNodeForError(*arg.switch_node),
- ") and exit node of frame ",
- frame->name, " not supported yet.");
- }
- for (const Edge* out : edge->dst()->out_edges()) {
- possible_exit.push_back(out);
- }
- }
- }
- }
- }
-
- // Builds the condition and body functions.
- std::unique_ptr<Graph> cond_graph;
- TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
- DataTypeVector arg_types;
- std::unique_ptr<Graph> body_graph;
- TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
-
- VLOG(2) << "Frame " << frame->name << " condition: "
- << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
- << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
-
- static std::atomic<int64> sequence_num(0LL);
- int64 id = ++sequence_num;
- NameAttrList cond_name;
- cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
- NameAttrList body_name;
- body_name.set_name(strings::StrCat("_functionalize_body_", id));
- FunctionDef cond_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
- FunctionDef body_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
-
- TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
- TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
- if (lookup_library) {
- // Copy missing FunctionDefs from lookup_library to library to make library
- // self-contained.
- TF_RETURN_IF_ERROR(
- AddMissingFunctionDef(cond_fdef, lookup_library, library));
- TF_RETURN_IF_ERROR(
- AddMissingFunctionDef(body_fdef, lookup_library, library));
- }
-
- // Builds a While operator.
- NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
- builder.Attr("T", arg_types);
- builder.Attr("cond", cond_name);
- builder.Attr("body", body_name);
- std::vector<NodeDefBuilder::NodeOut> inputs;
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- builder.ControlInput(in_edge->src()->name());
- } else {
- inputs.push_back(NodeDefBuilder::NodeOut(
- in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
- }
- }
- builder.Input(inputs);
- TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
- TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph));
-
- // Copies edges to the Enter nodes and from the Exit nodes onto the While.
- for (int i = 0; i < frame->args.size(); ++i) {
- const Arg& arg = frame->args[i];
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- graph->AddControlEdge(in_edge->src(), while_node);
- } else {
- graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
- }
-
- if (!arg.is_loop_invariant) {
- // Add output edges if the output of the loop is consumed.
- if (arg.exit != nullptr) {
- std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
- arg.exit->out_edges().end());
- for (const Edge* edge : edges) {
- Node* dst = edge->dst();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (dst_input == Graph::kControlSlot) {
- graph->AddControlEdge(while_node, dst);
- } else {
- graph->AddEdge(while_node, i, dst, dst_input);
- }
- }
- }
- }
- }
-
- // Remove the old nodes from the graph, and add the while node to the parent
- // frame.
- for (Node* node : frame->nodes) {
- graph->RemoveNode(node);
- }
- frame->nodes.clear();
- frame->parent->nodes.insert(while_node);
-
- VLOG(2) << "Frame " << frame->name << " after: "
- << dump_graph::DumpGraphToFile("functionalize_after", *graph,
- library);
-
- return Status::OK();
-}
-
-class FunctionalizeCond {
- public:
- // All nodes are assumed to be either in no branch, then branch, else branch,
- // or both branches (such as merge nodes).
- enum Branch {
- kElseBranch = 0,
- kThenBranch = 1,
- kBoth = 2,
- kNeither = 3,
- kNumBranchTypes = 4
- };
-
- // Returns a textual representation of the Branch b.
- static string Branch_Name(FunctionalizeCond::Branch b);
-
- // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
- // nodes. That is, attempt to transform every remaining switch and merge nodes
- // in the graph into XlaIf nodes.
- // Precondition: All while loops have been removed from graph.
- static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);
-
- private:
- // CondArgNode represents a input to the conditional and its corresponding
- // switch nodes.
- struct CondArgNode {
- explicit CondArgNode(Node* src, int src_output)
- : src(src), src_output(src_output) {}
- string ToString() const {
- return strings::StrCat("src=", src->name(), ":", src_output,
- " switches=", NodesToString(switches));
- }
-
- Node* src;
- int src_output;
- std::vector<Node*> switches;
- };
- using CondArgNodes = std::vector<CondArgNode>;
-
- struct ForwardFlowNode {
- explicit ForwardFlowNode(Branch branch = Branch::kNeither)
- : branch(branch), count(0) {}
- string ToString() const {
- return strings::StrCat("branch=", Branch_Name(branch), " count=", count);
- }
- Branch branch;
- int count;
- };
-
- // Group of switch nodes that will be part of the same XlaIf.
- struct SwitchCluster {
- explicit SwitchCluster(const Edge* predicate_edge)
- : predicate_edge(predicate_edge) {}
- string ToString() const {
- return strings::StrCat(name, " predicate=", predicate_edge->src()->name(),
- " switches=", NodesToString(switches));
- }
-
- string name;
- const Edge* predicate_edge;
- std::vector<Node*> switches;
- };
-
- FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
- bool dump_graphs)
- : library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
-
- // Perform the actual cond functionalization. Iterate over groups of switch
- // nodes (linked by common predicate), from innermost to outermost, and
- // extract into XlaIf nodes.
- Status FunctionalizeInternal();
-
- // Determines the branch_map (mapping from node to branch of cond) and
- // frontier (the nodes where the cond ends).
- StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
- std::unordered_set<Node*>>>
- DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
-
- // Returns XlaIf node created from subgraph of merge and switch nodes. This
- // encapsulates the process of extracting the bodies needed for the then and
- // else branch, creates a XlaIf node, removing the nodes of the branches from
- // the graph and replacing the merge node with a XlaIf.
- StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
- const SwitchCluster& switch_cluster,
- const std::vector<Node*>& switches);
-
- // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
- StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
- const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes);
-
- // Extracts a function body corresponding to the given input edge of the merge
- // node.
- Status ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switches,
- const std::vector<Node*>& merge_nodes, int input_edge,
- Graph* body);
-
- // Adds all the input edges to `if_node` corresponding to the arguments.
- Status AddInputEdges(const CondArgNodes& cond_arg_nodes,
- const Edge* predicate_edge, Node* if_node);
-
- // Adds all output edges from the `if_node`.
- Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
-
- // Returns the switch clusters of graph_ in postorder. Dead switch nodes are
- // skipped and removed from the graph.
- StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
-
- // Update the state for destination based on the state of source and the node
- // being updated.
- Status Join(const ForwardFlowNode& src_state, const Node* dst,
- ForwardFlowNode* dst_state);
-
- // Ensure that all nodes in the branch_map are dominated by the switch
- // nodes. Returns nodes that are not dominated by the switches but are a
- // control dependency of a node in the cond, and remove such control
- // dependencies.
- StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::vector<Node*>& switches);
-
- // Validates that the frontier of nodes for the conditional
- // section are as expected.
- Status ValidateFrontier(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::unordered_set<Node*>& frontier);
-
- FunctionLibraryDefinition* library_;
- Graph* graph_;
- bool dump_graphs_;
-};
-
-bool IsDeadSwitch(const Node* node) {
- for (const Edge* e : node->out_edges()) {
- const Node* dst = e->dst();
- if (!dst->IsIdentity()) {
- return false;
- }
- for (const Edge* ee : dst->out_edges()) {
- if (!ee->IsControlEdge() || !ee->dst()->IsSink()) {
- return false;
- }
- }
- }
- return true;
-}
-
-string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) {
- const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = {
- "else", "then", "both", "neither", "count"};
- return branch_name[b];
-}
-
-Status FunctionalizeCond::ValidateFrontier(
- const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
- branch_map,
- const std::unordered_set<Node*>& frontier) {
- std::unordered_set<const Node*> pending[kNumBranchTypes];
- for (Node* n : frontier) {
- pending[branch_map.at(n).branch].insert(n);
- }
- TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]);
- for (const Node* n : pending[kBoth]) {
- TF_RET_CHECK(IsMerge(n)) << n->DebugString();
- // Merge nodes may be in then or else branch too
- }
- int index = (pending[kThenBranch].size() <= pending[kElseBranch].size())
- ? kThenBranch
- : kElseBranch;
- int other = 1 - index;
- for (const Node* n : pending[index]) {
- if (pending[other].find(n) != pending[other].end()) {
- return errors::Internal(
- "Node (", n->DebugString().c_str(),
- ") in both Else and Then branch should be in Both.");
- }
- }
- // An empty frontier indicates a dead switch. Above we attempt to remove dead
- // switch nodes, but not all are removed so don't treat it as an error yet.
- // TODO(jpienaar): Find out why dead switch nodes remain.
- // if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
- // pending[kElseBranch].empty()) {
- // return errors::Internal("Unexpected empty frontier for switch nodes");
- // }
- return Status::OK();
-}
-
-Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
- const Node* dst, ForwardFlowNode* dst_state) {
- TF_RET_CHECK(dst_state->branch != Branch::kBoth &&
- dst_state->branch != Branch::kNumBranchTypes)
- << "Unexpected/Invalid branch type: Merging "
- << Branch_Name(src_state.branch) << " with "
- << Branch_Name(dst_state->branch);
- if (dst_state->branch == Branch::kNeither) {
- dst_state->branch = src_state.branch;
- } else if (src_state.branch != dst_state->branch &&
- src_state.branch != Branch::kNeither) {
- if (IsMerge(dst)) {
- dst_state->branch = Branch::kBoth;
- } else {
- return errors::Internal("Illegal merge:\n", src_state.ToString(),
- " with ", dst_state->ToString(), " for\n",
- dst->DebugString());
- }
- }
- ++dst_state->count;
- return Status::OK();
-}
-
-StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
-FunctionalizeCond::DeterminePredicateSwitchOrder() {
- struct Cluster {
- bool operator==(const Cluster& other) const {
- return representative == other.representative;
- }
- int representative = -1;
- };
-
- // Perform a DFS over the graph and
- // * Determine the reverse topological order of the nodes (there should be no
- // cycles at this point so the post-order numbering corresponds to the
- // reverse topological sorting);
- // * Identify dead switches;
- // * Initialize the cluster's representative;
- std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
- std::vector<Node*> dead_switches;
- std::vector<Node*> switch_order;
- std::vector<Node*> rev_topo_sorted_nodes;
- DFS(*graph_, nullptr, [&](Node* n) {
- clusters[n->id()].Get().representative = n->id();
- if (IsSwitch(n)) {
- if (IsDeadSwitch(n)) {
- dead_switches.push_back(n);
- } else {
- rev_topo_sorted_nodes.push_back(n);
- switch_order.push_back(n);
- }
- } else if (n->IsOp()) {
- // Exclude src and sink nodes from further consideration.
- rev_topo_sorted_nodes.push_back(n);
- }
- });
-
- std::vector<SwitchCluster> switch_clusters;
- // Return early if there are no switches in the graph.
- if (switch_order.empty()) {
- return switch_clusters;
- }
-
- // Remove all dead switch nodes.
- for (Node* n : dead_switches) {
- VLOG(2) << "Removing dead switch: " << n->DebugString();
- graph_->RemoveNode(n);
- }
-
- // Identify switch nodes that are part of the same control flow context by
- // considering the operands of operations: an operation is part of the same
- // control context as its operands unless the operation is a switch. Control
- // dependencies are considered part of the same control flow context if the
- // switch depth is the same (see comment below).
-
- // entry_cluster records the input cluster to a switch node. This is used when
- // merging with a merge node where the dst's cluster is merged with the entry
- // cluster of the merge node's cluster (which corresponds to a switch cluster
- // and so has an entry cluster).
- std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
-
- // Returns the output cluster of a node. Where the output cluster is cluster
- // where the output of the node is used. For non-merge nodes this is simply
- // the cluster they are part of, while for merge nodes it is the entry cluster
- // of the cluster they are part of (this will correspond to the entry node of
- // a switch node that dominates the merge).
- auto find_output_cluster = [&](Node* n) {
- UnionFind<Cluster>* cluster = &clusters[n->id()];
- if (!IsMerge(n)) return cluster;
- auto it = entry_cluster.find(clusters[n->id()].Get().representative);
- // If the cluster is not found in the entry_cluster map then an
- // instruction not dominated by a switch node has been merged into the
- // cluster of the merge. This indicates a failure of the clustering.
- CHECK(it != entry_cluster.end())
- << "Unable to find entry for n=" << n->id() << " ("
- << cluster->Get().representative << ")";
- return it->second;
- };
-
- // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
- std::vector<int> switch_depth(graph_->num_node_ids());
- for (auto it = rev_topo_sorted_nodes.rbegin();
- it != rev_topo_sorted_nodes.rend(); ++it) {
- Node* n = *it;
-
- // Compute switch depth.
- int new_switch_depth = 0;
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- new_switch_depth = std::max(
- new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
- }
- switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
-
- // Only merge the input operands of a switch. The switch's clustering itself
- // is determined by the interaction of the switch's outputs.
- if (IsSwitch(n)) {
- Node* input;
- TF_CHECK_OK(n->input_node(0, &input));
- entry_cluster[n->id()] = find_output_cluster(input);
- UnionFind<Cluster>* cluster = entry_cluster[n->id()];
- int cluster_depth = switch_depth[cluster->Get().representative];
- // Merge the inputs of the switch node with one another. This results in
- // predicates and control input residing in the same cluster.
- for (const Edge* e : n->in_edges()) {
- // Only consider the data inputs to the Switch node.
- if (e->IsControlEdge()) continue;
-
- Node* src = e->src();
- UnionFind<Cluster>* src_cluster = find_output_cluster(src);
- int src_cluster_depth = switch_depth[src_cluster->Get().representative];
- if (cluster_depth != src_cluster_depth) {
- return errors::InvalidArgument(
- "Unable to functionalize control flow in graph: Switch ('",
- n->name(), "') has operands ('", input->name(), "' and '",
- src->name(), "') that have different switch depths (",
- cluster_depth, " != ", src_cluster_depth, ")");
- }
- cluster->Merge(src_cluster);
- }
- continue;
- }
-
- for (const Edge* e : n->in_edges()) {
- Node* src = e->src();
- if (!src->IsOp()) continue;
- UnionFind<Cluster>* cluster = find_output_cluster(src);
- // Merge a node with its data operands and with its control operands if
- // the src and dst are in the same ControlContext. The ControlContext is
- // not explicitly available here, and instead the switch depth is used as
- // a proxy here. Due to the invariant that control edges can only be from
- // a containing scope to an inner scope or from the inner scope to its
- // containing scope (for exit nodes), the switch depth will only match if
- // the src and dst are in the same ControlContext. Control edges between
- // ControlContexts are handled during the extraction.
- int src_id = cluster->Get().representative;
- int src_depth = switch_depth[src_id];
- if (!e->IsControlEdge() || new_switch_depth == src_depth) {
- if (src_depth != new_switch_depth) {
- // TODO(b/77601805) remove this when outside_compilation supports
- // control flow.
- if (str_util::StrContains(src->name(), "outside_compilation") ||
- str_util::StrContains(n->name(), "outside_compilation")) {
- return errors::InvalidArgument(
- "outside_compilation is not yet supported within TensorFlow "
- "control flow constructs b/77601805");
- }
- return errors::InvalidArgument(
- "Unable to functionalize control flow in graph: Operand ('",
- src->name(), "') and operator ('", n->name(),
- "') have different switch depths (", src_depth,
- " != ", new_switch_depth, ")");
- }
- cluster->Merge(&clusters[n->id()]);
- }
- }
- }
-
- if (dump_graphs_) {
- // Mark the switch cluster each node is part of.
- for (Node* n : graph_->nodes()) {
- n->ClearAttr("_XlaFunctionalizeSwitchGroup");
- n->AddAttr("_XlaFunctionalizeSwitchGroup",
- clusters[n->id()].Get().representative);
- }
- LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
- << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
- library_);
- }
-
- // Verify all the nodes of a cluster are at the same depth.
- std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
- for (Node* n : graph_->nodes()) {
- int depth = switch_depth[n->id()];
- int cluster_rep = clusters[n->id()].Get().representative;
- auto it = cluster_to_depth_node.find(cluster_rep);
- if (it == cluster_to_depth_node.end()) {
- cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
- } else {
- if (it->second.first != depth) {
- return errors::Internal(
- "Illegal clustering created, mismatch in depths:", "\n\t",
- n->DebugString(), "(", clusters[n->id()].Get().representative,
- ") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
- "(", clusters[n->id()].Get().representative, ") at depth ",
- it->second.first);
- }
- }
- }
-
- struct Hash {
- size_t operator()(const std::pair<Node*, Cluster>& item) const {
- return Hash64Combine(hash<Node*>()(item.first),
- std::hash<int>()(item.second.representative));
- }
- };
-
- // Merge Switch nodes with common predicate.
- std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
- // The nodes in switch_order are in reverse topological order, but the
- // clustered switches need not be (i.e., when considered as a cluster one
- // element of a cluster may be later in the topological order than another
- // node whose cluster is later in the topological order of clustered
- // switches).
- for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
- const Edge* pred_edge;
- TF_CHECK_OK((*it)->input_edge(1, &pred_edge));
- // The predicate can be preceded by a identity node. Look through identity
- // nodes to predicate.
- while (pred_edge->src()->IsIdentity()) {
- TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge));
- }
- auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get());
- if (predicate_index.find(repr) == predicate_index.end()) {
- predicate_index[repr] = switch_clusters.size();
- switch_clusters.emplace_back(pred_edge);
- // Generate a name by concatenating with the cluster representative as
- // there could be multiple switch clusters with the same predicate.
- switch_clusters[predicate_index[repr]].name = strings::StrCat(
- pred_edge->src()->name(), "_", repr.second.representative, "_If");
- }
- switch_clusters[predicate_index[repr]].switches.push_back(*it);
- }
-
- return switch_clusters;
-}
-
-StatusOr<std::vector<Node*>>
-FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
- const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
- const std::vector<Node*>& switches) {
- std::vector<Node*> old_control_nodes;
- for (const auto& kv : branch_map) {
- if (kv.second.count != kv.first->in_edges().size()) {
- std::vector<const Edge*> delete_edges;
- for (const Edge* in : kv.first->in_edges()) {
- auto it = branch_map.find(in->src());
- if (it == branch_map.end()) {
- if (in->IsControlEdge()) {
- old_control_nodes.push_back(in->src());
- delete_edges.push_back(in);
- } else {
- if (IsSwitch(in->src())) {
- if (std::find(switches.begin(), switches.end(), in->src()) ==
- switches.end()) {
- return errors::Internal(
- "Unexpected switch node found during flow forward: ",
- in->src()->DebugString());
- }
- continue;
- }
- return errors::InvalidArgument(
- "Value ", kv.first->name(), "'s input, ", in->src()->name(),
- ", is not dominated by switch nodes ", NodesToString(switches));
- }
- }
- }
- // Remove control edges from nodes that are not dominated by the switch
- // nodes. New control dependencies will be added between these nodes and
- // the XlaIf node inserted.
- for (const Edge* e : delete_edges) {
- graph_->RemoveEdge(e);
- }
- }
- }
- return old_control_nodes;
-}
-
-StatusOr<
- std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
- std::unordered_set<Node*>>>
-FunctionalizeCond::DetermineBranchMapAndFrontier(
- const SwitchCluster& switch_cluster) {
- std::unordered_map<Node*, ForwardFlowNode> branch_map;
- std::unordered_set<Node*> frontier;
- std::vector<Node*> stack = switch_cluster.switches;
- std::vector<bool> visited(graph_->num_node_ids(), false);
- while (!stack.empty()) {
- Node* n = stack.back();
- stack.pop_back();
-
- if (visited[n->id()]) {
- continue;
- }
- visited[n->id()] = true;
-
- // Propagate branch state along each edge of a switch node.
- bool sink_only = true;
- for (const Edge* e : n->out_edges()) {
- Node* out = e->dst();
- if (!out->IsOp()) {
- continue;
- }
- sink_only = false;
- // Propagate branch information.
- ForwardFlowNode& ffn = branch_map[out];
- if (IsSwitch(n)) {
- int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ",
- e->DebugString());
- } else {
- TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn),
- " when joining ", e->DebugString());
- }
- if (IsMerge(out)) {
- if (out->in_edges().size() == ffn.count) {
- frontier.insert(out);
- }
- } else if (!visited[out->id()]) {
- stack.push_back(out);
- }
- }
- if (sink_only) {
- if (!IsIdentity(n)) {
- VLOG(1) << "Feeding into sink: " << n->DebugString();
- }
- }
- }
-
- if (dump_graphs_) {
- for (const auto& kv : branch_map) {
- // Append attribute to the graph if running with logging to make the
- // changes clearer in the visualization.
- kv.first->AddAttr("_XlaFunctionalizeBranch",
- Branch_Name(kv.second.branch));
- }
- }
- return std::make_pair(std::move(branch_map), std::move(frontier));
-}
-
-Status FunctionalizeCond::FunctionalizeInternal() {
- TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
- DeterminePredicateSwitchOrder());
-
- // Iterate from innermost set of clustered switches to outermost, replacing
- // matching switch->merge subgraphs with single XlaIf nodes.
- for (auto it = predicate_switch_order.rbegin();
- it != predicate_switch_order.rend(); ++it) {
- auto& ps = *it;
- VLOG(3) << "Flow down from: " << ps.ToString();
-
- std::unordered_map<Node*, ForwardFlowNode> branch_map;
- std::unordered_set<Node*> frontier;
- TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
- DetermineBranchMapAndFrontier(ps));
-
- if (dump_graphs_)
- LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
- library_);
- TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
-
- struct Hash {
- size_t operator()(const std::pair<Node*, int>& item) const {
- return Hash64Combine(hash<Node*>()(item.first),
- std::hash<int>()(item.second));
- }
- };
-
- // Sort the merge and switch nodes using NodeCmp. The switch-nodes are
- // further grouped (post sorting) by input to the switch node as in the
- // functionalized form each input will be passed in only once. This grouping
- // should retain the sorted order.
- CondArgNodes cond_arg_nodes;
- std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp());
- std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
- for (Node* switch_node : ps.switches) {
- const Edge* e;
- TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
- std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
- if (input_index.find(key) == input_index.end()) {
- input_index[key] = cond_arg_nodes.size();
- cond_arg_nodes.emplace_back(key.first, key.second);
- }
- cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node);
- }
- std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
- std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
-
- TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
- EnsureDominanceAndReturnNonDominatedControlNodes(
- branch_map, ps.switches));
-
- TF_ASSIGN_OR_RETURN(Node * if_node,
- ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
- for (Node* old : old_control_nodes) {
- graph_->AddControlEdge(old, if_node);
- }
-
- for (auto& del_kv : branch_map) {
- graph_->RemoveNode(del_kv.first);
- }
- for (auto& kv : cond_arg_nodes) {
- for (Node* node : kv.switches) {
- graph_->RemoveNode(node);
- }
- }
- if (dump_graphs_)
- LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
- << dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
- library_);
- }
- return Status::OK();
-}
-
-StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
- const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes) {
- VLOG(2) << "Build if op for " << switch_cluster.name;
-
- NodeDef if_def;
- // Create a new If node using the name of the merge node.
- NodeDefBuilder builder(switch_cluster.name, "XlaIf");
- string branch[] = {"else_branch", "then_branch"};
- for (int i = 0; i < 2; ++i) {
- static std::atomic<int64> sequence_num(0LL);
- int64 id = ++sequence_num;
-
- NameAttrList body_name;
- body_name.set_name(
- strings::StrCat("_functionalize_if_", branch[i], "_", id));
- auto body = xla::MakeUnique<Graph>(graph_->op_registry());
- TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
- merge_nodes, i, body.get()));
- VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
- FunctionDef body_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
- TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
- builder.Attr(branch[i], body_name);
- }
-
- // Build input type.
- std::vector<NodeDefBuilder::NodeOut> inputs;
- DataTypeVector in_arg_types;
- for (auto& kv : cond_arg_nodes) {
- bool inserted = false;
- for (const Node* arg : kv.switches) {
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
- if (in_edge->IsControlEdge()) {
- builder.ControlInput(in_edge->src()->name());
- } else {
- if (!inserted) {
- DataType dtype = arg->input_type(0);
- inputs.emplace_back(NodeDefBuilder::NodeOut(
- in_edge->src()->name(), in_edge->src_output(), dtype));
- in_arg_types.push_back(dtype);
- inserted = true;
- }
- }
- }
- }
- builder.Attr("Tin", in_arg_types);
-
- // Build output type.
- DataTypeVector out_type;
- for (const Node* merge : merge_nodes) {
- DataType dtype = merge->output_type(0);
- out_type.push_back(dtype);
- }
- builder.Attr("Tout", out_type);
-
- builder.Attr("Tcond", DT_BOOL);
- builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name());
- // Conditional should be the first input ...
- builder.Input(NodeDefBuilder::NodeOut(
- switch_cluster.predicate_edge->src()->name(),
- switch_cluster.predicate_edge->src_output(),
- switch_cluster.predicate_edge->src()->output_type(0)));
- // ... followed by the other inputs.
- builder.Input(inputs);
-
- TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
- TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_));
- return if_node;
-}
-
-Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
- const std::vector<Node*>& switches,
- const std::vector<Node*>& merge_nodes,
- int input_edge, Graph* body) {
- VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
- << input_edge;
- std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false);
- std::vector<Node*> node_map(graph_->num_node_ids(), nullptr);
- int arg_count = 0;
- for (auto& kv : cond_arg_nodes) {
- Node* arg_node = nullptr;
- for (const auto* arg : kv.switches) {
- DataType dtype = arg->input_type(0);
- if (arg_node == nullptr) {
- TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
- }
- node_map.at(arg->id()) = arg_node;
- squash_src_outputs.at(arg->id()) = true;
- }
- }
-
- std::vector<Node*> stack;
- stack.reserve(merge_nodes.size());
- for (int j = 0; j < merge_nodes.size(); ++j) {
- Node* node = merge_nodes[j];
- TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
- BuildRetvalNode(body, node->output_type(0),
- /*index=*/j));
- const Edge* in_edge;
- TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge));
- Node* in = in_edge->src();
- if (node_map.at(in->id()) == nullptr) {
- node_map.at(in->id()) = body->CopyNode(in);
- }
-
- if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
- body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
- node_map.at(node->id()), 0);
- } else {
- body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0);
- // Don't include input nodes that are already just returned in stack.
- continue;
- }
- stack.push_back(in);
- }
-
- return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map,
- body);
-}
-
-Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
- const Edge* predicate_edge,
- Node* if_node) {
- VLOG(3) << "AddInputEdges for " << if_node->name();
- int index = 0;
- graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node,
- index++);
- for (auto& arg : cond_arg_nodes) {
- if (arg.src_output == Graph::kControlSlot) {
- graph_->AddControlEdge(arg.src, if_node);
- } else {
- graph_->AddEdge(arg.src, arg.src_output, if_node, index++);
- }
- }
- return Status::OK();
-}
-
-Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
- Node* if_node) {
- VLOG(3) << "AddOutputEdges for " << if_node->name();
- for (int i = 0; i < outputs.size(); ++i) {
- Node* node = outputs[i];
- std::vector<const Edge*> edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : edges) {
- Node* dst = edge->dst();
- int dst_input = edge->dst_input();
-
- if (edge->src_output() > 0) {
- return errors::Unimplemented("Output of index (", edge->src_output(),
- ") of merge node ", node->name());
- }
-
- int src_output =
- dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
- graph_->RemoveEdge(edge);
- graph_->AddEdge(if_node, src_output, dst, dst_input);
- }
- }
- return Status::OK();
-}
-
-StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
- const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
- const std::vector<Node*>& merge_nodes) {
- VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
- << NodesToString(merge_nodes);
-
- // Extract bodies and builds a If operator.
- TF_ASSIGN_OR_RETURN(
- Node * if_node,
- BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
- TF_RETURN_IF_ERROR(
- AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node));
- TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
- // Check that the if_node doesn't feed into itself.
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- CheckNoCycleContains(if_node, graph_->num_node_ids()),
- "ConvertToXlaIf failed.");
-
- return if_node;
-}
-
-Status FunctionalizeCond::Functionalize(Graph* graph,
- FunctionLibraryDefinition* library) {
- VLOG(1) << "FunctionalizeCond::Functionalize";
- FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
- return fc.FunctionalizeInternal();
-}
-
-} // namespace
-
-// Transformation that converts TensorFlow's graph control flow constructs into
-// functional equivalents.
-Status FunctionalizeControlFlow(Graph* graph,
- FunctionLibraryDefinition* library) {
- return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
-}
-
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library) {
@@ -1462,98 +46,26 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
library);
- // Note: BuildControlFlowInfo() requires that the graph's source node is
- // connected to all source nodes in the graph. Many graphs violate this
- // invariant.
- std::vector<ControlFlowInfo> cf_info;
- std::vector<string> unreachable_nodes;
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes),
- "FunctionalizeControlFlow failed");
- if (!unreachable_nodes.empty()) {
- return errors::InvalidArgument(
- "The following nodes are unreachable from the source in the graph: ",
- errors::FormatNodeNamesForError(unreachable_nodes));
- }
-
- // Builds Frames, indexed by name.
- std::unordered_map<string, Frame> frames;
- for (Node* node : graph->op_nodes()) {
- const ControlFlowInfo& cf = cf_info[node->id()];
-
- VLOG(2) << "node: " << node->name() << " (" << node->id()
- << ") frame_name: " << cf.frame_name
- << " frame: " << (cf.frame ? cf.frame->name() : "---")
- << " parent_frame: "
- << (cf.parent_frame ? cf.parent_frame->name() : "---");
- TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
-
- Frame& frame = frames[cf.frame_name];
- Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
- if (frame.parent == nullptr) {
- frame.parent = parent;
- frame.name = cf.frame_name;
- ++parent->num_children;
- }
-
- if (IsEnter(node)) {
- Arg arg;
- arg.enter = node;
- TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
- &arg.is_loop_invariant));
- frame.args.push_back(arg);
- } else if (IsLoopCond(node)) {
- frame.loop_cond = node;
- }
- frame.nodes.insert(node);
- }
-
- // Adds frames with no children (i.e., the innermost frames) to a worklist.
- std::deque<Frame*> worklist;
- for (auto& frame : frames) {
- if (frame.second.num_children == 0) {
- worklist.push_back(&frame.second);
- }
- }
-
- // Eliminate loops from innermost to outermost.
- while (!worklist.empty()) {
- Frame* frame = worklist.front();
- worklist.pop_front();
- if (frame->parent == frame) {
- // Skip the root frame.
- continue;
- }
-
- TF_RETURN_IF_ERROR(
- FunctionalizeLoop(lookup_library, graph, frame, library));
-
- // If the parent has no remaining children, add it to the worklist.
- --frame->parent->num_children;
- if (frame->parent->num_children == 0) {
- worklist.push_back(frame->parent);
- }
- }
- // There should be no cycle at this point, since while loops have been removed
- // from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
- for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- CheckNoCycleContains(node, graph->num_node_ids()),
- "FunctionalizeLoop failed.");
- }
- }
+ // Functionalize and remove while loops from graph.
+ TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library));
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
// in successive invocations.
- TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
+ TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library));
VLOG(2) << "FunctionalizeControlFlow (final): "
<< dump_graph::DumpGraphToFile("functionalize_final", *graph,
library);
+
return Status::OK();
}
+// Transformation that converts TensorFlow's graph control flow constructs into
+// functional equivalents.
+Status FunctionalizeControlFlow(Graph* graph,
+ FunctionLibraryDefinition* library) {
+ return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index d941041d15..55600f2a8b 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -16,14 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// Transformation that converts tf.while_loop() loops into functional While
-// operators, suitable for XLA compilation. If lookup_library is provided, use
-// it to make the library for control flow self-contained.
+// operators and tf.cond() conditionals into function If operators, suitable for
+// XLA compilation. If lookup_library is provided, use it to make the library
+// for control flow self-contained.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library);
Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index ccf249b35d..c068a4110c 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -37,12 +37,12 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Returns the names of the "then" and "else" functions for the XlaIf node in a
+// Returns the names of the "then" and "else" functions for the If node in a
// graph.
Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
NameAttrList* then_fn, NameAttrList* else_fn) {
for (const NodeDef& node : graph.node()) {
- if (node.op() == "XlaIf") {
+ if (node.op() == "If") {
*op_name = node.name();
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
@@ -52,7 +52,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
return Status::OK();
}
}
- return errors::NotFound("No XlaIf node found in graph");
+ return errors::NotFound("No If node found in graph");
}
// Graph:
@@ -115,8 +115,13 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
std::initializer_list<Input>{less, y, x}, then_fn,
else_fn, {DT_INT32});
+ auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
+ // TODO(jpienaar): Create wrapper for IfOp.
+ for (NodeDef& n : *expected.mutable_node()) {
+ if (n.op() == "XlaIf") n.set_op("If");
+ }
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
@@ -800,11 +805,11 @@ TEST(FunctionalizeControlFlow, Complex) {
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
- auto one =
- ops::Const<int32>(scope.WithOpName("outer/inner/One")
- .WithControlDependencies(
- gtl::ArraySlice<Operation>{assign.operation}),
- 1);
+ auto one = ops::Const<int32>(
+ scope.WithOpName("outer/inner/One")
+ .WithControlDependencies(
+ absl::Span<const Operation>{assign.operation}),
+ 1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
@@ -818,7 +823,7 @@ TEST(FunctionalizeControlFlow, Complex) {
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
- .WithControlDependencies(gtl::ArraySlice<Operation>{
+ .WithControlDependencies(absl::Span<const Operation>{
exit_j.output.op(), exit_k.output.op()}),
identity_i, one_outer);
auto next_iteration_i =
@@ -924,7 +929,7 @@ TEST(FunctionalizeControlFlow, Complex) {
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
auto add_i =
ops::Add(scope.WithOpName("outer/add")
- .WithControlDependencies(gtl::ArraySlice<Operation>{
+ .WithControlDependencies(absl::Span<const Operation>{
while_op[0].op(), while_op[1].op()}),
identity_i, one_outer);
@@ -986,11 +991,11 @@ TEST(FunctionalizeControlFlow, Complex) {
auto assign = ops::AssignAddVariableOp(
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
- auto one =
- ops::Const<int32>(scope.WithOpName("outer/inner/One")
- .WithControlDependencies(
- gtl::ArraySlice<Operation>{assign.operation}),
- 1);
+ auto one = ops::Const<int32>(
+ scope.WithOpName("outer/inner/One")
+ .WithControlDependencies(
+ absl::Span<const Operation>{assign.operation}),
+ 1);
auto add_j =
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
@@ -1013,63 +1018,5 @@ TEST(FunctionalizeControlFlow, Complex) {
}
}
-TEST(FunctionalizeControlFlow, Cycle) {
- std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- // -----------------------------------------------------
- // | |
- // | v
- // less -> switch_1 --> add -> merge_1 -> identity -> switch_2
- // | ^ |
- // | | v
- // --------> one -------------------------> add_2 ---> merge_2
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
-
- auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
- auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
- auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
- auto two =
- ops::Const<int32>(scope.WithOpName("cond/two")
- .WithControlDependencies(switch_1.output_true),
- 2);
- auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"),
- switch_1.output_true, two);
- auto one =
- ops::Const<int32>(scope.WithOpName("cond/one")
- .WithControlDependencies(switch_1.output_false),
- 1);
- auto add = ops::Add(scope.WithOpName("cond/false/add"),
- switch_1.output_false, one);
-
- auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"),
- std::initializer_list<Input>{add, mul});
- auto identity =
- ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output);
- auto switch_2 =
- ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less);
- auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"),
- switch_2.output_false, one);
- auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"),
- switch_2.output_true, two);
- auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"),
- std::initializer_list<Input>{add_2, mul_2});
- TF_ASSERT_OK(scope.ToGraph(graph.get()));
- }
- // No cycle before functionalize control flow.
- TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph));
- FunctionLibraryDefinition library(OpRegistry::Global(), {});
- // switch_1 and switch_2 have the same switch depth. They are replaced by a
- // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle:
- // less -> XlaIf <--> identity.
- Status status = FunctionalizeControlFlow(graph.get(), &library);
- EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle"))
- << status.error_message();
- EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}"))
- << status.error_message();
-}
-
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
new file mode 100644
index 0000000000..54cebc6177
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+
+namespace tensorflow {
+
+bool NodeCmpByNameResourcesLast::operator()(const Node* lhs,
+ const Node* rhs) const {
+ bool lhs_is_resource =
+ lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
+ bool rhs_is_resource =
+ rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
+ return std::tie(lhs_is_resource, lhs->name()) <
+ std::tie(rhs_is_resource, rhs->name());
+}
+
+xla::StatusOr<Node*> AddNodeDefToGraph(const NodeDef& node_def, Graph* graph) {
+ Status status;
+ Node* inserted_node = graph->AddNode(node_def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ return inserted_node;
+}
+
+xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
+ const char* const kRetValOp = "_Retval";
+ NodeDef ret_def;
+ ret_def.set_op(kRetValOp);
+ ret_def.set_name(absl::StrCat(kRetValOp, index));
+ AddNodeAttr("T", type, &ret_def);
+ AddNodeAttr("index", index, &ret_def);
+ return AddNodeDefToGraph(ret_def, graph);
+}
+
+// Check that the graph has no cycle containing the given node.
+Status CheckNodeNotInCycle(const Node* node, const int num_nodes) {
+ std::vector<const Node*> ready;
+ ready.push_back(node);
+ std::vector<bool> visited(num_nodes);
+ while (!ready.empty()) {
+ const Node* current_node = ready.back();
+ ready.pop_back();
+ visited[current_node->id()] = true;
+ for (const Edge* out : current_node->out_edges()) {
+ if (out->dst() == node) {
+ return errors::Internal("Detected a cycle: ", FormatNodeForError(*node),
+ " (", node->def().op(), ") feeds into itself.");
+ } else if (!visited[out->dst()->id()]) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
new file mode 100644
index 0000000000..582b49d511
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/graph/graph.h"
+
+// Utility functions shared between functionalize cond and while.
+
+namespace tensorflow {
+
+// Check that the graph has no cycle containing the given node.
+Status CheckNodeNotInCycle(const Node* node, const int num_nodes);
+
+// Comparison function used for sorting nodes consistently.
+// a) resource variables are last, and
+// b) sort lexicographically by name (for deterministic output).
+struct NodeCmpByNameResourcesLast {
+ bool operator()(const Node* lhs, const Node* rhs) const;
+};
+
+// Returns the Node* created from the NodeDef in the Graph.
+xla::StatusOr<Node*> AddNodeDefToGraph(const NodeDef& node_def, Graph* graph);
+
+// Build a retval node of given type and index.
+xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index);
+
+// Returns a textual representation of the names of the nodes in the input.
+template <typename T>
+string NodesToString(const T& nodes) {
+ return absl::StrCat("{",
+ absl::StrJoin(nodes, ",",
+ [](string* output, const Node* node) {
+ absl::StrAppend(output, node->name());
+ }),
+ "}");
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
new file mode 100644
index 0000000000..7f45e3bffa
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -0,0 +1,668 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_while.h"
+
+#include <algorithm>
+#include <deque>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+namespace {
+
+using xla::StatusOr;
+
+// Information about a loop argument.
+struct Arg {
+ // Every loop argument has an Enter node.
+ Node* enter;
+
+ // Is the loop argument a loop-invariant value? Taken from the `is_constant`
+ // attribute on the Enter node.
+ bool is_loop_invariant;
+
+ // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
+ // arguments must have all of the following nodes:
+ Node* merge = nullptr;
+ Node* switch_node = nullptr;
+ Node* next_iteration = nullptr;
+ Node* exit = nullptr;
+};
+
+// Information about a loop frame.
+struct Frame {
+ string name;
+
+ // Pointer to the parent frame. The root frame has a pointer to itself.
+ Frame* parent = nullptr;
+ int num_children = 0;
+
+ // Arguments to this loop.
+ std::vector<Arg> args;
+
+ // The loop condition of the loop. There should be exactly one loop condition
+ // in every loop.
+ Node* loop_cond = nullptr;
+
+ // Set of nodes that belong to the loop frame.
+ std::unordered_set<Node*> nodes;
+};
+
+// Copies a subgraph from `graph` to `output` by performing a reverse DFS
+// starting at nodes in vector `stack`.
+// `node_map` is a vector indexed by source node ID to dest nodes.
+// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
+// before the traversal clients can cut the graph. If a frame is provided (frame
+// != nullptr), then this functions will return an error if the
+// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
+// cut the graph and prevent the traversal from escaping.
+//
+// `squash_src_outputs` contains a bool for each source node ID. If true, then
+// the source output on that node will be replaced by zero when copied. This is
+// used when replacing a Switch node with an _Arg node. The output we are
+// taking from the Switch node was not necessarily the first output, but _Arg
+// nodes only have one output. By adding the Switch node to `squash_src_outputs`
+// we rewrite the src_output of the corresponding edge to be 0.
+Status CopySubgraph(const Graph& graph, const Frame* frame,
+ std::vector<Node*> stack,
+ const std::vector<bool>& squash_src_outputs,
+ std::vector<Node*>* node_map, Graph* output) {
+ VLOG(3) << "Stack: " << NodesToString(stack);
+ std::vector<bool> visited(graph.num_node_ids(), false);
+ while (!stack.empty()) {
+ Node* n = stack.back();
+ stack.pop_back();
+
+ VLOG(5) << "Copying node " << n->name();
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+
+ for (const Edge* e : n->in_edges()) {
+ Node* src = e->src();
+ if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
+ // We traversed out of the loop frame, without encountering a cut node.
+ return errors::Internal("Graph traversal of loop frame ", frame->name,
+ " escaped frame at ", src->name(),
+ " without encountering an argument node.");
+ }
+ if ((*node_map)[src->id()] == nullptr) {
+ (*node_map)[src->id()] = output->CopyNode(src);
+ stack.push_back(src);
+ }
+ Node* src_copy = (*node_map)[e->src()->id()];
+ int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
+ ? 0
+ : e->src_output();
+ Node* dst_copy = (*node_map)[e->dst()->id()];
+ output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
+ const char* const kArgOp = "_Arg";
+ NodeDef arg_def;
+ NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
+ builder.Attr("T", type);
+ builder.Attr("index", index);
+ TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
+ return AddNodeDefToGraph(arg_def, graph);
+}
+
+// Builds a graph for the loop condition.
+Status BuildLoopCondition(const Graph& graph, Frame* frame,
+ std::unique_ptr<Graph>* cond_output) {
+ VLOG(2) << "Building loop condition for " << frame->name;
+ *cond_output = absl::make_unique<Graph>(graph.op_registry());
+ Graph* output = cond_output->get();
+
+ // Map from nodes in the original graph to the condition graph.
+ std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
+ std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
+
+ // Build one _Arg node for each Enter node.
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+
+ TF_ASSIGN_OR_RETURN(Node * arg_node,
+ BuildArgNode(output, arg.enter->input_type(0), i));
+ if (arg.is_loop_invariant) {
+ node_map[arg.enter->id()] = arg_node;
+ } else {
+ node_map[arg.merge->id()] = arg_node;
+ }
+ }
+
+ // Build a Retval node for the loop condition. The LoopCond nodes are always
+ // boolean because of the type constraints on the LoopCond op.
+ TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
+ BuildRetvalNode(output, DT_BOOL, 0));
+
+ // Performs a reverse DFS, copying nodes and edges to the output graph.
+ // The _Arg and _Retval nodes were added unconditionally above, so we are
+ // guaranteed to get the correct function signature.
+ return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
+ &node_map, output);
+}
+
+// Builds a graph for the loop body.
+Status BuildLoopBody(const Graph& graph, Frame* frame,
+ DataTypeVector* arg_types,
+ std::unique_ptr<Graph>* body_output) {
+ VLOG(2) << "Building loop body for " << frame->name;
+ *body_output = absl::make_unique<Graph>(graph.op_registry());
+ Graph* output = body_output->get();
+
+ // Map from nodes in the original graph to the condition graph.
+ std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
+ std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
+
+ // Build one _Arg node for each Enter node.
+ std::vector<Node*> next_iterations;
+ next_iterations.reserve(frame->args.size());
+ arg_types->reserve(frame->args.size());
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+
+ DataType dtype = arg.enter->input_type(0);
+ arg_types->push_back(dtype);
+
+ TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
+
+ if (dtype == DT_RESOURCE) {
+ // The convention of the XLA bridge is that resource variable arguments
+ // are only inputs to the loop body and have no corresponding output.
+ // TODO(b/37741920): change the convention so that DT_RESOURCE variables
+ // are both inputs and outputs, and then remove this case.
+ TF_RET_CHECK(arg.is_loop_invariant);
+ node_map[arg.enter->id()] = arg_node;
+ } else {
+ TF_ASSIGN_OR_RETURN(Node * retval_node,
+ BuildRetvalNode(output, dtype, i));
+
+ if (arg.is_loop_invariant) {
+ // Argument is loop-invariant. Forward it from the Arg to the Retval.
+ node_map[arg.enter->id()] = arg_node;
+ output->AddEdge(arg_node, 0, retval_node, 0);
+ } else {
+ // Argument is loop-varying.
+ node_map[arg.switch_node->id()] = arg_node;
+ // The Switch node has two outputs, but _Arg only has one. This tells
+ // the CopySubgraph function to rewrite the output number of edges from
+ // the _Arg node to be 0 rather than copying the output number from the
+ // Switch node.
+ squash_src_outputs[arg.switch_node->id()] = true;
+ node_map[arg.next_iteration->id()] = retval_node;
+ next_iterations.push_back(arg.next_iteration);
+ }
+ }
+ }
+
+ // Performs a reverse DFS, copying nodes and edges to the output graph.
+ // The _Arg and _Retval nodes were added unconditionally above, so we are
+ // guaranteed to get the correct function signature.
+ TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
+ squash_src_outputs, &node_map, output));
+
+ return Status::OK();
+}
+
+// Copy the FunctionDef of given function from lookup_library to library, if
+// it can be found in lookup_library but is missing from library.
+Status AddMissingFunctionByName(const string& function_name,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ if (!library->Find(function_name) && lookup_library->Find(function_name)) {
+ return library->AddFunctionDef(*lookup_library->Find(function_name));
+ }
+ return Status::OK();
+}
+
+// Iterate over all functions that the given fdef refers to. Copy the missing
+// FunctionDefs from lookup_library to library.
+Status AddMissingFunctionDef(const FunctionDef& fdef,
+ const FunctionLibraryDefinition* lookup_library,
+ FunctionLibraryDefinition* library) {
+ TF_RET_CHECK(lookup_library);
+ for (const NodeDef& node : fdef.node_def()) {
+ if (library->Find(node.op())) {
+ continue;
+ }
+ // The function referred by 'SymbolicGradient' node is specified in its
+ // attribute 'f'.
+ if (node.op() == FunctionLibraryDefinition::kGradientOp) {
+ const AttrValue* attr =
+ AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr);
+ if (!attr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const string& func_name = attr->func().name();
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(func_name, lookup_library, library));
+ // Copy the user-defined gradient function if it exists.
+ const string grad_name = lookup_library->FindGradient(func_name);
+ if (!grad_name.empty() && library->FindGradient(func_name).empty()) {
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionByName(grad_name, lookup_library, library));
+ GradientDef grad_def;
+ grad_def.set_function_name(func_name);
+ grad_def.set_gradient_func(grad_name);
+ TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def));
+ }
+ } else if (lookup_library->Find(node.op())) {
+ TF_RETURN_IF_ERROR(
+ library->AddFunctionDef(*lookup_library->Find(node.op())));
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph, Frame* frame,
+ FunctionLibraryDefinition* library) {
+ VLOG(2) << "Frame " << frame->name << " before: "
+ << dump_graph::DumpGraphToFile("functionalize_before", *graph,
+ library);
+
+ // Split loop-varying Enter nodes with multiple successors. If the same
+ // Tensor is fed as input to multiple loop arguments, we may end up with a
+ // shared Enter node. We clone Enter nodes with multiple successors to
+ // maintain the invariant of a unique Enter node per argument of the final
+ // loop.
+ std::vector<Arg> args;
+ for (const Arg& arg : frame->args) {
+ if (arg.is_loop_invariant) {
+ args.push_back(arg);
+ } else {
+ std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
+ arg.enter->out_edges().end());
+ for (int i = 0; i < edges.size(); ++i) {
+ if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
+ continue;
+ }
+ TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
+ Arg new_arg;
+ new_arg.is_loop_invariant = false;
+ if (i == 0) {
+ new_arg.enter = arg.enter;
+ } else {
+ new_arg.enter = graph->CopyNode(arg.enter);
+ frame->nodes.insert(new_arg.enter);
+ for (Edge const* e : arg.enter->in_edges()) {
+ graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
+ e->IsControlEdge() ? Graph::kControlSlot : 0);
+ }
+ Node* dst = edges[i]->dst();
+ int dst_input = edges[i]->dst_input();
+ graph->RemoveEdge(edges[i]);
+ graph->AddEdge(new_arg.enter, 0, dst, dst_input);
+ }
+ args.push_back(new_arg);
+ }
+ }
+ }
+ frame->args = std::move(args);
+
+ std::sort(frame->args.begin(), frame->args.end(),
+ [](const Arg& a, const Arg& b) {
+ return NodeCmpByNameResourcesLast()(a.enter, b.enter);
+ });
+
+ if (frame->loop_cond == nullptr) {
+ return errors::InvalidArgument("Loop ", frame->name,
+ " has no LoopCond node");
+ }
+
+ // Find the set of Switch nodes that are successors of the LoopCond.
+ std::unordered_set<Node*> switches;
+ for (const Edge* edge : frame->loop_cond->out_edges()) {
+ if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
+ edge->dst_input() == 1) {
+ switches.insert(edge->dst());
+ }
+ }
+
+ // For each non-constant argument, looks for the following pattern of nodes:
+ // Enter ----> Merge --------> Switch --> Exit
+ // ^ ^
+ // | |
+ // NextIteration LoopCond
+ // ^ ^
+ // | |
+ // ... ...
+ for (Arg& arg : frame->args) {
+ if (!arg.is_loop_invariant) {
+ // Follow the edge from the Enter to Merge.
+ const Edge* enter_merge = nullptr;
+ for (const Edge* e : arg.enter->out_edges()) {
+ // Ignore control-edges to the sink node. These are allowed by the
+ // graph invariants, although probably they should have been stripped
+ // off earlier.
+ if (e->IsControlEdge() && e->dst()->IsSink()) {
+ continue;
+ }
+ if (enter_merge != nullptr) {
+ return errors::Internal("Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.enter),
+ " has multiple successors: ",
+ FormatNodeForError(*enter_merge->dst()),
+ " and ", FormatNodeForError(*e->dst()));
+ }
+ enter_merge = e;
+ }
+ if (enter_merge == nullptr) {
+ return errors::Internal("Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.enter),
+ " has zero successors");
+ }
+ arg.merge = enter_merge->dst();
+ if (!IsMerge(arg.merge)) {
+ return errors::InvalidArgument(
+ "Successor of Enter node for loop-varying argument ",
+ FormatNodeForError(*arg.merge),
+ " is not a Merge node; got: ", arg.merge->type_string());
+ }
+
+ // Find the NextIteration from the merge. There should be two inputs to
+ // the Merge and the NextIteration should be the other input.
+ if (arg.merge->input_types().size() != 2) {
+ return errors::InvalidArgument(
+ "Unexpected number of inputs to Merge node for loop-varying "
+ "argument ",
+ FormatNodeForError(*arg.merge), "; expected 2, got ",
+ arg.merge->input_types().size());
+ }
+ TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
+ &arg.next_iteration));
+ if (!IsNextIteration(arg.next_iteration)) {
+ return errors::InvalidArgument(
+ "Expected NextIteration node as input to Merge node; got node ",
+ FormatNodeForError(*arg.next_iteration), " with kind ",
+ arg.next_iteration->type_string());
+ }
+
+ // Find the Switch successor of the Merge. There should be exactly one
+ // Switch node that is a successor of both the Merge and the LoopCond.
+ for (const Edge* edge : arg.merge->out_edges()) {
+ if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
+ switches.find(edge->dst()) != switches.end()) {
+ if (arg.switch_node != nullptr) {
+ return errors::InvalidArgument("Duplicate Switch successors to ",
+ FormatNodeForError(*arg.merge));
+ }
+ arg.switch_node = edge->dst();
+ }
+ }
+ if (arg.switch_node == nullptr) {
+ return errors::InvalidArgument("Missing Switch successor to ",
+ FormatNodeForError(*arg.merge));
+ }
+
+ // Update the device on the Identity outputs of the switch to match their
+ // target. These Identity outputs do not
+
+ // Loop over the switch node's output to:
+ // - Find the Exit successor.
+ // - Set the sharding on all Identity outputs of the switch. These
+ // identity nodes are values used by the loop body or condition.
+ // The Identity node may have the wrong device so copy the device from
+ // one of its outputs instead.
+ std::deque<const Edge*> possible_exit;
+ for (const Edge* edge : arg.switch_node->out_edges()) {
+ if (edge->src_output() == 0) {
+ possible_exit.push_back(edge);
+ }
+ if (IsIdentity(edge->dst())) {
+ TF_RETURN_IF_ERROR(
+ SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
+ }
+ }
+ // TODO(b/67425339): Allow general graph between switch and exit.
+ while (!possible_exit.empty()) {
+ const Edge* edge = possible_exit.front();
+ possible_exit.pop_front();
+ if (IsExit(edge->dst())) {
+ if (arg.exit != nullptr) {
+ return errors::InvalidArgument(
+ "Duplicate Exit successors to ",
+ FormatNodeForError(*arg.switch_node));
+ }
+ arg.exit = edge->dst();
+ } else {
+ if (!IsIdentity(edge->dst())) {
+ return errors::Unimplemented("General graph between switch (",
+ FormatNodeForError(*arg.switch_node),
+ ") and exit node of frame ",
+ frame->name, " not supported yet.");
+ }
+ for (const Edge* out : edge->dst()->out_edges()) {
+ possible_exit.push_back(out);
+ }
+ }
+ }
+ }
+ }
+
+ // Builds the condition and body functions.
+ std::unique_ptr<Graph> cond_graph;
+ TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ DataTypeVector arg_types;
+ std::unique_ptr<Graph> body_graph;
+ TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+
+ VLOG(2) << "Frame " << frame->name << " condition: "
+ << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
+ << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
+
+ static std::atomic<int64> sequence_num(0LL);
+ int64 id = ++sequence_num;
+ NameAttrList cond_name;
+ cond_name.set_name(absl::StrCat("_functionalize_cond_", id));
+ NameAttrList body_name;
+ body_name.set_name(absl::StrCat("_functionalize_body_", id));
+ FunctionDef cond_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
+ FunctionDef body_fdef;
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
+
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
+ if (lookup_library) {
+ // Copy missing FunctionDefs from lookup_library to library to make library
+ // self-contained.
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(cond_fdef, lookup_library, library));
+ TF_RETURN_IF_ERROR(
+ AddMissingFunctionDef(body_fdef, lookup_library, library));
+ }
+
+ // Builds a While operator.
+ NodeDef while_def;
+ NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ builder.Attr("T", arg_types);
+ builder.Attr("cond", cond_name);
+ builder.Attr("body", body_name);
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ builder.ControlInput(in_edge->src()->name());
+ } else {
+ inputs.push_back(NodeDefBuilder::NodeOut(
+ in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
+ }
+ }
+ builder.Input(inputs);
+ TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
+ TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
+
+ // Copies edges to the Enter nodes and from the Exit nodes onto the While.
+ for (int i = 0; i < frame->args.size(); ++i) {
+ const Arg& arg = frame->args[i];
+ const Edge* in_edge;
+ TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
+ if (in_edge->IsControlEdge()) {
+ graph->AddControlEdge(in_edge->src(), while_node);
+ } else {
+ graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
+ }
+
+ if (!arg.is_loop_invariant) {
+ // Add output edges if the output of the loop is consumed.
+ if (arg.exit != nullptr) {
+ std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
+ arg.exit->out_edges().end());
+ for (const Edge* edge : edges) {
+ Node* dst = edge->dst();
+ int dst_input = edge->dst_input();
+ graph->RemoveEdge(edge);
+
+ if (dst_input == Graph::kControlSlot) {
+ graph->AddControlEdge(while_node, dst);
+ } else {
+ graph->AddEdge(while_node, i, dst, dst_input);
+ }
+ }
+ }
+ }
+ }
+
+ // Remove the old nodes from the graph, and add the while node to the parent
+ // frame.
+ for (Node* node : frame->nodes) {
+ graph->RemoveNode(node);
+ }
+ frame->nodes.clear();
+ frame->parent->nodes.insert(while_node);
+
+ VLOG(2) << "Frame " << frame->name << " after: "
+ << dump_graph::DumpGraphToFile("functionalize_after", *graph,
+ library);
+
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph,
+ FunctionLibraryDefinition* library) {
+ // Note: BuildControlFlowInfo() requires that the graph's source node is
+ // connected to all source nodes in the graph. Many graphs violate this
+ // invariant.
+ std::vector<ControlFlowInfo> cf_info;
+ std::vector<string> unreachable_nodes;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
+ if (!unreachable_nodes.empty()) {
+ return errors::InvalidArgument(
+ "The following nodes are unreachable from the source in the graph: ",
+ errors::FormatNodeNamesForError(unreachable_nodes));
+ }
+
+ // Builds Frames, indexed by name.
+ std::unordered_map<string, Frame> frames;
+ for (Node* node : graph->op_nodes()) {
+ const ControlFlowInfo& cf = cf_info[node->id()];
+
+ VLOG(2) << "node: " << node->name() << " (" << node->id()
+ << ") frame_name: " << cf.frame_name
+ << " frame: " << (cf.frame ? cf.frame->name() : "---")
+ << " parent_frame: "
+ << (cf.parent_frame ? cf.parent_frame->name() : "---");
+ TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);
+
+ Frame& frame = frames[cf.frame_name];
+ Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
+ if (frame.parent == nullptr) {
+ frame.parent = parent;
+ frame.name = cf.frame_name;
+ ++parent->num_children;
+ }
+
+ if (IsEnter(node)) {
+ Arg arg;
+ arg.enter = node;
+ TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
+ &arg.is_loop_invariant));
+ frame.args.push_back(arg);
+ } else if (IsLoopCond(node)) {
+ frame.loop_cond = node;
+ }
+ frame.nodes.insert(node);
+ }
+
+ // Adds frames with no children (i.e., the innermost frames) to a worklist.
+ std::deque<Frame*> worklist;
+ for (auto& frame : frames) {
+ if (frame.second.num_children == 0) {
+ worklist.push_back(&frame.second);
+ }
+ }
+
+ // Eliminate loops from innermost to outermost.
+ while (!worklist.empty()) {
+ Frame* frame = worklist.front();
+ worklist.pop_front();
+ if (frame->parent == frame) {
+ // Skip the root frame.
+ continue;
+ }
+
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
+
+ // If the parent has no remaining children, add it to the worklist.
+ --frame->parent->num_children;
+ if (frame->parent->num_children == 0) {
+ worklist.push_back(frame->parent);
+ }
+ }
+
+ // There should be no cycle at this point, since while loops have been removed
+ // from graph.
+ // Check that the newly added XlaWhile nodes don't feed into themselves.
+ for (const Node* node : graph->op_nodes()) {
+ if (node->def().op() == "XlaWhile") {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNodeNotInCycle(node, graph->num_node_ids()),
+ "Functionalizing loop failed.");
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h
new file mode 100644
index 0000000000..a708c6e4ec
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_while.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Transformation that converts tf.while_loop() loops into functional While
+// operators, suitable for XLA compilation. If lookup_library is provided, use
+// it to make the library for control flow self-contained.
+Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
+ Graph* graph, FunctionLibraryDefinition* library);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index e4fdf0a618..82e9eef005 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
std::vector<bool> compile_time_constant_flags(expressions.size());
TF_RETURN_IF_ERROR(
- BackwardsConstAnalysis(*graph, &compile_time_constant_flags));
+ BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
+ /*compile_time_const_nodes=*/nullptr));
args->resize(expressions.size());
for (int i = 0; i < args->size(); ++i) {
@@ -80,7 +81,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
TF_ASSIGN_OR_RETURN(auto literal,
client->ComputeConstant(constant_graph));
TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
+ LiteralToHostTensor(literal, arg.type, &arg.constant_value));
} else {
arg.kind = XlaCompiler::Argument::kParameter;
}
@@ -126,7 +127,7 @@ Status GraphCompiler::Compile() {
TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
<< "Not supported node: " << n->DebugString();
params.op_kernel = op_kernel.get();
- gtl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
+ absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
params.output_attr_array = output_attr.data();
// tensor_inputs_ is a buffer reused across graph traversal. We clean up and
@@ -145,6 +146,7 @@ Status GraphCompiler::Compile() {
}
OpKernelContext op_context(&params, n->num_outputs());
+ VLOG(3) << "Translating " << params.op_kernel->name();
if (IsFunctional(n)) {
TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
} else {
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index 127562eb23..ab7cac7100 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -89,7 +89,7 @@ class GraphCompiler {
ScopedStepContainer* step_container_;
// A buffer to hold tensor inputs to a node, this is reused across the graph
// traversal.
- gtl::InlinedVector<TensorValue, 4> tensor_inputs_;
+ absl::InlinedVector<TensorValue, 4> tensor_inputs_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index b1366e9e31..46794f7b50 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -22,6 +22,7 @@ tf_kernel_library(
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
+ "broadcast_to_op.cc",
"bucketize_op.cc",
"cast_op.cc",
"categorical_op.cc",
@@ -100,6 +101,12 @@ tf_kernel_library(
"unary_ops.cc",
"unpack_op.cc",
"variable_ops.cc",
+ "xla_broadcast_helper_op.cc",
+ "xla_conv_op.cc",
+ "xla_dot_op.cc",
+ "xla_pad_op.cc",
+ "xla_reduce_op.cc",
+ "xla_select_and_scatter_op.cc",
],
hdrs = [
"index_ops.h",
@@ -158,14 +165,11 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
- ] + if_mkl(
- [
- "//tensorflow/core/kernels:mkl_transpose_op",
- ],
- [
- "//tensorflow/core/kernels:transpose_op",
- ],
- ),
+ "//tensorflow/core/kernels:transpose_op",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
)
tf_kernel_library(
@@ -174,6 +178,7 @@ tf_kernel_library(
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
@@ -191,6 +196,7 @@ tf_kernel_library(
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 48f2a005ab..a18e04995b 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -23,10 +23,10 @@ namespace {
void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
DataType input_dtype, const TensorShape& input_tensor_shape,
- gtl::ArraySlice<int64> block_shape,
+ absl::Span<const int64> block_shape,
const xla::Literal& crops) {
const int input_rank = input_tensor_shape.dims();
- const gtl::InlinedVector<int64, 4> input_shape =
+ const absl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
const int block_rank = block_shape.size();
@@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
- gtl::ArraySlice<int64> remainder_shape(input_shape);
+ absl::Span<const int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index ba3b1c9dab..182f7c9934 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for broadcasting used in gradient
// code.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -38,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel {
OP_REQUIRES(
ctx, ctx->num_inputs() == 2,
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
- gtl::InlinedVector<BCast::Vec, 2> shapes;
+ absl::InlinedVector<BCast::Vec, 2> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const TensorShape in_shape = ctx->InputShape(i);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
@@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
const int64 len = bcast.output_shape().size();
Tensor output(DT_INT32, TensorShape({len}));
@@ -87,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel {
ctx, ctx->num_inputs() == 2,
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
- gtl::InlinedVector<BCast::Vec, 4> shapes;
+ absl::InlinedVector<BCast::Vec, 4> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const TensorShape in_shape = ctx->InputShape(i);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
@@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel {
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
- "Incompatible shapes: [", str_util::Join(shapes[0], ","),
- "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ "Incompatible shapes: [", absl::StrJoin(shapes[0], ","),
+ "] vs. [", absl::StrJoin(shapes[1], ","), "]"));
Output(ctx, 0, bcast.grad_x_reduce_idx());
Output(ctx, 1, bcast.grad_y_reduce_idx());
}
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 2c328102e0..df17da4c1c 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -30,21 +30,21 @@ namespace {
// A subclass of a XlaBinaryOp must build the computation that
// describes the (tensor,tensor)->tensor function to apply to each element of
// the input.
-#define XLA_MAKE_BINARY(NAME, HLO) \
- class NAME##Op : public XlaBinaryOp { \
- public: \
- explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
- xla::XlaOp Computation( \
- XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
- const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs, \
- const gtl::ArraySlice<int64>& rhs_shape, \
- const BCast& broadcast_helper, \
- const std::vector<int64>& extend_dimensions) override { \
- xla::XlaBuilder* b = ctx->builder(); \
- (void)b; \
- return HLO; \
- } \
- }; \
+#define XLA_MAKE_BINARY(NAME, HLO) \
+ class NAME##Op : public XlaBinaryOp { \
+ public: \
+ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
+ xla::XlaOp Computation( \
+ XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
+ const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, \
+ const absl::Span<const int64>& rhs_shape, \
+ const BCast& broadcast_helper, \
+ const std::vector<int64>& extend_dimensions) override { \
+ xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
+ return HLO; \
+ } \
+ }; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op)
XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
new file mode 100644
index 0000000000..4bd7c74dca
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+namespace {
+
+class BroadcastToOp : public XlaOpKernel {
+ public:
+ explicit BroadcastToOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ TensorShape output_shape;
+ OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+
+ OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
+ errors::InvalidArgument(
+ "Input rank (", input_shape.dims(),
+ ") must be less than or equal to the output rank (",
+ output_shape.dims(), ")"));
+
+ auto input_dims = input_shape.dim_sizes();
+ auto output_dims = output_shape.dim_sizes();
+
+ // Broadcasting is done right-to-left on right-aligned dimensions; reverse
+ // the two vectors so elements to be broadcast are aligned.
+ absl::c_reverse(input_dims);
+ absl::c_reverse(output_dims);
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ for (int i = 0; i < output_shape.dims(); ++i) {
+ if (i < input_shape.dims()) {
+ OP_REQUIRES(
+ context,
+ (output_dims[i] == 0 && input_dims[i] == 0) ||
+ (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
+ errors::InvalidArgument("invalid shape to broadcast from ",
+ input_shape.DebugString(), " to ",
+ output_shape.DebugString()));
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ if (output_dims[i] != input_dims[i]) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(input_dims[i]);
+ broadcast_shape.push_back(output_dims[i] / input_dims[i]);
+ }
+ } else {
+ broadcast_shape.push_back(output_dims[i]);
+ }
+ }
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::Reshape(
+ xla::BroadcastInDim(context->Input(0),
+ xla::ShapeUtil::MakeShape(
+ context->input_xla_type(0), broadcast_shape),
+ broadcast_dims),
+ output_shape.dim_sizes());
+ context->SetOutput(0, output);
+ }
+};
+
+REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"),
+ BroadcastToOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 5da7972397..674720e22f 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -120,45 +120,30 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
{expanded_filter_shape.dims() - 2});
}
-// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
-// zeros for the cross-depth filters. Used to build a depthwise convolution.
-xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter,
- xla::XlaBuilder* builder) {
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dims() - 2;
+ int64 output_feature_dim = filter_shape.dims() - 1;
+ int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
+ int64 input_feature = filter_shape.dim_size(input_feature_dim);
// Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = expanded_filter_shape;
- implicit_broadcast_filter_shape.set_dim(
- implicit_broadcast_filter_shape.dims() - 2, 1);
- implicit_broadcast_filter_shape.set_dim(
- implicit_broadcast_filter_shape.dims() - 1,
- depthwise_multiplier * input_feature);
- auto implicit_broadcast_filter =
- xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-
- // Broadcast the filter to [H, W, ..., M, M*N].
- auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
- auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero);
-
- // If the filter mask is set, choose the broadcasted filter, othwerwise,
- // choose zero.
- return xla::Select(CreateExpandedFilterMask(filter_shape, builder),
- expanded_filter, expanded_zero);
+ TensorShape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dim(output_feature_dim,
+ depthwise_multiplier * input_feature);
+ return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
}
-// Inverse of ExpandFilterForDepthwiseConvolution.
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
const TensorShape& filter_shape,
DataType dtype,
const xla::XlaOp& filter_backprop,
xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
auto masked_expanded_filter = xla::Select(
CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
CreateExpandedZero(filter_shape, dtype, builder));
@@ -168,8 +153,7 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
// ExpandedZero guarantees that only one element is non zero, so there
// cannot be accumulated precision error.
xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype),
- {expanded_filter_shape.dims() - 2}),
+ *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
filter_shape.dim_sizes());
}
@@ -245,15 +229,9 @@ class ConvOp : public XlaOpKernel {
"input and filter must have the same depth: ", in_depth,
" vs ", input_shape.dim_size(feature_dim)));
- xla::XlaBuilder* b = ctx->builder();
-
xla::XlaOp filter = ctx->Input(1);
- TensorShape expanded_filter_shape = filter_shape;
if (depthwise_) {
- filter = ExpandFilterForDepthwiseConvolution(
- filter_shape, ctx->input_type(0), filter, b);
- expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
}
xla::ConvolutionDimensionNumbers dims;
@@ -280,14 +258,15 @@ class ConvOp : public XlaOpKernel {
int64 unused_output_size;
OP_REQUIRES_OK(
ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), expanded_filter_shape.dim_size(i),
+ input_shape.dim_size(dim), filter_shape.dim_size(i),
rhs_dilation[i], window_strides[i], padding_,
&unused_output_size, &padding[i].first, &padding[i].second));
}
- xla::XlaOp conv =
- xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
- lhs_dilation, rhs_dilation, dims);
+ xla::XlaOp conv = xla::ConvGeneralDilated(
+ ctx->Input(0), filter, window_strides, padding, lhs_dilation,
+ rhs_dilation, dims,
+ /*feature_group_count=*/depthwise_ ? in_depth : 1);
ctx->SetOutput(0, conv);
}
@@ -388,7 +367,6 @@ class ConvBackpropInputOp : public XlaOpKernel {
expanded_filter_shape, out_backprop_shape, dilations_,
strides_, padding_, data_format_, &dims));
- xla::XlaBuilder* b = ctx->builder();
auto filter = ctx->Input(1);
auto out_backprop = ctx->Input(2);
@@ -425,12 +403,6 @@ class ConvBackpropInputOp : public XlaOpKernel {
rhs_dilation[i] = dilations_[dim];
}
- // If this is a depthwise convolution, expand the filter.
- if (depthwise_) {
- filter = ExpandFilterForDepthwiseConvolution(
- filter_shape, ctx->input_type(1), filter, b);
- }
-
// Mirror the filter in the spatial dimensions.
xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
@@ -438,7 +410,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
// = gradients (with padding and dilation) <conv> mirrored_weights
xla::XlaOp in_backprop = xla::ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums);
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
+ filter_shape.dim_size(num_spatial_dims_ + 1)
+ : 1);
ctx->SetOutput(0, in_backprop);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index a5b870f8db..6653944a91 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel {
// in the XLA documentation.
virtual xla::XlaOp Computation(
XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
- const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs,
- const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
+ const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
+ const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
const std::vector<int64>& extend_dimensions) = 0;
void Compile(XlaOpKernelContext* ctx) override;
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 12b0e38288..e96a1adce4 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel {
OP_REQUIRES(ctx, kRequiredDims == input_rank,
errors::InvalidArgument("Input rank should be ", kRequiredDims,
"; got: ", input_rank));
- const gtl::InlinedVector<int64, 4> input_shape =
+ const absl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
xla::XlaOp input = ctx->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index ed44ad218b..49c12fc232 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -29,7 +29,7 @@ namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
- gtl::ArraySlice<int64> other_dims,
+ absl::Span<const int64> other_dims,
xla::PrimitiveType element_type) {
xla::XlaBuilder* builder = input.builder();
// Create two matrices that have the following forms, and compare them:
@@ -177,8 +177,8 @@ class MatrixDiagOp : public XlaOpKernel {
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
- tensorflow::gtl::ArraySlice<int64> other_dims(dims);
- other_dims.pop_back();
+ absl::Span<const int64> other_dims(dims);
+ other_dims.remove_suffix(1);
xla::XlaOp input = ctx->Input(0);
xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index 35de96e0aa..44140304fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// gather = s32[3,2] gather(operand, indices),
- // output_window_dims={0},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3, 1}
+ // slice_sizes={3, 1}
//
//
// Example of an N-D gather pulling out slices of shape [1,1,2] out of a
@@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
// operand = s32[3,3,2] parameter(0)
// indices = s32[2,2] parameter(1)
// gather = s32[2,2] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0,1},
- // gather_dims_to_operand_dims={0,1},
+ // offset_dims={1},
+ // collapsed_slice_dims={0,1},
+ // start_index_map={0,1},
// index_vector_dim=0,
- // window_bounds={1,1,2}
+ // slice_sizes={1,1,2}
xla::GatherDimensionNumbers dim_numbers;
- std::vector<int64> window_bounds;
- window_bounds.reserve(input_shape.dims());
+ std::vector<int64> slice_sizes;
+ slice_sizes.reserve(input_shape.dims());
for (int64 i = 0; i < input_shape.dims(); i++) {
int64 window_bound;
if (axis <= i && i < (axis + num_index_dims)) {
- dim_numbers.add_elided_window_dims(i);
+ dim_numbers.add_collapsed_slice_dims(i);
window_bound = 1;
} else {
window_bound = input_shape.dim_size(i);
}
- window_bounds.push_back(window_bound);
+ slice_sizes.push_back(window_bound);
if (i < axis) {
- dim_numbers.add_output_window_dims(i);
+ dim_numbers.add_offset_dims(i);
} else if (i >= (axis + num_index_dims)) {
int64 indices_rank =
indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
- dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims);
+ dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
}
}
dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
: indices_shape.dims());
for (int64 i = axis; i < axis + num_index_dims; i++) {
- dim_numbers.add_gather_dims_to_operand_dims(i);
+ dim_numbers.add_start_index_map(i);
}
- *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index e72200bfbc..19dd38c46e 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -25,7 +25,10 @@ class IdentityOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
- ctx->SetOutput(i, ctx->Input(i));
+ // Forwards using the underlying op_kernel_context so both tensor and
+ // resource values are forwarded correctly.
+ ctx->op_kernel_context()->set_output(i,
+ ctx->op_kernel_context()->input(i));
}
}
@@ -35,9 +38,10 @@ class IdentityOp : public XlaOpKernel {
// XLA_* devices also register a "real" Identity operator so we suppress the
// dummy operator using CompilationOnly().
-REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp);
-
-REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp);
+REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(),
+ IdentityOp);
+REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(),
+ IdentityOp);
REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6a7eb8d90c..56da50f140 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
// TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
+ options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "if" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(b, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -200,21 +216,10 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
}
- bool resource_variable_seen = false;
- for (int i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input_type(i) == DT_RESOURCE) {
- resource_variable_seen = true;
- } else {
- OP_REQUIRES(
- ctx, !resource_variable_seen,
- errors::FailedPrecondition(
- "Resource variables and regular inputs cannot be interleaved."));
- }
- }
-
- xla::XlaOp outputs = xla::Conditional(
- ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
- xla::Tuple(b, inputs), *else_result.computation);
+ auto input_tuple = xla::Tuple(b, inputs);
+ xla::XlaOp outputs =
+ xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation,
+ input_tuple, *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
@@ -230,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
+ if (has_token_input_output_) {
+ // Set token output for this "if" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(outputs, output_types_.size());
+ auto shape_or = b->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the conditional
// bodies.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
index f9bc98a198..7783e13a8a 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.h
@@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
DataType cond_type_;
DataTypeVector input_types_;
DataTypeVector output_types_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 8d75624e74..d9a0257b70 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -32,13 +32,13 @@ namespace {
//
// 1. S := (N - 1) / gcd(N-1, R-1)
// 2. k := (R - 1) / gcd(N-1, R-1)
-// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1)
+// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
//
// For example, to Scale from 7x7 -> 15x15:
//
// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
-// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2)
+// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
//
//
// The 7x7 -> 15x15 case is much too large to write out in full as an
@@ -65,6 +65,8 @@ namespace {
// 1/9 * 3 6 9 6 3
// 2 4 6 4 2
// 1 2 3 2 1
+// Note that the convolution kernel matrix is separable and thus we can instead
+// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
// Computes the size of the convolutional kernel and stride to use when resizing
// from in_size to out_size.
@@ -76,7 +78,8 @@ struct ResizeConvolutionDims {
std::vector<int64> stride;
};
ResizeConvolutionDims ComputeResizeConvolutionParameters(
- gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size) {
+ absl::Span<const int64> in_size, absl::Span<const int64> out_size,
+ bool align_corners) {
CHECK_EQ(in_size.size(), out_size.size());
int num_spatial_dims = in_size.size();
ResizeConvolutionDims dims;
@@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
// entry before resizing.
dims.stride[i] = dims.kernel_size[i] = 1;
} else {
- int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size[i] - 1),
- static_cast<uint64>(out_size[i] - 1));
- dims.stride[i] = (in_size[i] - 1) / gcd;
- dims.kernel_size[i] = (out_size[i] - 1) / gcd;
+ // The scaling factor changes depending on the alignment of corners.
+ const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i];
+ const int64 out_size_factor =
+ align_corners ? out_size[i] - 1 : out_size[i];
+
+ int64 gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
+ static_cast<uint64>(out_size_factor));
+ dims.stride[i] = in_size_factor / gcd;
+ dims.kernel_size[i] = out_size_factor / gcd;
}
}
return dims;
}
+// The upper padding of the input needed by ConvGeneralDilated calls is
+// determined by solving two related relationships (assuming rhs_dilation == 0):
+// 1. dilated_input_dim = lower_padding + upper_padding
+// + lhs_dilation * (in_size - 1) + 1
+// 2. dilated_input_dim = (2 * dims.kernel-size - 1)
+// + dims.stride * (out_size - 1)
+int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
+ int64 stride) {
+ return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) -
+ 1 - (kernel_size * (in_size - 1));
+}
+
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
@@ -127,7 +147,7 @@ std::vector<float> Make1DKernel(int64 n) {
const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
- gtl::ArraySlice<int64> kernel_size,
+ absl::Span<const int64> kernel_size,
int64 channels) {
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
@@ -145,7 +165,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
- gtl::ArraySlice<int64> kernel_size,
+ absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
@@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector<int64> in_size,
std::vector<int64> out_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
// Picture for a 1x3 to 1x4 resize:
// stride = 2, kernel size = 3
// Input:
@@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, out_size);
+ ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
xla::XlaOp output;
- // Split convolutions into independent dimensions if they wmuld be a very
+
+ // Concatenation and padding below currently assumes num_spatial_dims is 2 to
+ // prevent needless code complexity.
+ CHECK_EQ(num_spatial_dims, 2)
+ << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
+ std::vector<int64> upper_padding(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ upper_padding[i] = dims.kernel_size[i] - 1;
+ }
+ xla::XlaOp input_data = input;
+
+ if (!align_corners) {
+ // When Tensorflow does not align_corners, the resize indexing can access
+ // beyond the upper bound and is instead clamped to prevent out of bounds
+ // reads. This is conceptually the same as extending the edges of the input.
+ // We emulate this by copying the last row/column of the input.
+ // Calculate what padding would be needed then determine how far to extend
+ // the border before lhs dilation.
+ std::vector<int64> num_extended(num_spatial_dims);
+ upper_padding[0] = CalculateUpperPadding(
+ in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] = CalculateUpperPadding(
+ in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
+ num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
+ num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
+
+ if (num_extended[0] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, in_size[0] - 1, 0, 0},
+ {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[0]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
+ }
+ }
+
+ if (num_extended[1] > 0) {
+ auto slice =
+ xla::Slice(input_data, {0, 0, in_size[1] - 1, 0},
+ {1, in_size[0] + num_extended[0], in_size[1], channels},
+ {1, 1, 1, 1});
+ for (int i = 0; i < num_extended[1]; i++) {
+ input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
+ }
+ }
+
+ // Setting in_size to (in_size + num_extended) due to the above Slice and
+ // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
+ upper_padding[0] =
+ CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
+ dims.kernel_size[0], dims.stride[0]);
+ upper_padding[1] =
+ CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
+ dims.kernel_size[1], dims.stride[1]);
+ }
+
+ // Split convolutions into independent dimensions if they would be a very
// large kernel.
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- output = xla::ConvGeneralDilated(
- input, kernel, dims.stride,
- /*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
- {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
- /*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ output =
+ xla::ConvGeneralDilated(input_data, kernel, dims.stride,
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, upper_padding[0]},
+ {dims.kernel_size[1] - 1, upper_padding[1]}},
+ /*lhs_dilation=*/dims.kernel_size,
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
output = xla::ConvGeneralDilated(
- input, kernel0, {dims.stride[0], 1},
+ input_data, kernel0, {dims.stride[0], 1},
/*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
+ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp kernel1 =
@@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
output = xla::ConvGeneralDilated(
output, kernel1, {1, dims.stride[1]},
/*padding=*/
- {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
}
@@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
const int num_spatial_dims,
std::vector<int64> in_size,
std::vector<int64> grad_size,
- const int64 channels) {
+ const int64 channels,
+ const bool align_corners) {
ResizeConvolutionDims dims =
- ComputeResizeConvolutionParameters(in_size, grad_size);
+ ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
// To form the backward convolution, we keep the kernel unchanged (it is
// already symmetric) and swap the roles of strides and LHS dilation.
@@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel {
public:
explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
- OP_REQUIRES(
- ctx, align_corners_ == true,
- errors::Unimplemented(
- "ResizeBilinear with align_corners=False is not yet implemented"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel {
// If in_size[i] > 1 and out_size[i] == 1, slice out the first input in
// dimension i.
- std::vector<int64> slice_size = in_size;
bool slice_input = false;
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] > 1 && out_size[i] == 1) {
// If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
// entry before resizing.
slice_input = true;
- slice_size[i] = 1;
+ in_size[i] = 1;
}
}
if (slice_input) {
- input = xla::Slice(input, {0, 0, 0, 0},
- {batch, slice_size[0], slice_size[1], channels},
- {1, 1, 1, 1});
+ input =
+ xla::Slice(input, {0, 0, 0, 0},
+ {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
}
// Output is always type float.
@@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel {
// operations along different dimensions.
// Given sufficient numerical stability and a<e<c and b<f<d, bilinear resize
// from image of size axb -> cxd is same as resizing axb -> exf -> cxd.
+ // This does not work in the case of align_corners_=false because of special
+ // padding requirements that cause multiple resizes to be very different
+ // from a single resize.
//
// This makes the convolutions kernels smaller and the operation faster.
xla::XlaOp output = input;
@@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel {
(static_cast<float>(out_size[0]) - 1) / ((in_size[0] - 1) * 2),
(static_cast<float>(out_size[1]) - 1) / ((in_size[1] - 1) * 2)};
if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
- k[0] > 1 && k[1] > 1) {
+ k[0] > 1 && k[1] > 1 && align_corners_) {
std::vector<int64> next_out_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, next_out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, next_out_size,
+ channels, align_corners_);
input = output;
in_size = next_out_size;
} else {
- output = ResizeUsingDilationAndConvolution(
- b, input, num_spatial_dims, in_size, out_size, channels);
+ output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
+ in_size, out_size,
+ channels, align_corners_);
in_size = out_size;
}
} else {
output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims,
- in_size, out_size, channels);
+ in_size, out_size, channels,
+ align_corners_);
in_size = out_size;
}
}
@@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel {
std::vector<int64> next_grad_size = {(in_size[0] - 1) * 2 + 1,
(in_size[1] - 1) * 2 + 1};
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, next_grad_size, channels);
+ b, grad, num_spatial_dims, in_size, next_grad_size, channels,
+ align_corners_);
grad = output;
in_size = next_grad_size;
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
} else {
output = ResizeUsingDilationAndConvolutionGradOp(
- b, grad, num_spatial_dims, in_size, grad_size, channels);
+ b, grad, num_spatial_dims, in_size, grad_size, channels,
+ align_corners_);
in_size = grad_size;
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 22a45b2a11..3d81ae9eb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
args.push_back(xla::ConstantLiteral(
- &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index eedfc3c914..2a42eeaf76 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr<xla::XlaOp> DoMirrorPad(const xla::XlaOp& t,
const xla::Shape& original_shape,
const xla::LiteralSlice& pad_literal,
+ const MirrorPadMode mode,
xla::XlaBuilder* b) {
+ // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT
+ // will not mirror the border values while symmetric does.
+ // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is:
+ // - [1, 2, 3, 2, 1] in reflect mode
+ // - [1, 2, 3, 3, 2] in symmetric mode.
+ int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0;
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
@@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel {
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
pad_literal.GetIntegralAsS64({dimno, 1}));
int64 dim_size = original_shape.dimensions(dimno);
- auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding,
- dim_size - 1, 1, dimno);
- auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
+
+ // Padding amounts on each side must be no more than the size of the
+ // original shape.
+ TF_RET_CHECK(lhs_padding >= 0 &&
+ lhs_padding <= dim_size - excluded_edges);
+ TF_RET_CHECK(rhs_padding >= 0 &&
+ rhs_padding <= dim_size - excluded_edges);
+
+ auto lhs_pad =
+ xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding,
+ dim_size - excluded_edges, 1, dimno);
+ auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges,
+ excluded_edges + rhs_padding, 1, dimno);
accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno);
}
return accum;
@@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel {
MirrorPadMode mode;
OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode));
- OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT,
- xla::Unimplemented(
- "Only REFLECT MirrorPad mode is currently supported"));
+ OP_REQUIRES(
+ ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC,
+ xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and "
+ "REFLECT modes are currently supported"));
const int dims = input_shape.dims();
OP_REQUIRES(
@@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr<xla::Shape> in0_shape = b->GetShape(in0);
OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status());
xla::StatusOr<xla::XlaOp> accum_status =
- DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b);
+ DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b);
OP_REQUIRES_OK(ctx, accum_status.status());
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index d4d180aff8..27690c156e 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -138,7 +138,7 @@ xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
int num_dims = num_spatial_dims + 2;
int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
- gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
+ absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
spatial_dimensions[spatial_dim] =
GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
@@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp {
};
REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
-// Divide each element of an image by the count of elements that contributed to
-// that element during pooling.
-static xla::XlaOp AvgPoolDivideByCount(
- XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype,
- const TensorShape& input_shape, xla::Padding padding,
- const std::vector<int64>& ksize, const std::vector<int64>& stride,
- int num_spatial_dims, TensorFormat data_format) {
- if (padding == xla::Padding::kValid) {
- // In VALID padding, all windows have the same number of elements
- // contributing to each average. Divide by the window size everywhere to
- // get the average.
- int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
- [](int64 a, int64 b) { return a * b; });
-
- auto divisor =
- XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return xla::Div(output, divisor);
- } else {
- // For SAME padding, the padding shouldn't be included in the
- // counts. We use another ReduceWindow to find the right counts.
-
- // TODO(phawkins): use a less brute-force way to compute this. Only
- // the boundary regions will have interesting values here.
-
- std::vector<int64> input_dim_sizes(num_spatial_dims);
- std::vector<int64> window_dims(num_spatial_dims);
- std::vector<int64> window_ksize(num_spatial_dims);
- std::vector<int64> window_stride(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i);
- input_dim_sizes[i] = input_shape.dim_size(dim);
- window_dims[i] = dim;
- window_ksize[i] = ksize[dim];
- window_stride[i] = stride[dim];
- }
-
- // Build a matrix of all 1s, with the same width/height as the input.
- const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = xla::Broadcast(
- XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
-
- // Perform a ReduceWindow with the same window size, strides, and padding
- // to count the number of contributions to each result element.
- auto reduce = xla::ReduceWindow(
- ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
- xla::Padding::kSame);
- auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
-
- return xla::Div(output, counts, window_dims);
- }
-}
-
class AvgPoolOp : public PoolingOp {
public:
AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
@@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel {
errors::InvalidArgument("out_backprop must be ", num_dims(),
"-dimensional"));
- int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- int64 depth = out_backprop_shape.dim_size(depth_dim);
-
- // We can think of average-pooling as:
- // * a convolution with a kernel consisting entirely of 1s, where the
- // input feature and output feature are equal, and 0s everywhere else.
- // * followed by dividing by the counts.
- //
- // This then gives us an algorithm to build the gradient:
- // * divide out_backprop by the counts, followed by
- // * Conv2DBackpropInput specialized for that kernel, which simplifies to
- // a Pad and a ReduceWindow.
- //
- // For an explanation of backpropagation for convolution, see the comments
- // in third_party/tensorflow/core/kernels/conv_grad_ops.h
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- std::vector<int64> filter_dims(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- filter_dims[i] = ksize_[dim];
- }
- filter_dims[num_dims() - 2] = depth;
- filter_dims[num_dims() - 1] = depth;
- TensorShape filter_shape(filter_dims);
-
- // Reuse the logic from Conv2DBackpropInput to compute padding.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(
- ctx, ConvBackpropComputeDimensions(
- type_string(), /*num_spatial_dims=*/num_spatial_dims_,
- gradients_shape, filter_shape, out_backprop_shape, stride_,
- padding_, data_format_, &dims));
-
- // The input gradients are computed by a convolution of the output gradients
- // and the filter, with some appropriate padding. See the comment at the top
- // of conv_grad_ops.h for details.
- xla::XlaBuilder* const b = ctx->builder();
auto out_backprop = ctx->Input(1);
- auto dtype = input_type(1);
+ std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
-
- // Divide the out_backprop values by the counts for each spatial position.
- std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
- auto out_backprop_div = AvgPoolDivideByCount(
- ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
- stride_int64s, num_spatial_dims_, data_format_);
-
- // Pad the gradients in the spatial dimensions. We use the same padding
- // as Conv2DBackpropInput.
- xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- auto* padding = padding_config.mutable_dimensions(dim);
- padding->set_edge_padding_low(dims.spatial_dims[i].pad_before);
- padding->set_edge_padding_high(dims.spatial_dims[i].pad_after);
- padding->set_interior_padding(dims.spatial_dims[i].stride - 1);
- }
-
- auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
-
- // in_backprop = padded_gradients <conv> ones
- std::vector<int64> ones(num_dims(), 1LL);
- auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = xla::ReduceWindow(
- XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), ksize_,
- /* window_strides=*/ones, xla::Padding::kValid);
- ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype));
+ xla::PrimitiveType xla_reduction_type;
+ auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
+ auto converted_out_backprop =
+ xla::ConvertElementType(out_backprop, xla_reduction_type);
+ auto xla_data_format =
+ XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
+ auto padding_values =
+ MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
+ xla_padding, xla_data_format);
+ auto in_backprop =
+ xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
+ ksize_, stride_int64s, padding_values, xla_data_format,
+ /*counts_include_padding=*/padding_ == VALID);
+ // Convert the pooling result back to the input type before returning it.
+ xla::PrimitiveType xla_out_backprop_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
+ &xla_out_backprop_type));
+ ctx->SetOutput(0,
+ xla::ConvertElementType(in_backprop, xla_out_backprop_type));
}
protected:
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
index de9068a640..7ea0afc1f5 100644
--- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -23,15 +23,10 @@ namespace {
class QROp : public XlaOpKernel {
public:
explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- bool full_matrices;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices));
- OP_REQUIRES(
- ctx, full_matrices,
- errors::Unimplemented("full_matrices=False case of QR decomposition is "
- "not implemented in TF/XLA"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = QRDecomposition(ctx->Input(0));
+ auto result = QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;
@@ -39,6 +34,11 @@ class QROp : public XlaOpKernel {
ctx->SetOutput(0, result.ValueOrDie().q);
ctx->SetOutput(1, result.ValueOrDie().r);
}
+
+ private:
+ // If true, compute full-sized q and r. If false, compute only the leading P
+ // columns of q.
+ bool full_matrices_;
};
REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 2da9340625..afd5986846 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel {
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
// Swap the indices at i and swaps[i].
- auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto swap_body_fn = [&](xla::XlaOp i,
+ absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index b11a4ce36d..8102faad28 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel {
explicit ReduceWindowOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_dimensions", &window_dimensions_));
- OP_REQUIRES_OK(context,
- context->GetAttr("window_strides", &window_strides_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_));
- OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_));
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const DataType dtype = context->input_type(0);
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
const int rank = input_shape.dims();
- OP_REQUIRES(context, rank == window_dimensions_.size(),
+ OP_REQUIRES(context, rank == window_dimensions.size(),
errors::InvalidArgument(
"The size of window_dimensions must be equal to the input "
"rank (",
- window_dimensions_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == window_strides_.size(),
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
errors::InvalidArgument(
"The size of window_strides must be equal to the input "
"rank (",
- window_strides_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_low_.size(),
- errors::InvalidArgument(
- "The size of padding_low must be equal to the input "
- "rank (",
- padding_low_.size(), " vs. ", rank, ")"));
- OP_REQUIRES(context, rank == padding_high_.size(),
- errors::InvalidArgument(
- "The size of padding_high must be equal to the input "
- "rank (",
- padding_high_.size(), " vs. ", rank, ")"));
-
- xla::XlaBuilder* builder = context->builder();
+ window_strides.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel {
compile_options.use_tuple_arg = false;
compile_options.resolve_compile_time_constants = false;
compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult reducer;
OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
compile_options, *computation_,
@@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel {
xla::Shape scalar_shape;
OP_REQUIRES_OK(context,
TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of ReduceWindow reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
OP_REQUIRES(context,
- xla::ShapeUtil::Compatible(
- reducer.xla_output_shape,
- xla::ShapeUtil::MakeTupleShape({scalar_shape})),
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
errors::InvalidArgument(
- "Invalid output shape of ReduceWindow reducer. Expected ",
- xla::ShapeUtil::HumanString(scalar_shape), " got ",
- xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
-
- // Wraps the reducer in a computation that unpacks the output tuple.
- xla::XlaComputation wrapper;
- {
- std::unique_ptr<xla::XlaBuilder> cb =
- builder->CreateSubBuilder("wrapper");
- auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
- auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
- auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
- xla::GetTupleElement(outputs, 0);
- xla::StatusOr<xla::XlaComputation> result = cb->Build();
- OP_REQUIRES_OK(context, result.status());
- wrapper = std::move(result.ValueOrDie());
- }
-
- std::vector<std::pair<int64, int64>> padding(rank);
- for (int i = 0; i < rank; ++i) {
- padding[i] = {padding_low_[i], padding_high_[i]};
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
}
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
- context->Input(0), context->Input(1), wrapper, window_dimensions_,
- window_strides_, padding);
+ context->Input(0), context->Input(1), *reducer.computation,
+ window_dimensions, window_strides, padding);
context->SetOutput(0, output);
}
private:
const NameAttrList* computation_;
- std::vector<int64> window_dimensions_;
- std::vector<int64> window_strides_;
- std::vector<int64> padding_low_;
- std::vector<int64> padding_high_;
TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp);
};
-REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp);
+REGISTER_XLA_OP(Name("XlaReduceWindow")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ ReduceWindowOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index b52f0a0ab6..118f2798d5 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific reduction Ops.
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -29,9 +30,6 @@ namespace tensorflow {
XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
DataType reduction_type)
: XlaOpKernel(ctx), reduction_type_(reduction_type) {
- const DataType dt = BaseType(input_type(0));
- OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
-
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
OP_REQUIRES_OK(
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
@@ -58,20 +56,24 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
return;
}
+ OP_REQUIRES(ctx, axes_tensor_shape.dims() <= 1,
+ errors::InvalidArgument(
+ "Expected scalar or vector as index argument, got ",
+ axes_tensor_shape.DebugString()));
+
// Evaluate the constant, reshaping to a 1-vector if it is a scalar.
+ std::vector<int64> axes;
xla::Literal axes_literal;
- OP_REQUIRES_OK(
- ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()},
- &axes_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes));
VLOG(1) << "data shape: " << data_shape.DebugString();
- VLOG(1) << "axes : " << axes_literal.ToString();
+ VLOG(1) << "axes : " << absl::StrJoin(axes, ",");
- gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
+ absl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
std::vector<int64> xla_axes;
int64 num_elements_reduced = 1LL;
for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
- int32 index = axes_literal.Get<int>({i});
+ int64 index = axes[i];
OP_REQUIRES(ctx,
!(index < -data_shape.dims() || index >= data_shape.dims()),
errors::InvalidArgument("Invalid reduction dimension (", index,
@@ -101,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaBuilder* const b = ctx->builder();
// Construct the builder for the reduction lambda.
- xla::XlaBuilder r(strings::StrCat(desc, "-reduction"));
+ xla::XlaBuilder r(absl::StrCat(desc, "-reduction"));
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 121750a82a..366ce42866 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel {
sizes_shape.DebugString()));
const int64 num_dims = sizes_shape.num_elements();
- xla::Literal literal;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+ std::vector<int64> shape_input;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one if there
@@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel {
int64 product = 1;
int unknown_index = -1;
for (int d = 0; d < num_dims; ++d) {
- const int32 size = literal.Get<int>({d});
+ const int32 size = shape_input[d];
if (size == -1) {
OP_REQUIRES(
ctx, unknown_index == -1,
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 64900e4709..e172c64932 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel {
} else {
xla::XlaOp input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
+ DataType input_type = ctx->input_type(0);
+ XlaContext& tc = XlaContext::Get(ctx);
+
+ if (input_type == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ ctx->SetStatus(tc.AddResourceRetval(index_, resource));
+ return;
+ }
auto is_constant = ctx->builder()->IsConstant(input);
if (!is_constant.ok()) {
@@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel {
return;
}
- XlaContext& tc = XlaContext::Get(ctx);
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
@@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
-REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp);
+REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
+ RetvalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index d962ef4a5f..8494864b33 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -95,10 +95,24 @@ class ReverseV2Op : public XlaOpKernel {
std::vector<int64> axes;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes));
+ // witnessed_axes is used to ensure that the same axis is not marked to be
+ // reversed multiple times.
+ absl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
+
for (int d = 0; d < axes.size(); ++d) {
- OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()),
- errors::InvalidArgument(axes[d], " is out of range [0, ",
- x_shape.dims(), ")."));
+ OP_REQUIRES(
+ ctx, (-x_shape.dims() <= axes[d]) && (axes[d] < x_shape.dims()),
+ errors::InvalidArgument(axes[d], " is out of range [-",
+ x_shape.dims(), ", ", x_shape.dims(), ")."));
+ // Axes can be negative and are shifted to the canonical index before
+ // being lowered to HLO.
+ if (axes[d] < 0) {
+ axes[d] += x_shape.dims();
+ }
+ OP_REQUIRES(ctx, !witnessed_axes[axes[d]],
+ errors::InvalidArgument("canonicalized axis ", axes[d],
+ " was repeated."));
+ witnessed_axes[axes[d]] = true;
}
ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes));
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index 6ce50efb4a..9e4c57c9bf 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -66,8 +66,8 @@ class SelectOp : public XlaOpKernel {
// XLA. It seems we have to broadcast on the left and then Reshape
// to get the dimensions in the right order.
const auto dim_sizes = then_shape.dim_sizes();
- gtl::ArraySlice<int64> bdims = dim_sizes;
- bdims.pop_front();
+ absl::Span<const int64> bdims = dim_sizes;
+ bdims.remove_prefix(1);
cond_handle = xla::Broadcast(cond_handle, bdims);
std::vector<int64> dim_order(then_shape.dims());
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 4e0cf99d8e..2e0a69b70e 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel {
// accept legacy scalars, even when they should be forbidden by the graphdef
// version.
OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
- errors::InvalidArgument(strings::StrCat(
+ errors::InvalidArgument(absl::StrCat(
"dim input to ExpandDims must be a scalar; got ",
dim_shape.DebugString())));
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 6adc3c58de..537b71f3c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Slice Op.
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index 025ba82741..d6bd927135 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "absl/strings/match.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace {
@@ -33,7 +33,7 @@ namespace {
class SoftmaxOp : public XlaOpKernel {
public:
explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- log_ = str_util::StartsWith(type_string(), "Log");
+ log_ = absl::StartsWith(type_string(), "Log");
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 7327258c31..76b79be6f6 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -23,10 +23,10 @@ namespace {
void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
DataType input_dtype, const TensorShape& input_tensor_shape,
- gtl::ArraySlice<int64> block_shape,
+ absl::Span<const int64> block_shape,
const xla::Literal& paddings) {
const int input_rank = input_tensor_shape.dims();
- const gtl::InlinedVector<int64, 4> input_shape =
+ const absl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
const int block_rank = block_shape.size();
@@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
ctx, input_rank >= 1 + block_rank,
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
" instead of ", input_rank));
- gtl::ArraySlice<int64> remainder_shape(input_shape);
+ absl::Span<const int64> remainder_shape(input_shape);
remainder_shape.remove_prefix(1 + block_rank);
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 4493539fe3..3293c13b21 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel {
OP_REQUIRES(ctx, kRequiredDims == input_rank,
errors::InvalidArgument("Input rank should be ", kRequiredDims,
"; got ", input_rank));
- const gtl::InlinedVector<int64, 4> input_shape =
+ const absl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
xla::XlaOp input = ctx->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index df91900570..ee70f508a9 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel {
xla::XlaOp value;
XlaContext& xc = XlaContext::Get(ctx);
XlaResource* resource;
- string name = strings::StrCat("Stack: ", stack_name_);
+ string name = absl::StrCat("Stack: ", stack_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
TensorShape(), value, /*tensor_array_size=*/size,
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 1062399d91..2b2e3de64f 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/strided_slice_op.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mem.h"
namespace tensorflow {
@@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
TensorShape final_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> end;
- gtl::InlinedVector<int64, 4> strides;
+ absl::InlinedVector<int64, 4> begin;
+ absl::InlinedVector<int64, 4> end;
+ absl::InlinedVector<int64, 4> strides;
xla::Literal begin_literal, end_literal, strides_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
@@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel {
shrink_axis_mask_, &dummy_processing_shape, &final_shape,
&dummy, &dummy, &dummy, &begin, &end, &strides));
- gtl::InlinedVector<int64, 4> dimensions_to_reverse;
- gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
+ absl::InlinedVector<int64, 4> dimensions_to_reverse;
+ absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
for (int i = 0; i < begin.size(); ++i) {
if (strides[i] > 0) {
@@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape processing_shape, final_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> end;
- gtl::InlinedVector<int64, 4> strides;
+ absl::InlinedVector<int64, 4> begin;
+ absl::InlinedVector<int64, 4> end;
+ absl::InlinedVector<int64, 4> strides;
TensorShape input_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
@@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel {
grad = xla::Reshape(grad, processing_shape.dim_sizes());
// Pad the input gradients.
- gtl::InlinedVector<int64, 4> dimensions_to_reverse;
+ absl::InlinedVector<int64, 4> dimensions_to_reverse;
xla::PaddingConfig padding_config;
for (int i = 0; i < processing_shape.dims(); ++i) {
@@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
TensorShape final_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> end;
- gtl::InlinedVector<int64, 4> strides;
+ absl::InlinedVector<int64, 4> begin;
+ absl::InlinedVector<int64, 4> end;
+ absl::InlinedVector<int64, 4> strides;
xla::Literal begin_literal, end_literal, strides_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
@@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel {
xla::XlaOp rhs = ctx->Input(4);
- gtl::InlinedVector<int64, 4> dimensions_to_reverse;
- gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
+ absl::InlinedVector<int64, 4> dimensions_to_reverse;
+ absl::InlinedVector<int64, 4> slice_begin, slice_dims;
for (int i = 0; i < begin.size(); ++i) {
// TODO(phawkins): implement strides != 1
OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index be1814d8e3..94108b764f 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource,
// relevant slice of 'operand'.
xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
- const gtl::ArraySlice<int64>& update_dims,
+ absl::Span<const int64> update_dims,
const xla::XlaOp& start_indices) {
xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
xla::XlaOp sum = xla::Add(current, update);
@@ -167,7 +167,7 @@ class TensorArrayOp : public XlaOpKernel {
XlaContext& xc = XlaContext::Get(ctx);
XlaResource* var;
- string name = strings::StrCat("TensorArray: ", tensor_array_name_);
+ string name = absl::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
dtype_, shape, value, /*tensor_array_size=*/size,
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 1233a37565..93d5996b5e 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Tile Op.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel {
bool one_dimension_is_broadcasted_without_multiple = true;
for (int i = 0; i < input_dims; ++i) {
int multiple = literal.Get<int>({i});
- OP_REQUIRES(ctx, multiple,
+ OP_REQUIRES(ctx, multiple >= 0,
errors::InvalidArgument("Expected multiples[", i,
"] >= 0, but got ", multiple));
int64 new_dim = input_shape.dim_size(i) * multiple;
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index be5e911386..7077c2e3a5 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
}
// grad_to_use = grad + 2 * l2_shrinkage * var
- // new_accum = accum + grad_to_use * grad_to_use
+ // new_accum = accum + grad * grad
// linear += grad_to_use -
// (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
// quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
@@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
grad_to_use = grad;
}
- xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum = accum + xla::Square(grad);
xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index f9148b3942..6b303b31d4 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel {
std::vector<int64> transposed_order;
// Check whether permutation is a permutation of integers of [0 .. dims).
- gtl::InlinedVector<bool, 8> bits(dims);
+ absl::InlinedVector<bool, 8> bits(dims);
bool is_identity = true;
for (int i = 0; i < dims; ++i) {
const int32 d = perm[i];
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 296518229e..559414eeaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
body_options.is_entry_computation = false;
+ body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
cond_options.is_entry_computation = false;
+ cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
@@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "while" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(builder, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::GetTupleElement(while_result, i));
}
}
+ if (has_token_input_output_) {
+ // Set token output for this "while" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(while_result, ctx->num_outputs());
+ auto shape_or = builder->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 67edebabf9..aeeff40e68 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
new file mode 100644
index 0000000000..412afeaaad
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaBroadcastHelperOp : public XlaOpKernel {
+ public:
+ explicit XlaBroadcastHelperOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp lhs = context->Input(0);
+ xla::XlaOp rhs = context->Input(1);
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims();
+ const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape;
+ const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape;
+
+ std::vector<int64> broadcast_dims;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims",
+ &broadcast_dims));
+ if (broadcast_dims.empty()) {
+ OP_REQUIRES(
+ context,
+ lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 ||
+ rhs_shape.dims() == 0,
+ errors::InvalidArgument(
+ "If broadcast_dims is empty, both "
+ "arguments must have equal rank; "
+ "argument shapes, or at least one argument must be a scalar: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, broadcast_dims.size() == min_rank_shape->dims(),
+ errors::InvalidArgument(
+ "broadcast_dims must have size equal to the smaller argument rank; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ std::vector<int64> sorted_broadcast_dims = broadcast_dims;
+ absl::c_sort(sorted_broadcast_dims);
+ std::set<int64> dims_set(broadcast_dims.begin(), broadcast_dims.end());
+ OP_REQUIRES(context,
+ dims_set.size() == broadcast_dims.size() &&
+ broadcast_dims == sorted_broadcast_dims,
+ errors::InvalidArgument(
+ "Duplicate or nonmonotonic dimension in broadcast_dims; "
+ "broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]"));
+
+ std::vector<int64> broadcast_shape(max_rank_shape->dims(), 1LL);
+ for (int i = 0; i < broadcast_dims.size(); ++i) {
+ const int dim = broadcast_dims[i];
+ OP_REQUIRES(
+ context, dim >= 0 && dim < broadcast_shape.size(),
+ errors::InvalidArgument(
+ "Invalid broadcast dimension (", dim, "); broadcast_dims: [",
+ absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ",
+ lhs_shape.DebugString(), " and ", rhs_shape.DebugString()));
+ broadcast_shape[dim] = min_rank_shape->dim_size(i);
+ }
+ xla::PrimitiveType type = context->input_xla_type(0);
+ xla::Shape broadcast_xla_shape =
+ xla::ShapeUtil::MakeShape(type, broadcast_shape);
+ if (broadcast_lhs) {
+ lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims);
+ } else {
+ rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims);
+ }
+ context->SetOutput(0, lhs);
+ context->SetOutput(1, rhs);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp);
+};
+
+REGISTER_XLA_OP(
+ Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"),
+ XlaBroadcastHelperOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
new file mode 100644
index 0000000000..fecc7c556e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaConvOp : public XlaOpKernel {
+ public:
+ explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+ const TensorShape padding_shape = context->InputShape("padding");
+ std::vector<int64> window_strides;
+ std::vector<int64> lhs_dilation;
+ std::vector<int64> rhs_dilation;
+ int64 feature_group_count;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation",
+ &lhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation",
+ &rhs_dilation));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(
+ "feature_group_count", &feature_group_count));
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::ConvGeneralDilated(
+ context->Input(0), context->Input(1), window_strides, padding,
+ lhs_dilation, rhs_dilation, dnums_, feature_group_count,
+ &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::ConvolutionDimensionNumbers dnums_;
+ xla::PrecisionConfig precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
+};
+
+REGISTER_XLA_OP(Name("XlaConv")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("lhs_dilation")
+ .CompileTimeConstInput("rhs_dilation")
+ .CompileTimeConstInput("feature_group_count")
+ .CompileTimeConstInput("padding"),
+ XlaConvOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
new file mode 100644
index 0000000000..40b15b5579
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaDotOp : public XlaOpKernel {
+ public:
+ explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ string precision_config_attr;
+ OP_REQUIRES_OK(
+ context, context->GetAttr("precision_config", &precision_config_attr));
+ OP_REQUIRES(
+ context,
+ precision_config_.ParsePartialFromString(precision_config_attr),
+ errors::InvalidArgument("Error parsing convolution dimension numbers"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape lhs_shape = context->InputShape(0);
+ const TensorShape rhs_shape = context->InputShape(1);
+
+ // We do only minimal checking, relying on XLA to check the shape
+ // invariants.
+ xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
+ dnums_, &precision_config_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ xla::DotDimensionNumbers dnums_;
+ xla::PrecisionConfig precision_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
+};
+
+REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
new file mode 100644
index 0000000000..59502d83c7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaPadOp : public XlaOpKernel {
+ public:
+ explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape padding_value_shape =
+ context->InputShape("padding_value");
+
+ std::vector<int64> padding_low;
+ std::vector<int64> padding_high;
+ std::vector<int64> padding_interior;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low",
+ &padding_low));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high",
+ &padding_high));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "padding_interior", &padding_interior));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape),
+ errors::InvalidArgument("padding_value must be a scalar"));
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == padding_low.size(),
+ errors::InvalidArgument(
+ "The size of padding_low must be equal to the input "
+ "rank (",
+ padding_low.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_high.size(),
+ errors::InvalidArgument(
+ "The size of padding_high must be equal to the input "
+ "rank (",
+ padding_high.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == padding_interior.size(),
+ errors::InvalidArgument(
+ "The size of padding_interior must be equal to the input "
+ "rank (",
+ padding_interior.size(), " vs. ", rank, ")"));
+
+ auto non_negative = [](int64 x) { return x >= 0; };
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_low, non_negative),
+ errors::InvalidArgument("padding_low must be non-negative, got [",
+ absl::StrJoin(padding_low, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_high, non_negative),
+ errors::InvalidArgument("padding_high must be non-negative, got [",
+ absl::StrJoin(padding_high, ","), "]"));
+ OP_REQUIRES(
+ context, absl::c_all_of(padding_interior, non_negative),
+ errors::InvalidArgument("padding_interior must be non-negative, got [",
+ absl::StrJoin(padding_interior, ","), "]"));
+
+ xla::PaddingConfig padding_config;
+ for (int i = 0; i < rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ dim->set_edge_padding_low(padding_low[i]);
+ dim->set_edge_padding_high(padding_high[i]);
+ dim->set_interior_padding(padding_interior[i]);
+ }
+
+ xla::XlaOp output =
+ xla::Pad(context->Input("input"), context->Input("padding_value"),
+ padding_config);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp);
+};
+
+REGISTER_XLA_OP(Name("XlaPad")
+ .CompileTimeConstInput("padding_low")
+ .CompileTimeConstInput("padding_high")
+ .CompileTimeConstInput("padding_interior"),
+ XlaPadOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
new file mode 100644
index 0000000000..fc2425f37b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaReduceOp : public XlaOpKernel {
+ public:
+ explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_));
+ OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce",
+ &dimensions_to_reduce_));
+ std::set<int64> dims_set(dimensions_to_reduce_.begin(),
+ dimensions_to_reduce_.end());
+ OP_REQUIRES(
+ context, dims_set.size() == dimensions_to_reduce_.size(),
+ errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce "
+ "argument to XlaReduce"));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape("input");
+ const TensorShape init_value_shape = context->InputShape("init_value");
+ const DataType dtype = context->input_type(0);
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape),
+ errors::InvalidArgument("init_value must be a scalar"));
+
+ auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; };
+ OP_REQUIRES(context,
+ rank >= dimensions_to_reduce_.size() &&
+ absl::c_all_of(dimensions_to_reduce_, dim_in_range),
+ errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce"));
+
+ // Build the reducer function.
+ XlaCompiler::Argument reducer_arg;
+ reducer_arg.kind = XlaCompiler::Argument::kParameter;
+ reducer_arg.type = dtype;
+ reducer_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult reducer;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *reducer_,
+ {reducer_arg, reducer_arg}, &reducer));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaReduce reducer. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(reducer.xla_output_shape)));
+
+ xla::XlaOp output =
+ xla::Reduce(context->Input("input"), context->Input("init_value"),
+ *reducer.computation, dimensions_to_reduce_);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* reducer_;
+ std::vector<int64> dimensions_to_reduce_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp);
+};
+
+REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
new file mode 100644
index 0000000000..089776fcf7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc
@@ -0,0 +1,147 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaSelectAndScatterOp : public XlaOpKernel {
+ public:
+ explicit XlaSelectAndScatterOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_));
+ OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ const TensorShape input_shape = context->InputShape(0);
+ const DataType dtype = context->input_type(0);
+
+ std::vector<int64> window_dimensions;
+ std::vector<int64> window_strides;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dimensions", &window_dimensions));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
+ &window_strides));
+
+ const int rank = input_shape.dims();
+ OP_REQUIRES(context, rank == window_dimensions.size(),
+ errors::InvalidArgument(
+ "The size of window_dimensions must be equal to the input "
+ "rank (",
+ window_dimensions.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_strides.size(),
+ errors::InvalidArgument(
+ "The size of window_strides must be equal to the input "
+ "rank (",
+ window_strides.size(), " vs. ", rank, ")"));
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.resolve_compile_time_constants = false;
+ compile_options.is_entry_computation = false;
+ compile_options.always_return_tuple = false;
+
+ // Build the select function.
+ XlaCompiler::Argument select_arg;
+ select_arg.kind = XlaCompiler::Argument::kParameter;
+ select_arg.type = dtype;
+ select_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult select;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *select_computation_,
+ {select_arg, select_arg}, &select));
+
+ xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {});
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(select.xla_output_shape,
+ select_output_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of XlaSelectAndScatter select. Expected ",
+ xla::ShapeUtil::HumanString(select_output_shape), " got ",
+ xla::ShapeUtil::HumanString(select.xla_output_shape)));
+
+ // Build the scatter function.
+ XlaCompiler::Argument scatter_arg;
+ scatter_arg.kind = XlaCompiler::Argument::kParameter;
+ scatter_arg.type = dtype;
+ scatter_arg.shape = TensorShape();
+
+ XlaCompiler::CompilationResult scatter;
+ OP_REQUIRES_OK(context, context->compiler()->CompileFunction(
+ compile_options, *scatter_computation_,
+ {scatter_arg, scatter_arg}, &scatter));
+
+ xla::Shape scalar_shape;
+ OP_REQUIRES_OK(context,
+ TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape));
+ OP_REQUIRES(
+ context,
+ xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape),
+ errors::InvalidArgument(
+ "Invalid output shape of scatter. Expected ",
+ xla::ShapeUtil::HumanString(scalar_shape), " got ",
+ xla::ShapeUtil::HumanString(scatter.xla_output_shape)));
+
+ const TensorShape padding_shape = context->InputShape("padding");
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(padding_shape) &&
+ padding_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "padding must be a matrix with minor dimension 2, got ",
+ padding_shape.DebugString()));
+ xla::Literal padding_literal;
+ OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal(
+ "padding", &padding_literal));
+ std::vector<std::pair<int64, int64>> padding(padding_shape.dim_size(0));
+ for (int i = 0; i < padding.size(); ++i) {
+ padding[i] = {padding_literal.Get<int64>({i, 0}),
+ padding_literal.Get<int64>({i, 1})};
+ }
+
+ xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding(
+ context->Input("operand"), *select.computation, window_dimensions,
+ window_strides, padding, context->Input("source"),
+ context->Input("init_value"), *scatter.computation);
+ context->SetOutput(0, output);
+ }
+
+ private:
+ const NameAttrList* select_computation_;
+ const NameAttrList* scatter_computation_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp);
+};
+
+REGISTER_XLA_OP(Name("XlaSelectAndScatter")
+ .CompileTimeConstInput("window_dimensions")
+ .CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("padding"),
+ XlaSelectAndScatterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index cb7a40e23d..8597e7f139 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -25,8 +25,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
],
)
@@ -44,8 +44,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:lib",
],
@@ -78,8 +78,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
@@ -104,6 +104,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -119,6 +120,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
@@ -165,6 +167,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -202,6 +205,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index f666d22ea4..64f2d781a6 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -27,7 +27,8 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x, bool conjugate_y) {
+ bool transpose_y, bool conjugate_x, bool conjugate_y,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
@@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
y = xla::Conj(y);
}
+ xla::PrecisionConfig precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+
// If there are no batch dimensions, use a regular Dot.
// TODO(b/69062148) Remove this code when Dot emitters can be passed
// dimensions to transpose directly (i.e. without requiring a Transpose
@@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
if (batch_dimension_numbers.empty()) {
auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
- return xla::Dot(lhs, rhs);
+ return xla::Dot(lhs, rhs, &precision_proto);
}
xla::DotDimensionNumbers dot_dnums;
@@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- return xla::DotGeneral(x, y, dot_dnums);
+
+ return xla::DotGeneral(x, y, dot_dnums, &precision_proto);
});
}
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 8757b16a1c..6edd63a4d3 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -43,9 +43,11 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
- bool transpose_y = false, bool conjugate_x = false,
- bool conjugate_y = false);
+xla::XlaOp BatchDot(
+ xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
+ bool transpose_y = false, bool conjugate_x = false,
+ bool conjugate_y = false,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 87d73eb3f0..ab3d0a5668 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -49,20 +49,22 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
+xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int n_dims = xla::ShapeUtil::Rank(a_shape);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - 2);
+ auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
xla::XlaOp l = xla::ZerosLike(a);
// Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::Shape col_shape;
@@ -101,7 +103,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// np.dot(row, np.swapaxes(row, -1, -2))
auto diag_dot = BatchDot(row, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2)))
auto l_ii =
@@ -121,7 +124,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
// r.T)
auto dot = BatchDot(body_l, row,
/*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
// np.dot(l[..., i+1:, :i], r.T)
auto dot_ip1 =
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
@@ -145,7 +149,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
} // namespace
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -181,14 +186,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false,
- /*transpose_y=*/true);
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
a = UpdateSliceInMinorDims(a, before - delta, {i, i});
}
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
- auto factorized = CholeskyUnblocked(x);
+ auto factorized = CholeskyUnblocked(x, precision);
l = UpdateSliceInMinorDims(l, factorized, {i, i});
if (i + k < n) {
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 1bef9bb166..9a561c34b9 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -30,7 +30,9 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256);
+xla::XlaOp Cholesky(
+ xla::XlaOp a, int64 block_size = 256,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index fc0c1ee838..6b3f2b6e06 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -65,9 +65,9 @@ namespace {
// return (v, tau, beta)
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
// overflows in the norm/beta calculations. Perhaps do the same here.
-xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice<int64> batch_dims,
- const int64 m, xla::XlaOp* v, xla::XlaOp* tau,
- xla::XlaOp* beta) {
+xla::Status House(xla::XlaOp x, xla::XlaOp k,
+ absl::Span<const int64> batch_dims, const int64 m,
+ xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
xla::XlaBuilder* const builder = x.builder();
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
const xla::PrimitiveType type = x_shape.element_type();
@@ -149,7 +149,8 @@ struct QRBlockResult {
xla::XlaOp taus; // Shape: [..., n]
xla::XlaOp vs; // Shape: [..., m, n]
};
-xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
+xla::StatusOr<QRBlockResult> QRBlock(
+ xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -172,7 +173,7 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
auto qr_body_fn =
- [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ [&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto a = values[0];
auto vs = values[1];
@@ -190,8 +191,12 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
auto v_broadcast = xla::Reshape(v, shape);
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, :]))
- auto vva = BatchDot(v_broadcast, a);
- vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true);
+ auto vva =
+ BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ vva =
+ BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a = a - xla::Mul(tau, vva,
/*broadcast_dimensions=*/batch_dim_indices);
@@ -250,14 +255,15 @@ xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
// There is no need to return Y since at termination of the loop it is equal to
// vs.
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
- xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
- xla::XlaOp taus, int64 m, int64 n) {
+ xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
+ xla::XlaOp taus, int64 m, int64 n,
+ xla::PrecisionConfig::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1;
auto body_fn =
- [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ [&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto w = values[0];
auto y = values[1];
@@ -272,9 +278,12 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
// yv has shape [..., n, 1]
- auto yv = BatchDot(y, v, /*transpose_x=*/true);
+ auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
// wyv has shape [..., m, 1]
- auto wyv = BatchDot(w, yv);
+ auto wyv =
+ BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
auto z = xla::Mul(
-beta, v + wyv,
@@ -321,8 +330,9 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
-xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
- int64 block_size) {
+xla::StatusOr<QRDecompositionResult> QRDecomposition(
+ xla::XlaOp a, bool full_matrices, int64 block_size,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -352,33 +362,47 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
int64 k = std::min(block_size, p - i);
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
- TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block));
+ TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision));
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
// Compute the I-WY block representation of a product of Householder
// matrices.
- TF_ASSIGN_OR_RETURN(auto w,
- ComputeWYRepresentation(type, batch_dims, qr_block.vs,
- qr_block.taus, m - i, k));
+ TF_ASSIGN_OR_RETURN(
+ auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
+ qr_block.taus, m - i, k, precision));
auto y = qr_block.vs;
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
- auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true);
- a_update = BatchDot(y, a_update);
+ auto a_update =
+ BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ a_update =
+ BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
- auto q_update = BatchDot(q_panel, w);
- q_update =
- BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true);
+ auto q_update =
+ BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
+ q_update = BatchDot(q_update, y, /*transpose_x=*/false,
+ /*transpose_y=*/true, /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
QRDecompositionResult result;
+
+ // full_matrices is false when only a partial result in needed. Slice to the
+ // needed dimensions here.
+ if (!full_matrices) {
+ q = SliceInMinorDims(q, {0, 0}, {m, p});
+ a = SliceInMinorDims(a, {0, 0}, {p, n});
+ }
result.q = q;
result.r = a;
return result;
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index abd2316ac9..24b537ac8b 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -32,8 +33,9 @@ struct QRDecompositionResult {
xla::XlaOp r;
};
-xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
- int64 block_size = 128);
+xla::StatusOr<QRDecompositionResult> QRDecomposition(
+ xla::XlaOp a, bool full_matrices, int64 block_size = 128,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index ba22eff73a..38dfde165d 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -40,9 +40,9 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
- gtl::ArraySlice<int64> indices_dims =
+ absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
- gtl::ArraySlice<int64> buffer_dims =
+ absl::Span<const int64> buffer_dims =
xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
@@ -58,7 +58,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
") must be <= the rank of the buffer (shape: ",
xla::ShapeUtil::HumanString(buffer_shape), ")");
}
- indices_dims.pop_back();
+ indices_dims.remove_suffix(1);
}
int64 num_indices = 1;
@@ -107,7 +107,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 04fa10108c..6524c2a9b1 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// We can grab entire blocks using gather
if (n > block_size) {
// Construct the starting indices of the diagonal blocks
- auto gather_indices =
+ auto start_indices =
Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks),
xla::ConstantR0<int32>(builder, block_size)),
/*broadcast_sizes=*/{2}),
@@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
// Gather the diagonal blocks
xla::GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(ndims - 1);
- dim_numbers.add_output_window_dims(ndims);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 2);
- dim_numbers.add_gather_dims_to_operand_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims - 1);
+ dim_numbers.add_offset_dims(ndims);
+ dim_numbers.add_start_index_map(ndims - 2);
+ dim_numbers.add_start_index_map(ndims - 1);
dim_numbers.set_index_vector_dim(1);
- diag_blocks = Gather(a, gather_indices, dim_numbers,
- /*window_bounds=*/{block_size, block_size});
+ diag_blocks = Gather(a, start_indices, dim_numbers,
+ /*slice_sizes=*/{block_size, block_size});
}
// The last block might be smaller than the block size,
@@ -111,7 +111,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) {
}
xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
- bool transpose_a, bool conjugate_a) {
+ bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
@@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
- auto update = -DotGeneral(input_row, body_out, dnums);
+ xla::PrecisionConfig precision_proto;
+ precision_proto.add_operand_precision(precision);
+ precision_proto.add_operand_precision(precision);
+ auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, start_indices);
@@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
});
}
-xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
- xla::XlaOp inv_diag_blocks,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a) {
+xla::XlaOp SolveWithInvertedDiagonalBlocks(
+ xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
@@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
auto a_row =
MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
if (left_side) {
- remainder = b_row - BatchDot(a_row, x, transpose_a, false);
+ remainder = b_row - BatchDot(a_row, x, transpose_a, false,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
} else {
- remainder = b_row - BatchDot(x, a_row, false, transpose_a);
+ remainder = b_row - BatchDot(x, a_row, false, transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/false, precision);
}
}
@@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::ConstantR0WithType(builder, xla::S32, j * block_size);
std::vector<xla::XlaOp> update_starts = {start_index, zero};
if (left_side) {
- x_update = BatchDot(inv_block, remainder, transpose_a, false);
+ x_update =
+ BatchDot(inv_block, remainder, transpose_a, false,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
} else {
- x_update = BatchDot(remainder, inv_block, false, transpose_a);
+ x_update =
+ BatchDot(remainder, inv_block, false, transpose_a,
+ /*conjugate_x=*/false, /*conjugate_y=*/false, precision);
std::swap(update_starts[0], update_starts[1]);
}
x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts);
@@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b,
xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size) {
+ int64 block_size,
+ xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
auto diag_blocks = DiagonalBlocks(a, block_size);
// We invert these blocks in parallel using batched matrix-vector products
- auto inv_diag_blocks =
- InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a);
+ auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
+ conjugate_a, precision);
// We now find the solution using GEMMs
- auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
- lower, transpose_a, conjugate_a);
+ auto x =
+ SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
+ transpose_a, conjugate_a, precision);
return x;
});
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 555760b7ef..2303234f36 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
@@ -57,9 +57,10 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
- bool lower, bool transpose_a, bool conjugate_a,
- int64 block_size = 128);
+xla::XlaOp TriangularSolve(
+ xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a,
+ bool conjugate_a, int64 block_size = 128,
+ xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 8b5beba383..804671fbc7 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
+ literal = xla::LiteralUtil::CreateR0<uint8>(value);
break;
case xla::U32:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
+ literal = xla::LiteralUtil::CreateR0<uint32>(value);
break;
case xla::U64:
- literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
+ literal = xla::LiteralUtil::CreateR0<uint64>(value);
break;
case xla::S8:
- literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
+ literal = xla::LiteralUtil::CreateR0<int8>(value);
break;
case xla::S32:
- literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
+ literal = xla::LiteralUtil::CreateR0<int32>(value);
break;
case xla::S64:
- literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
+ literal = xla::LiteralUtil::CreateR0<int64>(value);
break;
case xla::F32:
- literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
+ literal = xla::LiteralUtil::CreateR0<float>(value);
break;
case xla::F64:
- literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
+ literal = xla::LiteralUtil::CreateR0<double>(value);
break;
case xla::C64:
- literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
+ literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
- literal = std::move(
- *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
break;
case xla::F16:
- literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
- static_cast<xla::half>(value)));
+ literal =
+ xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
@@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
return xla::ConstantLiteral(builder, literal);
}
-xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end) {
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
+ absl::Span<const int64> end) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_RET_CHECK(start.size() == end.size());
@@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - n_minor_dims);
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
// Prepends 0s in the major dim
std::vector<int64> padded_start(n_dims, 0);
@@ -143,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
});
}
-std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
- gtl::ArraySlice<int64> ys) {
+std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
+ absl::Span<const int64> ys) {
std::vector<int64> output(xs.size() + ys.size());
std::copy(xs.begin(), xs.end(), output.begin());
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
@@ -152,8 +153,8 @@ std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
}
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts,
- gtl::ArraySlice<int64> sizes) {
+ absl::Span<const xla::XlaOp> starts,
+ absl::Span<const int64> sizes) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
@@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - sizes.size());
+ auto major_dims = xla::AsInt64Slice(shape.dimensions())
+ .subspan(
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
auto padded_starts = PrependZerosInMajorDims(x, starts);
auto padded_sizes = ConcatVectors(major_dims, sizes);
return xla::DynamicSlice(x, padded_starts, padded_sizes);
@@ -171,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
}
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start) {
+ absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
@@ -189,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
}
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start) {
+ absl::Span<const int64> start) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
@@ -204,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
}
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<xla::XlaOp> starts) {
+ absl::Span<const xla::XlaOp> starts) {
auto padded_starts = PrependZerosInMajorDims(x, starts);
return xla::DynamicUpdateSlice(x, update, padded_starts);
}
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts) {
+ absl::Span<const xla::XlaOp> starts) {
xla::XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index b4905c9528..80e9e5b002 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
@@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Builds a vector of zeros of length rank(x) with the last values being
// those in `starts`.
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
+ absl::Span<const int64> end);
// Returns the concatenation of `xs` and `ys`.
-std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
- gtl::ArraySlice<int64> ys);
+std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
+ absl::Span<const int64> ys);
// Performs a dynamic slice in the minor dimensions of a Tensor.
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
- gtl::ArraySlice<xla::XlaOp> starts,
- gtl::ArraySlice<int64> sizes);
+ absl::Span<const xla::XlaOp> starts,
+ absl::Span<const int64> sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start);
+ absl::Span<const int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<int64> start);
+ absl::Span<const int64> start);
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
- gtl::ArraySlice<xla::XlaOp> starts);
+ absl::Span<const xla::XlaOp> starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index d64394f140..594ab1dfd0 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -24,7 +24,7 @@ namespace tensorflow {
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
@@ -47,7 +47,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
// Build the condition.
std::unique_ptr<xla::XlaBuilder> cond_builder =
- builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
+ builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
{
auto parameter =
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
@@ -61,7 +61,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
// Build the body.
std::unique_ptr<xla::XlaBuilder> body_builder =
- builder->CreateSubBuilder(strings::StrCat(name, "_body"));
+ builder->CreateSubBuilder(absl::StrCat(name, "_body"));
{
auto parameter =
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
@@ -84,15 +84,15 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) {
auto while_cond_fn =
- [&](gtl::ArraySlice<xla::XlaOp> values,
+ [&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
num_iterations));
};
- auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
+ auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::XlaOp iteration = values[0];
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 9493b1f109..f2134bb449 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,24 +19,24 @@ limitations under the License.
#include <functional>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
-typedef std::function<xla::StatusOr<xla::XlaOp>(gtl::ArraySlice<xla::XlaOp>,
+typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
xla::XlaBuilder*)>
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
- gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
+ absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
LoopBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
@@ -50,7 +50,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
@@ -59,13 +59,13 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
- xla::XlaOp, gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
+ xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
ForEachIndexBodyFunction;
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 77da1bf29c..20103ec3ae 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -49,9 +49,8 @@ Status HostTensorToMutableBorrowingLiteral(
return Status::OK();
}
-Status HostTensorsToBorrowingLiteralTuple(
- tensorflow::gtl::ArraySlice<Tensor> host_tensors,
- xla::BorrowingLiteral* literal) {
+Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
+ xla::BorrowingLiteral* literal) {
std::vector<const char*> buf_ptrs;
buf_ptrs.reserve(host_tensors.size());
std::vector<xla::Shape> tensor_shapes(host_tensors.size());
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 09d6fa8116..1db7470ee2 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -18,11 +18,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -43,9 +43,8 @@ Status HostTensorToMutableBorrowingLiteral(
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'.
-Status HostTensorsToBorrowingLiteralTuple(
- tensorflow::gtl::ArraySlice<Tensor> host_tensors,
- xla::BorrowingLiteral* literal);
+Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
+ xla::BorrowingLiteral* literal);
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index a3404c2b3d..ed452bceeb 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -27,19 +27,17 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
{
std::vector<int64> int64_values = {1, 2, 3};
- std::unique_ptr<xla::Literal> int64_values_literal =
- xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
+ xla::Literal int64_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
.error_message());
- EXPECT_EQ(
- "Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
EXPECT_TRUE(
- LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
- .ok());
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
test::ExpectTensorEqual<int64>(host_tensor,
test::AsTensor<int64>(int64_values));
}
@@ -48,23 +46,22 @@ TEST(LiteralUtil, LiteralToHostTensor) {
// Repeat tests with int32.
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
- std::unique_ptr<xla::Literal> int32_values_literal =
- xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
+ xla::Literal int32_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
- .ok());
+ LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
test::ExpectTensorEqual<int32>(host_tensor,
test::AsTensor<int32>(int32_values));
EXPECT_TRUE(
- LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
.ok());
std::vector<qint32> qint32_values = {10, 11};
test::ExpectTensorEqual<qint32>(host_tensor,
test::AsTensor<qint32>(qint32_values));
EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
+ LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
.error_message());
}
}
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index ace6fd1d8e..4dce0a2102 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -11,6 +11,8 @@ cc_library(
srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index a59c77f5c3..68cfdc1785 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -13,11 +13,97 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace {
+
+// Helper shape function for operators that return an output with the same rank
+// as their first input.
+Status UnchangedRank(shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("XlaBroadcastHelper")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("broadcast_dims: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("lhs_output: T")
+ .Output("rhs_output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Helper operator for performing XLA-style broadcasts
+
+Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to
+whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules
+for binary operators.
+
+lhs: the LHS input tensor
+rhs: the RHS input tensor
+broadcast_dims: an XLA-style broadcast dimension specification
+lhs_output: the broadcasted LHS tensor
+rhs_output: the broadcasted RHS tensor
+)doc");
+
+REGISTER_OP("XlaConv")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("lhs_dilation: Tindices")
+ .Input("rhs_dilation: Tindices")
+ .Input("feature_group_count: Tindices")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+.
+
+lhs: the input tensor
+rhs: the kernel tensor
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+lhs_dilation: dilation to apply between input elements
+rhs_dilation: dilation to apply between kernel elements
+feature_group_count: number of feature groups for grouped convolution.
+dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfig proto.
+)doc");
+
+REGISTER_OP("XlaDot")
+ .Input("lhs: T")
+ .Input("rhs: T")
+ .Attr("T: numbertype")
+ .Attr("dimension_numbers: string")
+ .Attr("precision_config: string")
+ .Output("output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA ConvGeneralDilated operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
+.
+
+lhs: the LHS tensor
+rhs: the RHS tensor
+dimension_numbers: a serialized xla::DotDimensionNumbers proto.
+precision_config: a serialized xla::PrecisionConfig proto.
+)doc");
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
@@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors.
whose types are the same as what then_branch returns.
)doc");
+REGISTER_OP("XlaPad")
+ .Input("input: T")
+ .Input("padding_value: T")
+ .Input("padding_low: Tindices")
+ .Input("padding_high: Tindices")
+ .Input("padding_interior: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Pad operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#pad
+.
+
+input: A `Tensor` of type T.
+padding_value: A scalar `Tensor` of type T.
+padding_low: the padding to apply at the start of each input dimensions
+padding_high: the padding to apply at the end of each input dimension.
+padding_interior: the padding to apply between each input element.
+output: A `Tensor` of type T.
+)doc");
+
REGISTER_OP("XlaRecv")
.Output("tensor: dtype")
.Attr("dtype: type")
@@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel.
shape: The shape of the tensor.
)doc");
+REGISTER_OP("XlaReduce")
+ .Input("input: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("dimensions_to_reduce: list(int)")
+ .Attr("reducer: func")
+ .Output("output: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->RankKnown(c->input(0))) {
+ int rank = c->Rank(c->input(0));
+ std::vector<int64> dimensions_to_reduce;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce));
+ std::set<int64> dims_set(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ auto dim_in_range = [rank](int64 dim) {
+ return dim >= 0 && dim < rank;
+ };
+ if (rank < dimensions_to_reduce.size() ||
+ dims_set.size() != dimensions_to_reduce.size() ||
+ !absl::c_all_of(dimensions_to_reduce, dim_in_range)) {
+ return errors::InvalidArgument(
+ "Invalid dimensions_to_reduce argument to XlaReduce");
+ }
+ c->set_output(
+ 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size()));
+ } else {
+ c->set_output(0, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Wraps the XLA Reduce operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#reduce .
+
+input: the input tensor
+init_value: a scalar representing the initial value for the reduction
+reducer: a reducer function to apply
+dimensions_to_reduce: dimension numbers over which to reduce
+)doc");
+
REGISTER_OP("XlaReduceWindow")
.Input("input: T")
.Input("init_value: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
.Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
.Attr("computation: func")
- .Attr("window_dimensions: list(int)")
- .Attr("window_strides: list(int)")
- .Attr("padding_low: list(int)")
- .Attr("padding_high: list(int)")
.Output("output: T")
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ReduceWindow operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
@@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction
computation: a reducer function to apply
window_dimensions: the shape of the window
window_strides: the inter-window strides
-padding_low: the padding to apply at the start of each input dimensions
-padding_high: the padding to apply at the end of each input dimension.
+padding: the padding to apply at the start and end of each input dimensions
+)doc");
+
+REGISTER_OP("XlaSelectAndScatter")
+ .Input("operand: T")
+ .Input("window_dimensions: Tindices")
+ .Input("window_strides: Tindices")
+ .Input("padding: Tindices")
+ .Input("source: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("select: func")
+ .Attr("scatter: func")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA SelectAndScatter operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter
+.
+
+operand: the input tensor
+window_dimensions: the shape of the window
+window_strides: the inter-window strides
+padding: the padding to apply at the start and end of each input dimensions
+source: a tensor of values to scatter
+init_value: a scalar representing the initial value for the output tensor
+select: a selection function to apply
+scatter: a scatter function to apply
)doc");
REGISTER_OP("XlaSend")
@@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another
list of tensors. Both lists have the same types as specified by T.
)doc");
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD
index 42b6292f79..69ca394360 100644
--- a/tensorflow/compiler/tf2xla/python/BUILD
+++ b/tensorflow/compiler/tf2xla/python/BUILD
@@ -28,5 +28,6 @@ py_library(
srcs = ["xla.py"],
deps = [
"//tensorflow/compiler/tf2xla/ops:gen_xla_ops",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
],
)
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 2fc47dffb8..3626de375e 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -15,11 +15,12 @@
"""Experimental library that exposes XLA operations directly in TensorFlow.
It is sometimes useful to be able to build HLO programs directly from
-TensorFlow. This file provides Tensorflow operators that map as closely as
-possible to HLO operators.
+TensorFlow. This file provides Tensorflow operators that mirror the semantics of
+HLO operators as closely as possible.
-There is no promise of backward or forward compatibility for operators defined
-in this module.
+Note: There is no promise of backward or forward compatibility for operators
+defined in this module. This is primarily because the underlying HLO operators
+do not promise backward or forward compatibility.
"""
from __future__ import absolute_import
@@ -27,11 +28,298 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
+# ops include:
+# infeed/outfeed (available via tf.contrib.tpu)
+# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
+# conditional
+# gather/scatter
+# collapse
+
+# This file reuses builtin names (following XLA's names, so we can call things
+# like xla.max), so we capture the builtin versions here.
+# pylint: disable=redefined-builtin
+_max = max
+_min = min
+_slice = slice # pylint: disable=invalid-name
+
+constant = constant_op.constant
+
+# Unary operators.
+
+# For most arithmetic operators there is a TensorFlow operator
+# that exactly corresponds to each XLA operator. Rather than defining
+# XLA-specific variants, we reuse the corresponding TensorFlow operator.
+# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
+# wrap every HLO operator, because that would allow us to be confident that the
+# semantics match.
+
+
+def _unary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def unary_op_wrapper(x, name=None):
+ return fn(x, name=name)
+
+ return unary_op_wrapper
+
+
+abs = _unary_op(math_ops.abs)
+# TODO(phawkins): implement clz.
+conj = _unary_op(math_ops.conj)
+cos = _unary_op(math_ops.cos)
+ceil = _unary_op(math_ops.ceil)
+digamma = _unary_op(math_ops.digamma)
+erf = _unary_op(math_ops.erf)
+erfc = _unary_op(math_ops.erfc)
+# TODO(phawkins): implement erfinv
+exp = _unary_op(math_ops.exp)
+expm1 = _unary_op(math_ops.expm1)
+floor = _unary_op(math_ops.floor)
+imag = _unary_op(math_ops.imag)
+is_finite = _unary_op(math_ops.is_finite)
+lgamma = _unary_op(math_ops.lgamma)
+log = _unary_op(math_ops.log)
+log1p = _unary_op(math_ops.log1p)
+logical_not = _unary_op(math_ops.logical_not)
+neg = _unary_op(math_ops.neg)
+real = _unary_op(math_ops.real)
+# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
+# numbers halfway between two integers.
+round = _unary_op(math_ops.round)
+sin = _unary_op(math_ops.sin)
+sign = _unary_op(math_ops.sign)
+tanh = _unary_op(math_ops.tanh)
+
+# Binary operators
+
+# The main difference between TensorFlow and XLA binary ops is the broadcasting
+# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
+# requires an explicit specification of which dimensions to broadcast if the
+# arguments have different ranks.
+
+
+def _broadcasting_binary_op(fn):
+ """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
+
+ def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
+ """Inner wrapper function."""
+ broadcast_dims = broadcast_dims or []
+ broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
+ # Rather than relying on having static shape information in the TensorFlow
+ # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
+ # at JIT compilation time.
+ x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
+ return fn(x, y, name=name)
+
+ return broadcasting_binary_op_wrapper
+
+
+# Map from TF signed types to TF unsigned types.
+_SIGNED_TO_UNSIGNED_TABLE = {
+ dtypes.int8: dtypes.uint8,
+ dtypes.int16: dtypes.uint16,
+ dtypes.int32: dtypes.uint32,
+ dtypes.int64: dtypes.uint64,
+}
+
+# Map from TF unsigned types to TF signed types.
+_UNSIGNED_TO_SIGNED_TABLE = {
+ dtypes.uint8: dtypes.int8,
+ dtypes.uint16: dtypes.int16,
+ dtypes.uint32: dtypes.int32,
+ dtypes.uint64: dtypes.int64,
+}
+
+
+def _shift_right_logical_helper(x, y, name=None):
+ """Performs an integer right logical shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
+ if signed:
+ unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
+ x = math_ops.cast(x, unsigned_dtype)
+ y = math_ops.cast(y, unsigned_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if signed:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+def _shift_right_arithmetic_helper(x, y, name=None):
+ """Performs an integer right arithmetic shift irrespective of input type."""
+ assert y.dtype == x.dtype
+ dtype = x.dtype
+ unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
+ if unsigned:
+ signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
+ x = math_ops.cast(x, signed_dtype)
+ y = math_ops.cast(y, signed_dtype)
+ output = bitwise_ops.right_shift(x, y, name=name)
+ if unsigned:
+ output = math_ops.cast(output, dtype)
+ return output
+
+
+add = _broadcasting_binary_op(math_ops.add)
+sub = _broadcasting_binary_op(math_ops.sub)
+mul = _broadcasting_binary_op(math_ops.mul)
+div = _broadcasting_binary_op(math_ops.div)
+rem = _broadcasting_binary_op(gen_math_ops.mod)
+max = _broadcasting_binary_op(math_ops.maximum)
+min = _broadcasting_binary_op(math_ops.minimum)
+atan2 = _broadcasting_binary_op(math_ops.atan2)
+complex = _broadcasting_binary_op(math_ops.complex)
+logical_and = _broadcasting_binary_op(math_ops.logical_and)
+logical_or = _broadcasting_binary_op(math_ops.logical_or)
+logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
+eq = _broadcasting_binary_op(math_ops.equal)
+ne = _broadcasting_binary_op(math_ops.not_equal)
+ge = _broadcasting_binary_op(math_ops.greater_equal)
+gt = _broadcasting_binary_op(math_ops.greater)
+le = _broadcasting_binary_op(math_ops.less_equal)
+lt = _broadcasting_binary_op(math_ops.less)
+pow = _broadcasting_binary_op(math_ops.pow)
+shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
+shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
+shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
+
+
+def _binary_op(fn):
+ """Wrapper that restricts `fn` to have the correct signature."""
+
+ def binary_op_wrapper(x, y, name=None):
+ return fn(x, y, name=name)
+
+ return binary_op_wrapper
+
+
+transpose = _binary_op(array_ops.transpose)
+rev = _binary_op(array_ops.reverse)
+
+bitcast_convert_type = array_ops.bitcast
+
+
+def broadcast(x, dims, name=None):
+ x = ops.convert_to_tensor(x)
+ shape = array_ops.concat(
+ [constant_op.constant(dims),
+ array_ops.shape(x)], axis=0)
+ return array_ops.broadcast_to(x, shape, name=name)
+
+
+def clamp(a, x, b, name=None):
+ return min(max(a, x, name=name), b, name=name)
+
+
+concatenate = array_ops.concat
+
+
+def conv(lhs,
+ rhs,
+ window_strides,
+ padding,
+ lhs_dilation,
+ rhs_dilation,
+ dimension_numbers,
+ feature_group_count=1,
+ precision_config=None,
+ name=None):
+ """Wraps the XLA ConvGeneralDilated operator.
+
+ ConvGeneralDilated is the most general form of XLA convolution and is
+ documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
+
+ Args:
+ lhs: the input tensor
+ rhs: the kernel tensor
+ window_strides: the inter-window strides
+ padding: the padding to apply at the start and end of each input dimensions
+ lhs_dilation: dilation to apply between input elements
+ rhs_dilation: dilation to apply between kernel elements
+ dimension_numbers: a `ConvolutionDimensionNumbers` proto.
+ feature_group_count: number of feature groups for grouped convolution.
+ precision_config: a `PrecisionConfigProto` proto.
+ name: an optional name for the operator
+
+ Returns:
+ A tensor representing the output of the convolution.
+ """
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_conv(
+ lhs,
+ rhs,
+ window_strides=window_strides,
+ padding=padding,
+ lhs_dilation=lhs_dilation,
+ rhs_dilation=rhs_dilation,
+ feature_group_count=feature_group_count,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+convert_element_type = math_ops.cast
+
+
+def dot(lhs, rhs, name=None):
+ return math_ops.tensordot(lhs, rhs, axes=1, name=name)
+
+
+def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
+ precision_config_proto = ""
+ if precision_config:
+ precision_config_proto = precision_config.SerializeToString()
+ return gen_xla_ops.xla_dot(
+ lhs,
+ rhs,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ precision_config=precision_config_proto,
+ name=name)
+
+
+def dynamic_slice(x, starts, sizes, name=None):
+ # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
+ # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
+ # slice if the slice is out of bounds.
+ return array_ops.slice(x, starts, sizes, name=name)
-# TODO(phawkins): provide wrappers for all XLA operators.
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
+# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
+# the XLA-specific pad operator.
+pad = gen_xla_ops.xla_pad
+
+
+def random_normal(mu, sigma, dims, name=None):
+ mu = ops.convert_to_tensor(mu)
+ return random_ops.random_normal(
+ dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
+
+
+def random_uniform(minval, maxval, dims, name=None):
+ minval = ops.convert_to_tensor(minval)
+ return random_ops.random_uniform(
+ dims, minval, maxval, dtype=minval.dtype, name=name)
+
+
+recv = gen_xla_ops.xla_recv
+reduce = gen_xla_ops.xla_reduce
+
def reduce_window(operand,
init,
@@ -61,22 +349,38 @@ def reduce_window(operand,
"""
window_strides = window_strides or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
- padding_low = [x for (x, _) in padding]
- padding_high = [y for (_, y) in padding]
return gen_xla_ops.xla_reduce_window(
- operand,
- init,
- reducer,
- window_dimensions,
- window_strides,
- padding_low,
- padding_high,
+ input=operand,
+ init_value=init,
+ window_dimensions=window_dimensions,
+ window_strides=window_strides,
+ padding=padding,
+ computation=reducer,
name=name)
-recv = gen_xla_ops.xla_recv
+def reshape(x, new_sizes, dimensions=None, name=None):
+ if dimensions is not None:
+ x = array_ops.transpose(x, dimensions)
+ x = array_ops.reshape(x, new_sizes, name=name)
+ return x
+
+
+def select(condition, x, y, name=None):
+ return array_ops.where(condition, x, y, name)
+
+
+select_and_scatter = gen_xla_ops.xla_select_and_scatter
send = gen_xla_ops.xla_send
-sort = gen_xla_ops.xla_sort
+def slice(x, start_dims, limit_dims, strides):
+ spec = [
+ _slice(start, limit, stride)
+ for (start, limit, stride) in zip(start_dims, limit_dims, strides)
+ ]
+ return x[tuple(spec)]
+
+
+sort = gen_xla_ops.xla_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
new file mode 100644
index 0000000000..20f2ce2919
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
+ XlaResourceOpKind op_kind) {
+ switch (op_kind) {
+ case XlaResourceOpKind::kRead:
+ return "Read";
+ case XlaResourceOpKind::kWrite:
+ return "Write";
+ case XlaResourceOpKind::kReadWrite:
+ return "Modify";
+ }
+}
+
+static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
+CreateResourceOpInfoMap() {
+ auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
+
+ auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
+ XlaResourceKind resource_kind) {
+ auto insert_result =
+ result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
+ CHECK(insert_result.second);
+ };
+
+ auto kRead = XlaResourceOpKind::kRead;
+ auto kWrite = XlaResourceOpKind::kWrite;
+ auto kReadWrite = XlaResourceOpKind::kReadWrite;
+
+ auto kVariable = XlaResourceKind::kVariable;
+ auto kStack = XlaResourceKind::kStack;
+ auto kTensorArray = XlaResourceKind::kTensorArray;
+
+ // clang-format off
+ add("AssignAddVariableOp" , kReadWrite, kVariable);
+ add("AssignSubVariableOp" , kReadWrite, kVariable);
+ add("AssignVariableOp" , kWrite, kVariable);
+ add("ReadVariableOp" , kRead, kVariable);
+ add("ResourceApplyAdaMax" , kReadWrite, kVariable);
+ add("ResourceApplyAdadelta" , kReadWrite, kVariable);
+ add("ResourceApplyAdagrad" , kReadWrite, kVariable);
+ add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
+ add("ResourceApplyAdam" , kReadWrite, kVariable);
+ add("ResourceApplyAddSign" , kReadWrite, kVariable);
+ add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
+ add("ResourceApplyFtrl" , kReadWrite, kVariable);
+ add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
+ add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
+ add("ResourceApplyMomentum" , kReadWrite, kVariable);
+ add("ResourceApplyPowerSign" , kReadWrite, kVariable);
+ add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
+ add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
+ add("ResourceApplyRMSProp" , kReadWrite, kVariable);
+ add("ResourceGather" , kRead, kVariable);
+ add("ResourceScatterAdd" , kReadWrite, kVariable);
+ add("ResourceScatterDiv" , kReadWrite, kVariable);
+ add("ResourceScatterMax" , kReadWrite, kVariable);
+ add("ResourceScatterMin" , kReadWrite, kVariable);
+ add("ResourceScatterMul" , kReadWrite, kVariable);
+ add("ResourceScatterNdAdd" , kReadWrite, kVariable);
+ add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
+ add("ResourceScatterSub" , kReadWrite, kVariable);
+ add("ResourceScatterUpdate" , kReadWrite, kVariable);
+ add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
+ add("VarIsInitializedOp" , kRead, kVariable);
+ add("VariableShape" , kRead, kVariable);
+
+ add("StackV2" , kWrite, kStack);
+ add("StackCloseV2" , kRead, kStack);
+ add("StackPopV2" , kReadWrite, kStack);
+ add("StackPushV2" , kReadWrite, kStack);
+
+ add("TensorArrayV3" , kWrite, kTensorArray);
+ add("TensorArrayConcatV3" , kRead, kTensorArray);
+ add("TensorArrayGatherV3" , kRead, kTensorArray);
+ add("TensorArrayScatterV3" , kWrite, kTensorArray);
+ add("TensorArrayGradV3" , kRead, kTensorArray);
+ add("TensorArrayCloseV3" , kRead, kTensorArray);
+ add("TensorArrayReadV3" , kRead, kTensorArray);
+ add("TensorArraySizeV3" , kRead, kTensorArray);
+ add("TensorArraySplitV3" , kWrite, kTensorArray);
+ add("TensorArrayWriteV3" , kWrite, kTensorArray);
+ // clang-format on
+
+ return result;
+}
+
+static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
+GetStaticResourceOpInfoMap() {
+ static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
+ CreateResourceOpInfoMap();
+ return *op_info_map;
+}
+
+const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
+ const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
+ GetStaticResourceOpInfoMap();
+ auto it = op_infos.find(op);
+ return it == op_infos.end() ? nullptr : &it->second;
+}
+
+namespace resource_op_table_internal {
+std::vector<absl::string_view> GetKnownResourceOps() {
+ std::vector<absl::string_view> result;
+ for (const auto& p : GetStaticResourceOpInfoMap()) {
+ result.push_back(p.first);
+ }
+ absl::c_sort(result);
+ return result;
+}
+} // namespace resource_op_table_internal
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h
new file mode 100644
index 0000000000..61c7a56ff0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/logging.h"
+
+// Exposes information about the resource operations supported by tf2xla in a
+// structured form.
+
+namespace tensorflow {
+enum class XlaResourceOpKind {
+ kRead, // Only reads from resources.
+ kWrite, // Only writes to resources.
+ kReadWrite // Reads from and writes to resources.
+};
+
+enum class XlaResourceKind {
+ kVariable, // Operates on resource variables.
+ kStack, // Operates on stacks.
+ kTensorArray // Operates on tensor arrays.
+};
+
+class XlaResourceOpInfo {
+ public:
+ explicit XlaResourceOpInfo(XlaResourceOpKind op_kind,
+ XlaResourceKind resource_kind)
+ : op_kind_(op_kind), resource_kind_(resource_kind) {}
+
+ XlaResourceOpKind kind() const { return op_kind_; }
+ XlaResourceKind resource_kind() const { return resource_kind_; }
+
+ static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind);
+
+ private:
+ XlaResourceOpKind op_kind_;
+ XlaResourceKind resource_kind_;
+};
+
+// Returns a XlaResourceOpInfo describing `op` if it is a resource operation
+// supported by tf2xla, otherwise returns null (i.e. if this returns null then
+// `op` is either not a resource operation or is unsupported by XLA).
+const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op);
+
+namespace resource_op_table_internal {
+// NB! Implementation detail exposed for unit testing, do not use.
+//
+// Returns the set of resource operations known by this module.
+std::vector<absl::string_view> GetKnownResourceOps();
+} // namespace resource_op_table_internal
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
new file mode 100644
index 0000000000..a85ef040a7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
+ return arg_def.type() == DT_RESOURCE;
+}
+
+bool HasResourceInputOrOutput(const OpDef& op_def) {
+ return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) ||
+ absl::c_any_of(op_def.output_arg(), IsResourceArgDef);
+}
+
+TEST(ResourceOperationTableTest, HaveAllResourceOps) {
+ gtl::FlatMap<string, bool> known_resource_ops;
+ for (absl::string_view known_resource_op :
+ resource_op_table_internal::GetKnownResourceOps()) {
+ ASSERT_TRUE(
+ known_resource_ops.insert({string(known_resource_op), false}).second);
+ }
+
+ std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps();
+ for (const string& xla_op_name : xla_op_names) {
+ const OpDef* op_def;
+ TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def));
+ if (HasResourceInputOrOutput(*op_def)) {
+ EXPECT_EQ(known_resource_ops.count(xla_op_name), 1)
+ << "Unknown resource op " << xla_op_name;
+ known_resource_ops[xla_op_name] = true;
+ }
+ }
+
+ std::vector<string> unnecessary_resource_ops;
+ for (const auto& pair : known_resource_ops) {
+ if (!pair.second) {
+ unnecessary_resource_ops.push_back(pair.first);
+ }
+ }
+
+ EXPECT_TRUE(unnecessary_resource_ops.empty())
+ << "Stale resource ops:\n"
+ << absl::StrJoin(unnecessary_resource_ops, "\n");
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 5759c72af3..8aae498be1 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -14,10 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "absl/strings/match.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -27,10 +26,10 @@ const char kShardingAttribute[] = "_XlaSharding";
} // namespace
namespace {
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-GetShardingFromNodeDef(const NodeDef& node_def) {
+xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
+ const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
- return tensorflow::gtl::optional<xla::OpSharding>();
+ return absl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
@@ -40,7 +39,7 @@ GetShardingFromNodeDef(const NodeDef& node_def) {
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
- return tensorflow::gtl::optional<xla::OpSharding>(sharding);
+ return absl::optional<xla::OpSharding>(sharding);
}
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
@@ -50,12 +49,11 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
}
} // namespace
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
- tensorflow::gtl::optional<xla::OpSharding> explicit_sharding) {
+ absl::optional<xla::OpSharding> explicit_sharding) {
if (device_name.empty()) {
- return tensorflow::gtl::optional<xla::OpSharding>();
+ return absl::optional<xla::OpSharding>();
}
DeviceNameUtils::ParsedName parsed_device;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
@@ -66,34 +64,34 @@ ParseShardingFromDevice(
if (explicit_sharding.has_value()) {
return explicit_sharding;
} else if (!parsed_device.has_type || !parsed_device.has_id ||
- !str_util::StrContains(parsed_device.type,
- kDeviceSuffixReplicatedCore)) {
- return tensorflow::gtl::optional<xla::OpSharding>();
+ !absl::StrContains(parsed_device.type,
+ kDeviceSuffixReplicatedCore)) {
+ return absl::optional<xla::OpSharding>();
} else {
const int core = parsed_device.id;
if (core < 0 || core >= num_cores_per_replica) {
return CoreOutOfRangeError(core, num_cores_per_replica);
}
- return tensorflow::gtl::optional<xla::OpSharding>(
+ return absl::optional<xla::OpSharding>(
xla::sharding_builder::AssignDevice(core));
}
}
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) {
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
+ const NodeDef& node_def, int num_cores_per_replica) {
const string& device_name = node_def.device();
- TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional<xla::OpSharding> sharding,
+ TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node_def));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(const Node& node, int num_cores_per_replica) {
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
+ const Node& node, int num_cores_per_replica) {
string device_name = node.assigned_device_name();
if (device_name.empty()) {
device_name = node.requested_device();
}
- TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional<xla::OpSharding> sharding,
+ TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node.def()));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}
diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h
index b1c817bdcc..ab67d4f154 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.h
+++ b/tensorflow/compiler/tf2xla/sharding_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
-#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_
#include <string>
@@ -33,19 +33,18 @@ namespace tensorflow {
// - explicit_sharding if explicit_sharding.has_value()
// - a non-value if there is no assigned core or
// - a sharding set as per xla::sharding_builder::AssignDevice.
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(const string& device_name, int num_cores_per_replica,
- tensorflow::gtl::optional<xla::OpSharding>
- explicit_sharding = tensorflow::gtl::nullopt);
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
+ const string& device_name, int num_cores_per_replica,
+ absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(const Node& node, int num_cores_per_replica);
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
+ const Node& node, int num_cores_per_replica);
-xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
-ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica);
+xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
+ const NodeDef& node_def, int num_cores_per_replica);
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc
index bff5978237..dcb7e212b7 100644
--- a/tensorflow/compiler/tf2xla/sharding_util_test.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc
@@ -23,7 +23,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) {
Graph graph(OpRegistry::Global());
auto core_from_sharding =
- [](tensorflow::gtl::optional<xla::OpSharding> sharding) -> int64 {
+ [](absl::optional<xla::OpSharding> sharding) -> int64 {
if (sharding.has_value() &&
sharding.value().type() ==
xla::OpSharding::Type::OpSharding_Type_MAXIMAL) {
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
new file mode 100644
index 0000000000..6cd7b24592
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
+
+const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
+ std::set<std::string> results;
+ Node* first_side_effecting_node_on_path = nullptr;
+ ReverseDFS(g,
+ [&](Node* n) {
+ std::vector<string> token_input_nodes;
+ if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
+ &token_input_nodes)
+ .ok() ||
+ token_input_nodes.empty()) {
+ return;
+ }
+
+ if (first_side_effecting_node_on_path != nullptr) {
+ return;
+ }
+
+ first_side_effecting_node_on_path = n;
+ results.insert(n->name());
+ },
+ [&](Node* n) {
+ if (first_side_effecting_node_on_path == n) {
+ first_side_effecting_node_on_path = nullptr;
+ }
+ },
+ NodeComparatorName());
+ return results;
+}
+
+bool HasSideEffectingNodes(const Graph& g) {
+ for (Node* n : g.nodes()) {
+ std::vector<string> token_input_nodes;
+ if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
+ .ok() &&
+ !token_input_nodes.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
new file mode 100644
index 0000000000..ad07624729
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Side-effecting nodes will have this attribute set. Its value is the list of
+// node names which this node has side-effect dependencies on.
+//
+// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
+// because they always have side-effect.
+// If and While nodes may or may not have this attribute, depending on whether
+// their bodies have side-effecting nodes.
+extern const char kXlaTokenInputNodesAttrName[];
+
+// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
+// node has side-effect dependency on current graph's token input.
+extern const char kXlaTokenArgNodeName[];
+
+// Calculates side-effect dependencies for the graph's token output.
+// Returns a set of node names representing these dependencies.
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
+
+// Returns whether a graph contains side-effecting nodes.
+bool HasSideEffectingNodes(const Graph& g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc
deleted file mode 100644
index 2b0834fe7b..0000000000
--- a/tensorflow/compiler/tf2xla/str_util.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/str_util.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-namespace tensorflow {
-namespace str_util {
-
-static void ReplaceAll(string* text, StringPiece from, StringPiece to) {
- size_t pos = 0;
- while ((pos = text->find(from.data(), pos, from.size())) != string::npos) {
- text->replace(pos, from.size(), to.data(), to.size());
- pos += to.size();
- if (from.empty()) {
- pos++; // Match at the beginning of the text and after every byte
- }
- }
-}
-
-void ReplaceAllPairs(string* text,
- const std::vector<std::pair<string, string>>& replace) {
- for (const std::pair<string, string>& from_to : replace) {
- ReplaceAll(text, from_to.first, from_to.second);
- }
-}
-
-} // namespace str_util
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h
deleted file mode 100644
index 51f25009d7..0000000000
--- a/tensorflow/compiler/tf2xla/str_util.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// String utilities that are esoteric enough that they don't belong in
-// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally
-// useful under xla.
-
-#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
-#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-namespace tensorflow {
-namespace str_util {
-
-// Replace all non-overlapping occurrences of the given (from,to) pairs in-place
-// in text. If from is empty, it matches at the beginning of the text and after
-// every byte. Each (from,to) replacement pair is processed in the order it is
-// given.
-void ReplaceAllPairs(string* text,
- const std::vector<std::pair<string, string>>& replace);
-
-} // namespace str_util
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc
deleted file mode 100644
index 8817f6902a..0000000000
--- a/tensorflow/compiler/tf2xla/str_util_test.cc
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/tf2xla/str_util.h"
-
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace str_util {
-
-class ReplaceAllPairsTest : public ::testing::Test {
- protected:
- void ExpectReplaceAllPairs(
- string text, const std::vector<std::pair<string, string>>& replace,
- StringPiece want) {
- ReplaceAllPairs(&text, replace);
- EXPECT_EQ(text, want);
- }
-};
-
-TEST_F(ReplaceAllPairsTest, Simple) {
- ExpectReplaceAllPairs("", {}, "");
- ExpectReplaceAllPairs("", {{"", ""}}, "");
- ExpectReplaceAllPairs("", {{"", "X"}}, "X");
- ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ");
- ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_");
- ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_");
- ExpectReplaceAllPairs("banana", {}, "banana");
- ExpectReplaceAllPairs("banana", {{"", ""}}, "banana");
- ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_");
- ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__");
- ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana");
- ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn");
- ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX");
- ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX");
- ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX");
- ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}",
- {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}},
- "a0b123456789c0");
-}
-
-} // namespace str_util
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a91b6..f31bfb45a2 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+ std::unordered_map<string, Node*> index;
+ for (Node* node : graph.nodes()) {
+ index[node->name()] = node;
+ }
+ return index;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae92ed..350a868568 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
} // namespace tensorflow
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \
+ do { \
+ string diff; \
+ EqualGraphDefOptions eq_options; \
+ eq_options.ignore_internal_attrs = false; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 48568c825b..7dbe3a0b58 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -40,8 +42,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -75,7 +75,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
auto node_it = node_map.find(remap_it->second);
if (node_it == node_map.end()) {
// Strip off the aot_feed_#/ prefix.
- StringPiece name(remap_it->second);
+ absl::string_view name(remap_it->second);
const auto index = name.find('/');
if (index > 0) name.remove_prefix(index + 1);
return errors::InvalidArgument(
@@ -89,7 +89,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
// explicitly specify or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
+ NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp)
.Attr("T", BaseType(feed_node->output_type(output_index)))
.Attr("index", arg_index)
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
@@ -136,7 +136,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
// Connects fetch_node -> retval_node.
Node* retval_node = nullptr;
TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
+ NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp)
.Input(fetch_node, id.output_index())
.Attr("T", BaseType(fetch_node->output_type(id.output_index())))
.Attr("index", ret_index)
@@ -197,8 +197,8 @@ Status RewriteAndPruneGraph(
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted(
"Post graph-pruning",
- ", missing feeds: ", str_util::Join(missing_feeds, ", "),
- ", missing fetches: ", str_util::Join(missing_fetches, ", "));
+ ", missing feeds: ", absl::StrJoin(missing_feeds, ", "),
+ ", missing fetches: ", absl::StrJoin(missing_fetches, ", "));
}
return Status::OK();
}
@@ -256,7 +256,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
XlaOpRegistry::RegisterCompilationKernels();
for (Node* node : graph->nodes()) {
node->set_assigned_device_name(
- strings::StrCat("/device:", DEVICE_CPU_XLA_JIT));
+ absl::StrCat("/device:", DEVICE_CPU_XLA_JIT));
}
std::vector<XlaCompiler::Argument> xla_args;
TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
index 7aca889a26..567d212b5e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) {
}
std::sort(types.begin(), types.end());
constraints.push_back("`" + constraint.name() + "={" +
- str_util::Join(types, ",") + "}`");
+ absl::StrJoin(types, ",") + "}`");
}
std::cout << "`" << kdef->op() << "` | "
- << str_util::Join(constraints, "<br>") << std::endl;
+ << absl::StrJoin(constraints, "<br>") << std::endl;
}
std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
@@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
{"device", &device,
"Name of the compilation device for which to print supported ops, "
"one of: " +
- str_util::Join(device_names, ",")},
+ absl::StrJoin(device_names, ",")},
};
string usage = Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 56f7045a98..ab26d939cc 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) {
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
- auto x_global_or = client->TransferToServer(*x_literal);
- auto y_global_or = client->TransferToServer(*y_literal);
+ auto x_global_or = client->TransferToServer(x_literal);
+ auto y_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
@@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) {
auto result_or =
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
- std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+ xla::Literal result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
config.mutable_feed(0)->mutable_id()->set_output_index(
123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 0e07485d18..211caf8736 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <set>
#include <unordered_map>
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -32,8 +34,6 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -112,8 +112,8 @@ Status AddPlaceholdersForFeeds(
const string name_port = TensorIdToString(feed->id());
PlaceholderInfo& info = placeholder_info[name_port];
info.feed = feed;
- info.placeholder_name = strings::StrCat(
- "aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
+ info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
+ "/", feed->id().node_name());
(*feed_remapping)[name_port] = info.placeholder_name;
}
@@ -233,7 +233,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
// Push input nodes of the currently visited node to name_queue.
for (const string& in_edge : map_entry.second->input()) {
auto id = ParseTensorName(in_edge);
- const string node_name = std::string(id.first);
+ const string node_name = string(id.first);
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
feed_tensors.end()) {
name_queue.push(node_name);
@@ -258,7 +258,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
}
string TensorIdToString(const tf2xla::TensorId& id) {
- return strings::StrCat(id.node_name(), ":", id.output_index());
+ return absl::StrCat(id.node_name(), ":", id.output_index());
}
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
@@ -268,7 +268,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
if (edge->IsControlEdge()) continue;
const Node* possible_match = out_edges ? edge->dst() : edge->src();
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::optional<xla::OpSharding> sharding,
+ absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(
*possible_match,
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
@@ -289,7 +289,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
return Status::OK();
}
-void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
KernelDef* kdef) {
for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
if (constraint.name() == name) {
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 33620ef810..a29e764466 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -53,7 +53,7 @@ string TensorIdToString(const tf2xla::TensorId& id);
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
// Add an allowed data type to the AttrConstraint with the given name.
-void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
KernelDef* kdef);
// Returns the next random seed to use for seeding xla rng.
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
index ae51446204..68441b3d47 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
@@ -24,17 +27,14 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
-void ExpectErrorContains(const Status& status, StringPiece str) {
+void ExpectErrorContains(const Status& status, absl::string_view str) {
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), str))
<< "expected error: " << status.error_message() << " to contain: " << str;
}
@@ -153,7 +153,7 @@ static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
tf2xla::Config config;
for (const auto& fetch_node_name : fetches) {
auto* fetch = config.add_fetch();
- fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
+ fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
fetch->mutable_id()->set_node_name(fetch_node_name);
}
return config;
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index e89f473328..7f860500c7 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator {
XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
DeviceType type)
- : LocalDevice(
- options,
- Device::BuildDeviceAttributes(
- strings::StrCat("/device:", type.type(), ":0"), type,
- Bytes(256 << 20), DeviceLocality(),
- strings::StrCat("device: XLA compilation device ", type.type()))),
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ absl::StrCat("/device:", type.type(), ":0"),
+ type, Bytes(256 << 20), DeviceLocality(),
+ absl::StrCat("device: XLA compilation device ",
+ type.type()))),
allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {}
@@ -103,7 +102,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
auto sharding_parse_result = ParseShardingFromDevice(
op_kernel->def(), std::numeric_limits<int>::max());
OP_REQUIRES_OK(context, sharding_parse_result.status());
- tensorflow::gtl::optional<xla::OpSharding> op_sharding =
+ absl::optional<xla::OpSharding> op_sharding =
sharding_parse_result.ValueOrDie();
// If no sharding metadata is found, XLA is free to use whatever device it
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 226c89bcf1..dcb455779d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -18,11 +18,13 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@@ -197,14 +199,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// lowest-numbered core that consumes the argument. We choose the
// lowest-numbered core so the assignment is deterministic.
for (Node* n : graph->nodes()) {
- if (StringPiece(n->type_string()) == "_Arg") {
+ if (absl::string_view(n->type_string()) == "_Arg") {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
}
}
// Do _Retval as a second loop, in case the retval's input is an _Arg (which
// may have gotten a device assignment from the first loop).
for (Node* n : graph->nodes()) {
- if (StringPiece(n->type_string()) == "_Retval") {
+ if (absl::string_view(n->type_string()) == "_Retval") {
TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
}
}
@@ -212,8 +214,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileFunction: "
<< dump_graph::DumpGraphToFile(
- strings::StrCat("xla_compile_function_", function_id),
- *graph);
+ absl::StrCat("xla_compile_function_", function_id), *graph);
}
VLOG(1) << "====================================================";
@@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
"Invalid resource type in XLAShapeForArgument()");
}
}
+ case XlaCompiler::Argument::kToken: {
+ *xla_shape = xla::ShapeUtil::MakeTokenShape();
+ return Status::OK();
+ }
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
@@ -310,7 +315,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status status;
- auto step_container = xla::MakeUnique<ScopedStepContainer>(
+ auto step_container = absl::make_unique<ScopedStepContainer>(
step_id, [&status, device](const string& name) {
status = device->resource_manager()->Cleanup(name);
});
@@ -360,6 +365,9 @@ Status BuildComputation(
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
+ } else if (retval.resource() != nullptr) {
+ output.is_constant = false;
+ output.input_index = retval.resource()->arg_num();
} else {
output.is_constant = false;
elems.push_back(retval.handle());
@@ -413,7 +421,7 @@ Status BuildComputation(
// Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding(
- builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
+ builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
xla::XlaOp handle;
@@ -464,8 +472,6 @@ Status XlaCompiler::BuildArguments(
// XLA computation as runtime parameters.
input_mapping->clear();
input_mapping->reserve(args.size());
- std::vector<int> resources;
- resources.reserve(args.size());
// Fills in constant arguments, and computes non-constant argument order.
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
@@ -484,10 +490,12 @@ Status XlaCompiler::BuildArguments(
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
if (arg.initialized) {
- resources.push_back(i);
+ input_mapping->push_back(i);
}
+
break;
- case XlaCompiler::Argument::kParameter: {
+ case XlaCompiler::Argument::kParameter:
+ case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
break;
}
@@ -495,14 +503,11 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_constant_value(arg.constant_value);
break;
case XlaCompiler::Argument::kInvalid:
- return errors::Internal("Unreachable case in BuildArguments()");
+ return errors::Internal(
+ "Unreachable case in BuildArguments() while filling constant args");
}
}
- // Append parameters containing variable values after the other runtime
- // parameters.
- input_mapping->insert(input_mapping->end(), resources.begin(),
- resources.end());
if (input_mapping->empty()) {
return Status::OK();
}
@@ -522,7 +527,7 @@ Status XlaCompiler::BuildArguments(
// Use the _Arg nodes in the graph to resolve core assignments.
for (const Node* n : graph.nodes()) {
- if (StringPiece(n->type_string()) != "_Arg") continue;
+ if (absl::string_view(n->type_string()) != "_Arg") continue;
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < args.size())
@@ -570,7 +575,7 @@ Status XlaCompiler::BuildArguments(
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::XlaScopedShardingAssignment assign_sharding(
- builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
+ builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = xla::GetTupleElement(tuple, i);
}
@@ -578,10 +583,10 @@ Status XlaCompiler::BuildArguments(
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::XlaScopedShardingAssignment assign_sharding(
- builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
+ builder, core == -1 ? absl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
- strings::StrCat("arg", i));
+ absl::StrCat("arg", i));
}
}
@@ -617,9 +622,14 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_handle(arg_handles[i]);
}
break;
+ case XlaCompiler::Argument::kToken: {
+ arg_expression.set_handle(arg_handles[i]);
+ break;
+ }
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
- return errors::Internal("Unreachable case in BuildArguments()");
+ return errors::Internal(
+ "Unreachable case in BuildArguments() while filling handles");
}
}
@@ -643,7 +653,7 @@ Status XlaCompiler::CompileSingleOp(
// dependency edge to the _SOURCE node.
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
Node* node;
- string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
+ string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
Status status = NodeBuilder(name, "_Arg")
.ControlInput(graph->source_node())
.Attr("T", ctx->input_dtype(i))
@@ -656,7 +666,7 @@ Status XlaCompiler::CompileSingleOp(
// Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < ctx->num_outputs(); ++i) {
Node* node;
- string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
+ string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
Status status = NodeBuilder(name, "_Retval")
.Input(main_node, i)
.Attr("T", ctx->expected_output_dtype(i))
@@ -692,7 +702,7 @@ Status ValidateGraph(const Graph* graph,
const DeviceType& device_type, const string& name) {
auto maybe_error = [&](const Node* node, const Status& s) -> Status {
if (!s.ok()) {
- return errors::InvalidArgument(strings::StrCat(
+ return errors::InvalidArgument(absl::StrCat(
"Detected unsupported operations when trying to compile graph ", name,
" on ", device_type.type_string(), ": ", node->def().op(), " (",
s.error_message(), ")", FormatNodeForError(*node)));
@@ -733,7 +743,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(
- strings::StrCat("xla_compile_graph_", name), *graph);
+ absl::StrCat("xla_compile_graph_", name), *graph);
}
// Report the error here if initialization failed.
@@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
+ std::vector<XlaCompiler::Argument> real_args(args);
+ int token_input_index = -1;
+ if (options.add_token_input_output) {
+ // Add extra token input.
+ token_input_index = real_args.size();
+
+ XlaCompiler::Argument token_arg;
+ token_arg.kind = XlaCompiler::Argument::kToken;
+ real_args.push_back(token_arg);
+ }
+
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
- TF_RETURN_IF_ERROR(
- BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
- &arg_cores, &arg_expressions, &result->input_mapping,
- &result->xla_input_shapes, options.is_entry_computation));
+ TF_RETURN_IF_ERROR(BuildArguments(
+ *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+ &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
+ options.is_entry_computation));
context->set_args(std::move(arg_expressions));
+ PushNodeTokenMapping();
+ // Use std::set instead of std::unordered_set to ensure determinism.
+ std::set<std::string> output_node_token_inputs;
+ if (token_input_index != -1) {
+ // Original token comes from input.
+ auto arg_expression = context->args()[token_input_index];
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
+
+ // Calculate token inputs for output token.
+ output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
+
+ // If there's no side-effecting op in the graph, use token input as token
+ // output.
+ if (output_node_token_inputs.empty()) {
+ output_node_token_inputs.insert(kXlaTokenArgNodeName);
+ }
+ } else if (options.is_entry_computation) {
+ // Original token is manually created.
+ if (HasSideEffectingNodes(*graph)) {
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
+ }
+ }
+
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
+ if (token_input_index != -1) {
+ // Add extra token output.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const auto& node_name : output_node_token_inputs) {
+ auto token_or = GetNodeToken(node_name);
+ TF_RETURN_IF_ERROR(token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ TF_RETURN_IF_ERROR(
+ context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
+ }
+ TF_RETURN_IF_ERROR(PopNodeTokenMapping());
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
- args, arg_cores, context->retvals(), context->resources(),
+ real_args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@@ -791,14 +849,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
- // Copy the host transfer metadata to the result.
- for (const auto& send : host_compute_sends_) {
- *result->host_compute_metadata.add_device_to_host() = send.second;
- }
- for (const auto& recv : host_compute_recvs_) {
- *result->host_compute_metadata.add_host_to_device() = recv.second;
- }
-
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
@@ -816,10 +866,34 @@ Status XlaCompiler::GetChannelHandle(const string& key,
return Status::OK();
}
+Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key,
+ xla::ChannelHandle* channel) {
+ auto result = channels_.emplace(key, xla::ChannelHandle());
+ if (result.second) {
+ TF_ASSIGN_OR_RETURN(result.first->second,
+ client()->CreateHostToDeviceChannelHandle());
+ }
+ *channel = result.first->second;
+ VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString();
+ return Status::OK();
+}
+
+Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
+ xla::ChannelHandle* channel) {
+ auto result = channels_.emplace(key, xla::ChannelHandle());
+ if (result.second) {
+ TF_ASSIGN_OR_RETURN(result.first->second,
+ client()->CreateDeviceToHostChannelHandle());
+ }
+ *channel = result.first->second;
+ VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString();
+ return Status::OK();
+}
+
namespace {
-void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes,
+void SetTransfer(const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes,
tf2xla::HostTransferMetadata* transfer) {
transfer->set_key(key);
CHECK(types.size() == shapes.size());
@@ -833,8 +907,8 @@ void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
} // namespace
Status XlaCompiler::SetDeviceToHostMetadata(
- const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes) {
+ const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes) {
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
@@ -860,8 +934,8 @@ Status XlaCompiler::GetDeviceToHostShapes(
}
Status XlaCompiler::SetHostToDeviceMetadata(
- const string& key, gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes) {
+ const string& key, absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes) {
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
return errors::InvalidArgument(
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
@@ -896,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
return Status::OK();
}
+void XlaCompiler::PushNodeTokenMapping() {
+ node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
+}
+
+Status XlaCompiler::PopNodeTokenMapping() {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ node_token_mapping_stack_.pop();
+ return Status::OK();
+}
+
+Status XlaCompiler::SetNodeToken(const string& node_name,
+ const xla::XlaOp& op) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling SetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
+ if (!insert_result.second) {
+ return errors::FailedPrecondition("Token mapping already exists for node ",
+ node_name);
+ }
+ return Status::OK();
+}
+
+xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling GetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto iter = node_token_mapping_stack_.top().find(node_name);
+ if (iter == node_token_mapping_stack_.top().end()) {
+ return errors::FailedPrecondition("Cannot find token mapping for node ",
+ node_name);
+ }
+ return iter->second;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 25332c8d8e..2cc603a580 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include <stack>
+
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
+
+ // Argument is an XLA token.
+ kToken,
};
Kind kind = kInvalid;
@@ -179,10 +185,15 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
+
+ // True when we should add XLA input & output to the graph/function.
+ bool add_token_input_output = false;
};
struct OutputDescription {
// Type and shape of the output. The shape is the unflattened shape.
+ // When `type` is DT_RESOURCE, `shape` is the shape of the resource
+ // variable's value.
DataType type;
TensorShape shape;
@@ -190,6 +201,10 @@ class XlaCompiler {
// 'Tensor' is in host memory.
bool is_constant = false;
Tensor constant_value;
+
+ // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
+ // the index of the input that contains the resource.
+ int input_index;
};
// Describes a variable write side effect of the computation.
@@ -212,9 +227,9 @@ class XlaCompiler {
struct CompilationResult {
// Vector that maps from the parameters of the XLA computation to their
- // original argument positions. To handle compile-time constant inputs and
- // resources, the parameters to the XLA computation may be a subset of the
- // original arguments, and are not necessarily in the same order.)
+ // original argument positions. To handle compile-time constant inputs, the
+ // parameters to the XLA computation may be a subset of the original
+ // arguments. The relative ordering of parameters are maintained.
std::vector<int> input_mapping;
// Input shapes of the computation. If we are flattening inputs, these are
@@ -332,11 +347,21 @@ class XlaCompiler {
// same XlaCompiler.
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
+ // Retrieves the host-to-device channel handle associated with `key`.
+ // Allocates a new channel handle if none exists.
+ Status GetHostToDeviceChannelHandle(const string& key,
+ xla::ChannelHandle* channel);
+
+ // Retrieves the device-to-host channel handle associated with `key`.
+ // Allocates a new channel handle if none exists.
+ Status GetDeviceToHostChannelHandle(const string& key,
+ xla::ChannelHandle* channel);
+
// Sets the shapes and types for the device to host transfer associated with
// 'key'.
Status SetDeviceToHostMetadata(const string& key,
- gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes);
+ absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes);
// Gets the shapes the device to host transfer associated with 'key'.
Status GetDeviceToHostShapes(const string& key,
@@ -345,8 +370,8 @@ class XlaCompiler {
// Sets the shapes and types for the host to device transfer associated with
// 'key'.
Status SetHostToDeviceMetadata(const string& key,
- gtl::ArraySlice<DataType> types,
- gtl::ArraySlice<TensorShape> shapes);
+ absl::Span<const DataType> types,
+ absl::Span<const TensorShape> shapes);
// In order to avoid deadlocks from dependencies in host computations, it can
// be necessary to enforce a partial order on the execution of HostCompute
@@ -368,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
+ void PushNodeTokenMapping();
+ Status PopNodeTokenMapping();
+ Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+ xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@@ -432,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
+ // This is used to store <node name, token output> mapping. Side-effecting
+ // ops call SetNodeToken() to record its token output, so later side-effecting
+ // ops can use GetNodeToken() to get it and use it as token input.
+ //
+ // It's a stack because we need a mapping like this for each level of nested
+ // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+ // stack, and pop the mapping before returning.
+ std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be00ed8813..70efa7781d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -14,15 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -38,7 +42,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
@@ -205,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) {
std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation of a graph where the _Retval node is not necessarily last
@@ -261,23 +259,68 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) {
args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
+}
+
+// Tests that the compiler doesn't reorder the parameters.
+TEST_F(XlaCompilerTest, MixedOrderArguments) {
+ for (bool swap_order : {false, true}) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto var =
+ ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
+ // Adds an identity op around the resource to make sure identity ops
+ // propagate resources correctly.
+ auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
+ auto write = ops::AssignAddVariableOp(scope, identity, a);
+ auto read = ops::ReadVariableOp(
+ scope.WithControlDependencies(std::vector<Operation>{write}), var,
+ DT_INT32);
+ auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+ auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kResource;
+ args[1].resource_kind = XlaResource::kVariable;
+ args[1].initialized = true;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ if (swap_order) {
+ // Even after swapping arguments, the compiler should maintain the new
+ // ordering of parameters.
+ std::swap(args[0], args[1]);
+ }
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.always_return_tuple = false;
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
+ args, &result));
+
+ EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
+ }
}
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
@@ -309,10 +352,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
std::move(graph), args, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "depends on a parameter"))
+ absl::StrContains(status.error_message(), "depends on a parameter"))
<< status.error_message();
EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
+ absl::StrContains(status.error_message(), "[[{{node C}} = Reshape"))
<< status.error_message();
}
@@ -357,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get()});
- EXPECT_TRUE(
- xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
{
@@ -392,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
+ xla::Literal actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR0<int32>(7);
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({-7, -42});
- std::unique_ptr<xla::Literal> expected =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+ xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+ xla::Literal expected =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
@@ -621,34 +657,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> input_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> input_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> input =
- xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*input).ConsumeValueOrDie();
+ client_->TransferToServer(input).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> output_read =
- xla::LiteralUtil::CreateR0<int32>(42);
- std::unique_ptr<xla::Literal> output_base =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> output_grad1 =
- xla::LiteralUtil::CreateR1<int32>({0, 1});
- std::unique_ptr<xla::Literal> output_grad2 =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
- {output_base.get(), output_grad1.get(), output_grad2.get()});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+ xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+ xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ xla::Literal output_resource =
+ xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -727,8 +755,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result);
EXPECT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "is not defined."))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
}
@@ -807,21 +834,44 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
ASSERT_FALSE(status.ok());
// Flib lookup failure.
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "is not defined."))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
<< status.error_message();
// Local flib lookup failure.
- EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()),
- "Attr T is not found"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
<< status.error_message();
}
+void RunAndCheckVariablesComputation(
+ xla::Client* client, const XlaCompiler::CompilationResult& result) {
+ xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
+}
+
// Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest, Variables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
- auto write = ops::AssignAddVariableOp(scope, var, a);
+ // Adds an identity op around the resource to make sure identity ops propagate
+ // resources correctly.
+ auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
+ auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
@@ -847,31 +897,82 @@ TEST_F(XlaCompilerTest, Variables) {
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
+ RunAndCheckVariablesComputation(client_, result);
+}
+
+// Tests a simple graph that reads and writes a variable.
+TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
+ auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kVariable;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::LiteralUtil::CreateR1<int32>({7, 42});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
- client_
- ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ client_->Execute(*result.computation, {param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({5, 144});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({4, 143});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
+}
+
+TEST_F(XlaCompilerTest, ReturnResourceHandle) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
+ // Adds an identity op around the resource to make sure identity ops propagate
+ // resources correctly.
+ auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
+ auto write = ops::AssignAddVariableOp(scope, identity, a);
+ auto read = ops::ReadVariableOp(
+ scope.WithControlDependencies(std::vector<Operation>{write}), var,
+ DT_INT32);
+ auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+ auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
+ auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kResource;
+ args[1].resource_kind = XlaResource::kVariable;
+ args[1].initialized = true;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+ RunAndCheckVariablesComputation(client_, result);
}
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
@@ -937,29 +1038,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal expected0 =
xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1006,29 +1105,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal param1_literal =
xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
+
+ xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+ xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+ xla::Literal expected_literal =
+ xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests a graph which has a function with an invalid op.
@@ -1075,9 +1171,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
<< status.error_message();
}
@@ -1100,10 +1196,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(str_util::StrContains(status.error_message(),
- "is not in the list of allowed values"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(),
+ "is not in the list of allowed values"))
<< status.error_message();
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}"))
+ EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
<< status.error_message();
}
@@ -1127,9 +1223,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::move(graph_copy), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(
- str_util::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: {{node NoOp}}"))
+ absl::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: {{node NoOp}}"))
<< status.error_message();
}
@@ -1145,5 +1241,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
}
}
+class DummySideEffectingOp : public XlaOpKernel {
+ public:
+ explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
+ name(), xla::CreateToken(ctx->builder())));
+ }
+};
+
+REGISTER_OP("DummySideEffectingOp");
+
+REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
+
+TEST_F(XlaCompilerTest, TokenInputAndOutput) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef side_effecting_op;
+ side_effecting_op.set_name("DummySideEffectingOp");
+ side_effecting_op.set_op("DummySideEffectingOp");
+ AddNodeAttr(kXlaTokenInputNodesAttrName,
+ std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
+ Status status;
+ graph->AddNode(side_effecting_op, &status);
+ TF_ASSERT_OK(status);
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
+
+ const std::vector<XlaCompiler::Argument> empty_args;
+ {
+ // The case for entry computation: we don't add token input/output. Instead,
+ // we use CreateToken HLO to create the entry token.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = true;
+ options.add_token_input_output = false;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 0);
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
+ }
+ {
+ // The case for non-entry computation (e.g. while loop body). We add token
+ // input/output.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = false;
+ options.add_token_input_output = true;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(
+ xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index b24e3aabbe..f247570d72 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -31,8 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -107,6 +106,30 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
return Status::OK();
}
+Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
+ VLOG(1) << "Adding retval index " << retval_index << " with resource "
+ << resource->name() << ":" << resource->shape().DebugString()
+ << " to XLA computation";
+ if (retvals_.size() <= retval_index) {
+ retvals_.resize(retval_index + 1);
+ }
+ XlaExpression e;
+ e.set_resource(resource);
+ retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
+ return Status::OK();
+}
+
+Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
+ VLOG(1) << "Adding retval index " << retvals_.size()
+ << " with token to XLA computation";
+ XlaExpression e;
+ e.set_handle(token);
+ // We use DT_INVALID because there is no TF DataType which corresponds to XLA
+ // token. XlaCompiler handles this case separately, so putting it here is OK.
+ retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
+ return Status::OK();
+}
+
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 3db37afdba..d7dbdc957f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -86,6 +86,12 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::LiteralSlice& literal);
+ // As for Retval, but for return values that are resource handles.
+ Status AddResourceRetval(int retval_index, XlaResource* resource);
+
+ // As for Retval, but for return values that are XLA tokens.
+ Status AppendTokenRetval(const xla::XlaOp& token);
+
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 8efb3d55c8..9a34cd8c6a 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
}
/* static */ Status XlaHelpers::ReshapeLiteral(
- const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
+ const xla::Literal& input, absl::Span<const int64> dimensions,
xla::Literal* output) {
if (xla::ShapeUtil::IsTuple(input.shape())) {
return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index e6522157a5..39578144ca 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -18,10 +18,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -50,7 +50,7 @@ class XlaHelpers {
// Reshapes literal 'input' to have 'shape'. Both the original shape and
// 'shape' must contain the same number of elements.
static Status ReshapeLiteral(const xla::Literal& input,
- gtl::ArraySlice<int64> shape,
+ absl::Span<const int64> shape,
xla::Literal* output);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 82028c8b9c..d10a504da0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
-const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
return GetComputationFromTensor(GetInputTensorByName(name));
}
@@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
-TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
return GetInputTensorByName(name).shape();
}
@@ -99,8 +99,27 @@ Status XlaOpKernelContext::ConstantInput(int index,
index, context_->input(index).shape().dim_sizes(), constant_literal);
}
+static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
+ absl::string_view name) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was "
+ "expected");
+ }
+ return start;
+}
+
+Status XlaOpKernelContext::ConstantInput(absl::string_view name,
+ xla::Literal* constant_literal) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInput(index, constant_literal);
+}
+
Status XlaOpKernelContext::ConstantInputReshaped(
- int index, gtl::ArraySlice<int64> new_dims,
+ int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal) {
const Tensor& tensor = context_->input(index);
TensorShape new_shape(new_dims);
@@ -194,16 +213,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
context_->op_kernel().name(), " input ", index,
".\nError: ", constant_graph.status().error_message());
}
- xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
- compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
- &layout);
+ xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+ constant_graph.ValueOrDie(), &layout);
if (!computed.ok()) {
return errors::Internal("Error evaluating ", context_->op_kernel().name(),
" input ", index,
- "as a compile-time constant.\nError: ",
+ " as a compile-time constant.\nError: ",
computed.status().error_message());
}
- *constant_literal = std::move(*computed.ValueOrDie());
+ *constant_literal = std::move(computed).ValueOrDie();
return Status::OK();
}
@@ -246,6 +264,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
return LiteralToInt64Scalar(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
+ int64* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntScalar(index, out);
+}
+
Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
@@ -280,6 +304,20 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
+ std::vector<int64>* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsIntVector(index, out);
+}
+
+Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
+ int index, std::vector<int64>* out) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ConstantInputReshaped(
+ index, {InputShape(index).num_elements()}, &literal));
+ return LiteralToInt64Vector(literal, out);
+}
+
Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
xla::Literal* out) {
xla::Literal literal;
@@ -305,6 +343,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
}
}
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
+ xla::Literal* out) {
+ TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
+ return ConstantInputAsInt64Literal(index, out);
+}
+
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
@@ -316,7 +360,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
return Status::OK();
}
-Status XlaOpKernelContext::InputList(StringPiece name,
+Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
@@ -331,7 +375,7 @@ Status XlaOpKernelContext::InputList(StringPiece name,
}
Status XlaOpKernelContext::ConstantInputList(
- StringPiece name, std::vector<xla::Literal>* outputs) {
+ absl::string_view name, std::vector<xla::Literal>* outputs) {
int start, stop;
TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
@@ -384,8 +428,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
value);
}
-Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
- TensorShape* shape,
+Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
+ DataType type, TensorShape* shape,
xla::XlaOp* value) {
return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
shape, value);
@@ -519,7 +563,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
handle, builder());
}
-Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
return AssignVariableTensor(GetInputTensorByName(name), type, context_,
@@ -565,7 +609,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
return XlaContext::Get(context_).GetOrCreateMul(type);
}
-const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
CHECK(context_->input(name, &tensor).ok());
return *tensor;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index ac9dfe3369..962c86d3a5 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -80,14 +80,14 @@ class XlaOpKernelContext {
TensorShape InputShape(int index);
// Returns the shape of input `name`.
- TensorShape InputShape(StringPiece name);
+ TensorShape InputShape(absl::string_view name);
// Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
const xla::XlaOp& Input(int index);
// Returns input `name` as a XlaOp.
- const xla::XlaOp& Input(StringPiece name);
+ const xla::XlaOp& Input(absl::string_view name);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@@ -97,7 +97,7 @@ class XlaOpKernelContext {
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list.
- Status InputList(StringPiece name, std::vector<xla::XlaOp>* handles,
+ Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes);
// Helper methods for constant inputs.
@@ -106,26 +106,35 @@ class XlaOpKernelContext {
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
+ Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
// Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
// InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
- Status ConstantInputReshaped(int index, gtl::ArraySlice<int64> new_shape,
+ Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
xla::Literal* constant_literal);
// Converts a constant scalar int32 or int64 tensor into an int64.
Status ConstantInputAsIntScalar(int index, int64* out);
+ Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
// Converts a constant scalar float32 or float64 tensor into a float64.
Status ConstantInputAsFloatScalar(int index, double* out);
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+ Status ConstantInputAsIntVector(absl::string_view name,
+ std::vector<int64>* out);
+
+ // Reshapes and converts a constant int32 or int64 tensor into a vector of
+ // int64s.
+ Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
// Converts a constant int32 or int64 Tensor into an xla int64 Literal.
Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
+ Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);
@@ -133,7 +142,7 @@ class XlaOpKernelContext {
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list.
- Status ConstantInputList(StringPiece name,
+ Status ConstantInputList(absl::string_view name,
std::vector<xla::Literal>* literals);
// Outputs
@@ -182,8 +191,8 @@ class XlaOpKernelContext {
xla::XlaOp* value);
// Reads the current value of the resouce variable referred to by input
// `name`.
- Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
- xla::XlaOp* value);
+ Status ReadVariableInput(absl::string_view name, DataType type,
+ TensorShape* shape, xla::XlaOp* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. The variable must be of `type`. Returns an error if the
@@ -191,7 +200,8 @@ class XlaOpKernelContext {
// different shape.
Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
// Assigns the value `handle` to the variable referenced by input `name`.
- Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
+ Status AssignVariable(absl::string_view name, DataType type,
+ xla::XlaOp handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
@@ -240,7 +250,7 @@ class XlaOpKernelContext {
private:
// Returns the tensor of input `name`.
- const Tensor& GetInputTensorByName(StringPiece name);
+ const Tensor& GetInputTensorByName(absl::string_view name);
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 46785bc1f0..b0eeee3174 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -105,7 +105,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
/* static */ void XlaOpRegistry::RegisterBackend(
const string& compilation_device_name,
- gtl::ArraySlice<DataType> supported_types, BackendOpFilter op_filter) {
+ absl::Span<const DataType> supported_types, BackendOpFilter op_filter) {
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto result = registry.backends_.emplace(compilation_device_name, Backend());
@@ -325,6 +325,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
return kernels;
}
+/*static*/ std::vector<string> XlaOpRegistry::GetAllRegisteredOps() {
+ std::vector<string> ops;
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ for (const auto& pair : registry.ops_) {
+ ops.push_back(pair.first);
+ }
+ std::sort(ops.begin(), ops.end());
+ return ops;
+}
+
/* static */ const std::unordered_set<string>*
XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
XlaOpRegistry& registry = Instance();
@@ -360,28 +371,30 @@ XlaOpRegistry& XlaOpRegistry::Instance() {
return *r;
}
-XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) {
+XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) {
registration_.reset(new XlaOpRegistry::OpRegistration);
- registration_->name = std::string(name);
+ registration_->name = string(name);
}
-XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
+XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
+ absl::string_view name) {
XlaOpRegistrationBuilder registration(name);
return registration;
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
- gtl::ArraySlice<StringPiece> devices) {
+ absl::Span<const absl::string_view> devices) {
registration_->has_device_whitelist = true;
- for (StringPiece device : devices) {
- registration_->device_whitelist.insert(std::string(device));
+ for (absl::string_view device : devices) {
+ registration_->device_whitelist.emplace(device);
}
return *this;
}
-XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) {
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
+ absl::string_view device) {
registration_->has_device_whitelist = true;
- registration_->device_whitelist.insert(std::string(device));
+ registration_->device_whitelist.emplace(device);
return *this;
}
@@ -396,17 +409,17 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() {
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
- StringPiece attr_name, DataType allowed) {
+ absl::string_view attr_name, DataType allowed) {
std::set<DataType>& types =
- registration_->type_constraints[std::string(attr_name)];
+ registration_->type_constraints[string(attr_name)];
types.insert(allowed);
return *this;
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
- StringPiece attr_name, gtl::ArraySlice<DataType> allowed) {
+ absl::string_view attr_name, absl::Span<const DataType> allowed) {
std::set<DataType>& types =
- registration_->type_constraints[std::string(attr_name)];
+ registration_->type_constraints[string(attr_name)];
for (DataType t : allowed) {
types.insert(t);
}
@@ -414,8 +427,8 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
}
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
- StringPiece input_name) {
- registration_->compile_time_constant_inputs.insert(std::string(input_name));
+ absl::string_view input_name) {
+ registration_->compile_time_constant_inputs.emplace(input_name);
return *this;
}
@@ -441,10 +454,10 @@ XlaOpRegistrar::XlaOpRegistrar(
}
XlaBackendRegistrar::XlaBackendRegistrar(
- StringPiece name, gtl::ArraySlice<DataType> types,
+ absl::string_view name, absl::Span<const DataType> types,
XlaOpRegistry::BackendOpFilter op_filter) {
XlaOpRegistry& registry = XlaOpRegistry::Instance();
- registry.RegisterBackend(std::string(name), types, op_filter);
+ registry.RegisterBackend(string(name), types, op_filter);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index fc14834ca6..74a4885f1f 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -94,7 +94,7 @@ class XlaOpRegistry {
// the device; it may optionally modify the KernelDef.
typedef bool (*BackendOpFilter)(KernelDef* kdef);
static void RegisterBackend(const string& compilation_device_name,
- gtl::ArraySlice<DataType> supported_types,
+ absl::Span<const DataType> supported_types,
BackendOpFilter op_filter);
// Returns the names of the registered backends.
@@ -128,6 +128,9 @@ class XlaOpRegistry {
const string& compilation_device_name,
bool include_compilation_only_kernels);
+ // Returns all operations for which there are XLA kernels on any device.
+ static std::vector<string> GetAllRegisteredOps();
+
// Returns the set of compile-time constant inputs to 'op'. Returns nullptr
// if the op is not registered.
static const std::unordered_set<string>* CompileTimeConstantInputs(
@@ -229,19 +232,19 @@ class XlaOpRegistry {
class XlaOpRegistrationBuilder {
public:
// Starts an operator registration chain.
- static XlaOpRegistrationBuilder Name(StringPiece name);
+ static XlaOpRegistrationBuilder Name(absl::string_view name);
// Specifies a whitelist of devices on which the operator may run.
- XlaOpRegistrationBuilder& Device(StringPiece devices);
- XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
+ XlaOpRegistrationBuilder& Device(absl::string_view devices);
+ XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
// Specifies a type constraint for a type variable attribute. Each constraint
// specifies the set of types that the type variable may assume.
- XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
+ XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
DataType allowed);
- XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
- gtl::ArraySlice<DataType> allowed);
+ XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
+ absl::Span<const DataType> allowed);
// Specifies that a dummy copy of this operator should not be registered on
// XLA_* devices, but may be used during compilation.
@@ -251,13 +254,13 @@ class XlaOpRegistrationBuilder {
XlaOpRegistrationBuilder& AllowResourceTypes();
// Mark 'input_name' as an argument whose value must be known at compile-time.
- XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
+ XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);
private:
- XlaOpRegistrationBuilder(StringPiece name);
+ XlaOpRegistrationBuilder(absl::string_view name);
std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
};
@@ -285,7 +288,7 @@ class XlaOpRegistrar {
class XlaBackendRegistrar {
public:
- XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
+ XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
XlaOpRegistry::BackendOpFilter op_filter = nullptr);
};
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 7928fa0347..56c2e01055 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -43,7 +43,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
for (const string& gradient : tensor_array_gradients) {
tensor_array_gradients_[gradient].reset(new XlaResource(
/*kind=*/kTensorArray, /*arg_num=*/-1,
- /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_,
+ /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}));
}
}
@@ -135,7 +135,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
- /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
+ /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
type_, shape_, gradient_value, tensor_array_size_,
/*tensor_array_gradients=*/{}));
}
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index fdf13bb18c..76e36f3c46 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -113,6 +113,7 @@ cc_library(
":statusor",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -161,7 +162,6 @@ cc_library(
"iterator_util.h",
"map_util.h",
"overflow_util.h",
- "ptr_util.h",
"util.h",
],
visibility = ["//visibility:public"],
@@ -172,7 +172,11 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
- "//tensorflow/core:ptr_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -210,6 +214,7 @@ tf_cc_test(
":test",
":util",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -236,10 +241,13 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -256,6 +264,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -297,6 +306,10 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -315,6 +328,8 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -335,6 +350,9 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -353,6 +371,8 @@ cc_library(
":literal_util",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -364,6 +384,8 @@ cc_library(
deps = [
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -373,8 +395,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":types",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -385,6 +407,8 @@ cc_library(
":status",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -405,8 +429,9 @@ cc_library(
deps = [
":array",
":types",
- ":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -451,6 +476,8 @@ cc_library(
":array2d",
":types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -462,6 +489,7 @@ tf_cc_test(
":test",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/types:span",
],
)
@@ -489,6 +517,8 @@ cc_library(
":util",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/memory",
],
)
@@ -503,6 +533,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -521,6 +552,8 @@ cc_library(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -551,6 +584,8 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -576,10 +611,12 @@ cc_library(
deps = [
":shape_util",
":status_macros",
- ":util",
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -593,6 +630,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -619,6 +657,8 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -642,6 +682,8 @@ cc_library(
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -660,6 +702,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -671,7 +714,8 @@ cc_library(
":array2d",
":shape_util",
":xla_data_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index 2d5d078aa7..58cc157585 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -27,12 +27,12 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -97,12 +97,11 @@ class Array {
using value_type = T;
// Creates a new array with the specified dimensions.
- explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
- : Array(sizes, T()) {}
+ explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
// Creates a new array with the specified dimensions and specified value for
// every cell.
- Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
+ Array(absl::Span<const int64> sizes, T value)
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
Fill(value);
}
@@ -301,7 +300,7 @@ class Array {
// Invokes a callback with the (indices, value_ptr) for each cell in the
// array.
- void Each(std::function<void(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ void Each(std::function<void(absl::Span<const int64>, T*)> f) {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
f(index, &values_[i]);
@@ -309,8 +308,7 @@ class Array {
}
// Invokes a callback with the (indices, value) for each cell in the array.
- void Each(
- std::function<void(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+ void Each(std::function<void(absl::Span<const int64>, T)> f) const {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
f(index, values_[i]);
@@ -320,8 +318,7 @@ class Array {
// Invokes a callback with the (indices, value_ptr) for each cell in the
// array. If a callback returns a non-OK status, returns that else returns
// Status::OK().
- Status EachStatus(
- std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, &values_[i]);
@@ -335,8 +332,7 @@ class Array {
// Invokes a callback with the (indices, value) for each cell in the array.
// If a callback returns a non-OK status, returns that else returns
// Status::OK().
- Status EachStatus(
- std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+ Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
std::vector<int64> index(sizes_.size());
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
Status s = f(index, values_[i]);
@@ -377,13 +373,13 @@ class Array {
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
- const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
+ const T& operator()(absl::Span<const int64> indexes) const {
return values_[calculate_index(indexes)];
}
// Returns the value at the cell specified by the indexes. The number of
// arguments have to match with the number of dimensions for the array.
- T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
+ T& operator()(absl::Span<const int64> indexes) {
return values_[calculate_index(indexes)];
}
@@ -438,8 +434,8 @@ class Array {
bool operator!=(const Array<T>& other) const { return !(*this == other); }
// Performs the equivalent of a slice operation on this array.
- Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits) const {
+ Array<T> Slice(absl::Span<const int64> starts,
+ absl::Span<const int64> limits) const {
CHECK_EQ(starts.size(), num_dimensions());
CHECK_EQ(limits.size(), num_dimensions());
@@ -464,7 +460,7 @@ class Array {
// Performs the equivalent of a DynamicUpdateSlice in-place on this array.
void UpdateSlice(const Array<T>& from,
- tensorflow::gtl::ArraySlice<int64> start_indices) {
+ absl::Span<const int64> start_indices) {
CHECK_EQ(from.num_dimensions(), num_dimensions());
std::vector<int64> limit_indices;
std::transform(start_indices.begin(), start_indices.end(),
@@ -484,7 +480,7 @@ class Array {
// Performs an in-place reshape, modifying the dimensions but not the
// underlying data.
- void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
+ void Reshape(absl::Span<const int64> new_dimensions) {
int64 old_num_elements = num_elements();
sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
CHECK_EQ(num_elements(), old_num_elements);
@@ -507,9 +503,7 @@ class Array {
}
}
- pieces.push_back(
- tensorflow::strings::AlphaNum(values_[calculate_index(index)])
- .data());
+ pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
// Emit comma if it isn't the last element
if (index.back() != sizes_.back() - 1) {
@@ -527,7 +521,7 @@ class Array {
}
}
} while (next_index(&index));
- return tensorflow::str_util::Join(pieces, "");
+ return absl::StrJoin(pieces, "");
}
private:
diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h
index a17e81f448..782c966b4c 100644
--- a/tensorflow/compiler/xla/array2d.h
+++ b/tensorflow/compiler/xla/array2d.h
@@ -24,12 +24,11 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -101,7 +100,7 @@ class Array2D : public Array<T> {
template <typename NativeT = float>
std::unique_ptr<Array2D<NativeT>> MakeLinspaceArray2D(double from, double to,
int64 n1, int64 n2) {
- auto array = MakeUnique<Array2D<NativeT>>(n1, n2);
+ auto array = absl::make_unique<Array2D<NativeT>>(n1, n2);
int64 count = n1 * n2;
NativeT step =
static_cast<NativeT>((count > 1) ? (to - from) / (count - 1) : 0);
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h
index a75fffc605..e23d317baf 100644
--- a/tensorflow/compiler/xla/array4d.h
+++ b/tensorflow/compiler/xla/array4d.h
@@ -26,13 +26,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc
index 927733ea1e..918872a7a0 100644
--- a/tensorflow/compiler/xla/array4d_test.cc
+++ b/tensorflow/compiler/xla/array4d_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <initializer_list>
#include <numeric>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
@@ -27,8 +27,7 @@ namespace {
// Given an Array4D and a 4-tuple index, computes the linear index into the
// array idx represents.
template <typename T>
-int64 Array4DLinearIndex(const Array4D<T>& arr,
- tensorflow::gtl::ArraySlice<int64> idx) {
+int64 Array4DLinearIndex(const Array4D<T>& arr, absl::Span<const int64> idx) {
EXPECT_EQ(4, idx.size());
return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() +
idx[0] * arr.n2() * arr.n3() * arr.n4());
@@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) {
EXPECT_EQ(fullof7.n3(), 4);
EXPECT_EQ(fullof7.n4(), 5);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 7);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
}
TEST(Array4dTest, ContainerCtor) {
@@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) {
EXPECT_EQ(arr.n3(), 4);
EXPECT_EQ(arr.n4(), 5);
- arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ arr.Each([&arr](absl::Span<const int64> idx, int* cell) {
EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx));
});
}
@@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) {
TEST(Array4dTest, Fill) {
Array4D<int> fullof7(2, 3, 4, 5, 7);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 7);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
fullof7.Fill(11);
- fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
- EXPECT_EQ(*cell, 11);
- });
+ fullof7.Each(
+ [](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 11); });
}
TEST(Array4dTest, FillWithMultiples) {
Array4D<float> arr(2, 3, 4, 5);
arr.FillWithMultiples(2.0f);
- arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, float* cell) {
+ arr.Each([&arr](absl::Span<const int64> idx, float* cell) {
EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx));
});
}
diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc
index e8356c9832..2d0ac98bd4 100644
--- a/tensorflow/compiler/xla/array_test.cc
+++ b/tensorflow/compiler/xla/array_test.cc
@@ -163,7 +163,7 @@ TEST(ArrayTest, Each) {
arr.FillWithMultiples(1);
int64 each_count = 0, each_sum = 0;
- arr.Each([&](tensorflow::gtl::ArraySlice<int64> idx, int cell) {
+ arr.Each([&](absl::Span<const int64> idx, int cell) {
int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2];
EXPECT_EQ(lin_idx, cell);
each_count++;
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index ad3fcee05b..f825f67b44 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -45,6 +45,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -71,12 +72,14 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -90,6 +93,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -104,7 +110,6 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
@@ -115,8 +120,9 @@ cc_library(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",
"//tensorflow/compiler/xla/service:stream_pool",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
"@llvm//:support",
],
)
@@ -130,11 +136,11 @@ cc_library(
":xla_computation",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:support",
],
)
@@ -159,6 +165,7 @@ cc_library(
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -186,6 +193,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
+ "@com_google_absl//absl/memory",
],
)
@@ -211,6 +219,10 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index d0ce5e8a6a..5dde5b432f 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -18,15 +18,15 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
-StatusOr<std::unique_ptr<Literal>> Client::Transfer(
- const GlobalData& data, const Shape* shape_with_layout) {
+StatusOr<Literal> Client::Transfer(const GlobalData& data,
+ const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
@@ -89,7 +89,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
"TransferToServer request");
}
- return MakeUnique<GlobalData>(stub_, response.data());
+ return absl::make_unique<GlobalData>(stub_, response.data());
}
Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
@@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
@@ -162,9 +162,8 @@ Status Client::ResetDevice() {
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+StatusOr<Literal> Client::ExecuteAndTransfer(
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
TF_ASSIGN_OR_RETURN(
@@ -178,8 +177,8 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
return Transfer(*data, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
- const XlaComputation& computation, const Layout* output_layout) const {
+StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
+ const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
@@ -212,8 +211,7 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
ExecuteGraphRequest request;
@@ -248,11 +246,11 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
}
}
- return MakeUnique<GlobalData>(stub_, response.output());
+ return absl::make_unique<GlobalData>(stub_, response.output());
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
- tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
+ absl::Span<const XlaComputationInstance> computations) {
ExecuteGraphParallelRequest request;
for (const XlaComputationInstance& computation : computations) {
@@ -278,7 +276,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
std::vector<std::unique_ptr<GlobalData>> outputs;
for (size_t i = 0; i < computations.size(); ++i) {
outputs.push_back(
- MakeUnique<GlobalData>(stub_, response.responses(i).output()));
+ absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
if (computations[i].execution_profile != nullptr) {
*computations[i].execution_profile = response.responses(i).profile();
}
@@ -340,7 +338,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
std::vector<std::unique_ptr<GlobalData>> handles;
for (auto& handle : response.element_handles()) {
- handles.push_back(MakeUnique<GlobalData>(stub_, handle));
+ handles.push_back(absl::make_unique<GlobalData>(stub_, handle));
}
return std::move(handles);
}
@@ -369,7 +367,7 @@ StatusOr<ComputationStats> Client::GetComputationStats(
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
const XlaComputation& computation) {
TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
- return MakeUnique<ProgramShape>(result);
+ return absl::make_unique<ProgramShape>(result);
}
StatusOr<Shape> Client::GetShape(const GlobalData& data) {
@@ -400,7 +398,7 @@ StatusOr<string> Client::ExecutionStatsAsString(
int64 nanoseconds = profile.compute_time_ns();
int64 cycle_count = profile.compute_cycle_count();
double gflops = total_flops / nanoseconds;
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"[Execution Statistics] flop count: ", computation_stats.flop_count(),
", transcendental count: ", computation_stats.transcendental_count(),
", compute execution time: ", nanoseconds, " nsec",
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index be50cebfcc..6f4d33c469 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -53,7 +53,7 @@ class Client {
// will be filled with profile data from the execution.
StatusOr<std::unique_ptr<GlobalData>> Execute(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
@@ -82,7 +82,7 @@ class Client {
// from each computation.
//
StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
- tensorflow::gtl::ArraySlice<XlaComputationInstance> computations);
+ absl::Span<const XlaComputationInstance> computations);
// Requests device_count device handles available on the target. The returned
// device handles are used to specify the devices to execute the computations
@@ -96,8 +96,8 @@ class Client {
//
// If shape_with_layout is not nullptr, it points to a shape whose layout will
// be the layout of the returned literal.
- StatusOr<std::unique_ptr<Literal>> Transfer(
- const GlobalData& data, const Shape* shape_with_layout = nullptr);
+ StatusOr<Literal> Transfer(const GlobalData& data,
+ const Shape* shape_with_layout = nullptr);
// Transfer the given literal to the server. This allocates memory on the
// device and copies the literal's contents over. Returns a global data handle
@@ -122,7 +122,7 @@ class Client {
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+ StatusOr<Literal> TransferFromOutfeed(
const Shape* shape_with_layout, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
@@ -132,9 +132,9 @@ class Client {
// Executes the computation with the given arguments and transfers the result
// to the client as a literal. Parameters are defined the same as for
// Execute() and Transfer().
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options = nullptr,
ExecutionProfile* execution_profile = nullptr);
@@ -153,7 +153,7 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ StatusOr<Literal> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;
diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc
index 803a9e4009..27b7fa7b29 100644
--- a/tensorflow/compiler/xla/client/client_library.cc
+++ b/tensorflow/compiler/xla/client/client_library.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default;
service_options.set_intra_op_parallelism_threads(
options.intra_op_parallelism_threads());
- auto instance = MakeUnique<LocalInstance>();
+ auto instance = absl::make_unique<LocalInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
LocalService::NewService(service_options));
- instance->client = MakeUnique<LocalClient>(instance->service.get());
+ instance->client = absl::make_unique<LocalClient>(instance->service.get());
LocalClient* cl = instance->client.get();
client_library.local_instances_.insert(
@@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
return it->second->client.get();
}
- auto instance = MakeUnique<CompileOnlyInstance>();
+ auto instance = absl::make_unique<CompileOnlyInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
CompileOnlyService::NewService(platform));
- instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
+ instance->client =
+ absl::make_unique<CompileOnlyClient>(instance->service.get());
CompileOnlyClient* cl = instance->client.get();
client_library.compile_only_instances_.insert(
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index 5c9abad4c3..a6c58cb175 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -15,15 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/compile_only_client.h"
+#include "absl/memory/memory.h"
#include "llvm/ADT/Triple.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
@@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime(
metadata);
}
-int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
+int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) {
llvm::Triple llvm_triple(
llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size())));
if (llvm_triple.isArch64Bit()) {
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
index a551edeab0..9e3ed23734 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.h
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -52,12 +52,12 @@ class CompileOnlyClient : public Client {
// code. |metadata|, if provided, is populated during compilation.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
// Returns the size of a pointer in bytes for a given triple.
- static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
+ static int64 PointerSizeForTriple(absl::string_view triple);
private:
CompileOnlyService* compiler_service_;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index 7dee41f6a0..0f1745366b 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -59,10 +59,10 @@ string ExecutableBuildOptions::ToString() const {
if (generate_hlo_graph_.has_value()) {
generate_hlo_graph = generate_hlo_graph_.value();
}
- return tensorflow::strings::Printf(
+ return absl::StrFormat(
"ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
"generate_hlo_graph=%s}",
- device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str());
+ device_ordinal_, result_layout, generate_hlo_graph);
}
ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
@@ -71,41 +71,41 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
return *this;
}
-const tensorflow::gtl::optional<string>&
-ExecutableBuildOptions::generate_hlo_graph() const {
+const absl::optional<string>& ExecutableBuildOptions::generate_hlo_graph()
+ const {
return generate_hlo_graph_;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
- dump_optimized_hlo_proto_to_ = dirpath.ToString();
+ absl::string_view dirpath) {
+ dump_optimized_hlo_proto_to_ = string(dirpath);
return *this;
}
-const tensorflow::gtl::optional<string>&
+const absl::optional<string>&
ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
return dump_optimized_hlo_proto_to_;
}
ExecutableBuildOptions&
ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
- dump_unoptimized_hlo_proto_to_ = dirpath.ToString();
+ absl::string_view dirpath) {
+ dump_unoptimized_hlo_proto_to_ = string(dirpath);
return *this;
}
-const tensorflow::gtl::optional<string>&
+const absl::optional<string>&
ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const {
return dump_unoptimized_hlo_proto_to_;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
- tensorflow::StringPiece dirpath) {
- dump_per_pass_hlo_proto_to_ = dirpath.ToString();
+ absl::string_view dirpath) {
+ dump_per_pass_hlo_proto_to_ = string(dirpath);
return *this;
}
-const tensorflow::gtl::optional<string>&
+const absl::optional<string>&
ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const {
return dump_per_pass_hlo_proto_to_;
}
@@ -115,7 +115,7 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) {
return *this;
}
-tensorflow::gtl::optional<bool> ExecutableBuildOptions::hlo_profile() const {
+absl::optional<bool> ExecutableBuildOptions::hlo_profile() const {
return hlo_profile_;
}
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 9dc9be4423..93334db88b 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -57,37 +57,36 @@ class ExecutableBuildOptions {
// If set, specifies a regexp of HLO graphs to dump (as in DebugOptions).
ExecutableBuildOptions& set_generate_hlo_graph(string regex);
- const tensorflow::gtl::optional<string>& generate_hlo_graph() const;
+ const absl::optional<string>& generate_hlo_graph() const;
// If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_optimized_hlo_proto_to(
- tensorflow::StringPiece dirpath);
- const tensorflow::gtl::optional<string>& dump_optimized_hlo_proto_to() const;
+ absl::string_view dirpath);
+ const absl::optional<string>& dump_optimized_hlo_proto_to() const;
// If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to(
- tensorflow::StringPiece dirpath);
- const tensorflow::gtl::optional<string>& dump_unoptimized_hlo_proto_to()
- const;
+ absl::string_view dirpath);
+ const absl::optional<string>& dump_unoptimized_hlo_proto_to() const;
// If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
// to (as in DebugOptions).
ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
- tensorflow::StringPiece dirpath);
- const tensorflow::gtl::optional<string>& dump_per_pass_hlo_proto_to() const;
+ absl::string_view dirpath);
+ const absl::optional<string>& dump_per_pass_hlo_proto_to() const;
// If true, specifies that we should record an HLO profile during execution
// and log it after execution (as in DebugOptions). If nullopt the default is
// used.
ExecutableBuildOptions& set_hlo_profile(bool enabled);
- tensorflow::gtl::optional<bool> hlo_profile() const;
+ absl::optional<bool> hlo_profile() const;
- void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) {
+ void add_disabled_hlo_pass(absl::string_view pass_name) {
disabled_hlo_passes_.push_back(std::string(pass_name));
}
- const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
+ const absl::Span<const std::string> disabled_hlo_passes() const {
return disabled_hlo_passes_;
}
@@ -96,14 +95,14 @@ class ExecutableBuildOptions {
string ToString() const;
private:
- tensorflow::gtl::optional<bool> hlo_profile_;
+ absl::optional<bool> hlo_profile_;
int device_ordinal_ = -1;
Shape result_layout_;
bool result_layout_set_ = false;
- tensorflow::gtl::optional<string> generate_hlo_graph_;
- tensorflow::gtl::optional<string> dump_optimized_hlo_proto_to_;
- tensorflow::gtl::optional<string> dump_unoptimized_hlo_proto_to_;
- tensorflow::gtl::optional<string> dump_per_pass_hlo_proto_to_;
+ absl::optional<string> generate_hlo_graph_;
+ absl::optional<string> dump_optimized_hlo_proto_to_;
+ absl::optional<string> dump_unoptimized_hlo_proto_to_;
+ absl::optional<string> dump_per_pass_hlo_proto_to_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
std::vector<std::string> disabled_hlo_passes_;
};
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index a2f32ab97e..a18c94c4e6 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -31,7 +31,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -65,6 +65,17 @@ xla_test(
)
cc_library(
+ name = "conv_grad_size_util",
+ srcs = ["conv_grad_size_util.cc"],
+ hdrs = ["conv_grad_size_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "math",
srcs = ["math.cc"],
hdrs = ["math.h"],
@@ -102,7 +113,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -128,9 +139,9 @@ cc_library(
deps = [
":arithmetic",
":constants",
- "//tensorflow/compiler/tf2xla/lib:util",
+ ":conv_grad_size_util",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -142,6 +153,7 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -209,5 +221,6 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 9225b1acd6..e86c10f030 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
@@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
b = builder->CreateSubBuilder(name);
} else {
b = builder->CreateSubBuilder(
- tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
+ absl::StrCat(name, "_", PrimitiveType_Name(type)));
}
const Shape scalar = ShapeUtil::MakeShape(type, {});
diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc
index 031d62e4ff..1ada7b4a96 100644
--- a/tensorflow/compiler/xla/client/lib/constants.cc
+++ b/tensorflow/compiler/xla/client/lib/constants.cc
@@ -56,7 +56,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
std::numeric_limits<double>::epsilon());
default:
return builder->ReportError(InvalidArgument(
- "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str()));
+ "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
}
}
diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h
index 0c8a9b8cc0..81624614c1 100644
--- a/tensorflow/compiler/xla/client/lib/constants.h
+++ b/tensorflow/compiler/xla/client/lib/constants.h
@@ -37,13 +37,13 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
primitive_util::IsComplexType(type))) {
return builder->ReportError(InvalidArgument(
"Invalid cast from floating point type to %s in ConstantR0WithType.",
- PrimitiveType_Name(type).c_str()));
+ PrimitiveType_Name(type)));
}
if (std::is_same<T, complex64>::value &&
!primitive_util::IsComplexType(type)) {
return builder->ReportError(InvalidArgument(
"Invalid cast from complex type to %s in ConstantR0WithType.",
- PrimitiveType_Name(type).c_str()));
+ PrimitiveType_Name(type)));
}
switch (type) {
case F16:
@@ -71,7 +71,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
default:
return builder->ReportError(
InvalidArgument("Invalid type for ConstantR0WithType (%s).",
- PrimitiveType_Name(type).c_str()));
+ PrimitiveType_Name(type)));
}
}
diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc
new file mode 100644
index 0000000000..a4c50a5491
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+namespace {
+
+StatusOr<SpatialDimensionOutputSizeAndPadding> GetWindowedOutputSize(
+ int64 input_size, int64 filter_size, int64 dilation_rate, int64 stride,
+ Padding padding_type) {
+ if (stride <= 0) {
+ return tensorflow::errors::InvalidArgument("Stride must be > 0, but got ",
+ stride);
+ }
+ if (dilation_rate < 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Dilation rate must be >= 1, but got ", dilation_rate);
+ }
+
+ int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
+ SpatialDimensionOutputSizeAndPadding dim;
+ switch (padding_type) {
+ case Padding::kValid:
+ dim.output_size = (input_size - effective_filter_size + stride) / stride;
+ dim.pad_before = dim.pad_after = 0;
+ break;
+ case Padding::kSame:
+ dim.output_size = (input_size + stride - 1) / stride;
+ const int64 padding_needed =
+ std::max(int64{0}, (dim.output_size - 1) * stride +
+ effective_filter_size - input_size);
+ // For odd values of total padding, add more padding on the "after" side
+ // of the given dimension.
+ dim.pad_before = padding_needed / 2;
+ dim.pad_after = padding_needed - dim.pad_before;
+ break;
+ }
+ if (dim.output_size < 0) {
+ return tensorflow::errors::InvalidArgument(
+ "Computed output size would be negative: ", dim.output_size,
+ " [input_size: ", input_size,
+ ", effective_filter_size: ", effective_filter_size,
+ ", stride: ", stride, "]");
+ }
+ return dim;
+}
+
+} // namespace
+
+StatusOr<SpatialDimensionOutputSizeAndPadding>
+ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size,
+ int64 output_size, int64 dilation,
+ int64 stride, Padding padding) {
+ TF_ASSIGN_OR_RETURN(SpatialDimensionOutputSizeAndPadding output_dim,
+ GetWindowedOutputSize(input_size, filter_size, dilation,
+ stride, padding));
+ if (output_size != output_dim.output_size) {
+ return tensorflow::errors::InvalidArgument(
+ "Size of out_backprop doesn't match computed: ", "actual = ",
+ output_size, ", computed = ", output_dim.output_size,
+ " input: ", input_size, " filter: ", filter_size,
+ " output: ", output_size, " stride: ", stride, " dilation: ", dilation);
+ }
+
+ SpatialDimensionOutputSizeAndPadding dim;
+ int64 effective_filter_size = (filter_size - 1) * dilation + 1;
+ dim.output_size = (output_dim.output_size - 1) * stride + 1;
+ const auto padded_out_size = input_size + effective_filter_size - 1;
+ dim.pad_before = effective_filter_size - 1 - output_dim.pad_before;
+ dim.pad_after = padded_out_size - dim.output_size - dim.pad_before;
+ VLOG(2) << "expanded_out = " << dim.output_size
+ << ", effective_filter_size = " << effective_filter_size
+ << ", padded_out = " << padded_out_size
+ << ", pad_before = " << dim.pad_before
+ << ", pad_after = " << dim.pad_after << ", dilation = " << dilation
+ << ", strides = " << stride;
+ return dim;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h
new file mode 100644
index 0000000000..0ad01728e6
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_
+
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Information about a single spatial dimension for a convolution gradients and
+// windowed operations.
+struct SpatialDimensionOutputSizeAndPadding {
+ // Effective size of the operation output (potentially expanded).
+ int64 output_size;
+ // Number of padding elements to be added before/after this dimension of
+ // the input when computing the input gradient.
+ int64 pad_before;
+ int64 pad_after;
+};
+
+// Verifies that the dimensions all match, and computes the size and padding of
+// a spatial dimension for convolution gradient operations.
+StatusOr<SpatialDimensionOutputSizeAndPadding>
+ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size,
+ int64 output_size, int64 dilation,
+ int64 stride, Padding padding);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index 0221de7672..d3d7edb42a 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -69,8 +69,7 @@ std::array<float, 6> kErfUCoefficient = {
// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients) {
+XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
XlaOp poly = ScalarLike(x, 0.0);
for (float c : coefficients) {
poly = poly * x + ScalarLike(x, c);
@@ -207,7 +206,11 @@ XlaOp Lgamma(XlaOp input) {
XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x);
- XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y;
+ // If z = a + 0j, the analytic continuation of log reduces to taking the
+ // absolute value of the real part.
+ // Re(log(z)) = Re(log|z| + arg(z)j)
+ // = log|a|
+ XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y;
XlaOp result = Select(need_to_reflect, reflection, log_y);
return result;
}
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index 13db232556..a6cafd4207 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand);
// Evaluates a polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
-XlaOp EvaluatePolynomial(XlaOp x,
- tensorflow::gtl::ArraySlice<float> coefficients);
+XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients);
// Computes an approximation of the error function complement (1 - erf(x)).
XlaOp Erfc(XlaOp x);
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index 1c91237ae1..377654220b 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -16,61 +16,13 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
-namespace {
-
-template <typename T>
-XlaOp MakeIota(XlaBuilder* builder, int64 size) {
- std::vector<T> values(size);
- for (int64 i = 0; i < size; ++i) {
- values[i] = static_cast<T>(i);
- }
- return ConstantR1<T>(builder, values);
-}
-
-} // namespace
-
-XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
- switch (type) {
- case S8:
- return MakeIota<int8>(builder, size);
- case S16:
- return MakeIota<int16>(builder, size);
- case S32:
- return MakeIota<int32>(builder, size);
- case S64:
- return MakeIota<int64>(builder, size);
- case U8:
- return MakeIota<uint8>(builder, size);
- case U16:
- return MakeIota<uint16>(builder, size);
- case U32:
- return MakeIota<uint32>(builder, size);
- case U64:
- return MakeIota<uint64>(builder, size);
- case BF16:
- return MakeIota<bfloat16>(builder, size);
- case F16:
- return MakeIota<half>(builder, size);
- case F32:
- return MakeIota<float>(builder, size);
- case F64:
- return MakeIota<double>(builder, size);
- case C64:
- return MakeIota<complex64>(builder, size);
- default:
- return builder->ReportError(
- InvalidArgument("Unimplemented type for Iota: %s.",
- PrimitiveType_Name(type).c_str()));
- }
-}
-
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
int64 n) {
auto a = Iota(builder, type, m);
@@ -87,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ absl::Span<const int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
@@ -114,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) {
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
- tensorflow::gtl::ArraySlice<int64> major_dims(
- AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2);
+ absl::Span<const int64> major_dims =
+ AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, U32, n);
auto b = Iota(builder, U32, m);
xla::XlaOp indicator;
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
index 8a96ec68d2..7d6aedd494 100644
--- a/tensorflow/compiler/xla/client/lib/numeric_test.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase {
void TestMatrixDiagonal();
};
-// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to
-// xla::Iota. This test is already implemented for xla::IotaGen in
-// xla/tests/iota_test.cc.
-XLA_TEST_F(NumericTest, Iota) {
- XlaBuilder builder(TestName());
- Iota(&builder, S32, 10);
-
- ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
-}
-
XLA_TEST_F(NumericTest, Triangle) {
XlaBuilder builder(TestName());
Array3D<int32> input(2, 3, 4);
diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc
index 7199269a6c..1979c867a4 100644
--- a/tensorflow/compiler/xla/client/lib/pooling.cc
+++ b/tensorflow/compiler/xla/client/lib/pooling.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
-#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
namespace xla {
@@ -26,11 +26,9 @@ namespace {
// element of an image by the count of elements that contributed to that
// element during pooling.
XlaOp AvgPoolDivideByCountWithGeneralPadding(
- XlaOp sums, PrimitiveType dtype,
- tensorflow::gtl::ArraySlice<int64> input_shape,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- tensorflow::gtl::ArraySlice<int64> ksize,
- tensorflow::gtl::ArraySlice<int64> stride,
+ XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ absl::Span<const int64> ksize, absl::Span<const int64> stride,
const TensorFormat& data_format) {
// The padding shouldn't be included in the counts. We use another
// ReduceWindow to find the right counts.
@@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding(
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -89,11 +87,9 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
// Creates a padding configuration out of spatial padding values.
PaddingConfig MakeSpatialPaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ int num_spatial_dims, absl::Span<const int64> stride,
const TensorFormat& data_format) {
- const int num_spatial_dims = kernel_size.size() - 2;
PaddingConfig padding_config;
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
padding_config.add_dimensions();
@@ -109,10 +105,33 @@ PaddingConfig MakeSpatialPaddingConfig(
return padding_config;
}
+XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ PrimitiveType dtype, const TensorFormat& data_format,
+ bool counts_include_padding) {
+ if (counts_include_padding) {
+ // If counts include padding, all windows have the same number of elements
+ // contributing to each average. Divide by the window size everywhere to get
+ // the average.
+ int64 window_size =
+ std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1,
+ [](int64 a, int64 b) { return a * b; });
+ auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size);
+
+ return pooled / divisor;
+ } else {
+ return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size,
+ padding, window_dimensions,
+ window_strides, data_format);
+ }
+}
+
} // namespace
-XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format) {
XlaBuilder* b = operand.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -125,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
});
}
-XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding) {
XlaBuilder* b = operand.builder();
@@ -137,32 +156,22 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
auto init_value = Zero(b, dtype);
std::vector<int64> input_size(operand_shape.dimensions().begin(),
operand_shape.dimensions().end());
- auto padding_config =
- MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format);
+ const int num_dims = kernel_size.size();
+ const int num_spatial_dims = num_dims - 2;
+ auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims,
+ stride, data_format);
auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
data_format);
- if (counts_include_padding) {
- // If counts include padding, all windows have the same number of elements
- // contributing to each average. Divide by the window size everywhere to
- // get the average.
- int64 window_size =
- std::accumulate(kernel_size.begin(), kernel_size.end(), 1,
- [](int64 x, int64 y) { return x * y; });
-
- auto divisor = ConstantR0WithType(b, dtype, window_size);
- return pooled / divisor;
- } else {
- return AvgPoolDivideByCountWithGeneralPadding(
- pooled, dtype, input_size, padding, kernel_size, stride, data_format);
- }
+ return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride,
+ padding, dtype, data_format,
+ counts_include_padding);
});
}
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
- tensorflow::gtl::ArraySlice<int64> input_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format) {
const int num_spatial_dims = kernel_size.size() - 2;
std::vector<int64> input_spatial_dimensions;
@@ -180,4 +189,101 @@ std::vector<std::pair<int64, int64>> MakeSpatialPadding(
stride_spatial_dimensions, padding);
}
+XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ const TensorFormat& data_format,
+ const bool counts_include_padding) {
+ XlaBuilder* b = out_backprop.builder();
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ const int num_dims = kernel_size.size();
+
+ if (gradients_size.size() != num_dims) {
+ return tensorflow::errors::InvalidArgument("gradients must be ", num_dims,
+ "-dimensional");
+ }
+
+ TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape,
+ b->GetShape(out_backprop));
+ if (out_backprop_xla_shape.dimensions().size() != num_dims) {
+ return tensorflow::errors::InvalidArgument("out_backprop must be ",
+ num_dims, "-dimensional");
+ }
+
+ // We can think of average-pooling as:
+ // * a convolution with a kernel consisting entirely of 1s, where the
+ // input feature and output feature are equal, and 0s everywhere else.
+ // * followed by dividing by the counts.
+ //
+ // This then gives us an algorithm to build the gradient:
+ // * divide out_backprop by the counts, followed by
+ // * Conv2DBackpropInput specialized for that kernel, which simplifies to
+ // a Pad and a ReduceWindow.
+ //
+ // For an explanation of backpropagation for convolution, see the comments
+ // in third_party/tensorflow/core/kernels/conv_grad_ops.h
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+
+ // The input gradients are computed by a convolution of the output gradients
+ // and the filter, with some appropriate padding. See the comment at the top
+ // of conv_grad_ops.h for details.
+ PrimitiveType dtype = out_backprop_xla_shape.element_type();
+ auto out_backprop_div = AvgPoolDivideByCount(
+ out_backprop, gradients_size, kernel_size, stride, spatial_padding,
+ dtype, data_format, counts_include_padding);
+
+ // Pad the gradients in the spatial dimensions. We use the same padding
+ // as Conv2DBackpropInput.
+ PaddingConfig padding_config = MakeNoPaddingConfig(num_dims);
+ std::vector<int64> padded_gradients_size(gradients_size.begin(),
+ gradients_size.end());
+ // First, pad the output gradients the same way as the input. The additional
+ // padding will be removed as a last step before returning the input
+ // gradients.
+ const int num_spatial_dims = num_dims - 2;
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ int dim = data_format.spatial_dimension(i);
+ padded_gradients_size[dim] +=
+ (spatial_padding[i].first + spatial_padding[i].second);
+ }
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ int dim = data_format.spatial_dimension(i);
+ TF_ASSIGN_OR_RETURN(
+ SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim,
+ ConvGradExtractAndVerifyDimension(
+ /*input_size=*/padded_gradients_size[dim],
+ /*filter_size=*/kernel_size[dim],
+ /*output_size=*/out_backprop_xla_shape.dimensions(dim),
+ /*dilation=*/1,
+ /*stride=*/stride[dim], /*padding=*/Padding::kValid));
+ auto* padding = padding_config.mutable_dimensions(dim);
+ padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before);
+ padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after);
+ padding->set_interior_padding(stride[dim] - 1);
+ }
+
+ auto zero = Zero(b, dtype);
+ auto padded_gradients = Pad(out_backprop_div, zero, padding_config);
+
+ // in_backprop = padded_gradients <conv> ones
+ std::vector<int64> ones(num_dims, 1LL);
+ auto in_backprop =
+ ReduceWindow(padded_gradients, Zero(b, dtype),
+ CreateScalarAddComputation(dtype, b), kernel_size,
+ /*window_strides=*/ones, Padding::kValid);
+ // The input padding doesn't contribute to the gradient, remove it.
+ std::vector<std::pair<int64, int64>> neg_spatial_padding;
+ neg_spatial_padding.reserve(spatial_padding.size());
+ for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) {
+ neg_spatial_padding.emplace_back(-spatial_padding_dim.first,
+ -spatial_padding_dim.second);
+ }
+ auto remove_padding_config = MakeSpatialPaddingConfig(
+ neg_spatial_padding, num_spatial_dims, stride, data_format);
+ return Pad(in_backprop, zero, remove_padding_config);
+ });
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h
index 1699c585d3..5c0054857d 100644
--- a/tensorflow/compiler/xla/client/lib/pooling.h
+++ b/tensorflow/compiler/xla/client/lib/pooling.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
@@ -25,7 +25,7 @@ namespace xla {
class TensorFormat {
public:
TensorFormat(int batch_dimension, int feature_dimension,
- tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
+ absl::Span<const int64> spatial_dimensions)
: batch_dimension_(batch_dimension),
feature_dimension_(feature_dimension),
spatial_dimensions_(spatial_dimensions.begin(),
@@ -45,29 +45,36 @@ class TensorFormat {
// The number of the dimension that represents the features.
int feature_dimension_;
// The dimension numbers for the spatial dimensions.
- tensorflow::gtl::InlinedVector<int, 4> spatial_dimensions_;
+ absl::InlinedVector<int, 4> spatial_dimensions_;
};
// Computes the max pool of 'operand'.
-XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format);
// Computes the average pool of 'operand'.
-XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding,
const TensorFormat& data_format,
const bool counts_include_padding);
// Returns the list of low and high padding elements in each spatial dimension
// for the given 'padding' specification.
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
- tensorflow::gtl::ArraySlice<int64> input_size,
- tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const TensorFormat& data_format);
+// Computes the average pool gradient.
+XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
+ absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> spatial_padding,
+ const TensorFormat& data_format,
+ const bool counts_include_padding);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc
index 4b4553b60d..30adb9b1ad 100644
--- a/tensorflow/compiler/xla/client/lib/pooling_test.cc
+++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/pooling.h"
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -22,7 +23,7 @@ namespace xla {
namespace {
TensorFormat MakeNCHWFormat(int num_spatial_dims) {
- tensorflow::gtl::InlinedVector<int64, 4> spatial_dimensions;
+ absl::InlinedVector<int64, 4> spatial_dimensions;
for (int i = 0; i < num_spatial_dims; ++i) {
spatial_dimensions.push_back(i + 2);
}
@@ -31,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) {
}
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
- XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
- tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
+ XlaOp input, absl::Span<const int64> kernel_size,
+ absl::Span<const int64> stride, Padding padding,
const xla::TensorFormat& data_format) {
XlaBuilder* b = input.builder();
Shape operand_shape = b->GetShape(input).ValueOrDie();
@@ -45,7 +46,7 @@ std::vector<std::pair<int64, int64>> MakeGeneralPadding(
// Add singleton batch and feature dimensions to spatial dimensions, according
// to 'data_format' specification.
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
- tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
+ absl::Span<const int64> spatial_dim_sizes,
const xla::TensorFormat& data_format) {
const int num_spatial_dims = spatial_dim_sizes.size();
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
@@ -181,5 +182,109 @@ XLA_TEST_F(PoolingTest,
error_spec_);
}
+XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) {
+ XlaBuilder builder(TestName());
+ for (bool counts_include_padding : {false, true}) {
+ XlaOp out_backprop = ConstantR4FromArray4D<float>(&builder, {{{{1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride,
+ {{0, 0}, {0, 0}}, MakeNCHWFormat(2),
+ /*counts_include_padding=*/counts_include_padding);
+ // Without padding, counts_include_padding makes no difference.
+ ComputeAndCompareR4<float>(
+ &builder, {{{{0.25, 0.25, 0.}, {0.25, 0.25, 0.}, {0., 0., 0.}}}}, {},
+ error_spec_);
+ }
+}
+
+XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) {
+ XlaBuilder builder(TestName());
+ for (bool counts_include_padding : {false, true}) {
+ XlaOp out_backprop =
+ ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride,
+ {{0, 0}, {0, 0}}, MakeNCHWFormat(2),
+ /*counts_include_padding=*/counts_include_padding);
+ // Without padding, counts_include_padding makes no difference.
+ ComputeAndCompareR4<float>(
+ &builder, {{{{0.25, 0.5, 0.25}, {0.5, 1., 0.5}, {0.25, 0.5, 0.25}}}},
+ {}, error_spec_);
+ }
+}
+
+XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) {
+ XlaBuilder builder(TestName());
+
+ XlaOp out_backprop =
+ ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
+ MakeNCHWFormat(2),
+ /*counts_include_padding=*/true);
+ ComputeAndCompareR4<float>(
+ &builder,
+ {{{{0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}}}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) {
+ XlaBuilder builder(TestName());
+
+ XlaOp out_backprop =
+ ConstantR4FromArray4D<float>(&builder, {{{{1., 1.}, {1., 1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
+ MakeNCHWFormat(2), false);
+ ComputeAndCompareR4<float>(
+ &builder, {{{{1., 0.5, 0.5}, {0.5, 0.25, 0.25}, {0.5, 0.25, 0.25}}}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) {
+ XlaBuilder builder(TestName());
+
+ XlaOp out_backprop =
+ ConstantR4FromArray4D<float>(&builder, {{{{1., 1., 1., 1.},
+ {1., 1., 1., 1.},
+ {1., 1., 1., 1.},
+ {1., 1., 1., 1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
+ MakeNCHWFormat(2), true);
+ ComputeAndCompareR4<float>(&builder,
+ {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(PoolingTest,
+ AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) {
+ XlaBuilder builder(TestName());
+
+ XlaOp out_backprop =
+ ConstantR4FromArray4D<float>(&builder, {{{{1., 1., 1., 1.},
+ {1., 1., 1., 1.},
+ {1., 1., 1., 1.},
+ {1., 1., 1., 1.}}}});
+ auto data_format = MakeNCHWFormat(2);
+ auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format);
+ auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format);
+ AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}},
+ MakeNCHWFormat(2), false);
+ ComputeAndCompareR4<float>(
+ &builder, {{{{2.25, 1.5, 2.25}, {1.5, 1., 1.5}, {2.25, 1.5, 2.25}}}}, {},
+ error_spec_);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 081fec7ad9..25cc37edc4 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
Client* client) {
- XlaBuilder b(
- tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
+ XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
BuildFakeDataOpOnDevice(shape, &b);
XlaComputation computation = b.Build().ConsumeValueOrDie();
@@ -77,7 +76,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
if (DataSizeOfShape(shape) < (1LL << 20)) {
- StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
+ StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via
// an on-device computation.
@@ -85,7 +84,7 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
tensorflow::error::UNIMPLEMENTED);
return MakeFakeDataViaDeviceOrDie(shape, client);
}
- return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
+ return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
}
// If the data is large, generate it on-device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index cffb24e29b..f96b6c9c26 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/ADT/Triple.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
@@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
}
Status LocalExecutable::ValidateExecutionOptions(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
const ComputationLayout& computation_layout =
executable_->module_config().entry_computation_layout();
@@ -59,7 +59,7 @@ Status LocalExecutable::ValidateExecutionOptions(
// Check argument number, shapes, and layouts.
if (arguments.size() != computation_layout.parameter_count()) {
return InvalidArgument(
- "invalid number of arguments for computation: expected %d, got %zu",
+ "invalid number of arguments for computation: expected %d, got %u",
computation_layout.parameter_count(), arguments.size());
}
for (int i = 0; i < arguments.size(); ++i) {
@@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions(
"parameter "
"%d: want %s, got %s",
i,
- ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
- .c_str(),
- ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
+ ShapeUtil::HumanString(
+ computation_layout.parameter_layout(i).shape()),
+ ShapeUtil::HumanString(arguments[i]->on_host_shape()));
}
}
@@ -88,8 +88,7 @@ Status LocalExecutable::ValidateExecutionOptions(
if (stream_platform != backend_->platform()) {
return InvalidArgument(
"stream is for platform %s, but service targets platform %s",
- stream_platform->Name().c_str(),
- backend_->platform()->Name().c_str());
+ stream_platform->Name(), backend_->platform()->Name());
}
// Cannot specify device_ordinal with a stream. The stream determines these
@@ -120,10 +119,10 @@ Status LocalExecutable::ValidateExecutionOptions(
return InvalidArgument(
"executable is built for device %s of type \"%s\"; cannot run it on "
"device %s of type \"%s\"",
- backend_->device_name(build_device_ordinal()).c_str(),
- build_executor->GetDeviceDescription().name().c_str(),
- backend_->device_name(run_device_ordinal).c_str(),
- run_executor->GetDeviceDescription().name().c_str());
+ backend_->device_name(build_device_ordinal()),
+ build_executor->GetDeviceDescription().name(),
+ backend_->device_name(run_device_ordinal),
+ run_executor->GetDeviceDescription().name());
}
if (!run_options.allocator()) {
@@ -133,15 +132,15 @@ Status LocalExecutable::ValidateExecutionOptions(
if (run_options.allocator()->platform() != backend.platform()) {
return InvalidArgument(
"allocator platform (%s) does not match service platform (%s)",
- run_options.allocator()->platform()->Name().c_str(),
- backend.platform()->Name().c_str());
+ run_options.allocator()->platform()->Name(),
+ backend.platform()->Name());
}
return Status::OK();
}
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options) {
TF_RETURN_IF_ERROR(
ValidateExecutionOptions(arguments, run_options, *backend_));
@@ -178,7 +177,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ const absl::Span<const ShapedBuffer* const> arguments) {
executable_->hlo_snapshot()->set_execution_platform(
backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
@@ -192,13 +191,12 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
}
Status LocalExecutable::RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*argument));
- *hlo_snapshot->add_arguments() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
+ *hlo_snapshot->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -206,13 +204,12 @@ Status LocalExecutable::RecordArguments(
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- LiteralFromShapedBuffer(*result));
- *hlo_snapshot->mutable_result() = literal->ToProto();
+ TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
+ *hlo_snapshot->mutable_result() = literal.ToProto();
return Status::OK();
}
-StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
+StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream,
backend_->BorrowStream(shaped_buffer.device_ordinal()));
@@ -246,7 +243,7 @@ Backend* LocalClient::mutable_backend() {
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options) {
ExecutableBuildOptions updated_options = options;
if (options.device_ordinal() == -1) {
@@ -257,9 +254,9 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(
computation, argument_layouts, updated_options));
- return WrapUnique(new LocalExecutable(std::move(executable),
- local_service_->mutable_backend(),
- updated_options));
+ return absl::WrapUnique(new LocalExecutable(std::move(executable),
+ local_service_->mutable_backend(),
+ updated_options));
}
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
@@ -278,7 +275,7 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
return std::move(scoped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
shaped_buffer.device_ordinal()));
@@ -299,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal,
literal);
}
-StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal) {
+StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, shape, literal.get()));
+ executor, shape, &literal));
return std::move(literal);
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index ae23809261..feb2f8ec9d 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -40,7 +40,7 @@ class LocalExecutable {
// Run the compiled computation with the given arguments and options and
// return the result.
StatusOr<ScopedShapedBuffer> Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
ExecutableRunOptions run_options);
// Return the options used to build the executable.
@@ -63,7 +63,7 @@ class LocalExecutable {
// The given ExecutableRunOptions override any values from legacy_flags
// (TF_XLA_FLAGS environment variable).
Status ValidateExecutionOptions(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
const ExecutableRunOptions& run_options, const Backend& backend);
// Records the computation in a SessionModule proto with the arguments used to
@@ -73,20 +73,18 @@ class LocalExecutable {
// (TF_XLA_FLAGS environment variable).
StatusOr<ScopedShapedBuffer> ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ const absl::Span<const ShapedBuffer* const> arguments);
// Records the arguments used to invoke the computation in a SessionModule
// proto.
- Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- HloSnapshot* hlo_snapshot);
+ Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
+ HloSnapshot* hlo_snapshot);
// Records the result of the computation in a SessionModule proto.
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
- StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
@@ -120,7 +118,7 @@ class LocalClient : public Client {
// (TF_XLA_FLAGS environment variable).
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options);
// Copy the literal data to the device with the given ordinal and return as a
@@ -133,8 +131,7 @@ class LocalClient : public Client {
// Copy the data from the device contained in the given ShapedBuffer and
// return as a Literal.
- StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
@@ -152,8 +149,8 @@ class LocalClient : public Client {
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
- StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
- const Shape& shape, int device_ordinal);
+ StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
+ int device_ordinal);
// Returns the device ordinal that corresponds to the given replica number.
//
diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc
index 6a9cf466ac..992b13139c 100644
--- a/tensorflow/compiler/xla/client/padding.cc
+++ b/tensorflow/compiler/xla/client/padding.cc
@@ -23,16 +23,15 @@ limitations under the License.
namespace xla {
-Status ValidatePaddingValues(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides) {
+Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides) {
bool ok = input_dimensions.size() == window_dimensions.size() &&
input_dimensions.size() == window_strides.size();
if (!ok) {
return InvalidArgument(
- "Want input dimensions size %zu = window dimensions size %zu = window "
- "strides size %zu",
+ "Want input dimensions size %u = window dimensions size %u = window "
+ "strides size %u",
input_dimensions.size(), window_dimensions.size(),
window_strides.size());
}
@@ -40,9 +39,9 @@ Status ValidatePaddingValues(
}
std::vector<std::pair<int64, int64>> MakePadding(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding) {
TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions,
window_strides));
std::vector<std::pair<int64, int64>> low_high_padding;
diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h
index e23b0b3a90..5c009bd49e 100644
--- a/tensorflow/compiler/xla/client/padding.h
+++ b/tensorflow/compiler/xla/client/padding.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -41,10 +41,9 @@ enum class Padding {
// Validates that the slices are acceptable for determining padding -- this can
// be used to check the preconditions of MakePadding below to produce an error
// message that can be returned to the user.
-Status ValidatePaddingValues(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides);
+Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides);
// Returns the padding needed for the base area, given the base area dimensions,
// window dimensions, strides, and the type of padding.
@@ -58,9 +57,9 @@ Status ValidatePaddingValues(
// window_dimensions, and strides must match, which is equal to the number
// of elements in the result vector.
std::vector<std::pair<int64, int64>> MakePadding(
- tensorflow::gtl::ArraySlice<int64> input_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ absl::Span<const int64> input_dimensions,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/sharding_builder.h b/tensorflow/compiler/xla/client/sharding_builder.h
index 34763e54d9..59df3a8762 100644
--- a/tensorflow/compiler/xla/client/sharding_builder.h
+++ b/tensorflow/compiler/xla/client/sharding_builder.h
@@ -56,4 +56,4 @@ OpSharding Tuple(const ShapeTree<OpSharding>& shardings);
} // namespace sharding_builder
} // namespace xla
-#endif
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index aa47f992bc..4e1ff9e5c0 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -21,19 +21,24 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
namespace {
@@ -67,7 +72,7 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
if (!ShapeUtil::ElementIsIntegral(shape)) {
return InvalidArgument(
"Argument to >> operator does not have an integral type (%s).",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
if (ShapeUtil::ElementIsSigned(shape)) {
return ShiftRightArithmetic(x, y);
@@ -85,7 +90,7 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
}
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) const {
+ absl::Span<const XlaOp> operands) const {
std::vector<Shape> operand_shapes;
for (const XlaOp& operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
@@ -194,7 +199,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
// TODO(b/33009255): Implmement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
- case HloOpcode::kHostCompute:
case HloOpcode::kCall:
// TODO(b/32495713): We aren't checking the to_apply computation itself,
// so we conservatively say that computations containing the Call op
@@ -221,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() {
auto build_status = Build();
if (!build_status.ok()) {
parent_builder_->ReportError(
- AddStatus(build_status.status(),
- tensorflow::strings::StrCat("error from: ", name_)));
+ AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
return {};
}
return build_status.ConsumeValueOrDie();
@@ -288,7 +291,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
@@ -349,9 +352,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
});
}
-XlaOp XlaBuilder::BinaryOp(
- HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -445,12 +447,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
@@ -463,14 +465,27 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
});
}
+XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape;
+ instr.add_dimensions(iota_dimension);
+ return AddInstruction(std::move(instr), HloOpcode::kIota);
+ });
+}
+
+XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
+ return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
+}
+
XlaOp XlaBuilder::Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ absl::Span<const XlaOp> operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
@@ -489,7 +504,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!parameter_numbers_.insert(parameter_number).second) {
- return InvalidArgument("parameter %lld already registered",
+ return InvalidArgument("parameter %d already registered",
parameter_number);
}
instr.set_parameter_number(parameter_number);
@@ -499,8 +514,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
});
}
-XlaOp XlaBuilder::Broadcast(
- const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
+ absl::Span<const int64> broadcast_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -525,7 +540,7 @@ XlaOp XlaBuilder::Broadcast(
XlaOp XlaBuilder::BroadcastInDim(
const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ const absl::Span<const int64> broadcast_dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return InDimBroadcast(shape, operand, broadcast_dimensions);
});
@@ -540,9 +555,9 @@ StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
}
XlaOp XlaBuilder::Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -577,7 +592,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
}
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -615,15 +630,15 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
});
}
-XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
@@ -655,8 +670,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
@@ -670,7 +685,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
@@ -680,7 +695,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
}
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
@@ -690,8 +705,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
// Out-of-order collapse is not supported.
// Checks that the collapsed dimensions are in order and consecutive.
- for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
- i < dimensions.size(); ++i) {
+ for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
if (dimensions[i] - 1 != dimensions[i - 1]) {
return InvalidArgument(
"Collapsed dimensions are not in consecutive order.");
@@ -703,8 +717,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
- VLOG(3) << "dims to collapse: "
- << tensorflow::str_util::Join(dimensions, ",");
+ VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
std::vector<int64> new_sizes;
for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
@@ -715,8 +728,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
}
- VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
- << "]";
+ VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
return Reshape(operand, new_sizes);
});
@@ -726,7 +738,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
@@ -744,13 +756,13 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
});
}
-XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
+XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
@@ -765,7 +777,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
if (!ShapeUtil::IsTuple(tuple_shape)) {
return InvalidArgument(
"Operand to GetTupleElement() is not a tuple; got %s",
- ShapeUtil::HumanString(tuple_shape).c_str());
+ ShapeUtil::HumanString(tuple_shape));
}
*instr.mutable_shape() =
ShapeUtil::GetTupleElementShape(tuple_shape, index);
@@ -778,36 +790,37 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
}
-XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
+XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -815,12 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape.dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers);
+ return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
});
}
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers) {
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -829,6 +843,9 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
+ }
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
});
}
@@ -840,16 +857,14 @@ Status XlaBuilder::VerifyConvolution(
return InvalidArgument(
"Convolution arguments must have same number of "
"dimensions. Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str());
+ ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
}
int num_dims = ShapeUtil::Rank(lhs_shape);
if (num_dims < 2) {
return InvalidArgument(
"Convolution expects argument arrays with >= 3 dimensions. "
"Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str());
+ ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
}
int num_spatial_dims = num_dims - 2;
@@ -863,7 +878,7 @@ Status XlaBuilder::VerifyConvolution(
}
for (int i = 0; i < numbers.size(); ++i) {
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
- return InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
+ return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
field_name, i, numbers.Get(i));
}
}
@@ -881,29 +896,28 @@ Status XlaBuilder::VerifyConvolution(
}
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count) {
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count) {
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
- feature_group_count);
+ feature_group_count, precision_config);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -930,28 +944,27 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count,
+ precision_config);
});
}
XlaOp XlaBuilder::ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count,
+ precision_config);
}
XlaOp XlaBuilder::ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -972,34 +985,37 @@ XlaOp XlaBuilder::ConvGeneralDilated(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferConvolveShape(
- lhs_shape, rhs_shape, instr.window(),
- dimension_numbers, feature_group_count));
+ lhs_shape, rhs_shape, feature_group_count,
+ instr.window(), dimension_numbers));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
+ }
+
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
{lhs, rhs});
});
}
StatusOr<Window> XlaBuilder::MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation) const {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation) const {
const auto verify_size = [&](const size_t x, const char* x_name) {
if (x == 0 || x == window_dimensions.size()) {
return Status::OK();
} else {
return InvalidArgument(
- "%s", tensorflow::strings::StrCat(
+ "%s", absl::StrCat(
"Window has different number of window dimensions than of ",
x_name,
"\nNumber of window dimensions: ", window_dimensions.size(),
- "\nNumber of ", x_name, ": ", x, "\n")
- .c_str());
+ "\nNumber of ", x_name, ": ", x, "\n"));
}
};
TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
@@ -1039,7 +1055,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
}
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1175,8 +1191,8 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
return InvalidArgument(
"Outfeed shape %s must be compatible with operand shape %s",
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
- ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ ShapeUtil::HumanStringWithLayout(shape_with_layout),
+ ShapeUtil::HumanStringWithLayout(operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout;
@@ -1228,8 +1244,8 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
return InvalidArgument(
"Outfeed shape %s must be compatible with operand shape %s",
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
- ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ ShapeUtil::HumanStringWithLayout(shape_with_layout),
+ ShapeUtil::HumanStringWithLayout(operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout;
@@ -1248,7 +1264,7 @@ XlaOp XlaBuilder::CreateToken() {
});
}
-XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (tokens.empty()) {
return InvalidArgument("AfterAll requires at least one operand");
@@ -1260,15 +1276,15 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+ absl::Span<const XlaOp> operands,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- if (tensorflow::str_util::StartsWith(call_target_name, "$")) {
+ if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
- call_target_name.c_str());
+ call_target_name);
}
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
@@ -1276,21 +1292,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
});
}
-XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name,
- int64 cost_estimate_ns, const Shape& shape) {
- return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
- *instr.mutable_shape() = shape;
- instr.set_channel_name(channel_name);
- instr.set_cost_estimate_ns(cost_estimate_ns);
- return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands);
- });
-}
-
-XlaOp XlaBuilder::Complex(
- const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions);
}
@@ -1299,42 +1302,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) {
}
XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions);
}
@@ -1342,22 +1345,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNot, operand);
}
-XlaOp XlaBuilder::ShiftLeft(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
broadcast_dimensions);
}
XlaOp XlaBuilder::ShiftRightLogical(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
broadcast_dimensions);
}
@@ -1366,9 +1368,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) {
return UnaryOp(HloOpcode::kAbs, operand);
}
-XlaOp XlaBuilder::Atan2(
- const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x,
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions);
}
@@ -1433,7 +1434,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
}
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+ absl::Span<const int64> permutation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1448,7 +1449,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
}
XlaOp XlaBuilder::Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1462,7 +1463,7 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional<XlaOp> values,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1490,7 +1491,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
}
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
}
@@ -1528,10 +1529,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
return TernaryOp(HloOpcode::kClamp, min, operand, max);
}
-XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!static_operands.empty()) {
return Unimplemented("static_operands is not supported in Map");
@@ -1540,8 +1541,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
- c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
@@ -1572,7 +1573,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
}
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ absl::Span<const XlaOp> parameters,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1584,7 +1585,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
if (parameters.size() != 2) {
return InvalidArgument(
"RNG distribution (%s) expects 2 parameters, but got %ld",
- RandomDistribution_Name(distribution).c_str(), parameters.size());
+ RandomDistribution_Name(distribution), parameters.size());
}
break;
default:
@@ -1631,27 +1632,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
});
}
-XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ absl::Span<const int64> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
- TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape,
- GetShape(gather_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
- ShapeInference::InferGatherShape(input_shape, gather_indices_shape,
- dimension_numbers, window_bounds));
+ ShapeInference::InferGatherShape(input_shape, start_indices_shape,
+ dimension_numbers, slice_sizes));
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
- for (int64 bound : window_bounds) {
- instr.add_gather_window_bounds(bound);
+ for (int64 bound : slice_sizes) {
+ instr.add_gather_slice_sizes(bound);
}
return AddInstruction(std::move(instr), HloOpcode::kGather,
- {input, gather_indices});
+ {input, start_indices});
});
}
@@ -1713,22 +1714,39 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
});
}
-XlaOp XlaBuilder::Reduce(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
+ return Reduce(absl::Span<const XlaOp>({operand}),
+ absl::Span<const XlaOp>({init_value}), computation,
+ dimensions_to_reduce);
+}
+
+XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
- TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
- TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
- ShapeInference::InferReduceShape(
- {&operand_shape, &init_shape}, dimensions_to_reduce,
- called_program_shape));
+ std::vector<XlaOp> all_operands;
+ all_operands.insert(all_operands.end(), operands.begin(), operands.end());
+ all_operands.insert(all_operands.end(), init_values.begin(),
+ init_values.end());
+
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
+ GetOperandShapes(all_operands));
+ absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
+
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferReduceShape(
+ operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
for (int64 dim : dimensions_to_reduce) {
instr.add_dimensions(dim);
@@ -1736,8 +1754,7 @@ XlaOp XlaBuilder::Reduce(
AddCalledComputation(computation, &instr);
- return AddInstruction(std::move(instr), HloOpcode::kReduce,
- {operand, init_value});
+ return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
});
}
@@ -1751,11 +1768,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
});
}
-XlaOp XlaBuilder::ReduceWindow(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1776,9 +1793,9 @@ XlaOp XlaBuilder::ReduceWindow(
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1873,8 +1890,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
}
XlaOp XlaBuilder::CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
+ const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
@@ -1882,23 +1898,24 @@ XlaOp XlaBuilder::CrossReplicaSum(
b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
- return CrossReplicaSum(operand, computation, replica_group_ids,
- /*channel_id=*/tensorflow::gtl::nullopt);
+ return CrossReplicaSum(operand, computation, replica_groups,
+ /*channel_id=*/absl::nullopt);
});
}
XlaOp XlaBuilder::CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
+ absl::Span<const ReplicaGroup> replica_groups,
+ const absl::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
- for (int64 replica_group_id : replica_group_ids) {
- instr.add_replica_group_ids(replica_group_id);
+
+ for (const ReplicaGroup& group : replica_groups) {
+ *instr.add_replica_groups() = group;
}
if (channel_id.has_value()) {
@@ -1945,8 +1962,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
std::vector<const Shape*> slice_shape_ptrs;
- c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
- [](const Shape& shape) { return &shape; });
+ absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
@@ -1967,12 +1984,34 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
});
}
-XlaOp XlaBuilder::SelectAndScatter(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+XlaOp XlaBuilder::CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferCollectivePermuteShape(operand_shape));
+
+ for (const auto& pair : source_target_pairs) {
+ auto* proto_pair = instr.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
+ {operand});
+ });
+}
+
+XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand,
+ const XlaComputation& select,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
return SelectAndScatterWithGeneralPadding(
@@ -1985,11 +2024,10 @@ XlaOp XlaBuilder::SelectAndScatter(
XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -2133,13 +2171,13 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
return InvalidArgument(
"SendToHost shape %s must be compatible with operand shape %s",
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
- ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ ShapeUtil::HumanStringWithLayout(shape_with_layout),
+ ShapeUtil::HumanStringWithLayout(operand_shape));
}
// TODO(b/111544877): Support tuple shapes.
if (!ShapeUtil::IsArray(operand_shape)) {
return InvalidArgument("SendToHost only supports array shapes, shape: %s",
- ShapeUtil::HumanString(operand_shape).c_str());
+ ShapeUtil::HumanString(operand_shape));
}
if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
@@ -2178,7 +2216,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument(
"RecvFromHost only supports array shapes, shape: %s",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
@@ -2233,7 +2271,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
"of being evaluated at XLA compile time.\n\n"
"Please file a usability bug with the framework being used (e.g. "
"TensorFlow).",
- op_string.c_str());
+ op_string);
}
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
@@ -2296,7 +2334,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
const string& computation_name) {
- auto sub_builder = MakeUnique<XlaBuilder>(computation_name);
+ auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
sub_builder->parent_builder_ = this;
sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
return sub_builder;
@@ -2341,8 +2379,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
- "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
- "%lld)",
+ "dimension numbers for the input are not unique: (%d, %d, %d, "
+ "%d)",
dnum.input_batch_dimension(), dnum.input_feature_dimension(),
dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
}
@@ -2352,8 +2390,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
dnum.kernel_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
- "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
- "%lld)",
+ "dimension numbers for the weight are not unique: (%d, %d, %d, "
+ "%d)",
dnum.kernel_output_feature_dimension(),
dnum.kernel_input_feature_dimension(),
dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
@@ -2364,17 +2402,17 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
dnum.output_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
- "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
- "%lld)",
+ "dimension numbers for the output are not unique: (%d, %d, %d, "
+ "%d)",
dnum.output_batch_dimension(), dnum.output_feature_dimension(),
dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
}
return Status::OK();
}
-StatusOr<XlaOp> XlaBuilder::AddInstruction(
- HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
+ HloOpcode opcode,
+ absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
const int64 handle = instructions_.size();
@@ -2385,13 +2423,11 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
- return InvalidArgument("invalid XlaOp with handle %lld",
- operand.handle());
+ return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
}
if (operand.builder_ != this) {
return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
- operand.builder_->name().c_str(),
- this->name().c_str());
+ operand.builder_->name(), this->name());
}
instr.add_operand_ids(operand.handle());
}
@@ -2421,18 +2457,18 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
if (op.builder_ == nullptr) {
return InvalidArgument(
- "invalid XlaOp with handle %lld; the builder of this op is freed",
+ "invalid XlaOp with handle %d; the builder of this op is freed",
op.handle());
}
if (op.builder_ != this) {
return InvalidArgument(
- "XlaOp with handle %lld is built by builder '%s', but is trying to use "
+ "XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'",
- op.handle(), op.builder_->name().c_str(), this->name().c_str());
+ op.handle(), op.builder_->name(), this->name());
}
if (op.handle() >= instructions_.size() || op.handle() < 0) {
- return InvalidArgument("no XlaOp value %lld", op.handle());
+ return InvalidArgument("no XlaOp value %d", op.handle());
}
return &instructions_[op.handle()];
}
@@ -2450,14 +2486,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
return builder->ConstantLiteral(literal);
}
-XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes) {
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
-XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions) {
return operand.builder()->BroadcastInDim(operand, shape,
broadcast_dimensions);
}
@@ -2467,26 +2501,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
return operand.builder()->Pad(operand, padding_value, padding_config);
}
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, dimensions, new_sizes);
}
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes);
}
-XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions) {
return operand.builder()->Collapse(operand, dimensions);
}
-XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return operand.builder()->Slice(operand, start_indices, limit_indices,
strides);
}
@@ -2498,7 +2528,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
}
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
@@ -2507,8 +2537,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
}
-XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
int64 dimension) {
return builder->ConcatInDim(operands, dimension);
}
@@ -2521,7 +2550,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
return pred.builder()->Select(pred, on_true, on_false);
}
-XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements) {
+XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
return builder->Tuple(elements);
}
@@ -2530,94 +2559,98 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
}
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
}
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
}
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
}
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
}
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
}
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
}
-XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return lhs.builder()->Dot(lhs, rhs);
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->Dot(lhs, rhs, precision_config);
}
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers) {
- return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
+ precision_config);
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count) {
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
- feature_group_count);
+ feature_group_count, precision_config);
}
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count) {
- return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding, feature_group_count);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
+ return lhs.builder()->ConvWithGeneralPadding(
+ lhs, rhs, window_strides, padding, feature_group_count, precision_config);
}
XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
- return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
- padding, dimension_numbers,
- feature_group_count);
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, const PrecisionConfig* precision_config) {
+ return lhs.builder()->ConvWithGeneralDimensions(
+ lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
+ precision_config);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
- dimension_numbers, feature_group_count);
-}
-
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ dimension_numbers, feature_group_count,
+ precision_config);
+}
+
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
- dimension_numbers, feature_group_count);
+ dimension_numbers, feature_group_count, precision_config);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length) {
+ absl::Span<const int64> fft_length) {
return operand.builder()->Fft(operand, fft_type, fft_length);
}
@@ -2631,106 +2664,106 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
}
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ absl::Span<const XlaOp> operands) {
return builder->Call(computation, operands);
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape) {
+ absl::Span<const XlaOp> operands, const Shape& shape) {
return builder->CustomCall(call_target_name, operands, shape);
}
-XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape) {
- return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape);
-}
-
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return real.builder()->Complex(real, imag, broadcast_dimensions);
}
XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
}
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
}
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
}
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
}
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
}
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
}
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
}
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
}
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
}
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
}
XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
}
-XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
}
-XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
}
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ absl::Span<const int64> dimensions_to_reduce) {
return operand.builder()->Reduce(operand, init_value, computation,
dimensions_to_reduce);
}
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce) {
+ return builder->Reduce(operands, init_values, computation,
+ dimensions_to_reduce);
+}
+
XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
return operand.builder()->ReduceAll(operand, init_value, computation);
@@ -2738,9 +2771,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding) {
return operand.builder()->ReduceWindow(operand, init_value, computation,
window_dimensions, window_strides,
padding);
@@ -2749,25 +2781,24 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
- return operand.builder()->CrossReplicaSum(operand, replica_group_ids);
+ absl::Span<const ReplicaGroup> replica_groups) {
+ return operand.builder()->CrossReplicaSum(operand, replica_groups);
}
-XlaOp CrossReplicaSum(
- const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
+XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
+ absl::Span<const ReplicaGroup> replica_groups,
+ const absl::optional<ChannelHandle>& channel_id) {
return operand.builder()->CrossReplicaSum(operand, computation,
- replica_group_ids, channel_id);
+ replica_groups, channel_id);
}
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
@@ -2777,11 +2808,17 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
split_count, replica_groups);
}
+XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return operand.builder()->CollectivePermute(operand, source_target_pairs);
+}
+
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value, const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter) {
return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
window_strides, padding, source,
init_value, scatter);
@@ -2789,11 +2826,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter) {
return operand.builder()->SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter);
@@ -2802,7 +2838,7 @@ XlaOp SelectAndScatterWithGeneralPadding(
XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return y.builder()->Atan2(y, x, broadcast_dimensions);
}
@@ -2835,7 +2871,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
}
@@ -2853,17 +2889,15 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
-XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
return operand.builder()->Transpose(operand, permutation);
}
-XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
-XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
- int64 dimension) {
+XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values, int64 dimension) {
return keys.builder()->Sort(keys, std::move(values), dimension);
}
@@ -2871,10 +2905,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
return min.builder()->Clamp(min, operand, max);
}
-XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ const XlaComputation& computation, absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands) {
return builder->Map(operands, computation, dimensions, static_operands);
}
@@ -2906,11 +2939,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
mantissa_bits);
}
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return input.builder()->Gather(input, gather_indices, dimension_numbers,
- window_bounds);
+ absl::Span<const int64> slice_sizes) {
+ return input.builder()->Gather(input, start_indices, dimension_numbers,
+ slice_sizes);
}
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -2964,7 +2997,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
-XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
return builder->AfterAll(tokens);
}
@@ -2991,11 +3024,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
grad_output, epsilon, feature_index);
}
-XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
- HloInstructionProto instr;
- *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size});
- return builder->ReportErrorOrReturn(
- builder->AddInstruction(std::move(instr), HloOpcode::kIota));
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ return builder->Iota(type, size);
+}
+
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
+ return builder->Iota(shape, iota_dimension);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 78aec770a6..833eafcf85 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <type_traits>
#include <utility>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -32,8 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -154,12 +154,10 @@ class XlaBuilder {
// Clears the sharding. Ops will be sharded according to the default placement
// policy.
- void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
+ void ClearSharding() { sharding_ = absl::nullopt; }
// Returns the OpSharding that will be attached to all instructions.
- const tensorflow::gtl::optional<OpSharding>& sharding() const {
- return sharding_;
- }
+ const absl::optional<OpSharding>& sharding() const { return sharding_; }
// Sets the builder to a mode where it will die immediately when an error is
// encountered, rather than producing it in a deferred fashion when Build() is
@@ -296,7 +294,7 @@ class XlaBuilder {
template <typename NativeT>
XlaOp ConstantR0(NativeT value);
template <typename NativeT>
- XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ XlaOp ConstantR1(absl::Span<const NativeT> values);
XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(
@@ -338,7 +336,7 @@ class XlaBuilder {
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
// Performs in-dimension-style broadcast.
//
@@ -357,9 +355,8 @@ class XlaBuilder {
// will generate output
// [1 , 1]
// [2 , 2]
- XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
@@ -372,15 +369,13 @@ class XlaBuilder {
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
- XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
@@ -400,8 +395,7 @@ class XlaBuilder {
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
- XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
@@ -414,10 +408,9 @@ class XlaBuilder {
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
- XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
@@ -438,7 +431,7 @@ class XlaBuilder {
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
@@ -461,8 +454,7 @@ class XlaBuilder {
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
- XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
+ XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
@@ -473,88 +465,93 @@ class XlaBuilder {
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
- XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements);
+ XlaOp Tuple(absl::Span<const XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
- XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count = 1);
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count = 1);
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
- XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
- XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
@@ -576,25 +573,14 @@ class XlaBuilder {
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
// During code generation, a call instruction is emitted which targets a
// symbol with the name |call_target_name|. The |operands| are passed to the
// call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
-
- // Enqueues a pseudo-op to represent host-side computation data-dependencies.
- // During code generation, host send and receive operations will be generated
- // to transfer |operands| to the host and a single result of |shape| back to
- // the device. Host send/recv operations are emitted using |channel_name|.
- // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
- // instruction scheduling.
- XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -603,65 +589,70 @@ class XlaBuilder {
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Not(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
+ XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
+ XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+
+ // Reduces several arrays simultaneously among the provided dimensions, given
+ // "computation" as a reduction operator.
+ XlaOp Reduce(absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
@@ -671,25 +662,23 @@ class XlaBuilder {
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
- XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+ XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
@@ -698,10 +687,11 @@ class XlaBuilder {
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
- // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
- // replicas belong to one group. Allreduce will be applied within subgroups.
- // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
- // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
+ // empty, all replicas belong to one group. Allreduce will be applied within
+ // subgroups. For example, we have 4 replicas, then
+ // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0,
+ // replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different modules, if they have
// the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will
@@ -710,22 +700,25 @@ class XlaBuilder {
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
- const tensorflow::gtl::optional<ChannelHandle>& channel_id =
- tensorflow::gtl::nullopt);
+ absl::Span<const ReplicaGroup> replica_groups = {},
+ const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
- //
- // TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
+ // Enqueues an operation that do an CollectivePermute of the operand cross
+ // cores.
+ XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding, const XlaOp& source,
const XlaOp& init_value,
const XlaComputation& scatter);
@@ -734,18 +727,17 @@ class XlaBuilder {
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
@@ -792,7 +784,7 @@ class XlaBuilder {
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
@@ -800,6 +792,12 @@ class XlaBuilder {
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
+ // Enqueues an iota operation onto the computation.
+ XlaOp Iota(const Shape& shape, int64 iota_dimension);
+
+ // Enqueues a rank-1 iota operation onto the computation.
+ XlaOp Iota(PrimitiveType type, int64 size);
+
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand,
@@ -816,14 +814,12 @@ class XlaBuilder {
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
- XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
- XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
@@ -841,18 +837,16 @@ class XlaBuilder {
// * The result is a tuple that consists of a sorted tensor of keys (along the
// provided dimension, as above) as the first element, and a tensor with their
// corresponding values as the second element.
- XlaOp Sort(XlaOp keys,
- tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values = absl::nullopt,
int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
- XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+ XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
@@ -877,9 +871,9 @@ class XlaBuilder {
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
- XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -907,7 +901,7 @@ class XlaBuilder {
// Enqueues an AfterAll operation with no operands producing a token-shaped
// value.
- XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
+ XlaOp AfterAll(absl::Span<const XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
@@ -954,9 +948,8 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- StatusOr<XlaOp> AddInstruction(
- HloInstructionProto&& instr, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<XlaOp> operands = {});
+ StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
+ absl::Span<const XlaOp> operands = {});
void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr);
@@ -970,19 +963,17 @@ class XlaBuilder {
// broadcast_dimensions specifies which dimensions to use for broadcasting
// when the operation is between tensors of different ranks.
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Internal helper method that does the building for an arbitrary ternary op.
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs);
XlaOp RngOp(RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<XlaOp> parameters,
- const Shape& shape);
+ absl::Span<const XlaOp> parameters, const Shape& shape);
- StatusOr<XlaOp> InDimBroadcast(
- const Shape& shape, const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
+ absl::Span<const int64> broadcast_dimensions);
// Internal helper method that creates a sequence of instructions that
// performs an explicit broadcast of the operand to the target shape.
@@ -998,7 +989,7 @@ class XlaBuilder {
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) const;
+ absl::Span<const XlaOp> operands) const;
// A visitor which checks whether an operation is a compile-time constant,
// meaning that it doesn't depend on any parameters, or on any stateful
@@ -1015,12 +1006,11 @@ class XlaBuilder {
// Helper function for creating a Window proto from user-supplied data.
// Returns error if the user-supplied data was invalid.
- StatusOr<Window> MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation) const;
+ StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation) const;
string name_; // Name to use for the built computation.
@@ -1049,7 +1039,7 @@ class XlaBuilder {
// Sharding for this operator. This is structured as a "model"-like operation,
// in order to simplify client code, similar to metadata_.
- tensorflow::gtl::optional<OpSharding> sharding_;
+ absl::optional<OpSharding> sharding_;
// Mode bit that indicates whether to die when a first error is encountered.
bool die_immediately_on_error_ = false;
@@ -1064,7 +1054,7 @@ class XlaBuilder {
friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
friend XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
+ absl::Span<const NativeT> values);
friend XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values);
template <typename NativeT>
@@ -1104,182 +1094,180 @@ class XlaBuilder {
friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
friend XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
friend XlaOp BroadcastInDim(
const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ const absl::Span<const int64> broadcast_dimensions);
friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config);
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
- friend XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
friend XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
friend XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno);
friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices);
friend XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- int64 dimension);
+ absl::Span<const XlaOp> operands, int64 dimension);
friend void Trace(const string& tag, const XlaOp& operand);
friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
const XlaOp& on_false);
- friend XlaOp Tuple(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ absl::Span<const int64> broadcast_dimensions);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_number,
+ const PrecisionConfig* precision_config);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, int64 feature_group_count);
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count);
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
- friend XlaOp ConvGeneral(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
+ friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count,
+ const PrecisionConfig* precision_config);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ int64 feature_group_count, const PrecisionConfig* precision_config);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config);
friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
- friend XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Not(const XlaOp& operand);
- friend XlaOp ShiftLeft(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp ShiftRightArithmetic(
const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- friend XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+ friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation);
- friend XlaOp ReduceWindow(
- const XlaOp& operand, const XlaOp& init_value,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding);
friend XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids);
- friend XlaOp CrossReplicaSum(
- const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups);
+ friend XlaOp CrossReplicaSum(const XlaOp& operand,
+ const XlaComputation& computation,
+ absl::Span<const ReplicaGroup> replica_groups,
+ const absl::optional<ChannelHandle>& channel_id);
friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups);
- friend XlaOp SelectAndScatter(
- const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ friend XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+ friend XlaOp SelectAndScatter(const XlaOp& operand,
+ const XlaComputation& select,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value,
+ const XlaComputation& scatter);
friend XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
friend XlaOp Abs(const XlaOp& operand);
friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp Exp(const XlaOp& operand);
friend XlaOp Expm1(const XlaOp& operand);
friend XlaOp Floor(const XlaOp& operand);
@@ -1295,28 +1283,25 @@ class XlaBuilder {
friend XlaOp Real(const XlaOp& operand);
friend XlaOp Imag(const XlaOp& operand);
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
- // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
- // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
- friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+ friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
+ int64 iota_dimension);
+ friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
friend XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp Neg(const XlaOp& operand);
friend XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
- friend XlaOp Rev(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
- friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
- int64 dimension);
+ absl::Span<const int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
+ friend XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values, int64 dimension);
friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
- friend XlaOp Map(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
+ friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands);
friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
const Shape& shape);
friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
@@ -1328,9 +1313,9 @@ class XlaBuilder {
const XlaComputation& false_computation);
friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
- friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates,
const XlaComputation& update_computation,
@@ -1364,8 +1349,7 @@ class XlaBuilder {
const Shape& shape_with_layout,
const string& outfeed_config);
friend XlaOp CreateToken(XlaBuilder* builder);
- friend XlaOp AfterAll(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> tokens);
+ friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
};
// RAII-style object: sets the current sharding assignment in builder on
@@ -1373,7 +1357,7 @@ class XlaBuilder {
class XlaScopedShardingAssignment {
public:
XlaScopedShardingAssignment(xla::XlaBuilder* builder,
- tensorflow::gtl::optional<OpSharding> sharding)
+ absl::optional<OpSharding> sharding)
: builder_(builder), prev_sharding_(builder->sharding()) {
SetSharding(sharding);
}
@@ -1385,7 +1369,7 @@ class XlaScopedShardingAssignment {
~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
private:
- void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
+ void SetSharding(const absl::optional<OpSharding>& sharding) {
if (sharding.has_value()) {
builder_->SetSharding(sharding.value());
} else {
@@ -1394,7 +1378,7 @@ class XlaScopedShardingAssignment {
}
xla::XlaBuilder* const builder_;
- tensorflow::gtl::optional<OpSharding> prev_sharding_;
+ absl::optional<OpSharding> prev_sharding_;
};
// Free functions for building XlaOps. The intention is that these will
@@ -1429,8 +1413,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values);
+XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
@@ -1479,8 +1462,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
// The new dimensions index into copies of the operand, i.e.
//
// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
-XlaOp Broadcast(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);
// Performs in-dimension-style broadcast.
//
@@ -1499,9 +1481,8 @@ XlaOp Broadcast(const XlaOp& operand,
// will generate output
// [1 , 1]
// [2 , 2]
-XlaOp BroadcastInDim(
- const XlaOp& operand, const Shape& shape,
- const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
+ const absl::Span<const int64> broadcast_dimensions);
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
@@ -1514,15 +1495,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
// given, followed by reshaping it into the shape with the given dimension
// sizes (also major to minor). Conceptually, this is a limited form of
// "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
-XlaOp Reshape(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
// Wrapper for Reshape.
// Enqueues an operation to collapse the provided dimensions; e.g. an
@@ -1542,8 +1521,7 @@ XlaOp Reshape(const XlaOp& operand,
//
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
-XlaOp Collapse(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a slice operation onto the computation that slices the operand
// from the start indices to the limit indices; e.g.
@@ -1556,10 +1534,9 @@ XlaOp Collapse(const XlaOp& operand,
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
// The strides parameter determines the stride over the slice
-XlaOp Slice(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Enqueues a slice operation in a given dimension, taking all other
// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
@@ -1580,7 +1557,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
// Slice index calculations are computed modulo input dimension sizes to
// prevent dynamic start indices from generating out-of-bound array accesses.
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Enqueues a dynamic update slice operation onto the computation, which
// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
@@ -1603,8 +1580,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
// Enqueues a concatenate instruction onto the computation. 'operands' must
// have >= 1 entry.
-XlaOp ConcatInDim(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
+XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ int64 dimension);
// Enqueue a tracing operation onto the computation; the computation will emit
// a logging message with the operand.
@@ -1615,87 +1592,91 @@ void Trace(const string& tag, const XlaOp& operand);
XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
// Enqueues a tuple-creation instruction onto the computation.
-XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
+XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
// Enqueues a tuple-element-get instruction onto the computation.
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
// Enqueues an equal-to comparison instruction onto the computation.
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a not-equal comparison instruction onto the computation.
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-or-equal comparison instruction onto the computation.
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a greater-than comparison instruction onto the computation.
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-than comparison instruction onto the computation.
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a less-or-equal comparison instruction onto the computation.
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a dot instruction onto the computation.
-XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- int64 feature_group_count = 1);
+ absl::Span<const int64> window_strides, Padding padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- int64 feature_group_count = 1);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
+ Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1,
+ const PrecisionConfig* precision_config = nullptr);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
XlaOp Fft(const XlaOp& operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
@@ -1727,26 +1708,14 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
// Enqueues a call instruction onto the computation.
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<XlaOp> operands);
+ absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
// During code generation, a call instruction is emitted which targets a
// symbol with the name |call_target_name|. The |operands| are passed to the
// call instruction. |shape| is the resultant shape.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const Shape& shape);
-
-// Enqueues a pseudo-op to represent host-side computation data-dependencies.
-// During code generation, host send and receive operations will be generated
-// to transfer |operands| to the host and a single result of |shape| back to
-// the device. Host send/recv operations are emitted using |channel_name|.
-// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
-// instruction scheduling.
-XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1755,65 +1724,70 @@ XlaOp HostCompute(XlaBuilder* builder,
// Enqueues a complex compose instruction onto the computation.
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
XlaOp Conj(const XlaOp& operand);
// Enqueues an add instruction onto the computation.
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a subtract instruction onto the computation.
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a multiply instruction onto the computation.
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a divide instruction onto the computation.
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a remainder instruction onto the computation.
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a max instruction onto the computation.
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a min instruction onto the computation.
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Element-wise logical operators
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
XlaOp Not(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-XlaOp ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
+XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
+XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions = {});
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
+
+// Reduces several arrays simultaneously among the provided dimensions, given
+// "computation" as a reduction operator.
+XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ absl::Span<const XlaOp> init_values,
+ const XlaComputation& computation,
+ absl::Span<const int64> dimensions_to_reduce);
// Convenience wrapper around the above that reduces all the dimensions in the
// operand shape.
@@ -1823,25 +1797,23 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
// Enqueues a windowed reduce instruction onto the computation.
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
// replicas supply one input to the sum and all replicas receive the resulting
// sum for each subgroup.
-XlaOp CrossReplicaSum(
- const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+XlaOp CrossReplicaSum(const XlaOp& operand,
+ absl::Span<const ReplicaGroup> replica_groups = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
@@ -1850,52 +1822,61 @@ XlaOp CrossReplicaSum(
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
-// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
-// replicas belong to one group. Allreduce will be applied within subgroups.
-// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
-// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+// - `replica_groups`: each ReplicaGroup contains a list of replica id. If
+// empty, all replicas belong to one group. Allreduce will be applied within
+// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
+// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different modules, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
-XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
- const tensorflow::gtl::optional<ChannelHandle>&
- channel_id = tensorflow::gtl::nullopt);
+XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ absl::Span<const ReplicaGroup> replica_groups = {},
+ const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
// Enqueues an operation that do an Alltoall of the operand cross cores.
-//
-// TODO(b/110096724): This is NOT YET ready to use.
XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
int64 concat_dimension, int64 split_count,
const std::vector<ReplicaGroup>& replica_groups = {});
+// Enqueues an collective operation that sends and receives data cross replicas.
+//
+// - `source_target_pair`: a list of (source_replica_id, target_replica_id)
+// pairs. For each pair, the operand is sent from source replica to target
+// replica. Note that, 1) any two pairs should not have the same target replica
+// id, and they should not have the same source replica id; 2) if a replica id
+// is not a target in any pair, then the output on that replica is a tensor
+// consists of 0(s) with the same shape as the input.
+XlaOp CollectivePermute(
+ const XlaOp& operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding, const XlaOp& source,
- const XlaOp& init_value, const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
// As SelectAndScatter(), but the padding is given in the format
// returned by MakePadding().
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp& operand, const XlaComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const XlaOp& source, const XlaOp& init_value,
- const XlaComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
// Enqueues an abs instruction onto the computation.
XlaOp Abs(const XlaOp& operand);
// Enqueues a atan2 instruction onto the computation.
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
@@ -1942,7 +1923,7 @@ XlaOp Imag(const XlaOp& operand);
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ absl::Span<const int64> broadcast_dimensions = {});
// Enqueues an operator that tests if the operand's values are finite, i.e.,
// not Inf or NaN. Defined only for floating-point types. Returns an array of
@@ -1950,6 +1931,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
+// Enqueues an iota operation onto the computation.
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
+
+// Enqueues a rank-1 iota operation onto the computation.
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
+
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
@@ -1964,13 +1951,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
XlaOp Neg(const XlaOp& operand);
// Enqueues a transpose instruction onto the computation.
-XlaOp Transpose(const XlaOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
// Enqueues a reverse instruction onto the computation. The order of the
// elements in the given dimensions is reversed (i.e., the element at index i
// is moved to index dimension_size - 1 - i).
-XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
// If only keys are provided:
@@ -1988,18 +1974,16 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
// * The result is a tuple that consists of a sorted tensor of keys (along the
// provided dimension, as above) as the first element, and a tensor with their
// corresponding values as the second element.
-XlaOp Sort(XlaOp keys,
- tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+XlaOp Sort(XlaOp keys, absl::optional<XlaOp> values = absl::nullopt,
int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
// Enqueues a map instruction onto the computation.
-XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
+ const XlaComputation& computation, absl::Span<const int64> dimensions,
+ absl::Span<const XlaOp> static_operands = {});
// Enqueues a N(mu, sigma) random number generation instruction onto the
// computation.
@@ -2024,9 +2008,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits);
// Enqueues a Gather node onto the computation.
-XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
// Enqueues a Scatter node onto the computation.
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
@@ -2084,7 +2068,7 @@ XlaOp CreateToken(XlaBuilder* builder);
// Enqueues an AfterAll instruction which produces a token-shaped value and
// takes a variadic number of token-shaped operands. The number of operands must
// be greater than zero. Used for joining tokens.
-XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
// Normalizes operand across spatial and batch dimensions for each feature.
//
@@ -2128,12 +2112,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
+ return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2145,44 +2129,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
+ return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2205,13 +2189,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
-XlaOp ConstantR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
+ return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2224,13 +2207,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2238,14 +2221,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
+ LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2253,15 +2235,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2270,7 +2251,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 49a15ec3b4..7c37ed00cd 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) {
ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
}
+TEST_F(XlaBuilderTest, CollectivePermute) {
+ XlaBuilder b(TestName());
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
+}
+
TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc
index 3543d41fc2..22c9e83bb2 100644
--- a/tensorflow/compiler/xla/client/xla_computation.cc
+++ b/tensorflow/compiler/xla/client/xla_computation.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -32,7 +32,7 @@ StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
if (IsNull()) {
return InvalidArgument("Computation is invalid.");
}
- auto session = MakeUnique<HloSnapshot>();
+ auto session = absl::make_unique<HloSnapshot>();
*session->mutable_hlo()->mutable_hlo_module() = proto_;
return std::move(session);
}
diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h
index 1a51fdee68..6d51126d88 100644
--- a/tensorflow/compiler/xla/device_util.h
+++ b/tensorflow/compiler/xla/device_util.h
@@ -21,8 +21,8 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -30,8 +30,8 @@ namespace xla {
// Returns a string that represents the device in terms of platform and ordinal;
// e.g. the first CUDA device will be "cuda:0"
string DeviceIdentifier(se::StreamExecutor* stream_exec) {
- return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":",
- stream_exec->device_ordinal());
+ return absl::StrCat(stream_exec->platform()->Name(), ":",
+ stream_exec->device_ordinal());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
index ffd1fb79e9..3fadabcf52 100644
--- a/tensorflow/compiler/xla/index_util.cc
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -18,16 +18,16 @@ limitations under the License.
#include <algorithm>
#include <string>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ const Shape& shape, absl::Span<const int64> multi_index) {
DCHECK_EQ(shape.dimensions_size(), multi_index.size());
// Padding and nested layouts not supported yet.
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
@@ -36,7 +36,7 @@ namespace xla {
DCHECK_GE(multi_index[i], 0);
DCHECK_LT(multi_index[i], shape.dimensions(i))
<< "indexing beyond extent in dimension " << i << ":"
- << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",")
+ << "\n\tindex: " << absl::StrJoin(multi_index, ",")
<< "\n\tshape: " << ShapeUtil::HumanString(shape);
}
@@ -118,8 +118,8 @@ namespace xla {
return multi_index;
}
-/* static */ bool IndexUtil::BumpIndices(
- const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices) {
+/* static */ bool IndexUtil::BumpIndices(const Shape& shape,
+ absl::Span<int64> indices) {
for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) {
int64 limit = shape.dimensions(dimno);
if (indices[dimno] + 1 < limit) {
@@ -149,8 +149,8 @@ namespace xla {
return stride;
}
-/* static */ bool IndexUtil::IndexInBounds(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
+/* static */ bool IndexUtil::IndexInBounds(const Shape& shape,
+ absl::Span<const int64> index) {
int64 rank = ShapeUtil::Rank(shape);
if (rank != index.size()) {
return false;
@@ -163,9 +163,8 @@ namespace xla {
return true;
}
-/* static */ int IndexUtil::CompareIndices(
- tensorflow::gtl::ArraySlice<int64> lhs,
- tensorflow::gtl::ArraySlice<int64> rhs) {
+/* static */ int IndexUtil::CompareIndices(absl::Span<const int64> lhs,
+ absl::Span<const int64> rhs) {
int64 rank = lhs.size();
CHECK_EQ(rhs.size(), rank);
for (int64 dim = 0; dim < rank; ++dim) {
diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h
index 142006f262..2979cf87dd 100644
--- a/tensorflow/compiler/xla/index_util.h
+++ b/tensorflow/compiler/xla/index_util.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -35,7 +35,7 @@ class IndexUtil {
// on the shape and its layout. The first index in the multi_index is
// dimension 0.
static int64 MultidimensionalIndexToLinearIndex(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index);
+ const Shape& shape, absl::Span<const int64> multi_index);
// Converts a linear index into multidimensional index (eg {x, y, z}) based on
// the shape and its layout. The first index in the returned multidimensional
@@ -58,8 +58,7 @@ class IndexUtil {
//
// Returns true iff the indices were successfully bumped; false if we've hit
// the limit where it can no longer be bumped in-bounds.
- static bool BumpIndices(const Shape& shape,
- tensorflow::gtl::MutableArraySlice<int64> indices);
+ static bool BumpIndices(const Shape& shape, absl::Span<int64> indices);
// Calculates the stride size (in number of elements, not byte size) of a
// given logical shape dimension (from 0 to rank-1). If available, padded
@@ -71,15 +70,14 @@ class IndexUtil {
// Returns true iff the given multi-index is contained in the bounds for the
// shape.
- static bool IndexInBounds(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> index);
+ static bool IndexInBounds(const Shape& shape, absl::Span<const int64> index);
// Compares the given indices in lexicographic order. lhs[0] and rhs[0] are
// compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger,
// then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is
// returned.
- static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,
- tensorflow::gtl::ArraySlice<int64> rhs);
+ static int CompareIndices(absl::Span<const int64> lhs,
+ absl::Span<const int64> rhs);
private:
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc
index 7c4efdee48..93522d2ca8 100644
--- a/tensorflow/compiler/xla/index_util_test.cc
+++ b/tensorflow/compiler/xla/index_util_test.cc
@@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) {
TEST(IndexUtilTest, BumpIndices2x2) {
auto shape = ShapeUtil::MakeShape(S32, {2, 2});
std::vector<int64> indices = {0, 0};
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(0, 1));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 0));
- EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
EXPECT_THAT(indices, ::testing::ElementsAre(1, 1));
- EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices)));
}
} // namespace
diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h
index a8bb8c7a7e..3a3ee21e76 100644
--- a/tensorflow/compiler/xla/iterator_util.h
+++ b/tensorflow/compiler/xla/iterator_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
#include <iterator>
#include <utility>
@@ -95,4 +95,4 @@ UnwrappingIterator<NestedIter> MakeUnwrappingIterator(NestedIter iter) {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
+#endif // TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_
diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc
index 7bc3189507..ec8b66df2d 100644
--- a/tensorflow/compiler/xla/iterator_util_test.cc
+++ b/tensorflow/compiler/xla/iterator_util_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <algorithm>
#include <list>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
@@ -27,7 +27,7 @@ namespace {
TEST(UnwrappingIteratorTest, Simple) {
std::vector<std::unique_ptr<int>> v;
for (int i = 0; i < 3; ++i) {
- v.push_back(MakeUnique<int>(i));
+ v.push_back(absl::make_unique<int>(i));
}
int i = 0;
for (auto iter = MakeUnwrappingIterator(v.begin());
@@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) {
TEST(UnwrappingIteratorTest, StdFind) {
std::list<std::unique_ptr<int>> l;
for (int i = 0; i < 3; ++i) {
- l.push_back(MakeUnique<int>(i));
+ l.push_back(absl::make_unique<int>(i));
}
EXPECT_EQ(l.begin()->get(),
*std::find(MakeUnwrappingIterator(l.begin()),
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index b72d190d54..d310335618 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -31,8 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer(
} // namespace
/* static */ Layout LayoutUtil::MakeLayout(
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ absl::Span<const int64> minor_to_major) {
Layout layout;
layout.set_format(DENSE);
for (int64 dimension_number : minor_to_major) {
@@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer(
}
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
- tensorflow::gtl::ArraySlice<int64> major_to_minor) {
+ absl::Span<const int64> major_to_minor) {
Layout layout;
layout.set_format(DENSE);
for (int i = major_to_minor.size() - 1; i >= 0; i--) {
@@ -169,7 +169,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} else if (ShapeUtil::IsArray(shape)) {
if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
return ValidateLayoutForShape(shape.layout(), shape);
} else {
@@ -177,7 +177,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
if (shape.has_layout()) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return Status::OK();
}
@@ -194,7 +194,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
layout.padded_dimensions_size() != 0) {
return InvalidArgument(
"shape of primitive type %s should not have a non-trivial layout",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return Status::OK();
}
@@ -202,17 +202,17 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
if (layout.format() == INVALID_FORMAT) {
return InvalidArgument(
"Layout does not have a valid format: layout {%s}, shape {%s}",
- layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str());
+ layout.ShortDebugString(), shape.ShortDebugString());
}
if (layout.format() == DENSE) {
if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) {
return InvalidArgument(
"layout minor_to_major field contains %d elements, "
- "but shape is rank %lld: {%s}; shape: %s",
+ "but shape is rank %d: {%s}; shape: %s",
layout.minor_to_major_size(), ShapeUtil::Rank(shape),
- tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(),
- shape.ShortDebugString().c_str());
+ absl::StrJoin(layout.minor_to_major(), ", "),
+ shape.ShortDebugString());
}
std::vector<bool> dimensions_in_layout(ShapeUtil::Rank(shape), false);
@@ -221,12 +221,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
if (dim < 0 || dim >= ShapeUtil::Rank(shape)) {
return InvalidArgument(
"layout minor_to_major field has out-of-bounds value: %s",
- HumanString(layout).c_str());
+ HumanString(layout));
}
if (dimensions_in_layout[dim]) {
return InvalidArgument(
"layout minor_to_major field has duplicate values: {%s}",
- HumanString(layout).c_str());
+ HumanString(layout));
}
dimensions_in_layout[dim] = true;
}
@@ -234,14 +234,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
if (layout.padded_dimensions_size() > 0) {
if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) {
return InvalidArgument(
- "layout has %d padded dimensions, but shape is rank %lld",
+ "layout has %d padded dimensions, but shape is rank %d",
layout.padded_dimensions_size(), ShapeUtil::Rank(shape));
}
for (int i = 0; i < layout.padded_dimensions_size(); ++i) {
if (layout.padded_dimensions(i) < shape.dimensions(i)) {
return InvalidArgument(
- "for dimension %d, dimension padding (%lld) is smaller than "
- "the dimension size (%lld) of the shape",
+ "for dimension %d, dimension padding (%d) is smaller than "
+ "the dimension size (%d) of the shape",
i, layout.padded_dimensions(i), shape.dimensions(i));
}
}
@@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return false;
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
+/* static */ absl::Span<const int64> LayoutUtil::PaddedDimensions(
const Shape& shape) {
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().padded_dimensions());
@@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return protobuf_util::ProtobufEquals(lhs, rhs);
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
+/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
const Shape& shape) {
CHECK(IsDenseArray(shape));
return AsInt64Slice(shape.layout().minor_to_major());
}
-/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
+/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
const Layout& layout) {
CHECK(layout.format() == DENSE);
return AsInt64Slice(layout.minor_to_major());
@@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ string LayoutUtil::HumanString(const Layout& layout) {
if (IsSparse(layout)) {
- return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(),
- "}");
+ return absl::StrCat("sparse{", layout.max_sparse_elements(), "}");
}
CHECK(IsDense(layout));
- return tensorflow::strings::StrCat(
- "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}");
+ return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}");
}
namespace {
@@ -474,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
- const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
+ const Layout& layout, absl::Span<const int64> dims) {
CHECK(IsDense(layout));
std::vector<int64> positions_in_layout;
for (int64 dim : dims) {
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 739bbe7367..b78883c2d8 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -34,11 +34,11 @@ class LayoutUtil {
public:
// Creates a layout with the given minor-to-major dimension order. (This is a
// convenience function for protobuf construction.)
- static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ static Layout MakeLayout(absl::Span<const int64> minor_to_major);
// Similar to MakeLayout, but take indices in reverse order.
static Layout MakeLayoutFromMajorToMinor(
- tensorflow::gtl::ArraySlice<int64> major_to_minor);
+ absl::Span<const int64> major_to_minor);
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
@@ -104,8 +104,7 @@ class LayoutUtil {
// Returns the padded_dimensions array for the given Shape. Requires that the
// shape is an array and has a dense layout.
- static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
- const Shape& shape);
+ static absl::Span<const int64> PaddedDimensions(const Shape& shape);
// Returns the given index of the padded_dimensions array for the given Shape.
// Requires that the shape is an array and has a dense layout.
@@ -138,8 +137,8 @@ class LayoutUtil {
// Returns the minor_to_major array for the given Shape. Requires that the
// shape is an array and has a dense layout.
- static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Shape& shape);
- static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Layout& layout);
+ static absl::Span<const int64> MinorToMajor(const Shape& shape);
+ static absl::Span<const int64> MinorToMajor(const Layout& layout);
// Major(0) is the most major logical dimension number, Major(1) is the
// second-most-major logical dimension number and so on.
@@ -196,7 +195,7 @@ class LayoutUtil {
// Returns whether the given dimensions are consecutive in the given layout,
// not necessarily in the order given.
static bool AreDimensionsConsecutive(const Layout& layout,
- tensorflow::gtl::ArraySlice<int64> dims);
+ absl::Span<const int64> dims);
// Compute a hash for `layout`.
static size_t Hash(const Layout& layout);
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
index e4c825450d..f25dae6ff4 100644
--- a/tensorflow/compiler/xla/layout_util_test.cc
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -27,15 +27,15 @@ namespace {
class LayoutUtilTest : public ::testing::Test {
protected:
Shape MakeShapeWithLayout(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
return shape;
}
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
+ absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index 89353448e2..3e79129aaf 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -39,6 +40,7 @@ tf_cc_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -56,6 +58,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -73,5 +76,7 @@ tf_cc_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 1bf8948ef6..0d3136b0cc 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace legacy_flags {
@@ -87,7 +87,7 @@ void AllocateFlags() {
// Custom "sub-parser" lambda for xla_disable_hlo_passes.
auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
std::vector<string> disabled_passes =
- tensorflow::str_util::Split(comma_separated_values, ',');
+ absl::StrSplit(comma_separated_values, ',');
for (const auto& passname : disabled_passes) {
flag_values->add_xla_disable_hlo_passes(passname);
}
@@ -316,6 +316,13 @@ void AllocateFlags() {
bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
flag_values->xla_cpu_use_mkl_dnn(),
"Generate calls to MKL-DNN in the CPU backend."),
+ tensorflow::Flag(
+ "xla_gpu_crash_on_verification_failures",
+ bool_setter_for(
+ &DebugOptions::set_xla_gpu_crash_on_verification_failures),
+ flag_values->xla_gpu_crash_on_verification_failures(),
+ "Crashes the program on extra verification failures, e.g. cuDNN "
+ "cross checking failures"),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
index e9cf435d83..ee7eb019c0 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
@@ -17,10 +17,10 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#include <vector>
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace legacy_flags {
@@ -30,7 +30,7 @@ template <typename T>
void parse_xla_backend_extra_options(T* extra_options_map,
string comma_separated_values) {
std::vector<string> extra_options_parts =
- tensorflow::str_util::Split(comma_separated_values, ',');
+ absl::StrSplit(comma_separated_values, ',');
// The flag contains a comma-separated list of options; some options
// have arguments following "=", some don't.
@@ -59,8 +59,7 @@ void parse_xla_backend_extra_options(T* extra_options_map,
inline bool parse_xla_reduce_precision_option(
HloReducePrecisionOptions* options, string option_string) {
// Split off "LOCATION" from remainder of string.
- std::vector<string> eq_split =
- tensorflow::str_util::Split(option_string, '=');
+ std::vector<string> eq_split = absl::StrSplit(option_string, '=');
if (eq_split.size() != 2) {
return false;
}
@@ -80,26 +79,25 @@ inline bool parse_xla_reduce_precision_option(
}
// Split off "E,M" from remainder of string.
- std::vector<string> colon_split =
- tensorflow::str_util::Split(eq_split[1], ':');
+ std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':');
if (colon_split.size() != 2) {
return false;
}
// Split E and M, and parse.
std::vector<int32> bitsizes;
- if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',',
- &bitsizes) ||
- bitsizes.size() != 2) {
- return false;
+ for (const auto& s : absl::StrSplit(colon_split[0], ',')) {
+ bitsizes.emplace_back();
+ if (!absl::SimpleAtoi(s, &bitsizes.back())) {
+ return false;
+ }
}
options->set_exponent_bits(bitsizes[0]);
options->set_mantissa_bits(bitsizes[1]);
// Split off OPS comma-separated list from remainder of string, if the
// remainder exists.
- std::vector<string> semicolon_split =
- tensorflow::str_util::Split(colon_split[1], ';');
+ std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';');
if (semicolon_split.size() > 2) {
return false;
}
@@ -113,8 +111,7 @@ inline bool parse_xla_reduce_precision_option(
options->add_opcodes_to_suffix(i);
}
} else {
- std::vector<string> opcodes =
- tensorflow::str_util::Split(opcode_string, ',');
+ std::vector<string> opcodes = absl::StrSplit(opcode_string, ',');
for (const string& opcode : opcodes) {
bool found = false;
for (int i = 0; i < HloOpcodeCount(); i++) {
@@ -132,8 +129,7 @@ inline bool parse_xla_reduce_precision_option(
// Process the NAMES string, if it exists.
if (semicolon_split.size() == 2) {
- std::vector<string> opnames =
- tensorflow::str_util::Split(semicolon_split[1], ',');
+ std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ',');
for (const string& opname : opnames) {
if (opname.length() > 0) {
options->add_opname_substrings_to_suffix(opname);
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
index 0ed788a967..6f197aec53 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc
index 7b6ae311c1..138c0c852e 100644
--- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc
+++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <stdlib.h>
#include <vector>
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/subprocess.h"
#include "tensorflow/core/platform/test.h"
@@ -106,8 +106,8 @@ TEST(ParseFlagsFromEnv, File) {
if (tmp_dir == nullptr) {
tmp_dir = kTempDir;
}
- string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d",
- tmp_dir, getpid());
+ string tmp_file =
+ absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
FILE* fp = fopen(tmp_file.c_str(), "w");
CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
for (int i = 0; kTestFlagString[i] != '\0'; i++) {
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 36e472568e..f1f255efae 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -22,6 +22,10 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -30,19 +34,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+using absl::StrFormat;
+
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
// Converts between little and big endian.
@@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
MutableLiteralBase::StrideConfig::StrideConfig(
const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+ absl::Span<const int64> dimensions)
: dimensions(dimensions),
base(dimensions.size(), 0),
step(dimensions.size(), 1) {
@@ -134,7 +134,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
Literal::Literal(const Shape& shape, bool allocate_arrays)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(shape);
+ shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
root_piece_->set_subshape(shape_.get());
@@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) {
return *this;
}
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = MakeUnique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+ Literal literal(shape);
+ literal.root_piece_->ForEachMutableSubpiece(
[&](const ShapeIndex& index, Piece* piece) {
if (ShapeUtil::IsArray(piece->subshape())) {
memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices(
template <typename NativeT>
Status MutableLiteralBase::CopySliceFromInternal(
- const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+ const LiteralBase& src_literal, absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
auto linear_index = [](const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
+ absl::Span<const int64> multi_index) {
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
};
@@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal(
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
copy_size);
- auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto copy_proc = [&](absl::Span<const int64> indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
src_indexes.begin(), std::plus<int64>());
@@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal(
return Status::OK();
}
-Status MutableLiteralBase::CopyElementFrom(
- const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index) {
+Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
+ absl::Span<const int64> src_index,
+ absl::Span<const int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
src_literal.shape(), src_index);
@@ -280,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+ const LiteralProto& proto) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@@ -289,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
return InvalidArgument("LiteralProto has no layout");
}
- auto literal = MakeUnique<Literal>(proto.shape());
+ Literal literal(proto.shape());
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
const LiteralProto* proto_element = &proto;
for (int64 i : index) {
@@ -303,7 +301,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
if (proto_element->tuple_literals_size() !=
ShapeUtil::TupleElementCount(piece->subshape())) {
return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
+ "Expected %d tuple elements in LiteralProto, has %d",
ShapeUtil::TupleElementCount(piece->subshape()),
proto_element->tuple_literals_size());
}
@@ -355,9 +353,9 @@ namespace {
// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
// the array slices are indicated by dest_shape and src_shape respectively.
template <typename NativeT>
-void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- tensorflow::gtl::ArraySlice<NativeT> src,
- const Shape& dest_shape, const Shape& src_shape) {
+void CopyElementsBetween(absl::Span<NativeT> dest,
+ absl::Span<const NativeT> src, const Shape& dest_shape,
+ const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
@@ -366,7 +364,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
do {
dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
- } while (IndexUtil::BumpIndices(dest_shape, &index));
+ } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index)));
}
} // namespace
@@ -404,7 +402,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
default:
return Unimplemented(
"Copying a Literal object with element type %s is not implemented.",
- PrimitiveType_Name(subshape().element_type()).c_str());
+ PrimitiveType_Name(subshape().element_type()));
}
}
return Status::OK();
@@ -420,8 +418,8 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal,
if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
return InvalidArgument(
"Destination subshape incompatible with source subshape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_subshape).c_str());
+ ShapeUtil::HumanString(dest_subshape),
+ ShapeUtil::HumanString(src_subshape));
}
return root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
@@ -458,8 +456,8 @@ Status Literal::MoveFrom(Literal&& src_literal,
if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
return InvalidArgument(
"Destination subshape not equal to source shape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_literal.shape()).c_str());
+ ShapeUtil::HumanString(dest_subshape),
+ ShapeUtil::HumanString(src_literal.shape()));
}
src_literal.root_piece_->ForEachSubpiece(
@@ -479,7 +477,7 @@ Status Literal::MoveFrom(Literal&& src_literal,
dest_piece.set_sparse_indices(src_piece.sparse_indices());
});
- src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ src_literal.shape_ = absl::make_unique<Shape>(ShapeUtil::MakeNil());
delete src_literal.root_piece_;
src_literal.root_piece_ = new LiteralBase::Piece();
src_literal.root_piece_->set_subshape(src_literal.shape_.get());
@@ -487,11 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal,
return Status::OK();
}
-Status MutableLiteralBase::CopySliceFrom(
- const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size) {
TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
<< ShapeUtil::HumanString(src_literal.shape());
@@ -559,40 +556,38 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) {
}
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
*subshape->mutable_layout() = new_layout;
- auto result = MakeUnique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
+ Literal result(new_shape);
+ TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
<< " not compatible with literal shape "
<< ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ Literal result = CreateFromShape(shape_with_layout);
ShapeUtil::ForEachSubshape(
- result->shape(),
+ result.shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
+ TF_CHECK_OK(result.CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
+StatusOr<Literal> LiteralBase::Broadcast(
+ const Shape& result_shape, absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Broadcast only supports arrays.");
}
@@ -602,20 +597,20 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
result_shape.dimensions(dimensions[i]));
}
- std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+ Literal result(result_shape);
// scratch_source_index is temporary storage space for the computed index into
// the input literal. We put it here to avoid allocating an std::vector in
// every iteration of ShapeUtil::ForEachIndex.
std::vector<int64> scratch_source_index(shape().dimensions_size());
- char* dest_data = static_cast<char*>(result->untyped_data());
+ char* dest_data = static_cast<char*>(result.untyped_data());
const char* source_data = static_cast<const char*>(untyped_data());
const int64 primitive_size =
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
ShapeUtil::ForEachIndex(
- result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ result_shape, [&](absl::Span<const int64> output_index) {
for (int64 i = 0; i < dimensions.size(); ++i) {
scratch_source_index[i] = output_index[dimensions[i]];
}
@@ -631,37 +626,36 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
return std::move(result);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
+StatusOr<Literal> LiteralBase::Reshape(
+ absl::Span<const int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
- std::unique_ptr<Literal> output;
+ Literal output;
if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
output =
Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
} else {
- output = CloneToUnique();
+ output = Clone();
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape_do_not_use() =
+ *output.mutable_shape_do_not_use() =
ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output.shape());
if (elements_before != elements_after) {
return InvalidArgument(
"Shapes before and after Literal::Reshape have different numbers "
"of elements: %s vs %s.",
- ShapeUtil::HumanString(shape()).c_str(),
- ShapeUtil::HumanString(output->shape()).c_str());
+ ShapeUtil::HumanString(shape()),
+ ShapeUtil::HumanString(output.shape()));
}
return std::move(output);
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
@@ -691,33 +685,31 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
for (auto index : LayoutUtil::MinorToMajor(shape())) {
layout->add_minor_to_major(inverse_permutation[index]);
}
- auto new_literal = MakeUnique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ Literal new_literal(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const {
- auto result_literal = MakeUnique<Literal>(result_shape);
+Literal LiteralBase::SliceInternal(
+ const Shape& result_shape, absl::Span<const int64> start_indices) const {
+ Literal result_literal(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- result_literal->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
+ result_literal.EachCell<NativeT>(
+ [&](absl::Span<const int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
- result_literal->Set<NativeT>(indices, value);
+ result_literal.Set<NativeT>(indices, value);
});
return result_literal;
}
-std::unique_ptr<Literal> LiteralBase::Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
@@ -755,13 +747,7 @@ Literal LiteralBase::Clone() const {
return result;
}
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = MakeUnique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
-string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
@@ -858,7 +844,7 @@ string LiteralBase::GetSparseElementAsString(
}
StatusOr<int64> LiteralBase::GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@@ -874,9 +860,8 @@ StatusOr<int64> LiteralBase::GetIntegralAsS64(
case U64:
return Get<uint64>(multi_index);
default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
+ return FailedPrecondition("Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()));
}
}
@@ -901,8 +886,8 @@ size_t LiteralBase::Hash() const {
return hash_value;
}
-Status MutableLiteralBase::SetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
+Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
+ int64 value) {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
case PRED:
@@ -924,14 +909,13 @@ Status MutableLiteralBase::SetIntegralAsS64(
Set<uint64>(multi_index, value);
break;
default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
+ return FailedPrecondition("Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()));
}
return Status::OK();
}
-tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
+absl::Span<const int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
@@ -1000,7 +984,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() {
auto values = data<NativeT>();
CHECK_LE(num_elements, values.size());
sparse_indices()->SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
+ absl::Span<NativeT>(values.data(), num_elements));
}
namespace {
@@ -1029,9 +1013,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
element_index.push_back(i);
std::vector<string> element_pieces;
ToStringHelper(literal, element_index, print_layout, &element_pieces);
- tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+ tuple_pieces.push_back(absl::StrJoin(element_pieces, ""));
}
- pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+ pieces->push_back(absl::StrJoin(tuple_pieces, ",\n"));
pieces->push_back("\n)");
return;
}
@@ -1055,8 +1039,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(": ");
} else {
pieces->push_back("[");
- pieces->push_back(
- tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
+ pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", "));
pieces->push_back("]: ");
}
pieces->push_back(literal.GetSparseElementAsString(i));
@@ -1067,8 +1050,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
CHECK(LayoutUtil::IsDenseArray(subshape));
- auto element_to_string =
- [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ auto element_to_string = [&](absl::Span<const int64> indices) -> string {
PrimitiveType element_type = subshape.element_type();
if (element_type == PRED) {
// We display predicates in a densely packed form.
@@ -1117,9 +1099,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(shape_to_string(subshape));
pieces->push_back(" {\n");
for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0));
for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1));
for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
pieces->push_back(" {");
for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
@@ -1137,11 +1119,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(shape_to_string(subshape));
pieces->push_back(" {\n");
for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0));
for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1));
for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
+ pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2));
for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
pieces->push_back(" {");
for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
@@ -1163,7 +1145,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
pieces->push_back(shape_to_string(subshape));
pieces->push_back(" {");
literal.EachCellAsString(
- [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ [&](absl::Span<const int64> indices, const string& value) {
pieces->push_back(" ");
pieces->push_back(value);
});
@@ -1182,11 +1164,11 @@ string LiteralBase::ToString(bool print_layout) const {
std::vector<string> pieces;
CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, print_layout, &pieces);
- return tensorflow::str_util::Join(pieces, "");
+ return absl::StrJoin(pieces, "");
}
void LiteralBase::EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const std::function<void(absl::Span<const int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
@@ -1195,19 +1177,19 @@ void LiteralBase::EachCellAsString(
shape(), /*linear_index=*/0);
do {
per_cell(indices, GetAsString(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+ const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+ Literal result_literal(ShapeUtil::ChangeElementType(
src_literal.shape(),
primitive_util::NativeToPrimitiveType<NativeDestT>()));
auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
+ auto dest_data = result_literal.template data<NativeDestT>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
@@ -1217,8 +1199,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1226,7 +1207,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
@@ -1241,22 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
// identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
+ Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(
+ Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
- tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.data<NativeSrcT>();
- tensorflow::gtl::MutableArraySlice<complex64> dest_data =
- result_literal->data<complex64>();
+ absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
+ absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1265,8 +1244,7 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
return BitcastBetweenNativeTypes<
@@ -1284,9 +1262,9 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
}
template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
@@ -1313,18 +1291,17 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
default:
break;
}
- return Unimplemented(
- "Converting from type %s to type %s is not implemented.",
- PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
+ return Unimplemented("Converting from type %s to type %s is not implemented.",
+ PrimitiveType_Name(src_literal.shape().element_type()),
+ PrimitiveType_Name(primitive_dest_type));
}
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+ PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
+ return literal.Clone();
}
switch (literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
@@ -1345,38 +1322,37 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
- return Unimplemented(
- "%s from type %s to type %s is not implemented.",
- (bitcast ? "Bitcast converting" : "Converting"),
- PrimitiveType_Name(literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
+ return Unimplemented("%s from type %s to type %s is not implemented.",
+ (bitcast ? "Bitcast converting" : "Converting"),
+ PrimitiveType_Name(literal.shape().element_type()),
+ PrimitiveType_Name(primitive_dest_type));
}
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
return InvalidArgument(
"Cannot bitcast convert from %s to %s, bit widths are different: %d != "
"%d",
- PrimitiveType_Name(shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str(),
+ PrimitiveType_Name(shape().element_type()),
+ PrimitiveType_Name(primitive_dest_type),
primitive_util::BitWidth(shape().element_type()),
primitive_util::BitWidth(primitive_dest_type));
}
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
+ bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
dest_shape.element_type() == BF16) {
@@ -1394,15 +1370,13 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
+ elements.push_back(std::move(new_element));
}
- auto converted = MakeUnique<Literal>();
- *converted = MutableLiteralBase::MoveIntoTuple(&elements);
- return std::move(converted);
+ return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
}
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements) {
+ absl::Span<Literal> elements) {
std::vector<Shape> element_shapes;
for (const Literal& element : elements) {
element_shapes.push_back(element.shape());
@@ -1435,6 +1409,12 @@ bool LiteralBase::Piece::EqualElementsInternal(
bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
+ if (ShapeUtil::Equal(subshape(), other.subshape()) &&
+ LayoutUtil::IsDenseArray(subshape())) {
+ CHECK_EQ(size_bytes(), other.size_bytes());
+ return memcmp(buffer(), other.buffer(), size_bytes()) == 0;
+ }
+
std::vector<int64> multi_index;
switch (subshape().element_type()) {
case PRED:
@@ -1487,7 +1467,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const {
namespace {
template <typename NativeT>
-static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+static bool AllElementsEqualValue(absl::Span<const NativeT> data,
NativeT value) {
for (int64 i = 0; i < data.size(); ++i) {
if (data[i] != value) {
@@ -1686,7 +1666,62 @@ bool LiteralBase::IsAllFirst() const {
});
}
-bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsR1Iota() const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return false;
+ }
+
+ if (ShapeUtil::Rank(shape()) != 1) {
+ return false;
+ }
+
+ auto is_iota_at_idx = [&](const int64 idx) {
+ switch (shape().element_type()) {
+ case U8:
+ return Get<uint8>({idx}) == idx;
+ case U16:
+ return Get<uint16>({idx}) == idx;
+ case U32:
+ return Get<uint32>({idx}) == idx;
+ case U64:
+ return Get<uint64>({idx}) == idx;
+ case S8:
+ return Get<int8>({idx}) == idx;
+ case S16:
+ return Get<int16>({idx}) == idx;
+ case S32:
+ return Get<int32>({idx}) == idx;
+ case S64:
+ return Get<int64>({idx}) == idx;
+ case F32:
+ return Get<float>({idx}) == idx;
+ case F64:
+ return Get<double>({idx}) == idx;
+ case F16:
+ return Get<half>({idx}) == static_cast<half>(idx);
+ case BF16:
+ return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
+ case C64:
+ return Get<complex64>({idx}) == complex64(idx, 0.0f);
+ case PRED:
+ return Get<bool>({idx}) == idx;
+ // token, opaque, tuple, etc. are all not iota.
+ default:
+ return false;
+ }
+ };
+
+ const int64 elements = ShapeUtil::ElementsIn(shape());
+ for (int64 idx = 0; idx < elements; ++idx) {
+ if (!is_iota_at_idx(idx)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
@@ -1722,7 +1757,7 @@ namespace {
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
- const tensorflow::gtl::ArraySlice<NativeT> src) {
+ const absl::Span<const NativeT> src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
@@ -1800,7 +1835,7 @@ void* LiteralBase::Piece::untyped_data() {
namespace {
template <typename RepeatedFieldT, typename NativeT>
-Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+Status CopyFromRepeatedField(absl::Span<NativeT> dest,
const RepeatedFieldT& src) {
if (dest.size() != src.size()) {
return InvalidArgument(
@@ -1956,7 +1991,7 @@ MutableLiteralBase::~MutableLiteralBase() {}
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableBorrowingLiteral& literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1967,7 +2002,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
const MutableBorrowingLiteral& literal) {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1981,7 +2016,7 @@ MutableBorrowingLiteral& MutableBorrowingLiteral::operator=(
MutableBorrowingLiteral::MutableBorrowingLiteral(
const MutableLiteralBase& literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.shape());
+ shape_ = absl::make_unique<Shape>(literal.shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -1992,7 +2027,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal->shape());
+ shape_ = absl::make_unique<Shape>(literal->shape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -2004,7 +2039,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal)
MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral literal, const ShapeIndex& view_root)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(literal.piece(view_root).subshape());
+ shape_ = absl::make_unique<Shape>(literal.piece(view_root).subshape());
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = new Piece();
@@ -2016,7 +2051,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(
MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr,
const Shape& shape)
: MutableLiteralBase() {
- shape_ = MakeUnique<Shape>(shape);
+ shape_ = absl::make_unique<Shape>(shape);
CHECK(LayoutUtil::HasLayout(*shape_));
CHECK(!ShapeUtil::IsTuple(*shape_));
@@ -2061,7 +2096,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));
CHECK(LayoutUtil::HasLayout(*shape_));
@@ -2070,9 +2105,9 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
root_piece_.set_subshape(shape_.get());
}
-BorrowingLiteral::BorrowingLiteral(
- tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
+ const Shape& shape)
+ : LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
CHECK(ShapeUtil::IsTuple(*shape_));
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 92c0f903cb..fa5b5f7fab 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -25,13 +25,15 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -40,8 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -70,13 +70,12 @@ class LiteralBase {
// Serialize to proto.
LiteralProto ToProto() const;
- // Returns an ArraySlice of the array for this literal for the given NativeT
+ // Returns a Span of the array for this literal for the given NativeT
// (e.g., float). CHECKs if the subshape of the literal at the given
// ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
// to native type.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
+ absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
// Returns a const pointer to the sparse index array. Returns nullptr if the
// literal is not a sparse array.
@@ -100,12 +99,12 @@ class LiteralBase {
// Gets an element in the literal at the given index. The multi_index is
// CHECKed against the dimension sizes.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT Get(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const;
// Overloads of Get for array literals. CHECKs if the literal is not
// array-shaped and dense.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ NativeT Get(absl::Span<const int64> multi_index) const;
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
@@ -114,7 +113,7 @@ class LiteralBase {
// As Get(), but determines the correct type and converts the value
// into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ string GetAsString(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index = {}) const;
// As GetSparseElement(), but determines the correct type and converts the
// value into text.
@@ -122,14 +121,13 @@ class LiteralBase {
const ShapeIndex& shape_index = {}) const;
// As Get(), but determines the correct type and converts the value into
// int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
// Returns the multi-index of the element in a sparse literal at the given
// sparse element number. The sparse element number is the position with in
// the sparse array's list of (index, value) pairs, and is checked against the
// total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ absl::Span<const int64> GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
// Returns the value of the element in a sparse literal at the given sparse
@@ -150,12 +148,12 @@ class LiteralBase {
//
// This literal must have a dense layout.
void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const std::function<void(absl::Span<const int64> indices,
const string& value)>& per_cell) const;
template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
+ void EachCell(
+ std::function<void(absl::Span<const int64> indices, NativeT value)>
+ per_cell) const;
// Returns whether every element in this literal is equal to value.
//
@@ -195,9 +193,12 @@ class LiteralBase {
// Literal consists entirely of the first element of the literal.
bool IsAllFirst() const;
+ // Literal consists entirely of an iota.
+ bool IsR1Iota() const;
+
// Returns whether this literal is zero at the specified index. This literal
// must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+ bool IsZero(absl::Span<const int64> indices) const;
// Returns the count of the elements in the array at the given shape index in
// this literal.
@@ -222,25 +223,21 @@ class LiteralBase {
//
// TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
// the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
+ bool round_f32_to_bf16 = false) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
// width. Returns an error if the conversion is not possible. This literal
// must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
// Converts this literal to another primitive type. Returns an error if the
// conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
+ // Clones the underlying buffers into a new Literal.
Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
// TODO(b/67651157): The methods below which perform computation on Literals
// (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -258,25 +255,23 @@ class LiteralBase {
// Note: this is useful when the client wants to ensure that a value placed in
// the XLA allocation tracker has a particular layout; for efficiency
// purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
+ Literal Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
// An overload of Relayout which changes the layout of the entire shape rather
// than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+ Literal Relayout(const Shape& shape_with_layout) const;
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
// This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
// Creates a new literal by broadcasting this literal with `dimensions` to
// yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ StatusOr<Literal> Broadcast(const Shape& result_shape,
+ absl::Span<const int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
@@ -285,8 +280,7 @@ class LiteralBase {
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
// This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
+ Literal Transpose(absl::Span<const int64> permutation) const;
// Creates a sub-array from this literal by extracting the indices
// [start_index, limit_index) of each dimension. The result literal has the
@@ -294,16 +288,15 @@ class LiteralBase {
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
// This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+ Literal Slice(absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
// This literal must be an array.
template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
+ Literal Replicate(int64 times) const;
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
@@ -312,9 +305,9 @@ class LiteralBase {
// Note: It's an antipattern to use this method then immediately call
// MutableLiteralBase::Populate on the result (since that results in zero
// initialization, then reinitialization. Conside if a call to
- // MakeUnique<Literal>(shape), followed by the call to
+ // absl::make_unique<Literal>(shape), followed by the call to
// MutableLiteralBase::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+ static Literal CreateFromShape(const Shape& shape);
protected:
// A data structure representing a subshape at a particular ShapeIndex within
@@ -325,9 +318,9 @@ class LiteralBase {
// Returns the buffer holding the array data for this piece as an array
// slice. This piece must be array-shaped.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
+ absl::Span<const NativeT> data() const;
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
+ absl::Span<NativeT> data();
// Returns the buffer holding the array data for this piece as a void*. This
// piece must be array-shaped.
@@ -338,9 +331,9 @@ class LiteralBase {
// is CHECKed against the dimension sizes of the array. This piece must be
// array-shaped.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ NativeT Get(absl::Span<const int64> index) const;
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+ void Set(absl::Span<const int64> index, NativeT value);
// Gets/sets the buffer holding the array data.
char* buffer() const { return buffer_; }
@@ -541,9 +534,8 @@ class LiteralBase {
private:
template <typename NativeT>
- std::unique_ptr<Literal> SliceInternal(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> start_indices) const;
+ Literal SliceInternal(const Shape& result_shape,
+ absl::Span<const int64> start_indices) const;
};
// Abstract base class representing a mutable literal in XLA.
@@ -551,13 +543,12 @@ class MutableLiteralBase : public LiteralBase {
public:
virtual ~MutableLiteralBase() = 0;
- // Returns a MutableArraySlice view of the array for this literal for the
+ // Returns a Span view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
// given ShapeIndex is not array. See primitive_util.h for the mapping from
// XLA type to native type.
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {});
+ absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
// Unhide const method from parent class.
using LiteralBase::data;
@@ -584,8 +575,7 @@ class MutableLiteralBase : public LiteralBase {
// are populated.
template <typename NativeT>
void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
+ absl::Span<const NativeT> values, bool sort = true);
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
@@ -606,39 +596,38 @@ class MutableLiteralBase : public LiteralBase {
// corresponding base indices being 0.
// This literal and 'src_literal' must be arrays.
Status CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size);
// Copies one element from src_literal[src_index] to (*this)[dest_index].
Status CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
+ absl::Span<const int64> src_index,
+ absl::Span<const int64> dest_index);
// Sets an element in the literal at the given index. The multi_index is
// CHECKed against the dimension sizes.
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
+ void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
+ NativeT value);
// Overloads of Set for array literals. CHECKs if the literal is not
// array-shaped and dense.
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+ void Set(absl::Span<const int64> multi_index, NativeT value);
// Appends the given element to the literal. If the elements are not appended
// in sorted order, then SortSparseElements should be called before calling
// other methods. This literal must have a sparse layout.
template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
+ void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
+ const ShapeIndex& shape_index = {});
// Sorts the elements in a sparse array.
void SortSparseElements(const ShapeIndex& shape_index = {});
// As Set(), but truncates `value` to the literal element type before storing.
// This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
+ Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
// Populate this literal with the given values. Examples:
//
@@ -653,7 +642,7 @@ class MutableLiteralBase : public LiteralBase {
// example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
// array of S32.
template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(absl::Span<const NativeT> values);
void PopulateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
@@ -670,7 +659,7 @@ class MutableLiteralBase : public LiteralBase {
// in this literal object.
//
// generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ // NativeT(absl::Span<int64> indexes) or compatible.
//
// This literal must have a dense layout.
template <typename NativeT, typename FnType>
@@ -690,12 +679,10 @@ class MutableLiteralBase : public LiteralBase {
// moved into the tuple elements of a new tuple-shaped Literal which is
// returned. Upon return, each of the Literals in 'elements' is set to a nil
// shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
+ static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
protected:
// Returns the piece at the given ShapeIndex.
@@ -709,20 +696,20 @@ class MutableLiteralBase : public LiteralBase {
// arguments one by one.
template <typename NativeT>
Status CopySliceFromInternal(const LiteralBase& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ absl::Span<const int64> src_base,
+ absl::Span<const int64> dest_base,
+ absl::Span<const int64> copy_size);
// Utility structure which is used to create the optimal configuration for
// a ShapeUtil::ForEachIndex() scan across two literals.
struct StrideConfig {
StrideConfig(const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// The dimensions of the stride operation. Essentially every dimension
// will be iterated from base[i] to base[i]+dimensions[i], in step[i]
// steps.
- tensorflow::gtl::ArraySlice<int64> dimensions;
+ absl::Span<const int64> dimensions;
DimensionVector base;
DimensionVector step;
int64 minor_dimension = 0;
@@ -851,7 +838,7 @@ class BorrowingLiteral : public LiteralBase {
// This constructor is only used for array shapes.
BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
// Similar as above, except to be used for constructing non-nested tuples.
- BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
+ BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
const Shape& shape);
// TODO(b/79707221): adding constructors for nested tuples as well.
@@ -871,7 +858,7 @@ class BorrowingLiteral : public LiteralBase {
};
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
+absl::Span<const NativeT> LiteralBase::Piece::data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -879,12 +866,12 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
<< PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::ArraySlice<NativeT>(
- reinterpret_cast<const NativeT*>(buffer()), element_count());
+ return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
+ element_count());
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
+absl::Span<NativeT> LiteralBase::Piece::data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -892,20 +879,19 @@ tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
<< PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::MutableArraySlice<NativeT>(
- reinterpret_cast<NativeT*>(buffer()), element_count());
+ return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
+ element_count());
}
template <typename NativeT>
-NativeT LiteralBase::Piece::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)];
}
template <typename NativeT>
-void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
NativeT value) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
@@ -913,39 +899,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
}
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
+absl::Span<const NativeT> LiteralBase::data(
const ShapeIndex& shape_index) const {
return piece(shape_index).data<NativeT>();
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
- const ShapeIndex& shape_index) {
+absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
return piece(shape_index).data<NativeT>();
}
template <typename NativeT>
-inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
const ShapeIndex& shape_index) const {
return piece(shape_index).Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline NativeT LiteralBase::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
+inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
return root_piece().Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline void MutableLiteralBase::Set(
- tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
+inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
+ const ShapeIndex& shape_index,
+ NativeT value) {
return piece(shape_index).Set<NativeT>(multi_index, value);
}
template <typename NativeT>
-inline void MutableLiteralBase::Set(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
+inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
+ NativeT value) {
return root_piece().Set<NativeT>(multi_index, value);
}
@@ -964,7 +948,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
template <typename NativeT>
void MutableLiteralBase::AppendSparseElement(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
+ absl::Span<const int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
Piece& p = piece(shape_index);
const Shape& subshape = p.subshape();
@@ -980,8 +964,7 @@ void MutableLiteralBase::AppendSparseElement(
template <typename NativeT>
void LiteralBase::EachCell(
- std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
+ std::function<void(absl::Span<const int64> indices, NativeT value)>
per_cell) const {
if (ShapeUtil::IsZeroElementArray(shape())) {
return;
@@ -989,12 +972,11 @@ void LiteralBase::EachCell(
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
do {
per_cell(indices, Get<NativeT>(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
+ } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices)));
}
template <typename NativeT>
-inline void MutableLiteralBase::PopulateR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
+inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
@@ -1039,8 +1021,9 @@ void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
for (int dim = 0; dim < values.num_dimensions(); ++dim) {
CHECK_EQ(values.dim(dim), shape().dimensions(dim));
}
- values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value) { this->Set(indices, value); });
+ values.Each([this](absl::Span<const int64> indices, NativeT value) {
+ this->Set(indices, value);
+ });
}
template <typename NativeT>
@@ -1059,9 +1042,9 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
}
template <typename NativeT>
-void MutableLiteralBase::PopulateSparse(
- SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
+void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort) {
CHECK(LayoutUtil::IsSparseArray(shape()));
int rank = ShapeUtil::Rank(shape());
CHECK_EQ(indices.rank(), rank);
@@ -1071,7 +1054,7 @@ void MutableLiteralBase::PopulateSparse(
CHECK_LE(num_elements, max_elements);
CHECK_EQ(num_elements, indices.index_count());
auto root_data = root_piece().data<NativeT>();
- // Piece::data() returns an ArraySlice of size equal to the number of indices
+ // Piece::data() returns a Span of size equal to the number of indices
// in the SparseIndexArray. So there is no need to adjust the size of the data
// here. It is enough to just copy the incoming values into the data buffer.
std::copy(values.begin(), values.end(), root_data.begin());
@@ -1091,14 +1074,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
- tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
+ absl::Span<NativeT> literal_data = data<NativeT>();
if (rank > 0) {
StrideConfig stride_config(this_shape, this_shape,
AsInt64Slice(this_shape.dimensions()));
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
- auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto init_function = [&](absl::Span<const int64> indexes) {
DimensionVector minor_scan_indexes(rank, 0);
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
@@ -1116,7 +1099,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
ShapeUtil::ForEachIndex(
this_shape, stride_config.base, stride_config.dimensions,
stride_config.step,
- [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ [&init_function](absl::Span<const int64> indexes) {
init_function(indexes);
return true;
});
@@ -1148,27 +1131,26 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) {
}
template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal =
- MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal.shape());
if (elements == 0) {
return literal;
}
DimensionVector output_indices(bounds.size(), 0);
- tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
+ absl::Span<const int64> input_indices = output_indices;
input_indices.remove_prefix(1);
bool done = false;
while (!done) {
const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
+ literal.Set<NativeT>(output_indices, element);
done = true;
for (int n = 0; n < output_indices.size(); ++n) {
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 94993cc874..3d8725ed70 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -19,16 +19,16 @@ limitations under the License.
#include <cmath>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
-using tensorflow::strings::Appendf;
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrAppendFormat;
+using absl::StrCat;
namespace xla {
namespace literal_comparison {
@@ -38,7 +38,8 @@ namespace {
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
// -- on miscompare, a nice error message is given in the AssertionFailure.
template <typename FloatT, typename UnsignedT>
-Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
+ absl::Span<const int64> multi_index) {
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
auto lhs_double = static_cast<double>(lhs);
@@ -46,9 +47,10 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
if (ulhs != urhs) {
return InvalidArgument(
"floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a",
- StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
- StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
+ "was requested: %s=%g=%a vs %s=%g=%a at array index %s",
+ StrCat(absl::Hex(ulhs)), lhs_double, lhs_double,
+ StrCat(absl::Hex(urhs)), rhs_double, rhs_double,
+ LiteralUtil::MultiIndexAsString(multi_index));
}
return Status::OK();
}
@@ -57,39 +59,47 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
// bitwise helper above (this is the un-specialized fallback, to just use the
// default gunit implementation).
template <typename NativeT>
-Status CompareEqual(NativeT lhs, NativeT rhs) {
+Status CompareEqual(NativeT lhs, NativeT rhs,
+ absl::Span<const int64> multi_index) {
if (lhs == rhs) {
return Status::OK();
}
- return InvalidArgument("Expected equality of these values:\n %s\n %s",
- StrCat(lhs).c_str(), StrCat(rhs).c_str());
+ return InvalidArgument(
+ "first mismatch at array index %s:\n expected value: %s\n actual "
+ "value: %s",
+ LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
}
// Specializations for floating types that do bitwise comparisons when equality
// comparison is requested.
template <>
-Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
- return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
+ absl::Span<const int64> multi_index) {
+ return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
- return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
+ absl::Span<const int64> multi_index) {
+ return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<float>(float lhs, float rhs) {
- return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+Status CompareEqual<float>(float lhs, float rhs,
+ absl::Span<const int64> multi_index) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<double>(double lhs, double rhs) {
- return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+Status CompareEqual<double>(double lhs, double rhs,
+ absl::Span<const int64> multi_index) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
}
template <>
-Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
- auto res = CompareEqual<float>(lhs.real(), rhs.real());
+Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
+ absl::Span<const int64> multi_index) {
+ auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index);
if (!res.ok()) {
return res;
}
- return CompareEqual<float>(lhs.imag(), rhs.imag());
+ return CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
}
// A recursive function which iterates through every index of expected and
@@ -97,18 +107,18 @@ Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
// elements are equal.
template <typename NativeT>
Status Equal(LiteralSlice expected, LiteralSlice actual,
- tensorflow::gtl::MutableArraySlice<int64> multi_index,
- int64 dimension) {
+ absl::Span<int64> multi_index, int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
NativeT expected_value = expected.Get<NativeT>(multi_index);
NativeT actual_value = actual.Get<NativeT>(multi_index);
- return CompareEqual<NativeT>(expected_value, actual_value);
+ return CompareEqual<NativeT>(expected_value, actual_value, multi_index);
}
Status result;
for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
multi_index[dimension] = i;
- result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+ TF_RETURN_IF_ERROR(
+ Equal<NativeT>(expected, actual, multi_index, dimension + 1));
}
return result;
}
@@ -152,15 +162,26 @@ bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
static_cast<float>(actual), relaxed_nans);
}
+// Returns whether the given value is infinity.
+template <typename NativeT>
+bool IsInf(NativeT val) {
+ return std::isinf(val);
+}
+
+template <>
+bool IsInf<half>(half val) {
+ return std::isinf(static_cast<float>(val));
+}
+
// Converts the given floating-point value to a string.
template <typename NativeT>
string FpValueToString(NativeT value) {
- return Printf("%8.4g", static_cast<double>(value));
+ return absl::StrFormat("%8.4g", static_cast<double>(value));
}
template <>
string FpValueToString<complex64>(complex64 value) {
- return Printf("%8.4g + %8.4fi", value.real(), value.imag());
+ return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
}
// Returns the absolute value of the given floating point value. This function
@@ -215,13 +236,12 @@ class NearComparator {
}
string ToString(const Shape& shape) const {
- return Printf(
+ return absl::StrFormat(
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
- FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
+ FpValueToString(actual), FpValueToString(expected),
LiteralUtil::MultiIndexAsString(
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
- linear_index))
- .c_str(),
+ linear_index)),
rel_error, abs_error);
}
};
@@ -240,17 +260,12 @@ class NearComparator {
// Runs the comparison between expected and actual literals.
Status Run() {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, ToStringTruncated(expected_));
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, ToStringTruncated(actual_));
-
// If the shapes mismatch, we simply fail the expectation instead of
// printing out data, as it's a type error rather than a value error.
TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
if (!ShapeUtil::IsArray(expected_.shape())) {
return InvalidArgument("Expected array shape; got %s.",
- ShapeUtil::HumanString(expected_.shape()).c_str());
+ ShapeUtil::HumanString(expected_.shape()));
}
mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
@@ -263,7 +278,7 @@ class NearComparator {
} else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
miscompare_callback_(expected_, actual_, mismatches_);
}
- return InvalidArgument("%s", ErrorMessage().c_str());
+ return InvalidArgument("%s", ErrorMessage());
}
// Insert the given absolute value into the absolute value bucket vector. The
@@ -288,8 +303,7 @@ class NearComparator {
}
// Insert the given error into the given error bucket vector.
- void UpdateErrorBucket(
- float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
+ void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
for (int i = 0; i < error_buckets.size(); ++i) {
if (error >= kErrorBucketBounds[i]) {
@@ -300,12 +314,13 @@ class NearComparator {
// Compares the two given elements from the expected and actual literals at
// the given literal_index and keeps track of various mismatch statistics.
- void CompareValues(NativeT expected, NativeT actual, int64 linear_index) {
+ template <typename T>
+ void CompareValues(T expected, T actual, int64 linear_index) {
const bool is_nan_mismatch =
NanMismatch(expected, actual, error_.relaxed_nans);
float abs_error;
float rel_error;
- if (actual == expected) {
+ if (CompareEqual<T>(expected, actual, {linear_index}).ok()) {
abs_error = 0;
rel_error = 0;
} else if (is_nan_mismatch) {
@@ -316,6 +331,12 @@ class NearComparator {
// weak ordering requirement of std containers.
abs_error = std::numeric_limits<float>::infinity();
rel_error = std::numeric_limits<float>::infinity();
+ } else if (IsInf(expected) || IsInf(actual)) {
+ // If either the expected or actual value is infinity but not both,
+ // then both absolute and relative error are regarded as inifity.
+ CHECK(!CompareEqual(expected, actual, {linear_index}).ok());
+ abs_error = std::numeric_limits<float>::infinity();
+ rel_error = std::numeric_limits<float>::infinity();
} else {
abs_error = FpAbsoluteValue(actual - expected);
rel_error = abs_error / FpAbsoluteValue(expected);
@@ -329,11 +350,11 @@ class NearComparator {
// bound is exceeded and vice versa.
if (is_abs_mismatch) {
num_abs_mismatches_++;
- UpdateErrorBucket(rel_error, &rel_error_buckets_);
+ UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
}
if (is_rel_mismatch) {
num_rel_mismatches_++;
- UpdateErrorBucket(abs_error, &abs_error_buckets_);
+ UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
}
UpdateAbsValueBucket(actual, is_mismatch);
@@ -358,15 +379,36 @@ class NearComparator {
mismatches_.data<bool>()[linear_index] = true;
}
+ // For complex64 types, we compare real and imaginary parts individually.
+ void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
+ bool mismatch = false;
+ CompareValues<float>(expected.real(), actual.real(), linear_index);
+ if (mismatches_.data<bool>()[linear_index] == true) {
+ mismatch = true;
+ // Delay the mismatch count increase for real part, instead increase
+ // mismatch by 1 for the entire complex number.
+ num_mismatches_--;
+ }
+ CompareValues<float>(expected.imag(), actual.imag(), linear_index);
+ if (mismatches_.data<bool>()[linear_index] == true) {
+ mismatch = true;
+ // Delay the mismatch count increase for imag part, instead increase
+ // mismatch by 1 for the entire complex number.
+ num_mismatches_--;
+ }
+ if (mismatch == true) {
+ num_mismatches_++;
+ }
+ mismatches_.data<bool>()[linear_index] = mismatch;
+ }
+
// Compares the two literals elementwise.
void CompareLiterals() {
// Fast path optimization for the case were layouts match.
if (LayoutUtil::Equal(actual_.shape().layout(),
expected_.shape().layout())) {
- tensorflow::gtl::ArraySlice<const NativeT> expected_data =
- expected_.data<NativeT>();
- tensorflow::gtl::ArraySlice<const NativeT> actual_data =
- actual_.data<NativeT>();
+ absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
+ absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
const int64 len = expected_data.size();
for (int64 i = 0; i < len; ++i) {
CompareValues(expected_data[i], actual_data[i], i);
@@ -402,23 +444,23 @@ class NearComparator {
auto percent_string = [](float a, float b) {
float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
- return Printf("%0.4f%%", pct);
+ return absl::StrFormat("%0.4f%%", pct);
};
- Appendf(&out,
- "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound "
- "%g, rel bound %g\n",
- num_mismatches_,
- percent_string(num_mismatches_, element_count).c_str(),
- ShapeUtil::HumanString(actual_.shape()).c_str(),
- ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
+ StrAppendFormat(
+ &out,
+ "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
+ "%g, rel bound %g\n",
+ num_mismatches_, percent_string(num_mismatches_, element_count),
+ ShapeUtil::HumanString(actual_.shape()),
+ ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
if (num_nan_mismatches_ > 0) {
StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
}
- Appendf(&out, "Top relative error mismatches:\n");
+ StrAppendFormat(&out, "Top relative error mismatches:\n");
for (auto it = top_rel_mismatches_.rbegin();
it != top_rel_mismatches_.rend(); ++it) {
- StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n");
+ StrAppend(&out, " ", it->ToString(actual_.shape()), "\n");
}
if (!detailed_message_) {
@@ -430,36 +472,37 @@ class NearComparator {
for (int i = 0; i < abs_value_buckets_.size(); ++i) {
const int64 bucket_size = abs_value_buckets_[i].first;
const int64 bucket_mismatches = abs_value_buckets_[i].second;
- string mismatch_str = bucket_mismatches > 0
- ? Printf(", mismatches %lld", bucket_mismatches)
- : "";
- Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n",
- kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
- bucket_size, percent_string(bucket_size, element_count).c_str(),
- mismatch_str.c_str());
+ string mismatch_str =
+ bucket_mismatches > 0
+ ? absl::StrFormat(", mismatches %d", bucket_mismatches)
+ : "";
+ StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n",
+ kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
+ bucket_size, percent_string(bucket_size, element_count),
+ mismatch_str);
}
auto print_accum_buckets = [&](const string& header, int64 total,
- tensorflow::gtl::ArraySlice<int64> buckets) {
+ absl::Span<const int64> buckets) {
StrAppend(&out, header, ":\n");
- Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0],
- total - buckets[0],
- percent_string(total - buckets[0], total).c_str());
+ StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
+ total - buckets[0],
+ percent_string(total - buckets[0], total));
CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
- Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i],
- buckets[i], percent_string(buckets[i], total).c_str());
+ StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
+ buckets[i], percent_string(buckets[i], total));
}
};
- Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n",
- error_.abs, num_abs_mismatches_,
- percent_string(num_abs_mismatches_, element_count).c_str());
+ StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
+ error_.abs, num_abs_mismatches_,
+ percent_string(num_abs_mismatches_, element_count));
print_accum_buckets(
"Relative error breakdown of elements exceeding abs error bound",
num_abs_mismatches_, rel_error_buckets_);
- Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n",
- error_.rel, num_rel_mismatches_,
- percent_string(num_rel_mismatches_, element_count).c_str());
+ StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
+ error_.rel, num_rel_mismatches_,
+ percent_string(num_rel_mismatches_, element_count));
print_accum_buckets(
"Absolute error breakdown of elements exceeding rel error bound",
num_rel_mismatches_, abs_error_buckets_);
@@ -528,6 +571,63 @@ constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
template <typename NativeT>
constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
+Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ auto index = absl::MakeSpan(multi_index);
+ Status result;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ result = Equal<bool>(expected, actual, index, 0);
+ break;
+ case U8:
+ result = Equal<uint8>(expected, actual, index, 0);
+ break;
+ case S32:
+ result = Equal<int32>(expected, actual, index, 0);
+ break;
+ case S64:
+ result = Equal<int64>(expected, actual, index, 0);
+ break;
+ case U32:
+ result = Equal<uint32>(expected, actual, index, 0);
+ break;
+ case U64:
+ result = Equal<uint64>(expected, actual, index, 0);
+ break;
+ case BF16:
+ result = Equal<bfloat16>(expected, actual, index, 0);
+ break;
+ case F16:
+ result = Equal<half>(expected, actual, index, 0);
+ break;
+ case F32:
+ result = Equal<float>(expected, actual, index, 0);
+ break;
+ case F64:
+ result = Equal<double>(expected, actual, index, 0);
+ break;
+ case C64:
+ result = Equal<complex64>(expected, actual, index, 0);
+ break;
+ case TUPLE: {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ result.Update(EqualHelper(LiteralSlice(expected, {i}),
+ LiteralSlice(actual, {i})));
+ }
+ break;
+ }
+ case TOKEN:
+ // Tokens have no on-device representation and are trivially equal.
+ return Status::OK();
+ default:
+ LOG(FATAL) << "Unsupported primitive type: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+
+ return result;
+}
+
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
@@ -544,17 +644,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
const auto actual_element = LiteralSlice(actual, {i});
ShapeIndex element_index = shape_index;
element_index.push_back(i);
- Status res =
+ Status element_result =
NearHelper(expected_element, actual_element, error, detailed_message,
miscompare_callback, element_index);
- if (!res.ok()) {
- string err_message = Printf("\nArray at shape index %s%s",
- element_index.ToString().c_str(),
- res.error_message().c_str());
+ if (!element_result.ok()) {
+ element_result = InvalidArgument("Array at shape index %s, %s",
+ element_index.ToString(),
+ element_result.error_message());
if (return_status.ok()) {
- return_status = res;
+ return_status = element_result;
} else {
- return_status = AppendStatus(return_status, res.error_message());
+ return_status =
+ AppendStatus(return_status, element_result.error_message());
}
}
}
@@ -562,10 +663,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
// Emit a top-level error message containing the top-level shape in case
// of mismatch.
int64 total_elements = RecursiveElementCount(actual.shape());
- return_status = InvalidArgument(
- "\nMismatches in shape %s (%lld elements):\n%s",
- ShapeUtil::HumanString(actual.shape()).c_str(), total_elements,
- return_status.error_message().c_str());
+ return_status =
+ InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
+ ShapeUtil::HumanString(actual.shape()),
+ total_elements, return_status.error_message());
}
return return_status;
}
@@ -600,8 +701,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
}
}
- // Non-floating point literal.
- return literal_comparison::Equal(expected, actual);
+ // Non-floating point, non-tuple literal.
+ return EqualHelper(expected, actual);
}
} // namespace
@@ -609,14 +710,14 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
Status EqualShapes(const Shape& expected, const Shape& actual) {
if (expected.element_type() != actual.element_type()) {
return InvalidArgument("element type mismatch, want: %s got %s",
- ShapeUtil::HumanString(expected).c_str(),
- ShapeUtil::HumanString(actual).c_str());
+ ShapeUtil::HumanString(expected),
+ ShapeUtil::HumanString(actual));
}
if (ShapeUtil::IsTuple(expected)) {
if (ShapeUtil::TupleElementCount(expected) !=
ShapeUtil::TupleElementCount(actual)) {
return InvalidArgument(
- "want tuple element count: %lld got tuple element count: %lld",
+ "want tuple element count: %d got tuple element count: %d",
ShapeUtil::TupleElementCount(expected),
ShapeUtil::TupleElementCount(actual));
}
@@ -630,14 +731,13 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
} else if (ShapeUtil::IsArray(expected)) {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return InvalidArgument("want rank of %s got rank of %s",
- ShapeUtil::HumanString(expected).c_str(),
- ShapeUtil::HumanString(actual).c_str());
+ ShapeUtil::HumanString(expected),
+ ShapeUtil::HumanString(actual));
}
if (expected.element_type() != actual.element_type()) {
- return InvalidArgument(
- "mismatch in primitive type %s vs %s",
- PrimitiveType_Name(expected.element_type()).c_str(),
- PrimitiveType_Name(actual.element_type()).c_str());
+ return InvalidArgument("mismatch in primitive type %s vs %s",
+ PrimitiveType_Name(expected.element_type()),
+ PrimitiveType_Name(actual.element_type()));
}
if (expected.dimensions_size() != actual.dimensions_size()) {
return InvalidArgument("want dimensions_size %d got dimensions_size %d",
@@ -648,8 +748,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
if (expected.dimensions(i) != actual.dimensions(i)) {
return InvalidArgument(
"mismatch in dimension #%d expected: %s actual: %s", i,
- ShapeUtil::HumanString(expected).c_str(),
- ShapeUtil::HumanString(actual).c_str());
+ ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
}
}
}
@@ -657,83 +756,43 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return Status::OK();
}
+namespace {
+
+// If result is an error, extend the error message with the expected and actual
+// literals.
+Status EmitLiteralsInErrorMessage(const Status& result,
+ const LiteralSlice& expected,
+ const LiteralSlice& actual) {
+ if (result.ok()) {
+ return result;
+ }
+ return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
+ result.error_message(), ToStringTruncated(expected),
+ ToStringTruncated(actual));
+}
+
+} // namespace
+
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
-
- TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
- std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
- Status result;
- switch (expected.shape().element_type()) {
- case PRED:
- result = Equal<bool>(expected, actual, &multi_index, 0);
- break;
- case U8:
- result = Equal<uint8>(expected, actual, &multi_index, 0);
- break;
- case S32:
- result = Equal<int32>(expected, actual, &multi_index, 0);
- break;
- case S64:
- result = Equal<int64>(expected, actual, &multi_index, 0);
- break;
- case U32:
- result = Equal<uint32>(expected, actual, &multi_index, 0);
- break;
- case U64:
- result = Equal<uint64>(expected, actual, &multi_index, 0);
- break;
- case BF16:
- result = Equal<bfloat16>(expected, actual, &multi_index, 0);
- break;
- case F16:
- result = Equal<half>(expected, actual, &multi_index, 0);
- break;
- case F32:
- result = Equal<float>(expected, actual, &multi_index, 0);
- break;
- case F64:
- result = Equal<double>(expected, actual, &multi_index, 0);
- break;
- case C64:
- result = Equal<complex64>(expected, actual, &multi_index, 0);
- break;
- case TUPLE: {
- for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- result.Update(
- Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
- }
- break;
- }
- case TOKEN:
- // Tokens have no on-device representation and are trivially equal.
- return Status::OK();
- default:
- LOG(FATAL)
- << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
- << PrimitiveType_Name(expected.shape().element_type());
- }
-
- if (result.ok()) {
- return Status::OK();
- }
-
- return AppendStatus(result,
- tensorflow::strings::Printf(
- "\nat index: %s\nexpected: %s\nactual: %s",
- LiteralUtil::MultiIndexAsString(multi_index).c_str(),
- ToStringTruncated(expected).c_str(),
- ToStringTruncated(actual).c_str()));
+ Status result = EqualHelper(expected, actual);
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
const ErrorSpec& error, bool detailed_message,
const MiscompareCallback& miscompare_callback) {
- return NearHelper(expected, actual, error, detailed_message,
- miscompare_callback,
- /*shape_index=*/{});
+ VLOG(1) << "Expected literal:";
+ XLA_VLOG_LINES(1, expected.ToString());
+ VLOG(1) << "Actual literal:";
+ XLA_VLOG_LINES(1, actual.ToString());
+ Status result =
+ NearHelper(expected, actual, error, detailed_message, miscompare_callback,
+ /*shape_index=*/{});
+ return EmitLiteralsInErrorMessage(result, expected, actual);
}
string ToStringTruncated(const LiteralSlice& literal) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index e8f919950f..ba7fd29a62 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -17,6 +17,9 @@ limitations under the License.
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -33,7 +36,6 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
@@ -90,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test {
Layout layout_r3_dim0minor_;
Layout layout_r4_dim0major_;
Layout layout_r4_dim0minor_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
- std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+ Literal literal_r4_2x2x3x3_dim0major_;
+ Literal literal_r4_2x2x3x3_dim0minor_;
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto true_lit = LiteralUtil::CreateR0<bool>(true);
- ASSERT_EQ("true", true_lit->ToString());
+ EXPECT_EQ("true", true_lit.ToString());
auto false_lit = LiteralUtil::CreateR0<bool>(false);
- ASSERT_EQ("false", false_lit->ToString());
+ EXPECT_EQ("false", false_lit.ToString());
auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
- ASSERT_EQ("42", u32_lit->ToString());
+ EXPECT_EQ("42", u32_lit.ToString());
auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
- ASSERT_EQ("-999", s32_lit->ToString());
+ EXPECT_EQ("-999", s32_lit.ToString());
auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
- ASSERT_EQ("3.14", f32_lit->ToString());
+ EXPECT_EQ("3.14", f32_lit.ToString());
auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
- ASSERT_EQ("0.5", f16_lit->ToString());
+ EXPECT_EQ("0.5", f16_lit.ToString());
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
- ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
+ EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
- ASSERT_EQ("0.5", bf16_lit->ToString());
+ EXPECT_EQ("0.5", bf16_lit.ToString());
- // 3.14 will be truncated to 3.125 in bfloat16 format.
+ // 3.14 will be rounded to 3.14062 in bfloat16 format.
auto bf16_lit_truncated =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
- ASSERT_EQ("3.125", bf16_lit_truncated->ToString());
+ ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
auto bf16_lit_truncated2 =
LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
- ASSERT_EQ("9", bf16_lit_truncated2->ToString());
+ EXPECT_EQ("9", bf16_lit_truncated2.ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
- ASSERT_EQ("{101}", pred_vec->ToString());
+ EXPECT_EQ("{101}", pred_vec.ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
@@ -141,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) {
{ 3, 4 },
{ 5, 6 }
})";
- ASSERT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, R3ToString) {
@@ -155,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) {
{ { 5 },
{ 6 } }
})";
- ASSERT_EQ(expected, literal->ToString());
+ EXPECT_EQ(expected, literal.ToString());
}
TEST_F(LiteralUtilTest, TupleToString) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -169,7 +171,7 @@ f32[2,2] {
{ 3, 4 }
}
))";
- ASSERT_EQ(expected, tuple->ToString());
+ EXPECT_EQ(expected, tuple.ToString());
}
TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -185,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
// clang-format on
auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[2,3,2] {
{ { 1, 2 },
{ 3, 4 },
@@ -195,7 +197,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
{ 9, 10 },
{ 11, 12 } }
})";
- ASSERT_EQ(expected, result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, CreateSparse) {
@@ -218,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) {
};
std::vector<int64> expected_values = {8, 9, 7, 10};
- EXPECT_EQ(literal->sparse_indices()->data(),
- ArraySlice<int64>(expected_indices.data(),
- expected_indices.num_elements()));
- EXPECT_EQ(literal->data<int64>(), ArraySlice<int64>(expected_values));
+ EXPECT_EQ(literal.sparse_indices()->data(),
+ absl::Span<const int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -232,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
{2001, 2002},
}, /*projection_p=*/1, /*projection_z=*/2);
// clang-format on
- EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
- string result = literal->ToString();
+ EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+ string result = literal.ToString();
const string expected = R"(f32[1,2,3,2] {
{ /*i0=0*/
{ /*i1=0*/
@@ -248,13 +250,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
}
}
})";
- ASSERT_EQ(expected, result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
- EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+ EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
ElementsAre(2, 2, 3, 3));
- string result = literal_r4_2x2x3x3_dim0major_->ToString();
+ string result = literal_r4_2x2x3x3_dim0major_.ToString();
const string expected = R"(f32[2,2,3,3] {
{ /*i0=0*/
{ /*i1=0*/
@@ -281,7 +283,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
}
}
})";
- ASSERT_EQ(expected, result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, EachCellR2F32) {
@@ -292,8 +294,8 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
});
// clang-format on
std::vector<std::tuple<int64, int64, string>> seen;
- literal->EachCellAsString(
- [&seen](ArraySlice<int64> indices, const string& value) {
+ literal.EachCellAsString(
+ [&seen](absl::Span<const int64> indices, const string& value) {
seen.emplace_back(indices[0], indices[1], value);
});
@@ -308,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) {
auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
- EXPECT_EQ(*f32_42, *f32_42);
- EXPECT_EQ(*f32_42, *f32_42_clone);
+ EXPECT_EQ(f32_42, f32_42);
+ EXPECT_EQ(f32_42, f32_42_clone);
auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*f32_42, *f32_123);
+ EXPECT_NE(f32_42, f32_123);
auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
- EXPECT_NE(*f32_42, *f64_42);
+ EXPECT_NE(f32_42, f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -328,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(*matrix, *matrix);
- EXPECT_EQ(*matrix, *matrix_clone);
- EXPECT_NE(*matrix, *matrix_different);
- EXPECT_NE(*matrix, *vector_literal);
- EXPECT_NE(*matrix, *scalar);
- EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(matrix, matrix);
+ EXPECT_EQ(matrix, matrix_clone);
+ EXPECT_NE(matrix, matrix_different);
+ EXPECT_NE(matrix, vector_literal);
+ EXPECT_NE(matrix, scalar);
+ EXPECT_NE(matrix, nil);
EXPECT_EQ(nil, nil);
}
@@ -342,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) {
auto token1 = LiteralUtil::CreateToken();
auto scalar = LiteralUtil::CreateR0<float>(1.0);
- EXPECT_EQ(*token0, *token1);
- EXPECT_NE(*token0, *scalar);
+ EXPECT_EQ(token0, token1);
+ EXPECT_NE(token0, scalar);
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
- *LiteralUtil::MakeTuple({token0.get()}));
- EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
- EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
- *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+ LiteralUtil::MakeTuple({&token0}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&token1, &scalar}));
+ EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+ LiteralUtil::MakeTuple({&scalar, &token1}));
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor =
- MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
- colmajor->Set<float>({0, 0}, 1.0);
- colmajor->Set<float>({0, 1}, 2.0);
- colmajor->Set<float>({1, 0}, 3.0);
- colmajor->Set<float>({1, 1}, 4.0);
+ Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+ colmajor.Set<float>({0, 0}, 1.0);
+ colmajor.Set<float>({0, 1}, 2.0);
+ colmajor.Set<float>({1, 0}, 3.0);
+ colmajor.Set<float>({1, 1}, 4.0);
- auto rowmajor =
- MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
- rowmajor->Set<float>({0, 0}, 1.0);
- rowmajor->Set<float>({0, 1}, 2.0);
- rowmajor->Set<float>({1, 0}, 3.0);
- rowmajor->Set<float>({1, 1}, 4.0);
+ Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+ rowmajor.Set<float>({0, 0}, 1.0);
+ rowmajor.Set<float>({0, 1}, 2.0);
+ rowmajor.Set<float>({1, 0}, 3.0);
+ rowmajor.Set<float>({1, 1}, 4.0);
- EXPECT_EQ(*rowmajor, *colmajor);
+ EXPECT_EQ(rowmajor, colmajor);
}
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
- auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
- EXPECT_EQ(*tuple1, *tuple2);
+ auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+ EXPECT_EQ(tuple1, tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
- EXPECT_NE(*tuple1, *reversed_tuple);
+ auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+ EXPECT_NE(tuple1, reversed_tuple);
// Tuple with different value.
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
- auto different_tuple =
- LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
- EXPECT_NE(*tuple1, *different_tuple);
+ auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+ EXPECT_NE(tuple1, different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
@@ -403,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) {
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- EXPECT_EQ(*vector, *vector_clone);
+ EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
- EXPECT_NE(*vector, *vector_reversed);
+ EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+ auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
// Tuples should always return false for IsAll.
- EXPECT_FALSE(tuple->IsAll(0));
- EXPECT_FALSE(tuple->IsAll(1));
+ EXPECT_FALSE(tuple.IsAll(0));
+ EXPECT_FALSE(tuple.IsAll(1));
}
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
auto scalar = LiteralUtil::CreateR0<float>(0.0);
auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
- auto x = Literal::CreateFromShape(tuple->shape());
- EXPECT_EQ(*tuple, *x);
+ auto x = Literal::CreateFromShape(tuple.shape());
+ EXPECT_EQ(tuple, x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
- EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
- EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
- ->IsAll(-1));
+ .IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
- ->IsAllFloat(.5));
+ .IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
EXPECT_FALSE(
- LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
- ->IsAllComplex({8.0f, 9.0f}));
+ .IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(
- LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
- EXPECT_TRUE(scalar_zero->IsZero({}));
- EXPECT_FALSE(scalar_one->IsZero({}));
+ EXPECT_TRUE(scalar_zero.IsZero({}));
+ EXPECT_FALSE(scalar_one.IsZero({}));
auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
- EXPECT_FALSE(array->IsZero({0, 1}));
- EXPECT_TRUE(array->IsZero({0, 2}));
- EXPECT_TRUE(array->IsZero({1, 1}));
- EXPECT_FALSE(array->IsZero({1, 2}));
+ EXPECT_FALSE(array.IsZero({0, 1}));
+ EXPECT_TRUE(array.IsZero({0, 2}));
+ EXPECT_TRUE(array.IsZero({1, 1}));
+ EXPECT_FALSE(array.IsZero({1, 2}));
auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
- EXPECT_TRUE(complex_zero->IsZero({}));
- EXPECT_FALSE(complex_nonzero->IsZero({}));
+ EXPECT_TRUE(complex_zero.IsZero({}));
+ EXPECT_FALSE(complex_nonzero.IsZero({}));
}
template <typename T>
@@ -574,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
- auto data01 = data->Relayout(layout01);
- EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
- EXPECT_EQ(*data, *data01);
+ auto data01 = data.Relayout(layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+ EXPECT_EQ(data, data01);
- auto data10 = data->Relayout(layout10);
- EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
- EXPECT_EQ(*data, *data10);
+ auto data10 = data.Relayout(layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+ EXPECT_EQ(data, data10);
}
TEST_F(LiteralUtilTest, ReshapeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -604,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -624,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
}, layout_r3_dim0major_);
// clang-format on
- auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+ auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
- EXPECT_EQ(*expected, *reshape);
+ EXPECT_EQ(expected, reshape);
}
TEST_F(LiteralUtilTest, TransposeR0) {
auto original = LiteralUtil::CreateR0<float>(1.7f);
- auto reshape = original->Transpose(/*permutation=*/{});
- EXPECT_EQ(*original, *reshape);
+ auto reshape = original.Transpose(/*permutation=*/{});
+ EXPECT_EQ(original, reshape);
}
TEST_F(LiteralUtilTest, TransposeR4) {
@@ -644,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}});
// clang-format on
- auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+ auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
- reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
- EXPECT_EQ(value, original->Get<float>(
+ reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+ EXPECT_EQ(value, original.Get<float>(
{indices[2], indices[3], indices[0], indices[1]}));
});
}
@@ -656,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
// Tests that using Relayout on an array is equivalent to creating it in the
// target layout in the first place.
auto dim0minor_relaid_to_dim0major =
- literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+ literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
auto dim0major_relaid_to_dim0minor =
- literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
- EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+ literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+ EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->element_count(), 6);
- EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_EQ(mat_dim0minor.element_count(), 6);
+ EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
- auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+ EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->element_count(), 6);
- EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_EQ(mat_dim0major.element_count(), 6);
+ EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
- auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+ EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
ElementsAre(1, 4, 2, 5, 3, 6));
}
@@ -705,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->element_count(), 12);
+ EXPECT_EQ(lit_dim0minor.element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->data<int32>(),
+ EXPECT_THAT(lit_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
- auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+ auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+ EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->element_count(), 12);
- EXPECT_THAT(lit_dim0major->data<int32>(),
+ EXPECT_EQ(lit_dim0major.element_count(), 12);
+ EXPECT_THAT(lit_dim0major.data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
- auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+ auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+ EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
TEST_F(LiteralUtilTest, SliceR0S32) {
auto input = LiteralUtil::CreateR0<int32>(1);
- auto result = input->Slice({}, {});
- EXPECT_EQ(*input, *result);
+ auto result = input.Slice({}, {});
+ EXPECT_EQ(input, result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
- auto result = input->Slice({3}, {4});
+ auto result = input.Slice({3}, {4});
auto expected = LiteralUtil::CreateR1<float>({4.0});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
auto input_3x4 = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto result = input_3x4->Slice({0, 2}, {2, 4});
+ auto result = input_3x4.Slice({0, 2}, {2, 4});
auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
- EXPECT_EQ(*expected, *result);
+ EXPECT_EQ(expected, result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
- EXPECT_EQ(*input_2x3x2, *result);
+ auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+ EXPECT_EQ(input_2x3x2, result);
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = LiteralUtil::CreateR1<int64>({77});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -783,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -791,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR0<bfloat16>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -799,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -807,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
auto expected = LiteralUtil::CreateR0<float>(2.5f);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -836,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
output.PopulateWithValue<complex64>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -844,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
half h(0.25f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR0<half>(h);
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -852,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
half h(0.5f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR1<half>({h, h, h});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -860,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
half h(2.0f);
output.PopulateWithValue<half>(h);
auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
- EXPECT_EQ(output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
auto input = LiteralUtil::CreateR2<uint32>(
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
- auto output = input->Replicate<uint32>(3);
+ auto output = input.Replicate<uint32>(3);
auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
- EXPECT_EQ(*output, *expected);
+ EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -886,35 +884,35 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 zero_base[] = {0, 0, 0, 0};
const int64 step[] = {1, 1, 1, 1};
uint32 seqnr = 0;
- auto init_proc = [&](ArraySlice<int64> indexes) {
- source->Set(indexes, ++seqnr);
+ auto init_proc = [&](absl::Span<const int64> indexes) {
+ source.Set(indexes, ++seqnr);
return true;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
init_proc);
auto blank = Literal::CreateFromShape(shape);
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
bool matched = true;
- auto check_proc = [&](ArraySlice<int64> indexes) {
+ auto check_proc = [&](absl::Span<const int64> indexes) {
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
source_indexes.begin(), std::plus<int64>());
std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
blank_indexes.begin(), std::plus<int64>());
- auto bval = blank->Get<uint32>(blank_indexes);
- matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+ auto bval = blank.Get<uint32>(blank_indexes);
+ matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
return matched;
};
- ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+ ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
check_proc);
EXPECT_TRUE(matched);
}
@@ -923,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = LiteralUtil::CreateR0<uint32>(0);
auto nine = LiteralUtil::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->CopyFrom(*nine));
- EXPECT_EQ(*zero, *nine);
+ TF_EXPECT_OK(zero.CopyFrom(nine));
+ EXPECT_EQ(zero, nine);
auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
- EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
- EXPECT_EQ(vect->Get<uint32>({4}), 17);
+ TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+ EXPECT_EQ(zero.Get<uint32>({}), 17);
+ TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+ EXPECT_EQ(vect.Get<uint32>({4}), 17);
}
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -943,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
- EXPECT_EQ(*nine, *const_nine);
+ TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+ EXPECT_EQ(nine, const_nine);
}
{
// Copy 0 element to destination with zero elements.
- const auto empty = Literal::CreateFromShape(empty_r1_shape);
+ auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = LiteralUtil::CreateR1<float>({9});
- TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
- EXPECT_EQ(*empty, *const_empty);
+ TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+ EXPECT_EQ(empty, const_empty);
}
}
@@ -967,76 +965,77 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
TEST_F(LiteralUtilTest, CopyFromArrays) {
auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
- EXPECT_NE(*scalar_42, *scalar_123);
- TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*scalar_42, *scalar_123);
- EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+ EXPECT_NE(scalar_42, scalar_123);
+ TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(scalar_42, scalar_123);
+ EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
- EXPECT_NE(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
- TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
- /*src_shape_index=*/{}));
- EXPECT_EQ(*matrix_1234, *matrix_5678);
- EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+ EXPECT_NE(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(matrix_1234, matrix_5678);
+ EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {matrix.get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get()});
+ Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0})};
+ Literal inner_tuple = LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal});
+ Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(-5).get(),
- LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+ Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+ Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+ Literal tuple =
+ LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
// Overwrite the inner tuple element of nested_tuple with the contents of
// 'tuple'.
- TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{}));
+ TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+ EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
- EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
- EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
- EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+ EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
- LiteralUtil::CreateR0<int32>(4).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+ LiteralUtil::CreateR0<int32>(4)};
+ Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
// Copy from one element to the other.
- TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
- /*src_shape_index=*/{0}));
+ TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
- EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
- EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
- Status status = matrix->CopyFrom(*vector);
+ Status status = matrix.CopyFrom(vector);
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
+ EXPECT_THAT(status.error_message(),
HasSubstr("Destination subshape incompatible"));
}
@@ -1044,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
- auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
- Literal* l1 = m1.get();
- const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+ Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+ const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
@@ -1059,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l2 = m2.get();
- const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+ const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
@@ -1089,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = MakeUnique<Literal>(shape);
- auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ Literal literal(shape);
+ auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->Populate<uint32>(generator));
+ TF_EXPECT_OK(literal.Populate<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
- auto check_function = [&](ArraySlice<int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto check_function = [&](absl::Span<const int64> indexes) {
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1131,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
Shape shape = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
data.layout);
- auto literal = MakeUnique<Literal>(shape);
- auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ Literal literal(shape);
+ auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
indexes) +
17;
};
- TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+ TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
bool matched = true;
- auto check_function = [&](ArraySlice<int64> indexes) {
- auto value = literal->Get<uint32>(indexes);
+ auto check_function = [&](absl::Span<const int64> indexes) {
+ auto value = literal.Get<uint32>(indexes);
matched = matched && (value == generator(indexes));
return matched;
};
- ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
check_function);
EXPECT_TRUE(matched);
}
@@ -1168,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) {
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// clang-format on
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->Convert(U32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
- EXPECT_EQ(*expected, *converted);
+ EXPECT_EQ(expected, converted);
}
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1243,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
- std::unique_ptr<Literal> conv;
+ Literal conv;
- conv = s8->Convert(U32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u32);
+ conv = s8.Convert(U32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u32);
- conv = s8->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = s8.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s8->Convert(U64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *u64);
+ conv = s8.Convert(U64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, u64);
- conv = s8->Convert(S64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s64);
+ conv = s8.Convert(S64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s64);
- conv = s8->Convert(PRED).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *pred);
+ conv = s8.Convert(PRED).ConsumeValueOrDie();
+ EXPECT_EQ(conv, pred);
- conv = bf16->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = bf16.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = bf16->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = bf16.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = pred->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *int32_pred);
+ conv = pred.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, int32_pred);
- conv = f32->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f32.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = f64->Convert(S32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *s32);
+ conv = f64.Convert(S32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, s32);
- conv = s32->Convert(F32).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f32);
+ conv = s32.Convert(F32).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f32);
- conv = f32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = f64->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = f64.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = s32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = u32->Convert(F16).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *f16);
+ conv = u32.Convert(F16).ConsumeValueOrDie();
+ EXPECT_EQ(conv, f16);
- conv = s32->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = s32.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- conv = f16->Convert(C64).ConsumeValueOrDie();
- EXPECT_EQ(*conv, *c64);
+ conv = f16.Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(conv, c64);
- EXPECT_EQ(s32->Convert(TUPLE).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(S16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(s32->Convert(U16).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(F32).status().code(),
- tensorflow::error::UNIMPLEMENTED);
- EXPECT_EQ(c64->Convert(S32).status().code(),
+ EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+ EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1315,16 +1307,15 @@ TEST_F(LiteralUtilTest, BitcastConvert) {
tensorflow::bit_cast<uint32>(100.f), 0xbeef});
auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
- original->BitcastConvert(F32));
+ TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
auto literal = LiteralUtil::CreateR0<uint32>(1234);
- Status status = literal->BitcastConvert(F64).status();
+ Status status = literal.BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
- EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
- "bit widths are different"));
+ EXPECT_TRUE(
+ absl::StrContains(status.error_message(), "bit widths are different"));
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
@@ -1339,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
p.add_preds((i % 2) == (len % 2));
}
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- ASSERT_EQ(len, literal->data<bool>().size());
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal.data<bool>().size());
int i = 0;
- for (bool value : literal->data<bool>()) {
+ for (bool value : literal.data<bool>()) {
EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
@@ -1356,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h2(2.0f);
auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->data<half>().size());
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+ EXPECT_EQ(4, m.data<half>().size());
- LiteralProto p = l->ToProto();
+ LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
@@ -1387,56 +1376,53 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
- Literal::CreateFromProto(p));
- auto r = literal->data<half>();
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+ auto r = literal.data<half>();
ASSERT_EQ(4, r.size());
- ASSERT_EQ(h1, r[0]);
- ASSERT_EQ(h2, r[1]);
- ASSERT_EQ(h2, r[2]);
- ASSERT_EQ(h1, r[3]);
+ EXPECT_EQ(h1, r[0]);
+ EXPECT_EQ(h2, r[1]);
+ EXPECT_EQ(h2, r[2]);
+ EXPECT_EQ(h1, r[3]);
}
TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+ EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+ EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+ EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 1.0f);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
1.0f);
- nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
- EXPECT_EQ(
- nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
- 555.0f);
+ nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
/*shape_index=*/{0, 0}),
555.0f);
@@ -1445,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = LiteralUtil::CreateR0<float>(1.0);
auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+ auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
- const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view,
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1495,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
}
TEST_F(LiteralUtilTest, LiteralMove) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- Literal literal(std::move(*matrix));
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(matrix));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1509,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
- .get(),
- &nil_literal});
-
- EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
- std::vector<Literal> elements = nested_tuple->DecomposeTuple();
- EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+ Literal inner_elements[] = {
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}),
+ };
+ Literal tuple_elements[] = {
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+ LiteralUtil::MakeTuple(
+ {&inner_elements[0], &inner_elements[1], &nil_literal}),
+ };
+ Literal nested_tuple = LiteralUtil::MakeTuple(
+ {&tuple_elements[0], &tuple_elements[1], &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+ std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
ASSERT_EQ(elements.size(), 3);
@@ -1550,15 +1539,15 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
- elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
- ));
-
- Literal literal = Literal::MoveIntoTuple(&elements);
+ elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+ elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+ std::vector<Literal> inner_elements;
+ inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+ inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+ elements.push_back(
+ LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
+
+ Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
@@ -1577,16 +1566,15 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) {
TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
Literal literal = Literal::MoveIntoTuple({});
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
- ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
+ EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
}
TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
Literal literal;
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- literal = std::move(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(matrix);
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1597,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
}
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
- std::unique_ptr<Literal> matrix =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralSlice(*matrix);
+ Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralSlice(matrix);
LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1609,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(42.0).get(),
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
- tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
-
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
- 3.0);
- tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
- EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ Literal elements[] = {
+ LiteralUtil::CreateR0<float>(42.0),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ };
+ auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+ tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-4.0);
}
TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
// Literals constructed using CreateFromShape should be zero initialized.
- std::unique_ptr<Literal> scalar_f32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
- EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
- EXPECT_TRUE(scalar_f32->IsAll(0));
-
- std::unique_ptr<Literal> vector_s32 =
- Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
- EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
- EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
- EXPECT_TRUE(vector_s32->IsAll(0));
-
- std::unique_ptr<Literal> tuple =
- Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
- {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
- ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
-
- EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
- EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
- EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
- EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
- EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+ Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32.IsAll(0));
+
+ Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32.IsAll(0));
+
+ Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1663,25 +1648,25 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto matrix_pred =
LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
auto tuple = LiteralUtil::MakeTuple(
- {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = LiteralUtil::MakeTuple(
- {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+ auto nested_tuple =
+ LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
- return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
};
- EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
- EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
- EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
- EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
- EXPECT_EQ(*tuple, to_from_proto(*tuple));
- EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+ EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+ EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+ EXPECT_EQ(tuple, to_from_proto(tuple));
+ EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
- EXPECT_NE(*one_f32, *two_f32);
- EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+ EXPECT_NE(one_f32, two_f32);
+ EXPECT_NE(one_f32, to_from_proto(two_f32));
}
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1690,7 +1675,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
+ EXPECT_THAT(status.error_message(),
HasSubstr("Expected 3 elements in LiteralProto"));
}
@@ -1702,7 +1687,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
proto.add_preds(false);
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
+ EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
}
TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
@@ -1714,7 +1699,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
proto.add_preds(false);
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
+ EXPECT_THAT(status.error_message(),
HasSubstr("Expected 3 elements in LiteralProto"));
}
@@ -1727,7 +1712,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
proto.add_f32s(3.0);
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
+ EXPECT_THAT(status.error_message(),
HasSubstr("Expected 84 elements in LiteralProto"));
}
@@ -1740,7 +1725,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
proto.add_s32s(100);
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
+ EXPECT_THAT(status.error_message(),
HasSubstr("Expected 2 elements in LiteralProto"));
}
@@ -1755,7 +1740,7 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
proto.add_preds(false);
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
+ EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
}
TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
@@ -1771,7 +1756,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
+ EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
}
TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
@@ -1794,17 +1779,17 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
+ EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
}
TEST_F(LiteralUtilTest, SortSparseElements) {
auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
SparseIndexArray(10, 3), {});
- literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
- literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
- literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
- literal->SortSparseElements();
- ASSERT_EQ(literal->ToString(false),
+ literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+ literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+ literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+ literal.SortSparseElements();
+ EXPECT_EQ(literal.ToString(false),
"f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
}
@@ -1812,60 +1797,56 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
std::vector<int64> dimensions = {10, 10, 10};
SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
- ASSERT_EQ(
+ EXPECT_EQ(
LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
- ->GetSparseElementAsString(1),
+ .GetSparseElementAsString(1),
"false");
- ASSERT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(int64{2}));
- ASSERT_EQ(
+ EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
+ .GetSparseElementAsString(1),
+ absl::StrCat(int64{2}));
+ EXPECT_EQ(
LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(double{2.0}));
- ASSERT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
+ .GetSparseElementAsString(1),
+ absl::StrCat(double{2.0}));
+ EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
{half{1.0}, half{2.0}, half{3.0}})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(static_cast<float>(half{2.0})));
- ASSERT_EQ(
- LiteralUtil::CreateSparse<complex64>(
- dimensions, indices,
- std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
+ .GetSparseElementAsString(1),
+ absl::StrCat(static_cast<float>(half{2.0})));
+ EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
+ dimensions, indices,
+ std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
+ .GetSparseElementAsString(1),
+ absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+ Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
- /*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+ Literal literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> broadcasted_literal,
- literal->Broadcast(
- /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
- /*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal,
- *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+ Literal broadcasted_literal,
+ literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(broadcasted_literal,
+ LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 5d33df7d40..0cb1ae35f4 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -22,6 +22,9 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -30,23 +33,19 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+Literal ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
@@ -57,7 +56,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
primitive_util::NativeToPrimitiveType<ToNativeT>());
}
});
- auto result = MakeUnique<Literal>(result_shape);
+ Literal result(result_shape);
// Then copy over the data from 'literal' converting FromNativeT values to
// ToNativeT values as necessary.
@@ -68,14 +67,14 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
if (subshape.element_type() ==
primitive_util::NativeToPrimitiveType<FromNativeT>()) {
auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
+ auto dest = result.data<ToNativeT>(shape_index);
for (int64 i = 0; i < src.size(); ++i) {
dest[i] = static_cast<ToNativeT>(src[i]);
}
} else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
+ TF_CHECK_OK(result.CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
}
}
});
@@ -84,54 +83,52 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+/* static */ Literal LiteralUtil::CreateFromDimensions(
+ PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
return Literal::CreateFromShape(
ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
+/* static */ Literal LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
+/* static */ Literal LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
- return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
+/* static */ Literal LiteralUtil::CreateToken() {
+ return Literal(ShapeUtil::MakeTokenShape());
}
/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(0));
+ return LiteralUtil::CreateR0<uint8>(0);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(0));
+ return LiteralUtil::CreateR0<uint32>(0);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(0));
+ return LiteralUtil::CreateR0<uint64>(0);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(0));
+ return LiteralUtil::CreateR0<int8>(0);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(0));
+ return LiteralUtil::CreateR0<int32>(0);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(0));
+ return LiteralUtil::CreateR0<int64>(0);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(0));
+ return LiteralUtil::CreateR0<float>(0);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(0));
+ return LiteralUtil::CreateR0<double>(0);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(0));
+ return LiteralUtil::CreateR0<complex64>(0);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -147,30 +144,29 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*LiteralUtil::CreateR0<uint8>(1));
+ return LiteralUtil::CreateR0<uint8>(1);
case U32:
- return std::move(*LiteralUtil::CreateR0<uint32>(1));
+ return LiteralUtil::CreateR0<uint32>(1);
case U64:
- return std::move(*LiteralUtil::CreateR0<uint64>(1));
+ return LiteralUtil::CreateR0<uint64>(1);
case S8:
- return std::move(*LiteralUtil::CreateR0<int8>(1));
+ return LiteralUtil::CreateR0<int8>(1);
case S32:
- return std::move(*LiteralUtil::CreateR0<int32>(1));
+ return LiteralUtil::CreateR0<int32>(1);
case S64:
- return std::move(*LiteralUtil::CreateR0<int64>(1));
+ return LiteralUtil::CreateR0<int64>(1);
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
+ return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
case BF16:
- return std::move(
- *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(1));
+ return LiteralUtil::CreateR0<float>(1);
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(1));
+ return LiteralUtil::CreateR0<double>(1);
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(1));
+ return LiteralUtil::CreateR0<complex64>(1);
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -186,42 +182,36 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- -std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- -std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(false));
+ return LiteralUtil::CreateR0<bool>(false);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(-std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
@@ -234,40 +224,34 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
case F32:
- return std::move(*LiteralUtil::CreateR0<float>(
- std::numeric_limits<float>::infinity()));
+ return LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity());
case F64:
- return std::move(*LiteralUtil::CreateR0<double>(
- std::numeric_limits<double>::infinity()));
+ return LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity());
case PRED:
- return std::move(*LiteralUtil::CreateR0<bool>(true));
+ return LiteralUtil::CreateR0<bool>(true);
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*LiteralUtil::CreateR0<half>(
- static_cast<half>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<half>(
+ static_cast<half>(std::numeric_limits<float>::infinity()));
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
+ return LiteralUtil::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
@@ -277,34 +261,31 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+/* static */ Literal LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = MakeUnique<Literal>(
+ Literal literal(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
- tensorflow::StringPiece value) {
- auto literal = MakeUnique<Literal>(
- ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
+ Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
for (int i = 0; i < value.size(); ++i) {
- literal->Set<uint8>({i}, value[i]);
+ literal.Set<uint8>({i}, value[i]);
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
- float from, float to, int64 rows, int64 cols) {
+/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
+ int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const LiteralSlice& literal) {
+/* static */ Literal LiteralUtil::ReshapeSlice(
+ absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];
@@ -312,13 +293,13 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
CHECK_EQ(new_dimensions.size(), minor_to_major.size());
- auto new_literal = MakeUnique<Literal>(
+ Literal new_literal(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
// writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
+ Shape shape_with_layout = new_literal.shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
// Copy data into new literal, element-by-element.
@@ -329,40 +310,40 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
switch (literal.shape().element_type()) {
case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
+ new_literal.Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
break;
case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
+ new_literal.Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
break;
case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
+ new_literal.Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
break;
case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
+ new_literal.Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
break;
case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
+ new_literal.Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
break;
case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
+ new_literal.Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
break;
case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
+ new_literal.Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
break;
case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
+ new_literal.Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
break;
case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
+ new_literal.Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
@@ -379,101 +360,89 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
switch (literal.shape().element_type()) {
case PRED:
- return std::move(
- *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
+ return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
// 8 bit types.
case S8:
- return std::move(
- *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
+ return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
case U8:
- return std::move(
- *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
+ return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
// 16 bit types.
case BF16:
- return std::move(*LiteralUtil::CreateR0<bfloat16>(
- literal.GetFirstElement<bfloat16>()));
+ return LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>());
case F16:
- return std::move(
- *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
+ return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
case S16:
- return std::move(
- *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
+ return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
case U16:
- return std::move(
- *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
+ return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
// 32 bit types.
case F32:
- return std::move(
- *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
+ return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
case S32:
- return std::move(
- *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
+ return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
case U32:
- return std::move(
- *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
+ return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
// 64 bit types.
case C64:
- return std::move(*LiteralUtil::CreateR0<complex64>(
- literal.GetFirstElement<complex64>()));
+ return LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>());
case F64:
- return std::move(
- *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
+ return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
case S64:
- return std::move(
- *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
+ return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
- return std::move(
- *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
+ return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();
}
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
- tensorflow::gtl::ArraySlice<const Literal*> elements) {
+/* static */ Literal LiteralUtil::MakeTuple(
+ absl::Span<const Literal* const> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
- tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+/* static */ Literal LiteralUtil::MakeTupleFromSlices(
+ absl::Span<const LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
element_shapes.push_back(element.shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
}
return literal;
}
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements) {
+/* static */ Literal LiteralUtil::MakeTupleOwned(
+ std::vector<Literal> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
for (const auto& element : elements) {
- element_shapes.push_back(element->shape());
+ element_shapes.push_back(element.shape());
}
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
for (int64 i = 0; i < elements.size(); ++i) {
TF_CHECK_OK(
- literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
}
return literal;
}
/* static */ string LiteralUtil::MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index) {
- return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+ absl::Span<const int64> multi_index) {
+ return StrCat("{", absl::StrJoin(multi_index, ","), "}");
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index e3737a9d00..2b181621ed 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -27,6 +27,9 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -34,7 +37,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -43,8 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -69,37 +69,34 @@ class LiteralUtil {
// The variants not ending with WithLayout use the default XLA layout for the
// literal's linear representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR0(NativeT value);
+ static Literal CreateR0(NativeT value);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR1(
- tensorflow::gtl::ArraySlice<NativeT> values);
- static std::unique_ptr<Literal> CreateR1(
- const tensorflow::core::Bitmap& values);
+ static Literal CreateR1(absl::Span<const NativeT> values);
+ static Literal CreateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2(
+ static Literal CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2WithLayout(
+ static Literal CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3(
- std::initializer_list<
- std::initializer_list<std::initializer_list<NativeT>>>
- values);
+ static Literal CreateR3(std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3WithLayout(
+ static Literal CreateR3WithLayout(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4(
+ static Literal CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4WithLayout(
+ static Literal CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -140,9 +137,10 @@ class LiteralUtil {
// [9, 10, 11]: 4.0
//
template <typename NativeT>
- static std::unique_ptr<Literal> CreateSparse(
- tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
+ static Literal CreateSparse(absl::Span<const int64> dimensions,
+ SparseIndexArray indices,
+ absl::Span<const NativeT> values,
+ bool sort = true);
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
@@ -156,132 +154,120 @@ class LiteralUtil {
static Literal MaxValue(PrimitiveType primitive_type);
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
- tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
+ static Literal CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value);
// Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+ static Literal CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
+ static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2D(
- const Array2D<NativeT>& values);
+ static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
+ static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3D(
- const Array3D<NativeT>& values);
+ static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
+ static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4D(
- const Array4D<NativeT>& values);
+ static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
+ static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout);
// Creates a new vector of U8s literal value from a string.
- static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
+ static Literal CreateR1U8(absl::string_view value);
// Creates a linspace-populated literal with the given number of rows and
// columns.
- static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
- int64 rows, int64 cols);
+ static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+ int64 cols);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z dimension given by "projection".
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR3Projected(
+ static Literal CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection);
// Creates a literal that projects the (x, y) dimensions given in values into
// the z and p dimensions given.
template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
+ static Literal CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
- static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+ static Literal MakeIdentityR2(int64 size);
// Returns a tuple literal composed of given literals. Data is copied from the
// given elements into the returned literal.
- static std::unique_ptr<Literal> MakeTuple(
- tensorflow::gtl::ArraySlice<const Literal*> elements);
+ static Literal MakeTuple(absl::Span<const Literal* const> elements);
- static std::unique_ptr<Literal> MakeTupleFromSlices(
- tensorflow::gtl::ArraySlice<LiteralSlice> elements);
+ static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
// As above, but intended to be invoked with move semantics; i.e.
//
- // std::vector<std::unique_ptr<Literal>> elements = ...;
+ // std::vector<Literal> elements = ...;
// auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::vector<std::unique_ptr<Literal>> elements);
+ static Literal MakeTupleOwned(std::vector<Literal> elements);
- // This overload lets you pass a braced list of unique_ptr<Literal>s to
+ // This overload lets you pass a braced list of Literals to
// MakeTupleOwned:
//
// LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
- // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+ // Simply relying on the MakeTupleOwned(std::vector<Literal>)
// overload doesn't work because std::initializer_list's elements are always
// const.
//
- // The arguments to this function must all be unique_ptr<Literal>.
+ // The arguments to this function must all be Literal.
template <typename... Ts>
- static std::unique_ptr<Literal> MakeTupleOwned(
- std::unique_ptr<Ts>... elements) {
- std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
- std::move(elements)...};
- std::vector<std::unique_ptr<Literal>> v;
+ static Literal MakeTupleOwned(Ts... elements) {
+ std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+ std::vector<Literal> v;
v.insert(v.begin(), std::make_move_iterator(arr.begin()),
std::make_move_iterator(arr.end()));
return MakeTupleOwned(std::move(v));
}
// Create a constant token literal. Token types have no value.
- static std::unique_ptr<Literal> CreateToken();
+ static Literal CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ static Literal CreateFromDimensions(PrimitiveType primitive_type,
+ absl::Span<const int64> dimensions);
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(
- const LiteralSlice& bf16_literal);
+ static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(
- const LiteralSlice& f32_literal);
+ static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
// Creates a literal with a new shape with the given new dimensions using the
// data in the given input literal. For reshaping purposes the (flat) data
// buffer of the input literal is assumed to have the given minor_to_major
// layout order.
- static std::unique_ptr<Literal> ReshapeSlice(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const LiteralSlice& literal);
+ static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+ absl::Span<const int64> minor_to_major,
+ const LiteralSlice& literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
@@ -289,9 +275,9 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ static StatusOr<Literal> CreateRandomLiteral(
const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+ const std::function<T(absl::Span<const int64>)>& generator);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -300,8 +286,8 @@ class LiteralUtil {
template <
PrimitiveType type, typename E,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+ T mean, T stddev);
// Creates a literal with the supplied shape, and initializes the literal
// values using a normal distribution with given mean and stddev standard
@@ -310,8 +296,8 @@ class LiteralUtil {
template <
PrimitiveType type,
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
+ static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+ T stddev);
//
// End of factory methods.
@@ -319,51 +305,49 @@ class LiteralUtil {
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
// dimension 1 equal to 8.
- static string MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index);
+ static string MultiIndexAsString(absl::Span<const int64> multi_index);
};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
- literal->Set({}, value);
+ literal.Set({}, value);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
- auto literal = MakeUnique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+ Literal literal(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size())}));
- literal->PopulateR1(values);
+ literal.PopulateR1(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{static_cast<int64>(values.size()),
static_cast<int64>(values.begin()->size())},
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateR2(values);
+ literal.PopulateR2(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -388,14 +372,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -426,22 +410,22 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
- tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
+/* static */ Literal LiteralUtil::CreateSparse(
+ absl::Span<const int64> dimensions, SparseIndexArray indices,
+ absl::Span<const NativeT> values, bool sort) {
int64 num_elements = values.size();
int64 rank = dimensions.size();
CHECK_EQ(num_elements, indices.index_count());
CHECK_EQ(rank, indices.rank());
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
+ Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
indices.max_indices()));
- literal->PopulateSparse(indices, values, sort);
+ literal.PopulateSparse(indices, values, sort);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -449,50 +433,48 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ Literal literal(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major())));
- literal->PopulateFromArray(values);
+ literal.PopulateFromArray(values);
return literal;
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -517,7 +499,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -545,21 +527,20 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -568,45 +549,39 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(
- tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
- auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+ absl::Span<const int64> dimensions, NativeT value) {
+ Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
- literal->PopulateWithValue(value);
+ literal.PopulateWithValue(value);
return literal;
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+ const std::function<T(absl::Span<const int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
TF_RET_CHECK(shape.element_type() == type);
- auto literal = MakeUnique<Literal>(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- return generator(indexes);
- }));
+ Literal literal(shape);
+ TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
+ [&](absl::Span<const int64> indexes) { return generator(indexes); }));
return std::move(literal);
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
- shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
- return generator(*engine);
- });
+ shape,
+ [&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
index 3c74e070da..fcff48b6b1 100644
--- a/tensorflow/compiler/xla/map_util.h
+++ b/tensorflow/compiler/xla/map_util.h
@@ -60,7 +60,7 @@ MaybeFind(const Collection& collection,
if (it == collection.end()) {
std::ostringstream os;
os << key;
- return NotFound("key not found: %s", os.str().c_str());
+ return NotFound("key not found: %s", os.str());
}
return {it->second};
}
diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc
index 69ef4f7a2f..4eab4fa429 100644
--- a/tensorflow/compiler/xla/metric_table_report.cc
+++ b/tensorflow/compiler/xla/metric_table_report.cc
@@ -18,7 +18,8 @@ limitations under the License.
#include <cctype>
#include <unordered_map>
-#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) {
if (end_of_line == string::npos) {
end_of_line = report.size();
}
- tensorflow::StringPiece line(report.data() + pos, end_of_line - pos);
+ absl::string_view line(report.data() + pos, end_of_line - pos);
// TODO(b/34779244): Figure out how to do this without the verbose log-line
// prefix. The usual way didn't compile on open source.
@@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() {
if (text.empty()) {
text = "[no category]";
}
- tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ",
- entry_name_, ")");
+ absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_,
+ ")");
AppendTableRow(text, category.metric_sum, metric_sum);
// Show the top entries in the category.
@@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() {
}
const int64 remaining_categories = categories.size() - categories_shown;
if (remaining_categories > 0) {
- AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories,
- " more categories)"),
- expected_metric_sum_ - metric_sum, expected_metric_sum_);
+ AppendTableRow(
+ absl::StrCat("... (", remaining_categories, " more categories)"),
+ expected_metric_sum_ - metric_sum, expected_metric_sum_);
}
}
@@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() {
}
const int64 remaining_entries = entries_.size() - entries_shown;
if (remaining_entries > 0) {
- AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries,
- " more ", entry_name_, ")"),
- expected_metric_sum_ - metric_sum, expected_metric_sum_);
+ AppendTableRow(
+ absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"),
+ expected_metric_sum_ - metric_sum, expected_metric_sum_);
}
}
@@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() {
string MetricTableReport::MetricString(double metric) {
// Round to integer and stringify.
- string s1 = tensorflow::strings::StrCat(std::llround(metric));
+ string s1 = absl::StrCat(std::llround(metric));
// Code below commafies the string, e.g. "1234" becomes "1,234".
- tensorflow::StringPiece sp1(s1);
+ absl::string_view sp1(s1);
string output;
// Copy leading non-digit characters unconditionally.
// This picks up the leading sign.
@@ -263,8 +264,7 @@ string MetricTableReport::MetricString(double metric) {
}
string MetricTableReport::MetricPercent(double metric) {
- return tensorflow::strings::Printf("%5.2f%%",
- metric / expected_metric_sum_ * 100.0);
+ return absl::StrFormat("%5.2f%%", metric / expected_metric_sum_ * 100.0);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h
index 818fb1d3fe..062d8ed99b 100644
--- a/tensorflow/compiler/xla/metric_table_report.h
+++ b/tensorflow/compiler/xla/metric_table_report.h
@@ -18,9 +18,8 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -108,7 +107,7 @@ class MetricTableReport {
// Append all parameters to the report.
template <typename... Args>
void AppendLine(Args... args) {
- tensorflow::strings::StrAppend(&report_, std::forward<Args>(args)..., "\n");
+ absl::StrAppend(&report_, std::forward<Args>(args)..., "\n");
}
// Represents a set of entries with the same category_text.
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 6b7fd10d63..0f86f9f35e 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -19,15 +19,15 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/base/casts.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
PackedLiteralReader::~PackedLiteralReader() { delete file_; }
-StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
- const Shape& shape, const Layout* layout) {
+StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
+ const Layout* layout) {
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
@@ -54,17 +54,17 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
if (shape.element_type() != F32) {
return Unimplemented(
"not yet implemented element type for packed literal reading: %s",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
- auto result = MakeUnique<Literal>(literal_shape);
- result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+ Literal result(literal_shape);
+ result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
int64 elements = ShapeUtil::ElementsIn(shape);
- tensorflow::gtl::ArraySlice<float> field = result->data<float>();
- char* data = tensorflow::bit_cast<char*>(field.data());
+ absl::Span<const float> field = result.data<float>();
+ char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- tensorflow::StringPiece sp;
+ absl::string_view sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 98dccaa9a2..d6d2ff1521 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -41,8 +41,7 @@ class PackedLiteralReader {
//
// Layout is optional. If it is not provided, no layout is set on the literal
// that is produced.
- StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
- const Layout* layout = nullptr);
+ StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
// Returns whether the input file has been fully exhausted; i.e. all available
// packed literals have been read and we're at the end of the file.
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index c8f2d65c22..f0d84646b9 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -39,6 +39,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -59,6 +62,8 @@ cc_library(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 8246f76d34..9da5dc0d2d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -14,10 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal,
return client->TransferToInfeedLocal(literal, device_ordinal);
}
-StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number) {
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number) {
VLOG(1) << "Outfeeding literal from replica number: " << replica_number
<< " shape: " << shape;
LocalClient* client = GetOrCreateLocalClient();
@@ -137,14 +137,12 @@ static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
/* static */
StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
- const Literal& argument,
- const tensorflow::gtl::optional<Shape>& shape_with_layout) {
+ const Literal& argument, const absl::optional<Shape>& shape_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
StatusOr<ScopedShapedBuffer> buf = [&] {
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- return ToBuffer(client, /*device_ordinal=*/0, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ return ToBuffer(client, /*device_ordinal=*/0, relaid);
}
return ToBuffer(client, /*device_ordinal=*/0, argument);
}();
@@ -152,7 +150,7 @@ StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}
-StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
+StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer());
}
@@ -161,16 +159,16 @@ CompiledLocalComputation::CompiledLocalComputation(
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
-StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
+StatusOr<Literal> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments,
- const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
+ const std::vector<absl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
// Each replica populates a StatusOr result, but only replica zero actually
// retrieves its literal value.
- std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+ std::vector<StatusOr<Literal>> results(GetReplicaCount());
{
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
GetReplicaCount());
@@ -194,14 +192,13 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
scoped_buffers.reserve(arguments.size());
for (int i = 0; i < arguments.size(); ++i) {
const Literal& argument = arguments[i];
- const tensorflow::gtl::optional<Shape>& shape_with_layout =
+ const absl::optional<Shape>& shape_with_layout =
shapes_with_layout[i];
StatusOr<ScopedShapedBuffer> pushed;
if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- pushed = ToBuffer(client, device_ordinal, *relaid);
+ Literal relaid = argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}
@@ -252,7 +249,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
return InternalError(
"Failed running replica %d (other replicas may have failed as well): "
"%s.",
- replica, statusor.status().ToString().c_str());
+ replica, statusor.status().ToString());
}
}
@@ -260,7 +257,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
}
LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
- tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
+ absl::Span<LocalShapedBuffer* const> argument_handles) {
LocalClient* client = GetOrCreateLocalClient();
std::vector<const ShapedBuffer*> argument_buffers;
@@ -370,8 +367,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
}
LocalOp LocalComputationBuilder::Broadcast(
- const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const LocalOp& operand, absl::Span<const int64> broadcast_sizes) {
return xla::Broadcast(operand.op(), broadcast_sizes);
}
@@ -381,14 +377,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
return xla::Pad(operand.op(), padding_value.op(), padding_config);
}
-LocalOp LocalComputationBuilder::Reshape(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
return xla::Reshape(operand.op(), dimensions, new_sizes);
}
-LocalOp LocalComputationBuilder::Collapse(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand,
+ absl::Span<const int64> dimensions) {
return xla::Collapse(operand.op(), dimensions);
}
@@ -396,10 +392,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
return xla::CrossReplicaSum(operand.op());
}
-LocalOp LocalComputationBuilder::Slice(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+LocalOp LocalComputationBuilder::Slice(const LocalOp& operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
return xla::Slice(operand.op(), start_indices, limit_indices, strides);
}
@@ -412,7 +408,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
LocalOp LocalComputationBuilder::DynamicSlice(
const LocalOp& operand, const LocalOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
}
@@ -422,8 +418,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice(
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
}
-LocalOp LocalComputationBuilder::ConcatInDim(
- tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) {
+LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
+ int64 dimension) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -434,18 +430,16 @@ LocalOp LocalComputationBuilder::ConcatInDim(
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const LocalOp& source, const LocalOp& init_value,
- const LocalComputation& scatter) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
+ const LocalOp& init_value, const LocalComputation& scatter) {
return xla::SelectAndScatterWithGeneralPadding(
operand.op(), select.computation(), window_dimensions, window_strides,
padding, source.op(), init_value.op(), scatter.computation());
}
-LocalOp LocalComputationBuilder::Tuple(
- tensorflow::gtl::ArraySlice<LocalOp> elements) {
+LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(elements.size());
for (const auto& op : elements) {
@@ -472,10 +466,9 @@ LocalOp LocalComputationBuilder::DotGeneral(
LocalOp LocalComputationBuilder::ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers);
@@ -491,9 +484,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType(
return xla::BitcastConvertType(operand.op(), new_element_type);
}
-LocalOp LocalComputationBuilder::Call(
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<LocalOp> operands) {
+LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
+ absl::Span<const LocalOp> operands) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -503,19 +495,18 @@ LocalOp LocalComputationBuilder::Call(
}
LocalOp LocalComputationBuilder::Transpose(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
+ const LocalOp& operand, absl::Span<const int64> permutation) {
return xla::Transpose(operand.op(), permutation);
}
-LocalOp LocalComputationBuilder::Rev(
- const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Rev(const LocalOp& operand,
+ absl::Span<const int64> dimensions) {
return xla::Rev(operand.op(), dimensions);
}
-LocalOp LocalComputationBuilder::Map(
- tensorflow::gtl::ArraySlice<LocalOp> operands,
- const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
+ const LocalComputation& local_computation,
+ absl::Span<const int64> dimensions) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
@@ -529,7 +520,7 @@ LocalOp LocalComputationBuilder::Map(
LocalOp LocalComputationBuilder::Reduce(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ absl::Span<const int64> dimensions_to_reduce) {
return xla::Reduce(operand.op(), init_value.op(),
local_computation.computation(), dimensions_to_reduce);
}
@@ -537,9 +528,9 @@ LocalOp LocalComputationBuilder::Reduce(
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
window_dimensions, window_strides, padding);
@@ -575,6 +566,16 @@ StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
return builder_.IsConstant(operand.op());
}
+LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) {
+ return xla::Sort(operand.op(), absl::nullopt, dimension);
+}
+
+LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys,
+ const LocalOp& values,
+ int64 dimension) {
+ return xla::Sort(keys.op(), values.op(), dimension);
+}
+
StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
const LocalOp& operand) {
TF_ASSIGN_OR_RETURN(XlaComputation computation,
@@ -590,10 +591,10 @@ StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
#define _FORWARD_UNOP(method_name) \
_FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))
-#define _FORWARD_BINOP(method_name) \
- _FORWARD(method_name, LocalOp, \
- (const LocalOp& lhs, const LocalOp& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD(method_name, LocalOp, \
+ (const LocalOp& lhs, const LocalOp& rhs, \
+ absl::Span<const int64> broadcast_dimensions), \
(lhs.op(), rhs.op(), broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
@@ -640,7 +641,6 @@ _FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
-_FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
@@ -688,8 +688,7 @@ StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
"shape; shape: %s",
ShapeUtil::HumanString(
- local_shaped_buffer->shaped_buffer()->on_device_shape())
- .c_str());
+ local_shaped_buffer->shaped_buffer()->on_device_shape()));
}
DeviceMemoryAllocator* allocator =
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index a568c24c63..1d5dfe5911 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace swig {
@@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
-StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
- const Shape& shape, int replica_number);
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+ int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@@ -60,13 +60,12 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
class LocalShapedBuffer {
public:
static StatusOr<LocalShapedBuffer*> FromLiteral(
- const Literal& argument,
- const tensorflow::gtl::optional<Shape>& shape_with_layout);
+ const Literal& argument, const absl::optional<Shape>& shape_with_layout);
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
const ScopedShapedBuffer* shaped_buffer() const;
- StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+ StatusOr<Literal> ToLiteral() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
@@ -118,12 +117,12 @@ class CompiledLocalComputation {
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
- StatusOr<std::unique_ptr<Literal> > Execute(
+ StatusOr<Literal> Execute(
const std::vector<Literal>& arguments,
- const std::vector<tensorflow::gtl::optional<Shape> >& shapes_with_layout);
+ const std::vector<absl::optional<Shape> >& shapes_with_layout);
LocalShapedBuffer* ExecuteWithShapedBuffers(
- tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
+ absl::Span<LocalShapedBuffer* const> argument_handles);
private:
std::unique_ptr<LocalExecutable> executable_;
@@ -200,46 +199,41 @@ class LocalComputationBuilder {
LocalOp ConstantLiteral(const Literal& literal);
LocalOp Broadcast(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ absl::Span<const int64> broadcast_sizes);
LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
const PaddingConfig& padding_config);
- LocalOp Reshape(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
- LocalOp Collapse(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions);
LocalOp CrossReplicaSum(const LocalOp& operand);
- LocalOp Slice(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno);
LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
const LocalOp& start_indices);
- LocalOp ConcatInDim(tensorflow::gtl::ArraySlice<LocalOp> operands,
- int64 dimension);
+ LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
LocalOp SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
- const LocalOp& source, const LocalOp& init_value,
- const LocalComputation& scatter);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
+ const LocalOp& init_value, const LocalComputation& scatter);
- LocalOp Tuple(tensorflow::gtl::ArraySlice<LocalOp> elements);
+ LocalOp Tuple(absl::Span<const LocalOp> elements);
LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
@@ -250,10 +244,10 @@ class LocalComputationBuilder {
LocalOp ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
LocalOp ConvertElementType(const LocalOp& operand,
@@ -263,28 +257,27 @@ class LocalComputationBuilder {
PrimitiveType new_element_type);
LocalOp Call(const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<LocalOp> operands);
+ absl::Span<const LocalOp> operands);
LocalOp Transpose(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ absl::Span<const int64> permutation);
- LocalOp Rev(const LocalOp& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
- LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
+ LocalOp Map(absl::Span<const LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ absl::Span<const int64> dimensions_to_reduce);
LocalOp ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding);
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
const Shape& shape);
@@ -301,6 +294,11 @@ class LocalComputationBuilder {
StatusOr<bool> IsConstant(const LocalOp& operand);
+ LocalOp Sort(const LocalOp& operand, int64 dimension);
+
+ LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
+ int64 dimension);
+
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \
@@ -312,7 +310,7 @@ class LocalComputationBuilder {
#define _FORWARD_BINOP(method_name) \
_FORWARD(method_name, LocalOp, \
(const LocalOp& lhs, const LocalOp& rhs, \
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
+ absl::Span<const int64> broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
_FORWARD(method_name, LocalOp, \
@@ -357,7 +355,6 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
- _FORWARD_UNOP(Sort)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 5d5a955bfe..521490e76c 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -22,15 +22,15 @@ limitations under the License.
//
// C++ Python
// -------------------------------------+---------------------------------------
-// ArraySlice<int64> <- sequence of int
-// ArraySlice<LocalOp> <- sequence of LocalOp
+// Span<int64> <- sequence of int
+// Span<LocalOp> <- sequence of LocalOp
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape -> pair holding (dtype, dimensions)
// <- object duck-typed as xla_client.Shape
// std::vector<Shape> <- sequence of xla_client.Shape objects
// PrimitiveType <- int
-// ArraySlice<pair<int64, in64>> <- sequence of int pairs
+// Span<pair<int64, in64>> <- sequence of int pairs
// PaddingConfig proto <- corresponding Python proto
// ConvolutionDimensionNumbers proto <- corresponding Python proto
// DotDimensionNumbers proto <- corresponding Python proto
@@ -109,10 +109,12 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
@@ -154,8 +156,8 @@ bool HandleStringAttribute(PyObject* o,
return true; // The attribute is None, which we consider ok.
}
if (!PyString_Check(attr)) {
- string message = tensorflow::strings::Printf("%s must be a string or none; got %s",
- attr_name, numpy::PyObjectCppRepr(attr).c_str());
+ string message = absl::StrFormat("%s must be a string or none; got %s",
+ attr_name, numpy::PyObjectCppRepr(attr));
PyErr_SetString(PyExc_TypeError, message.c_str());
Py_DECREF(attr);
return false; // Type error, not ok.
@@ -214,9 +216,9 @@ tensorflow::ImportNumpy();
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
- std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
+ Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@@ -265,9 +267,9 @@ tensorflow::ImportNumpy();
$result = Py_None;
}
-// ArraySlice<int64>
+// Span<int64>
-%typemap(in) tensorflow::gtl::ArraySlice<int64>
+%typemap(in) absl::Span<const int64>
(std::vector<int64> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -297,9 +299,9 @@ tensorflow::ImportNumpy();
$1 = temps;
}
-// ArraySlice<LocalOp>
+// Span<LocalOp>
-%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalOp>(
+%typemap(in) absl::Span<const xla::swig::LocalOp>(
std::vector<LocalOp> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -321,7 +323,7 @@ tensorflow::ImportNumpy();
// LocalShapedBuffer*
-%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*>
+%typemap(in) absl::Span<xla::swig::LocalShapedBuffer* const>
(std::vector<LocalShapedBuffer*> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -344,25 +346,25 @@ tensorflow::ImportNumpy();
// Literal
-%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
+%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
SWIG_fail;
}
- $1 = literal_status.ValueOrDie().get();
+ $1 = &literal_status.ValueOrDie();
}
-%typemap(out) std::unique_ptr<Literal> {
+%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
- $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+ $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@@ -373,13 +375,13 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
+ StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
if (!literal_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
Py_DECREF(o);
SWIG_fail;
}
- temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
+ temps.push_back(literal_status.ConsumeValueOrDie());
Py_DECREF(o);
}
$1 = &temps;
@@ -409,10 +411,10 @@ tensorflow::ImportNumpy();
$1 = &temp;
}
-%typemap(in) const tensorflow::gtl::optional<Shape>& (
- tensorflow::gtl::optional<Shape> temp) {
+%typemap(in) const absl::optional<Shape>& (
+ absl::optional<Shape> temp) {
if ($input == Py_None) {
- temp = tensorflow::gtl::nullopt;
+ temp = absl::nullopt;
$1 = &temp;
} else {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
@@ -448,8 +450,8 @@ tensorflow::ImportNumpy();
$1 = &temps;
}
-%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
- std::vector<tensorflow::gtl::optional<Shape> > temps) {
+%typemap(in) const std::vector<absl::optional<Shape> >& (
+ std::vector<absl::optional<Shape> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
SWIG_fail;
@@ -458,7 +460,7 @@ tensorflow::ImportNumpy();
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (o == Py_None) {
- temps.push_back(tensorflow::gtl::nullopt);
+ temps.push_back(absl::nullopt);
} else {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
Py_DECREF(o);
@@ -494,9 +496,9 @@ tensorflow::ImportNumpy();
$1 = static_cast<PrimitiveType>(value);
}
-// ArraySlice<pair<int64, in64>>
+// Span<pair<int64, in64>>
-%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
+%typemap(in) absl::Span<const std::pair<int64, int64> >
(std::vector<std::pair<int64, int64> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -896,7 +898,7 @@ tensorflow::ImportNumpy();
if (o != Py_None) {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
if (!statusor.ok()) {
- PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
+ PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
Py_DECREF(o);
SWIG_fail;
}
@@ -1011,6 +1013,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::LocalComputationBuilder::SortKeyVal;
%unignore xla::swig::LocalComputationBuilder::Sqrt;
%unignore xla::swig::LocalComputationBuilder::Rsqrt;
%unignore xla::swig::LocalComputationBuilder::Square;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 6f665faf61..b0aa024c74 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -149,9 +151,7 @@ static int NumpyTypenum(PyObject* o) {
//
// NOTE: this is an internal helper for conversion to a C++, and so decrefs r.
static string ExtractStringAndDecref(PyObject* r) {
- auto error = [r] {
- return tensorflow::strings::Printf("<failed conversion of %p>", r);
- };
+ auto error = [r] { return absl::StrFormat("<failed conversion of %p>", r); };
if (r == nullptr) {
return error();
}
@@ -191,8 +191,8 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
PyObject* result =
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
if (result == nullptr) {
- return error(tensorflow::strings::StrCat(
- "Failed to call method of shape object:", method));
+ return error(
+ absl::StrCat("Failed to call method of shape object:", method));
}
return result;
};
@@ -281,15 +281,15 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
// Helper that retrieves the member with attr_name, stringifies it if is not
// None, and returns it as a C++ string.
-static tensorflow::gtl::optional<string> GetAttrAsString(
- PyObject* o, const string& attr_name) {
+static absl::optional<string> GetAttrAsString(PyObject* o,
+ const string& attr_name) {
if (!PyObject_HasAttrString(o, attr_name.c_str())) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
if (attr == Py_None) {
Py_DECREF(attr);
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
string result = PyObjectCppStr(attr);
Py_DECREF(attr);
@@ -298,48 +298,46 @@ static tensorflow::gtl::optional<string> GetAttrAsString(
// Helper that retrieves the member with attr_name, checks that it is an integer
// if it is not None, and returns it as an int32 value.
-static tensorflow::gtl::optional<int32> GetAttrAsInt32(
- PyObject* o, const string& attr_name) {
+static absl::optional<int32> GetAttrAsInt32(PyObject* o,
+ const string& attr_name) {
if (!PyObject_HasAttrString(o, attr_name.c_str())) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str());
if (attr == Py_None) {
Py_DECREF(attr);
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
if (!CheckPyIntOrLong(attr)) {
Py_DECREF(attr);
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
long value = PyIntOrPyLongToLong(attr); // NOLINT
Py_DECREF(attr);
if (value == -1 && PyErr_Occurred() != nullptr) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
if (static_cast<int32>(value) != value) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
return value;
}
StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
OpMetadata result;
- tensorflow::gtl::optional<string> op_type = GetAttrAsString(o, "op_type");
+ absl::optional<string> op_type = GetAttrAsString(o, "op_type");
if (op_type.has_value()) {
result.set_op_type(op_type.value());
}
- tensorflow::gtl::optional<string> op_name = GetAttrAsString(o, "op_name");
+ absl::optional<string> op_name = GetAttrAsString(o, "op_name");
if (op_name.has_value()) {
result.set_op_name(op_name.value());
}
- tensorflow::gtl::optional<string> source_file =
- GetAttrAsString(o, "source_file");
+ absl::optional<string> source_file = GetAttrAsString(o, "source_file");
if (source_file.has_value()) {
result.set_source_file(source_file.value());
}
- tensorflow::gtl::optional<int32> source_line =
- GetAttrAsInt32(o, "source_line");
+ absl::optional<int32> source_line = GetAttrAsInt32(o, "source_line");
if (source_line.has_value()) {
result.set_source_line(source_line.value());
}
@@ -370,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
}
}
-StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
if (PyTuple_Check(o)) {
int num_elements = PyTuple_Size(o);
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
elements.reserve(num_elements);
for (int i = 0; i < num_elements; i++) {
PyObject* element = PyTuple_GetItem(o, i);
@@ -391,8 +389,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
int np_type = PyArray_TYPE(py_array);
auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
- TF_RETURN_IF_ERROR(
- CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
+ TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
return std::move(literal);
} else {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index a67c93a4fb..40ff2d9ad2 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -25,9 +25,9 @@ limitations under the License.
#include <algorithm>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace xla {
@@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// To avoid transferring ownership of the data buffers that underlie
// PyArrays and XLA literals, this function makes deep copies of all
// array data.
-StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
// The following functions copy array data from the buffers underlying Numpy
// ndarrays into those underlying XLA literals, and vice versa.
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index a2c6fc344d..fa4366ff07 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -105,7 +105,6 @@ _UNARY_OPS = [
'Square',
'Reciprocal',
'Neg',
- 'Sort',
'Erf',
'Erfc',
'ErfInv',
@@ -1218,6 +1217,14 @@ class ComputationBuilder(object):
lhs_dilation, rhs_dilation,
dimension_numbers)
+ def Sort(self, operand, dimension=-1):
+ """Enqueues a sort operation onto the computation."""
+ return self._client.Sort(operand, dimension)
+
+ def SortKeyVal(self, keys, values, dimension=-1):
+ """Enqueues a key-value sort operation onto the computation."""
+ return self._client.SortKeyVal(keys, values, dimension)
+
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index a803520876..05325367f5 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <array>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@@ -43,7 +44,7 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
int m = lhs.height();
int n = rhs.width();
int k = lhs.width();
- auto result = MakeUnique<Array2D<T>>(m, n);
+ auto result = absl::make_unique<Array2D<T>>(m, n);
// Because Eigen is a header-oriented library, make sure that the Eigen code
// is the same as the code used by the CPU backend (otherwise the linker will
// randomly pick *some* definition).
@@ -77,7 +78,8 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
const Array2D<float>& input) {
- auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
+ auto result =
+ absl::make_unique<Array2D<double>>(input.height(), input.width());
for (int64 rowno = 0; rowno < input.height(); ++rowno) {
for (int64 colno = 0; colno < input.height(); ++colno) {
(*result)(rowno, colno) = input(rowno, colno);
@@ -106,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
// array by adding a fourth dummy dimension of size 1 without stride, padding
// and dilation.
Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
- a4dlhs.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
- });
+ a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
+ });
Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
- a4drhs.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
- });
+ a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
+ });
// Add a second dummy spatial dimensions.
ConvolutionDimensionNumbers dnums2d = dnums;
dnums2d.add_input_spatial_dimensions(3);
@@ -126,13 +126,12 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
{rhs_dilation, 1}, dnums2d);
- auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(),
- convr4->height());
- convr4->Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
- CHECK_EQ(indices[3], 0);
- convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
- });
+ auto convr3 = absl::make_unique<Array3D<float>>(
+ convr4->planes(), convr4->depth(), convr4->height());
+ convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
+ CHECK_EQ(indices[3], 0);
+ convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
+ });
return convr3;
}
@@ -187,11 +186,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const absl::Span<const float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -201,7 +200,7 @@ ReferenceUtil::ReduceWindow1DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<std::vector<float>>(window_counts[0]);
+ auto result = absl::make_unique<std::vector<float>>(window_counts[0]);
// Do a full 1D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
@@ -219,10 +218,11 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
+ float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
return ReduceWindow1DGeneric(
@@ -234,9 +234,9 @@ ReferenceUtil::ReduceWindow1DAdd(
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -247,7 +247,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
+ auto result =
+ absl::make_unique<Array2D<float>>(window_counts[0], window_counts[1]);
// Do a full 2D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
@@ -273,8 +274,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
const Array2D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -284,8 +285,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
const Array3D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -296,8 +297,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
WindowCount(dim_lengths[i], window[i], stride[i], padding);
pad_low[i] = padding_both[i].first;
}
- auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1],
- window_counts[2]);
+ auto result = absl::make_unique<Array3D<float>>(
+ window_counts[0], window_counts[1], window_counts[2]);
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
@@ -331,8 +332,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -344,9 +345,9 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -358,8 +359,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
window_util::StridedBound(padded_width, window[i], stride[i]);
pad_low[i] = padding[i].first;
}
- auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
- window_counts[2], window_counts[3]);
+ auto result = absl::make_unique<Array4D<float>>(
+ window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
// Do a full 4D reduce window.
for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
@@ -399,8 +400,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
const Array4D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -421,13 +422,15 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>>
-ReferenceUtil::SelectAndScatter4DGePlus(
- const Array4D<float>& operand, const Array4D<float>& source, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
+ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
+ const Array4D<float>& source,
+ float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
- auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
- operand.n3(), operand.n4());
+ auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
+ operand.n3(), operand.n4());
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -526,13 +529,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -543,7 +546,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@@ -553,7 +556,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@@ -561,35 +564,39 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
dim2.set_base_dilation(lhs_dilation.second);
*window.add_dimensions() = dim2;
- const Shape& shape =
- ShapeInference::InferConvolveShape(lhs_literal->shape(),
- rhs_literal->shape(), window, dnums)
- .ConsumeValueOrDie();
+ const Shape& shape = ShapeInference::InferConvolveShape(
+ lhs_literal.shape(), rhs_literal.shape(),
+ /*feature_group_count=*/1, window, dnums)
+ .ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, precision_config));
HloModuleConfig config;
HloModule module("ReferenceUtil", config);
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
- CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+ CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
- MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+ result_literal.shape().dimensions(1),
+ result_literal.shape().dimensions(2),
+ result_literal.shape().dimensions(3));
- result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
- *value = result_literal->Get<float>(indices);
+ result->Each([&](absl::Span<const int64> indices, float* value) {
+ *value = result_literal.Get<float>(indices);
});
return result;
@@ -601,7 +608,7 @@ ReferenceUtil::ReduceToColArray2D(
const std::function<float(float, float)>& reduce_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<std::vector<float>>();
+ auto result = absl::make_unique<std::vector<float>>();
for (int64 i = 0; i < rows; ++i) {
float acc = init;
for (int64 j = 0; j < cols; ++j) {
@@ -618,7 +625,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float, float)>& reduce_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<std::vector<float>>();
+ auto result = absl::make_unique<std::vector<float>>();
for (int64 i = 0; i < cols; ++i) {
float acc = init;
for (int64 j = 0; j < rows; ++j) {
@@ -630,8 +637,7 @@ ReferenceUtil::ReduceToRowArray2D(
}
/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
- const Array4D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array4D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function) {
std::vector<float> result;
CHECK_EQ(dims.size(), 3);
@@ -674,8 +680,8 @@ ReferenceUtil::ReduceToRowArray2D(
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
const std::vector<float>& array, const std::vector<int64>& bounds,
int64 broadcast_from_dim) {
- auto result =
- MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
+ auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
+ bounds[2], bounds[3]);
for (int64 i = 0; i < result->n1(); ++i) {
for (int64 j = 0; j < result->n2(); ++j) {
for (int64 k = 0; k < result->n3(); ++k) {
@@ -704,13 +710,12 @@ ReferenceUtil::ReduceToRowArray2D(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
- const Array3D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array3D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function) {
CHECK_EQ(dims.size(), 1);
int64 rows = dims[0] == 0 ? array.n2() : array.n1();
int64 cols = dims[0] == 2 ? array.n2() : array.n3();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
result->Fill(init);
for (int i0 = 0; i0 < array.n1(); ++i0) {
for (int i1 = 0; i1 < array.n2(); ++i1) {
@@ -730,7 +735,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float)>& map_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(matrix(i, j));
@@ -746,7 +751,7 @@ ReferenceUtil::ReduceToRowArray2D(
CHECK_EQ(lhs.width(), rhs.width());
int64 rows = lhs.height();
int64 cols = rhs.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
@@ -760,7 +765,7 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float, int64, int64)>& map_function) {
int64 rows = matrix.height();
int64 cols = matrix.width();
- auto result = MakeUnique<Array2D<float>>(rows, cols);
+ auto result = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 i = 0; i < rows; ++i) {
for (int64 j = 0; j < cols; ++j) {
(*result)(i, j) = map_function(matrix(i, j), i, j);
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 8fa6961d19..9ce098029d 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -22,14 +22,14 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,8 @@ class ReferenceUtil {
template <typename T>
static std::unique_ptr<Array2D<T>> TransposeArray2D(
const Array2D<T>& operand) {
- auto result = MakeUnique<Array2D<T>>(operand.width(), operand.height());
+ auto result =
+ absl::make_unique<Array2D<T>>(operand.width(), operand.height());
for (int64 w = 0; w < operand.width(); ++w) {
for (int64 h = 0; h < operand.height(); ++h) {
(*result)(w, h) = operand(h, w);
@@ -143,8 +144,7 @@ class ReferenceUtil {
// Returns the result of reducing the 4D array to a vector, reducing away
// the dimensions specified in dims.
static std::vector<float> Reduce4DTo1D(
- const Array4D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array4D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function);
// Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
@@ -155,8 +155,7 @@ class ReferenceUtil {
// Returns the result of reducing the 3D array to a 2D array, reducing away
// the dimensions specified in dims.
static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
- const Array3D<float>& array, float init,
- tensorflow::gtl::ArraySlice<int64> dims,
+ const Array3D<float>& array, float init, absl::Span<const int64> dims,
const std::function<float(float, float)>& reduce_function);
// Applies map_function to each element in the input (2D array) and returns
@@ -178,47 +177,47 @@ class ReferenceUtil {
// Windowed reductions with Add as the function to apply.
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const float>& operand, float init,
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
const Array2D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
const Array3D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
const Array4D<float>& operand, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
// Windowed reductions with a generic reduce function.
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
- const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const absl::Span<const float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, Padding padding);
// With arbitrary padding.
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride,
- const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride,
+ const absl::Span<const std::pair<int64, int64>>& padding);
// Batch normalize data.
static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -231,8 +230,8 @@ class ReferenceUtil {
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
- const tensorflow::gtl::ArraySlice<int64>& window,
- const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding);
+ const absl::Span<const int64>& window,
+ const absl::Span<const int64>& stride, bool same_padding);
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -242,7 +241,7 @@ class ReferenceUtil {
const Array2D<T>& rhs,
int concatenate_dimension) {
CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
- auto result = MakeUnique<Array2D<T>>(
+ auto result = absl::make_unique<Array2D<T>>(
concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
@@ -276,7 +275,8 @@ class ReferenceUtil {
out_dims[i] = lhs_dims[i] + rhs_dims[i];
}
}
- auto result = MakeUnique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
+ auto result =
+ absl::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -310,8 +310,8 @@ class ReferenceUtil {
out_dims[i] = lhs_dims[i] + rhs_dims[i];
}
}
- auto result = MakeUnique<Array4D<T>>(out_dims[0], out_dims[1], out_dims[2],
- out_dims[3]);
+ auto result = absl::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
+ out_dims[2], out_dims[3]);
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -332,8 +332,8 @@ class ReferenceUtil {
// Slices with index clamping
template <typename T>
- static std::vector<T> ClampSlice1D(
- const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
+ static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
+ int64 start, int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
@@ -355,9 +355,9 @@ class ReferenceUtil {
CHECK_LE(limits[1], input.n2());
CHECK_GE(strides[0], 1);
CHECK_GE(strides[1], 1);
- auto result =
- MakeUnique<Array2D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]));
+ auto result = absl::make_unique<Array2D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
(*result)(i0, i1) =
@@ -381,10 +381,10 @@ class ReferenceUtil {
CHECK_GE(strides[0], 1);
CHECK_GE(strides[1], 1);
CHECK_GE(strides[2], 1);
- auto result =
- MakeUnique<Array3D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]),
- CeilOfRatio(limits[2] - starts[2], strides[2]));
+ auto result = absl::make_unique<Array3D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
@@ -415,11 +415,11 @@ class ReferenceUtil {
CHECK_GE(strides[1], 1);
CHECK_GE(strides[2], 1);
CHECK_GE(strides[3], 1);
- auto result =
- MakeUnique<Array4D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
- CeilOfRatio(limits[1] - starts[1], strides[1]),
- CeilOfRatio(limits[2] - starts[2], strides[2]),
- CeilOfRatio(limits[3] - starts[3], strides[3]));
+ auto result = absl::make_unique<Array4D<T>>(
+ CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]),
+ CeilOfRatio(limits[3] - starts[3], strides[3]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
@@ -460,8 +460,8 @@ class ReferenceUtil {
template <typename F>
static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
const Array4D<float>& input, F&& map_function) {
- auto result = MakeUnique<Array4D<float>>(input.planes(), input.depth(),
- input.height(), input.width());
+ auto result = absl::make_unique<Array4D<float>>(
+ input.planes(), input.depth(), input.height(), input.width());
for (int64 plane = 0; plane < input.planes(); ++plane) {
for (int64 depth = 0; depth < input.depth(); ++depth) {
for (int64 height = 0; height < input.height(); ++height) {
@@ -495,8 +495,8 @@ class ReferenceUtil {
template <typename F>
static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
- auto result = MakeUnique<Array4D<float>>(lhs.planes(), lhs.depth(),
- lhs.height(), lhs.width());
+ auto result = absl::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
+ lhs.height(), lhs.width());
for (int64 plane = 0; plane < lhs.planes(); ++plane) {
for (int64 depth = 0; depth < lhs.depth(); ++depth) {
for (int64 height = 0; height < lhs.height(); ++height) {
@@ -530,7 +530,7 @@ class ReferenceUtil {
int64 out1 =
in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
- auto result = MakeUnique<Array2D<NativeT>>(out0, out1);
+ auto result = absl::make_unique<Array2D<NativeT>>(out0, out1);
result->Fill(pad);
int64 o0 = low_padding0;
for (int64 i0 = 0; i0 < in0; ++i0) {
@@ -631,7 +631,7 @@ class ReferenceUtil {
Array4D<NativeT> result(output_bounds[0], output_bounds[1],
output_bounds[2], output_bounds[3]);
result.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) {
+ [&](absl::Span<const int64> indices, NativeT* value) {
for (int i = 0; i < 4; ++i) {
bool in_low_padding = indices[i] < pad_low[i];
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
@@ -669,7 +669,7 @@ class ReferenceUtil {
static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
AssertSameSize2D(array1, arrays...);
- auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n2());
+ auto result = absl::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
for (int64 i = 0; i < array1.n1(); ++i) {
for (int64 j = 0; j < array1.n2(); ++j) {
(*result)(i, j) = f(array1(i, j), arrays(i, j)...);
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 8091bed499..a1b0f4045f 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -18,12 +18,12 @@ limitations under the License.
#include <cmath>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -36,7 +36,7 @@ namespace {
class ReferenceUtilTest : public ::testing::Test {
protected:
ReferenceUtilTest() {
- matrix_ = MakeUnique<Array2D<float>>(rows_, cols_);
+ matrix_ = absl::make_unique<Array2D<float>>(rows_, cols_);
// [1.f 2.f 3.f]
// [4.f 5.f 6.f]
for (int64 i = 0; i < rows_; ++i) {
@@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
- LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
+ LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
ErrorSpec(0.0001));
}
@@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
- LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({0}, result);
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
+ LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
ErrorSpec(0.0001));
}
@@ -108,12 +108,12 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
- auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
- /*height=*/4, /*width=*/5);
+ auto input = absl::make_unique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
@@ -121,13 +121,13 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
- auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
- /*height=*/4, /*width=*/5);
+ auto input = absl::make_unique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
input->FillWithMultiples(1.0f);
auto subtract_index = [](float value, int64 plane, int64 depth, int64 height,
int64 width) {
@@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
- LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
- LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
- *actual_literal, ErrorSpec(0.0001));
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray3D) {
@@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+ {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
ErrorSpec(0.0001));
}
@@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
- {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
- *actual_literal, ErrorSpec(0.0001));
+ {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
+ ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceArray4D) {
@@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
{{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
- LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
- LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
ErrorSpec(0.0001));
}
@@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
- *actual_literal, ErrorSpec(0.0001));
+ actual_literal, ErrorSpec(0.0001));
}
} // namespace
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 44b22a5586..97fcd37f6b 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -43,6 +43,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -62,6 +63,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 6788676181..84fe5b17d1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -23,12 +23,12 @@ limitations under the License.
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/rpc/grpc_stub.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/subprocess.h"
@@ -46,7 +46,7 @@ class GRPCClientTestBase : public ::testing::Test {
int port = tensorflow::internal::PickUnusedPortOrDie();
subprocess_.SetProgram(
service_main_path,
- {service_main_path, tensorflow::strings::Printf("--port=%d", port)});
+ {service_main_path, absl::StrFormat("--port=%d", port)});
subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT,
tensorflow::ACTION_DUPPARENT);
subprocess_.SetChannelAction(tensorflow::CHAN_STDERR,
@@ -54,9 +54,8 @@ class GRPCClientTestBase : public ::testing::Test {
CHECK(subprocess_.Start());
LOG(INFO) << "Launched subprocess";
- auto channel =
- ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port),
- ::grpc::InsecureChannelCredentials());
+ auto channel = ::grpc::CreateChannel(absl::StrFormat("localhost:%d", port),
+ ::grpc::InsecureChannelCredentials());
channel->WaitForConnected(gpr_time_add(
gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN)));
LOG(INFO) << "Channel to server is connected on port " << port;
@@ -96,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<float>(expected);
+ Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
ErrorSpec(0.0001)));
}
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index c68c857c30..d6b5149a24 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include "grpcpp/security/server_credentials.h"
#include "grpcpp/server.h"
#include "grpcpp/server_builder.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -44,7 +44,7 @@ int RealMain(int argc, char** argv) {
xla::GRPCService::NewService().ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(tensorflow::strings::Printf("localhost:%d", port));
+ string server_address(absl::StrFormat("localhost:%d", port));
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index a65bdebf51..f4e24bff34 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -69,6 +69,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -86,6 +87,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -99,9 +101,11 @@ cc_library(
":bfloat16_support",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -120,6 +124,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -156,6 +161,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@@ -175,6 +181,10 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -191,6 +201,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -226,6 +237,7 @@ cc_library(
hdrs = ["hlo_evaluator.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_query",
":shape_inference",
"//tensorflow/compiler/xla:literal",
@@ -237,6 +249,12 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -263,6 +281,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -275,6 +294,7 @@ cc_library(
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
+ "hlo_schedule.cc",
"hlo_sharding.cc",
],
hdrs = [
@@ -287,6 +307,7 @@ cc_library(
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
+ "hlo_schedule.h",
"hlo_sharding.h",
],
deps = [
@@ -311,6 +332,13 @@ cc_library(
"//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -326,6 +354,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -337,7 +366,7 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -363,6 +392,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/types:span",
],
)
@@ -375,6 +405,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -389,7 +420,8 @@ cc_library(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -419,6 +451,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -449,6 +482,9 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -466,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -517,6 +554,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -535,6 +573,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -552,6 +591,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -574,6 +614,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -615,6 +658,10 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
@@ -647,6 +694,10 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -669,6 +720,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -719,6 +771,10 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -736,6 +792,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:ptr_util",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -753,9 +810,11 @@ cc_library(
":hlo_execution_profile",
":hlo_graph_dumper",
":hlo_proto",
+ ":maybe_owning_device_memory",
":shaped_buffer",
":stream_pool",
"//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -766,6 +825,10 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -784,6 +847,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/types:span",
],
)
@@ -813,6 +877,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -831,6 +898,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -847,6 +916,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -864,6 +934,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -874,6 +947,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -908,6 +982,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -917,12 +993,15 @@ tf_cc_test(
deps = [
":buffer_liveness",
":hlo",
+ ":hlo_dataflow_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -939,8 +1018,8 @@ cc_library(
":buffer_value_containers",
":heap_simulator",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_proto",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -950,6 +1029,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -964,8 +1047,8 @@ tf_cc_test(
":cpu_plugin",
":flatten_call_graph",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -975,8 +1058,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -996,6 +1082,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1006,14 +1094,15 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -1031,6 +1120,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1047,8 +1137,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1059,12 +1152,15 @@ cc_library(
deps = [
":hlo",
":hlo_casting_utils",
+ ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1074,6 +1170,7 @@ cc_library(
hdrs = ["hlo_module_group_util.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_group_metadata",
":hlo_reachability",
"//tensorflow/compiler/xla:status",
@@ -1082,17 +1179,41 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_schedule_test",
+ srcs = ["hlo_schedule_test.cc"],
+ deps = [
+ ":heap_simulator",
+ ":hlo",
+ ":hlo_dce",
+ ":hlo_memory_scheduler",
+ ":hlo_ordering",
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
cc_library(
- name = "hlo_scheduling",
- srcs = ["hlo_scheduling.cc"],
- hdrs = ["hlo_scheduling.h"],
+ name = "hlo_memory_scheduler",
+ srcs = ["hlo_memory_scheduler.cc"],
+ hdrs = ["hlo_memory_scheduler.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1101,24 +1222,27 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
)
tf_cc_test(
- name = "hlo_scheduling_test",
- srcs = ["hlo_scheduling_test.cc"],
+ name = "hlo_memory_scheduler_test",
+ srcs = ["hlo_memory_scheduler_test.cc"],
deps = [
- ":buffer_value",
":heap_simulator",
":hlo",
+ ":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
+ ":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1142,6 +1266,7 @@ cc_library(
":hlo_pass",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1167,6 +1292,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1181,6 +1307,9 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1196,8 +1325,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1216,6 +1347,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1231,6 +1364,7 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1245,6 +1379,7 @@ cc_library(
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1267,6 +1402,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1276,6 +1412,7 @@ cc_library(
hdrs = ["algebraic_simplifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_creation_utils",
":hlo_pass",
":hlo_query",
@@ -1289,6 +1426,11 @@ cc_library(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1298,6 +1440,7 @@ tf_cc_test(
deps = [
":algebraic_simplifier",
":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
@@ -1312,6 +1455,8 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1323,8 +1468,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":hlo_pass",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1377,6 +1521,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1414,6 +1559,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1439,8 +1586,7 @@ cc_library(
deps = [
":hlo",
":hlo_evaluator",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1455,6 +1601,8 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1468,6 +1616,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1567,6 +1716,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -1582,6 +1732,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1602,6 +1753,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -1616,6 +1768,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1635,6 +1788,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -1654,6 +1808,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
alwayslink = True, # Contains per-platform computation placer registration
)
@@ -1667,6 +1823,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1704,6 +1862,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1744,6 +1903,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1758,6 +1919,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1789,6 +1951,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1797,6 +1961,8 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1805,6 +1971,9 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1820,6 +1989,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1847,6 +2018,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1864,6 +2037,9 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1882,6 +2058,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1923,6 +2103,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -1959,6 +2141,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -1979,6 +2162,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1999,6 +2184,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2016,6 +2202,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
],
)
@@ -2028,7 +2215,6 @@ cc_library(
":hlo_dataflow_analysis",
":logical_buffer",
":logical_buffer_analysis",
- "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -2036,6 +2222,11 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2054,6 +2245,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2086,6 +2278,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2108,6 +2304,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2127,6 +2324,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2175,7 +2373,10 @@ cc_library(
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -2204,21 +2405,23 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
- ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
- ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2235,6 +2438,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -2258,6 +2462,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2300,9 +2505,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2339,6 +2546,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -2355,6 +2565,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2373,9 +2584,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2392,6 +2605,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2402,6 +2616,7 @@ tf_cc_test(
":hlo",
":hlo_constant_folding",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2423,6 +2638,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -2437,6 +2653,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2497,6 +2715,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -2543,6 +2762,22 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "maybe_owning_device_memory",
+ srcs = [
+ "maybe_owning_device_memory.cc",
+ ],
+ hdrs = [
+ "maybe_owning_device_memory.h",
+ ],
+ deps = [
+ ":device_memory_allocator",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -2552,6 +2787,7 @@ cc_library(
hdrs = ["elemental_ir_emitter.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_config",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -2560,11 +2796,14 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
"@llvm//:transform_utils",
],
@@ -2596,10 +2835,11 @@ cc_library(
":computation_layout",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -2612,6 +2852,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2648,8 +2889,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -2659,6 +2900,7 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
@@ -2683,6 +2925,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
alwayslink = 1,
)
@@ -2699,6 +2944,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2780,9 +3026,9 @@ cc_library(
hdrs = ["stream_pool.h"],
deps = [
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -2880,6 +3126,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2900,7 +3148,7 @@ cc_library(
hdrs = ["tuple_util.h"],
deps = [
":hlo",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2926,7 +3174,8 @@ cc_library(
":hlo_creation_utils",
":tuple_util",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -2940,6 +3189,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -2955,6 +3205,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2982,6 +3234,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3015,13 +3269,13 @@ cc_library(
cc_library(
name = "source_map_util",
- srcs = ["source_map_util.cc"],
+ srcs = [],
hdrs = ["source_map_util.h"],
deps = [
":executable",
"//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -3036,6 +3290,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -3067,8 +3325,11 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -3077,11 +3338,15 @@ tf_cc_test(
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
+ ":hlo",
+ ":hlo_casting_utils",
+ ":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main", # fixdeps: keep
+ "@com_google_absl//absl/strings",
],
)
@@ -3100,6 +3365,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index f7812d9661..c88a3a3b4b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -22,13 +22,20 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -40,8 +47,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -122,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleImag(HloInstruction* imag) override;
+ Status HandleIota(HloInstruction* instruction) override;
+
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleDivide(HloInstruction* divide) override;
@@ -198,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero =
computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+ LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -266,7 +273,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
@@ -289,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
return scalar_add_computation_;
}
+ // Tries to fold a kPad in the input or filter into the convolution
+ // instruction's window.
+ StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+ StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+ // Tries to use a kDot in place of the given convolution.
+ StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -305,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplification on platforms where it causes a slowdown.
+ // Disable convolution -> dot simplification on platforms where it causes a
+ // slowdown.
bool enable_conv_simplification_;
// Cached computation for adding two scalar F32.
@@ -444,8 +460,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
Status AlgebraicSimplifierVisitor::HandleConcatenate(
HloInstruction* concatenate) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(
- concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
if (operands.size() == 1) {
// Unary concatenates are useless.
ReplaceInstructionIfSameShape(concatenate, operands[0]);
@@ -521,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ HloInstruction::CreateConstant(literal.Clone()));
}
}
@@ -540,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ Literal unique_scalar(
LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -548,6 +563,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant,
HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
}
+
+ // If a literal is an increasing sequence from zero, replace it with an iota.
+ if (ShapeUtil::Rank(constant->shape()) == 1 &&
+ ShapeUtil::ElementsIn(constant->shape()) > 1 &&
+ constant->literal().IsR1Iota()) {
+ return ReplaceWithNewInstruction(
+ constant, HloInstruction::CreateIota(constant->shape(), 0));
+ }
return Status::OK();
}
@@ -575,7 +598,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
namespace {
template <typename T>
Status InvertConstant(const HloInstruction& constant, Literal* result) {
- return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return result->Populate<T>([&](absl::Span<const int64> indices) {
return T{1.0} / constant.literal().Get<T>(indices);
});
}
@@ -662,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}
auto inverse = computation_->AddInstruction(
- HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ HloInstruction::CreateConstant((new_literal.Clone())));
TF_ASSIGN_OR_RETURN(auto new_divide,
MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
return ReplaceInstruction(divide, new_divide);
@@ -827,18 +850,18 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
TF_ASSIGN_OR_RETURN(
HloInstruction * optimized_lhs_concat,
- OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs,
+ OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
rhs_contracting_dim, /*swapped=*/false));
if (optimized_lhs_concat) {
return optimized_lhs_concat;
}
- return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs,
+ return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
lhs_contracting_dim, /*swapped=*/true);
}
StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
- const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
+ const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
lhs->concatenate_dimension() == lhs_contracting_dim &&
@@ -936,12 +959,13 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
new_dot_rhs = rhs_slice;
}
- auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums));
+ auto* new_dot = computation_->AddInstruction(
+ HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
+ new_dot_dnums, dot.precision_config()));
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
- dot_shape, HloOpcode::kAdd, add_result, new_dot));
+ dot.shape(), HloOpcode::kAdd, add_result, new_dot));
} else {
add_result = new_dot;
}
@@ -1038,8 +1062,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
const int n =
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
- auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
- memoized_shape, left_operand, right_operand, dnums));
+ auto* memoized_inst = computation_->AddInstruction(
+ HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
+ dnums, dot->precision_config()));
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1135,8 +1160,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
- rhs->mutable_operand(0), lhs->mutable_operand(0),
- dot_dimension_numbers));
+ rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
+ dot->precision_config()));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -1232,9 +1257,8 @@ namespace {
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
-std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
- const HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
+absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+ const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
@@ -1252,11 +1276,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
- return std::make_pair(false, std::vector<int64>());
+ return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
- return std::make_pair(true, output_dim_indices);
+ return output_dim_indices;
}
// Returns true if the output of "instruction" is a permutation of the
@@ -1385,6 +1409,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+ // broadcast(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateIota(
+ broadcast->shape(),
+ dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
+ }
+
// Merge two consecutive broadcasts into a single one.
if (operand->opcode() == HloOpcode::kBroadcast) {
std::vector<int64> new_dimensions;
@@ -1439,6 +1472,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
+ // iota -> zero if the iota dimension never produces an element other than
+ // zero.
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
+ auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(iota->shape().element_type()).Clone()));
+ return ReplaceWithNewInstruction(
+ iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
@@ -1535,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).Clone());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1570,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).Clone()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -1713,12 +1759,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
reshape, reshape->operand(0)->dimensions());
- if (opt_dims.first) {
+ if (opt_dims.has_value()) {
return ReplaceWithNewInstruction(
reshape,
HloInstruction::CreateBroadcast(
reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
- opt_dims.second));
+ *opt_dims));
+ }
+ }
+
+ // reshape(iota) -> iota.
+ if (operand->opcode() == HloOpcode::kIota) {
+ auto* iota = Cast<HloIotaInstruction>(operand);
+ auto opt_dims =
+ ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
+ if (opt_dims.has_value()) {
+ CHECK_EQ(opt_dims->size(), 1);
+ return ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
}
}
@@ -1752,8 +1811,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
}
auto is_unstrided_slice = [](const HloInstruction* hlo) {
- return c_all_of(hlo->slice_strides(),
- [](int64 stride) { return stride == 1; });
+ return absl::c_all_of(hlo->slice_strides(),
+ [](int64 stride) { return stride == 1; });
};
if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) {
@@ -1815,7 +1874,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
ShapeUtil::IsZeroElementArray(reduce->shape())) {
@@ -1930,7 +1989,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// This should make fusion easier or use less memory bandwidth in the unfused
// case.
if (arg->opcode() == HloOpcode::kConcatenate &&
- c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) {
+ absl::c_linear_search(reduce->dimensions(),
+ arg->concatenate_dimension())) {
HloInstruction* old_reduce = nullptr;
for (HloInstruction* operand : arg->operands()) {
HloInstruction* new_reduce = computation_->AddInstruction(
@@ -1983,9 +2043,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
- << (convert != nullptr ? tensorflow::strings::StrCat(
- "\nvia convert: ", convert->ToString())
- : "");
+ << (convert != nullptr
+ ? absl::StrCat("\nvia convert: ", convert->ToString())
+ : "");
// Do not fold interior padding into ReduceWindow since the backends do not
// support it.
@@ -2011,7 +2071,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (!converted_pad_literal.ok()) {
return false;
}
- return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+ return converted_pad_literal.ValueOrDie() == reduce_init_literal;
};
// The pad value is usually a constant, so we handle that case and do not
// try to get more fancy about proving equivalence in cases beyond that.
@@ -2161,40 +2221,157 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
HloInstruction* convolution) {
- auto lhs = convolution->mutable_operand(0);
- auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
- ShapeUtil::IsZeroElementArray(rhs->shape())) {
- return ReplaceWithNewInstruction(
- convolution,
- HloInstruction::CreateBroadcast(
- convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type())
- .CloneToUnique())),
- {}));
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
+ return false;
+ }
+
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
+ }
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
+
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
}
+
+ const auto& padding = rhs->padding_config();
+
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
+
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
+ }
+
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
+
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
+ }
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
+
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
if (!enable_conv_simplification_) {
- return Status::OK();
+ return false;
}
- // HandleConvolution tries to replace a convolution with a DOT instruction.
- //
- // Only add when bitcasts can be used:
- // - if bitcasts are not supported, then reshapes could be used but will
- // end up with another copy.
- // - if bitcasts are supported, the simplifier will be called again with
- // bitcasts_ == true.
- // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+ // layout-insensitive mode, for fear of adding nontrivial reshapes.
if (!is_layout_sensitive_) {
- return Status::OK();
+ return false;
}
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
const Shape& input_shape = lhs->shape();
const Shape& filter_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();
@@ -2205,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Require the spatial dimensions in the kernel to have a bound of one.
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
- return Status::OK();
+ return false;
}
}
@@ -2216,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
- return Status::OK();
+ return false;
}
// Also, the shapes must align for a rowmajor matmul:
@@ -2242,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dnums.kernel_input_feature_dimension()) <
PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
dnums.kernel_output_feature_dimension()))) {
- return Status::OK();
+ return false;
}
auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2284,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
- return Status::OK();
+ return false;
}
auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2293,8 +2470,47 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
- return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+ dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
+ convolution->precision_config()));
+
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution) {
+ // Zero-sized input or filter.
+ if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+ return ReplaceWithNewInstruction(
+ convolution,
+ HloInstruction::CreateBroadcast(
+ convolution->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type()))),
+ {}));
+ }
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
+ // Try to replace the convolution with a kDot instruction.
+ TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+ if (replaced_with_dot) {
+ return Status::OK();
+ }
+
+ return Status::OK();
}
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index c48196e861..b864c372fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface {
enable_dot_strength_reduction_(enable_dot_strength_reduction),
enable_conv_simplification_(enable_conv_simplification) {}
~AlgebraicSimplifier() override = default;
- tensorflow::StringPiece name() const override { return "algsimp"; }
+ absl::string_view name() const override { return "algsimp"; }
// Run algebraic simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 5837391d75..3fc1ba2427 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -18,11 +18,15 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -34,13 +38,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-using ::testing::ElementsAre;
namespace xla {
namespace {
+using ::testing::ElementsAre;
+
namespace op = xla::testing::opcode_matchers;
AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
@@ -290,6 +293,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
EXPECT_THAT(root, op::Constant());
}
+TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
+ HloComputation::Builder builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+}
+
// Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, SubZero) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -513,7 +531,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
@@ -1026,7 +1044,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
dim->set_window_reversal(false);
// Create add computation.
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
+ ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -1820,6 +1839,126 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
op::Reshape(op::Broadcast(param)));
}
+TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(HloInstruction::CreateIota(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
+ Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
+ builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
+ auto result_shape = iota->shape();
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ auto root = computation->root_instruction();
+ EXPECT_THAT(root, op::Broadcast(op::Constant()));
+ EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension(),
+ 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Iota());
+ const int64 iota_dim =
+ Cast<HloIotaInstruction>(computation->root_instruction())
+ ->iota_dimension();
+ EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
+}
+
+TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto iota = builder.AddInstruction(
+ HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
+
+ HloComputation* computation = module().AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota()));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
@@ -2006,6 +2145,269 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
}
+// Used for TEST_Ps that test merging (or not) of a kPad instruction into a
+// convolution's Window.
+struct ConvPaddingTestcase {
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window)
+ : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
+ /*pad_value=*/0) {}
+
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window, float pad_value)
+ : padding(padding),
+ orig_conv_window(orig_conv_window),
+ expected_conv_window(expected_conv_window),
+ pad_value(pad_value) {}
+
+ string ToString() const {
+ return absl::StrFormat(
+ "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
+ "pad_value=%f",
+ padding, orig_conv_window, expected_conv_window, pad_value);
+ }
+
+ string padding;
+ string orig_conv_window;
+ string expected_conv_window;
+ float pad_value;
+};
+
+// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
+// computation that does
+//
+// conv(pad(param0, padding=padding), param1), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvInputPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvInputPaddingTestCases, ConvInputPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Merge this edge padding into the conv.
+ {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
+ // Merge this edge padding with the conv's edge padding.
+ {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
+ // Merge this interior-padded kPad with the unpadded conv. The 3x6
+ // interior padding gets transformed to 4x7 conv lhs dilation.
+ {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
+ // kPad has dilation on one dim, conv has it on the other; merge them.
+ {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
+ // kPad has dilation and edge padding on one dim, conv has them on the
+ // other; merge them.
+ {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
+ "pad=0_1x3_0 lhs_dilate=2x10"},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
+
+ // We refuse to transform the following because on some dimension, one
+ // of the kPad and conv has dilation and the other has some sort of
+ // padding.
+ {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+
+ // We can't merge feature or batch padding into the conv.
+ {"1_0x0_0x0_0x0_0", "", ""},
+ {"0_0x1_0x0_0x0_0", "", ""},
+ }));
+
+TEST_P(ConvInputPaddingTest, DoTest) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
+ "input"));
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ input, pad_value, padding_config));
+
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1,
+ ShapeUtil::MakeShape(
+ F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window =
+ ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
+ .ValueOrDie();
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
+ /*feature_group_count=*/1, window,
+ dnums)
+ .ValueOrDie(),
+ lhs_pad, filter, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrCat("size=3x3 ", testcase.expected_conv_window));
+ }
+}
+
+// ConvFilterPaddingTest (and its one associated TEST_P) checks that a
+// computation that does
+//
+// conv(param0, pad(param1, padding=padding)), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvFilterPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvFilterPaddingTestCases, ConvFilterPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Can only merge interior padding on the filter's spatial dimensions;
+ // all
+ // other paddings (edge padding and interior padding on the channel
+ // dims)
+ // should be rejected out of hand.
+ {"1_0_0x0_0_0x0_0x0_0", "", ""},
+ {"0_1_0x0_0_0x0_0x0_0", "", ""},
+ {"0_0_1x0_0_0x0_0x0_0", "", ""},
+ {"0_0_0x1_0_0x0_0x0_0", "", ""},
+ {"0_0_0x0_1_0x0_0x0_0", "", ""},
+ {"0_0_0x0_0_1x0_0x0_0", "", ""},
+ {"0_0_0x0_0_0x1_0x0_0", "", ""},
+ {"0_0_0x0_0_0x0_1x0_0", "", ""},
+ {"0_0_0x0_0_0x0_0x1_0", "", ""},
+ {"0_0_0x0_0_0x0_0x0_1", "", ""},
+
+ // Interior padding on channel dims can be merged into the conv, so long
+ // as the conv and pad don't have interior padding on the same dim.
+ {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
+ {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
+ {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
+
+ // Can't merge if for a given dim there's interior padding on both the
+ // pad and conv.
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
+ }));
+
+TEST_P(ConvFilterPaddingTest, DoIt) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
+ "input"));
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ filter, pad_value, padding_config));
+
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeShape(
+ F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
+ rhs_pad->shape().dimensions(2),
+ rhs_pad->shape().dimensions(3),
+ testcase.orig_conv_window))
+ .ValueOrDie();
+
+ // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
+ // after the transformation.
+ PrecisionConfig precision_config;
+ precision_config.add_operand_precision(PrecisionConfig::HIGH);
+ precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ /*feature_group_count=*/1, window,
+ dnums)
+ .ValueOrDie(),
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ precision_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrFormat("size=%dx%d %s",
+ conv->operand(1)->shape().dimensions(2),
+ conv->operand(1)->shape().dimensions(3),
+ testcase.expected_conv_window));
+ EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
+ ->precision_config()
+ .operand_precision(),
+ ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
+ }
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;
@@ -2037,7 +2439,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
- auto build_and_simplify = [&options]() -> string {
+ auto build_and_simplify = [&]() -> string {
HloComputation::Builder b(TestName());
Window window;
@@ -2109,7 +2511,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
auto out_dims = in_dims;
out_dims[in_channel_idx] = options.f_output_channels;
- auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
+ auto make_shape = [](absl::Span<const int64> dims,
bool minor_to_major_layout) {
if (minor_to_major_layout) {
return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
@@ -2126,8 +2528,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
HloInstruction* filter =
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
- b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
- window, dnums));
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ out_shape, input, filter,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
@@ -2143,9 +2546,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
root->operand(0)->opcode() == HloOpcode::kDot) {
auto lhs_shape = root->operand(0)->operand(0)->shape();
auto rhs_shape = root->operand(0)->operand(1)->shape();
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
- tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
+ return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
+ absl::StrJoin(rhs_shape.dimensions(), "x"));
}
return "UNEXPECTED CHANGE";
};
@@ -2506,7 +2908,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
+ DefaultPrecisionConfig(2)));
std::unique_ptr<HloComputation> dot_computation(builder.Build());
HloComputation::Builder call_builder(TestName() + ".Call");
@@ -2529,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
+ Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector)};
+ Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2648,6 +3051,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
}
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
+ HloComputation::Builder builder(TestName());
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
+// Test that a broadcast of an iota can be merged to one iota.
+TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
+ HloComputation::Builder builder(TestName());
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
+ HloInstruction* iota =
+ builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
+ builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+ EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
+}
+
struct PadReduceWindowEffectiveBroadcastCase {
std::vector<int64> input_spatials;
std::vector<int64> symmetric_pad_spatials;
@@ -2660,11 +3104,10 @@ struct PadReduceWindowEffectiveBroadcastCase {
bool should_become_broadcast;
string ToTestCaseName() const {
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Join(input_spatials, ","), ";",
- tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
- tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
- ";", should_become_broadcast);
+ return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
+ absl::StrJoin(symmetric_pad_spatials, ","), ";",
+ absl::StrJoin(reduce_window_spatials, ","), ";",
+ prepend_a, ";", should_become_broadcast);
}
};
@@ -2682,8 +3125,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
// a and b are parallel bounds we can either turn into a B F S0 S1 or
// `B S0 S1 F` kind of pattern.
- auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
- int64 a, int64 b) {
+ auto decorate_spatials = [&param](absl::Span<const int64> spatials, int64 a,
+ int64 b) {
std::vector<int64> result;
if (param.prepend_a) {
result.push_back(a);
@@ -2818,8 +3261,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -2894,8 +3337,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -2958,8 +3401,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3076,8 +3519,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 dot_row_size = 1;
int64 dot_col_size = spec.n;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3146,8 +3589,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 dot_row_size = spec.m;
int64 dot_col_size = 1;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 51ebc4763b..1ed6142dce 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -17,15 +17,15 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -69,8 +69,7 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
return InvalidArgument(
"AllocationTracker for platform %s cannot register buffer from "
"platform %s",
- backend_->platform()->Name().c_str(),
- shaped_buffer.platform()->Name().c_str());
+ backend_->platform()->Name(), shaped_buffer.platform()->Name());
}
}
@@ -91,8 +90,9 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
// If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
// into a regular ShapedBuffer, which is stored in
// handle_to_shaped_buffers_.
- handle_to_shaped_buffers_[handle].emplace_back(MakeUnique<ShapedBuffer>(
- ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
+ handle_to_shaped_buffers_[handle].emplace_back(
+ absl::make_unique<ShapedBuffer>(
+ ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
}
GlobalDataHandle result;
@@ -124,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
// "handle does not exist".
auto it = handle_to_shaped_buffers_.find(data.handle());
if (it == handle_to_shaped_buffers_.end()) {
- return NotFound("no allocation record for global data handle: %lld",
+ return NotFound("no allocation record for global data handle: %d",
data.handle());
}
for (auto& shaped_buffer : it->second) {
@@ -143,7 +143,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
// the same for all buffers across replicas.
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) {
- return InvalidArgument("global data handle %lld is not a tuple",
+ return InvalidArgument("global data handle %d is not a tuple",
data.handle());
}
// If the on-host representation is a tuple, then the on-device one should be
@@ -200,14 +200,14 @@ StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
VLOG(2) << "resolve:" << data.handle();
auto it = handle_to_shaped_buffers_.find(data.handle());
if (it == handle_to_shaped_buffers_.end()) {
- return NotFound("no allocation record for global data handle: %lld",
+ return NotFound("no allocation record for global data handle: %d",
data.handle());
}
std::vector<const ShapedBuffer*> replicated_buffers;
for (const auto& shaped_buffer : it->second) {
if (shaped_buffer == nullptr) {
- return InvalidArgument(
- "global data handle %lld was previously deallocated", data.handle());
+ return InvalidArgument("global data handle %d was previously deallocated",
+ data.handle());
}
replicated_buffers.push_back(shaped_buffer.get());
}
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index d12be3e007..5c180cbdd4 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -111,11 +112,11 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
return stream_pools_.at(executor).BorrowStream(executor);
}
-Backend::Backend(
- se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
- TransferManager* transfer_manager, ComputationPlacer* computation_placer,
- int intra_op_parallelism_threads)
+Backend::Backend(se::Platform* platform, Compiler* compiler,
+ absl::Span<se::StreamExecutor* const> stream_executors,
+ TransferManager* transfer_manager,
+ ComputationPlacer* computation_placer,
+ int intra_op_parallelism_threads)
: platform_(platform),
compiler_(compiler),
transfer_manager_(transfer_manager),
@@ -127,8 +128,8 @@ Backend::Backend(
}
}
// Create a memory allocator for the valid stream executors.
- memory_allocator_ =
- MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
+ memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>(
+ platform, stream_executors);
CHECK(!stream_executors_.empty())
<< "Service found no devices for backend " << platform_->Name() << '.';
@@ -176,7 +177,7 @@ StatusOr<se::StreamExecutor*> Backend::stream_executor(
}
}
return InvalidArgument("device %s not supported by XLA service",
- device_name(device_ordinal).c_str());
+ device_name(device_ordinal));
}
StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 1bc3796fa4..a2dafbe803 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -28,8 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -130,7 +130,7 @@ class Backend {
// Return a string identifier for the given device, eg: "GPU:3".
string device_name(int device_ordinal) const {
- return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal);
+ return absl::StrCat(platform_->Name(), ":", device_ordinal);
}
// Returns true if the devices with the given ordinals are equivalent from
@@ -149,7 +149,7 @@ class Backend {
private:
struct EigenThreadPoolWrapper;
Backend(se::Platform* platform, Compiler* compiler,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
+ absl::Span<se::StreamExecutor* const> stream_executors,
TransferManager* transfer_manager,
ComputationPlacer* computation_placer,
int intra_op_parallelism_threads);
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index 2099916509..eda026ac56 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -62,7 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
- MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
+ MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+ batch_dot->precision_config()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
@@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
return true;
}
-tensorflow::StringPiece BatchDotSimplification::name() const {
+absl::string_view BatchDotSimplification::name() const {
return "batch-dot-simplification";
}
@@ -84,10 +86,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> dot_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
- [](HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kDot;
- });
+ absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
+ [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kDot;
+ });
}
for (HloInstruction* dot_instr : dot_instrs) {
TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index c0ca8d8eba..79d37f08d3 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -28,7 +28,7 @@ namespace xla {
class BatchDotSimplification : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
private:
StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot(
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index c4cd60c120..30d33e0d35 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -33,9 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -43,7 +43,7 @@ namespace xla {
namespace {
-using tensorflow::gtl::optional;
+using absl::optional;
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
@@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
const Shape feature_shape = scale->shape();
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
const Shape feature_shape = scale->shape();
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
@@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = LiteralUtil::CreateR0(0.0f);
- TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
- TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+ TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
@@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto elements_per_feature_literal =
LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
+ elements_per_feature_literal.Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 7ae202c583..76e32174f3 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface {
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
- tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
+ absl::string_view name() const override { return "batchnorm_expander"; }
// Run operation expander on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index a725351462..aba0d9bb5b 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
index 1b8b2d2045..d63287539d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index c939838709..5dcd31b83d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16ConversionFolding() override = default;
- tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
+ absl::string_view name() const override { return "bfloat16-fold"; }
// Run BF16 conversion folding on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 7cf05ca443..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -235,8 +239,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
- sum, /*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ sum, /*replica_groups=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 16e99b5722..d5b1148058 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -15,12 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
- // Special handling for cross-replica-sum and sort which can have a tuple
- // output.
- Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleSort(HloInstruction* sort) override;
-
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
@@ -73,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
// Inserts conversion HLOs to replace the called computations' BF16
// operands/outputs to F32.
Status ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
@@ -118,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
}
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
+ HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
std::map<HloComputation*, HloComputation*> cloned_computations;
for (auto& comp : bf16_called_comps) {
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
@@ -150,23 +144,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
return Status::OK();
}
-Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
- HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape())) {
- return HandleInstruction(crs);
- } else {
- return HandleMultipleOutputs(crs);
- }
-}
-
-Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
- if (!ShapeUtil::IsTuple(sort->shape())) {
- return HandleInstruction(sort);
- } else {
- return HandleMultipleOutputs(sort);
- }
-}
-
Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
HloInstruction* hlo) {
std::vector<PrimitiveType> operand_types(hlo->operand_count());
@@ -380,6 +357,12 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ // TODO(b/112040122): Correctly normalize variadic reduce.
+ if ((hlo->opcode() == HloOpcode::kSort ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
+ ShapeUtil::IsTuple(hlo->shape())) {
+ return HandleMultipleOutputs(hlo);
+ }
return HandleInstruction(hlo);
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 2a60fe0af3..30b6346312 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface {
: bfloat16_support_(bfloat16_support) {}
~BFloat16Normalization() override = default;
- tensorflow::StringPiece name() const override { return "bf16-normalization"; }
+ absl::string_view name() const override { return "bf16-normalization"; }
// Run BF16 normalization on the given computation. Returns whether the
// computation was changed.
@@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface {
~BFloat16MixedPrecisionRemoval() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "bf16-mixed-precision-removal";
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index f9f1f64998..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,15 +68,20 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
- HloVerifier verifier(/*allow_mixed_precision=*/true);
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true);
EXPECT_IS_OK(verifier.Run(module).status());
return result.ValueOrDie();
@@ -104,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -132,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -162,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -200,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -251,14 +256,14 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
- /*replica_group_ids=*/{}, /*barrier=*/"",
- /*all_reduce_id=*/tensorflow::gtl::nullopt));
+ /*replica_groups=*/{}, /*barrier=*/"",
+ /*all_reduce_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -285,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -307,13 +312,16 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 2fb401c428..545a6ecfb1 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
HloInstruction* hlo) {
auto adjust_computation =
[this, hlo](HloComputation* computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
// Adjust parameters.
CHECK_EQ(operands.size(), computation->num_parameters());
for (int64 i = 0; i < operands.size(); ++i) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 02b8cad089..1ee64971ab 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface {
~BFloat16Propagation() override = default;
- tensorflow::StringPiece name() const override {
- return "bfloat16-propagation";
- }
+ absl::string_view name() const override { return "bfloat16-propagation"; }
// Runs the pass on the given module. Returns whether the module was changed
// (precision reductions were added).
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 69b654d30e..e032b5c624 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16PropagationTest : public HloTestBase {
+class BFloat16PropagationTest : public HloVerifiedTestBase {
protected:
+ BFloat16PropagationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
// Runs the propagation pass on the given module, and returns whether the
// module is changed after this pass.
bool PropagatePrecision(HloModule* module) {
@@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase {
inst->users()[0]->opcode() == HloOpcode::kConvert &&
inst->users()[0]->shape().element_type() == BF16;
}
+
+ std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ DefaultPrecisionConfig(2));
+ }
};
// Tests that BF16 can propagate through select over non-tuple buffers, but not
@@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
- HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b));
+ HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
HloInstruction* sel = builder.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a));
- HloInstruction* root = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
+ HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(dot->operand(0)));
@@ -150,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
dot->operand(1)->literal()));
}
@@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple0->shape(), tuple1, 0)),
0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
HloInstruction* output_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
@@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), output_tuple);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
// lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(add1));
@@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
// Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
@@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
@@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_FALSE(OutputsBF16(add));
@@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
HloInstruction* b_f1 = builder_f1.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
- HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1));
+ HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion1);
EXPECT_TRUE(OutputsBF16(add));
@@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* add_f = builder_f.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
- HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
+ HloInstruction* dot_f =
+ builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion);
}
@@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
HloInstruction::CreateGetTupleElement(shape, fusion, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, fusion, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(gte0));
@@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) {
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add0));
@@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
auto builder_cond = HloComputation::Builder("cond");
auto cond_param = builder_cond.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_param, cond_param));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
auto body_param = builder_body.AddInstruction(
HloInstruction::CreateParameter(0, shape, "body_param"));
- auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, body_param, body_param));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
auto body = module->AddEmbeddedComputation(builder_body.Build());
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(
@@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest,
HloInstruction::CreateParameter(0, shape, "cond_param"));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
+ {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest,
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add));
EXPECT_FALSE(OutputsBF16(body_fusion));
@@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
// This add should prevent RHS from using BF16
auto cond_add_rhs = builder_cond.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot1 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
- auto body_dot2 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs));
+ auto body_dot1 =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
+ auto body_dot2 =
+ builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
auto body_transpose = builder_body.AddInstruction(
HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
builder_body.AddInstruction(
@@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(lhs));
@@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond0_add_rhs =
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
- auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+ auto cond0_dot =
+ builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
// Condition computation for the second while.
@@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond1_add_lhs =
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
- auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+ auto cond1_dot =
+ builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
// Body computation shared by both whiles.
@@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
builder_body.AddInstruction(
HloInstruction::CreateTuple({body_dot, body_rhs}));
auto body = module->AddEmbeddedComputation(builder_body.Build());
@@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
- auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 1))));
- auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 1))));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto lhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+ auto rhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_FALSE(OutputsBF16(body_dot));
EXPECT_FALSE(OutputsBF16(body_rhs));
EXPECT_FALSE(OutputsBF16(body_lhs));
@@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), add2);
EXPECT_EQ(add2->operand(0), add0);
@@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
HloInstruction::CreateGetTupleElement(shape, domain, 0));
HloInstruction* b_gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, domain, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
// test BF16 propagated through domain
@@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
HloInstruction* b_trans = builder.AddInstruction(
HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* dot =
+ builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(a_trans));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index cfd26fc778..65fa951afe 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,13 +22,14 @@ limitations under the License.
#include <ostream>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -36,20 +37,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
+namespace {
+using absl::StrAppend;
+using absl::StrAppendFormat;
using ::tensorflow::gtl::FlatMap;
using ::tensorflow::gtl::FlatSet;
-using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::HumanReadableNumBytes;
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-
-namespace {
template <typename T>
string ColocatedBufferSetsToString(const T& container, const char* title) {
@@ -61,12 +57,65 @@ string ColocatedBufferSetsToString(const T& container, const char* title) {
return result;
}
-// Walk the call graph of the HLO module and place each computation into either
-// thread_local_computations or global_computations depending upon whether the
-// computation requires thread-local allocations or global allocations. The
-// elements in thread_local_computations and global_computations are in post
-// order (if computation A has an instruction which calls computation B, then A
-// will appear after B in the vector).
+// Checks that points-to set of 'instruction' is unambiguous and distinct
+// (ensured by CopyInsertion), then adds the buffer from the points-to set at
+// 'index' to 'colocated_set'.
+const LogicalBuffer* AddBufferToColocatedSet(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis,
+ std::vector<const LogicalBuffer*>* colocated_set) {
+ // CopyInsertion ensures root points-to set is unambiguous and distinct.
+ const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
+ DCHECK(!points_to.IsAmbiguous());
+ colocated_set->push_back(points_to.element(index)[0]);
+ return colocated_set->back();
+}
+
+// Given the interference map of a graph (the list of interfering node indices
+// for each node), perform graph coloring such that interfering nodes are
+// assigned to different colors. Returns the assigned color of the nodes, where
+// the colors are represented as integer values [0, color_count).
+std::vector<int64> ColorInterferenceGraph(
+ const std::vector<std::vector<int64>>& interference_map) {
+ const int64 node_count = interference_map.size();
+
+ // Sort the nodes such that we assign nodes with more interference first. This
+ // relies on the common heuristic of assigning the most constrained node
+ // first, but it would be good to investigate other ordering heuristics too.
+ std::vector<int64> nodes(node_count);
+ std::iota(nodes.begin(), nodes.end(), 0);
+ std::sort(nodes.begin(), nodes.end(),
+ [&interference_map](const int64 i, const int64 j) {
+ return interference_map[i].size() > interference_map[j].size();
+ });
+
+ const int64 kColorUnassigned = -1;
+ std::vector<int64> assigned_colors(node_count, kColorUnassigned);
+ for (int64 node : nodes) {
+ // Mark the colors that are already assigned to the neighbors.
+ std::vector<bool> available_colors(node_count, true);
+ for (int64 neighbor : interference_map[node]) {
+ int64 color = assigned_colors[neighbor];
+ if (color != kColorUnassigned) {
+ available_colors[color] = false;
+ }
+ }
+
+ // Find the color that is not yet assigned to the neighbors.
+ int64 color = kColorUnassigned;
+ for (color = 0; color < available_colors.size(); ++color) {
+ if (available_colors[color]) {
+ break;
+ }
+ }
+ CHECK_NE(color, kColorUnassigned);
+ assigned_colors[node] = color;
+ }
+ return assigned_colors;
+}
+
+} // namespace
+
Status GatherComputationsByAllocationType(
const HloModule* module,
std::vector<const HloComputation*>* thread_local_computations,
@@ -107,7 +156,7 @@ Status GatherComputationsByAllocationType(
return InvalidArgument(
"computation %s has conflicting allocation requirements (global "
"and thread-local)",
- computation->name().c_str());
+ computation->name());
}
if (is_thread_local) {
@@ -130,7 +179,7 @@ Status GatherComputationsByAllocationType(
return InvalidArgument(
"computation %s cannot contain call/while op because it "
"requires thread-local buffer allocations",
- computation->name().c_str());
+ computation->name());
}
worklist.push_back(std::make_pair(subcomputation,
false)); // Not thread local.
@@ -147,9 +196,8 @@ Status GatherComputationsByAllocationType(
true)); // Thread local.
break;
default:
- return InternalError(
- "Unexpected calling opcode: %s",
- HloOpcodeString(instruction->opcode()).c_str());
+ return InternalError("Unexpected calling opcode: %s",
+ HloOpcodeString(instruction->opcode()));
}
}
}
@@ -169,65 +217,6 @@ Status GatherComputationsByAllocationType(
return Status::OK();
}
-// Checks that points-to set of 'instruction' is unambiguous and distinct
-// (ensured by CopyInsertion), then adds the buffer from the points-to set at
-// 'index' to 'colocated_set'.
-const LogicalBuffer* AddBufferToColocatedSet(
- const HloInstruction* instruction, const ShapeIndex& index,
- const TuplePointsToAnalysis& points_to_analysis,
- std::vector<const LogicalBuffer*>* colocated_set) {
- // CopyInsertion ensures root points-to set is unambiguous and distinct.
- const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
- DCHECK(!points_to.IsAmbiguous());
- colocated_set->push_back(points_to.element(index)[0]);
- return colocated_set->back();
-}
-
-// Given the interference map of a graph (the list of interfering node indices
-// for each node), perform graph coloring such that interfering nodes are
-// assigned to different colors. Returns the assigned color of the nodes, where
-// the colors are represented as integer values [0, color_count).
-std::vector<int64> ColorInterferenceGraph(
- const std::vector<std::vector<int64>>& interference_map) {
- const int64 node_count = interference_map.size();
-
- // Sort the nodes such that we assign nodes with more interference first. This
- // relies on the common heuristic of assigning the most constrained node
- // first, but it would be good to investigate other ordering heuristics too.
- std::vector<int64> nodes(node_count);
- std::iota(nodes.begin(), nodes.end(), 0);
- std::sort(nodes.begin(), nodes.end(),
- [&interference_map](const int64 i, const int64 j) {
- return interference_map[i].size() > interference_map[j].size();
- });
-
- const int64 kColorUnassigned = -1;
- std::vector<int64> assigned_colors(node_count, kColorUnassigned);
- for (int64 node : nodes) {
- // Mark the colors that are already assigned to the neighbors.
- std::vector<bool> available_colors(node_count, true);
- for (int64 neighbor : interference_map[node]) {
- int64 color = assigned_colors[neighbor];
- if (color != kColorUnassigned) {
- available_colors[color] = false;
- }
- }
-
- // Find the color that is not yet assigned to the neighbors.
- int64 color = kColorUnassigned;
- for (color = 0; color < available_colors.size(); ++color) {
- if (available_colors[color]) {
- break;
- }
- }
- CHECK_NE(color, kColorUnassigned);
- assigned_colors[node] = color;
- }
- return assigned_colors;
-}
-
-} // namespace
-
size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
uint64 h = std::hash<int64>()(s.index());
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
@@ -236,8 +225,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
}
string BufferAllocation::Slice::ToString() const {
- return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_,
- ", size:", size_, "}");
+ return absl::StrCat("{index:", index(), ", offset:", offset_,
+ ", size:", size_, "}");
}
BufferAllocation::Slice BufferAllocation::GetSlice(
@@ -298,7 +287,7 @@ BufferAllocationProto BufferAllocation::ToProto() const {
string BufferAllocation::ToString() const {
string output;
- Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size());
+ StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
if (color().value() != 0) {
StrAppend(&output, ", color ", color().value());
}
@@ -330,11 +319,10 @@ string BufferAllocation::ToString() const {
});
for (const LogicalBuffer* buffer : sorted_buffers) {
const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
- StrAppend(&output,
- tensorflow::strings::Printf(
- " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(),
- offset_size.offset, offset_size.size,
- ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str()));
+ StrAppend(&output, absl::StrFormat(
+ " %s [%d,%d]: %s\n", buffer->ToString(),
+ offset_size.offset, offset_size.size,
+ ShapeUtil::HumanStringWithLayout(buffer->shape())));
}
return output;
}
@@ -427,7 +415,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
return FailedPrecondition(
"BufferAllocation::Slice for instruction %s at index %s cannot "
"be determined at compile-time.",
- instruction->name().c_str(), index.ToString().c_str());
+ instruction->name(), index.ToString());
}
} else {
VLOG(3) << "No allocation";
@@ -436,7 +424,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
if (result.allocation() == nullptr) {
return FailedPrecondition(
"BufferAllocation::Slice not assigned for instruction %s at index %s",
- instruction->name().c_str(), index.ToString().c_str());
+ instruction->name(), index.ToString());
}
return result;
}
@@ -627,19 +615,25 @@ Status BufferAssignment::ComputeSummaryStats() {
stats_.total_allocation_bytes += allocation.size();
}
- // Only compute total fragmentation if all computations are sequential.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ // Only compute total fragmentation if all computations have schedules.
+ HloSchedule schedule(module_);
+ bool schedule_complete = true;
for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- liveness_->hlo_ordering().SequentialOrder(*computation);
- if (sequence != nullptr) {
- module_sequence.emplace(computation, *sequence);
+ if (!computation->IsFusionComputation()) {
+ const std::vector<const HloInstruction*>* sequence =
+ liveness_->hlo_ordering().SequentialOrder(*computation);
+ if (sequence == nullptr) {
+ schedule_complete = false;
+ } else {
+ schedule.set_sequence(computation, *sequence);
+ }
}
}
- if (module_sequence.size() == module_->computation_count()) {
+ if (schedule_complete) {
+ TF_RETURN_IF_ERROR(schedule.Verify());
TF_ASSIGN_OR_RETURN(
const int64 min_size,
- HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
+ HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
@@ -648,39 +642,38 @@ Status BufferAssignment::ComputeSummaryStats() {
string BufferAssignment::Stats::ToString() const {
string s;
- Appendf(&s, "BufferAssignment stats:\n");
- Appendf(&s, " parameter allocation: %10s\n",
- HumanReadableNumBytes(parameter_allocation_bytes).c_str());
- Appendf(&s, " constant allocation: %10s\n",
- HumanReadableNumBytes(constant_allocation_bytes).c_str());
- Appendf(&s, " maybe_live_out allocation: %10s\n",
- HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str());
- Appendf(&s, " preallocated temp allocation: %10s\n",
- HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str());
+ StrAppendFormat(&s, "BufferAssignment stats:\n");
+ StrAppendFormat(&s, " parameter allocation: %10s\n",
+ HumanReadableNumBytes(parameter_allocation_bytes));
+ StrAppendFormat(&s, " constant allocation: %10s\n",
+ HumanReadableNumBytes(constant_allocation_bytes));
+ StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
+ HumanReadableNumBytes(maybe_live_out_allocation_bytes));
+ StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
+ HumanReadableNumBytes(preallocated_temp_allocation_bytes));
if (preallocated_temp_fragmentation_bytes >= 0) {
const double percent = 100. * preallocated_temp_fragmentation_bytes /
preallocated_temp_allocation_bytes;
- Appendf(
+ StrAppendFormat(
&s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
- HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(),
- percent);
+ HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
}
- Appendf(&s, " total allocation: %10s\n",
- HumanReadableNumBytes(total_allocation_bytes).c_str());
+ StrAppendFormat(&s, " total allocation: %10s\n",
+ HumanReadableNumBytes(total_allocation_bytes));
if (total_fragmentation_bytes >= 0) {
const double percent =
100. * total_fragmentation_bytes / total_allocation_bytes;
- Appendf(&s, " total fragmentation: %10s (%.2f%%)\n",
- HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent);
+ StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
+ HumanReadableNumBytes(total_fragmentation_bytes), percent);
}
return s;
}
string BufferAssignment::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "BufferAssignment:\n");
+ absl::StrAppend(&output, "BufferAssignment:\n");
for (auto& allocation : allocations_) {
- tensorflow::strings::StrAppend(&output, allocation.ToString());
+ absl::StrAppend(&output, allocation.ToString());
}
return output;
}
@@ -1076,7 +1069,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// since buffers for kCall, kWhile, and kConditional sub-computations are
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(&assignment->module());
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
@@ -1084,7 +1077,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
- module_sequence[computation] = *instruction_sequence;
+ schedule.set_sequence(computation, *instruction_sequence);
all_buffers_to_assign.insert(buffers_to_assign.begin(),
buffers_to_assign.end());
}
@@ -1100,9 +1093,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
- assignment->module(), module_sequence,
+ HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
+ assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
@@ -1130,11 +1123,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<LazyBestFitHeap>(alignment)),
- *computation, *instruction_sequence,
- assignment->points_to_analysis(),
- assignment->buffer_size_, options));
+ HeapSimulator::Run(
+ absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)),
+ *computation, HloInstructionSequence(*instruction_sequence),
+ assignment->points_to_analysis(), assignment->buffer_size_,
+ options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
@@ -1646,7 +1640,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
XLA_VLOG_LINES(3, liveness->ToString());
XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
- // Can't use MakeUnique because BufferAssignment constructor is private.
+ // Can't use absl::make_unique because BufferAssignment constructor is
+ // private.
std::unique_ptr<BufferAssignment> assignment(
new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
std::move(color_alignment)));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 94495290c1..24ba7c16f5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
@@ -41,6 +41,17 @@ limitations under the License.
namespace xla {
+// Walk the call graph of the HLO module and place each computation into either
+// thread_local_computations or global_computations depending upon whether the
+// computation requires thread-local allocations or global allocations. The
+// elements in thread_local_computations and global_computations are in post
+// order (if computation A has an instruction which calls computation B, then A
+// will appear after B in the vector).
+Status GatherComputationsByAllocationType(
+ const HloModule* module,
+ std::vector<const HloComputation*>* thread_local_computations,
+ std::vector<const HloComputation*>* global_computations);
+
// This class abstracts an allocation of contiguous memory which can hold the
// values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
// of the allocation, represented by a Slice. A single BufferAllocation may hold
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index eccb146a0d..795beb9ff5 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
@@ -30,16 +30,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -79,15 +81,14 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
return main_list.GetInstructions();
}
-class BufferAssignmentTest : public HloTestBase {
+class BufferAssignmentTest : public HloVerifiedTestBase {
protected:
- BufferAssignmentTest() {}
~BufferAssignmentTest() override {}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -98,7 +99,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
HloModule* module, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -109,7 +110,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
return BufferAssigner::Run(
- module, xla::MakeUnique<DependencyHloOrdering>(module),
+ module, absl::make_unique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -119,15 +120,12 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
HloModule* module,
- tensorflow::gtl::ArraySlice<const HloInstruction*> instruction_sequence,
+ absl::Span<const HloInstruction* const> instruction_sequence,
int64 alignment = 1) {
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[module->entry_computation()] =
- std::vector<const HloInstruction*>(instruction_sequence.begin(),
- instruction_sequence.end());
+ HloSchedule schedule(module);
+ schedule.set_sequence(module->entry_computation(), instruction_sequence);
return BufferAssigner::Run(
- module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -147,6 +145,17 @@ class BufferAssignmentTest : public HloTestBase {
return builder.Build();
}
+ std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ auto param2 =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
+ return builder.Build();
+ }
+
// Builds a simple compare-to-limit (x < 4) computation for a While.
//
// condition:
@@ -163,8 +172,8 @@ class BufferAssignmentTest : public HloTestBase {
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4));
return builder.Build();
}
@@ -311,12 +320,12 @@ TEST_F(BufferAssignmentTest, ScalarConstant) {
module->AddEntryComputation(builder.Build());
{
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
}
{
- auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
+ auto buffers = RunBufferAssignmentNoBuffersForConstants(module);
EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
}
}
@@ -335,13 +344,13 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
module->AddEntryComputation(builder.Build());
{
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
GetAssignedOutputAllocation(*buffers, add);
}
{
- auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
+ auto buffers = RunBufferAssignmentNoBuffersForConstants(module);
EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
GetAssignedOutputAllocation(*buffers, add);
@@ -363,7 +372,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
// reports for the instruction directly.
EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
@@ -386,7 +395,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// The copy node now has an output buffer.
GetAssignedOutputAllocation(*buffers, copy);
}
@@ -400,12 +409,14 @@ TEST_F(BufferAssignmentTest, Basic) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -413,7 +424,7 @@ TEST_F(BufferAssignmentTest, Basic) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -447,12 +458,14 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -472,7 +485,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
return Status::OK();
};
- auto buffers = RunColoredBufferAssignment(module.get(), colorer);
+ auto buffers = RunColoredBufferAssignment(module, colorer);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -506,12 +519,14 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -539,7 +554,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
return Status::OK();
};
- auto buffers = RunColoredBufferAssignment(module.get(), colorer);
+ auto buffers = RunColoredBufferAssignment(module, colorer);
// Distinct input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -576,12 +591,14 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
auto sub = builder.AddInstruction(
@@ -589,7 +606,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
// Input buffers were assigned for parameters.
BufferAllocation paramscalar_buffer =
@@ -640,7 +657,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
// Assigns buffers and fetches sizes.
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
int64 size0 = ValidateBuffers(level0, *buffers);
int64 size1 = ValidateBuffers(level1, *buffers);
@@ -675,10 +692,10 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
// output. (Reuse is not safe in the general case, as it reshapes and some
// out-of-order reductions could overwrite an element before a use.)
//
- // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3)
+ // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
auto module = CreateNewModule();
auto reduce_computation =
- module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
+ module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
@@ -699,7 +716,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
ValidateBuffers(instrs, *buffers);
@@ -755,7 +772,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
// Assigns buffers and fetches sizes.
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
int64 size0 = ValidateBuffers(level0, *buffers);
int64 sizec = ValidateBuffers(levelc, *buffers);
int64 sizeb = ValidateBuffers(levelb, *buffers);
@@ -820,7 +837,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) {
EXPECT_EQ(2, true_instrs.size());
EXPECT_EQ(2, false_instrs.size());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
ValidateBuffers(conditional_instrs, *buffers);
ValidateBuffers(true_instrs, *buffers);
ValidateBuffers(false_instrs, *buffers);
@@ -858,7 +875,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// tanh and exp2 can reuse exp1's buffer
EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
@@ -887,7 +904,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// negate and broadcast should share a buffer.
EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
@@ -920,7 +937,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The instructions should not share buffers.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -957,7 +974,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The instructions should not share buffers.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -992,7 +1009,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The broadcast output buffer cannot be shared.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -1024,7 +1041,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// negate and broadcast should share a buffer.
EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
@@ -1062,7 +1079,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// The broadcast output buffer cannot be shared.
EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
@@ -1106,7 +1123,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
HloInstruction::CreateMap(vec_shape, {call}, map_computation));
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Allocations for the map computation should be thread-local and not
// live-out.
@@ -1155,7 +1172,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// There should be four allocations: one for vector of pointers, and one for
// each tuple element.
@@ -1191,7 +1208,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Only some of the elements of the input param are liveout.
EXPECT_FALSE(
@@ -1228,13 +1245,14 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
+ Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()})));
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(3, assignment->Allocations().size());
}
@@ -1248,7 +1266,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
/*operands=*/{}, /*custom_call_target=*/"foo_function"));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(3, assignment->Allocations().size());
EXPECT_TRUE(
@@ -1279,7 +1297,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
EXPECT_EQ(2, assignment->Allocations().size());
// Buffers for call are colocated with the sub-computation.
@@ -1341,7 +1359,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
module->AddEntryComputation(std::move(a_computation));
module->AddEmbeddedComputation(std::move(b_computation));
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Buffers for call are colocated with the sub-computations.
EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
@@ -1377,7 +1395,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Bitcast should get the same allocation as the param.
EXPECT_EQ(1, assignment->Allocations().size());
@@ -1404,7 +1422,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// Select shallow copies one of its operands so it defines its own top-level
// buffer and receives its own allocation.
@@ -1442,7 +1460,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(module);
// There should be no buffer reuse. The copy should not reuse the tuple
// buffer.
@@ -1471,17 +1489,20 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_ab = builder.AddInstruction(
- HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
- auto dot_bc = builder.AddInstruction(
- HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_2x4, param_a, param_b, dot_dnums, precision_config));
+ auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_3x4, param_b, param_c, dot_dnums, precision_config));
builder.AddInstruction(
- HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1));
+ HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
// Run buffer assignment with alignment=1.
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
+ auto assignment = RunBufferAssignment(module, /*alignment=*/1);
// There are 5 allocations: 3 parameters, 1 output, and 1 temp.
EXPECT_EQ(5, assignment->Allocations().size());
@@ -1500,7 +1521,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
EXPECT_EQ(80, slice_bc.allocation()->size());
// Re-run buffer assignment with alignment=64.
- assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
+ assignment = RunBufferAssignment(module, /*alignment=*/64);
EXPECT_EQ(5, assignment->Allocations().size());
slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
@@ -1531,12 +1552,14 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
- f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ f32vec100_, HloOpcode::kMultiply, broadcast, param0));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1544,16 +1567,13 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
- // Trivially, the set of peak memory logical buffer(s) of an allocation with a
- // single logical buffer should be exactly the logical buffer in that
- // allocation.
const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
const std::vector<const LogicalBuffer*>& peak_buffers =
mul_buffer.PeakMemoryLogicalBuffers();
ASSERT_EQ(peak_buffers.size(), 1);
- EXPECT_EQ(peak_buffers[0]->instruction(), mul);
+ EXPECT_EQ(peak_buffers[0]->instruction(), broadcast);
}
TEST_F(BufferAssignmentTest, PeakBuffers) {
@@ -1589,7 +1609,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
module->AddEntryComputation(builder.Build());
auto buffers = RunBufferAssignmentWithInstructionSequence(
- module.get(), {param, log, rev, neg, concat, root});
+ module, {param, log, rev, neg, concat, root});
// The temporary buffer should hold the 4 interior instructions.
const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
@@ -1645,7 +1665,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0}));
module->AddEntryComputation(builder.Build());
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(module);
const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast);
const std::vector<const LogicalBuffer*>& peak_buffers =
buffer.PeakMemoryLogicalBuffers();
@@ -1695,15 +1715,13 @@ ENTRY main {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(hlo_text));
-
+ ParseAndVerifyModule(hlo_text);
HloInstruction* constant_1 =
- module->entry_computation()->GetInstructionWithName("constant.1.1");
+ module().entry_computation()->GetInstructionWithName("constant.1.1");
HloInstruction* constant_2 =
- module->entry_computation()->GetInstructionWithName("constant.1.2");
+ module().entry_computation()->GetInstructionWithName("constant.1.2");
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(&module());
{
const BufferAllocation& allocation_for_const_1 =
@@ -1732,7 +1750,7 @@ ENTRY main {
}
}
-class WhileBufferAssignmentTest : public HloTestBase {
+class WhileBufferAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloComputation> BuildWhileConditionComputation(
const string& name) {
@@ -1766,10 +1784,10 @@ class WhileBufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
- module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
+ module, absl::make_unique<SequentialHloOrdering>(schedule),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; },
/*allow_input_output_aliasing=*/false,
@@ -1805,9 +1823,9 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -1831,8 +1849,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// Verify 'input0' and read-only use while0{0} alias.
EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
@@ -1888,20 +1906,20 @@ ENTRY %test_module {
ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
- int64 instruction_count = module->instruction_count();
+ int64 instruction_count = module().instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
- ASSERT_EQ(instruction_count, module->instruction_count());
+ ASSERT_IS_OK(copy_insertion.Run(&module()).status());
+ ASSERT_EQ(instruction_count, module().instruction_count());
// Get the instructions in the module.
- const HloInstruction* bcast = module->entry_computation()->root_instruction();
+ const HloInstruction* bcast =
+ module().entry_computation()->root_instruction();
const HloInstruction* param =
- module->entry_computation()->parameter_instruction(0);
+ module().entry_computation()->parameter_instruction(0);
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
const HloInstruction* while1 = bcast->operand(0);
ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
@@ -1909,7 +1927,7 @@ ENTRY %test_module {
ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
// Run buffer assignment.
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(&module());
TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
assignment->GetUniqueSlice(param, {}));
TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
@@ -1956,20 +1974,20 @@ ENTRY %test_module {
ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
})";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
- int64 instruction_count = module->instruction_count();
+ int64 instruction_count = module().instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
- ASSERT_EQ(instruction_count, module->instruction_count());
+ ASSERT_IS_OK(copy_insertion.Run(&module()).status());
+ ASSERT_EQ(instruction_count, module().instruction_count());
// Get the instructions in the module.
- const HloInstruction* bcast = module->entry_computation()->root_instruction();
+ const HloInstruction* bcast =
+ module().entry_computation()->root_instruction();
const HloInstruction* constant =
- module->entry_computation()->GetInstructionWithName("constant.42");
+ module().entry_computation()->GetInstructionWithName("constant.42");
ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
const HloInstruction* while1 = bcast->operand(0);
ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
@@ -1977,7 +1995,7 @@ ENTRY %test_module {
ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
// Run buffer assignment.
- auto assignment = RunBufferAssignment(module.get());
+ auto assignment = RunBufferAssignment(&module());
TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
assignment->GetUniqueSlice(constant, {}));
TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
@@ -2070,24 +2088,31 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// any copies inserted for BufferAssignment to run.
int64 instruction_count = module->instruction_count();
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
+ ASSERT_IS_OK(copy_insertion.Run(module).status());
ASSERT_EQ(instruction_count, module->instruction_count());
// Create a sequential order among all the instructions in the entry
// computation, since the issue this test stresses depends on the order the
// nodes are traversed during BufferAssignment.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[module->entry_computation()] = {
- token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+ schedule.set_sequence(
+ module->entry_computation(),
+ {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
+ TF_ASSERT_OK(schedule.Verify());
+
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
- backend().compiler()->BufferSizeBytesFunction(),
- [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ backend().compiler()->BufferSizeBytesFunction(),
+ [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// The result tuple elements must be assigned with different buffers.
TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
@@ -2120,7 +2145,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2141,8 +2166,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// while0 and while1 buffers should be completely aligned.
EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
@@ -2184,13 +2209,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
{
FlattenCallGraph flatten;
- TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
}
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
}
@@ -2214,15 +2239,14 @@ ENTRY Main {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloModule> module,
- HloRunner::CreateModuleFromString(
- hlo_text, legacy_flags::GetDebugOptionsFromFlags()));
+ HloModuleConfig config;
+ config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+ ParseAndVerifyModule(hlo_text, config);
- auto buffers = RunBufferAssignment(module.get());
+ auto buffers = RunBufferAssignment(&module());
- HloComputation* main = module->entry_computation();
- HloComputation* callee = module->GetComputationWithName("Callee");
+ HloComputation* main = module().entry_computation();
+ HloComputation* callee = module().GetComputationWithName("Callee");
EXPECT_NE(callee, nullptr);
HloInstruction* param0 = callee->parameter_instruction(0);
@@ -2245,29 +2269,6 @@ ENTRY Main {
GetAllocation(*buffers, param0, {1, 1}));
}
-static bool IsPostOrderTraversal(
- const std::vector<const HloInstruction*>& sequence) {
- tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
- auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
- return seen_so_far.count(instruction) == 0;
- };
-
- for (auto instruction : sequence) {
- if (std::any_of(instruction->operands().begin(),
- instruction->operands().end(), has_not_been_seen_yet) ||
- std::any_of(instruction->control_predecessors().begin(),
- instruction->control_predecessors().end(),
- has_not_been_seen_yet)) {
- return false; // Not a post order.
- }
- if (!seen_so_far.insert(instruction).second) {
- return false; // Not a "traversal".
- }
- }
-
- return true;
-}
-
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
@@ -2282,14 +2283,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto weights0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto input1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, data_shape_, "input1"));
auto weights1 = builder.AddInstruction(
HloInstruction::CreateParameter(3, data_shape_, "weights1"));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, one, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, one, {}));
auto cond =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2309,41 +2310,40 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
- auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
- while0->shape(), HloOpcode::kAdd, gte0, gte1));
+ auto root_add = builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
module->AddEntryComputation(builder.Build());
{
FlattenCallGraph flatten;
- TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module));
EXPECT_TRUE(result);
}
- RunCopyInsertion(module.get());
+ RunCopyInsertion(module);
- auto sequence =
- ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+ HloSchedule schedule =
+ ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
- // To trigger b/38494731, we want a specific Hlo sequence for the
+ // To trigger b/38494731, we want a specific Hlo schedule for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
- sequence[module->entry_computation()] = {
- input1, weights1, one, output1, while1->operand(0), while1,
- input0, weights0, zero, output0, while0->operand(0), while0,
- gte0, gte1, root_add};
+ schedule.set_sequence(module->entry_computation(),
+ {input1, weights1, one, output1, while1->operand(0),
+ while1, input0, weights0, zero, output0,
+ while0->operand(0), while0, gte0, gte1, root_add});
- // If this ASSERT_TRUE fails, we constructed a bogus sequence above
- // and this test itself is buggy.
- ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+ // If this ASSERT fails, we constructed a bogus sequence above and this test
+ // itself is buggy.
+ TF_ASSERT_OK(schedule.Verify());
auto assignment =
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
- ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true)
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true)
.ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
@@ -2361,9 +2361,9 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto output1 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ HloInstruction::CreateBroadcast(data_shape_, zero, {}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -2394,8 +2394,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
module->AddEntryComputation(builder.Build());
- RunCopyInsertion(module.get());
- auto assignment = RunBufferAssignment(module.get());
+ RunCopyInsertion(module);
+ auto assignment = RunBufferAssignment(module);
// Get BufferAllocation for root instruction.
auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
.ConsumeValueOrDie()
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 810d597e73..9b2783a214 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -28,8 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() {
string BufferLiveness::ToString() const {
std::vector<string> pieces;
- pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):",
- module_->name().c_str()));
+ pieces.push_back(
+ absl::StrFormat("BufferLiveness(module=%s):", module_->name()));
pieces.push_back("HloOrdering:");
pieces.push_back(hlo_ordering_->ToString());
- pieces.push_back(tensorflow::strings::Printf("Aliased buffers:"));
+ pieces.push_back("Aliased buffers:");
for (const LogicalBuffer* buffer : aliased_buffers_) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
- pieces.push_back(tensorflow::strings::Printf("Live out buffers:"));
+ pieces.push_back("Live out buffers:");
for (const LogicalBuffer* buffer : maybe_live_out_buffers_) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", buffer->ToString()));
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
const LogicalBuffer& b) const {
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
- TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a));
+ TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b));
if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
return false;
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 4a927b5767..17e5090505 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -18,14 +18,16 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -119,8 +121,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
@@ -165,11 +167,11 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
auto module = CreateNewModule();
HloComputation* entry = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), sequence))
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
@@ -215,8 +217,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -249,8 +251,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@@ -290,12 +292,11 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- module_sequence.emplace(computation, order);
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation, {param, negate, exp, add});
auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@@ -338,13 +339,13 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build(add));
- SequentialHloOrdering::HloModuleSequence module_sequence;
- std::vector<const HloInstruction*> order = {param, add, recv,
- recv_done, send, send_done};
- module_sequence.emplace(computation, order);
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(computation,
+ {param, add, token, recv, recv_done, send, send_done});
+ TF_ASSERT_OK(schedule.Verify());
auto liveness =
- BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
+ BufferLiveness::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
@@ -376,8 +377,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// All buffers should be live out except the param
@@ -412,8 +413,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) {
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Buffers in different computations should always interfere.
@@ -439,22 +440,22 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
- LiteralUtil::CreateR0<int64>(1).get()});
- auto inner_tuple1 =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+ LiteralUtil::CreateR0<int64>(1)};
+ auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+ Literal element1 = LiteralUtil::CreateR0<int64>(3);
+ auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- inner_tuple0->shape(), tuple_constant, 0));
+ inner_tuple0.shape(), tuple_constant, 0));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Only the element buffers of the tuple constant which are pointed to by
@@ -518,8 +519,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -580,8 +581,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
module->AddEmbeddedComputation(builder.Build());
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()))
+ BufferLiveness::Run(
+ module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
- // Runs BufferLiveness on this computation.
- // Returns whether buffer interference is detected between tuple-shaped
- // parameter and root instructions at tuple element 1.
- bool Run(const bool update_uses_tuple_element1,
- const bool fuse_gte0 = false) {
+ std::unique_ptr<HloModule> BuildModule(const bool update_uses_tuple_element1,
+ const bool fuse_gte0) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
@@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
// Create output tuple.
- auto tuple_root = builder.AddInstruction(
+ builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = CreateNewModule();
- module->AddEntryComputation(BuildDummyComputation());
- auto* computation = module->AddEmbeddedComputation(builder.Build());
+ module->AddEntryComputation(builder.Build());
+ auto* computation = module->entry_computation();
// Create fusion instruction based on number of tuple element 1 users.
if (update_uses_tuple_element1) {
computation->CreateFusionInstruction(
@@ -666,16 +664,39 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
computation->CreateFusionInstruction({gte0},
HloInstruction::FusionKind::kLoop);
}
+ return module;
+ }
+ // Returns whether buffer interference is detected between tuple-shaped
+ // parameter and root instructions at tuple element 1.
+ bool Run(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
+ bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
+ const bool fuse_gte0 = false) {
+ auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
+ // Run BufferLiveness on 'module'.
+ auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
+ auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
+ // Return whether or not buffers interference is detected between
+ // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
+ auto tuple_param0 = FindInstruction(module.get(), "param0");
+ auto tuple_root = module->entry_computation()->root_instruction();
+ return hlo_ordering->MayInterfere(
+ dataflow->GetUniqueValueAt(tuple_param0, {1}),
+ dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
+ }
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
+ EXPECT_FALSE(
+ RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
@@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
+ EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
+ /*fuse_gte0=*/true));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
@@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
//
TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
+ EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
@@ -780,10 +806,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
module->AddEntryComputation(BuildDummyComputation());
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
- auto liveness =
- BufferLiveness::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
- .ConsumeValueOrDie();
+ auto liveness = BufferLiveness::Run(
+ module.get(),
+ absl::make_unique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc
index 2bc556a9e2..fdf822c666 100644
--- a/tensorflow/compiler/xla/service/buffer_value.cc
+++ b/tensorflow/compiler/xla/service/buffer_value.cc
@@ -17,11 +17,10 @@ limitations under the License.
#include <iosfwd>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h
index f4be16e084..69b3646356 100644
--- a/tensorflow/compiler/xla/service/buffer_value.h
+++ b/tensorflow/compiler/xla/service/buffer_value.h
@@ -19,12 +19,12 @@ limitations under the License.
#include <functional>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 985ff30e80..23b2a32709 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,21 +17,21 @@ limitations under the License.
#include <queue>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::strings::Appendf;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppendFormat;
+using absl::StrCat;
string CallContextToString(CallContext context) {
switch (context) {
@@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
}
string CallSite::ToString() const {
- return StrCat(instruction()->name(), " calls in context ",
- CallContextToString(context()), ": ",
- tensorflow::str_util::Join(
- called_computations(), ", ",
+ return StrCat(
+ instruction()->name(), " calls in context ",
+ CallContextToString(context()), ": ",
+ absl::StrJoin(called_computations(), ", ",
[](string* out, const HloComputation* computation) {
out->append(computation->name());
}));
@@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() {
/* static */
std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
- // Constructor for CallGraph is private so MakeUnique can't be used.
- auto call_graph = WrapUnique<CallGraph>(new CallGraph(module));
+ // Constructor for CallGraph is private so absl::make_unique can't be used.
+ auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
VLOG(2) << "Building call graph for:";
XLA_VLOG_LINES(2, module->ToString());
@@ -356,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
string CallGraph::ToString() const {
string out;
- Appendf(&out, "Call graph for module %s:\n", module_->name().c_str());
+ StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
for (const CallGraphNode& node : nodes()) {
- Appendf(&out, "Computation %s:\n", node.computation()->name().c_str());
- Appendf(&out, " calls:\n");
+ StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
+ StrAppendFormat(&out, " calls:\n");
for (const HloComputation* callee : node.callees()) {
- Appendf(&out, " %s\n", callee->name().c_str());
+ StrAppendFormat(&out, " %s\n", callee->name());
}
- Appendf(&out, " called by:\n");
+ StrAppendFormat(&out, " called by:\n");
for (const HloComputation* caller : node.callers()) {
- Appendf(&out, " %s\n", caller->name().c_str());
+ StrAppendFormat(&out, " %s\n", caller->name());
}
- Appendf(&out, " callsites:\n");
+ StrAppendFormat(&out, " callsites:\n");
for (const CallSite& callsite : node.callsites()) {
- Appendf(&out, " %s\n", callsite.ToString().c_str());
+ StrAppendFormat(&out, " %s\n", callsite.ToString());
}
}
return out;
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 97d3811508..3af2ab5edf 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -15,8 +15,8 @@ limitations under the License.
// Call graph for an HLO module.
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
#include <ostream>
@@ -272,4 +272,4 @@ class CallGraph {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74843..34f3f914d5 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
@@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(3, call_graph->nodes().size());
@@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Test visitation of only reachable nodes.
{
@@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewModule();
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
index 256d05a73e..1d42140444 100644
--- a/tensorflow/compiler/xla/service/call_inliner.cc
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
if (it == subcomputation_hlo_to_new_hlo_.end()) {
return NotFound(
"Could not find mapping from subcomputation HLO %s to a cloned HLO.",
- subcomputation_hlo->ToString().c_str());
+ subcomputation_hlo->ToString());
}
return it->second;
}
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index a8345a394d..c5cd88b9ea 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
#include <deque>
@@ -35,11 +35,11 @@ class CallInliner : public HloPassInterface {
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
~CallInliner() override = default;
- tensorflow::StringPiece name() const override { return "CallInliner"; }
+ absl::string_view name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index ff968bca29..5d85a3f173 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -32,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace op = xla::testing::opcode_matchers;
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
index 13008efed1..3c2d1ae6d8 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.cc
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/channel_tracker.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) {
Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
- return NotFound("channel handle not found: %lld", handle.handle());
+ return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
if (channel.type == ChannelHandle::HOST_TO_DEVICE) {
return FailedPrecondition(
"host-to-device channels cannot be used with a Send operation; "
- "channel handle: %lld",
+ "channel handle: %d",
handle.handle());
}
if (channel.has_sender) {
return FailedPrecondition(
"when registering send, passed a channel handle that is already used "
- "by a sender: %lld",
+ "by a sender: %d",
handle.handle());
}
channel.has_sender = true;
@@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
- return NotFound("channel handle not found: %lld", handle.handle());
+ return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
if (channel.type == ChannelHandle::DEVICE_TO_HOST) {
return FailedPrecondition(
"device-to-host channels cannot be used with a Recv operation; "
- "channel handle: %lld",
+ "channel handle: %d",
handle.handle());
}
@@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
if (channel.receiver_count >= 1) {
return FailedPrecondition(
"when registering recv, passed a channel handle that is already used "
- "by a receiver: %lld",
+ "by a receiver: %d",
handle.handle());
}
channel.receiver_count += 1;
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
index d773558c28..52037bf9b5 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.h
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <map>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 7426672a7a..e5a6c28478 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
@@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime(
if (!directory_path.empty()) {
HloSnapshot hlo_snapshot;
*hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation;
- string filename = tensorflow::strings::StrCat(
- "computation_", instance.computation.id(), "__",
- instance.computation.entry_computation_name());
+ string filename =
+ absl::StrCat("computation_", instance.computation.id(), "__",
+ instance.computation.entry_computation_name());
const string& per_host_path = tensorflow::io::JoinPath(
directory_path, tensorflow::port::Hostname());
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 1ac950bdd6..61136a3e11 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -50,12 +50,12 @@ class CompileOnlyService : public Service {
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const absl::Span<const AotXlaComputationInstance> computations,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata);
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 6b3b9820f0..687ecafe0c 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() {
return NotFound(
"could not find registered compiler for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
// And then we invoke the factory, placing the result into the mapping.
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 34f7fe12ca..1fdda31c34 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
index cb61f3da39..af8f7f1027 100644
--- a/tensorflow/compiler/xla/service/computation_layout.cc
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <algorithm>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -52,9 +52,8 @@ string ComputationLayout::ToString() const {
for (auto& param_layout : parameter_layouts_) {
params.push_back(param_layout.ToString());
}
- return tensorflow::strings::StrCat("(",
- tensorflow::str_util::Join(params, ", "),
- ") => ", result_layout_.ToString());
+ return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ",
+ result_layout_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index 187ce568cb..2210a8578a 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -19,8 +19,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -29,12 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
namespace xla {
@@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
"computation_count=%d",
proto.replica_count(), proto.computation_count());
}
- auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
- proto.computation_count());
+ auto assignment = absl::make_unique<DeviceAssignment>(
+ proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
@@ -132,7 +132,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
return NotFound(
"could not find registered computation placer for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
if (it->second.placer == nullptr) {
@@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() {
} // namespace xla
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
- return xla::MakeUnique<xla::ComputationPlacer>();
+ return absl::make_unique<xla::ComputationPlacer>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index b7be3ba605..4ea3a13f28 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -28,8 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 063261e26d..3de50cbd7f 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -27,9 +27,7 @@ namespace xla {
// with their true or false computation as appropriate.
class ConditionalSimplifier : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "simplify-conditional";
- }
+ absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 45252fc1ee..0ac4a65ec6 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
expanded_filter = add(HloInstruction::CreateConcatenate(
expanded_filter_shape, concat_operands, input_feature_dim));
}
- auto zero = add(HloInstruction::CreateConstant(MakeUnique<Literal>(
- LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
auto new_filter = add(
@@ -223,7 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
filter_mask, expanded_filter, zero_filter));
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
- convolution->window(), dim_numbers, /*feature_group_count=*/1);
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index f213cc8709..498894737f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface {
public:
ConvolutionFeatureGroupConverter() {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-feature-group-converter";
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 3e39c1bab1..b65dfef9c9 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -31,18 +33,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace {
+using absl::StrAppend;
+
bool IsEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
@@ -381,7 +378,7 @@ class CopyRemover {
}
string ToString() const {
- string out = StrCat("CopyRemover, module ", module_->name(), "\n");
+ string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n");
StrAppend(&out, " Buffer values, in dependency order:\n");
for (const HloBuffer& buffer : alias_analysis_.buffers()) {
StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
@@ -482,7 +479,7 @@ class CopyRemover {
// 'values' an entry is created in value_to_node which indicates the
// respective ValueNode representing that value.
void AddValueList(
- tensorflow::gtl::ArraySlice<const HloValue*> values,
+ absl::Span<const HloValue* const> values,
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
ValueNode* tail = nullptr;
ValueNode* head = nullptr;
@@ -863,16 +860,16 @@ class CopyRemover {
for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
values.push_back(p->value);
}
- return StrCat("{",
- Join(values, ", ",
- [](string* s, const HloValue* value) {
- StrAppend(s, value->ToShortString());
- }),
- "}");
+ return absl::StrCat("{",
+ absl::StrJoin(values, ", ",
+ [](string* s, const HloValue* value) {
+ StrAppend(s, value->ToShortString());
+ }),
+ "}");
}
string ToString() const {
- string out = StrCat("BufferValueTracker:\n");
+ string out = absl::StrCat("BufferValueTracker:\n");
StrAppend(&out, " Def-use chains in each buffer:\n");
for (const ValueNode* head : value_lists_) {
StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
@@ -880,10 +877,10 @@ class CopyRemover {
const ValueNode* p = head;
do {
StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
- Join(p->uses, "; ",
- [](string* s, const HloUse* use) {
- StrAppend(s, use->ToString());
- }),
+ absl::StrJoin(p->uses, "; ",
+ [](string* s, const HloUse* use) {
+ StrAppend(s, use->ToString());
+ }),
"\n");
p = p->next;
@@ -960,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
return Status::OK();
}
-// Add copies to address special constraints on the roots of computations not
-// related to live range interference:
-//
-// (1) Entry computation root must be unambiguous and distinct.
-//
-// (2) Any computation called by a kCall instruction must have an
-// unambiguous root.
-//
-// (3) Constants and parameters cannot be live out of the entry computation
-//
+Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ return AddSpecialCaseCopies(*call_graph, module);
+}
+
Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
@@ -1065,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
- // Special case copies are not eligible for later copy elision passes.
- indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
- if (has_copy) {
- HloInstruction* copy = *copies_added.mutable_element(index);
- if (copy != nullptr) {
- copy->SetCopyElisionAllowed(false);
- }
- }
- });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1081,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
return Status::OK();
}
-Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) {
+Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering,
+ HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
- DependencyHloOrdering ordering(module);
TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
return Status::OK();
}
@@ -1101,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- instruction->CopyElisionAllowed()) {
+ if (instruction->opcode() == HloOpcode::kCopy) {
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
}
}
@@ -1168,10 +1150,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ DependencyHloOrdering dep_ordering(module);
+ TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module));
- DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module));
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
@@ -1179,7 +1161,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ TF_DCHECK_OK(
+ VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module));
MaybeDumpModule("after copy insertion", *module);
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 5ba64b78a3..d308f6bc84 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -45,7 +45,7 @@ namespace xla {
// InstructionAliasSet::IsDistinct return true.
class CopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
// fusion_can_share_buffer: backend specific function that decides whether a
// fusion can share buffer with its operand.
@@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface {
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module);
- private:
- // Verifies that no HLO values have interfering live ranged assuming the
- // ordering used by copy insertion.
- Status VerifyNoLiveRangeInterference(HloModule* module);
+ // Add copies to address special constraints on the roots of computations not
+ // related to live range interference:
+ //
+ // (1) Entry computation root must be unambiguous and distinct.
+ //
+ // (2) Any computation called by a kCall instruction must have an
+ // unambiguous root.
+ //
+ // (3) Constants and parameters cannot be live out of the entry computation
+ //
+ Status AddSpecialCaseCopies(HloModule* module);
- Status AddCopiesToResolveInterference(HloModule* module);
+ // Verifies that no HLO values have interfering live ranges using the given
+ // ordering.
+ Status VerifyNoLiveRangeInterference(const HloOrdering& ordering,
+ HloModule* module);
+ private:
+ // Override which requires the caller to pass in a call graph.
Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module);
+ Status AddCopiesToResolveInterference(HloModule* module);
+
// Backend specific function that decides whether a fusion can share buffer
// with its operand.
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index fe1ef78533..8cc522a59e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,8 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -62,6 +64,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -85,6 +88,10 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ":target_machine_features",
+ "@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
@@ -115,7 +122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
@@ -178,6 +185,7 @@ cc_library(
":runtime_single_threaded_conv2d",
":runtime_single_threaded_fft",
":runtime_single_threaded_matmul",
+ "@com_google_absl//absl/memory",
"@llvm//:execution_engine",
"@llvm//:core",
"@llvm//:mc", # fixdeps: keep
@@ -229,6 +237,9 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
"@llvm//:orc_jit",
],
)
@@ -271,11 +282,15 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
"@llvm//:code_gen",
"@llvm//:core",
"@llvm//:support",
@@ -320,6 +335,8 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -330,12 +347,12 @@ cc_library(
hdrs = ["parallel_loop_emitter.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:core",
],
)
@@ -362,6 +379,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -382,6 +400,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -395,6 +414,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:mc",
"@llvm//:mc_disassembler",
"@llvm//:object",
@@ -418,6 +438,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
"@llvm//:analysis",
"@llvm//:core",
"@llvm//:ipo",
@@ -446,6 +467,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -634,6 +656,8 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -646,8 +670,11 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -742,6 +769,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -773,6 +801,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -794,6 +823,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -810,6 +840,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
],
)
@@ -846,6 +878,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -893,6 +926,8 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
@@ -913,6 +948,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -938,6 +974,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
index 408fe0f5bf..1942ea1a2a 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
@@ -40,7 +40,7 @@ std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
}
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
+ absl::Span<const BufferInfo> buffer_infos) {
std::vector<int32> result;
for (int64 i = 0; i < buffer_infos.size(); i++) {
if (buffer_infos[i].is_entry_parameter()) {
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
index 05de70c726..e9ee928ab2 100644
--- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment(
// If this function returns V then entry parameter i has buffer allocation index
// V[i].
std::vector<int32> CreateArgIndexTableFromBufferInfos(
- tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
+ absl::Span<const ::tensorflow::cpu_function_runtime::BufferInfo>
buffer_infos);
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 128eea4828..73b03440cb 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses(
llvm::legacy::PassManagerBase* passes) const {
llvm::Triple target_triple(target_machine_->getTargetTriple());
auto target_library_info_impl =
- MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
+ absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
target_library_info_impl->addVectorizableFunctions(
VectorFunctionsForTargetLibraryInfoImpl());
passes->add(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 0985b9297f..2d9978404c 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -130,8 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
// change the dimension mapping but not the dimension sizes. For
// example, input height and width are the same as before the reshapes.
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
- HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
- hlo->window(), new_dnums));
+ HloInstruction::CreateConvolve(
+ new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
+ hlo->window(), new_dnums, hlo->precision_config()));
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index e6fd1499ed..59437e88af 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface {
: target_machine_features_(*target_machine_features) {}
~ConvCanonicalization() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "convolution-canonicalization";
}
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 547d4c696d..2083f440fd 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {
@@ -84,7 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -95,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -146,7 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -156,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index c7e766ca6b..18fc144efe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -26,6 +26,8 @@ limitations under the License.
// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
// IWYU pragma: no_include "llvm/Config/Targets.def.inc"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
@@ -42,7 +44,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
@@ -76,12 +77,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
@@ -101,8 +102,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace cpu {
@@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
-} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine) {
- LLVMTargetMachineFeatures target_machine_features(target_machine);
+} // namespace
- // Optimization pipeline.
- HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>();
+Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes through layout assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
- pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
+ pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
@@ -278,7 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
- pipeline.AddPass<ZeroSizedHloElimination>();
+ pass.AddPass<ZeroSizedHloElimination>();
pass.AddPass<WhileLoopInvariantCodeMotion>();
pass.AddPass<TupleSimplifier>();
@@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
}
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
- [&target_machine_features](
- const HloInstruction& dot,
+ [&](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
+ return PotentiallyImplementedAsEigenDot(dot, *target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), &target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
+ return pipeline.Run(module).status();
+}
+
+Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes after layout assignment");
+ // After layout assignment, use a layout-sensitive verifier.
+ auto& after_layout_assn =
+ pipeline.AddPass<HloPassPipeline>("after layout assignment");
+ after_layout_assn.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
+
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
- "after layout assignement");
+ "simplification after layout assignement");
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
@@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
+
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
@@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
pipeline.AddPass<ParallelTaskAssigner>(
- max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
+ max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
}
- // Copy insertion should be performed immediately before IR emission to avoid
- // inserting unnecessary copies (later pass adds an instruction which
- // materializes the value) or missing a necessary copy (later pass removes an
- // instruction which materializes a value). DCE must be run immediately before
- // (and sometime after) copy insertion, to avoid dead code from interfering
- // with the rewrites.
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value). DCE must be run immediately
+ // before (and sometime after) copy insertion, to avoid dead code from
+ // interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
@@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
return pipeline.Run(module).status();
}
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+ TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
+ &target_machine_features));
+ return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
+ &target_machine_features);
+}
+
namespace {
// Align buffers to 16-byte boundaries.
@@ -453,7 +479,7 @@ Status CreateHloProfilingArtifacts(
computation_to_profile_idx,
std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
- *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module);
+ *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
const HloComputation& entry_computation = *module.entry_computation();
TF_ASSIGN_OR_RETURN(
@@ -520,11 +546,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
&pre_optimization_ir_hook, &post_optimization_ir_hook));
// Compile must be thread-safe so create a new LLVM context for the module.
- auto llvm_context = xla::MakeUnique<llvm::LLVMContext>();
+ auto llvm_context = absl::make_unique<llvm::LLVMContext>();
auto llvm_module =
- xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context);
+ absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
- auto jit = xla::MakeUnique<SimpleOrcJIT>(
+ auto jit = absl::make_unique<SimpleOrcJIT>(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
@@ -558,20 +584,17 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
- DFSMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
- // Run buffer analysis on the HLO graph. This analysis figures out which
- // temporary buffers are required to run the computation.
+ // Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module.get(),
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -602,9 +625,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation, embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
string function_name_prefix = entry_computation->name().empty()
@@ -612,9 +636,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
: entry_computation->name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
- ir_emitter.EmitComputation(entry_computation, function_name_prefix,
- /*is_top_level_computation=*/true,
- &module_sequence.at(entry_computation)));
+ ir_emitter.EmitComputation(
+ entry_computation, function_name_prefix,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(entry_computation).instructions()));
string function_name = [&]() {
llvm::SmallVector<char, 40> function_name_vector;
@@ -679,8 +704,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
if (target == nullptr) {
- return InternalError("TargetRegistry::lookupTarget failed: %s",
- error.c_str());
+ return InternalError("TargetRegistry::lookupTarget failed: %s", error);
}
llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
@@ -716,7 +740,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name());
llvm::StringRef features = llvm_ir::AsStringRef(options.features());
llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
- std::unique_ptr<llvm::TargetMachine> target_machine = WrapUnique(
+ std::unique_ptr<llvm::TargetMachine> target_machine = absl::WrapUnique(
target->createTargetMachine(triple.getTriple(), cpu_name, features,
CompilerTargetOptions(modules[0]->config()),
reloc_model, llvm::None, opt_level));
@@ -747,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
- TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- ScheduleComputationsInModule(*module, BufferSizeBytesFunction()));
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module,
- xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
- BufferSizeBytesFunction(), memory_alignment,
- /*allow_input_output_aliasing=*/false,
- /*allocate_buffers_for_constants=*/true));
+ BufferAssigner::Run(module,
+ absl::make_unique<SequentialHloOrdering>(schedule),
+ BufferSizeBytesFunction(), memory_alignment,
+ /*allow_input_output_aliasing=*/false,
+ /*allocate_buffers_for_constants=*/true));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -800,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
}
TF_RETURN_IF_ERROR(
ir_emitter
- .EmitComputation(embedded_computation,
- embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
+ .EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &schedule.sequence(embedded_computation).instructions())
.status());
}
const string& entry_point_name = options.entry_point_name();
- TF_ASSIGN_OR_RETURN(
- llvm::Function * entry_function,
- ir_emitter.EmitComputation(computation, entry_point_name,
- /*is_top_level_computation=*/true,
- &module_sequence.at(computation)));
+ TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
+ ir_emitter.EmitComputation(
+ computation, entry_point_name,
+ /*is_top_level_computation=*/true,
+ &schedule.sequence(computation).instructions()));
CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
@@ -851,7 +873,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice());
- results.emplace_back(MakeUnique<CpuAotCompilationResult>(
+ results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data)));
}
@@ -874,7 +896,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::host::kHostPlatformId,
- []() { return xla::MakeUnique<xla::cpu::CpuCompiler>(); });
+ []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 04e1c48872..f2af923782 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -18,13 +18,14 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler {
Status RunHloPasses(HloModule* module, bool is_aot_compile,
llvm::TargetMachine* target_machine);
+ // Runs HLO passes up to and including layout assignment.
+ Status RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features);
+
+ // Runs HLO passes after layout assignment.
+ Status RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features);
+
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index 3313d1e6eb..d49f7d7cc2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -32,11 +32,11 @@ namespace xla {
// (module-scoped).
class CpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa446e..c9fb34be1c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*module), 3);
@@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index c376864c3e..29abf38e43 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -22,6 +22,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -35,9 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
@@ -75,9 +75,9 @@ CpuExecutable::CpuExecutable(
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
-CpuExecutable::CreateTempArray(
+CpuExecutable::CreateBufferTable(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<OwningDeviceMemory> owning_buffers(
@@ -136,19 +136,19 @@ CpuExecutable::CreateTempArray(
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+ absl::Span<const se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
//
// void function(void* result, const void* run_options, void** args_array,
- // void** temps_array)
+ // void** buffer_table)
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
// args_array: null
- // temps_array: An array of pointers, containing pointers to temporary buffers
- // required by the executable adn pointers to entry computation
- // parameters.
+ // buffer_table: An array of pointers, containing pointers to temporary
+ // buffers required by the executable adn pointers to entry computation
+ // parameters.
//
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -171,20 +171,19 @@ Status CpuExecutable::ExecuteComputeFunction(
void* result_buffer = buffer_pointers[result_slice.index()];
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
- VLOG(3) << tensorflow::strings::Printf(
- " func(void* result, void* params[null], void* temps[%zu], "
- "uint64 profile_counters[%zu])",
+ VLOG(3) << absl::StrFormat(
+ " func(void* result, void* params[null], void* buffer_table[%u], "
+ "uint64 profile_counters[%u])",
buffer_pointers.size(), profile_counters_size);
- VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
+ VLOG(3) << absl::StrFormat(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
- tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
+ absl::StrAppend(out, absl::StrFormat("%p", p));
};
VLOG(3) << " params = nullptr";
- VLOG(3) << tensorflow::strings::Printf(
- " temps = [%s]",
- tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
- VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
- profile_counters);
+ VLOG(3) << absl::StrFormat(
+ " buffer_table = [%s]",
+ absl::StrJoin(buffer_pointers, ", ", ptr_printer));
+ VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters);
}
compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
@@ -209,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
+ absl::Span<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(),
@@ -247,7 +246,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
TF_ASSIGN_OR_RETURN(
auto result,
@@ -258,7 +257,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
if (hlo_profiling_enabled()) {
return Unimplemented(
"Asynchronous execution on stream with hlo profiling is not yet "
@@ -269,7 +268,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
@@ -283,11 +282,12 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
+ CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
- TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &owning_buffers));
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer result,
+ CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers)));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@@ -300,7 +300,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
//
// We also need to change the types of some of the variables we capture:
// run_options needs to change from a pointer to a value type, and arguments
- // needs to change from an ArraySlice into a vector. We use a struct instead
+ // needs to change from a Span into a vector. We use a struct instead
// of a lambda to make this explicit.
struct AsyncRunTask {
CpuExecutable* executable;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 96e53de57e..3c3c047bfe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -57,12 +57,12 @@ class CpuExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
// This should be called after set_ir_module_string.
const string& ir_module_string() const { return ir_module_string_; }
@@ -74,9 +74,10 @@ class CpuExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape);
// Type of the computation function we expect in the JIT.
- using ComputeFunctionType = void (*)(
- void* /*result*/, const ExecutableRunOptions* /*run_options*/,
- const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
+ using ComputeFunctionType =
+ void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
+ const void** /*args*/, void** /*buffer_table*/,
+ int64* /*profile_counters*/);
const ComputeFunctionType& compute_function() const {
return compute_function_;
@@ -92,18 +93,18 @@ class CpuExecutable : public Executable {
// exists) must out-live the task.
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile);
- // Creates an array suitable for passing as the "temps" argument to the JIT
- // compiled function pointer.
+ // Creates an array suitable for passing as the "buffer_table" argument to the
+ // JIT compiled function pointer.
//
// Returns (unowning_buffers, owning_buffers) where:
//
- // - unowning_buffers.data() can be passed as the temps argument as-is and
- // includes pointers to the scratch storage required by the computation,
- // the live-out buffer into which the result will be written and entry
- // computation parameters.
+ // - unowning_buffers.data() can be passed as the buffer_table argument as-is
+ // and includes pointers to the scratch storage required by the
+ // computation, the live-out buffer into which the result will be written
+ // and entry computation parameters.
//
// - owning_buffers contains owning pointers to the buffers that were
// allocated by this routine. This routine allocates buffers for temporary
@@ -111,22 +112,21 @@ class CpuExecutable : public Executable {
// result.
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
- CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
- Status ExecuteComputeFunction(
- const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- HloExecutionProfile* hlo_execution_profile);
+ Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
+ absl::Span<const se::DeviceMemoryBase> buffers,
+ HloExecutionProfile* hlo_execution_profile);
// Creates a ScopedShapedBuffer for holding the result of the computation,
// moving buffers out of allocated_buffers and into the result as appropriate.
// The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
+ absl::Span<OwningDeviceMemory> buffers);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
index 7bd4741a04..7fbe0fa157 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
@@ -34,9 +34,8 @@ StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) {
return xla::Unimplemented(
"CPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
- instruction->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction->shape())
- .c_str());
+ instruction->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 2924b63659..6af724b2a5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface {
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "cpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "cpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6de6..be1208fb2d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index b40d264c03..f9cd61bea3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -35,7 +35,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicSlice ||
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kGather ||
- hlo.opcode() == HloOpcode::kPad ||
+ hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReshape ||
hlo.opcode() == HloOpcode::kReverse ||
hlo.opcode() == HloOpcode::kSlice ||
@@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
}
if (!CanBeLoopFused(*producer)) {
- VLOG(2) << "Producer is not fusile.";
+ VLOG(2) << "Producer is not fusible.";
return false;
}
@@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
}
if (CanBeLoopFused(*consumer)) {
- VLOG(2) << "Fusing: consumer is elementwise or fusile.";
+ VLOG(2) << "Fusing: consumer is elementwise or fusible.";
return true;
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 8cbdc36f84..7d99b914d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -18,11 +18,13 @@ limitations under the License.
#include <algorithm>
#include <set>
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace op = xla::testing::opcode_matchers;
@@ -37,7 +39,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ precision_config);
}
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
@@ -566,7 +572,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
HloOpcode::kParameter, HloOpcode::kParameter});
}
-TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
+TEST_F(OpcodeFusionTest, MessOfFusibleNodes) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
@@ -691,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
auto* addend = builder.AddInstruction(
HloInstruction::CreateParameter(2, dot_shape, "param2"));
- auto* dot = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ auto* dot =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
builder.AddInstruction(
HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
@@ -773,8 +779,8 @@ class GatherLoopFusionTest
TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
const GatherLoopFusionTestSpec& spec = GetParam();
- string hlo_string = tensorflow::strings::StrCat(
- "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
+ string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
+ spec.hlo_computation_text);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_string));
@@ -792,11 +798,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
@@ -808,11 +814,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
@@ -824,11 +830,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -840,11 +846,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -856,11 +862,11 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
@@ -872,11 +878,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
@@ -888,11 +894,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index aa872d5ec9..bfecbd6e01 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -34,8 +34,8 @@ namespace cpu {
// instruction stream.
namespace {
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
+using absl::nullopt;
+using absl::optional;
using ShouldMakeOperandColMajorCache =
tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 3681d12d8d..4668f3872d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@@ -39,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
@@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
builder.AddInstruction(HloInstruction::CreateBinary(
result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
@@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
auto tuple_result = builder.AddInstruction(
HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
@@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "param0"));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
HloInstruction::CreateParameter(1, dot_shape, "param1"));
HloInstruction* dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
- HloInstruction* dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ HloInstruction* dot_result =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
HloInstruction* add_result;
if (dot_operand_idx_in_add == 0) {
add_result = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index 3ed7876715..b8ace57026 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace {
@@ -45,17 +46,16 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
}
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config) {
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
int64 tiling_factor;
if (it != extra_options_map.end() &&
- tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
+ absl::SimpleAtoi(it->second, &tiling_factor)) {
return tiling_factor;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
@@ -64,38 +64,37 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
}
-static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
- tensorflow::StringPiece suffix) {
+static absl::string_view RemoveSuffix(absl::string_view str,
+ absl::string_view suffix) {
CHECK_GE(str.size(), suffix.size());
CHECK_EQ(str.substr(str.size() - suffix.size()), suffix);
return str.substr(0, str.size() - suffix.size());
}
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
auto it = extra_options_map.find(kLlvmIrGemmTileSize);
if (it == extra_options_map.end()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
- std::vector<string> tile_components =
- tensorflow::str_util::Split(it->second, ':');
+ std::vector<string> tile_components = absl::StrSplit(it->second, ':');
CHECK_EQ(tile_components.size(), 3);
int64 tile_size_m;
int64 tile_size_k;
int64 tile_size_n_in_vector_width;
- CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m));
- CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k));
+ CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m));
+ CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k));
- tensorflow::StringPiece tile_size_n_in_vector_width_str =
+ absl::string_view tile_size_n_in_vector_width_str =
RemoveSuffix(tile_components[2], "*vectwidth");
- CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str,
- &tile_size_n_in_vector_width));
+ CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str,
+ &tile_size_n_in_vector_width));
return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k,
tile_size_n_in_vector_width);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 429b9e16cb..47c7eb13b6 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,9 +27,8 @@ namespace options {
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
-tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
- const HloModuleConfig& config);
-tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
+absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
const HloModuleConfig& config);
} // namespace options
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 639064040f..8a44c384bb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index 2ac950e6d9..1ae3aa5711 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -19,16 +19,16 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -46,7 +46,7 @@ std::unique_ptr<Array2D<float>> MaybeTransposeArray2D(const Array2D<T>& array,
if (transpose) {
std::swap(output_width, output_height);
}
- auto output = MakeUnique<Array2D<float>>(output_height, output_width);
+ auto output = absl::make_unique<Array2D<float>>(output_height, output_width);
for (int y = 0; y < array.height(); y++) {
for (int x = 0; x < array.width(); x++) {
if (transpose) {
@@ -93,7 +93,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it. Swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_EigenSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
@@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest,
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
- return tensorflow::strings::Printf(
- "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
- transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
- single_threaded ? "single" : "multi");
+ return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m,
+ shape.k, shape.n, transpose_lhs ? "Tlhs_" : "",
+ transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
}
};
@@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest,
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
- return tensorflow::strings::Printf(
- "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
- transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
- single_threaded ? "single" : "multi");
+ return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m,
+ shape.k, shape.n, transpose_lhs ? "Tlhs_" : "",
+ transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
}
};
@@ -204,7 +204,7 @@ std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it, swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
- auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ auto c_transpose = absl::make_unique<Array2D<float>>(n, m);
if (single_threaded) {
__xla_cpu_runtime_MKLSingleThreadedMatMulF32(
nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 59bc7e0e16..5519a43b2f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
@@ -103,7 +104,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(
if (ShapeUtil::IsNestedTuple(shape)) {
return Unimplemented(
"Infeed with a nested tuple shape is not supported: %s",
- ShapeUtil::HumanString(literal.shape()).c_str());
+ ShapeUtil::HumanString(literal.shape()));
}
// For a tuple, we transfer each of its elements to the device and
@@ -151,11 +152,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
int64 size,
const void* source) {
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
+ return InvalidArgument("Infeed shape is too large: needs %d bytes", size);
}
if (size <= 0) {
- return InvalidArgument("Infeed shape must have positive size; got %lld",
+ return InvalidArgument("Infeed shape must have positive size; got %d",
size);
}
@@ -178,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
int64 size = GetByteSizeRequirement(literal_shape);
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
- tensorflow::gtl::ArraySlice<int64> dimensions(
+ absl::Span<const int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
TF_ASSIGN_OR_RETURN(
@@ -224,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data) {
+ absl::Span<const std::pair<void*, int64>> buffer_data) {
return TransferBuffersFromOutfeedInternal(executor, buffer_data,
/*is_tuple=*/true);
}
@@ -237,18 +238,17 @@ StatusOr<Shape> CpuTransferManager::TransferArrayBufferFromOutfeed(
StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple) {
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple) {
std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
for (auto b : buffer_data) {
int64 size = b.second;
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Outfeed shape is too large: needs %lld bytes",
+ return InvalidArgument("Outfeed shape is too large: needs %d bytes",
size);
}
if (size <= 0) {
- return InvalidArgument("Outfeed shape must have positive size; got %lld",
+ return InvalidArgument("Outfeed shape must have positive size; got %d",
size);
}
@@ -256,7 +256,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
VLOG(2)
<< "Enqueueing outfeed buffer (for the device to populate) of length "
<< size_32 << "B";
- buffers.emplace_back(MakeUnique<CpuOutfeedBuffer>(b.first, size_32));
+ buffers.emplace_back(absl::make_unique<CpuOutfeedBuffer>(b.first, size_32));
}
std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers;
@@ -283,7 +283,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() {
- return xla::MakeUnique<xla::CpuTransferManager>();
+ return absl::make_unique<xla::CpuTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 80ef953d53..361d4b9c84 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -13,18 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager {
// Helper that transfers a tuple of element buffers from the device's outfeed.
StatusOr<Shape> TransferTupleBuffersFromOutfeed(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data);
+ absl::Span<const std::pair<void*, int64>> buffer_data);
// Helper that transfers an array buffer from the device's outfeed.
StatusOr<Shape> TransferArrayBufferFromOutfeed(se::StreamExecutor* executor,
@@ -68,12 +68,11 @@ class CpuTransferManager : public GenericTransferManager {
// for the given buffers.
StatusOr<Shape> TransferBuffersFromOutfeedInternal(
se::StreamExecutor* executor,
- tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
- bool is_tuple);
+ absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple);
TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager);
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc
index e4c674e227..3ae64142cd 100644
--- a/tensorflow/compiler/xla/service/cpu/disassembler.cc
+++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc
@@ -21,13 +21,13 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/str_format.h"
#include "llvm/MC/MCInst.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -151,7 +151,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
size = 1;
}
- ostream << tensorflow::strings::Printf("0x%08lx", index) << " ";
+ ostream << absl::StrFormat("0x%08lx", index) << " ";
if (decode_status == llvm::MCDisassembler::Success) {
// For branches, try to determine the actual address and emit it as an
@@ -163,7 +163,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
uint64_t target;
if (inst_analysis_->evaluateBranch(
instruction, section_address + index, size, target)) {
- annotation = tensorflow::strings::Printf("[0x%08lx]", target);
+ annotation = absl::StrFormat("[0x%08lx]", target);
}
}
inst_printer_->printInst(&instruction, ostream, annotation.c_str(),
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index f2ac742b6e..99fa707c95 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -79,7 +80,7 @@ class MemoryTile {
// `minor_dim_offset`}.
//
// Note: `major_dim_offset` is a parameter to the constructor.
- void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+ void StoreTile(absl::Span<llvm::Value* const> tile,
llvm::Value* minor_dim_offset) const {
CHECK_EQ(tile.size(), pointers_.size());
for (int64 i = 0; i < pointers_.size(); i++) {
@@ -146,9 +147,9 @@ class GemvConfig {
bool has_addend() const { return has_addend_; }
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
- tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+ return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_",
+ tile_rows(), "_", tile_cols(), "_", m(), "_", k(),
+ has_addend() ? "_with_addend" : "");
}
protected:
@@ -621,19 +622,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
}
// This class implements a tiled matrix multiplication algorithm, intended for
-// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto,
-// Kazushige, and Robert Van De Geijn. "High-performance implementation of the
-// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008):
-// 4).
+// multiplying small matrices that don't need cache tiling.
+//
+// In the future this can be used as the innermost GEBP loop in a GEMM kernel as
+// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of
+// high-performance matrix multiplication." ACM Transactions on Mathematical
+// Software (TOMS) 34.3 (2008): 12.".
//
// This only supports canonical dot operations (i.e. where the lhs contraction
// dimension is 1 and the rhs contraction dimension is 0) over row major
// matrices.
-class MatrixMatrixBlockPanelEmitter {
+class TiledSmallGemmEmitter {
public:
- // Describe the dimensions of the GEBP kernel. These will usually not be the
- // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP
- // kernels with smaller dimensions.
+ // Describe the dimensions of the kernel.
class Dimensions {
public:
explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
@@ -642,9 +643,7 @@ class MatrixMatrixBlockPanelEmitter {
int64 k() const { return k_; }
int64 n() const { return n_; }
- string ToString() const {
- return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
- }
+ string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); }
private:
const int64 m_;
@@ -652,9 +651,9 @@ class MatrixMatrixBlockPanelEmitter {
const int64 n_;
};
- // Represents the configuration of the GEBP emitter. The LLVM IR emitted by
- // the emitter, modulo the LLVM values holding the input and output buffers,
- // must be a function of the instance of `Config` passed to it.
+ // Represents the configuration of the emitter. The LLVM IR emitted by the
+ // emitter, modulo the LLVM values holding the input and output buffers, must
+ // be a function of the instance of `Config` passed to it.
//
// `dims` holds the matrix multiplication dimensions.
//
@@ -687,10 +686,10 @@ class MatrixMatrixBlockPanelEmitter {
tile_size_k_(tile_size_k) {}
string GetCacheKey() const {
- return tensorflow::strings::StrCat(
- "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
- "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
- tile_size_m(), "_", tile_size_k());
+ return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_",
+ dims().ToString(), "_", max_vectorization_width(),
+ "_", min_vectorization_width(), "_", tile_size_m(),
+ "_", tile_size_k());
}
PrimitiveType scalar_type() const { return scalar_type_; }
@@ -712,11 +711,11 @@ class MatrixMatrixBlockPanelEmitter {
int64 tile_size_k_;
};
- // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+ // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies
// `lhs` with `rhs` and stores the result in `result`.
- explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
- llvm::Value* rhs, llvm::Value* result,
- llvm::IRBuilder<>* b)
+ explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* result,
+ llvm::IRBuilder<>* b)
: lhs_(lhs),
rhs_(rhs),
result_(result),
@@ -780,9 +779,9 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
+void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); }
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
+void TiledSmallGemmEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
@@ -799,7 +798,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
- "gebp");
+ "gemm");
HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
@@ -813,7 +812,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
if (n_start != dims().n()) {
- VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp");
+ VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm");
ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
HandleResiduesOnK(&vsl, n_i, n_i_next);
@@ -821,9 +820,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
}
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
- llvm::Value* n_start,
- llvm::Value* n_end) {
+void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
@@ -838,7 +837,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
}
}
-void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
+void TiledSmallGemmEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();
@@ -921,7 +920,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
// +-------------------+-------------------+-------------------+---------
// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
// +-------------------+-------------------+-------------------+---------
-void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
+void TiledSmallGemmEmitter::EmitTiledGemm(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
@@ -1001,12 +1000,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
return dot_emitter.Emit();
}
-bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
+bool DotOpEmitter::EmitSmallGemmIfProfitable(
const DotOpEmitter::MatMultDims& mat_mult_dims) {
- if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) {
+ if (ShouldUseMultiThreadedEigen()) {
return false;
}
+ if (!EnableExperimentalLlvmIrGemm()) {
+ // TODO(sanjoy): We should make these numbers micro-arch specific.
+ bool small_gemm = mat_mult_dims.k <= 128 &&
+ ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) ||
+ (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32));
+ if (!small_gemm) {
+ return false;
+ }
+ }
+
if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) {
return false;
}
@@ -1054,15 +1063,15 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
GetGemmTileSize();
- MatrixMatrixBlockPanelEmitter::Config config(
+ TiledSmallGemmEmitter::Config config(
/*scalar_type=*/primitive_type,
- MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
+ TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
/*max_vectorization_width=*/max_target_vector_width,
/*max_vector_count=*/tile_size_n_in_vector_width,
/*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
/*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
- VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
+ VLOG(2) << "Emitting GEMM kernel in LLVM IR with config "
<< config.GetCacheKey();
const bool enable_fast_math =
@@ -1075,10 +1084,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
/*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs,
rhs, target,
[this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
- MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs,
- /*rhs=*/rhs,
- /*result=*/target, b_);
- gebp_emitter.Emit();
+ TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs,
+ /*rhs=*/rhs,
+ /*result=*/target, b_);
+ small_gemm_emitter.Emit();
});
return true;
@@ -1136,7 +1145,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
}
if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
- return EmitExperimentalGebpDotIfEnabled(mat_mult_dims);
+ return EmitSmallGemmIfProfitable(mat_mult_dims);
}
int64 tiling_factor = GetGemvTilingFactor();
@@ -1458,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
break;
default:
return Unimplemented("Invalid type %s for dot operation",
- PrimitiveType_Name(type).c_str());
+ PrimitiveType_Name(type));
}
llvm::Type* float_ptr_type = float_type->getPointerTo();
@@ -1610,7 +1619,7 @@ bool PotentiallyImplementedAsEigenDot(
// For vector-matrix dot products, it is always profitable to make the Rhs
// column major.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo) {
if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 &&
hlo.shape().dimensions(0) == 1) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 590032fbe9..4c2041b556 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
+#include "absl/strings/string_view.h"
#include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot(
// Returns the index for an operand to `hlo` that should ideally be column
// major. Returns nullopt if there is no such operand or if `hlo` is not a dot
// or a fusion containing a dot.
-tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
+absl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
const HloInstruction& hlo);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
@@ -121,7 +121,7 @@ class DotOpEmitter {
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
- bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims);
+ bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims);
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index db54454707..c8312d80bd 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -30,15 +30,16 @@ limitations under the License.
namespace xla {
namespace cpu {
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- lhs = b_->CreateFPCast(lhs, b_->getFloatTy());
- rhs = b_->CreateFPCast(rhs, b_->getFloatTy());
+ lhs = FPCast(lhs, b_->getFloatTy());
+ rhs = FPCast(rhs, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@@ -58,21 +59,21 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, {lhs, rhs});
+ llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
-StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
bool cast_result_to_fp16 = false;
string function_name;
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
- value = b_->CreateFPCast(value, b_->getFloatTy());
+ value = FPCast(value, b_->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@@ -91,16 +92,16 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
// Create an instruction to call the function.
- llvm::Value* result = b_->CreateCall(function, value);
+ llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
if (hlo->opcode() == HloOpcode::kMap) {
return [this, hlo, &operand_to_generator](
const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 76833e765d..e3fba9306b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6f433b4f30..df8c2a636b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -27,6 +27,9 @@ limitations under the License.
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/types/span.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/BasicBlock.h"
@@ -64,11 +67,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
@@ -100,12 +100,17 @@ IrEmitter::IrEmitter(
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_cpu_enable_fast_math()));
+ Status s = GatherComputationsByAllocationType(
+ &hlo_module, &thread_local_computations_, &global_computations_);
+ absl::c_sort(thread_local_computations_);
+ absl::c_sort(global_computations_);
+ TF_CHECK_OK(s) << "Should have failed buffer assignment.";
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order) {
+ const std::vector<const HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);
@@ -170,9 +175,9 @@ IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] =
- b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)),
- IrShapeType(bitcast->shape())->getPointerTo(),
- AsStringRef(IrName(bitcast)));
+ BitCast(GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(),
+ AsStringRef(IrName(bitcast)));
return Status::OK();
}
@@ -230,9 +235,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
// Use the elemental emitter for array shapes.
return DefaultAction(copy);
}
- return Unimplemented(
- "unsupported operand type %s for copy instruction",
- PrimitiveType_Name(copy->shape().element_type()).c_str());
+ return Unimplemented("unsupported operand type %s for copy instruction",
+ PrimitiveType_Name(copy->shape().element_type()));
}
// Calculate the alignment of a buffer allocated for a given primitive type.
@@ -338,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Write the tuple index table.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
assignment_.GetUniqueSlice(infeed, {0}));
- llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
assignment_.GetUniqueSlice(infeed, {1}));
- llvm::Value* token_address = EmitTempBufferPointer(
+ llvm::Value* token_address = EmitBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
module_);
@@ -364,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
// instruction. Target addresses for internal elements can be obtained
- // from EmitTempBufferPointer.
+ // from EmitBufferPointer.
llvm::Value* tuple_element_address =
- EmitTempBufferPointer(buffer, tuple_element_shape);
+ EmitBufferPointer(buffer, tuple_element_shape);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(
XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
@@ -389,7 +393,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
int64 length = ByteSizeOf(shape);
if (length <= 0 || length > std::numeric_limits<int32>::max()) {
return InvalidArgument(
- "xfeed (infeed or outfeed) buffer length %lld is outside the valid "
+ "xfeed (infeed or outfeed) buffer length %d is outside the valid "
"size range",
length);
}
@@ -440,27 +444,33 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer = b_.CreateCall(
- acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer =
+ Call(acquire_func,
+ {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
- b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
- /*SrcAlign=*/1, length_32);
+ MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer,
+ /*SrcAlign=*/1, length_32);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
- b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
- /*SrcAlign=*/1, length_32);
+ MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address,
+ /*SrcAlign=*/1, length_32);
}
- b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer,
- shape_ptr, b_.getInt32(shape_length)});
+ Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
+ b_.getInt32(shape_length)});
return Status::OK();
}
Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
+ // Outfeed produces no useful result, but it does return a token[] that can be
+ // threaded through to other side effecting operations to ensure ordering. In
+ // the IR emitter we treat this token as a normal u8[] and thus need to insert
+ // an entry for it in emitted_value_.
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
+
HloInstruction* operand = outfeed->operands()[0];
const Shape& operand_shape = operand->shape();
@@ -501,8 +511,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name) {
+ absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
@@ -519,8 +528,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accumulator_address", &b_,
MinimumAlignmentForPrimitiveType(operand_element_type));
- b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
- accumulator_address);
+ Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
std::vector<int64> window_size;
@@ -537,22 +546,21 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm::Value* in_bounds_condition = nullptr;
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
- b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
// input_index[i] < bound as an *unsigned* comparison, since a negative
// value will wrap to a large positive value.
- llvm::Value* index_condition = b_.CreateICmpULT(
- input_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ llvm::Value* index_condition =
+ ICmpULT(input_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
if (in_bounds_condition == nullptr) {
in_bounds_condition = index_condition;
} else {
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
}
CHECK(in_bounds_condition != nullptr);
@@ -565,12 +573,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
- b_.CreateStore(result, accumulator_address);
+ *reduce_window->to_apply(), {Load(accumulator_address), input_value},
+ "reducer_function");
+ Store(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_address);
+ return Load(accumulator_address);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
@@ -647,7 +655,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- return b_.CreateLoad(init_value_addr);
+ return Load(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
@@ -667,7 +675,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
@@ -685,15 +693,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size());
llvm::Value* in_bounds_condition = b_.getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
- source_index[i], b_.getInt64(window.dimensions(i).stride()));
- operand_index[i] =
- b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
- operand_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ llvm::Value* strided_index =
+ NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
+ b_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition =
+ ICmpULT(operand_index[i],
+ b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -703,7 +710,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -712,38 +719,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
[&](const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* operand_element = Load(operand_address);
llvm::Value* result = EmitThreadLocalCall(
*select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
- "select_function");
+ {Load(selected_value_address), operand_element}, "select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -754,8 +760,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
llvm::Value* source_value =
@@ -837,7 +843,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
lhs_llvm_type, "convolution_sum_address", &b_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
- b_.CreateStore(constant_zero, sum_address);
+ Store(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
@@ -846,7 +852,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
loops
.AddLoop(
0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
+ absl::StrCat("k", i))
->GetIndVarValue();
}
llvm::Value* input_feature =
@@ -864,11 +870,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::Value* kernel_index,
const WindowDimension& window_dim) {
llvm::Value* strided_index =
- b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = b_.CreateNSWMul(
- kernel_index, b_.getInt64(window_dim.window_dilation()));
- return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index),
- b_.getInt64(window_dim.padding_low()));
+ NSWMul(output_index, b_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index =
+ NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation()));
+ return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
+ b_.getInt64(window_dim.padding_low()));
};
std::vector<llvm::Value*> input_spatial(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -885,9 +891,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
// Also need to check that the input coordinates are not in one of the
// holes created by base dilation.
const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
- llvm::Value* remainder =
- b_.CreateSRem(input_index, b_.getInt64(base_dilation));
- return b_.CreateICmpEQ(remainder, b_.getInt64(0));
+ llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation));
+ return ICmpEQ(remainder, b_.getInt64(0));
};
llvm::Value* in_bounds_condition = b_.getInt1(true);
@@ -895,17 +900,17 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound(
lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
llvm::Value* dim_not_in_hole =
not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok);
+ llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = And(in_bounds_condition, dim_ok);
}
// Now we need to map the dilated base coordinates back to the actual
// data indices on the lhs.
const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
- return b_.CreateSDiv(input_index, b_.getInt64(base_dilation));
+ return SDiv(input_index, b_.getInt64(base_dilation));
};
for (int i = 0; i < num_spatial_dims; ++i) {
input_spatial[i] =
@@ -930,8 +935,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
for (int i = 0; i < num_spatial_dims; ++i) {
kernel_index[dnums.kernel_spatial_dimensions(i)] =
window.dimensions(i).window_reversal()
- ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
+ ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
: kernel_spatial[i];
}
@@ -940,13 +945,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
llvm::Value* product =
- b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_),
- kernel_array.EmitReadArrayElement(kernel_index, &b_));
- llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product);
- b_.CreateStore(sum, sum_address);
+ FMul(input_array.EmitReadArrayElement(input_index, &b_),
+ kernel_array.EmitReadArrayElement(kernel_index, &b_));
+ llvm::Value* sum = FAdd(Load(sum_address), product);
+ Store(sum, sum_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(sum_address);
+ return Load(sum_address);
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
@@ -1072,34 +1077,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
- b_.CreateCall(
- conv_func,
- {
- GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type),
- b_.CreateBitCast(lhs_address, ir_ptr_type),
- b_.CreateBitCast(rhs_address, ir_ptr_type),
- b_.getInt64(input_batch),
- b_.getInt64(input_rows),
- b_.getInt64(input_cols),
- b_.getInt64(input_channels),
- b_.getInt64(kernel_rows),
- b_.getInt64(kernel_cols),
- b_.getInt64(kernel_channels),
- b_.getInt64(kernel_filters),
- b_.getInt64(output_rows),
- b_.getInt64(output_cols),
- b_.getInt64(row_stride),
- b_.getInt64(col_stride),
- b_.getInt64(padding_top),
- b_.getInt64(padding_bottom),
- b_.getInt64(padding_left),
- b_.getInt64(padding_right),
- b_.getInt64(lhs_row_dilation),
- b_.getInt64(lhs_col_dilation),
- b_.getInt64(rhs_row_dilation),
- b_.getInt64(rhs_col_dilation),
- });
+ Call(conv_func, {
+ GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
+ BitCast(lhs_address, ir_ptr_type),
+ BitCast(rhs_address, ir_ptr_type),
+ b_.getInt64(input_batch),
+ b_.getInt64(input_rows),
+ b_.getInt64(input_cols),
+ b_.getInt64(input_channels),
+ b_.getInt64(kernel_rows),
+ b_.getInt64(kernel_cols),
+ b_.getInt64(kernel_channels),
+ b_.getInt64(kernel_filters),
+ b_.getInt64(output_rows),
+ b_.getInt64(output_cols),
+ b_.getInt64(row_stride),
+ b_.getInt64(col_stride),
+ b_.getInt64(padding_top),
+ b_.getInt64(padding_bottom),
+ b_.getInt64(padding_left),
+ b_.getInt64(padding_right),
+ b_.getInt64(lhs_row_dilation),
+ b_.getInt64(lhs_col_dilation),
+ b_.getInt64(rhs_row_dilation),
+ b_.getInt64(rhs_col_dilation),
+ });
return Status::OK();
}
@@ -1159,15 +1162,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
const int fft_rank = fft_length.size();
- b_.CreateCall(
- fft_func,
- {GetExecutableRunOptionsArgument(),
- b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type),
- b_.CreateBitCast(operand_address, int8_ptr_type),
- b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank),
- b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
- b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
- b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
+ Call(fft_func,
+ {GetExecutableRunOptionsArgument(),
+ BitCast(GetEmittedValueFor(fft), int8_ptr_type),
+ BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
+ b_.getInt32(fft_rank), b_.getInt64(input_batch),
+ b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
+ b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
+ b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)});
return Status::OK();
}
@@ -1203,11 +1205,11 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
const Shape& operand_shape = crs->operand(i)->shape();
CHECK(ShapeUtil::IsArray(operand_shape))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+ operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
- /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
+ MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
+ /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape));
}
llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_);
return Status::OK();
@@ -1457,7 +1459,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment) {
ShardedVector accumulator;
accumulator.reserve(accumulator_type.size());
@@ -1466,19 +1468,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
accumulator_shard_type, "accumulator", &b_, 0));
}
- llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value));
+ llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
for (llvm::Value* accumulator_shard : accumulator) {
llvm::Value* initial_value;
auto shard_type = accumulator_shard->getType()->getPointerElementType();
if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
initial_value =
- b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa);
+ VectorSplat(vector_type->getNumElements(), init_value_ssa);
} else {
initial_value = init_value_ssa;
}
- b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment);
+ AlignedStore(initial_value, accumulator_shard, element_alignment);
}
llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
@@ -1500,24 +1502,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
}
CHECK(output_index.end() == it);
- llvm::Value* input_address = b_.CreateBitCast(
+ llvm::Value* input_address = BitCast(
arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
for (int i = 0; i < accumulator.size(); i++) {
auto input_address_typed =
- b_.CreateBitCast(input_address, accumulator[i]->getType());
+ BitCast(input_address, accumulator[i]->getType());
auto current_accumulator_value =
- b_.CreateAlignedLoad(accumulator[i], element_alignment);
- auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment);
+ AlignedLoad(accumulator[i], element_alignment);
+ auto addend = AlignedLoad(input_address_typed, element_alignment);
arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
auto reduced_result =
reduction_generator(&b_, current_accumulator_value, addend);
- b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment);
+ AlignedStore(reduced_result, accumulator[i], element_alignment);
if (i != (accumulator.size() - 1)) {
- input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(),
- input_address_typed, 1);
+ input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
+ input_address_typed, 1);
}
}
@@ -1526,8 +1528,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
ShardedVector result_ssa;
result_ssa.reserve(accumulator.size());
for (auto accumulator_shard : accumulator) {
- result_ssa.push_back(
- b_.CreateAlignedLoad(accumulator_shard, element_alignment));
+ result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
}
return result_ssa;
}
@@ -1536,25 +1537,25 @@ void IrEmitter::EmitShardedVectorStore(
llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
const int alignment, const llvm_ir::IrArray& containing_array) {
for (int i = 0; i < value_to_store.size(); i++) {
- auto store_address_typed = b_.CreateBitCast(
- store_address,
- llvm::PointerType::getUnqual(value_to_store[i]->getType()));
+ auto store_address_typed =
+ BitCast(store_address,
+ llvm::PointerType::getUnqual(value_to_store[i]->getType()));
- auto store_instruction = b_.CreateAlignedStore(
- value_to_store[i], store_address_typed, alignment);
+ auto store_instruction =
+ AlignedStore(value_to_store[i], store_address_typed, alignment);
containing_array.AnnotateLoadStoreInstructionWithMetadata(
store_instruction);
if (i != (value_to_store.size() - 1)) {
- store_address = b_.CreateConstInBoundsGEP1_32(
- value_to_store[i]->getType(), store_address_typed, 1);
+ store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
+ store_address_typed, 1);
}
}
}
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- gtl::ArraySlice<int64> dimensions, HloComputation* function,
+ absl::Span<const int64> dimensions, HloComputation* function,
string* failure_reason) {
if (!ReductionPreservesLayout(*reduce)) {
return false;
@@ -1620,9 +1621,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
int64 start_index = 0;
int64 end_index = reduce->shape().dimensions(dimension);
- std::unique_ptr<llvm_ir::ForLoop> loop =
- loop_nest.AddLoop(start_index, end_index,
- tensorflow::strings::Printf("dim.%lld", dimension));
+ std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
+ start_index, end_index, absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
@@ -1641,9 +1641,9 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
int64 start_index = 0;
int64 end_index = (innermost_dimension_size / vectorization_factor) *
vectorization_factor;
- std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- start_index, end_index, vectorization_factor,
- tensorflow::strings::Printf("dim.%lld", innermost_dimension));
+ std::unique_ptr<llvm_ir::ForLoop> loop =
+ loop_nest.AddLoop(start_index, end_index, vectorization_factor,
+ absl::StrFormat("dim.%d", innermost_dimension));
array_index[innermost_dimension] = loop->GetIndVarValue();
SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
@@ -1705,7 +1705,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1713,8 +1713,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
&b_, MinimumAlignmentForPrimitiveType(accumulator_type));
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = b_.CreateLoad(init_value_addr);
- b_.CreateStore(load_init_value, accumulator_addr);
+ llvm::Value* load_init_value = Load(init_value_addr);
+ Store(load_init_value, accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to compute
// the actual target element. For this, we build a new loop nest to iterate
@@ -1747,12 +1747,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
// Apply the reduction function to the loaded value.
llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ *reduce->to_apply(), {Load(accumulator_addr), input_element},
"reduce_function");
- b_.CreateStore(result, accumulator_addr);
+ Store(result, accumulator_addr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
}
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
@@ -1762,7 +1762,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
string vectorization_failure_reason;
@@ -1990,7 +1990,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
[this, pad](const llvm_ir::IrArray::Index& target_index) {
const HloInstruction* padding_value = pad->operand(1);
llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
- return b_.CreateLoad(padding_value_addr);
+ return Load(padding_value_addr);
}));
// Create a loop to iterate over the operand elements and update the output
@@ -2012,10 +2012,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
- llvm::Value* offset = b_.CreateMul(
- operand_index[i],
- b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
- llvm::Value* index = b_.CreateAdd(
+ llvm::Value* offset =
+ Mul(operand_index[i],
+ b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
+ llvm::Value* index = Add(
offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
output_index.push_back(index);
}
@@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
{}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument());
HloInstruction* root = computation->root_instruction();
@@ -2117,8 +2117,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
- gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
- tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
+ absl::Span<HloInstruction* const> operands(custom_call->operands());
+ absl::string_view custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
@@ -2126,10 +2126,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
- b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
llvm::Value* slot_in_operands_alloca =
- b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)});
- b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ InBoundsGEP(operands_alloca, {b_.getInt64(i)});
+ Store(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
@@ -2141,9 +2141,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
auto* output_address_arg =
- b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
+ PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
- b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca});
+ Call(custom_call_ir_function, {output_address_arg, operands_alloca});
return Status::OK();
}
@@ -2170,8 +2170,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
return InternalError(
"instruction %s %s does not share slice with "
"instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
+ a->ToString(), slice_a.ToString(), b->ToString(),
+ slice_b.ToString());
}
return Status::OK();
};
@@ -2202,15 +2202,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
compute_function_->function());
- b_.CreateBr(header_bb);
+ Br(header_bb);
b_.SetInsertPoint(header_bb);
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
- llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ llvm::Value* while_predicate = ICmpNE(
+ Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2219,7 +2218,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
- b_.CreateCondBr(while_predicate, body_bb, exit_bb);
+ CondBr(while_predicate, body_bb, exit_bb);
// Calls the body function from the body block.
b_.SetInsertPoint(body_bb);
@@ -2228,7 +2227,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
// Finishes with a branch back to the header.
- b_.CreateBr(header_bb);
+ Br(header_bb);
// Adds the exit block to the function and sets the insert point there.
compute_function_->function()->getBasicBlockList().push_back(exit_bb);
@@ -2238,7 +2237,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
}
StatusOr<bool> IrEmitter::EmitFastConcatenate(
- HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
+ HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
string* failure_reason) {
if (ShouldEmitParallelLoopFor(*concatenate)) {
*failure_reason =
@@ -2275,7 +2274,6 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
output_min2maj.end());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
- llvm::Type* i8_type = b_.getInt8Ty();
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
@@ -2298,9 +2296,9 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin =
- b_.CreateBitCast(target_array.EmitArrayElementAddress(
- outer_dims_index, &b_, "target_region"),
- i8_ptr_type);
+ BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_,
+ "target_region"),
+ i8_ptr_type);
int64 byte_offset_into_target_region = 0;
int64 inner_dims_product =
@@ -2314,13 +2312,12 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
for (HloInstruction* operand : operands) {
const Shape& input_shape = operand->shape();
llvm_ir::IrArray source_array = GetIrArrayFor(operand);
- llvm::Value* copy_source_address = b_.CreateBitCast(
+ llvm::Value* copy_source_address = BitCast(
source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"),
i8_ptr_type);
llvm::Value* copy_target_address =
- b_.CreateGEP(i8_type, target_region_begin,
- b_.getInt64(byte_offset_into_target_region));
+ GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
EmitTransferElements(
copy_target_address, copy_source_address,
@@ -2352,15 +2349,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
- auto* load_instruction = b_.CreateAlignedLoad(
- b_.CreateBitCast(source, primitive_ptr_type), element_alignment);
+ auto* load_instruction =
+ AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
- auto* store_instruction = b_.CreateAlignedStore(
- load_instruction, b_.CreateBitCast(target, primitive_ptr_type),
- element_alignment);
+ auto* store_instruction =
+ AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
+ element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
- auto* memcpy_instruction = b_.CreateMemCpy(
+ auto* memcpy_instruction = MemCpy(
target, /*DstAlign=*/element_alignment, source,
/*SrcAlign=*/element_alignment, element_count * primitive_type_size);
@@ -2376,7 +2373,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
}
Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
- gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
string failure_reason;
TF_ASSIGN_OR_RETURN(
bool successful,
@@ -2422,9 +2419,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
// cond_result = true_computation(true_operand)
// else
// cond_result = false_computation(false_operand)
- llvm::LoadInst* pred_value = b_.CreateLoad(
- GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = b_.CreateICmpNE(
+ llvm::LoadInst* pred_value =
+ Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = ICmpNE(
pred_value,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
@@ -2450,11 +2447,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
return Status::OK();
}
-Status IrEmitter::HandleIota(HloInstruction* iota) {
- // TODO(b/64798317): implement iota on CPU.
- return Unimplemented("Iota is not implemented on CPU.");
-}
-
Status IrEmitter::HandleRng(HloInstruction* rng) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : rng->operands()) {
@@ -2511,8 +2503,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon(
int64 prof_counter_idx = it->second;
string counter_name = IrName("prof_counter", hlo.name());
- return b_.CreateGEP(GetProfileCountersArgument(),
- b_.getInt64(prof_counter_idx), AsStringRef(counter_name));
+ return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
+ AsStringRef(counter_name));
}
void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
@@ -2630,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() {
return compute_function_->profile_counters_arg();
}
-llvm::Value* IrEmitter::GetTempBuffersArgument() {
- return compute_function_->temp_buffers_arg();
+llvm::Value* IrEmitter::GetBufferTableArgument() {
+ return compute_function_->buffer_table_arg();
}
llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
@@ -2666,8 +2658,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
+ llvm::LoadInst* param_address_untyped = Load(param_address_offset);
if (!ShapeUtil::IsOpaque(target_shape)) {
AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
@@ -2687,25 +2678,23 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
auto buf_it = thread_local_buffers_.find(key);
if (buf_it == thread_local_buffers_.end()) {
llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
- IrShapeType(shape),
- tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
- MinimumAlignmentForShape(target_shape));
+ IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
+ &b_, MinimumAlignmentForShape(target_shape));
auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
CHECK(it_inserted_pair.second);
buf_it = it_inserted_pair.first;
}
return buf_it->second;
}();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+llvm::Value* IrEmitter::EmitGlobalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
- GetTempBuffersArgument(), slice.index(), &b_);
- llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
+ GetBufferTableArgument(), slice.index(), &b_);
+ llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
tempbuf_address_base->setMetadata(
@@ -2719,20 +2708,20 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
if (slice.offset() > 0) {
// Adjust the address to account for the slice offset.
tempbuf_address_untyped =
- b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
+ InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
- return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ return BitCast(tempbuf_address_untyped,
+ IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
+llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape) {
if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ return EmitThreadLocalBufferPointer(slice, target_shape);
} else if (slice.allocation()->is_constant()) {
return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
} else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
+ return EmitGlobalBufferPointer(slice, target_shape);
}
}
@@ -2740,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
const Shape& target_shape = op->shape();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ llvm::Value* addr = EmitBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2753,7 +2742,7 @@ Status IrEmitter::EmitTargetElementLoop(
}
Status IrEmitter::EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator) {
VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
@@ -2769,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop(
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(target_op, {i}));
const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
- llvm::Value* op_target_address =
- EmitTempBufferPointer(slice, element_shape);
+ llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
output_arrays.push_back(
llvm_ir::IrArray(op_target_address, element_shape));
}
@@ -2808,15 +2796,15 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
llvm::Value* destination_value = GetEmittedValueFor(&destination);
int64 source_size = ByteSizeOf(source.shape());
// TODO(b/63762267): Be more aggressive about specifying alignment.
- b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value,
- /*SrcAlign=*/1, source_size);
+ MemCpy(destination_value, /*DstAlign=*/1, source_value,
+ /*SrcAlign=*/1, source_size);
return Status::OK();
}
Status IrEmitter::ElementTypesSameAndSupported(
const HloInstruction& instruction,
- gtl::ArraySlice<const HloInstruction*> operands,
- gtl::ArraySlice<PrimitiveType> supported_types) {
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types) {
for (auto operand : operands) {
TF_RET_CHECK(
ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
@@ -2827,8 +2815,8 @@ Status IrEmitter::ElementTypesSameAndSupported(
if (std::find(supported_types.begin(), supported_types.end(),
primitive_type) == supported_types.end()) {
return Unimplemented("unsupported operand type %s in op %s",
- PrimitiveType_Name(primitive_type).c_str(),
- HloOpcodeString(instruction.opcode()).c_str());
+ PrimitiveType_Name(primitive_type),
+ HloOpcodeString(instruction.opcode()));
}
return Status::OK();
}
@@ -2846,9 +2834,10 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
}
llvm::Value* IrEmitter::EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name) {
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+ absl::string_view name) {
+ CHECK(absl::c_binary_search(thread_local_computations_, &callee));
+
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
@@ -2863,38 +2852,39 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
CHECK(!parameter->getType()->isPointerTy());
llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
+ Store(parameter, parameter_addr);
parameter_addrs.push_back(parameter_addr);
}
llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(return_type, module_),
- tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ absl::StrCat(name, "_retval_addr"), &b_,
MinimumAlignmentForPrimitiveType(return_type));
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*buffer_table_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
- return b_.CreateLoad(return_value_buffer);
+ return Load(return_value_buffer);
}
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
+ absl::string_view name) {
+ CHECK(absl::c_binary_search(global_computations_, &callee));
+
+ Call(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
@@ -2906,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
const BufferAllocation::Slice root_buffer =
assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
+ return EmitBufferPointer(root_buffer, root_inst->shape());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index c9a1dab62d..3df99464ba 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -23,6 +23,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -39,13 +41,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -55,13 +56,14 @@ namespace cpu {
// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
// functions.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
// Create a new LLVM IR emitter.
//
// hlo_module: the HLO module we are emitting IR for.
- // assignment: a BufferAssignment from which we know which temporary buffers
- // are used by the HLO nodes.
+ // assignment: a BufferAssignment from which we know which buffers are used by
+ // the HLO nodes.
// llvm_module: the LLVM module to emit IR into.
// instruction_to_profile_idx: the mapping from HLO instructions to their
// index in the profiling array.
@@ -96,18 +98,21 @@ class IrEmitter : public DfsHloVisitorWithDefault {
StatusOr<llvm::Function*> EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_top_level_computation,
- std::vector<const HloInstruction*>* instruction_order);
+ const std::vector<const HloInstruction*>* instruction_order);
llvm::IRBuilder<>* b() { return &b_; }
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return &b_; }
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
// Emit code to map one element according to `map_instr`.
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name);
+ absl::Span<llvm::Value* const> elemental_operands,
+ absl::string_view name);
protected:
//
@@ -152,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleConditional(HloInstruction* conditional) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
- Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
Status FinishVisit(HloInstruction* root) override;
@@ -215,31 +219,28 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// argument of the computation function being emitted by this emitter.
llvm::Value* GetExecutableRunOptionsArgument();
- // Get the llvm::Value* that represents the "temps" argument of the
+ // Get the llvm::Value* that represents the "buffer_table" argument of the
// computation function being emitted by this emitter.
- llvm::Value* GetTempBuffersArgument();
+ llvm::Value* GetBufferTableArgument();
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitThreadLocalTempBufferPointer(
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape);
// Emits code that computes the address of the given buffer allocation slice.
- //
- // TODO(sanjoy): This should be renamed to reflect that it no longer provides
- // access to just temporaries.
- llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
// Emits a function into the current module. This can be used for
// computations embedded inside other computations, such as the
// function that a map operation applies.
StatusOr<llvm::Function*> EmitFunction(
HloComputation* function, // The function to emit.
- tensorflow::StringPiece
+ absl::string_view
function_name_suffix); // Used for LLVM IR register names.
// Emits a call to a thread local function (e.g. to the computation nested
@@ -248,17 +249,15 @@ class IrEmitter : public DfsHloVisitorWithDefault {
//
// `parameters` holds the *scalar values* that need to be passed to the
// callee. The return value is the scalar returned by the callee.
- llvm::Value* EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name);
+ llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
+ absl::Span<llvm::Value* const> parameters,
+ absl::string_view name);
// Emits a call to a "global" function (e.g. to the computation nested within
// a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
// the parameters and return values for these computations so there is no need
// to explicitly pass parameters or return results.
- void EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name);
+ void EmitGlobalCall(const HloComputation& callee, absl::string_view name);
// Returns the buffer to which a global call to `callee` would have written
// its result.
@@ -268,8 +267,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// match and are of one of the given supported types.
Status ElementTypesSameAndSupported(
const HloInstruction& instruction,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
+ absl::Span<const HloInstruction* const> operands,
+ absl::Span<const PrimitiveType> supported_types);
// Emit IR to perform a computation for every element in the given target op.
// This produces a series of nested loops (one for each dimension of the op's
@@ -285,7 +284,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* target_op,
const llvm_ir::ElementGenerator& element_generator);
Status EmitTargetElementLoop(
- HloInstruction* target_op, tensorflow::StringPiece desc,
+ HloInstruction* target_op, absl::string_view desc,
const llvm_ir::ElementGenerator& element_generator);
// Emits a memcpy from the source instruction's result value to the
@@ -316,10 +315,12 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// concepts that generalize over other vectorizable operations. We should
// consider pulling out these abstractions into a VectorizingIrEmitter or
// something similar.
- StatusOr<bool> EmitVectorizedReduce(
- HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
- string* failure_reason);
+ StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
+ HloInstruction* arg,
+ HloInstruction* init_value,
+ absl::Span<const int64> dimensions,
+ HloComputation* function,
+ string* failure_reason);
// We'd like to keep one or two one cache-line's worth of data in registers
// without generating IR with illegal (e.g. excessively large or
@@ -369,16 +370,15 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, absl::Span<const int64> dimensions,
unsigned element_alignment);
// Tries to emit a fast concatenate operation using memcpy. Returns true if
// successful, and false on failure. On failure, sets "failure_reason" to a
// string describing why it could not emit a fast concatenate.
- StatusOr<bool> EmitFastConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- string* failure_reason);
+ StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
+ absl::Span<HloInstruction* const> operands,
+ string* failure_reason);
// Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
// from the address "source" to the address "target".
@@ -387,8 +387,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& source_array);
- // Assignment of the temporary buffers needed by the computation and their
- // shape information.
+ // Assignment of the buffers needed by the computation and their shape
+ // information.
const BufferAssignment& assignment_;
// The LLVM module into which IR will be emitted.
@@ -568,6 +568,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
+ std::vector<const HloComputation*> thread_local_computations_;
+ std::vector<const HloComputation*> global_computations_;
+
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 2db4d000f5..adfb8392bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -77,19 +78,20 @@ void IrFunction::Initialize(const string& function_name,
const bool optimize_for_size_requested,
const bool enable_fast_math) {
// The function signature is:
- // void function(i8* retval, i8* run_options, i8** params, i8** temps,
+ // void function(i8* retval, i8* run_options, i8** params, i8**
+ // buffer_table,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
// For thread local functions:
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
- // temps: is null
+ // buffer_table: is null
//
// For global functions:
// retval: is null
// params: is null
- // temps: address of an array with pointers to temporary buffers and entry
- // computation parameters.
+ // buffer_table: address of an array with pointers to temporary buffers and
+ // entry computation parameters (but not to constant buffers).
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -115,7 +117,7 @@ void IrFunction::Initialize(const string& function_name,
// \---------/ \---------/ \-----------/
//
// /---------------------------------------------\
- // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
+ // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
// | addr | addr | | addr |
// \---------------------------------------------/
// | | |
@@ -133,9 +135,9 @@ void IrFunction::Initialize(const string& function_name,
// prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
// \---------------------------------------------/
- // Even though the type of params and temps is void** in the host's view, in
- // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
- // to use GEPs to unravel the indirection layers.
+ // Even though the type of params and buffer_table is void** in the host's
+ // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
+ // the code to use GEPs to unravel the indirection layers.
llvm::FunctionType* function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
/*Params=*/
@@ -159,8 +161,8 @@ void IrFunction::Initialize(const string& function_name,
exec_run_options_arg_ = &*arg_iter;
(++arg_iter)->setName("params");
parameters_arg_ = &*arg_iter;
- (++arg_iter)->setName("temps");
- temp_buffers_arg_ = &*arg_iter;
+ (++arg_iter)->setName("buffer_table");
+ buffer_table_arg_ = &*arg_iter;
if (num_dynamic_loop_bounds_ > 0) {
(++arg_iter)->setName("dynamic_loop_bounds");
dynamic_loop_bounds_arg_ = &*arg_iter;
@@ -189,7 +191,7 @@ void IrFunction::Initialize(const string& function_name,
llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
CHECK_GT(num_dynamic_loop_bounds_, 0);
CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
- string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
+ string name = absl::StrCat("dynamic_loop_bound_", offset);
return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
b_->getInt64(offset), AsStringRef(name)));
}
@@ -199,10 +201,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// Returns an array of compute function call arguments (including parameter
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
+ llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
if (parameter_addresses.empty()) {
@@ -211,13 +213,13 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
} else {
parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+ absl::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr =
b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
+ AsStringRef(absl::StrCat(name, "_parameter_", i,
+ "_address_as_i8ptr")));
llvm::Value* slot_in_param_addresses =
b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
@@ -229,7 +231,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
};
std::vector<llvm::Value*> arguments{
to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
- parameter_addresses_buffer, temp_buffers_arg};
+ parameter_addresses_buffer, buffer_table_arg};
if (profile_counters_arg != nullptr) {
arguments.push_back(profile_counters_arg);
}
@@ -320,8 +322,7 @@ Status EmitCallToParallelForkJoin(
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/partitions_array,
/*Name=*/
- AsStringRef(
- tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
+ AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions")));
// Add argument specifying parallel dimension partitions.
fork_join_arguments.push_back(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index a41cbb64cd..623a5f185f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#include "absl/types/span.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -80,8 +80,9 @@ class IrFunction {
// Get the llvm::Value* that represents this functions parameters argument.
llvm::Value* parameters_arg() { return parameters_arg_; }
- // Get the llvm::Value* that represents this functions "temps" argument.
- llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; }
+ // Get the llvm::Value* that represents this functions "buffer_table"
+ // argument.
+ llvm::Value* buffer_table_arg() { return buffer_table_arg_; }
// Get the llvm::Value* that represents this functions "prof_counters"
// argument.
@@ -108,17 +109,17 @@ class IrFunction {
llvm::Argument* result_arg_;
llvm::Value* exec_run_options_arg_;
llvm::Value* parameters_arg_;
- llvm::Value* temp_buffers_arg_;
+ llvm::Value* buffer_table_arg_;
llvm::Value* dynamic_loop_bounds_arg_ = nullptr;
llvm::Value* profile_counters_arg_;
};
// Returns an array of compute function call argument ir values.
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* b, tensorflow::StringPiece name,
- llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
- llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
+ absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
+ absl::string_view name, llvm::Value* return_value_buffer,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
+ llvm::Value* profile_counters_arg);
// Emits a call to a runtime fork/join function which dispatches parallel
// calls to 'parallel_function' (and joins threads before returning).
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 8560e4296a..f8441c3e34 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace cpu {
@@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
dynamic_loop_bounds_(dynamic_loop_bounds) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
CHECK(!ShapeUtil::IsTuple(shape_));
@@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second;
std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension),
- start_index, end_index);
+ /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index,
+ end_index);
array_index[dimension] = loop->GetIndVarValue();
} else {
// Emit static loop bounds for this dimension.
std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/shape_.dimensions(dimension),
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ /*suffix=*/absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 076c683ca5..a604e1db22 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
const DynamicLoopBounds* dynamic_loop_bounds_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 4fa5984b04..b4c0c09ec0 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
@@ -109,7 +111,7 @@ ParallelTaskAssignment::ParallelTaskAssignment(
: target_machine_features_(*target_machine_features) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
- auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
+ auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size);
HloComputation* computation = module->entry_computation();
Status status = computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
@@ -216,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
// Outline 'instruction' in 'computation' for parallel task assignment.
auto* call = module->OutlineExpressionFromComputation(
- {instruction},
- tensorflow::strings::StrCat("parallel_", instruction->name()),
+ {instruction}, absl::StrCat("parallel_", instruction->name()),
computation);
// Set assigned dimension partitioning to 'instruction'.
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 8becc8fa23..a99cd99c14 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface {
target_machine_features_(*target_machine_features) {}
~ParallelTaskAssigner() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cpu-parallel-task-assigner";
}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index ee272b5f4f..fad76338a5 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -36,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index a5f34908d7..2d9492eacf 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
//
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, uint64* prof_counters, int32 num_partitions,
+ void** buffer_table, uint64* prof_counters, int32 num_partitions,
int64* partitions, int32 num_partitioned_dims, void* function_ptr) {
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
@@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, nullptr, temps,
+ function(result_ptr, run_options_ptr, nullptr, buffer_table,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
@@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
}
// Call first compute function inline.
- function(result_ptr, run_options_ptr, params, temps, &partitions[0],
+ function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0],
prof_counters);
VLOG(3) << "ParallelForkJoin partition 0 done.";
bc.Wait();
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
index 1cf0ec6e3d..a279c7d2d6 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
@@ -24,7 +24,7 @@ extern "C" {
// threads before returning. See comments in runtime_fork_join.cc for details.
extern void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, tensorflow::uint64* prof_counters,
+ void** buffer_table, tensorflow::uint64* prof_counters,
tensorflow::int32 num_partitions, tensorflow::int64* partitions,
tensorflow::int32 num_partitioned_dims, void* function_ptr);
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index f227e4ae13..55d5925642 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -37,21 +37,20 @@ int main(int argc, char** argv) {
xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
// Transfer parameters.
- std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal param0_literal =
xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
- client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal =
- xla::LiteralUtil::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
- client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client->TransferToServer(param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@ int main(int argc, char** argv) {
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
- xla::StatusOr<std::unique_ptr<xla::Literal>> result =
- client->ExecuteAndTransfer(
- computation,
- /*arguments=*/{param0_data.get(), param1_data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/&profile);
- std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+ xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*execution_options=*/nullptr,
+ /*execution_profile=*/&profile);
+ xla::Literal actual = result.ConsumeValueOrDie();
- LOG(INFO) << tensorflow::strings::Printf("computation took %lldns",
- profile.compute_time_ns());
- LOG(INFO) << actual->ToString();
+ LOG(INFO) << absl::StrFormat("computation took %dns",
+ profile.compute_time_ns());
+ LOG(INFO) << actual.ToString();
return 0;
}
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index ae80a6f497..1a3d82de95 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
protected:
typedef std::vector<int64> Vec;
@@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) {
{
ShapePartitionIterator iterator(shape, {1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2});
EXPECT_EQ(2, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1)));
}
{
ShapePartitionIterator iterator(shape, {3});
EXPECT_EQ(3, iterator.GetTotalPartitionCount());
- EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0)));
- EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1)));
- EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2)));
+ EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0)));
+ EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1)));
+ EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2)));
}
}
@@ -128,24 +128,24 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
ShapePartitionIterator iterator(shape, {1, 1});
EXPECT_EQ(1, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0)));
}
{
ShapePartitionIterator iterator(shape, {2, 2});
EXPECT_EQ(4, iterator.GetTotalPartitionCount());
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
+ absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0)));
EXPECT_TRUE(
- ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
+ absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
+ absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2)));
EXPECT_TRUE(
- ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
+ absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3)));
}
}
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index be772cfb7e..bf98064647 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include <list>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/Mangler.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Host.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
@@ -170,15 +170,14 @@ namespace {
bool RegisterKnownJITSymbols() {
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
-#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
- do { \
- auto* function_address = \
- reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
- registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
- function_address); \
- CHECK_EQ( \
- tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \
- "__xla_cpu_runtime_" #base_name); \
+#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
+ do { \
+ auto* function_address = \
+ reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
+ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
+ function_address); \
+ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
+ "__xla_cpu_runtime_" #base_name); \
} while (false)
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 181cec3cdd..c55206eee7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,9 +48,11 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +96,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
)
@@ -108,6 +111,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -118,9 +122,11 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index 6fcce42eaa..18ee25ba91 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -19,10 +19,11 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
HloInstruction* rhs = builder.AddInstruction(
HloInstruction::CreateParameter(1, param_shape, "input"));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
@@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
HloInstruction* lhs_transposed = builder.AddInstruction(
HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index d98856fdbf..1deb412064 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -17,15 +17,15 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
protected:
CpuFusionTest() {}
@@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
- Shape vshape = input_literal1->shape();
+ Shape vshape = input_literal1.shape();
auto input1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
}
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -122,20 +122,19 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
- LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
- error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
}
-TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
- // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the
+TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
+ // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the
// middle.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto input = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
<< fusion_instruction2->fused_instructions_computation()->ToString();
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
- *result, error_spec_);
+ result, error_spec_);
}
TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- Shape vshape = input_literal->shape();
+ Shape vshape = input_literal.shape();
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// Run fusion.
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
auto fusion1 = result->operand(0);
auto fusion2 = result->operand(1);
@@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The only fusion instruction should be operand 0 of the tuple (formerly
// negate1).
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c661..5cc6d01c0f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
// Send 5 Infeed data of shape F32[3].
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 3 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
}
// Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+ LiteralUtil::CreateR0<bool>(false)})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(true)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
- LiteralUtil::CreateR0<bool>(false).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+ LiteralUtil::CreateR0<bool>(false)})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
- LiteralUtil::CreateR0<bool>(true).get()})));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+ LiteralUtil::CreateR0<bool>(true)})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
// Only the first 6 infeed data should be added.
- LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+ LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
index 973aac8766..a434c04a98 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <cctype>
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android";
struct IntrinsicTestSpec {
HloOpcode opcode;
- tensorflow::StringPiece triple;
- tensorflow::StringPiece features;
- tensorflow::StringPiece check_lines;
+ absl::string_view triple;
+ absl::string_view features;
+ absl::string_view check_lines;
};
// Tests that unary functions get lowered using intrinsic calls.
@@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest
features = "";
}
- return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(),
- features.empty() ? "" : "_With",
- features.c_str());
+ return absl::StrCat(opcode, "_On_", triple,
+ (features.empty() ? "" : "_With"), features);
}
};
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 01daed4bcd..7af51db55a 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {};
TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
@@ -62,7 +61,8 @@ TEST_F(CpuNoAliasTest, Concat) {
// Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it.
auto status_or_buffer_assn = BufferAssigner::Run(
- hlo_module.get(), MakeUnique<DependencyHloOrdering>(hlo_module.get()),
+ hlo_module.get(),
+ absl::make_unique<DependencyHloOrdering>(hlo_module.get()),
backend().compiler()->BufferSizeBytesFunction(),
[](LogicalBuffer::Color) { return /*alignment=*/1; });
ASSERT_EQ(status_or_buffer_assn.status(), Status::OK());
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 780c07f819..e2c7af541e 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -54,6 +54,33 @@ CHECK: private constant [48 x i8]
/*match_optimized_ir=*/false);
}
+TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) {
+ const string hlo_text = R"(
+HloModule OutfeedTokenInTuple
+
+ENTRY main {
+ const = f32[] constant(42)
+ epoch = token[] after-all()
+ outfeed.tok = token[] outfeed(const, epoch)
+ ROOT root = (token[], f32[]) tuple(outfeed.tok, const)
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: Outfeed
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/false);
+}
} // namespace
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index 3274be8d9d..1bd4b59dd6 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "absl/algorithm/container.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -422,12 +423,12 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support,
std::vector<llvm::Value*> TileVariable::Get() const {
std::vector<llvm::Value*> result;
- c_transform(storage_, std::back_inserter(result),
- [&](VectorVariable vect_var) { return vect_var.Get(); });
+ absl::c_transform(storage_, std::back_inserter(result),
+ [&](VectorVariable vect_var) { return vect_var.Get(); });
return result;
}
-void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
+void TileVariable::Set(absl::Span<llvm::Value* const> value) {
CHECK_EQ(value.size(), storage_.size());
for (int64 i = 0, e = value.size(); i < e; i++) {
storage_[i].Set(value[i]);
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index c728f6df0a..5690d2be2f 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -324,7 +324,7 @@ class TileVariable {
std::vector<llvm::Value*> initial_value);
std::vector<llvm::Value*> Get() const;
- void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
+ void Set(absl::Span<llvm::Value* const> value);
private:
std::vector<VectorVariable> storage_;
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
index 47543b2082..b9e47f5aad 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc
@@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() {
}
void XfeedQueueManager::EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers) {
+ absl::Span<XfeedBuffer* const> buffers) {
tensorflow::mutex_lock l(mu_);
bool was_empty = enqueued_buffers_.empty();
for (XfeedBuffer* b : buffers) {
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
index b4ace23260..990ff94ba2 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h
@@ -22,10 +22,10 @@ limitations under the License.
#include <deque>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -63,8 +63,7 @@ class XfeedQueueManager {
// called when the buffer will no longer be accessed by the XfeedManager,
// either as a result of a call to Reset or because the runtime has dequeued
// and used the buffer.
- void EnqueueBuffersAtomically(
- tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers);
+ void EnqueueBuffersAtomically(absl::Span<XfeedBuffer* const> buffers);
// Blocks until the queue is non-empty, then returns the buffer at the head of
// the queue. Sets the current buffer to be the returned buffer. It is an
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index 56b28fd22d..c326beb899 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -29,7 +29,7 @@ class Defuser : public HloPassInterface {
public:
Defuser() {}
~Defuser() override {}
- tensorflow::StringPiece name() const override { return "defuser"; }
+ absl::string_view name() const override { return "defuser"; }
// Run defusion on the given module. Returns whether the module was
// changed.
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index 48e4471499..ba2a674d9a 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -27,9 +27,7 @@ namespace {
class ControlDepRemover : public HloPassInterface {
public:
ControlDepRemover() = default;
- tensorflow::StringPiece name() const override {
- return "control-dep-remover";
- }
+ absl::string_view name() const override { return "control-dep-remover"; }
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index cc1695b7f8..7be70add2f 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -33,7 +33,7 @@ namespace xla {
class Despecializer : public HloPassInterface {
public:
Despecializer();
- tensorflow::StringPiece name() const override { return "despecializer"; }
+ absl::string_view name() const override { return "despecializer"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
index e228bb56bc..edbcb25247 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.cc
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -25,7 +25,7 @@ namespace xla {
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors)
+ absl::Span<se::StreamExecutor* const> stream_executors)
: DeviceMemoryAllocator(platform),
stream_executors_(stream_executors.begin(), stream_executors.end()) {}
@@ -36,9 +36,8 @@ StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
se::DeviceMemoryBase result = stream_executor->AllocateArray<uint8>(size);
if (size > 0 && result == nullptr) {
return ResourceExhausted(
- "Failed to allocate request for %s (%lluB) on device ordinal %d",
- tensorflow::strings::HumanReadableNumBytes(size).c_str(), size,
- device_ordinal);
+ "Failed to allocate request for %s (%uB) on device ordinal %d",
+ tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal);
}
return OwningDeviceMemory(result, device_ordinal, this);
}
@@ -61,12 +60,12 @@ StatusOr<se::StreamExecutor*> StreamExecutorMemoryAllocator::GetStreamExecutor(
}
if (device_ordinal >= stream_executors_.size()) {
return InvalidArgument(
- "device ordinal value (%d) >= number of devices (%zu)", device_ordinal,
+ "device ordinal value (%d) >= number of devices (%u)", device_ordinal,
stream_executors_.size());
}
if (stream_executors_[device_ordinal] == nullptr) {
return NotFound("Device %s:%d present but not supported",
- platform()->Name().c_str(), device_ordinal);
+ platform()->Name(), device_ordinal);
}
return stream_executors_[device_ordinal];
}
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h
index d87b86caf0..a2308ee7a4 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.h
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
public:
StreamExecutorMemoryAllocator(
const se::Platform* platform,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<se::StreamExecutor* const> stream_executors);
StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
bool retry_on_failure) override;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
index 2172ae0a29..3e7373adc5 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
@@ -28,14 +28,14 @@ template <typename HloInstructionPtr>
Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseUnary(
HloInstructionPtr hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
}
template <typename HloInstructionPtr>
Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseBinary(
HloInstructionPtr hlo) {
return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
}
template <typename HloInstructionPtr>
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 86d57581f8..5761573791 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -19,14 +19,14 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -107,6 +107,7 @@ class DfsHloVisitorBase {
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
+ virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
virtual Status HandleCompare(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
@@ -208,7 +209,6 @@ class DfsHloVisitorBase {
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
- virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
virtual Status HandleSort(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 617a5a2eb4..4cd10ab06c 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase
Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
return DefaultAction(crs);
}
- Status HandleAllToAll(HloInstructionPtr crs) override {
- return DefaultAction(crs);
+ Status HandleAllToAll(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
+ }
+ Status HandleCollectivePermute(HloInstructionPtr hlo) override {
+ return DefaultAction(hlo);
}
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
@@ -106,9 +109,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleOutfeed(HloInstructionPtr outfeed) override {
return DefaultAction(outfeed);
}
- Status HandleHostCompute(HloInstructionPtr host_compute) override {
- return DefaultAction(host_compute);
- }
Status HandleReverse(HloInstructionPtr reverse) override {
return DefaultAction(reverse);
}
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 12faed6967..b2ba261790 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -134,8 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
- dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
+ auto dot_r2 = computation->AddInstruction(
+ HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
+ dot_dnums, dot->precision_config()));
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index 1959b687f1..fc38e31700 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface {
DotDecomposer(bool decompose_batch_dot = true)
: decompose_batch_dot_(decompose_batch_dot) {}
~DotDecomposer() = default;
- tensorflow::StringPiece name() const override { return "dot_decomposer"; }
+ absl::string_view name() const override { return "dot_decomposer"; }
// Run DotDecomposer pass on computations in 'module'.
// Returns whether the 'module' was changed.
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 2e9d6be2de..4bb1e071d8 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -21,11 +21,15 @@ limitations under the License.
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -38,17 +42,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+using absl::StrCat;
using llvm_ir::AsStringRef;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrCat;
namespace {
@@ -203,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
} // namespace
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
} else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
@@ -217,7 +220,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -229,14 +232,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateICmpNE(operand_value, llvm::ConstantInt::get(
- operand_value->getType(), 0)),
+ ICmpNE(operand_value,
+ llvm::ConstantInt::get(operand_value->getType(), 0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsIntegralType(to_type)) {
- return b_->CreateIntCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
- primitive_util::IsSignedIntegralType(from_type));
+ return IntCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_),
+ primitive_util::IsSignedIntegralType(from_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (to_type == BF16) {
@@ -252,19 +255,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return EmitComposeComplex(
- op, b_->CreateSIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, SIToFP(operand_value, to_ir_component_type), nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return EmitComposeComplex(
- op, b_->CreateUIToFP(operand_value, to_ir_component_type),
- nullptr);
+ op, UIToFP(operand_value, to_ir_component_type), nullptr);
}
}
return Unimplemented("conversion from primitive type %s to %s",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str());
+ PrimitiveType_Name(from_type),
+ PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -275,14 +276,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str(),
+ PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
@@ -292,10 +292,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
if (is_signed) {
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = b_->CreateICmpSGE(operand_value, zero);
- return b_->CreateSelect(cmp, operand_value,
- b_->CreateNeg(operand_value));
+ auto cmp = ICmpSGE(operand_value, GetZero(type));
+ return Select(cmp, operand_value, Neg(operand_value));
} else {
return operand_value;
}
@@ -307,44 +305,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
{operand_value->getType()}, b_);
}
case HloOpcode::kSign: {
- bool is_signed =
- primitive_util::IsSignedIntegralType(op->shape().element_type());
+ CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
+ << op->shape().element_type();
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = b_->CreateICmpEQ(operand_value, zero);
- if (is_signed) {
- auto ashr =
- b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
- return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1));
- } else {
- return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1));
- }
+ auto cmp = ICmpEQ(operand_value, GetZero(type));
+ auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
+ return Select(cmp, GetZero(type), Or(ashr, 1));
}
case HloOpcode::kNegate:
- return b_->CreateNeg(operand_value);
+ return Neg(operand_value);
case HloOpcode::kNot: {
auto type = op->shape().element_type();
if (type == PRED) {
// It is not sufficient to just call CreateNot() here because a PRED
// is represented as an i8 and the truth value is stored only in the
// bottom bit.
- return b_->CreateZExt(
- b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())),
- llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
- return b_->CreateNot(operand_value);
+ return Not(operand_value);
}
return Unimplemented("unary op Not is not defined for type '%d'", type);
}
default:
return Unimplemented("unary integer op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -361,8 +352,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
return EmitComposeComplex(
op,
- b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType(
- to_component_type, module_)),
+ FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
if (from_type == BF16) {
@@ -378,26 +369,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (to_type == PRED) {
return b_->CreateZExt(
- b_->CreateFCmpUNE(
- operand_value,
- llvm::ConstantFP::get(operand_value->getType(), 0.0)),
+ FCmpUNE(operand_value,
+ llvm::ConstantFP::get(operand_value->getType(), 0.0)),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- return b_->CreateFPCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
- return b_->CreateFPToSI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToSI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
- return b_->CreateFPToUI(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return FPToUI(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str());
+ PrimitiveType_Name(from_type),
+ PrimitiveType_Name(to_type));
}
case HloOpcode::kBitcastConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -408,14 +398,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return b_->CreateBitCast(
- operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
+ return BitCast(operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
"bitcast conversion from primitive type %s to %s with unequal "
"bit-widths (%u versus %u) ",
- PrimitiveType_Name(from_type).c_str(),
- PrimitiveType_Name(to_type).c_str(),
+ PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
primitive_util::BitWidth(from_type),
primitive_util::BitWidth(to_type));
}
@@ -453,11 +442,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(operand_value, zero);
- auto olt = b_->CreateFCmpOLT(operand_value, zero);
- return b_->CreateSelect(
- oeq, zero,
- b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
+ auto oeq = FCmpOEQ(operand_value, zero);
+ auto olt = FCmpOLT(operand_value, zero);
+ return Select(oeq, zero,
+ Select(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
@@ -467,24 +455,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
auto infinity = llvm::ConstantFP::getInfinity(type);
- auto not_infinite = b_->CreateFCmpONE(abs_value, infinity);
+ auto not_infinite = FCmpONE(abs_value, infinity);
return b_->CreateZExt(not_infinite,
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
- return b_->CreateFNeg(operand_value);
+ return FNeg(operand_value);
case HloOpcode::kReal:
return operand_value;
case HloOpcode::kImag:
return llvm::ConstantFP::get(operand_value->getType(), 0.0);
default:
return Unimplemented("unary floating-point op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const {
+ const HloInstruction* op, llvm::Value* operand_value) {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType component_type =
primitive_util::IsComplexType(input_type)
@@ -496,12 +484,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
@@ -509,14 +496,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
- auto a_plus_one = b_->CreateFAdd(a, one);
- auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one),
- b_->CreateFMul(b, b));
+ auto a_plus_one = FAdd(a, one);
+ auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
- angle);
+ return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -530,11 +515,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
- return EmitComposeComplex(op,
- b_->CreateFPCast(EmitExtractReal(operand_value),
- to_ir_component_type),
- b_->CreateFPCast(EmitExtractImag(operand_value),
- to_ir_component_type));
+ return EmitComposeComplex(
+ op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
+ FPCast(EmitExtractImag(operand_value), to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
@@ -544,8 +527,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
- return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b),
- b_->CreateFMul(exp_a, sin_b));
+ return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
}
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
@@ -556,8 +538,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
- auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one);
- auto imag_result = b_->CreateFMul(exp_a, sin_b);
+ auto real_result = FSub(FMul(exp_a, cos_b), one);
+ auto imag_result = FMul(exp_a, sin_b);
return EmitComposeComplex(op, real_result, imag_result);
}
case HloOpcode::kCos: {
@@ -572,14 +554,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)),
- b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b)));
+ return EmitComposeComplex(op,
+ FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
+ FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
@@ -595,14 +576,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
- auto half_exp_neg_b =
- b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
- return EmitComposeComplex(
- op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)),
- b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b)));
+ return EmitComposeComplex(op,
+ FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
+ FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kTanh: {
/*
@@ -630,74 +610,63 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
- auto exp_neg_a =
- b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
- auto exp_2a_minus_exp_neg_2a = b_->CreateFSub(
- b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a));
- auto cos_b_sq = b_->CreateFMul(cos_b, cos_b);
- auto sin_b_sq = b_->CreateFMul(sin_b, sin_b);
- auto real_num =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
- b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
- auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b);
- auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a =
+ FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = FMul(cos_b, cos_b);
+ auto sin_b_sq = FMul(sin_b, sin_b);
+ auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ FMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = FMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a);
auto exp_a_plus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
- auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a);
+ FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a);
auto exp_a_minus_exp_neg_a_sq =
- b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
- auto imag_num = b_->CreateFMul(
- cos_b_sin_b,
- b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
- auto denom =
- b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
- b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
- return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom),
- b_->CreateFDiv(imag_num, denom));
+ FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = FMul(
+ cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
+ auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, FDiv(real_num, denom),
+ FDiv(imag_num, denom));
}
case HloOpcode::kAbs: {
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
{sum_sq->getType()}, b_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
- auto sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- b_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq = FAdd(
+ FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)),
+ FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero);
- return b_->CreateSelect(
+ auto oeq = FCmpOEQ(cplx_abs, zero);
+ return Select(
oeq, EmitComposeComplex(op, zero, zero),
- EmitComposeComplex(
- op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
- b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs)));
+ EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
+ FDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kNegate:
- return EmitComposeComplex(op,
- b_->CreateFNeg(EmitExtractReal(operand_value)),
- b_->CreateFNeg(EmitExtractImag(operand_value)));
+ return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
+ FNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
return EmitExtractReal(operand_value);
case HloOpcode::kImag:
return EmitExtractImag(operand_value);
default:
return Unimplemented("unary complex op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType operand_type = op->operand(0)->shape().element_type();
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
operand_type == PRED) {
@@ -712,21 +681,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
- return b_->CreateFAdd(lhs_value, rhs_value);
+ return FAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateFSub(lhs_value, rhs_value);
+ return FSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateFMul(lhs_value, rhs_value);
+ return FMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return b_->CreateFDiv(lhs_value, rhs_value);
+ return FDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
- return b_->CreateFRem(lhs_value, rhs_value);
+ return FRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@@ -763,66 +731,52 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
default:
return Unimplemented("binary floating point op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
switch (op->opcode()) {
case HloOpcode::kAdd:
- return EmitComposeComplex(op,
- b_->CreateFAdd(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFAdd(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
- return EmitComposeComplex(op,
- b_->CreateFSub(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFSub(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(
+ op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
return EmitComposeComplex(
op,
- b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value))));
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
auto rhs_sum_sq =
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(rhs_value),
- EmitExtractImag(rhs_value)));
+ FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero);
- auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero);
- auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero);
- return b_->CreateSelect(
+ auto oeq = FCmpOEQ(rhs_sum_sq, zero);
+ auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero);
+ auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero);
+ return Select(
oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
- EmitComposeComplex(
- op,
- b_->CreateFDiv(
- b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq),
- b_->CreateFDiv(
- b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value))),
- rhs_sum_sq)));
+ EmitComposeComplex(op,
+ FDiv(FAdd(FMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq),
+ FDiv(FSub(FMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)),
+ FMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value))),
+ rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
@@ -832,21 +786,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
- return b_->CreateAnd(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kNe:
- return b_->CreateOr(
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), b_),
- llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
- EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), b_));
+ return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value), b_),
+ llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
+ EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kPower: {
// (a+bi)^(c+di) =
@@ -858,45 +810,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
auto b = EmitExtractImag(lhs_value);
auto c = EmitExtractReal(rhs_value);
auto d = EmitExtractImag(rhs_value);
- auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
+ auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
- auto half_c = b_->CreateFMul(one_half, c);
+ auto half_c = FMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
- auto neg_d = b_->CreateFNeg(d);
+ auto neg_d = FNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
- auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs);
+ auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
- auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
- auto half_d = b_->CreateFMul(one_half, d);
- auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs),
- b_->CreateFMul(half_d, ln_aa_p_bb));
+ auto half_d = FMul(one_half, d);
+ auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
- return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q),
- b_->CreateFMul(coeff, sin_q));
+ return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q));
}
default:
return Unimplemented("binary complex op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
- llvm::Value* x) const {
+ llvm::Value* x) {
if (prim_type != F32) {
// TODO(b/34339814): Implement inverse erf for F64.
return Unimplemented(
@@ -906,12 +856,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(b_->getFloatTy(), f);
};
- auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
+ auto multiply_add = [&](absl::Span<const float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
- coefficients.pop_front();
+ coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), getFloat(coefficient));
}
return p;
};
@@ -931,25 +881,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::log, {b_->getFloatTy()});
- llvm::Value* w = b_->CreateFNeg(b_->CreateCall(
- logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x),
- b_->CreateFAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(
+ Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
llvm::Value* p_addr =
llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
+ FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
// Handle true BB.
SetToFirstInsertPoint(if_data.true_block, b_);
{
- llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f));
- tensorflow::gtl::ArraySlice<float> lq{
+ llvm::Value* lw = FSub(w, getFloat(2.5f));
+ absl::Span<const float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
llvm::Value* p = multiply_add(lq, lw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
// Handle false BB.
@@ -958,76 +907,73 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
- llvm::Value* gw =
- b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
- tensorflow::gtl::ArraySlice<float> gq{
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
+ absl::Span<const float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
llvm::Value* p = multiply_add(gq, gw);
- b_->CreateStore(p, p_addr);
+ Store(p, p_addr);
}
SetToFirstInsertPoint(if_data.after_block, b_);
- llvm::Value* p = b_->CreateLoad(p_addr);
- return b_->CreateFMul(p, x);
+ llvm::Value* p = Load(p_addr);
+ return FMul(p, x);
}
-StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
+ llvm::Value* value) {
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
- return EmitErfInv(prim_type, b_->CreateFSub(one, value));
+ return EmitErfInv(prim_type, FSub(one, value));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
auto negative_half = llvm::ConstantFP::get(type, -0.5);
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
- TF_ASSIGN_OR_RETURN(auto for_large_x,
- EmitLog(prim_type, b_->CreateFAdd(x, one)));
+ TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
- auto for_small_x =
- b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x);
+ auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
+ auto x_is_small = FCmpOLT(
abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
- return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
auto x = value;
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
@@ -1035,40 +981,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
- auto for_large_x = b_->CreateFSub(exp_x, one);
+ auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
- auto x_squared = b_->CreateFAdd(x, x);
- auto x_squared_over_two = b_->CreateFMul(x_squared, half);
- auto for_small_x = b_->CreateFAdd(x, x_squared_over_two);
+ auto x_squared = FAdd(x, x);
+ auto x_squared_over_two = FMul(x_squared, half);
+ auto for_small_x = FAdd(x, x_squared_over_two);
const auto kExponentIsSmallThreshold = 1e-5;
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
- auto x_is_small = b_->CreateFCmpOLT(
- abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
- return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ auto x_is_small =
+ FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
+ return Select(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return Unimplemented("atan2");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const {
+ llvm::Value* value) {
return Unimplemented("tanh");
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
- const HloInstruction* hlo, llvm::Value* x) const {
+ const HloInstruction* hlo, llvm::Value* x) {
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
@@ -1099,23 +1045,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
+llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
+}
+
+llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
+}
+
+llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
+ integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(
+ integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
+ return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
+}
+
+llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
+ llvm::Value* rhs) {
+ return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
+ ICmpEQ(rhs, GetMinusOne(rhs->getType())));
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) {
+ // Integer division overflow behavior:
+ //
+ // X / 0 == -1
+ // INT_SMIN /s -1 = INT_SMIN
+
+ if (!is_signed) {
+ llvm::Value* udiv_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = UDiv(lhs, safe_rhs);
+ return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = SDiv(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, GetMinusOne(lhs->getType()),
+ Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) {
+ // Integer remainder overflow behavior:
+ //
+ // X % 0 == X
+ // INT_SMIN %s -1 = 0
+
+ if (!is_signed) {
+ llvm::Value* urem_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = URem(lhs, safe_rhs);
+ return Select(urem_is_unsafe, lhs, safe_rem);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = SRem(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, lhs,
+ Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const {
+ bool is_signed) {
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
- return b_->CreateAdd(lhs_value, rhs_value);
+ return Add(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return b_->CreateSub(lhs_value, rhs_value);
+ return Sub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return b_->CreateMul(lhs_value, rhs_value);
+ return Mul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return is_signed ? b_->CreateSDiv(lhs_value, rhs_value)
- : b_->CreateUDiv(lhs_value, rhs_value);
+ return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
- return is_signed ? b_->CreateSRem(lhs_value, rhs_value)
- : b_->CreateURem(lhs_value, rhs_value);
+ return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
@@ -1143,11 +1169,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
case HloOpcode::kMaximum:
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
- return b_->CreateAnd(lhs_value, rhs_value);
+ return And(lhs_value, rhs_value);
case HloOpcode::kOr:
- return b_->CreateOr(lhs_value, rhs_value);
+ return Or(lhs_value, rhs_value);
case HloOpcode::kXor:
- return b_->CreateXor(lhs_value, rhs_value);
+ return Xor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
@@ -1156,43 +1182,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
// UB.
case HloOpcode::kShiftRightArithmetic:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateAShr(lhs_value, rhs_value),
+ AShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/true);
case HloOpcode::kShiftLeft:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateShl(lhs_value, rhs_value),
+ Shl(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
case HloOpcode::kShiftRightLogical:
return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
- b_->CreateLShr(lhs_value, rhs_value),
+ LShr(lhs_value, rhs_value),
/*saturate_to_sign_bit=*/false);
default:
return Unimplemented("binary integer op '%s'",
- HloOpcodeString(op->opcode()).c_str());
+ HloOpcodeString(op->opcode()));
}
}
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
- return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
- : llvm::ICmpInst::ICMP_UGE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ bool is_signed) {
+ return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
+ : llvm::ICmpInst::ICMP_UGE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const {
- return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
- : llvm::ICmpInst::ICMP_ULE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ bool is_signed) {
+ return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
+ : llvm::ICmpInst::ICMP_ULE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const {
+ int64 operand_no) {
CHECK(hlo.IsElementwise())
<< "HLO " << hlo.ToString() << " is not elementwise.";
@@ -1233,7 +1259,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const {
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) {
TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma,
@@ -1251,17 +1277,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// Perform the division using the float type with the same number of bits
// as the raw value to avoid overflow.
if (raw_value_size_in_bits == 32) {
- elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy());
- elem_value = b_->CreateFDiv(
- elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
+ elem_value = UIToFP(elem_value, b_->getFloatTy());
+ elem_value = FDiv(elem_value,
+ llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32)));
} else {
- elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy());
- elem_value = b_->CreateFDiv(
+ elem_value = UIToFP(elem_value, b_->getDoubleTy());
+ elem_value = FDiv(
elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64)));
}
if (elem_ir_ty != elem_value->getType()) {
- elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty);
+ elem_value = FPTrunc(elem_value, elem_ir_ty);
}
}
@@ -1269,9 +1295,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
switch (hlo->random_distribution()) {
case RNG_UNIFORM: {
if (elem_ir_ty->isFloatingPointTy()) {
- return b_->CreateFAdd(
- b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value),
- a_or_mean);
+ return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean);
} else {
// To generate a uniform random value in [a, b) from a raw random sample
// in range [0, 2^N), we let range = b - a and return
@@ -1284,22 +1308,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution(
// the same cost as if the whole warp were to re-sample. So an
// efficient re-sampling implementation on GPU would need to do
// nontrivial work to share entropy between threads in the warp.
- auto range = b_->CreateSub(b_or_sigma, a_or_mean);
- return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range));
+ auto range = Sub(b_or_sigma, a_or_mean);
+ return Add(a_or_mean, URem(elem_value, range));
}
}
case RNG_NORMAL: {
TF_ASSIGN_OR_RETURN(
llvm::Value * r,
- EmitErfcInv(elem_prim_ty,
- b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
- elem_value)));
- return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean);
+ EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0),
+ elem_value)));
+ return FAdd(FMul(r, b_or_sigma), a_or_mean);
}
default:
return InvalidArgument(
"unhandled distribution %s",
- RandomDistribution_Name(hlo->random_distribution()).c_str());
+ RandomDistribution_Name(hlo->random_distribution()));
}
}
@@ -1414,8 +1437,7 @@ std::array<llvm::Value*, 4> CalculateSampleValues(
// Precondition: the RNG instruction is not fused.
llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
VLOG(3) << "Using philox RNG algorithm";
CHECK(!hlo->IsFused());
// A random number generated by the per module random number generator.
@@ -1438,7 +1460,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Load the global state variable for the Philox RNG algorithm.
llvm::GlobalVariable* rng_state_ptr =
llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_);
- llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value");
+ llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value");
// Build and return the elemental IR generator to generate a random value for
// the element corresponding to the current thread.
@@ -1464,8 +1486,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// element within the sample.
llvm::Value* elems_per_sample_value =
llvm::ConstantInt::get(index_ty, elems_per_sample);
- llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value);
- llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value);
+ llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value);
+ llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value);
std::array<llvm::Value*, 4> counter_values = CalculateSampleValues(
sample_idx, hlo_random_value, global_random_number, rng_state, b_);
@@ -1473,18 +1495,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
// Store the four counter_values into the sample_address alloca so we can
// load the elem_offset'th one below.
for (int idx = 0; idx < 4; ++idx) {
- b_->CreateStore(counter_values[idx],
- b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx)));
+ Store(counter_values[idx],
+ InBoundsGEP(sample_address, b_->getInt32(idx)));
}
llvm::Type* int64_ty = b_->getInt64Ty();
CHECK(elems_per_sample == 2 || elems_per_sample == 4);
llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty;
// Retrieve the raw value for the current element from the current sample.
- llvm::Value* raw_elem_value = b_->CreateLoad(
- b_->CreateInBoundsGEP(
- b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()),
- elem_offset),
+ llvm::Value* raw_elem_value = Load(
+ InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()),
+ elem_offset),
"raw_elem_value");
return ConvertValueForDistribution(hlo, operand_to_generator, index,
@@ -1495,7 +1516,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1505,14 +1526,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
- return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()),
- on_true_value, on_false_value);
+ return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
+ on_false_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
@@ -1531,14 +1552,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
- PrimitiveType_Name(prim_type).c_str());
+ PrimitiveType_Name(prim_type));
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const {
+ const llvm_ir::IrArray::Index& target_index) {
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
@@ -1560,9 +1581,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
}
llvm_ir::SetToFirstInsertPoint(exit_block, b_);
- llvm::PHINode* output = b_->CreatePHI(
- llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
- hlo->operands().size());
+ llvm::PHINode* output =
+ PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
+ hlo->operands().size());
auto prior_insert_point = b_->GetInsertPoint();
b_->SetInsertPoint(init_block);
@@ -1577,9 +1598,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
auto concat_dim_size =
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
operand->shape().dimensions(concat_dim));
- b_->CreateCondBr(
- b_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
- true_block, false_block);
+ CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block,
+ false_block);
// Create the terminator of the true block before calling operand
// generators, because they require non-degenerate basic blocks.
@@ -1592,11 +1612,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
// Subtract the size of the concat dimension of the current operand
// from the source index.
b_->SetInsertPoint(false_block);
- source_index[concat_dim] =
- b_->CreateSub(source_index[concat_dim], concat_dim_size);
+ source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size);
}
- b_->CreateUnreachable();
+ Unreachable();
b_->SetInsertPoint(exit_block, prior_insert_point);
return output;
}
@@ -1604,7 +1623,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
@@ -1621,7 +1640,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
int64 largest_valid_start_index =
input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
CHECK_GE(largest_valid_start_index, 0);
@@ -1641,7 +1660,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
- input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]);
+ input_index[i] = Add(slice_start_index[i], index[i]);
}
return operand_to_generator.at(input_hlo)(input_index);
}
@@ -1649,7 +1668,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const Shape& operand_shape = hlo->operand(0)->shape();
const Shape& indices_shape = hlo->operand(1)->shape();
const Shape& output_shape = hlo->shape();
@@ -1672,22 +1691,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- int64 output_window_dim =
- dim_numbers.output_window_dims(operand_index_dim++);
+ int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
operand_to_output_dim[i] = output_window_dim;
operand_index.push_back(index[output_window_dim]);
}
}
- // This is the index of the index vector in the gather_indices tensor.
+ // This is the index of the index vector in the start_indices tensor.
IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
gather_index_index.push_back(index[i]);
}
}
@@ -1699,8 +1717,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
- b_->CreateSExtOrTrunc(index_component, index_type);
- int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ SExtOrTrunc(index_component, index_type);
+ int64 operand_dim = dim_numbers.start_index_map(dim);
int64 output_dim = operand_to_output_dim[operand_dim];
// If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
// This means we set the iteration index to 0, so for the purpose of the
@@ -1723,8 +1741,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
gather_dim_component_extended, is_signed),
is_signed);
- operand_index[operand_dim] = b_->CreateAdd(
- operand_index[operand_dim], gather_dim_component_extended_inbound);
+ operand_index[operand_dim] =
+ Add(operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
@@ -1748,7 +1766,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const {
+ const llvm_ir::IrArray::Index& index) {
const HloInstruction* input_hlo = hlo->operand(0);
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
@@ -1771,7 +1789,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
- start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ start_index_value = SExtOrTrunc(start_index_value, index_type);
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
int64 largest_valid_start_index =
@@ -1787,14 +1805,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
- slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size);
-
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]),
- "slice_intersection");
- slice_intersection = b_->CreateAnd(
- slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]),
- "slice_intersection");
+ slice_limit_index[i] = Add(slice_start_index[i], update_dim_size);
+
+ slice_intersection =
+ And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]),
+ "slice_intersection");
+ slice_intersection =
+ And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]),
+ "slice_intersection");
}
// Emit:
@@ -1811,26 +1829,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
- update_index[i] = b_->CreateSub(index[i], slice_start_index[i]);
+ update_index[i] = Sub(index[i], slice_start_index[i]);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
- b_->CreateStore(true_value, ret_value_addr);
+ Store(true_value, ret_value_addr);
// Handle false BB (return data from 'input')
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
- b_->CreateStore(false_value, ret_value_addr);
+ Store(false_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const {
+ const llvm_ir::IrArray::Index& padded_index) {
auto index = padded_index;
llvm::Value* in_bounds = b_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
@@ -1838,26 +1856,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
- index[i] =
- b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low()));
- in_bounds = b_->CreateAnd(in_bounds,
- b_->CreateICmpSGE(index[i], index_typed_const(0)),
- "in_bounds");
- in_bounds = b_->CreateAnd(
+ index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds =
+ And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds");
+ in_bounds = And(
in_bounds,
- b_->CreateICmpEQ(
+ ICmpEQ(
index_typed_const(0),
- b_->CreateURem(index[i],
- index_typed_const(pad_dim.interior_padding() + 1))),
- "in_bounds");
- index[i] = b_->CreateSDiv(
- index[i], index_typed_const(pad_dim.interior_padding() + 1));
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpSLT(
- index[i],
- index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
+ index[i] =
+ SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1));
+ in_bounds =
+ And(in_bounds,
+ ICmpSLT(index[i],
+ index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ "in_bounds");
}
// if (in_bounds) {
@@ -1873,26 +1887,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
SetToFirstInsertPoint(if_data.true_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
- b_->CreateStore(operand_value, ret_value_addr);
+ Store(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(padding_value, ret_value_addr);
+ Store(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, b_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
- return b_->CreateLoad(ret_value_addr);
+ return Load(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const {
+ const llvm_ir::IrArray::Index& dot_result_index) {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
@@ -1920,8 +1934,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
llvm::Value* accumulator_alloca =
llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
- b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
- accumulator_alloca);
+ Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
@@ -1943,42 +1956,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
}
rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
- llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca);
+ llvm::Value* current_accumulator = Load(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
- llvm::Value* product_real = b_->CreateFSub(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
- llvm::Value* product_imag = b_->CreateFAdd(
- b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
- b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
- next_accumulator = b_->CreateInsertValue(
+ llvm::Value* product_real =
+ FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
+ llvm::Value* product_imag =
+ FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
+ next_accumulator = InsertValue(
current_accumulator,
- b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real),
- {0});
- next_accumulator = b_->CreateInsertValue(
+ FAdd(EmitExtractReal(current_accumulator), product_real), {0});
+ next_accumulator = InsertValue(
next_accumulator,
- b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag),
- {1});
+ FAdd(EmitExtractImag(current_accumulator), product_imag), {1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
- next_accumulator = b_->CreateFAdd(current_accumulator,
- b_->CreateFMul(lhs_value, rhs_value));
+ next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value));
} else {
- next_accumulator =
- b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value));
+ next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value));
}
- b_->CreateStore(next_accumulator, accumulator_alloca);
+ Store(next_accumulator, accumulator_alloca);
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
- return b_->CreateLoad(accumulator_alloca);
+ return Load(accumulator_alloca);
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
- const {
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -2072,10 +2080,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
auto source_index = target_index;
for (int64 dim : hlo->dimensions()) {
- source_index[dim] = b_->CreateSub(
- llvm::ConstantInt::get(target_index[dim]->getType(),
- hlo->shape().dimensions(dim) - 1),
- target_index[dim]);
+ source_index[dim] =
+ Sub(llvm::ConstantInt::get(target_index[dim]->getType(),
+ hlo->shape().dimensions(dim) - 1),
+ target_index[dim]);
}
return operand_to_generator.at(operand)(source_index);
};
@@ -2089,6 +2097,61 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
hlo->dimensions(), b_));
};
+ case HloOpcode::kIota:
+ return [this, hlo](
+ const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ auto* iota = Cast<HloIotaInstruction>(hlo);
+ PrimitiveType element_type = iota->shape().element_type();
+ IrArray::Index elem_index =
+ ShapeUtil::Rank(iota->shape()) > 1
+ ? target_index.SourceIndexOfBroadcast(
+ iota->shape(),
+ ShapeUtil::MakeShapeWithDescendingLayout(
+ element_type,
+ {iota->shape().dimensions(iota->iota_dimension())}),
+ {iota->iota_dimension()}, b_)
+ : target_index;
+ llvm::Value* elem_index_linear = elem_index.linear();
+ if (elem_index_linear == nullptr) {
+ std::vector<int64> iota_bound = {
+ iota->shape().dimensions(iota->iota_dimension())};
+ elem_index_linear = elem_index.Linearize(iota_bound, b_);
+ }
+ Shape component_shape =
+ ShapeUtil::ElementIsComplex(iota->shape())
+ ? ShapeUtil::ComplexComponentShape(iota->shape())
+ : iota->shape();
+ PrimitiveType component_element_type = component_shape.element_type();
+ llvm::Value* iota_result;
+ if (ShapeUtil::ElementIsIntegral(component_shape)) {
+ iota_result = b_->CreateIntCast(
+ elem_index_linear,
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
+ /*isSigned=*/false);
+ } else {
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape))
+ << component_element_type;
+ llvm::Type* float_ir_type;
+ if (component_element_type == BF16) {
+ float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
+ } else {
+ float_ir_type =
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
+ }
+ llvm::Value* float_val =
+ b_->CreateUIToFP(elem_index_linear, float_ir_type);
+ if (component_element_type == BF16) {
+ iota_result = EmitF32ToBF16(float_val, b_);
+ } else {
+ iota_result = float_val;
+ }
+ }
+ if (ShapeUtil::ElementIsComplex(iota->shape())) {
+ return EmitComposeComplex(iota, iota_result, nullptr);
+ } else {
+ return iota_result;
+ }
+ };
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
@@ -2154,28 +2217,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
- HloOpcodeString(hlo->opcode()).c_str());
+ HloOpcodeString(hlo->opcode()));
};
}
}
-llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {0});
+llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
+ return ExtractValue(value, {0});
}
-llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
- return b_->CreateExtractValue(value, {1});
+llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
+ return ExtractValue(value, {1});
}
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
llvm::Value* real,
- llvm::Value* imag) const {
+ llvm::Value* imag) {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto complex = b_->CreateInsertValue(
- llvm::ConstantAggregateZero::get(cplx_type), real, {0});
+ auto complex =
+ InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
- complex = b_->CreateInsertValue(complex, imag, {1});
+ complex = InsertValue(complex, imag, {1});
}
return complex;
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 1598a4dd85..d3e2acaabd 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -23,12 +23,13 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
-class ElementalIrEmitter {
+class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
public:
using HloToElementGeneratorMap =
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
@@ -40,100 +41,114 @@ class ElementalIrEmitter {
virtual ~ElementalIrEmitter() = default;
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
- llvm::Value* operand_value) const;
+ llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Returns a function to generate an element of the output of `hlo`, given a
// map of functions to generate elements of its operands.
virtual llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
- llvm::IRBuilder<>* b() const { return b_; }
- llvm::Module* module() const { return module_; }
+ llvm::IRBuilder<>* b() { return b_; }
+
+ // builder() is for IrBuilderMixin.
+ llvm::IRBuilder<>* builder() { return b_; }
+
+ llvm::Module* module() { return module_; }
protected:
- virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
+
+ virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value);
- virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
- const HloInstruction* op, llvm::Value* operand_value) const;
+ llvm::Value* IsZero(llvm::Value* v);
+ llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* GetZero(llvm::Type* type);
+ llvm::Value* GetOne(llvm::Type* type);
+ llvm::Value* GetIntSMin(llvm::Type* type);
+ llvm::Value* GetMinusOne(llvm::Type* type);
+
+ llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed);
+ llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
- virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
- virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
- bool is_signed) const;
+ bool is_signed);
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
- llvm::Value* lhs,
- llvm::Value* rhs) const;
+ llvm::Value* lhs, llvm::Value* rhs);
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const;
+ llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
- llvm::Value* x) const;
+ llvm::Value* x);
- virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
- virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
+ virtual llvm::Value* EmitExtractReal(llvm::Value* value);
+ virtual llvm::Value* EmitExtractImag(llvm::Value* value);
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
- llvm::Value* imag) const;
+ llvm::Value* imag);
// A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
// the target array index, computes the source array index of its
@@ -142,50 +157,50 @@ class ElementalIrEmitter {
// Precondition: `hlo` is an elementwise op.
llvm_ir::IrArray::Index ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
- int64 operand_no) const;
+ int64 operand_no);
// Identifier of the thread unique among all threads on the device
- virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); }
+ virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
StatusOr<llvm::Value*> EmitElementalSelect(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalClamp(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalConcatenate(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& target_index) const;
+ const llvm_ir::IrArray::Index& target_index);
StatusOr<llvm::Value*> EmitElementalDynamicSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalGather(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index) const;
+ const llvm_ir::IrArray::Index& index);
StatusOr<llvm::Value*> EmitElementalPad(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& padded_index) const;
+ const llvm_ir::IrArray::Index& padded_index);
StatusOr<llvm::Value*> EmitElementalDot(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& dot_result_index) const;
+ const llvm_ir::IrArray::Index& dot_result_index);
llvm::IRBuilder<>* const b_;
@@ -200,13 +215,13 @@ class ElementalIrEmitter {
// random number generation algorithm.
llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const;
+ const HloToElementGeneratorMap& operand_to_generator);
// Converts the raw value generated by a random number generation algorithm
// to the distribution requested by the RNG HloInstruction.
StatusOr<llvm::Value*> ConvertValueForDistribution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
- const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const;
+ const llvm_ir::IrArray::Index& index, llvm::Value* raw_value);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index addb016b04..852f34e06d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -24,12 +24,11 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -57,9 +56,9 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
- RunTest(hlo_text, {lhs.get(), rhs.get()});
+ Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {&lhs, &rhs});
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index fd75847d0c..47c56e2f7f 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/status.h"
@@ -22,16 +24,14 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
-using tensorflow::gtl::ArraySlice;
namespace xla {
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
- ArraySlice<const ServiceExecutableRunOptions> run_options,
- ArraySlice<ArraySlice<const ShapedBuffer*>> arguments) {
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
TF_RET_CHECK(run_options.size() == arguments.size());
std::vector<ScopedShapedBuffer> return_values;
@@ -62,7 +62,7 @@ StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
se::Stream* stream = run_options->stream();
std::unique_ptr<se::Timer> timer;
if (profile != nullptr) {
@@ -76,8 +76,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
- ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
- &hlo_profile_index_map())
+ ? absl::make_unique<HloExecutionProfile>(&hlo_profile_printer_data(),
+ &hlo_profile_index_map())
: nullptr;
StatusOr<ScopedShapedBuffer> return_value =
@@ -154,9 +154,9 @@ Status Executable::DumpHloSnapshot() {
const string& directory_path =
module_config().debug_options().xla_dump_executions_to();
const auto& module = hlo_snapshot_->hlo().hlo_module();
- string filename = tensorflow::strings::Printf(
- "computation_%lld__%s__execution_%lld", module.id(),
- module.entry_computation_name().c_str(), ++execution_count_);
+ string filename =
+ absl::StrFormat("computation_%d__%s__execution_%d", module.id(),
+ module.entry_computation_name(), ++execution_count_);
return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_);
}
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 98eaeee30a..3a6780f2a6 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -18,7 +18,10 @@ limitations under the License.
#include <memory>
#include <utility>
+#include <vector>
+#include "absl/types/span.h"
+#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -26,18 +29,33 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace xla {
+// ExecutionOutput encapsulates the output buffers of a execution and the
+// leftover buffers to be released by the caller.
+struct ExecutionOutput {
+ ExecutionOutput(ScopedShapedBuffer result,
+ std::vector<OwningDeviceMemory> to_be_released)
+ : result(std::move(result)), to_be_released(std::move(to_be_released)) {}
+ ScopedShapedBuffer result;
+
+ // Leftover buffers for the caller to release. Elements in this list are
+ // donated input memory buffers that are not reused by XLA as outputs.
+ std::vector<OwningDeviceMemory> to_be_released;
+};
+
// A given platform's compiler will produce an Executable -- this is a uniform
// interface that is used for launching compiled programs across platforms.
class Executable {
@@ -63,25 +81,46 @@ class Executable {
// Returns a shaped buffer containing the result of the computation.
virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
// soon as all of the operations are enqueued for launch on the stream.
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) = 0;
+ absl::Span<const ShapedBuffer* const> arguments) = 0;
+
+ // Starts the given program executing on the given stream/executor.
+ //
+ // `arguments` are ShapeTree containing the input parameters. For each element
+ // in the shape tree, if the element holds the ownership of the memory, it is
+ // considered donated and XLA will potentially reuse it as output buffers. For
+ // all donated inputs, XLA is also responsible for freeing them.
+ //
+ // If an input is donated to XLA but is not reused as output, it is returned
+ // as an leftover buffer for the caller to release.
+ virtual StatusOr<ExecutionOutput> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ return Unimplemented(
+ "MaybeOwningDeviceMemory version of overload is not implemented ");
+ }
+
+ virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments) {
+ return Unimplemented(
+ "MaybeOwningDeviceMemory version of overload is not implemented ");
+ }
// Same as ExecuteOnStream(), but runs this executable on multiple
// streams. arguments[i] contains the arguments to the execution on
// run_options[i]->stream() and the returned value is at index i of the
// returned vector.
virtual StatusOr<std::vector<ScopedShapedBuffer>> ExecuteOnStreams(
- tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions>
- run_options,
- tensorflow::gtl::ArraySlice<
- tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- arguments);
+ absl::Span<const ServiceExecutableRunOptions> run_options,
+ absl::Span<const absl::Span<const ShapedBuffer* const>> arguments);
// Populates `hlo_execution_profile` from `executor`. This is implicit in any
// Execute* API call that takes a hlo_execution_profile argument, but must be
@@ -97,7 +136,7 @@ class Executable {
// given ExecutionProfile if non-null.
StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
// Returns the ExecutionProfile from executing on the device. This includes
// the number of cycles taken for the computation or the compilation time.
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index 228c3fac95..997db7c058 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend,
tensorflow::mutex_lock lock(execution_mutex_);
int64 handle = next_handle_++;
auto inserted = handle_to_execution_.emplace(
- handle,
- MakeUnique<AsyncExecution>(backend, std::move(streams), profile, result));
+ handle, absl::make_unique<AsyncExecution>(backend, std::move(streams),
+ profile, result));
CHECK(inserted.second);
ExecutionHandle execution_handle;
@@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
tensorflow::mutex_lock lock(execution_mutex_);
auto it = handle_to_execution_.find(handle.handle());
if (it == handle_to_execution_.end()) {
- return NotFound("no execution record for execution handle: %lld",
+ return NotFound("no execution record for execution handle: %d",
handle.handle());
}
handle_to_execution_.erase(handle.handle());
@@ -78,7 +78,7 @@ StatusOr<const AsyncExecution*> ExecutionTracker::Resolve(
tensorflow::mutex_lock lock(execution_mutex_);
auto it = handle_to_execution_.find(handle.handle());
if (it == handle_to_execution_.end()) {
- return NotFound("no execution record for execution handle: %lld",
+ return NotFound("no execution record for execution handle: %d",
handle.handle());
}
return it->second.get();
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index d3efab3614..3cccec9862 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -28,7 +28,7 @@ namespace xla {
// points-to analysis (see b/36865746 for details).
class FlattenCallGraph : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "flatten-call-graph"; }
+ absl::string_view name() const override { return "flatten-call-graph"; }
// Duplicates computations called from multiple call- or while-nodes to
// flatten the call graph.
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index e3a42d0d06..cb86c98579 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -24,88 +25,87 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::ArraySlice;
static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
- HloInstruction* gather_indices, int64 index_vector_dim) {
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices, int64 index_vector_dim) {
+ const Shape& start_indices_shape = start_indices->shape();
- if (gather_indices_shape.dimensions_size() == index_vector_dim) {
- return gather_indices;
+ if (start_indices_shape.dimensions_size() == index_vector_dim) {
+ return start_indices;
}
- if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) {
- return gather_indices;
+ if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) {
+ return start_indices;
}
std::vector<int64> permutation;
- permutation.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ permutation.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
permutation.push_back(i);
}
}
permutation.push_back(index_vector_dim);
- return MakeTransposeHlo(gather_indices, permutation);
+ return MakeTransposeHlo(start_indices, permutation);
}
-// Canonicalizes the gather_indices tensors so that we only have deal with some
+// Canonicalizes the start_indices tensors so that we only have deal with some
// specific cases in the while loop that does the heavy lifting.
//
// See the "High Level Algorithm" section for a broader picture.
static StatusOr<HloInstruction*> CanonicalizeGatherIndices(
- HloInstruction* gather_indices, int64 index_vector_dim) {
+ HloInstruction* start_indices, int64 index_vector_dim) {
// Transpose the non-index-vector dimensions to the front.
TF_ASSIGN_OR_RETURN(
- HloInstruction * transposed_gather_indices,
- TransposeIndexVectorDimToLast(gather_indices, index_vector_dim));
+ HloInstruction * transposed_start_indices,
+ TransposeIndexVectorDimToLast(start_indices, index_vector_dim));
bool indices_are_scalar =
- index_vector_dim == gather_indices->shape().dimensions_size();
+ index_vector_dim == start_indices->shape().dimensions_size();
- // The number of dimensions in gather_indices that are index dimensions.
- const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1;
+ // The number of dimensions in start_indices that are index dimensions.
+ const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1;
- // If there is only one index (i.e. gather_indices has rank 1 and this gather
+ // If there is only one index (i.e. start_indices has rank 1 and this gather
// is really just a dynamic slice) add a leading degenerate dimension for
// uniformity. Otherwise create a "collapsed" leading dimension that subsumes
// all of the non-index-vector dimensions.
- const Shape& shape = transposed_gather_indices->shape();
- if (shape.dimensions_size() == index_dims_in_gather_indices) {
- return PrependDegenerateDims(transposed_gather_indices, 1);
+ const Shape& shape = transposed_start_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_start_indices) {
+ return PrependDegenerateDims(transposed_start_indices, 1);
} else {
- // Collapse all but the dimensions (0 or 1) in gather_indices containing the
+ // Collapse all but the dimensions (0 or 1) in start_indices containing the
// index vectors.
return CollapseFirstNDims(
- transposed_gather_indices,
- shape.dimensions_size() - index_dims_in_gather_indices);
+ transposed_start_indices,
+ shape.dimensions_size() - index_dims_in_start_indices);
}
}
// Expands out or contracts away the gather dimensions in the accumulator
// produced by the while loop.
-static StatusOr<HloInstruction*> AdjustGatherDimsInAccumulator(
- const Shape& gather_indices_shape, HloInstruction* accumulator,
+static StatusOr<HloInstruction*> AdjustBatchDimsInAccumulator(
+ const Shape& start_indices_shape, HloInstruction* accumulator,
int64 index_vector_dim) {
- std::vector<int64> output_gather_dim_bounds;
- output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size());
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ std::vector<int64> batch_dim_bounds;
+ batch_dim_bounds.reserve(start_indices_shape.dimensions_size());
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != index_vector_dim) {
- output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i));
+ batch_dim_bounds.push_back(start_indices_shape.dimensions(i));
}
}
- if (output_gather_dim_bounds.empty()) {
- // If output_gather_dim_bounds is empty we must be lowering a (effectively)
+ if (batch_dim_bounds.empty()) {
+ // If batch_dim_bounds is empty we must be lowering a (effectively)
// dynamic-slice. In that case, there is a leading degenerate gather
// dimension that we added to make this special case play well with the
// general while loop which we need to remove now.
return ElideDegenerateDims(accumulator, {0});
}
- return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds);
+ return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds);
}
-// Expand an index vector from the gather_indices tensor into a vector that can
+// Expand an index vector from the start_indices tensor into a vector that can
// be used to dynamic-slice out of the gather operand.
static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
@@ -121,10 +121,8 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
std::vector<HloInstruction*> expanded_index_components;
for (int i = 0; i < operand_rank; i++) {
- int64 index_vector_dim_index =
- FindIndex(dim_numbers.gather_dims_to_operand_dims(), i);
- if (index_vector_dim_index !=
- dim_numbers.gather_dims_to_operand_dims_size()) {
+ int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
+ if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
TF_ASSIGN_OR_RETURN(
HloInstruction * component_to_concat,
MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
@@ -147,10 +145,10 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers();
CHECK_EQ(incoming_loop_state.size(), 3);
HloInstruction* const operand = incoming_loop_state[0];
- HloInstruction* const gather_indices = incoming_loop_state[1];
+ HloInstruction* const start_indices = incoming_loop_state[1];
HloInstruction* const output_accumulator = incoming_loop_state[2];
- bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1;
+ bool has_scalar_indices = start_indices->shape().dimensions_size() == 1;
CHECK_EQ(has_scalar_indices,
dim_numbers.index_vector_dim() ==
gather.operand(1)->shape().dimensions_size());
@@ -163,24 +161,24 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
HloInstruction* index_vector;
if (has_scalar_indices) {
- // In this case gather_indices has rank 1 and induction_var_as_vector (of
+ // In this case start_indices has rank 1 and induction_var_as_vector (of
// shape {1}) is an index into this rank 1 tensor.
TF_ASSIGN_OR_RETURN(
index_vector,
- MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1}));
+ MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1}));
} else {
- // In this case gather_indices has rank 2 and induction_var_as_vector (of
+ // In this case start_indices has rank 2 and induction_var_as_vector (of
// shape {1}) is an index into just the first dimension of this rank 2
// tensor.
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_into_gather_indices,
+ HloInstruction * index_into_start_indices,
PadVectorWithZeros(induction_var_as_vector,
/*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
- int64 index_vector_size = gather_indices->shape().dimensions(1);
+ int64 index_vector_size = start_indices->shape().dimensions(1);
TF_ASSIGN_OR_RETURN(
HloInstruction * index_vector_2d,
- MakeDynamicSliceHlo(gather_indices, index_into_gather_indices,
+ MakeDynamicSliceHlo(start_indices, index_into_start_indices,
{1, index_vector_size}));
TF_ASSIGN_OR_RETURN(index_vector,
@@ -194,26 +192,26 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice,
MakeDynamicSliceHlo(operand, gathered_slice_start,
- gather.gather_window_bounds()));
+ gather.gather_slice_sizes()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_with_dims_elided,
+ HloInstruction* const gathered_slice_with_dims_collapsed,
ElideDegenerateDims(gathered_slice,
- AsInt64Slice(dim_numbers.elided_window_dims())));
+ AsInt64Slice(dim_numbers.collapsed_slice_dims())));
TF_ASSIGN_OR_RETURN(
- HloInstruction * gathered_slice_for_update,
- PrependDegenerateDims(gathered_slice_with_dims_elided, 1));
+ HloInstruction* const gathered_slice_for_update,
+ PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1));
TF_ASSIGN_OR_RETURN(
- HloInstruction * index_vector_into_accumulator,
+ HloInstruction* const index_vector_into_accumulator,
PadVectorWithZeros(
induction_var_as_vector, /*zeros_to_prepend=*/0,
/*zeros_to_append=*/
- gathered_slice_with_dims_elided->shape().dimensions_size()));
+ gathered_slice_with_dims_collapsed->shape().dimensions_size()));
TF_ASSIGN_OR_RETURN(
- HloInstruction * updated_accumulator,
+ HloInstruction* const updated_accumulator,
MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update,
index_vector_into_accumulator));
@@ -221,19 +219,19 @@ static StatusOr<std::vector<HloInstruction*>> GatherLoopBody(
// WhileUtil::MakeCountedLoop functions takes care of the induction variable
// and the while loop exit condition.
return StatusOr<std::vector<HloInstruction*>>{
- {operand, gather_indices, updated_accumulator}};
+ {operand, start_indices, updated_accumulator}};
}
static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> window_bounds, int64 gather_loop_trip_count,
+ absl::Span<const int64> slice_sizes, int64 gather_loop_trip_count,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> accumulator_state_shape_dims;
- accumulator_state_shape_dims.reserve(1 + window_bounds.size());
+ accumulator_state_shape_dims.reserve(1 + slice_sizes.size());
accumulator_state_shape_dims.push_back(gather_loop_trip_count);
- for (int64 i = 0; i < window_bounds.size(); i++) {
- if (!c_binary_search(dim_numbers.elided_window_dims(), i)) {
- accumulator_state_shape_dims.push_back(window_bounds[i]);
+ for (int64 i = 0; i < slice_sizes.size(); i++) {
+ if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
+ accumulator_state_shape_dims.push_back(slice_sizes[i]);
}
}
return BroadcastZeros(computation, element_type,
@@ -241,23 +239,23 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue(
}
// `accumulator` is almost the tensor the gather operation would have produced,
-// except that it has the dimensions in the wrong order -- the gather dimensions
-// are the major dimensions and the window dimensions are the minor dimensions.
+// except that it has the dimensions in the wrong order -- the batch dimensions
+// are the major dimensions and the offset dimensions are the minor dimensions.
// Fix this up with a transpose.
-static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
- HloInstruction* accumulator, ArraySlice<int64> output_window_dims,
+static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
+ HloInstruction* accumulator, absl::Span<const int64> offset_dims,
int64 output_rank) {
std::vector<int64> permutation;
permutation.reserve(output_rank);
- int64 gather_idx_counter = 0;
- int64 window_idx_counter = output_rank - output_window_dims.size();
+ int64 batch_idx_counter = 0;
+ int64 offset_idx_counter = output_rank - offset_dims.size();
for (int64 i = 0; i < output_rank; i++) {
- bool is_window_dim = c_binary_search(output_window_dims, i);
- if (is_window_dim) {
- permutation.push_back(window_idx_counter++);
+ bool is_offset_dim = absl::c_binary_search(offset_dims, i);
+ if (is_offset_dim) {
+ permutation.push_back(offset_idx_counter++);
} else {
- permutation.push_back(gather_idx_counter++);
+ permutation.push_back(batch_idx_counter++);
}
}
@@ -268,11 +266,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
//
// We follow the following steps in sequence:
//
-// 1. We canonicalize the gather_indices tensor such that it has rank
+// 1. We canonicalize the start_indices tensor such that it has rank
// 2 (i.e. is a matrix) where each row is an index vector into the
// operand.
// 2. We iterate over the set of indices in the canonicalized
-// gather_indices tensor using a while loop, accumulating slices
+// start_indices tensor using a while loop, accumulating slices
// of the operand tensor into an accumulator using
// DynamicUpdateSlice.
// 3. The accumulator result from the while loop from (2) is then
@@ -287,11 +285,11 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
// operand = s32[3,3] parameter(0)
// indices = s32[2,2] parameter(1)
// ROOT gather = s32[2,3,2] gather(operand, indices),
-// output_window_dims={1},
-// elided_window_dims={1},
-// gather_dims_to_operand_dims={1},
+// offset_dims={1},
+// collapsed_slice_dims={1},
+// start_index_map={1},
// index_vector_dim=2,
-// window_bounds={3, 1}
+// slice_sizes={3, 1}
// }
//
// We'd first reshape indices to s32[4,1], where each row is an index
@@ -305,8 +303,8 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
- HloInstruction* gather_indices = gather_instr->mutable_operand(1);
- const Shape& gather_indices_shape = gather_indices->shape();
+ HloInstruction* start_indices = gather_instr->mutable_operand(1);
+ const Shape& start_indices_shape = start_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
@@ -314,9 +312,9 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
gather_instr->gather_dimension_numbers();
int64 gather_loop_trip_count = 1;
- for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) {
+ for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
- gather_loop_trip_count *= gather_indices_shape.dimensions(i);
+ gather_loop_trip_count *= start_indices_shape.dimensions(i);
}
}
@@ -324,27 +322,27 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
return Unimplemented(
"Gather operations with more than 2147483647 gather indices are not "
"supported. This error occurred for %s.",
- gather_instr->ToString().c_str());
+ gather_instr->ToString());
}
- TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices,
- CanonicalizeGatherIndices(
- gather_indices, dim_numbers.index_vector_dim()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_start_indices,
+ CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim()));
CHECK_EQ(gather_loop_trip_count,
- canonical_gather_indices->shape().dimensions(0));
+ canonical_start_indices->shape().dimensions(0));
TF_ASSIGN_OR_RETURN(
HloInstruction * accumulator_init,
CreateGatherLoopAccumulatorInitValue(
computation, output_shape.element_type(),
- gather_instr->gather_window_bounds(), gather_loop_trip_count,
+ gather_instr->gather_slice_sizes(), gather_loop_trip_count,
gather_instr->gather_dimension_numbers()));
StatusOr<std::vector<HloInstruction*>> gather_loop_result_or_error =
WhileUtil::MakeCountedLoop(
computation, gather_loop_trip_count,
- {operand, canonical_gather_indices, accumulator_init},
+ {operand, canonical_start_indices, accumulator_init},
[&](HloInstruction* indvar,
const std::vector<HloInstruction*>& loop_state) {
return GatherLoopBody(*gather_instr, indvar, loop_state);
@@ -356,13 +354,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* accumulator_result = gather_loop_result.back();
TF_ASSIGN_OR_RETURN(
- HloInstruction * accumulator_with_output_gather_dims_decanonicalized,
- AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result,
- dim_numbers.index_vector_dim()));
+ HloInstruction* const accumulator_with_batch_dims_decanonicalized,
+ AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result,
+ dim_numbers.index_vector_dim()));
- return PermuteGatherAndWindowDims(
- accumulator_with_output_gather_dims_decanonicalized,
- AsInt64Slice(dim_numbers.output_window_dims()), output_rank);
+ return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized,
+ AsInt64Slice(dim_numbers.offset_dims()),
+ output_rank);
}
StatusOr<bool> GatherExpander::Run(HloModule* module) {
@@ -375,8 +373,8 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
std::vector<HloInstruction*> gather_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
- c_copy_if(computation->instructions(), std::back_inserter(gather_instrs),
- is_nontrivial_gather);
+ absl::c_copy_if(computation->instructions(),
+ std::back_inserter(gather_instrs), is_nontrivial_gather);
}
for (HloInstruction* inst : gather_instrs) {
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index c1fc8574da..7bd9ea5984 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -25,7 +25,7 @@ namespace xla {
// nevertheless have a minimum level of support.
class GatherExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "gather_expander"; }
+ absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 020ffcd106..141dd4d6f1 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -28,11 +28,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2147483647,5] parameter(1)
ROOT gather = s32[2147483647,3,5] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -55,11 +55,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 0ce2db907b..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const {
}
Status GenericTransferManager::WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) {
TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
@@ -126,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
const auto subliteral = LiteralSlice(literal, index);
- std::unique_ptr<Literal> relayed_out_literal;
+ Literal relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
@@ -139,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->untyped_data();
+ source = relayed_out_literal.untyped_data();
TF_RETURN_IF_ERROR(TransferBufferToDevice(
stream,
/*size=*/GetByteSizeRequirement(device_subshape), source,
@@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed(
}
Status GenericTransferManager::ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>
+ absl::Span<se::StreamExecutor* const>
/*executors*/) {
return Unimplemented(
"Device reset is not yet supported on this platform (b/30481585)");
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 6c1a21587a..86c8b1c145 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager {
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
- Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+ Status ResetDevices(absl::Span<se::StreamExecutor* const> executors) override;
int64 GetByteSizeRequirement(const Shape& shape) const override;
protected:
Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 8ef72850dc..64b9683628 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -56,6 +56,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -91,6 +93,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -105,8 +108,12 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -126,6 +133,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -165,12 +174,14 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
"//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
@@ -180,6 +191,12 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
@@ -224,6 +241,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
],
@@ -243,6 +262,8 @@ cc_library(
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -257,6 +278,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -337,6 +359,11 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -345,6 +372,8 @@ cc_library(
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
deps = [
+ ":backend_configs",
+ ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -370,9 +399,13 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -390,6 +423,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -420,7 +454,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:shape_inference",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
@@ -431,6 +465,7 @@ cc_library(
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
+ ":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -451,6 +486,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -460,12 +496,14 @@ cc_library(
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
+ ":gpu_fusible",
":instruction_fusion",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -483,6 +521,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -506,6 +545,7 @@ cc_library(
srcs = ["fusion_merger.cc"],
hdrs = ["fusion_merger.h"],
deps = [
+ ":gpu_fusible",
":instruction_fusion",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -513,6 +553,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -544,6 +586,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:shape_inference",
+ "@com_google_absl//absl/memory",
],
)
@@ -600,6 +643,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
alwayslink = True, # Contains per-platform transfer manager registration
@@ -616,9 +660,9 @@ cc_library(
":gpu_constants",
":gpu_copy_insertion",
":gpu_executable",
+ ":gpu_hlo_schedule",
":gpu_hlo_support_checker",
":gpu_layout_assignment",
- ":hlo_schedule",
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
@@ -639,7 +683,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
- "//tensorflow/compiler/xla/service:convolution_feature_group_converter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@@ -670,6 +713,10 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
alwayslink = True, # Contains compiler registration
@@ -702,8 +749,8 @@ cc_library(
":xfeed_queue",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
],
)
@@ -718,6 +765,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -756,39 +804,44 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
+ "@com_google_absl//absl/strings",
],
)
cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
+ name = "gpu_hlo_schedule",
+ srcs = ["gpu_hlo_schedule.cc"],
+ hdrs = ["gpu_hlo_schedule.h"],
deps = [
":stream_assignment",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "@com_google_absl//absl/memory",
],
)
tf_cc_test(
- name = "hlo_schedule_test",
+ name = "gpu_hlo_schedule_test",
srcs = [
- "hlo_schedule_test.cc",
+ "gpu_hlo_schedule_test.cc",
],
deps = [
- ":hlo_schedule",
+ ":gpu_hlo_schedule",
":stream_assignment",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -839,7 +892,9 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -852,6 +907,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@@ -868,9 +924,8 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_runner",
- "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/strings",
],
)
@@ -889,3 +944,26 @@ xla_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "gpu_fusible",
+ srcs = ["gpu_fusible.cc"],
+ hdrs = ["gpu_fusible.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla/service:hlo",
+ ],
+)
+
+tf_cc_test(
+ name = "gpu_fusible_test",
+ srcs = ["gpu_fusible_test.cc"],
+ deps = [
+ ":gpu_fusible",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index 537295292b..528209abc7 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
const BufferAssignment* buffer_assignment, int device_ordinal,
DeviceMemoryAllocator* memory_allocator) {
const int64 num_buffers = buffer_assignment->Allocations().size();
- auto buffer_allocations = WrapUnique(new BufferAllocations(
+ auto buffer_allocations = absl::WrapUnique(new BufferAllocations(
num_buffers, device_ordinal, memory_allocator, buffer_assignment));
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
@@ -62,7 +62,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment !=
0) {
return InternalError(
- "Address of registered buffer %lld must be a multiple of %llx, but "
+ "Address of registered buffer %d must be a multiple of %x, but "
"was %p",
i, kEntryParameterAlignBytes, address.opaque());
}
@@ -83,7 +83,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
- "multiple of %llx, but was %p",
+ "multiple of 0x%x, but was %p",
kXlaAllocatedBufferAlignBytes, buffer.opaque());
}
// We do manual memory management within BufferAllocations. Be sure not
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
index f13eab0dd7..14186b8faa 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
index 6a285a6b98..13c83c9199 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include <cmath>
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -74,9 +74,8 @@ ENTRY MaxDifference {
%error = f32[SIZE] divide(%sub_abs, %denominator)
ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
})";
- auto size_string = std::to_string(num_elements);
- return tensorflow::str_util::StringReplace(
- kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
+ return absl::StrReplaceAll(kF16CompHloText,
+ {{"SIZE", absl::StrCat(num_elements)}});
}
StatusOr<F16BufferComparator> F16BufferComparator::Create(
@@ -125,7 +124,7 @@ StatusOr<F16BufferComparator> F16BufferComparator::Create(
StatusOr<bool> F16BufferComparator::CompareEqualImpl(
se::DeviceMemory<Eigen::half> test_buffer) {
if (ref_buffer_.root_buffer().size() != test_buffer.size()) {
- return InternalError("Mismatched buffer size: %lld vs %lld",
+ return InternalError("Mismatched buffer size: %d vs %d",
ref_buffer_.root_buffer().size(), test_buffer.size());
}
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 5780e0af40..9ed523998b 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to retrieve predicate value on stream %p: %s.",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
// Execute the true or the false computation depending on the value of the
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 7833a4077e..3a23ac1d63 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -17,12 +17,12 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -31,61 +31,32 @@ namespace gpu {
using se::dnn::AlgorithmDesc;
-ConvolutionThunk::ConvolutionThunk(
- CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo)
- : Thunk(Kind::kConvolution, hlo),
- convolution_kind_(convolution_kind),
- input_buffer_(input_buffer),
- filter_buffer_(filter_buffer),
- output_buffer_(output_buffer),
- tuple_result_buffer_(tuple_result_buffer),
- scratch_buffer_(scratch_buffer),
- input_shape_(input_shape),
- filter_shape_(filter_shape),
- output_shape_(output_shape),
- window_(window),
- dim_nums_(dim_nums),
- algorithm_(algorithm),
- tensor_ops_enabled_(tensor_ops_enabled) {}
-
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- se::DeviceMemoryBase input_data =
- buffer_allocations.GetDeviceAddress(input_buffer_);
- se::DeviceMemoryBase filter_data =
- buffer_allocations.GetDeviceAddress(filter_buffer_);
- se::DeviceMemoryBase output_data =
- buffer_allocations.GetDeviceAddress(output_buffer_);
+ CudnnConvParams params;
+
+ params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+ params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(
- convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
- stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
void* result_ptr = [&] {
- switch (convolution_kind_) {
+ switch (params.kind) {
case CudnnConvKind::kForward:
- return output_data.opaque();
+ return params.output_buf.opaque();
case CudnnConvKind::kBackwardInput:
- return input_data.opaque();
+ return params.input_buf.opaque();
case CudnnConvKind::kBackwardFilter:
- return filter_data.opaque();
+ return params.filter_buf.opaque();
}
}();
void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d76ca6698d..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
@@ -23,16 +24,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,26 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,35 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- class ScratchAllocator;
-
- Status Convolve(const se::dnn::BatchDescriptor& input_descriptor,
- se::DeviceMemory<float> input_data,
- const se::dnn::FilterDescriptor& filter_descriptor,
- se::DeviceMemory<float> filter_data,
- const se::dnn::BatchDescriptor& output_descriptor,
- se::DeviceMemory<float> output_data,
- const se::dnn::ConvolutionDescriptor& convolution_descriptor,
- const se::dnn::AlgorithmConfig& algorithm_config,
- se::Stream* stream, ScratchAllocator* scratch_allocator,
- se::dnn::ProfileResult* profile_result);
-
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index e09cde9abf..6e2e330edd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -54,9 +54,7 @@ namespace gpu {
// BatchNormRewriter.
class CudnnBatchNormRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "cudnn_batchnorm_rewriter";
- }
+ absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 7b172812c3..bc3c6f72f6 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,12 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 225d2ee3ce..c607aea1a8 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,24 +14,26 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
-#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
namespace gpu {
namespace {
+using absl::optional;
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
-using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
public:
@@ -59,8 +61,8 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
- tensorflow::strings::Printf(
- "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ absl::StrFormat(
+ "Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
@@ -128,14 +130,14 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
string AlgorithmToString(const AlgorithmDesc& algo) {
if (algo.tensor_ops_enabled()) {
- return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+ return absl::StrCat(algo.algo_id(), "+TC");
}
- return tensorflow::strings::StrCat(algo.algo_id());
+ return absl::StrCat(algo.algo_id());
}
string NumBytesToString(int64 bytes) {
- return tensorflow::strings::StrCat(
- tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)");
+ return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
+ bytes, "B)");
}
// Acquires a process-global lock on the device pointed to by the given
@@ -175,9 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
+ const HloCustomCallInstruction* instr) {
+ CudnnConvParams params;
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
+
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -191,6 +198,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// concurrently and then run them sequentially.
tensorflow::mutex_lock lock = LockGpu(stream_exec_);
+ // Make sure any previous activity on this executor is done. We don't want to
+ // interfere with programs that are still running on the GPU.
+ if (!stream_exec_->SynchronizeAllActivity()) {
+ return InternalError("Failed to synchronize GPU for autotuning.");
+ }
+
// Create a stream for us to do our work on.
se::Stream stream{stream_exec_};
stream.Init();
@@ -203,9 +216,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
if (allocator_ != nullptr) {
allocator = allocator_;
} else {
- se_allocator.emplace(
- stream_exec_->platform(),
- tensorflow::gtl::ArraySlice<se::StreamExecutor*>({stream_exec_}));
+ se_allocator.emplace(stream_exec_->platform(),
+ absl::Span<se::StreamExecutor* const>({stream_exec_}));
allocator = &*se_allocator;
}
@@ -213,13 +225,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
+ TF_ASSIGN_OR_RETURN(params.input_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
+ TF_ASSIGN_OR_RETURN(params.filter_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
+ TF_ASSIGN_OR_RETURN(params.output_buf,
input_output_allocator.AllocateBytes(
&stream, ShapeUtil::ByteSizeOf(output_shape)));
@@ -233,8 +245,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CHECK_EQ(0, left_over_bytes % 2);
constexpr float kBroadcastedConstant = 0.1f;
- Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant),
- Eigen::half(kBroadcastedConstant)};
+ static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant),
+ Eigen::half(kBroadcastedConstant)};
uint32 bits;
static_assert(sizeof(bits) == sizeof(halfs), "");
memcpy(&bits, halfs, sizeof(bits));
@@ -246,33 +258,32 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
};
- initialize_f16(input_buf);
- initialize_f16(filter_buf);
- initialize_f16(output_buf);
+ initialize_f16(params.input_buf);
+ initialize_f16(params.filter_buf);
+ initialize_f16(params.output_buf);
} else {
// Although we don't have evidence this matters, zero out the buffers before
// autotuning. It's conceivable that using uninitialized memory as the
// inputs might affect performance if e.g. the inputs contain denormals, and
// this is easy enough.
- stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size());
+ stream.ThenMemZero(&params.input_buf, params.input_buf.size())
+ .ThenMemZero(&params.filter_buf, params.filter_buf.size())
+ .ThenMemZero(&params.output_buf, params.output_buf.size());
}
- TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
DeviceMemoryBase* result_buf = [&] {
- switch (kind) {
+ switch (params.kind) {
case CudnnConvKind::kBackwardFilter:
- return &filter_buf;
+ return &params.filter_buf;
case CudnnConvKind::kBackwardInput:
- return &input_buf;
+ return &params.input_buf;
case CudnnConvKind::kForward:
- return &output_buf;
+ return &params.output_buf;
}
}();
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, dnums, stream_exec_);
+ input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
@@ -282,20 +293,23 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
for (const AlgorithmDesc& alg :
- GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+ GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok =
- RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums,
- AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ params.algorithm = AlgorithmConfig(alg);
+ bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
+ &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
+ const bool crash_on_checking_failure =
+ instr->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_crash_on_verification_failures();
if (comparator.has_value()) {
StatusOr<bool> result = comparator->CompareEqual(
se::DeviceMemory<Eigen::half>(*result_buf));
@@ -304,6 +318,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< AlgorithmToString(*first_algorithm) << " against "
<< AlgorithmToString(alg) << " for " << instr->ToString()
<< ": " << result.status();
+ CHECK(!crash_on_checking_failure);
} else if (!result.ValueOrDie()) {
LOG(ERROR) << "Results mismatch between different convolution "
"algorithms. This is likely a bug in convolution, or "
@@ -311,6 +326,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< instr->ToString() << " for "
<< AlgorithmToString(*first_algorithm) << " vs "
<< AlgorithmToString(alg);
+ CHECK(!crash_on_checking_failure);
}
} else if (cross_check_enabled) {
auto comp = F16BufferComparator::Create(
@@ -322,6 +338,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
} else {
LOG(ERROR) << "Fail to initialize buffer comparator: "
<< comp.status() << ", instruction: " << instr->ToString();
+ CHECK(!crash_on_checking_failure);
}
}
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
@@ -353,38 +370,15 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
- instr->ToString().c_str());
+ instr->ToString());
}
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- const auto& call_target = instr->custom_call_target();
- const auto& lhs_shape = instr->operand(0)->shape();
- const auto& rhs_shape = instr->operand(1)->shape();
- const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
- if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc =
- PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
- } else if (call_target == kCudnnConvBackwardInputCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
- } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(), instr);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instr->ToString();
- }
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
if (!alg_scratch_and_tc.ok()) {
LOG(ERROR) << alg_scratch_and_tc.status();
@@ -414,14 +408,9 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
- HloInstruction* new_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- new_call_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1)},
- instr->custom_call_target()));
- new_call->set_window(instr->window());
- new_call->set_convolution_dimension_numbers(
- instr->convolution_dimension_numbers());
+ HloInstruction* new_call = computation->AddInstruction(
+ instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
+ instr->mutable_operand(1)}));
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 8b7749628a..f79b113f8f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,12 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -39,7 +40,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
Compiler* compiler)
: stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-algorithm-picker";
}
@@ -49,9 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
+ const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 905b5ee876..3d1266355b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include <cstdlib>
#include <numeric>
#include <vector>
@@ -59,6 +60,9 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@@ -213,42 +217,55 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
- HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
- std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
+
+ // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but at least for now with version 7.1.4
+ // it is slower. This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dialted conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -373,23 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // Fuse the matched HLOs into a backward convolution instruction.
- //
- // If the reverse is omitted (for 1x1 filters) in the original pattern, we add
- // it back in the fusion instruction so that later passes (such as
- // PadInsertion) can handle such fusion instructions easily.
- if (reverse_filter->opcode() != HloOpcode::kReverse) {
- reverse_filter = reverse_filter->parent()->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
- }
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. This is the way cuDNN expects it to be.
dnums.set_kernel_input_feature_dimension(
conv->convolution_dimension_numbers().kernel_output_feature_dimension());
dnums.set_kernel_output_feature_dimension(
conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- return std::make_tuple(true, new_window, dnums);
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
+ // Create a double-reverse, which is a nop.
+ HloComputation* c = conv->parent();
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
+ }
+
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(true, new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ Shape new_shape = rhs->shape();
+ int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+ // In the backward convolution case, the spatial dimensions become the
+ // feature dimensions, and we are guaranteed that the spatial dimensions are
+ // adjacent.
+ CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+ int64 input_features = new_shape.dimensions(input_feature_dimension);
+ int64 output_features = new_shape.dimensions(output_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension,
+ input_features / conv->feature_group_count());
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * conv->feature_group_count());
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(true, new_window, dnums, rhs);
}
// Tries to rewrite a single convolution into a call to cudnn.
@@ -400,30 +458,28 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
+ HloInstruction* rhs;
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
return CreateCudnnConvBackwardFilter(
conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums);
+ window, dnums, conv->feature_group_count());
}
- std::tie(match, window, dnums) = MatchBackwardInput(conv);
+ std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- // Backward input conv subsumes the conv plus the reverse in operand 1.
- HloInstruction* reverse = conv->mutable_operand(1);
- CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
- HloInstruction* rhs = reverse->mutable_operand(0);
-
- return CreateCudnnConvBackwardInput(
- conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
+ return CreateCudnnConvBackwardInput(conv->shape(),
+ conv->mutable_operand(0), rhs, window,
+ dnums, conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers());
+ conv->convolution_dimension_numbers(),
+ conv->feature_group_count());
}
return nullptr;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index 0c0578d888..fbe7e98494 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -26,7 +26,7 @@ namespace gpu {
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
class CudnnConvolutionRewriter : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "cudnn-convolution-rewriter";
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 65588b6aaf..d237f8930b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -32,10 +32,13 @@ namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
+using ::testing::_;
-class CudnnConvolutionRewriterTest : public HloTestBase {
+class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
public:
- CudnnConvolutionRewriterTest() {
+ CudnnConvolutionRewriterTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false) {
for (int i = 0; i < 2; ++i) {
WindowDimension* window_dim = default_conv_window_.add_dimensions();
window_dim->set_size(1);
@@ -104,17 +107,17 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
conv_window.mutable_dimensions(1)->set_size(2);
conv_window.mutable_dimensions(1)->set_window_dilation(2);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -132,17 +135,17 @@ TEST_F(CudnnConvolutionRewriterTest,
Window conv_window = default_conv_window_;
conv_window.mutable_dimensions(1)->set_size(3);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -167,12 +170,13 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -197,12 +201,13 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -225,12 +230,13 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
@@ -269,18 +275,19 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
- /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+ /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window,
+ conv_dnums, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(),
- ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
- .ValueOrDie()));
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, conv_window, conv_dnums)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
@@ -316,16 +323,16 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- conv_window,
+ /*feature_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, conv_window,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -347,17 +354,18 @@ TEST_F(CudnnConvolutionRewriterTest,
1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- default_conv_window_,
- tf_default_dnums_for_backward_input_)
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(), /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -399,18 +407,20 @@ TEST_F(CudnnConvolutionRewriterTest,
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -446,18 +456,20 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
@@ -499,18 +511,20 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_base_dilation(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
const HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
ASSERT_THAT(entry_computation->root_instruction(),
op::GetTupleElement(
op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
@@ -551,23 +565,51 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_padding_high(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(RunPass(module.get()));
+ EXPECT_TRUE(RunPass(module));
EXPECT_THAT(
entry_computation->root_instruction(),
op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
}
+// Check that we will materialize a reversed version of a constant in order to
+// pattern-match a backwards input convolution.
+TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
+ Array4D<float> constant_arr(4, 4, 2, 2);
+ constant_arr.FillIota(0);
+ string constant_str =
+ LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
+ ParseAndVerifyModule(absl::StrFormat(R"(
+ HloModule test
+
+ ENTRY entry_computation {
+ param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
+ constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
+ ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
+ window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
+ dim_labels=bf01_01oi->bf01, feature_group_count=1
+ })",
+ constant_str));
+ EXPECT_TRUE(RunPass(&module()));
+ EXPECT_THAT(
+ module().entry_computation()->root_instruction(),
+ op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _,
+ op::Reverse(op::Constant())),
+ 0));
+}
+
} // anonymous namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 7b0d9e53d6..2a86ac265e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator {
"Can't allocate twice from a ScratchBufAllocator.");
}
if (byte_size > scratch_.size()) {
- return se::port::InternalError(tensorflow::strings::StrCat(
+ return se::port::InternalError(absl::StrCat(
"Can't allocate ", byte_size,
" bytes from a ScratchBufAllocator of size ", scratch_.size()));
}
@@ -71,13 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator {
};
template <typename T>
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<T> input_buf,
- DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
- Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ CudnnConvKind kind = params.kind;
+ const Shape& input_shape = *params.input_shape;
+ const Shape& filter_shape = *params.filter_shape;
+ const Shape& output_shape = *params.output_shape;
+ DeviceMemory<T> input_buf(params.input_buf);
+ DeviceMemory<T> filter_buf(params.filter_buf);
+ DeviceMemory<T> output_buf(params.output_buf);
+ const Window& window = *params.window;
+ const ConvolutionDimensionNumbers& dnums = *params.dnums;
+ int64 feature_group_count = params.feature_group_count;
+ AlgorithmConfig algorithm = params.algorithm;
+
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -143,6 +153,7 @@ Status RunCudnnConvolution(
}
ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
+ convolution_descriptor.set_group_count(feature_group_count);
for (int dim = 0; dim < num_dimensions; ++dim) {
convolution_descriptor
.set_zero_padding(
@@ -196,8 +207,8 @@ Status RunCudnnConvolution(
if (!stream->ok()) {
return InternalError(
- "Unable to launch convolution with type %s and algorithm (%lld, %lld)",
- CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(),
+ "Unable to launch convolution with type %s and algorithm (%d, %d)",
+ CudnnConvKindToString(kind), algorithm.algorithm().algo_id(),
algorithm.algorithm_no_scratch().algo_id());
}
return Status::OK();
@@ -216,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) {
}
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(params, &scratch_allocator, stream,
+ profile_result);
}
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = params.output_shape->element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+ stream, profile_result);
case F32:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+ profile_result);
case F64:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+ profile_result);
default:
- LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 944e4ac686..381aa37a1b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@ enum class CudnnConvKind {
kBackwardFilter, // input + output => filter
};
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);
@@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
//
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
@@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(
- CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, se::DeviceMemoryBase input_buf,
- se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
- se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
- se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+ se::ScratchAllocator* scratch_allocator,
+ se::Stream* stream,
+ se::dnn::ProfileResult* profile_result = nullptr);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 9b6de115ad..c1aaa4bf04 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instructions.h"
@@ -43,16 +45,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gpu {
+using absl::StrAppend;
using llvm_ir::IrArray;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
-using tensorflow::strings::StrAppend;
namespace {
// Returns whether operand is a floating-point literal with the given value.
@@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(
compute_nested_(std::move(compute_nested)) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// The libdevice math functions differentiate between "double" and "float" by
// appending an 'f' to the function's name. libdevice doesn't have f16 math
// functions, so we convert the operands to f32 before calling the function
@@ -94,7 +92,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
converted_operands[i] =
- b_->CreateFPCast(converted_operands[i], b_->getFloatTy());
+ FPCast(converted_operands[i], b_->getFloatTy());
converted_input_types[i] = F32;
}
}
@@ -107,22 +105,20 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
break;
default:
return Unimplemented("Bad type for libdevice math call: %s",
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(output_type));
}
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
- result = b_->CreateFPCast(result, b_->getHalfTy());
+ result = FPCast(result, b_->getHalfTy());
}
return result;
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// llvm intrinsics differentiate between half/float/double functions via
// the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
@@ -138,22 +134,20 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
break;
default:
return Unimplemented("Bad type for llvm intrinsic math call: %s",
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(output_type));
}
return EmitMathCall(munged_callee, operands, input_types, output_type);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
return Unimplemented("Input type ≠ output type: %s ≠ %s",
- PrimitiveType_Name(input_type).c_str(),
- PrimitiveType_Name(output_type).c_str());
+ PrimitiveType_Name(input_type),
+ PrimitiveType_Name(output_type));
}
}
@@ -163,8 +157,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
@@ -183,8 +176,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const {
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
CHECK_EQ(op->opcode(), HloOpcode::kPower);
PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
@@ -218,7 +210,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
// TODO(jlebar): Does this happen with fastmath disabled? If not, should
// we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
- return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
+ return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString();
@@ -227,55 +219,56 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
- PrimitiveType prim_type, llvm::Value* value) const {
+ PrimitiveType prim_type, llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) {
return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
- llvm::Value* rhs) const {
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
- PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
+ llvm::Value* lhs,
+ llvm::Value* rhs) {
return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type},
prim_type);
}
-StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
- PrimitiveType prim_type, llvm::Value* value) const {
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
+ llvm::Value* value) {
// Emit a fast approximation of tanh instead of calling __nv_tanh.
// __nv_tanh is particularly bad because it contains branches, thus
// preventing LLVM's load-store vectorizer from working its magic across a
@@ -285,17 +278,15 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* input = FPCast(value, type);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
+ return FPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const {
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes) {
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
@@ -315,29 +306,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
callee->addFnAttr(attribute);
}
- return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+ return Call(callee, llvm_ir::AsArrayRef(operands));
}
-llvm::Value* GpuElementalIrEmitter::EmitThreadId() const {
- llvm::Value* block_id = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "block.id");
- llvm::Value* thread_id_in_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
- llvm::Value* threads_per_block = b_->CreateIntCast(
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
- {}, {}, b_),
- b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
- return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block),
- thread_id_in_block);
+llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
+ llvm::Value* block_id =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ llvm::Value* thread_id_in_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ llvm::Value* threads_per_block =
+ IntCast(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const {
+ const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kMap:
return [=, &operand_to_generator](
@@ -383,7 +373,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- b_->CreateStore(init_value, accum_ptr);
+ Store(init_value, accum_ptr);
}
llvm::Type* index_type = index.GetType();
@@ -405,22 +395,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
IrArray::Index input_index(index_type, index.size());
llvm::Value* in_bounds = b_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* stridden_index = b_->CreateNSWMul(
+ llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
- input_index[i] = b_->CreateNSWSub(
- b_->CreateNSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ input_index[i] =
+ NSWSub(NSWAdd(stridden_index, window_index[i]),
+ index_typed_const(window.dimensions(i).padding_low()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
// comparison is equivalent to the unsigned comparison
// input_index[i] < bound, as a negative value wraps to a large
// positive value.
- in_bounds = b_->CreateAnd(
- in_bounds,
- b_->CreateICmpULT(
- input_index[i],
- index_typed_const(operand->shape().dimensions(i))));
+ in_bounds =
+ And(in_bounds,
+ ICmpULT(input_index[i],
+ index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
@@ -432,12 +421,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
operand_to_generator.at(operand)(input_index));
TF_ASSIGN_OR_RETURN(
llvm::Value * accum_value,
- compute_nested_(*hlo->to_apply(),
- {b_->CreateLoad(accum_ptr), input_value}));
- b_->CreateStore(accum_value, accum_ptr);
+ compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
+ Store(accum_value, accum_ptr);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
- return b_->CreateLoad(accum_ptr);
+ return Load(accum_ptr);
};
case HloOpcode::kReduce:
// TODO(b/112040122): This should be supported.
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 84454d31bb..e8b56a39ce 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace gpu {
@@ -38,9 +38,9 @@ namespace gpu {
class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
- // given an ArraySlice of its input elements.
+ // given a Span of its input elements.
using NestedComputer = std::function<StatusOr<llvm::Value*>(
- const HloComputation&, tensorflow::gtl::ArraySlice<llvm::Value*>)>;
+ const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
llvm::Module* module, llvm::IRBuilder<>* b,
@@ -48,85 +48,77 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) const override;
+ const HloToElementGeneratorMap& operand_to_generator) override;
protected:
- StatusOr<llvm::Value*> EmitFloatBinaryOp(
- const HloInstruction* op, llvm::Value* lhs_value,
- llvm::Value* rhs_value) const override;
+ StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value) override;
StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
- llvm::Value* rhs) const override;
+ llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
- llvm::Value* value) const override;
+ llvm::Value* value) override;
- llvm::Value* EmitThreadId() const override;
+ llvm::Value* EmitThreadId() override;
private:
// Emits IR for op, which must have opcode kPower.
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,
llvm::Value* lhs_value,
- llvm::Value* rhs_value) const;
+ llvm::Value* rhs_value);
// Emits IR to call a device function named "callee_name" on the given
// operand. Returns the IR value that represents the return value.
llvm::Value* EmitDeviceFunctionCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
- PrimitiveType output_type,
- tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const;
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
+ absl::Span<const llvm::Attribute::AttrKind> attributes);
// Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLlvmIntrinsicMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a libdevice function of type [T] -> T. Adjusts
// callee_name according to T. Returns the IR value that represents the
// return value of the function.
StatusOr<llvm::Value*> EmitLibdeviceMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
- const string& callee_name,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
- PrimitiveType output_type) const;
+ const string& callee_name, absl::Span<llvm::Value* const> operands,
+ absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index 0cdddf8bcf..ca4a605af5 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -43,8 +43,8 @@ StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
- tensorflow::strings::Printf(
- "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ absl::StrFormat(
+ "Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
@@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) {
} // namespace
-FftThunk::FftThunk(FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length,
+FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
@@ -213,7 +212,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
return Status::OK();
}
return InternalError("Unable to launch fft for thunk %p with type %s", this,
- FftTypeToString(fft_type_).c_str());
+ FftTypeToString(fft_type_));
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 8c53be5077..2be50e08bd 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -62,7 +62,7 @@ class FftThunk : public Thunk {
public:
// Constructs a thunk for launching an FFT on a stream.
// Semantics of null hlo_instruction argument are as in Thunk.
- FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length,
+ FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const BufferAllocation::Slice& input_buffer,
const BufferAllocation::Slice& output_buffer,
const Shape& input_shape, const Shape& output_shape,
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 2fd2206324..88f0b4d71c 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
// Pass nullptr as the HloInstruction* to the body_thunk_sequence_
// constructor because this SequentialThunk is logically "part of"
// this ForThunk, and shouldn't be profiled separately from it.
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 3cd30b754c..30c1f90889 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -18,12 +18,14 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace gpu {
@@ -64,10 +66,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) {
// Slice for a more accurate estimate of bytes read.
double bytes = 0.0;
for (auto& instruction : instructions) {
- if (c_all_of(instruction->users(), [](const HloInstruction* instruction) {
- return instruction->opcode() == HloOpcode::kSlice ||
- instruction->opcode() == HloOpcode::kDynamicSlice;
- })) {
+ if (absl::c_all_of(
+ instruction->users(), [](const HloInstruction* instruction) {
+ return instruction->opcode() == HloOpcode::kSlice ||
+ instruction->opcode() == HloOpcode::kDynamicSlice;
+ })) {
// All users are slice: accumulate bytes of all user slice instructions.
for (auto& user : instruction->users()) {
bytes += ShapeUtil::ByteSizeOf(user->shape());
@@ -223,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
- if (!c_all_of(fusion->users(), [](const HloInstruction* user) {
+ if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
- user->fusion_kind() == HloInstruction::FusionKind::kInput);
+ (user->fusion_kind() == HloInstruction::FusionKind::kInput &&
+ LayoutsAreReduceInputFusionFriendly(*fusion, *user)));
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
@@ -241,11 +245,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// If 'fusion' has just one user, then an earlier fusion pass chose not to
// fuse this producer/comsumer pair (likely because of expensive instruction
// re-use by the consumer), and so we honor that choice here as well.
- if (c_any_of(fusion->fused_instructions(),
- [](const HloInstruction* instruction) {
- return instruction->opcode() != HloOpcode::kParameter &&
- GpuInstructionFusion::IsExpensive(*instruction);
- })) {
+ if (absl::c_any_of(fusion->fused_instructions(),
+ [](const HloInstruction* instruction) {
+ return instruction->opcode() != HloOpcode::kParameter &&
+ GpuInstructionFusion::IsExpensive(*instruction);
+ })) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Contains one or more expensive instructions.";
++num_fail_expensive_fused_instruction_;
@@ -287,11 +291,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
<< " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion)
<< " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio
<< " into users { "
- << tensorflow::str_util::Join(users, ", ",
- [](string* out, HloInstruction* user) {
- tensorflow::strings::StrAppend(
- out, user->name());
- })
+ << absl::StrJoin(users, ", ",
+ [](string* out, HloInstruction* user) {
+ absl::StrAppend(out, user->name());
+ })
<< " }";
// Remove 'fusion' instruction.
CHECK_EQ(0, fusion->user_count());
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 4c523a66de..7e3f5775b8 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -34,7 +34,7 @@ namespace gpu {
//
class FusionMerger : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "fusion merger"; }
+ absl::string_view name() const override { return "fusion merger"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index b22bb1d39b..7cc869ed9e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
op::Fusion(op::Parameter()));
}
+TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
+ auto module = ParseHloString(R"(
+ HloModule m
+
+ f1_computation {
+ f1_p0 = f32[16,16,256]{0,1,2} parameter(0)
+ add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0)
+ // Note that the copy changes the layout from {0,1,2} to {2,1,0}.
+ ROOT f1_root = f32[16,16,256]{2,1,0} copy(add)
+ }
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ f2_computation {
+ f2_p0 = f32[16,16,256]{2,1,0} parameter(0)
+ f2_zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,256]{0,1,2} parameter(0)
+ f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation
+ ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation
+ })")
+ .ValueOrDie();
+ EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
+}
+
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 74282c568c..9c4a490366 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -186,7 +186,7 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune(
}
return InternalError(
- "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms "
+ "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms "
"ran successfully",
stream, algorithms.size());
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 0c6f9b511f..8ffae18fe8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -27,7 +27,7 @@ namespace gpu {
// inserting kCopy instructions.
class GpuCopyInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 7060837904..31a9f9b1be 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks(
//
// TODO(jlebar): Should we cache the results of HloInstruction::ToString(),
// since we expect it to be an expensive call?
- tensorflow::gtl::optional<ScopedAnnotation> op_annotation;
+ absl::optional<ScopedAnnotation> op_annotation;
if (top_level_annotation.IsEnabled()) {
op_annotation.emplace(
thunk->hlo_instruction() != nullptr
@@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks(
TF_RETURN_IF_ERROR(
thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
if (thunk_schedule_->Depended(thunk)) {
- auto finish_event = MakeUnique<se::Event>(main_stream->parent());
+ auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
@@ -160,7 +160,7 @@ Status GpuExecutable::ExecuteThunks(
if (!block_status.ok()) {
return InternalError(
"Failed to complete all kernels launched on stream %p: %s",
- main_stream, block_status.error_message().c_str());
+ main_stream, block_status.error_message());
}
}
@@ -234,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
@@ -260,10 +260,9 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
if (buffer.is_null() && buffer.size() > 0) {
return FailedPrecondition(
"Cannot run XLA computation because pointer to (sub-)buffer at "
- "index %s of parameter %lld was null. All pointers to "
- "(sub-)buffers must not be null, unless the (sub-)buffer has zero "
- "elements.",
- allocation.param_shape_index().ToString().c_str(), param_no);
+ "index %s of parameter %d was null. All pointers to (sub-)buffers "
+ "must not be null, unless the (sub-)buffer has zero elements.",
+ allocation.param_shape_index().ToString(), param_no);
}
buffer_allocations_builder.RegisterBuffer(i, buffer);
@@ -326,7 +325,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode.
return Unimplemented(
"Asynchronous execution on stream is not yet supported on GPU.");
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index c7ce6d0acb..38b0f8f15b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,9 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -32,10 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -78,12 +78,12 @@ class GpuExecutable : public Executable {
// match the compute capability passed to this object's constructor.
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override;
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
private:
// If `block_host_until_done` is false, execution will not block the host
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
new file mode 100644
index 0000000000..2d31fd5570
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+void AppendParams(const HloInstruction& instr,
+ std::vector<HloInstruction*>* params) {
+ if (instr.opcode() == HloOpcode::kFusion) {
+ params->insert(std::end(*params), std::begin(instr.fused_parameters()),
+ std::end(instr.fused_parameters()));
+ } else {
+ for (HloInstruction* operand : instr.operands()) {
+ params->push_back(operand);
+ }
+ }
+}
+} // namespace
+
+bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
+ const HloInstruction& reduce) {
+ std::vector<HloInstruction*> params;
+ AppendParams(producer, &params);
+ AppendParams(reduce, &params);
+ int64 max_rank = -1;
+ const Layout* max_rank_layout;
+ for (HloInstruction* param : params) {
+ if (ShapeUtil::IsArray(param->shape()) &&
+ ShapeUtil::Rank(param->shape()) > max_rank) {
+ max_rank = ShapeUtil::Rank(param->shape());
+ max_rank_layout = &param->shape().layout();
+ }
+ }
+ return absl::c_all_of(params, [&](HloInstruction* param) {
+ return (!ShapeUtil::IsArray(param->shape())) ||
+ (ShapeUtil::Rank(param->shape()) < max_rank) ||
+ (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
+ });
+}
+
+bool IsInputFusibleReduction(const HloInstruction& instr) {
+ if (instr.IsMultiOutputFusion()) {
+ for (const HloInstruction* operand :
+ instr.fused_expression_root()->operands()) {
+ if (IsReductionToVector(*operand)) {
+ CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Multi-output fusion rooted at reduction-to-vector ops must be "
+ "of kind kInput: "
+ << instr.ToString();
+ return true;
+ }
+ }
+ return false;
+ } else if (instr.opcode() == HloOpcode::kFusion) {
+ if (IsReductionToVector(*instr.fused_expression_root())) {
+ CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
+ << instr.ToString();
+ return true;
+ }
+ return false;
+ }
+ return IsReductionToVector(instr);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
new file mode 100644
index 0000000000..f7c24a0d5b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from
+// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion.
+
+namespace xla {
+namespace gpu {
+
+// The code emitted for reduce-rooted input fusions (EmitReductionToVector)
+// suffers from poor data locality if the layouts of input parameters differ. In
+// such situtations it is better not to fuse. Only input params with
+// maximum rank are considered. Params with smaller ranks will be broadcasted
+// and have not been observed to cause data locality issues.
+// TODO(b/111977086): Improve reduce emitters to remove this limitation.
+bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
+ const HloInstruction& reduce);
+
+// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr`
+// is either an unfused reduction-to-vector op, an input fusion rooted at a
+// reduction-to-vector op, or a multi-output input fusion with at least one
+// reduction-to-vector op root.
+// Note that reduction ops are lowered in different ways. Reduce input fusions
+// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at
+// reduction-to-vector ops. Other reduction ops are lowered by
+// GpuElementalIrEmitter and fused like elementwise ops.
+bool IsInputFusibleReduction(const HloInstruction& instr);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
new file mode 100644
index 0000000000..d91b7bc61f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace gpu {
+
+using GpuFusibleTest = HloTestBase;
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+ scalar_add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ })";
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_ElementwiseProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ exp = f32[2,2,2]{2,1,0} exponential(p0)
+ ROOT reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ const HloInstruction* exp =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(exp->opcode(), HloOpcode::kExp);
+ EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*exp, *reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_MixedLayoutProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ mixed_input_layouts_computation {
+ p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
+ reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
+ HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(1);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
+ EXPECT_FALSE(
+ LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
+}
+
+TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduce {
+ p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
+ c0.1 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(p0.1, c0.1), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ copy = f32[128,1024,32,32]{1,3,2,0} copy(p0)
+ ROOT reduce_fusion = f32[1024]{0} fusion(copy), kind=kInput, calls=fused_reduce
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->fused_expression_root()->opcode(), HloOpcode::kReduce);
+ const HloInstruction* copy =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(copy->opcode(), HloOpcode::kCopy);
+ EXPECT_FALSE(LayoutsAreReduceInputFusionFriendly(*copy, *reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_LayoutChangingFusionProducer) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ layout_changing_computation {
+ p0.1 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ c0 = f16[] constant(0)
+ broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
+ greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast)
+ select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
+ ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
+ }
+ fused_reduce {
+ p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
+ p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
+ loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=layout_changing_computation
+ ROOT reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce_fusion =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
+ HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kCopy);
+ EXPECT_FALSE(
+ LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
+}
+
+TEST_F(GpuFusibleTest,
+ LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ broadcasting_computation {
+ p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1.1 = f32[128]{0} parameter(1)
+ broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
+ ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
+ }
+ ENTRY entry {
+ p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
+ p1 = f16[128]{0} parameter(1)
+ loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
+ c0.2 = f32[] constant(0)
+ ROOT reduce = f32[128,1024]{0,1} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ const HloInstruction* loop_fusion =
+ module->entry_computation()->root_instruction()->operand(0);
+ ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kAdd);
+ EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ // Reduction-to-vector lowered by IrEmitterUnnested.
+ ROOT reduce = f32[512]{0} reduce(p1, c0), dimensions={0,2,3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ ENTRY entry {
+ c0 = f32[] parameter(0)
+ p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1)
+ // Reduction lowered by GpuElementalIrEmitter.
+ ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add
+ })"))
+ .ValueOrDie();
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ ROOT reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = f32[128,512]{1,0} fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1)
+ ROOT reduce = f32[8,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={1,3}, to_apply=scalar_add
+ }
+ ENTRY entry {
+ p0 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ ROOT fusion = f32[8,5,1,1]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce.0 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ reduce.1 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ ROOT root = (f32[128,512]{1,0}, f32[128,512]{1,0}) tuple(reduce.0, reduce.1)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[128,512]{1,0}, f32[128,512]{1,0}) fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+ ROOT root = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce.0 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ reduce.1 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ ROOT root = (f32[512,28]{1,0}, f32[512,28]{1,0}) tuple(reduce.0, reduce.1)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[512,28]{1,0}, f32[512,28]{1,0}) fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+TEST_F(GpuFusibleTest,
+ IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_reduction {
+ c0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ reduce = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1)
+ ROOT root = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul)
+ }
+ ENTRY entry {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ ROOT fusion = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_reduction
+ })"))
+ .ValueOrDie();
+ const HloInstruction* reduce =
+ module->entry_computation()->root_instruction();
+ ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsInputFusibleReduction(*reduce));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index 19de37b0fb..02a0d028c1 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -17,12 +17,13 @@ limitations under the License.
#include <memory>
#include <unordered_map>
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -59,8 +60,8 @@ GpuHloOrdering::GpuHloOrdering(
: PredecessorHloOrdering(module) {
// The entry computation has a total order when there's only one stream.
if (stream_assignment.StreamCount() == 1) {
- entry_sequence_ =
- MakeUnique<std::vector<const HloInstruction*>>(thunk_launch_order);
+ entry_sequence_ = absl::make_unique<std::vector<const HloInstruction*>>(
+ thunk_launch_order);
}
// The ordering of instructions for the entry computation is determined by the
@@ -75,7 +76,7 @@ GpuHloOrdering::GpuHloOrdering(
// same-stream predecessors of each instruction.
// Compute the set of all instructions we will want to set reachability on.
- auto predecessor_map = MakeUnique<HloReachabilityMap>(
+ auto predecessor_map = absl::make_unique<HloReachabilityMap>(
module->entry_computation()->MakeInstructionPostOrder());
// The most recently visited instruction per stream.
@@ -184,13 +185,13 @@ void BFSLaunchOrder(const HloComputation* computation,
} // end namespace
-HloSchedule::HloSchedule() {}
+GpuHloSchedule::GpuHloSchedule() {}
/* static */
-StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
+StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build(
const HloModule& module, const StreamAssignment& stream_assignment,
int64 pointer_size) {
- std::unique_ptr<HloSchedule> schedule(new HloSchedule);
+ std::unique_ptr<GpuHloSchedule> schedule(new GpuHloSchedule);
// Initialize thunk_launch_order_, the total order of thunk launches.
const HloComputation* entry_computation = module.entry_computation();
@@ -198,17 +199,18 @@ StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
// All kernels are launched on a single stream, so there's no loss of
// concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN(
- schedule->thunk_launch_order_,
- ScheduleOneComputation(
+ HloInstructionSequence sequence,
+ ScheduleComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
}));
+ schedule->thunk_launch_order_ = sequence.instructions();
} else {
// BFS tends to increase concurrency, but also increases memory usage.
BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
}
- schedule->hlo_ordering_ = MakeUnique<GpuHloOrdering>(
+ schedule->hlo_ordering_ = absl::make_unique<GpuHloOrdering>(
&module, stream_assignment, schedule->thunk_launch_order_);
return std::move(schedule);
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
index 1ce7a48ac8..07a7fc67aa 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
#include <memory>
#include <vector>
@@ -33,12 +33,14 @@ namespace gpu {
// launches, because thunks may be scheduled onto concurrent streams. This
// schedule is used by BufferAssigner to determine buffer liveness (i.e. to
// minimize allocations), and also by ThunkSchedule to determine the thunk
-// launch order.
-class HloSchedule {
+// launch order. This class differs from xla::HloSchedule in that HloSchedule
+// represents a total order of all instructions in the module for backends which
+// execute HLO instructions strictly sequentially.
+class GpuHloSchedule {
public:
- // Constructs an HloSchedule for the given module, based on the given stream
- // assignment.
- static StatusOr<std::unique_ptr<HloSchedule>> Build(
+ // Constructs an GpuHloSchedule for the given module, based on the given
+ // stream assignment.
+ static StatusOr<std::unique_ptr<GpuHloSchedule>> Build(
const HloModule& module, const StreamAssignment& stream_assignment,
int64 pointer_size);
@@ -56,7 +58,7 @@ class HloSchedule {
}
private:
- HloSchedule();
+ GpuHloSchedule();
std::vector<const HloInstruction*> thunk_launch_order_;
std::unique_ptr<HloOrdering> hlo_ordering_;
@@ -65,4 +67,4 @@ class HloSchedule {
} // namespace gpu
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 45f0a1c645..b857fa775a 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -13,32 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include <algorithm>
#include <unordered_set>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class HloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
// Pre-canned shapes.
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
- static std::unique_ptr<HloSchedule> BuildHloSchedule(
+ static std::unique_ptr<GpuHloSchedule> BuildGpuHloSchedule(
const HloModule& module, const StreamAssignment& streams) {
- return HloSchedule::Build(module, streams, /*pointer_size=*/8)
+ return GpuHloSchedule::Build(module, streams, /*pointer_size=*/8)
.ConsumeValueOrDie();
}
@@ -47,7 +49,7 @@ class HloScheduleTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
HloVec RemoveHlo(const HloVec& input,
@@ -64,7 +66,7 @@ class HloScheduleTest : public HloTestBase {
// Test of a single stream, where data dependencies fully determine the
// execution order.
-TEST_F(HloScheduleTest, SequentialMatMul) {
+TEST_F(GpuHloScheduleTest, SequentialMatMul) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
@@ -72,10 +74,10 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -84,7 +86,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({dot1, dot2}));
@@ -122,7 +124,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) {
// Test of a single stream, where data dependencies do not fully determine the
// execution order, but the stream assignment does.
-TEST_F(HloScheduleTest, SequentialAdd) {
+TEST_F(GpuHloScheduleTest, SequentialAdd) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
@@ -146,7 +148,7 @@ TEST_F(HloScheduleTest, SequentialAdd) {
EXPECT_EQ(streams->StreamNumberForHlo(*add1),
streams->StreamNumberForHlo(*add3));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}),
HloVec({add1, add2, add3}));
@@ -194,18 +196,18 @@ TEST_F(HloScheduleTest, SequentialAdd) {
}
// Test of two streams.
-TEST_F(HloScheduleTest, ConcurrentMatMul) {
+TEST_F(GpuHloScheduleTest, ConcurrentMatMul) {
HloComputation::Builder builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
- HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* add =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(add));
@@ -214,7 +216,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
EXPECT_NE(streams->StreamNumberForHlo(*dot1),
streams->StreamNumberForHlo(*dot2));
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
// Remove parameters, which are unordered.
HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y});
EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) ||
@@ -250,7 +252,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) {
}
// Test of multiple streams.
-TEST_F(HloScheduleTest, LatticeMatMul) {
+TEST_F(GpuHloScheduleTest, LatticeMatMul) {
// d00 -- layer 0
// / \
// d10 d11 -- layer 1
@@ -265,26 +267,26 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
- i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
@@ -306,7 +308,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) {
// We don't check the thunk launch order, since there are many valid total
// orders, and it's annoying to express.
- auto schedule = BuildHloSchedule(*module, *streams);
+ auto schedule = BuildGpuHloSchedule(*module, *streams);
auto order = schedule->ConsumeHloOrdering();
const HloVec all_params(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
index 4944c41f7d..4268fb2c7a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
@@ -34,9 +34,8 @@ StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) {
return xla::Unimplemented(
"GPU backend does not support HLO instruction %s with shape "
"containing a sparse layout: %s",
- instruction->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction->shape())
- .c_str());
+ instruction->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction->shape()));
}
return Status::OK();
}));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index d63e213d2b..bbb3340760 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface {
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
- tensorflow::StringPiece name() const override {
- return "gpu_hlo_support_checker";
- }
+ absl::string_view name() const override { return "gpu_hlo_support_checker"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089df4c..27a4d0b601 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 286547ebae..fbc8ddf599 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
// Enumerate all combinations of shapes.
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
@@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
for (const Shape& input_shape : AllLayoutsOf(shape)) {
for (const Shape& result_shape : AllLayoutsOf(shape)) {
for (int constrained_param_no : {0, 4}) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
+ SCOPED_TRACE(absl::StrCat(
"input_shape=", ShapeUtil::HumanStringWithLayout(input_shape),
", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape)));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index a2f53f8446..f3c2744292 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "llvm/IR/DataLayout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -83,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
infeed_manager->EnqueueDestination(std::move(buffers));
@@ -96,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed(
StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
se::StreamExecutor* executor, int64 size, const void* source) {
if (size > std::numeric_limits<int32>::max()) {
- return InvalidArgument("Infeed shape is too large: needs %lld bytes", size);
+ return InvalidArgument("Infeed shape is too large: needs %d bytes", size);
}
if (size == 0) {
@@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
if (ShapeUtil::IsTuple(shape)) {
return;
}
- *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ *buffer = absl::make_unique<gpu::OutfeedBuffer>(
+ GetByteSizeRequirement(shape));
(*buffer)->set_destination(
- MakeUnique<MutableBorrowingLiteral>(literal, index));
+ absl::make_unique<MutableBorrowingLiteral>(literal, index));
});
// Give the tree of buffers to the outfeed mananger. The device will fill it
@@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() {
- return xla::MakeUnique<xla::gpu::GpuTransferManager>(
+ return absl::make_unique<xla::gpu::GpuTransferManager>(
/*id=*/stream_executor::cuda::kCudaPlatformId,
/*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout)
.getPointerSize(0 /* default address space */));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 7929042869..fa88816bc8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
#include <vector>
@@ -61,4 +61,4 @@ class GpuTransferManager : public GenericTransferManager {
} // namespace gpu
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
index 1722676930..b9c21e8edb 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -33,7 +34,7 @@ namespace gpu {
namespace {
void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
se::Stream* stream) {
- timers->push(MakeUnique<se::Timer>(stream->parent()));
+ timers->push(absl::make_unique<se::Timer>(stream->parent()));
stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
}
@@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler(
CHECK(hlo_instructions_.insert(hlo_instruction).second)
<< hlo_instruction->name();
}
- return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
+ return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 8c11cd0541..51627402b4 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -24,20 +25,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace gpu {
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
void HloToIrBindings::EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) {
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos) {
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
index eee40b0e91..c0edae530c 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -25,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace gpu {
@@ -45,8 +45,8 @@ class HloToIrBindings {
alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {}
void EmitBasePointersForHlos(
- tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
- tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);
+ absl::Span<const HloInstruction* const> io_hlos,
+ absl::Span<const HloInstruction* const> non_io_hlos);
// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index c5f0cdf6cd..a4364b0deb 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
namespace xla {
namespace gpu {
@@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
tensorflow::mutex_lock l(host_to_device_stream_mu_);
if (host_to_device_executor_ == nullptr) {
host_to_device_executor_ = executor;
- host_to_device_stream_ = MakeUnique<se::Stream>(executor);
+ host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_->Init();
}
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index fee6d2af3b..8c3a026740 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
VLOG(2) << "Infeeding to GPU complete";
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 0f2c83aeb2..4d5d8e99f8 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
@@ -26,7 +27,7 @@ namespace gpu {
namespace {
-bool IsFusile(const HloInstruction& hlo) {
+bool IsFusible(const HloInstruction& hlo) {
// Don't fuse get-tuple-element on GPU: We can, but it's slower than not
// fusing. We never generate kernels for unfused GTEs. Instead, if an
// unfused GTE is an input to a kernel (including a fusion kernel), we
@@ -41,7 +42,7 @@ bool IsFusile(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kFusion ||
hlo.opcode() == HloOpcode::kGather ||
- hlo.opcode() == HloOpcode::kPad ||
+ hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReduceWindow ||
hlo.opcode() == HloOpcode::kReshape ||
@@ -221,6 +222,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
+ // Do not fuse into reduce input fusions if the resulting kernel would suffer
+ // from poor data locality (due to unfriendly input layouts).
+ if (IsInputFusibleReduction(*consumer) &&
+ !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
+ return false;
+ }
+
// We can't fuse library calls, so if a user of such an op could become a
// bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
// further rationale.
@@ -245,7 +253,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return true;
}
- if (!IsFusile(*producer) || !IsFusile(*consumer) ||
+ if (!IsFusible(*producer) || !IsFusible(*consumer) ||
!InstructionFusion::ShouldFuse(consumer, operand_index)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 8d0522bd8f..96bfe0c12e 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
@@ -171,6 +172,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
op::Reduce(op::Broadcast(op::Constant()), op::Constant()));
}
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ constant.1 = f32[] constant(0)
+ ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add
+ })")
+ .ValueOrDie();
+
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ fused_reduce {
+ p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0)
+ mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1)
+ c0.1 = f32[] constant(0)
+ ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add
+ }
+
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce
+ ROOT root = (f32[]) tuple(fusion)
+ })")
+ .ValueOrDie();
+
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+ ENTRY entry {
+ p0 = f32[16,16,16,16]{3,2,1,0} parameter(0)
+ copy = f32[16,16,16,16]{0,1,2,3} copy(p0)
+ ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy()));
+}
+
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
auto module = ParseHloString(R"(
HloModule test_module
@@ -365,7 +438,7 @@ static StatusOr<const HloInstruction*> FindHloInstruction(
}
return NotFound(
"Computation '%s' does not contain an instruction with op code '%s'.",
- computation.name().c_str(), HloOpcodeString(op).c_str());
+ computation.name(), HloOpcodeString(op));
}
TEST_F(InstructionFusionTest, MultiOutputFusion) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index c349063c71..22f43bc08b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -144,10 +145,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(
- const char* call_target, const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+static HloInstruction* CreateCudnnConv(const char* call_target,
+ const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
HloComputation* computation = lhs->parent();
// This call returns a tuple of (conv_result, scratch_memory), where
@@ -165,28 +168,34 @@ static HloInstruction* CreateCudnnConv(
HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
custom_call->set_window(window);
custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
return custom_call;
}
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums);
+ window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums);
+ reverse_filter, window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums);
+ output, window, dnums, feature_group_count);
}
bool IsReductionToVector(const HloInstruction& reduce) {
@@ -215,8 +224,8 @@ bool IsReductionToVector(const HloInstruction& reduce) {
// This emits a device-side call to
// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+llvm::Value* EmitPrintf(absl::string_view fmt,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder) {
std::vector<llvm::Type*> argument_types;
for (auto argument : arguments) {
@@ -279,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params) {
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
+ const auto& target = custom_call->custom_call_target();
+ const auto& lhs_shape = custom_call->operand(0)->shape();
+ const auto& rhs_shape = custom_call->operand(1)->shape();
+ const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+ params->window = &custom_call->window();
+ params->dnums = &custom_call->convolution_dimension_numbers();
+ params->feature_group_count = custom_call->feature_group_count();
+ params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params->kind = CudnnConvKind::kForward;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &conv_result_shape;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params->kind = CudnnConvKind::kBackwardInput;
+ params->input_shape = &conv_result_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &lhs_shape;
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params->kind = CudnnConvKind::kBackwardFilter;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &conv_result_shape;
+ params->output_shape = &rhs_shape;
+ } else {
+ LOG(FATAL) << "Unexpected custom call target: "
+ << custom_call->custom_call_target();
+ }
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 5d23a3d018..09c455cc1e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@@ -109,15 +111,20 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
//
// The created cudnn call will use the default cudnn algorithm and no scratch
// space.
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
@@ -126,8 +133,8 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo);
bool IsReductionToVector(const HloInstruction& reduce);
// Emits call to "vprintf" with given format and arguments.
-llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
- tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+llvm::Value* EmitPrintf(absl::string_view fmt,
+ absl::Span<llvm::Value* const> arguments,
llvm::IRBuilder<>* builder);
// Emits code to shuffle data between threads of a warp. This has the same
@@ -143,6 +150,11 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution. Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 6675dbd3f9..b7c37bcf3c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/algorithm/container.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -140,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
Status IrEmitter::EmitCallToNestedComputation(
const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output) {
+ absl::Span<llvm::Value* const> operands, llvm::Value* output) {
TF_RET_CHECK(nested_computation.num_parameters() > 0);
llvm::Function*& emitted_function =
computation_to_ir_function_[&nested_computation];
@@ -155,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation(
std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
arguments.push_back(output);
arguments.push_back(bindings_.GetTempBufferBase());
- b_.CreateCall(emitted_function, arguments);
+ Call(emitted_function, arguments);
return Status::OK();
}
@@ -177,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
computation.root_instruction()->shape().element_type();
bool is_atomic_integral = element_type == S32 || element_type == U32 ||
element_type == S64 || element_type == U64;
- llvm::Value* source = b_.CreateLoad(source_address, "source");
+ llvm::Value* source = Load(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
if (element_type == F32) {
@@ -189,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
}
if (is_atomic_integral) {
// integral + integral
- b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
}
@@ -201,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::UMax;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -211,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::UMin;
- b_.CreateAtomicRMW(opcode, output_address, source,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ AtomicRMW(opcode, output_address, source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@@ -291,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// cas_old_output_address and cas_new_output_address point to the scratch
// memory where we store the old and new values for the repeated atomicCAS
// operations.
- llvm::Value* cas_old_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
- llvm::Value* cas_new_output_address = b_.CreateAlloca(
- atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
+ llvm::Value* cas_old_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address");
+ llvm::Value* cas_new_output_address =
+ Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address");
// Emit preparation code to the preheader.
llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock();
@@ -308,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
CHECK_EQ((element_size % sizeof(char)), 0);
llvm::Type* address_int_type =
module_->getDataLayout().getIntPtrType(output_address_type);
- atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type);
+ atomic_memory_address = PtrToInt(output_address, address_int_type);
llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3);
- llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask);
+ llvm::Value* offset = And(atomic_memory_address, mask);
mask = llvm::ConstantInt::get(address_int_type, -4);
- atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask);
+ atomic_memory_address = And(atomic_memory_address, mask);
atomic_memory_address =
- b_.CreateIntToPtr(atomic_memory_address, atomic_address_type);
- binop_output_address = b_.CreateAdd(
- b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset);
+ IntToPtr(atomic_memory_address, atomic_address_type);
binop_output_address =
- b_.CreateIntToPtr(binop_output_address, element_address_type);
+ Add(PtrToInt(cas_new_output_address, address_int_type), offset);
+ binop_output_address = IntToPtr(binop_output_address, element_address_type);
} else {
- atomic_memory_address =
- b_.CreateBitCast(output_address, atomic_address_type);
+ atomic_memory_address = BitCast(output_address, atomic_address_type);
binop_output_address =
- b_.CreateBitCast(cas_new_output_address, element_address_type);
+ BitCast(cas_new_output_address, element_address_type);
}
// Use the value from the memory that atomicCAS operates on to initialize
// cas_old_output.
- llvm::Value* cas_old_output =
- b_.CreateLoad(atomic_memory_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_old_output_address);
+ llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output");
+ Store(cas_old_output, cas_old_output_address);
llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
b_.GetInsertPoint(), "atomic_op_loop_exit");
@@ -343,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// Emit the body of the loop that repeatedly invokes atomicCAS.
//
// Use cas_old_output to initialize cas_new_output.
- cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output");
- b_.CreateStore(cas_old_output, cas_new_output_address);
+ cas_old_output = Load(cas_old_output_address, "cas_old_output");
+ Store(cas_old_output, cas_new_output_address);
// Emits code to calculate new_output = operation(old_output, source);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
computation, {binop_output_address, source_address},
binop_output_address));
- llvm::Value* cas_new_output =
- b_.CreateLoad(cas_new_output_address, "cas_new_output");
+ llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output");
// Emit code to perform the atomicCAS operation
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
// cas_new_output);
- llvm::Value* ret_value = b_.CreateAtomicCmpXchg(
- atomic_memory_address, cas_old_output, cas_new_output,
- llvm::AtomicOrdering::SequentiallyConsistent,
- llvm::AtomicOrdering::SequentiallyConsistent);
+ llvm::Value* ret_value =
+ AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ llvm::AtomicOrdering::SequentiallyConsistent);
// Extract the memory value returned from atomicCAS and store it as
// cas_old_output.
- b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"),
- cas_old_output_address);
+ Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address);
// Extract the success bit returned from atomicCAS and generate a
// conditional branch on the success bit.
- b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
- loop_body_bb);
+ CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb);
// Set the insertion point to the exit basic block so that the caller of
// this method can continue emitting code to the right place.
@@ -383,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation(
// TODO(b/30258929): We only accept binary computations so far.
return Unimplemented(
"We only support atomic functions with exactly two parameters, but "
- "computation %s has %lld.",
- computation.name().c_str(), computation.num_parameters());
+ "computation %s has %d.",
+ computation.name(), computation.num_parameters());
}
if (MaybeEmitDirectAtomicOperation(computation, output_address,
@@ -471,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
- result = b_.CreateInsertValue(result, value.first, {0});
- result = b_.CreateInsertValue(result, value.second, {1});
+ result = InsertValue(result, value.first, {0});
+ result = InsertValue(result, value.second, {1});
} else {
- result = b_.CreateFMul(lhs_value, rhs_value);
+ result = FMul(lhs_value, rhs_value);
}
target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
return Status::OK();
@@ -518,7 +513,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// We don't have to iterate over the batch dimensions in both arrays, simplify
// the loop nest of the rhs.
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
- DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
+ DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
rhs_index[i] = lhs_index[i];
}
@@ -558,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
- llvm::Value* accum = b_.CreateLoad(accum_address);
+ llvm::Value* accum = Load(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
llvm::Value* accum_real = Real(accum, &b_);
- llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first);
- updated_accum = b_.CreateInsertValue(accum, real_sum, {0});
+ llvm::Value* real_sum = FAdd(accum_real, value.first);
+ updated_accum = InsertValue(accum, real_sum, {0});
llvm::Value* accum_imag = Imag(accum, &b_);
- llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second);
- updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1});
+ llvm::Value* imag_sum = FAdd(accum_imag, value.second);
+ updated_accum = InsertValue(updated_accum, imag_sum, {1});
} else {
- llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element);
- updated_accum = b_.CreateFAdd(accum, product);
+ llvm::Value* product = FMul(lhs_element, rhs_element);
+ updated_accum = FAdd(accum, product);
}
- b_.CreateStore(updated_accum, accum_address);
+ Store(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
// address. The index into the target address is the concatenation of the rhs
@@ -594,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
target_array.EmitWriteArrayElement(
target_index,
- b_.CreateLoad(accum_address), // The value written to the target array.
+ Load(accum_address), // The value written to the target array.
&b_);
// Set the IR builder insert point to the exit basic block of the outer most
@@ -638,17 +633,16 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
return EmitTargetElementLoop(
*reduce,
[=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
- b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ Alloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), module_));
- b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)),
- accumulator_addr);
+ Store(Load(GetBasePointer(*init_value)), accumulator_addr);
// The enclosing loops go over all the target elements. Now we have to
// compute the actual target element. For this, we build a new loop nest
@@ -685,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
*function, {accumulator_addr, input_address}, accumulator_addr));
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return b_.CreateLoad(accumulator_addr);
+ return Load(accumulator_addr);
});
}
@@ -752,14 +746,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleIota(HloInstruction*) {
- // TODO(b/64798317): implement iota on GPU.
- return Unimplemented("Iota is not implemented on GPU.");
-}
-
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
+ absl::Span<llvm::Value* const> parameter_elements) {
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), module_),
@@ -768,11 +757,26 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
- b_.CreateStore(parameter_element, parameter_buffers.back());
+ Store(parameter_element, parameter_buffers.back());
}
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
return_buffer));
- return b_.CreateLoad(return_buffer);
+ return Load(return_buffer);
+}
+
+std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo) {
+ std::vector<llvm_ir::IrArray> output_arrays;
+ if (ShapeUtil::IsTuple(hlo.shape())) {
+ int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays.reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays.push_back(GetIrArray(hlo, hlo));
+ }
+ return output_arrays;
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 561c683879..8805201480 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@@ -35,13 +37,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -64,7 +65,8 @@ namespace gpu {
// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is
// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR
// generator generator. See comments on that class.
-class IrEmitter : public DfsHloVisitorWithDefault {
+class IrEmitter : public DfsHloVisitorWithDefault,
+ public IrBuilderMixin<IrEmitter> {
public:
IrEmitter(const IrEmitter&) = delete;
IrEmitter& operator=(const IrEmitter&) = delete;
@@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
- Status HandleIota(HloInstruction* iota) override;
Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
+ llvm::IRBuilder<>* builder() { return &b_; }
+
protected:
// Constructs an IrEmitter with the given IrEmitter context.
// ir_emitter_context is owned by the caller and should outlive the IrEmitter
@@ -121,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
return bindings_.GetBasePointer(inst);
}
+
+ // Generates the IrArray for each output of an hlo instruction and returns
+ // a vector containing such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
+ const HloInstruction& hlo);
+
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
BufferAllocation::Slice GetAllocationSlice(
const HloInstruction& hlo, const ShapeIndex& index = {}) const {
@@ -140,9 +149,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Emits a call in IR to the given nested computation with the given operands
// and output. If no IR function has been previously emitted for the
// computation, also emits such a function.
- Status EmitCallToNestedComputation(
- const HloComputation& nested_computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output);
+ Status EmitCallToNestedComputation(const HloComputation& nested_computation,
+ absl::Span<llvm::Value* const> operands,
+ llvm::Value* output);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
@@ -196,7 +205,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
StatusOr<llvm::Value*> ComputeNestedElement(
const HloComputation& computation,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
+ absl::Span<llvm::Value* const> parameter_elements);
// Emits an atomic operation that implements `nested_computation` in the
// sequentially consistent memory model. `output_address` and `source_address`
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
index 5c827e5f9c..66c65f6975 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop(
// For MOF we give the loop emitter an array for every output it should
// generate.
if (hlo.IsMultiOutputFusion()) {
- const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape());
- std::vector<llvm_ir::IrArray> target_arrays;
- target_arrays.reserve(num_elems);
- for (int64 i = 0; i != num_elems; ++i) {
- target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<llvm_ir::IrArray> target_arrays =
+ ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
-
- std::vector<llvm::Value*> tuple_operand_ptrs;
- tuple_operand_ptrs.reserve(num_elems);
- for (const llvm_ir::IrArray& array : target_arrays) {
- tuple_operand_ptrs.push_back(array.GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
return Status::OK();
}
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1e81cbde35..b669881026 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -21,6 +21,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -29,7 +35,6 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
@@ -56,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -76,8 +82,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -85,13 +89,12 @@ namespace gpu {
namespace {
+using absl::InlinedVector;
+using absl::nullopt;
+using absl::optional;
+using absl::StrCat;
using llvm_ir::IrArray;
using llvm_ir::IrName;
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::InlinedVector;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
-using tensorflow::strings::StrCat;
// If a dimensions is smaller than this, untiled transposition may be more
// efficient.
@@ -173,7 +176,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
+ absl::Span<const BufferAllocation* const> args) {
// Compute the kernel name. The opcode string may contain "-" which cannot be
// in a PTX function name, so sanitize the name before uniquifying it.
string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
@@ -314,13 +317,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
};
// Check the size of input tensors
- if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
+ if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
return i64_ty;
}
// Check the size of the internal result tensors
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
- if (!c_all_of(
+ if (!absl::c_all_of(
unnested_hlo->fused_instructions_computation()->instructions(),
hlo_shape_in_range)) {
return i64_ty;
@@ -383,7 +386,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
int64 feature_index_value = feature_index->literal().Get<int64>({});
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardInferenceThunk>(
+ absl::make_unique<CudnnBatchNormForwardInferenceThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -413,7 +416,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
thunk_sequence_->emplace_back(
- MakeUnique<CudnnBatchNormForwardTrainingThunk>(
+ absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
/*operand=*/GetAllocationSlice(*custom_call->operand(0)),
/*scale=*/GetAllocationSlice(*custom_call->operand(1)),
/*offset=*/GetAllocationSlice(*custom_call->operand(2)),
@@ -443,85 +446,54 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
auto output_grad_offset =
assn.GetUniqueSlice(custom_call, {2}).ValueOrDie();
- thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>(
- /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
- /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
- /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
- /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
- /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
- /*epsilon=*/epsilon_value,
- /*feature_index=*/feature_index_value,
- /*output_grad_data=*/output_grad_data,
- /*output_grad_scale=*/output_grad_scale,
- /*output_grad_offset=*/output_grad_offset,
- /*output_tuple=*/GetAllocationSlice(*custom_call),
- /*hlo=*/custom_call));
+ thunk_sequence_->emplace_back(
+ absl::make_unique<CudnnBatchNormBackwardThunk>(
+ /*operand=*/GetAllocationSlice(*custom_call->operand(0)),
+ /*scale=*/GetAllocationSlice(*custom_call->operand(1)),
+ /*mean=*/GetAllocationSlice(*custom_call->operand(2)),
+ /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)),
+ /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)),
+ /*epsilon=*/epsilon_value,
+ /*feature_index=*/feature_index_value,
+ /*output_grad_data=*/output_grad_data,
+ /*output_grad_scale=*/output_grad_scale,
+ /*output_grad_offset=*/output_grad_offset,
+ /*output_tuple=*/GetAllocationSlice(*custom_call),
+ /*hlo=*/custom_call));
return Status::OK();
}
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
- std::unique_ptr<ConvolutionThunk> thunk;
+ BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
if (target == kCudnnConvForwardCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
- CudnnConvKind::kForward,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/conv_result_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ input_slice = lhs_slice;
+ filter_slice = rhs_slice;
+ output_slice = conv_result_slice;
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
- CudnnConvKind::kBackwardInput,
- /*input_buffer=*/conv_result_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/lhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/lhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ input_slice = conv_result_slice;
+ filter_slice = rhs_slice;
+ output_slice = lhs_slice;
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = MakeUnique<ConvolutionThunk>(
- CudnnConvKind::kBackwardFilter,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/conv_result_slice,
- /*output_buffer=*/rhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape,
- /*output_shape=*/rhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ input_slice = lhs_slice;
+ filter_slice = conv_result_slice;
+ output_slice = rhs_slice;
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
}
- thunk_sequence_->emplace_back(std::move(thunk));
+ thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+ Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+ output_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
@@ -552,10 +524,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
- ArraySlice<HloInstruction*> output_instructions =
+ absl::Span<HloInstruction* const> output_instructions =
root->opcode() == HloOpcode::kTuple
? root->operands()
- : ArraySlice<HloInstruction*>(&root, 1);
+ : absl::Span<HloInstruction* const>(&root, 1);
// For multi-output fusion emit an initializer for each tuple element.
// Otherwise it's sufficient to just initialize the single output.
@@ -576,7 +548,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunks.push_back(
BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), fusion));
+ absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
@@ -714,8 +686,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
const HloInstruction* reduce, const IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
for (int i = 0; i != extra_output_gens.size(); ++i) {
const HloInstruction* output = reduce->parent()->FusionInstruction();
@@ -725,19 +696,18 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce(
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
extra_output_gens[i].first(index));
- b_.CreateStore(extra_output_ir_value, extra_output_address);
+ Store(extra_output_ir_value, extra_output_address);
}
return Status::OK();
}
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
@@ -798,8 +768,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize),
// //
// // and threads_per_block is a multiple of warpSize.
- // reduce_kernel<<<num_blocks, threads_per_block>>>();
- //
+ // reduce_kernel //
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
@@ -807,17 +776,17 @@ Status IrEmitterUnnested::EmitReductionToScalar(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
@@ -829,15 +798,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)),
+ tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
@@ -846,11 +814,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
IrArray::Index input_index(
/*linear=*/x, input_shape, &b_);
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -861,14 +829,14 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileSize),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileSize),
+ NSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
- b_.getInt1(all_threads_in_bounds));
+ Or(ICmpULE(x_end, index_typed_constant(num_elems)),
+ b_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
@@ -889,20 +857,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -917,10 +883,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id =
- b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
@@ -952,12 +917,11 @@ Status IrEmitterUnnested::EmitReductionToScalar(
Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Divide the input matrix into tiles of size KxL. For example, when the
// input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
@@ -1040,12 +1004,12 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." +
+ llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1056,8 +1020,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x_in_tiles = tile_index[1];
- y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty);
+ y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty);
+ x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
@@ -1069,34 +1033,32 @@ Status IrEmitterUnnested::EmitColumnReduction(
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&b_);
- llvm::Value* y = b_.CreateNSWAdd(
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
- tile_element_loop->GetIndVarValue());
+ llvm::Value* y =
+ NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
+ tile_element_loop->GetIndVarValue());
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds",
- &b_);
+ ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds",
- &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
}
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@@ -1123,7 +1085,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
@@ -1138,20 +1100,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
- llvm::Value* y_end = b_.CreateNSWAdd(
- index_typed_constant(kTileHeight),
- b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
+ llvm::Value* y_end =
+ NSWAdd(index_typed_constant(kTileHeight),
+ NSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
- llvm::Value* x_end = b_.CreateNSWAdd(
- index_typed_constant(kTileWidth),
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
+ llvm::Value* x_end =
+ NSWAdd(index_typed_constant(kTileWidth),
+ NSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds =
- b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)),
- b_.getInt1(height % kTileHeight == 0));
+ Or(ICmpULE(y_end, index_typed_constant(height)),
+ b_.getInt1(height % kTileHeight == 0));
llvm::Value* tile_in_x_bounds =
- b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)),
- b_.getInt1(width % kTileWidth == 0));
+ Or(ICmpULE(x_end, index_typed_constant(width)),
+ b_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
@@ -1185,9 +1147,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x = b_.CreateNSWAdd(
- b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
+ llvm::Value* x =
+ NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
+ index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
@@ -1243,12 +1205,11 @@ static std::pair<int64, int64> ComputeTilingSchemeForReduction(
Status IrEmitterUnnested::EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// A naive algorithm is:
// 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
@@ -1376,11 +1337,11 @@ Status IrEmitterUnnested::EmitRowReduction(
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address =
- b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
+ Alloca(element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
- b_.CreateStore(init_ir_value, partial_reduction_result_address);
+ Store(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
@@ -1389,22 +1350,20 @@ Status IrEmitterUnnested::EmitRowReduction(
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
- x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty);
+ x_tile = ZExtOrTrunc(x_tile, index_ty);
llvm::Value* warp_id =
- b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
+ UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id =
- b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id");
+ URem(x_tile, index_typed_constant(kWarpSize), "lane_id");
// The x-location of the last element in this z-x-tile.
// last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = b_.CreateNSWAdd(
+ llvm::Value* last_x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- index_typed_constant(x_tile_size - 1),
- b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(index_typed_constant(x_tile_size - 1),
+ NSWMul(warp_id, index_typed_constant(x_tile_size)))));
KernelSupportLibrary ksl(
&b_,
@@ -1416,9 +1375,8 @@ Status IrEmitterUnnested::EmitRowReduction(
auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
int64 x_tile_loop_bound) -> Status {
auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z = b_.CreateNSWAdd(
- z_indvar,
- b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile));
+ llvm::Value* z =
+ NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile));
TF_RETURN_IF_ERROR(ksl.For(
"x_tile",
/*start=*/index_typed_constant(0),
@@ -1426,22 +1384,20 @@ Status IrEmitterUnnested::EmitRowReduction(
/*step=*/1, [&](llvm::Value* x_indvar) -> Status {
// x = lane_id +
// warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = b_.CreateNSWAdd(
+ llvm::Value* x = NSWAdd(
lane_id,
- b_.CreateNSWMul(
- index_typed_constant(kWarpSize),
- b_.CreateNSWAdd(
- x_indvar, b_.CreateNSWMul(
- warp_id, llvm::ConstantInt::get(
- index_ty, x_tile_size)))));
+ NSWMul(index_typed_constant(kWarpSize),
+ NSWAdd(x_indvar,
+ NSWMul(warp_id, llvm::ConstantInt::get(
+ index_ty, x_tile_size)))));
// Unless we know the x-tile is entirely in bounds, we have to
// emit a x-in-bounds check before reading from the input.
if (!x_tile_in_bounds) {
llvm_ir::LlvmIfData if_x_in_bounds_data =
llvm_ir::EmitIfThenElse(
- b_.CreateICmpULT(x, index_typed_constant(width)),
- "x_in_bounds", &b_);
+ ICmpULT(x, index_typed_constant(width)), "x_in_bounds",
+ &b_);
// Points b_ to the then-block.
llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
&b_);
@@ -1449,7 +1405,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// Emit code that reads the input element and accumulates it
// to the partial reduction result.
- llvm::Value* input_address = b_.CreateAlloca(element_ir_type);
+ llvm::Value* input_address = Alloca(element_ir_type);
{
// {z,y,x} is an index to input_3d_tensor_shape
// [depth,height,width]. We need to convert that to an index
@@ -1480,7 +1436,7 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int i = 0; i != num_reduces; ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
input_gens[i](input_index));
- b_.CreateStore(input_ir_value, input_address);
+ Store(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
@@ -1500,8 +1456,8 @@ Status IrEmitterUnnested::EmitRowReduction(
};
llvm::Value* tile_in_bounds =
- b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- b_.CreateICmpULT(last_x, index_typed_constant(width)));
+ Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ ICmpULT(last_x, index_typed_constant(width)));
TF_RETURN_IF_ERROR(
ksl.If(tile_in_bounds,
@@ -1529,20 +1485,18 @@ Status IrEmitterUnnested::EmitRowReduction(
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
llvm::Value* result_from_other_lane =
- b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane");
+ Alloca(element_ir_type, nullptr, "result_from_other_lane");
for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result = b_.CreateLoad(
- b_.CreateBitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
+ llvm::Value* partial_reduction_result =
+ Load(BitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
<< "Requires block size a multiple of the warp size, otherwise we "
"will read undefined elements.";
- b_.CreateStore(
- EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- b_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
+ Store(EmitFullWarpShuffleDown(partial_reduction_result,
+ b_.getInt32(shuffle_distance), &b_),
+ BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], result_from_other_lane},
@@ -1557,8 +1511,7 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero",
- &b_);
+ ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
@@ -1604,13 +1557,12 @@ Status IrEmitterUnnested::EmitRowReduction(
// elementwise.
Status IrEmitterUnnested::EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
@@ -1705,7 +1657,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
}
auto input = reduce->operand(0);
auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions());
+ absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
HloComputation* reducer = reduce->to_apply();
// HandleReduce specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires an initializer thunk that
@@ -1718,7 +1670,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
thunks.push_back(
BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), reduce));
+ absl::make_unique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
reduce, input->shape(), {[&](const IrArray::Index& index) {
@@ -1738,7 +1690,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
bool all_tuple_elements_have_buffer =
- c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
+ absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
return ir_emitter_context_->buffer_assignment()
.GetUniqueTopLevelSlice(tuple_element)
.ok();
@@ -1760,7 +1712,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
for (const HloInstruction* tuple_element : tuple->operands()) {
tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
}
- thunk_sequence_->emplace_back(MakeUnique<TupleThunk>(
+ thunk_sequence_->emplace_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
@@ -1792,8 +1744,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
thunks.push_back(std::move(initializer_thunk));
thunks.push_back(BuildKernelThunk(select_and_scatter,
/*implements_whole_instruction=*/false));
- thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
+ thunk_sequence_->emplace_back(absl::make_unique<SequentialThunk>(
+ std::move(thunks), select_and_scatter));
// TODO(b/31410564): Implement dilation rate for select-and-scatter.
if (window_util::HasDilation(window)) {
@@ -1842,7 +1794,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
&b_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
b_.getInt1Ty(), "initialized_flag_address", &b_);
- b_.CreateStore(b_.getInt1(false), initialized_flag_address);
+ Store(b_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
@@ -1863,15 +1815,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = b_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
- llvm::Value* strided_index = b_.CreateNSWMul(
+ llvm::Value* strided_index = NSWMul(
source_index[i], index_typed_constant(window.dimensions(i).stride()));
- operand_index[i] = b_.CreateNSWSub(
- b_.CreateNSWAdd(strided_index, window_index[i]),
- index_typed_constant(window.dimensions(i).padding_low()));
- llvm::Value* index_condition = b_.CreateICmpULT(
+ operand_index[i] =
+ NSWSub(NSWAdd(strided_index, window_index[i]),
+ index_typed_constant(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ICmpULT(
operand_index[i],
index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
- in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition);
+ in_bounds_condition = And(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
@@ -1881,7 +1833,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
- b_.CreateLoad(initialized_flag_address), "initialized", &b_);
+ Load(initialized_flag_address), "initialized", &b_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
@@ -1889,16 +1841,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- b_.CreateStore(operand_index[i], selected_index_address_slot);
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ Store(operand_index[i], selected_index_address_slot);
}
};
IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &b_);
- b_.CreateStore(operand_data, selected_value_address);
+ Store(operand_data, selected_value_address);
save_operand_index(operand_index);
- b_.CreateStore(b_.getInt1(true), initialized_flag_address);
+ Store(b_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to
// potentially update the selected value and index with the currently
@@ -1914,11 +1866,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
{selected_value_address, operand_address}, select_return_buffer));
- llvm::Value* result = b_.CreateLoad(select_return_buffer);
+ llvm::Value* result = Load(select_return_buffer);
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
- llvm::Value* cond = b_.CreateICmpNE(
+ llvm::Value* cond = ICmpNE(
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
@@ -1927,7 +1879,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
- b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address);
+ Store(Load(operand_address), selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
@@ -1939,8 +1891,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
- b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)});
- selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
+ InBoundsGEP(selected_index_address, {b_.getInt32(i)});
+ selected_index.push_back(Load(selected_index_address_slot));
}
llvm::Value* source_value_address =
GetIrArray(*source, *select_and_scatter)
@@ -2018,7 +1970,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
thunks.push_back(std::move(rng_thunk));
thunks.push_back(std::move(increment_seed_thunk));
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), rng));
+ absl::make_unique<SequentialThunk>(std::move(thunks), rng));
return Status::OK();
}
@@ -2043,7 +1995,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
auto values_destination = GetAllocationSlice(*sort, values_shape_index);
if (keys_destination != GetAllocationSlice(*keys)) {
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*keys),
/*destination_buffer=*/keys_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr));
@@ -2051,7 +2003,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
if (values != nullptr && values_destination != GetAllocationSlice(*values)) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*values),
/*destination_buffer=*/values_destination,
/*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr));
@@ -2095,15 +2047,15 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index),
- values != nullptr ? tensorflow::gtl::make_optional<IrArray>(
+ values != nullptr ? absl::make_optional<IrArray>(
GetIrArray(*sort, *sort, values_shape_index))
- : tensorflow::gtl::nullopt,
+ : absl::nullopt,
IrName(sort), xor_mask, &b_, &launch_dimensions));
}
}
thunk_sequence_->emplace_back(
- MakeUnique<SequentialThunk>(std::move(thunks), sort));
+ absl::make_unique<SequentialThunk>(std::move(thunks), sort));
return Status::OK();
}
@@ -2130,7 +2082,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
if (crs->operand_count() == 1) {
CHECK(ShapeUtil::IsArray(crs->operand(0)->shape()))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunk_sequence_->push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
@@ -2145,17 +2097,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(crs, {i})
.ValueOrDie());
- thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
/*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
- thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), nullptr));
+ thunks.push_back(absl::make_unique<TupleThunk>(
+ tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
- MakeUnique<SequentialThunk>(std::move(thunks), crs));
+ absl::make_unique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
}
@@ -2305,7 +2257,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
for (const auto& kv : hlo_slices) {
buffers_needed.insert(kv.second.first.allocation());
}
- tensorflow::gtl::optional<const BufferAllocation*> temp_buffer;
+ absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
if (alloc.IsPreallocatedTempBuffer()) {
if (!temp_buffer.has_value()) {
@@ -2322,10 +2274,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
// We'll pass a pointer to each of the elements of `buffers` to our kernel, in
// this order.
std::vector<const BufferAllocation*> non_constant_buffers;
- c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
- [](const BufferAllocation* allocation) {
- return !allocation->is_constant();
- });
+ absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
+ [](const BufferAllocation* allocation) {
+ return !allocation->is_constant();
+ });
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
[](const BufferAllocation* a, const BufferAllocation* b) {
@@ -2364,8 +2316,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
*slice.allocation())));
CHECK_NE(loc, nullptr);
} else {
- loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()),
- {b_.getInt64(slice.offset())});
+ loc = InBoundsGEP(kernel_args.at(slice.allocation()),
+ {b_.getInt64(slice.offset())});
}
// If gte_index is nonempty, we have to dereference `loc` to get to the
@@ -2373,8 +2325,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::Type* int8_double_pointer =
llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
for (int64 idx : gte_index) {
- loc = b_.CreateBitCast(loc, int8_double_pointer);
- loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)}));
+ loc = BitCast(loc, int8_double_pointer);
+ loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
}
bindings_.BindHloToIrValue(*instr, loc, index);
@@ -2389,7 +2341,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
- return MakeUnique<KernelThunk>(
+ return absl::make_unique<KernelThunk>(
non_constant_buffers, llvm_ir::AsString(kernel->getName()),
implements_whole_instruction ? inst : nullptr, unroll_factor);
}
@@ -2398,7 +2350,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
CHECK_EQ(HloOpcode::kConstant, operand->opcode());
- return MakeUnique<HostToDeviceCopyThunk>(
+ return absl::make_unique<HostToDeviceCopyThunk>(
/*source_address=*/operand->literal().untyped_data(),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2410,7 +2362,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<DeviceToDeviceCopyThunk>(
+ return absl::make_unique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*operand),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
@@ -2430,7 +2382,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
});
- return MakeUnique<InfeedThunk>(slices, inst);
+ return absl::make_unique<InfeedThunk>(slices, inst);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
@@ -2447,7 +2399,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
*slice = status_or_slice.ConsumeValueOrDie();
}
});
- return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+ return absl::make_unique<OutfeedThunk>(std::move(slices), inst);
}
namespace {
@@ -2470,7 +2422,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (inst->opcode() == HloOpcode::kDot) {
const HloInstruction* lhs = inst->operand(0);
const HloInstruction* rhs = inst->operand(1);
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2512,7 +2464,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* rhs =
inst->operand(rhs_parameter->parameter_number());
- return MakeUnique<GemmThunk>(
+ return absl::make_unique<GemmThunk>(
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
GetAllocationSlice(*inst), // The output buffer.
@@ -2529,23 +2481,24 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
- return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(),
- /*input_buffer=*/GetAllocationSlice(*operand),
- /*output_buffer=*/GetAllocationSlice(*inst),
- /*input_shape=*/operand->shape(),
- /*output_shape=*/inst->shape(), inst);
+ return absl::make_unique<FftThunk>(
+ inst->fft_type(), inst->fft_length(),
+ /*input_buffer=*/GetAllocationSlice(*operand),
+ /*output_buffer=*/GetAllocationSlice(*inst),
+ /*input_shape=*/operand->shape(),
+ /*output_shape=*/inst->shape(), inst);
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index) {
+ HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
- const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value_operand = [&] {
+ HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+ HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
- return inst->operand(2);
+ return inst->mutable_operand(2);
case HloOpcode::kReduce:
- return inst->operand(1);
+ return inst->mutable_operand(1);
case HloOpcode::kTuple:
CHECK(hlo->IsMultiOutputFusion())
<< ": " << hlo->ToString() << " is not a multi-output fusion.";
@@ -2553,7 +2506,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
<< ": Found '" << inst->operand(index.back())->opcode() << "' in "
<< inst->ToString() << " but expected 'reduce'.";
// For multi-output fusion look through the tuple.
- return inst->operand(index.back())->operand(1);
+ return inst->mutable_operand(index.back())->mutable_operand(1);
default:
LOG(FATAL) << "Opcode " << inst->opcode()
<< " should not need an initializer.";
@@ -2580,11 +2533,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// Are all the bytes of this scalar equal to 0? If so, we can create a
// MemzeroThunk.
- ArraySlice<uint8> literal_bytes(
+ absl::Span<const uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
- if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {
- MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
+ if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
+ return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
+ nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2601,7 +2554,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
@@ -2612,7 +2565,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
literal_bytes.size() - 4) == 0) {
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
- return {MakeUnique<Memset32BitValueThunk>(
+ return {absl::make_unique<Memset32BitValueThunk>(
word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
@@ -2625,28 +2578,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
- // If the init_value was fused into this reduce we have to generate it first.
- if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
- CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
- const Literal& literal = init_value_operand->literal();
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ if (fused) {
+ // If init_value was fused into this reduce we have to generate it first.
+ std::vector<IrArray> parameter_arrays;
+ for (HloInstruction* operand : hlo->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand, *hlo));
+ }
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &b_, GetNestedComputer());
- llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
- *module_, initializer->getType(),
- /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
- /*Name=*/"");
- global_for_const->setAlignment(kConstantBufferAlignBytes);
- bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
+ TF_RETURN_IF_ERROR(
+ ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
+ } else {
+ // In the unfused case the element is already there, just read from it.
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &b_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
}
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &b_);
- },
- GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
- .EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)
@@ -2670,8 +2630,7 @@ Status CheckHloBuffersShareAllocation(
if (slice_a != slice_b) {
return InternalError(
"instruction %s %s does not share allocation with instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
+ a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
}
return Status::OK();
}
@@ -2764,7 +2723,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<WhileThunk>(
+ return absl::make_unique<WhileThunk>(
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence(), hlo);
@@ -2782,8 +2741,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
- return MakeUnique<ForThunk>(loop_limit,
- ir_emitter_body.ConsumeThunkSequence(), hlo);
+ return absl::make_unique<ForThunk>(
+ loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
@@ -2803,7 +2762,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
ir_emitter_context_);
TF_CHECK_OK(false_computation->Accept(&ir_emitter_false));
- return MakeUnique<ConditionalThunk>(
+ return absl::make_unique<ConditionalThunk>(
GetAllocationSlice(*hlo->operand(0)),
GetAllocationSlice(*hlo->operand(1)),
GetAllocationSlice(*hlo->operand(2)),
@@ -2836,10 +2795,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
// For multioutput fusion, we need to emit each operand and the root.
- std::vector<IrArray> output_arrays;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
- output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
&b_, unroll_factor)
@@ -2847,12 +2803,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
GetIndexTypeForKernel(
&hlo, launch_dimensions.launch_bound(), &b_)));
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
+
return Status::OK();
}
@@ -2864,34 +2817,19 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
-int IrEmitterUnnested::ConstructIrArrayForOutputs(
- const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
- int64 num_outputs = 1;
- if (hlo.IsMultiOutputFusion()) {
- num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
- output_arrays->reserve(num_outputs);
- for (int64 i = 0; i < num_outputs; ++i) {
- output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
- }
- } else {
- output_arrays->push_back(GetIrArray(hlo, hlo));
- }
- return num_outputs;
-}
-
-int IrEmitterUnnested::ConstructIrArrayForInputs(
- const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
- int64 num_params = hlo.operands().size();
- param_arrays->reserve(num_params);
+std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo) {
+ std::vector<IrArray> param_arrays;
+ param_arrays.reserve(hlo.operands().size());
for (const HloInstruction* param : hlo.operands()) {
- param_arrays->push_back(GetIrArray(*param, hlo));
+ param_arrays.push_back(GetIrArray(*param, hlo));
}
- return num_params;
+ return param_arrays;
}
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<IrArray>* output_in_reduced_shape_arrays) {
int64 num_outputs = 1;
@@ -2918,7 +2856,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<IrArray>* param_in_reduced_shape_arrays) {
int64 num_params = hlo.operands().size();
@@ -3059,18 +2997,18 @@ void EmitTiledElementalCodeWithBoundsCheck(
// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
// to launch fewer blocks so each transposes many tiles.
LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
- HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ HloInstruction* hlo, absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids) {
// Parameters for the tiling algorithm.
constexpr int64 kTileSize = 32;
constexpr int64 kNumRows = 4;
constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
// Construct IrArrays for the inputs and outputs.
- std::vector<IrArray> output_arrays;
- int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
- std::vector<IrArray> param_arrays;
- int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
+ int64 num_outputs = output_arrays.size();
+ std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo);
+ int64 num_params = param_arrays.size();
// Allocate shared memory buffers to store the tiled inputs.
std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
@@ -3105,7 +3043,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
}
const int64 num_tiles =
- c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
llvm::Type* index_ty =
@@ -3151,9 +3089,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
const IrArray::Index output_tile_origin = [&] {
IrArray::Index index = output_tile_index;
for (int i = 1; i < 3; ++i) {
- index[i] =
- b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize),
- "tile_origin." + std::to_string(i));
+ index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
}
return index;
}();
@@ -3166,12 +3103,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = 1; i < 3; ++i) {
// Only last row or column may not have full size.
- output_tile_bounds[i] = b_.CreateSelect(
- b_.CreateICmpEQ(output_tile_index[i],
- index_typed_constant(output_dims_in_tiles[i] - 1)),
- index_typed_constant(reduced_output_dims[i] -
- (output_dims_in_tiles[i] - 1) * kTileSize),
- index_typed_constant(kTileSize), "kTileSize");
+ output_tile_bounds[i] =
+ Select(ICmpEQ(output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
}
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
@@ -3189,7 +3126,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// Adds `addend` to the given `dim` of `index`.
auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
- index[dim] = b_.CreateAdd(index[dim], addend);
+ index[dim] = Add(index[dim], addend);
return index;
};
const IrArray::Index input_index =
@@ -3205,10 +3142,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* shmem_buffer = param_shmem_buffers[id];
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
// global variables, so LLVM can't infer much about it.
- b_.CreateStore(
- input_in_logical_shape.EmitReadArrayElement(index, &b_,
- "input_element"),
- b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
+ Store(input_in_logical_shape.EmitReadArrayElement(index, &b_,
+ "input_element"),
+ GEP(shmem_buffer, {index_typed_constant(0), y_loc, x}));
}
});
@@ -3229,9 +3165,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
[&](const IrArray::Index& index, llvm::Value* y_loc) {
// TODO(jlebar): Add AA metadata to this load.
- llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad(
- b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
- "output_element");
+ llvm::Instruction* load_from_shmem_buffer =
+ Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}),
+ "output_element");
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
index, load_from_shmem_buffer, &b_);
});
@@ -3259,7 +3195,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
output_in_reduced_shape_arrays.size());
for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
- index, b_.CreateExtractValue(output_value, i), &b_);
+ index, ExtractValue(output_value, i), &b_);
}
} else {
output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
@@ -3270,12 +3206,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// For multioutput fusion, emit a tuple with all the individual outputs.
if (hlo->IsMultiOutputFusion()) {
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
- module_);
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_);
}
return launch_dimensions;
@@ -3308,7 +3239,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
if (!reduced_dims_021.has_value()) {
reduced_dims_021 = curr_reduced_dims_021;
}
- if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
// There is more than one possible transpose. Instead of picking one
// transpose, we simply give up here.
return false;
@@ -3341,7 +3272,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
// if there's a Right Choice.
//
// This is only sound if tiled transposes are the only place where we use
- // shared memory in fusions. If in the future other fusile ops use shared
+ // shared memory in fusions. If in the future other fusible ops use shared
// memory, we'll have to adjust this heuristic.
constexpr int kMinBlocksPerCore = 3;
constexpr int64 kShmemPerCore = 48 * 1024;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 5254419907..bd5db72051 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter {
// This kernel takes as arguments pointers to the given buffer allocations.
llvm::Function* BuildKernelPrototype(
const HloInstruction& inst,
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
+ absl::Span<const BufferAllocation* const> args);
// Helper for writing extra outputs from inside a reduce kernel.
Status EmitExtraOutputsForReduce(
const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// EmitColumnReduction and EmitRowReduction emit code for column and row
@@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
@@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter {
Status EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
Status EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Figures out whether `reduce` is a row or column reduction, and which
@@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter {
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
- tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers,
- tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
- tensorflow::gtl::ArraySlice<
- std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ absl::Span<const llvm_ir::ElementGenerator> input_gens,
+ absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
+ absl::Span<const int64> dimensions_to_reduce,
+ absl::Span<HloComputation* const> reducers,
+ absl::Span<const ShapeIndex> reduce_output_shapes,
+ absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
@@ -195,18 +190,15 @@ class IrEmitterUnnested : public IrEmitter {
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
// returns the launch dimensions for the kernel. This is a helper to support
// the implementation of CheckAndEmitHloWithTile021.
- LaunchDimensions EmitHlo021Tile(
- HloInstruction* hlo,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
- tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
- // Generates the IrArray for each output of hlo and returns the number of
- // outputs.
- int ConstructIrArrayForOutputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* output_arrays);
- // Generates the IrArray for each input of hlo and returns the number of
- // inputs.
- int ConstructIrArrayForInputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* param_arrays);
+ LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
+ absl::Span<const int64> reduced_output_dims,
+ absl::Span<const int64> tiled_param_ids);
+
+ // Generates the IrArray for each input of an hlo and returns a vector that
+ // constains such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
+ const HloInstruction& hlo);
+
// For each output of the `hlo` instruction, constructs the reduced shape for
// the output with the given `reduced_output_dims` and cast the original
// output IrArray element in `output_arrays` to the reduced shape. Returns
@@ -214,7 +206,7 @@ class IrEmitterUnnested : public IrEmitter {
int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& output_arrays,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* output_reduced_shapes,
std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
// For each input of the `hlo` instruction, checks its value in
@@ -226,7 +218,7 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& hlo,
const std::vector<llvm_ir::IrArray>& param_arrays,
const std::vector<llvm::Value*>& param_buffers,
- tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ absl::Span<const int64> reduced_output_dims,
std::vector<Shape>* param_reduced_shapes,
std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
@@ -250,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a thunk that, given a reduce or select-and-scatter op, initializes
// its memory to the appropriate initial value.
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index = {});
+ HloInstruction* hlo, const ShapeIndex& index = {});
// Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index e76823ad10..e09b8fbd3b 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -15,22 +15,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
-KernelThunk::KernelThunk(
- tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction,
- int unroll_factor)
+KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
+ const string& kernel_name,
+ const HloInstruction* hlo_instruction,
+ int unroll_factor)
: Thunk(Kind::kKernel, hlo_instruction),
args_(args.begin(), args.end()),
kernel_name_(kernel_name),
@@ -41,11 +41,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
tensorflow::mutex_lock lock(mutex_);
if (!loader_spec_) {
loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
- tensorflow::StringPiece ptx = executable.ptx();
- // Convert tensorflow::StringPiece to se::port::StringPiece because
- // StreamExecutor uses the latter.
- loader_spec_->AddCudaPtxInMemory(
- se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
+ loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_);
if (!executable.cubin().empty()) {
loader_spec_->AddCudaCubinInMemory(
@@ -63,7 +59,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
if (kernel_cache_.end() == it) {
it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
if (!executor->GetKernel(*loader_spec_, &it->second)) {
- return InternalError("Unable to load kernel %s", kernel_name_.c_str());
+ return InternalError("Unable to load kernel %s", kernel_name_);
}
}
@@ -95,7 +91,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << "Launching " << kernel->name();
// Launch the kernel with potentially multiple blocks and threads.
static constexpr int kKernelArgsLimit = 1024;
- auto kernel_args = MakeUnique<se::KernelArgsArray<kKernelArgsLimit>>();
+ auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
for (const BufferAllocation* arg : args_) {
const auto& buf = buffer_allocations.GetDeviceAddress(arg->index());
kernel_args->add_device_memory_argument(buf);
@@ -107,7 +103,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
se::BlockDim(launch_dimensions.block_count()), *kernel,
*kernel_args)) {
- return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
+ return InternalError("Unable to launch kernel %s", kernel_name_);
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index d751de50ad..f63db5c369 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -47,7 +47,7 @@ class KernelThunk : public Thunk {
// Constructs a thunk for the given kernel.
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
- KernelThunk(tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
+ KernelThunk(absl::Span<const BufferAllocation* const> args,
const string& kernel_name, const HloInstruction* hlo_instruction,
int unroll_factor);
KernelThunk(const KernelThunk&) = delete;
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
index eb93efc560..698d2d51cc 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -34,6 +34,9 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:amdgpu_code_gen",
"@llvm//:analysis",
"@llvm//:bit_reader",
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
index 12a8a59488..85bc58cb44 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) {
const llvm::PassInfo *PI =
llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID());
const string basename = ReplaceFilenameExtension(
- tensorflow::io::Basename(input_filename_),
- tensorflow::strings::Printf(
+ absl::string_view(tensorflow::io::Basename(input_filename_)),
+ absl::StrFormat(
"pass-%02d.before.%s.ll", i,
- (PI == nullptr ? "unknown" : PI->getPassArgument().data())));
+ absl::string_view(PI == nullptr ? "unknown"
+ : PI->getPassArgument().data())));
llvm::legacy::PassManager::add(
new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename)));
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index ff4ae1f9ef..8751e3a9c2 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -20,13 +20,15 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
@@ -54,10 +56,7 @@ limitations under the License.
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Scalar.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
@@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
<< ", " << compute_capability.second << ") ."
<< "Defaulting to libdevice for compute_" << libdevice_version;
}
- return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version,
- ".10.bc");
+ return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc");
}
// Gets the GPU name as it's known to LLVM for a given compute capability. If
@@ -138,15 +136,16 @@ static string GetSmName(std::pair<int, int> compute_capability) {
<< "Defaulting to telling LLVM that we're compiling for sm_"
<< sm_version;
}
- return tensorflow::strings::StrCat("sm_", sm_version);
+ return absl::StrCat("sm_", sm_version);
}
// Convenience function for producing a name of a temporary compilation product
// from the input filename.
string MakeNameForTempProduct(const std::string& input_filename,
- tensorflow::StringPiece extension) {
- return ReplaceFilenameExtension(
- tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension);
+ absl::string_view extension) {
+ return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename(
+ llvm_ir::AsString(input_filename))),
+ extension);
}
// Initializes LLVM passes. Uses the PassRegistry mechanism.
@@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) {
// Returns the TargetMachine, given a triple.
std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
- llvm::Triple triple, tensorflow::StringPiece cpu_name,
+ llvm::Triple triple, absl::string_view cpu_name,
const HloModuleConfig& hlo_module_config) {
std::string error;
const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
@@ -205,7 +204,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
default:
codegen_opt_level = CodeGenOpt::None;
}
- return WrapUnique(target->createTargetMachine(
+ return absl::WrapUnique(target->createTargetMachine(
triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options,
Optional<Reloc::Model>(RelocModel), Optional<CodeModel::Model>(CMModel),
codegen_opt_level));
@@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
}
// Emits the given module to a bit code file.
-void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) {
+void EmitBitcodeToFile(const Module& module, absl::string_view filename) {
std::error_code error_code;
- llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code,
+ llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
llvm::sys::fs::F_None);
if (error_code) {
LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
@@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
// get creative to add a suffix.
string module_id(llvm_ir::AsString(module->getModuleIdentifier()));
IrDumpingPassManager codegen_passes(
- ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
- "-nvptx.dummy"),
+ ReplaceFilenameExtension(
+ absl::string_view(tensorflow::io::Basename(module_id)),
+ "-nvptx.dummy"),
"", false);
codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
llvm::Triple(module->getTargetTriple())));
@@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {
- return tensorflow::errors::Internal(tensorflow::strings::StrCat(
- "Error linking libdevice from ", libdevice_path));
+ return tensorflow::errors::Internal(
+ absl::StrCat("Error linking libdevice from ", libdevice_path));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
index 54e0e140de..9654175bfa 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/string_view.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
namespace gpu {
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
index 9ef9bc3a50..3b2c3591d9 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
@@ -17,13 +17,13 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace {
@@ -52,14 +52,13 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
return module;
}
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension) {
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension) {
auto pos = filename.rfind('.');
- tensorflow::StringPiece stem =
- pos == tensorflow::StringPiece::npos
- ? filename
- : tensorflow::StringPiece(filename.data(), pos);
- return tensorflow::strings::StrCat(stem, ".", new_extension);
+ absl::string_view stem = pos == absl::string_view::npos
+ ? filename
+ : absl::string_view(filename.data(), pos);
+ return absl::StrCat(stem, ".", new_extension);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
index a6daeca95a..60f4926849 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace llvm {
class LLVMContext;
@@ -41,8 +41,8 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
//
// For example:
// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc"
-string ReplaceFilenameExtension(tensorflow::StringPiece filename,
- tensorflow::StringPiece new_extension);
+string ReplaceFilenameExtension(absl::string_view filename,
+ absl::string_view new_extension);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c62bae0628..c21f76f6eb 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -23,7 +23,9 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,7 +50,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// If possible, we want to pick a reduce operand of the fusion root,
// because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
- if (inst->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*inst)) {
return inst;
}
}
@@ -63,7 +65,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
auto get_element_shape = [&](const HloInstruction* element_instr) {
// Special handling of kReduce instructions -- the fusion
// applies to the first operand.
- if (element_instr->opcode() == HloOpcode::kReduce) {
+ if (IsReductionToVector(*element_instr)) {
return element_instr->operand(0)->shape();
}
return element_instr->shape();
@@ -85,65 +87,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
get_element_shape(element_instr_1), get_element_shape(element_instr_2));
}
-namespace {
-bool IsInputFusibleReduction(HloInstruction* instr) {
- if (instr->IsMultiOutputFusion()) {
- for (const HloInstruction* operand :
- instr->fused_expression_root()->operands()) {
- if (operand->opcode() == HloOpcode::kReduce) {
- CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput)
- << " Reduce multi-output fusion " << instr->ToString()
- << " must be an input fusion.";
- return true;
- }
- }
- return false;
- } else if (instr->opcode() == HloOpcode::kFusion) {
- // The loop emitter can handle to-vector reduce fusions. Such reduce
- // fusions have the fusion kind kLoop rather than kInput. We do not fuse
- // to-vector reduce fusions, because the resulting fusions may no longer be
- // supported by loop emitter.
- return IsReductionToVector(*instr->fused_expression_root());
- } else {
- return IsReductionToVector(*instr);
- }
-}
-
-// The code emitted for reduction suffers from poor data locality if the layouts
-// of input parameters differ. In such situtations it is beneficial not to fuse.
-// We consider input params with maximum rank only. Params with smaller ranks
-// will be broadcasted and have not been observed to cause data locality issues.
-// TODO(b/111977086): Improve reduce emitters to remove this limitation.
-bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
- std::vector<HloInstruction*> params;
- if (instr->opcode() == HloOpcode::kFusion) {
- params = instr->fused_parameters();
- } else {
- for (HloInstruction* operand : instr->operands()) {
- params.push_back(operand);
- }
- }
- int64 max_rank = 0;
- const Layout* max_rank_layout;
- for (HloInstruction* param : params) {
- if (ShapeUtil::Rank(param->shape()) > max_rank) {
- max_rank = ShapeUtil::Rank(param->shape());
- max_rank_layout = &param->shape().layout();
- }
- }
- return c_all_of(params, [&](HloInstruction* param) {
- return (ShapeUtil::Rank(param->shape()) < max_rank) ||
- (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
- });
-}
-
-} // namespace
-
bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
- // We can fuse reduces and loop fusions.
- return IsInputFusibleReduction(instr) ||
- (instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
+ // We can fuse reduces and loop fusions. Elementwise instructions can be fused
+ // with any other instruction.
+ // TODO(b/112957171): This should use the same isFusible logic as
+ // instruction_fusion.
+ return instr->IsFusible() &&
+ (IsInputFusibleReduction(*instr) ||
+ (instr->opcode() == HloOpcode::kFusion &&
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop) ||
+ instr->IsElementwise());
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
@@ -177,11 +130,12 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
// merge into bigger loop fusions and input (reduce) fusions become fusions
// with multiple reduce outputs. We could fuse reduce and loop fusions
// together too (the result being an input fusion) if we find cases where this
- // improves things.
+ // improves things. Also disable fusing standalone input-fusible reduces into
+ // loop fusions.
CHECK(instr1->opcode() == HloOpcode::kFusion);
if ((instr2->opcode() == HloOpcode::kFusion &&
instr1->fusion_kind() != instr2->fusion_kind()) ||
- (instr2->opcode() != HloOpcode::kFusion &&
+ (IsReductionToVector(*instr2) &&
instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) {
return false;
}
@@ -197,7 +151,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
// Keep a list of the instructions to fuse after making all the fusion
// decisions. We first aggressively add instructions to potential_fusion_list,
- // then filter out instructions that will be no longer fusable because of
+ // then filter out instructions that will be no longer fusible because of
// reachability change. This avoids recalculating reachability on a large set
// of instructions.
std::vector<std::pair<HloInstruction*, HloInstruction*>>
@@ -212,8 +166,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
VLOG(3) << consumer->name() << " has no users.";
continue;
}
- if (!IsInputFusibleReduction(consumer)) {
- VLOG(3) << consumer->name() << " is not an input-fusable reduction.";
+ if (!IsInputFusibleReduction(*consumer)) {
+ VLOG(3) << consumer->name() << " is not an input-fusible reduction.";
continue;
}
VLOG(3) << consumer->name()
@@ -222,8 +176,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
auto consumer_operands = consumer->operands();
for (size_t i = 0; i < consumer_operands.size(); ++i) {
HloInstruction* producer = consumer_operands[i];
- if (!producer->IsFusable()) {
- VLOG(3) << producer->name() << " is not fusable.";
+ if (!producer->IsFusible()) {
+ VLOG(3) << producer->name() << " is not fusible.";
continue;
}
const bool is_loop_fusion =
@@ -237,7 +191,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
VLOG(3) << producer->name() << " has an incompatible shape.";
continue;
}
- if (!ReduceFriendlyInputLayouts(producer)) {
+ if (!LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
VLOG(3) << producer->name() << " has inputs with mixed layouts.";
continue;
}
@@ -248,7 +202,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
}
// Do not fuse a producer if the other operands of the fusion are
// reachable from the producer, this would create a cycle.
- if (c_any_of(consumer_operands, [&](HloInstruction* operand) {
+ if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
@@ -263,12 +217,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
}
}
- // Filter out pairs that will be no longer fusable because of reachability
+ // Filter out pairs that will be no longer fusible because of reachability
// change.
for (auto& fusion_pair : potential_fusion_list) {
HloInstruction* producer = fusion_pair.first;
HloInstruction* consumer = fusion_pair.second;
- if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) {
+ if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) {
return producer != operand &&
reachability()->IsReachable(producer, operand);
})) {
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 67ca5d49ee..f0b4d67ab8 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -22,7 +22,7 @@ namespace xla {
namespace gpu {
// Multi-output fusion of sibling and producer-consumer instructions for the
-// Jellyfish backend.
+// GPU backend.
class GpuMultiOutputFusion : public MultiOutputFusion {
public:
GpuMultiOutputFusion();
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 14f157a5e5..8a6e5327e0 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -15,19 +15,19 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace gpu {
+namespace op = xla::testing::opcode_matchers;
+
using MultiOutputFusionTest = HloTestBase;
const char kModulePrefix[] = R"(
@@ -47,7 +47,7 @@ const char kModulePrefix[] = R"(
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
// Fusion with reduce instruction root and a sibling reduce instruction
// sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[6400]{0} parameter(1)
mul = f32[6400]{0} multiply(p1.1, p1.1)
@@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
// Two sibling fusions with reduce instruction roots sharing the same input
// param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
@@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
// Multi-output fusion with two reduce instructions root and a sibling reduce
// instruction sharing the same input param.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
const.1 = f32[] constant(1)
p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
@@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest,
MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
// Verify that if we already have a multi-output fusion that we prefer to pick
// a reduce op from its operands for checking shape compatibility.
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p1.1 = f32[10,10]{1,0} parameter(1)
mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
@@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest,
}
TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_computation_1 {
p0.1 = f32[6400]{0} parameter(0)
ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
@@ -256,8 +256,136 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
+ // Fusing a reduce into a loop fusion would require changing the fusion kind.
+ // That's not supported yet.
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(0)
+ reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation
+ ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ const.2 = f32[] constant(1)
+ div = f32[6400]{0} divide(p0, const.2)
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Divide()));
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Exp(), op::Add()));
+}
+
+TEST_F(MultiOutputFusionTest,
+ MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
+ exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
+ ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ const.2 = f32[] constant(0)
+ ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
+ fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
+ gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
+ ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -277,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_add {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -304,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
c0 = f32[] constant(0)
@@ -345,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
}
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_element_wise {
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
@@ -372,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_select {
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
c0 = f16[] constant(0)
@@ -413,7 +541,7 @@ TEST_F(MultiOutputFusionTest,
TEST_F(MultiOutputFusionTest,
ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
- auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
mixed_input_layouts_computation {
p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 6c1eab4f8c..dfdcf1875d 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -21,20 +21,21 @@ limitations under the License.
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
-#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -43,9 +44,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
-#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
@@ -85,7 +86,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cuda_libdevice_path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -140,7 +140,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -156,7 +157,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -203,9 +205,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
- pipeline.AddInvariantChecker<HloVerifier>();
- // TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
- pipeline.AddPass<ConvolutionFeatureGroupConverter>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
@@ -214,13 +215,29 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// pairs that TupleSimplifier fixes.
pipeline.AddPass<TupleSimplifier>();
}
+ // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
+ // instructions which can be simplified by constant folding.
+ pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
- HloPassPipeline pipeline("layout_assignment");
+ // Run layout assignment in a separate pipeline from
+ // "post-layout-assignment" because we want everything after layout
+ // assignment to have a layout-sensitive invariant-checker, but
+ // HloPassPipeline also runs its invariant checker before any passes are
+ // run, meaning, the pipeline that contains layout assignment cannot contain
+ // a layout-sensitive verifier!
+ HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("post-layout_assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -266,17 +283,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>();
+ fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
fusion.AddPass<GpuMultiOutputFusion>();
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
+ fusion.AddPass<HloDCE>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>();
+ reduce_pipeline.AddInvariantChecker<HloVerifier>(
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -302,7 +322,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -352,9 +373,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
string vmaj_str, vmin_str, vdot_str;
if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str,
&vmin_str, &vdot_str) ||
- !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) ||
- !tensorflow::strings::safe_strto64(vmin_str, &vmin) ||
- !tensorflow::strings::safe_strto64(vdot_str, &vdot)) {
+ !absl::SimpleAtoi(vmaj_str, &vmaj) ||
+ !absl::SimpleAtoi(vmin_str, &vmin) ||
+ !absl::SimpleAtoi(vdot_str, &vdot)) {
LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path
<< " --version:\n"
<< out;
@@ -466,7 +487,7 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
tensorflow::SubProcess ptxas_info_dumper;
std::vector<string> ptxas_args = {
ptxas_path, ptx_path, "-o", cubin_path,
- tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)};
+ absl::StrCat("-arch=sm_", cc_major, cc_minor)};
if (VLOG_IS_ON(2)) {
ptxas_args.push_back("-v");
}
@@ -544,8 +565,8 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// must also be used to determine the thunk launch schedule.
std::unique_ptr<StreamAssignment> stream_assignment = AssignStreams(*module);
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloSchedule> hlo_schedule,
- HloSchedule::Build(*module, *stream_assignment, pointer_size_));
+ std::unique_ptr<GpuHloSchedule> hlo_schedule,
+ GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
@@ -674,7 +695,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// Write PTX to IR dump directory, if IR dumping was requested.
if (!ir_dump_directory.empty()) {
const string ptx_outfile = tensorflow::io::JoinPath(
- ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx"));
+ ir_dump_directory, absl::StrCat(module->name(), ".ptx"));
auto status = [&] {
auto* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory));
@@ -690,7 +711,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
const std::vector<uint8> cubin =
CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor);
- auto thunk_schedule = MakeUnique<ThunkSchedule>(
+ auto thunk_schedule = absl::make_unique<ThunkSchedule>(
ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
hlo_schedule->ThunkLaunchOrder());
VLOG(2) << "Printing the thunk schedule...";
@@ -704,7 +725,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
cost_analysis.set_bytes_per_second(
stream_exec->GetDeviceDescription().memory_bandwidth());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
- profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
+ profile_index_map = absl::make_unique<HloProfileIndexMap>(*module);
profile_printer =
CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
}
@@ -813,7 +834,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::cuda::kCudaPlatformId,
- []() { return xla::MakeUnique<xla::gpu::NVPTXCompiler>(); });
+ []() { return absl::make_unique<xla::gpu::NVPTXCompiler>(); });
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index d4d2909f1b..8e97774750 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
index 4aaf0c9e14..2fa170964e 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
index b99d998c4d..e0f3e84a4c 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
@@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream(
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
- stream, block_status.error_message().c_str());
+ stream, block_status.error_message());
}
VLOG(2) << "Outfeeding from GPU complete";
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index 79f7d31816..b0061fa655 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,8 +23,6 @@ limitations under the License.
namespace xla {
namespace gpu {
-using tensorflow::gtl::ArraySlice;
-
// We want the input/output feature counts of an f16 conv to be factors of 8,
// because without this cudnn can't use tensor cores on the conv.
static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -42,7 +40,7 @@ static constexpr double kMaxBytesTouchedIncrease = 1.2;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
-static Shape PadShape(Shape s, ArraySlice<int64> dims) {
+static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
int64 new_dim_to_pad_size =
@@ -64,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
HloComputation* comp = instr->parent();
const Shape& shape = instr->shape();
- auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+ auto* zero = comp->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 192359f026..11dc56a64f 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -32,9 +32,7 @@ namespace gpu {
// TODO(jlebar): Also pad dots.
class PadForTensorCores : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
- return "pad for tensor cores";
- }
+ absl::string_view name() const override { return "pad for tensor cores"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 99e7580b82..5c92b0dcb8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,7 +29,7 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-using PadForTensorCoresTest = HloVerifiedTestBase;
+class PadForTensorCoresTest : public HloVerifiedTestBase {};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b22040eee1..2a6415d0b6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -67,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
conv_window.dimensions(i).base_dilation() - 1);
}
PrimitiveType element_type = input->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -124,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloComputation* computation = kernel->parent();
PrimitiveType element_type = kernel->shape().element_type();
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -165,9 +164,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
Shape old_conv_shape = conv->shape().tuple_shapes(0);
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
- new_conv_window,
- conv->convolution_dimension_numbers());
+ auto new_conv = CreateCudnnConvForward(
+ old_conv_shape, new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -235,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
- LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(input->shape().element_type())));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
@@ -246,7 +245,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -311,7 +310,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index 67e51509e4..a622e894ed 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -26,7 +26,7 @@ namespace gpu {
// padding, so that they can be lowered to cuDNN convolution.
class PadInsertion : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "pad insertion"; }
+ absl::string_view name() const override { return "pad insertion"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 3838fee674..8154d75d23 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter(
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
int unroll_factor)
: LoopEmitter(target_element_generator, target_arrays, b),
@@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
unroll_factor_(unroll_factor) {}
std::vector<llvm_ir::IrArray::Index>
-ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
+ llvm::Type* index_type) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
// if (linear_index < num_elements) {
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index b82a23419d..f32ea1ce4c 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -47,18 +47,17 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
//
// This is used in multi-output fusion. target_element_generator should
// produce a struct with N elements, one for each of target_arrays.
- ParallelLoopEmitter(
- const llvm_ir::ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b,
- int unroll_factor = 1);
+ ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
+ absl::Span<const llvm_ir::IrArray> target_arrays,
+ const LaunchDimensions& launch_dimensions,
+ llvm::IRBuilder<>* b, int unroll_factor = 1);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
+ absl::string_view loop_name, llvm::Type* index_type) override;
private:
// The thread and block dimension to parallelize the loop on.
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index d3fd0544fb..cf9f102d31 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -18,15 +18,15 @@ limitations under the License.
#include <ostream>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -34,9 +34,8 @@ namespace gpu {
std::ostream& operator<<(std::ostream& out,
const LaunchDimensions& launch_dims) {
- out << tensorflow::strings::Printf("[block: %lld, thread: %lld]",
- launch_dims.block_count(),
- launch_dims.threads_per_block());
+ out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(),
+ launch_dims.threads_per_block());
return out;
}
@@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions(
}
int64 block_count = CeilOfRatio(num_elements, threads_per_block);
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << absl::StrFormat(
"Initialized the block count to ceil(# of elements / threads per "
- "block) = ceil(%lld/%lld) = %lld",
+ "block) = ceil(%d/%d) = %d",
num_elements, threads_per_block, block_count);
return LaunchDimensions(block_count, threads_per_block);
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
index 0806dd5161..5b6cf2c04d 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
@@ -119,7 +119,7 @@ int ComputeStreamToAssign(
} // namespace
std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
- auto stream_assignment = MakeUnique<StreamAssignment>();
+ auto stream_assignment = absl::make_unique<StreamAssignment>();
const HloComputation& computation = *module.entry_computation();
std::unique_ptr<HloReachabilityMap> reachability =
computation.ComputeReachability();
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 6f4bb0580e..c4f43cc9a6 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -15,25 +15,27 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace gpu {
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", config);
+ return absl::make_unique<HloModule>("test_module", config);
}
// Pre-canned shapes.
@@ -48,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -67,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
@@ -97,26 +99,26 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
params.reserve(6);
for (int i = 0; i < 6; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
- i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
index 05b305ea4c..08ff52211a 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
@@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
input_layout.push_back(dnums.input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid input layout: ",
- DataLayoutString(input));
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ DataLayoutString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> filter_layout;
@@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
filter_layout.push_back(dnums.kernel_input_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid filter layout: ",
- FilterLayoutString(filter));
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ FilterLayoutString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
std::vector<int64> output_layout;
@@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
output_layout.push_back(dnums.output_feature_dimension());
break;
default:
- return tensorflow::errors::Internal("Invalid output layout: ",
- DataLayoutString(output));
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ DataLayoutString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
@@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(input, nhwc_input)) {
input_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid input layout: ",
- input.ShortDebugString());
+ return InternalError("Invalid input layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(input),
+ ConvolutionDimensionNumbersToString(dnums));
}
FilterLayout filter_layout;
@@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(filter, nhwc_filter)) {
filter_layout = FilterLayout::kOutputYXInput;
} else {
- return tensorflow::errors::Internal("Invalid filter layout: ",
- filter.ShortDebugString());
+ return InternalError("Invalid filter layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(filter),
+ ConvolutionDimensionNumbersToString(dnums));
}
DataLayout output_layout;
@@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
} else if (LayoutUtil::Equal(output, nhwc_output)) {
output_layout = DataLayout::kBatchYXDepth;
} else {
- return tensorflow::errors::Internal("Invalid output layout: ",
- output.ShortDebugString());
+ return InternalError("Invalid output layout %s for conv with dnums %s",
+ LayoutUtil::HumanString(output),
+ ConvolutionDimensionNumbersToString(dnums));
}
return std::make_tuple(input_layout, filter_layout, output_layout);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 4fad3f46cf..db4a33dc56 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -35,13 +35,13 @@ cc_library(
"requires-gpu-sm35",
],
deps = [
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
)
@@ -60,6 +60,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -94,6 +95,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -150,6 +152,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
@@ -168,6 +171,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
index 4b8415fe91..79e77d4c4d 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/platform/logging.h"
@@ -32,15 +32,14 @@ std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) {
debug_options.add_xla_disable_hlo_passes("constant_folding");
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>(TestName(), config);
+ return absl::make_unique<HloModule>(TestName(), config);
}
void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module,
const string& pattern) {
std::unique_ptr<Executable> executable =
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
- string ptx_str =
- std::string(static_cast<GpuExecutable*>(executable.get())->ptx());
+ string ptx_str(static_cast<GpuExecutable*>(executable.get())->ptx());
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index ce69e058e6..780539c164 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {};
TEST_F(GpuCopyTest, UseMemcpy) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
index e5958165ef..a06576df7b 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index cca35316f0..15d1e269cc 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -27,13 +27,22 @@ namespace {
class GpuKernelTilingTest : public GpuCodegenTest {
protected:
- GpuKernelTilingTest() {
+ GpuKernelTilingTest() {}
+
+ // Most tests in this file want to skip layout assignment, but a few need it
+ // enabled.
+ HloModuleConfig ConfigWithLayoutAssignment() {
+ return GetModuleConfigForTest();
+ }
+
+ HloModuleConfig ConfigWithoutLayoutAssignment() {
+ HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
- config_.set_debug_options(debug_options);
// Disable layout_assignment to use the preassigned layouts.
- debug_options.add_xla_disable_hlo_passes("layout_assignment");
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ config.set_debug_options(debug_options);
+ return config;
}
- HloModuleConfig config_;
};
TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
@@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ //
+ // We must enable layout assignment in order for this test to work correctly.
+ // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
+ // which respects the module's entry computation layout. But if we don't run
+ // layout assignment...well, nobody else adds the copy back.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
})";
- // Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ // Check that a call to llvm.nvvm.barrier0 is not generated. As in
+ // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
+ // here.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest,
})";
// Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
index 6c9ae7bada..6a9ecd9dae 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
index c42e5704a4..15198865bd 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
index 9622936306..0f2d5568ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
@@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_max_kernel_unroll_factor(2);
+ // Disable layout assignment for this test. Layout assignment does not expect
+ // fusions to be present, and so it does the wrong thing.
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
const char *const kMultiOutputFusionModule = R"(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30317..f8120a5fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+ TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0minor));
- TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
{{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+ TestInfeedRoundTrip(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
@@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) {
TEST_F(InfeedTest, LargeInfeed) {
Array4D<float> array(80, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+ TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
- TestInfeedRoundTrip(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
}
// Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest, SingleInfeedLargeTuple) {
Array4D<float> array(40, 100, 8, 128);
array.FillIota(1.0f);
- TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
- LiteralUtil::CreateR0<int32>(5).get()}));
+ TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR4FromArray4D<float>(array),
+ LiteralUtil::CreateR0<int32>(5)}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
index bdb062837c..141f321938 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
@@ -144,16 +144,15 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn(
string ThunkSchedule::ToString() const {
string result = "Total order:\n";
for (Thunk* thunk : thunk_total_order_) {
- tensorflow::strings::StrAppend(&result, "\t",
- thunk->hlo_instruction()->ToString(), "\n");
+ absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n");
}
- tensorflow::strings::StrAppend(&result, "Dependencies:\n");
+ absl::StrAppend(&result, "Dependencies:\n");
for (const auto& entry : depends_on_) {
const Thunk* dependent = entry.first;
for (const Thunk* dependency : entry.second) {
- tensorflow::strings::StrAppend(
- &result, "\t", dependent->hlo_instruction()->name(), " depends on ",
- dependency->hlo_instruction()->name(), "\n");
+ absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(),
+ " depends on ", dependency->hlo_instruction()->name(),
+ "\n");
}
}
return result;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index 8579b1545f..989b542ff4 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
@@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
auto size = tuple_element_buffers_.size();
- auto tuple_element_buffer_addresses = MakeUnique<void*[]>(size);
+ auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size);
for (int i = 0; i != size; ++i) {
tuple_element_buffer_addresses[i] =
buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque();
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 2d5735d6c4..dcdbf2cf3c 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -18,12 +18,12 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -34,8 +34,7 @@ namespace gpu {
// issue (b/31336476).
class TupleThunk : public Thunk {
public:
- TupleThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice>
- tuple_element_buffers,
+ TupleThunk(absl::Span<const BufferAllocation::Slice> tuple_element_buffers,
const BufferAllocation::Slice& dest_buffer,
const HloInstruction* hlo_instruction)
: Thunk(Kind::kTuple, hlo_instruction),
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index d81d87e7dc..c4754fe378 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -34,9 +34,9 @@ WhileThunk::WhileThunk(
// and body_thunk_sequence_ constructors because these SequentialThunks
// are logically "part of" this WhileThunk, and shouldn't be profiled
// separately from it.
- condition_thunk_sequence_(MakeUnique<SequentialThunk>(
+ condition_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*condition_thunk_sequence), nullptr)),
- body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ body_thunk_sequence_(absl::make_unique<SequentialThunk>(
std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
@@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
if (!block_status.ok()) {
return InternalError(
"Failed to complete all kernels launched on stream %p: %s", stream,
- block_status.error_message().c_str());
+ block_status.error_message());
}
if (!condition_result) {
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index c5f3906356..40183de96e 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier;
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index aa89567ee8..ef70b68877 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -22,9 +22,10 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@@ -43,8 +43,7 @@ namespace {
// Adds a computation to the given HLO module which adds a scalar constant to
// its parameter and returns the result.
HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
- auto builder =
- HloComputation::Builder(tensorflow::strings::StrCat("add_", addend));
+ auto builder = HloComputation::Builder(absl::StrCat("add_", addend));
auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "x_value"));
auto half = builder.AddInstruction(
@@ -84,7 +83,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation,
// the module.
std::unique_ptr<HloModule> MakeBigGraph() {
HloModuleConfig config;
- auto module = MakeUnique<HloModule>("BigGraph", config);
+ auto module = absl::make_unique<HloModule>("BigGraph", config);
auto builder = HloComputation::Builder("TestBigGraphvizGraph");
@@ -113,8 +112,11 @@ std::unique_ptr<HloModule> MakeBigGraph() {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ vshape, clamp, param_v0, dot_dnums, precision_config));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({dot, param_s, clamp}));
auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 4005fc0d11..e0f3a7e0e2 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -28,13 +29,13 @@ using tensorflow::gtl::FlatSet;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
+ if (schedule.empty()) {
return 0;
}
- const HloModule* module = module_sequence.begin()->first->parent();
+ const HloModule* module = schedule.module();
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -45,37 +46,37 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
// bound, by minimizing the liveness of sub-computations.
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
+ schedule, *points_to_analysis, size_function));
return result.heap_size;
}
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function,
- HeapSimulator::Options(), memory_by_computation));
+ HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
+ computation, sequence, points_to_analysis,
+ size_function, HeapSimulator::Options(),
+ memory_by_computation));
return result.heap_size;
}
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options) {
- HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
+ HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
const HloComputation* entry_computation = module.entry_computation();
- const std::vector<const HloInstruction*>& instruction_sequence =
- FindOrDie(module_sequence, entry_computation);
+ const HloInstructionSequence& instruction_sequence =
+ schedule.sequence(entry_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(
*entry_computation, instruction_sequence, points_to_analysis));
return heap.Finish();
@@ -84,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr, memory_by_computation);
+ /*schedule=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -100,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
// 'instruction_sequence'.
Status HeapSimulator::RunComputation(
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis) {
VLOG(3) << "Computation:\n" << computation.ToString();
// The goal here is to minimize memory usage, assuming the given sequential
@@ -131,7 +132,8 @@ Status HeapSimulator::RunComputation(
// set of instructions that need to be visited contains all users of all
// aliases, that is, all users of all instructions that have the buffer
// contained in their points-to set.
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const PointsToSet& points_to =
points_to_analysis.GetPointsToSet(instruction);
const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
@@ -142,7 +144,7 @@ Status HeapSimulator::RunComputation(
}
} else {
// A GetTupleElement doesn't need to keep all of its operand's buffers
- // alive. It only needs the buffers that relate to the element its
+ // alive. It only needs the buffers that relate to the element it's
// extracting, and the tuple it's extracting from, but not the buffers
// for the other elements.
for (const BufferValue* buffer : points_to.element({})) {
@@ -164,7 +166,8 @@ Status HeapSimulator::RunComputation(
std::vector<const BufferValue*> dead_buffers_to_free;
std::vector<const BufferValue*> operand_buffers_to_free;
- for (const HloInstruction* instruction : instruction_sequence) {
+ for (const HloInstruction* instruction :
+ instruction_sequence.instructions()) {
const TuplePointsToAnalysis::BufferDefinitionVector&
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
@@ -275,22 +278,22 @@ Status HeapSimulator::RunComputation(
*memory_by_computation_);
}
- // If the whole module is sequential, we can save memory by running the
- // heap-simulation for sub-computations inline. E.g. the buffers for the
- // condition and body of a kWhile instruction are only live for the duration
- // of the instruction itself.
+ // If all computations in the module have been scheduled, we can save memory
+ // by running the heap-simulation for sub-computations inline. E.g. the
+ // buffers for the condition and body of a kWhile instruction are only live
+ // for the duration of the instruction itself.
//
// The order that the sub-computations are simulated does not affect
- // correctness; since the whole module is sequential, we know that the
+ // correctness; since the whole module has been scheduled, we know that the
// sub-computations will never be run concurrently.
- if (module_sequence_ != nullptr) {
+ if (schedule_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
- const std::vector<const HloInstruction*>& called_sequence =
- FindOrDie(*module_sequence_, called_computation);
+ const HloInstructionSequence& called_sequence =
+ schedule_->sequence(called_computation);
TF_RETURN_IF_ERROR(RunComputation(
*called_computation, called_sequence, points_to_analysis));
}
@@ -341,16 +344,16 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const HloSchedule* schedule,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation)
- : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
+ : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence),
+ schedule_(schedule),
memory_by_computation_(memory_by_computation) {
- debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
+ debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
}
HeapSimulator::~HeapSimulator() {}
@@ -378,9 +381,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer,
allocated_buffers_.insert(buffer);
const int64 size = size_fn_(*buffer);
- algorithm_->Alloc(buffer, size);
- no_fragmentation_stats_->Alloc(buffer, size);
-
+ const HloInstruction* instruction_to_calc_aliasing =
+ memory_by_computation_ == nullptr ? nullptr : instruction;
+ algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing);
+ no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing);
FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
nullptr);
}
@@ -518,6 +522,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ if (instruction == nullptr ||
+ (instruction->opcode() != HloOpcode::kWhile &&
+ instruction->opcode() != HloOpcode::kCall &&
+ instruction->opcode() != HloOpcode::kConditional)) {
+ Alloc(buffer, size);
+ }
+}
+
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 811a6042df..ffbf947d5a 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -36,6 +37,7 @@ namespace xla {
// Forward declare classes defined below.
class HeapAlgorithm;
+class NoFragmentationStatsHeap;
// HeapSimulator assigns buffer offsets by running a simulation of a regular
// memory heap with Alloc and Free calls. It only works for completely
@@ -87,23 +89,22 @@ class HeapSimulator {
// Returns the minimum memory required to compute an HLO module where all
// computations have been scheduled (represented by the given
- // module_sequence), assuming no fragmentation.
+ // schedule), assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForModule(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const HloSchedule& schedule,
const LogicalBuffer::SizeFunction& size_function);
// Returns the minimum memory required to compute the given computation,
// assuming no fragmentation.
static StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
+ const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation = nullptr);
// Run the heap simulation with the given algorithm, assuming the given
- // module_sequence, which must contain a topologically-consistent total
+ // schedule, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
// if instructions are not run in exactly this sequence.
//
@@ -111,12 +112,12 @@ class HeapSimulator {
// to running on a per-computation basis, since we can re-use buffer space for
// called sub-computations.
//
- static StatusOr<Result> Run(
- std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
+ const HloModule& module,
+ const HloSchedule& schedule,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options = Options());
// Same as above, but runs on a single computation. The 'instruction_sequence'
// must contain a topologically-consistent total ordering of all instructions
@@ -125,7 +126,7 @@ class HeapSimulator {
static StatusOr<Result> Run(
std::unique_ptr<HeapAlgorithm> algorithm,
const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
+ const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
const Options& options = Options(),
@@ -133,21 +134,19 @@ class HeapSimulator {
memory_by_computation = nullptr);
private:
- // If 'module_sequence' is non-null, it is used to find kCall and kWhile
+ // If 'schedule' is non-null, it is used to find kCall and kWhile
// sub-computations, and the heap simulation for those sub-computations will
// be run recursively. I.e. the simulation is run over the whole module.
- HeapSimulator(
- std::unique_ptr<HeapAlgorithm> algorithm,
- const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
- memory_by_computation = nullptr);
+ HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
+ const BufferValue::SizeFunction& size_fn,
+ const Options& options, const HloSchedule* schedule = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
- Status RunComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& instruction_sequence,
- const TuplePointsToAnalysis& points_to_analysis);
+ Status RunComputation(const HloComputation& computation,
+ const HloInstructionSequence& instruction_sequence,
+ const TuplePointsToAnalysis& points_to_analysis);
bool IgnoreBuffer(const BufferValue* buffer) const;
void Alloc(const BufferValue* buffer, const HloInstruction* instruction);
@@ -161,15 +160,18 @@ class HeapSimulator {
const HloInstruction* instruction,
const BufferValue* shared_with_canonical);
- const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_;
+ // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap,
+ // in which case we are calculating the same allocs/frees twice in the
+ // simulation.
+ const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_;
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
- // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // schedule_ is set by buffer assignment, and memory_by_computation_ is
// set by hlo scheduling. Then, in RunComputation, we check both in order to
// handle subcomputations. It would be good to unify the handling of
// subcomputations, but it's not clear how.
- const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const HloSchedule* schedule_;
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
memory_by_computation_;
@@ -216,6 +218,21 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
+ // NoFragmentationStatsHeap overrides this method.
+ virtual void Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) {
+ Alloc(buffer, size);
+ }
+
+ // Takes memory usage of subcomputations into account when calculating the
+ // memory usage of a computation. Currently, we don't handle buffer aliasing
+ // between computations entirely correctly. We are careful to not double count
+ // for the output buffers of whiles/conds/calls. But we don't take into
+ // account other aliases, such as for the while init. A more thorough solution
+ // would require something like BufferAssignment::BuildColocatedBufferSets.
+ // TODO(b/65835246):
+ // Since TuplePointsToAnalysis is being replaced with a module-aware alias
+ // analysis, it's not worth making major changes to HeapSimulator now.
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
@@ -240,6 +257,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
void Alloc(const BufferValue* buffer, int64 size) override;
+ void Alloc(const BufferValue* buffer, int64 size,
+ const HloInstruction* instruction) override;
+
void AccountForSubcomputationMemory(
const HloInstruction* instruction,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index b41dc66fe9..957c4a6891 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -28,13 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewModule();
@@ -84,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
- .ValueOrDie());
+ HloSchedule schedule(module);
+ schedule.set_sequence(cond_computation,
+ {cond_param, cond_iter, cond_data, cond_lt});
+ schedule.set_sequence(body_computation, {body_param});
+ schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(
+ 56,
+ HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
}
const char kAlloc[] = "Alloc";
@@ -137,7 +142,7 @@ class HeapSimulatorTracker {
const string& name, std::unique_ptr<HloComputation> computation,
const std::vector<const HloInstruction*>& instruction_sequence) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
module_->AddEntryComputation(std::move(computation));
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
@@ -146,17 +151,18 @@ class HeapSimulatorTracker {
// the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
// buffer id, for determinism in the tests.
auto zero_size = [](const BufferValue& buffer) { return 0; };
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(
- std::move(algorithm), *module_->entry_computation(),
- instruction_sequence, *points_to_analysis_, zero_size)
- .ConsumeValueOrDie();
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
+ result_ =
+ HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
+ HloInstructionSequence(instruction_sequence),
+ *points_to_analysis_, zero_size)
+ .ConsumeValueOrDie();
}
explicit HeapSimulatorTracker(const string& name) {
HloModuleConfig config;
- module_ = MakeUnique<HloModule>(name, config);
+ module_ = absl::make_unique<HloModule>(name, config);
}
// Similar to the single entry computation constructor above, but runs the
@@ -167,11 +173,12 @@ class HeapSimulatorTracker {
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
// Construct the module sequence grouped by computation.
- SequentialHloOrdering::HloModuleSequence module_sequence;
+ HloSchedule schedule(module_.get());
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
- module_sequence[instruction->parent()].push_back(instruction);
+ schedule.GetOrCreateSequence(instruction->parent())
+ .push_back(instruction);
reverse_position[instruction] = full_module_sequence.size() - i;
}
@@ -182,10 +189,10 @@ class HeapSimulatorTracker {
auto size_fn = [&reverse_position](const BufferValue& buffer) {
return reverse_position[buffer.instruction()];
};
- auto algorithm = MakeUnique<DecreasingSizeRunsHeap>(
- MakeUnique<HeapCallRecorder>(&actual_calls_));
- result_ = HeapSimulator::Run(std::move(algorithm), *module_,
- module_sequence, *points_to_analysis_, size_fn)
+ auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<HeapCallRecorder>(&actual_calls_));
+ result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
+ *points_to_analysis_, size_fn)
.ConsumeValueOrDie();
}
@@ -226,7 +233,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}
@@ -365,8 +372,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot is the output, and it cannot be shared with the buffer
// for mul, since dot isn't elementwise.
@@ -401,8 +408,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
@@ -439,10 +446,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot1 is the output. No buffers can be shared. The buffer
// for mul is freed before the end, since it's no longer used after dot0
@@ -480,10 +487,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
@@ -675,7 +682,8 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
- buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
+ buffers_.emplace_back(
+ absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
@@ -724,7 +732,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
TEST_F(DecreasingSizeRunsHeapTest, Empty) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Finish();
EXPECT_EQ(call_sequence, CallSequence({
{kFinish, nullptr},
@@ -733,7 +742,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) {
TEST_F(DecreasingSizeRunsHeapTest, Simple) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Alloc(buffer_c_, 30);
@@ -760,7 +770,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) {
TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
CallSequence call_sequence;
- DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence));
+ DecreasingSizeRunsHeap heap(
+ absl::make_unique<HeapCallRecorder>(&call_sequence));
heap.Alloc(buffer_a_, 10);
heap.Alloc(buffer_b_, 20);
heap.Free(buffer_b_, 20);
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 9d24b42401..93ec2c9438 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 51
+// Next ID: 53
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -46,6 +46,8 @@ message HloInstructionProto {
reserved "control_predecessor_names";
reserved 6;
reserved "called_computation_names";
+ reserved 44;
+ reserved "replica_group_ids";
string name = 1;
string opcode = 2;
@@ -139,7 +141,7 @@ message HloInstructionProto {
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
- repeated int64 gather_window_bounds = 34;
+ repeated int64 gather_slice_sizes = 34;
// Compute Host.
string channel_name = 41;
@@ -158,9 +160,6 @@ message HloInstructionProto {
string backend_config = 43;
// Cross replica op fields.
- // TODO(b/112107579): remove replica_group_ids field and always use
- // replica_groups.
- repeated int64 replica_group_ids = 44;
repeated ReplicaGroup replica_groups = 49;
int64 all_reduce_id = 45;
string cross_replica_sum_barrier = 46;
@@ -171,6 +170,12 @@ message HloInstructionProto {
bool is_host_transfer = 47;
xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
+
+ // Precision configuration for the instruction. Has backend-specific meaning.
+ xla.PrecisionConfig precision_config = 51;
+
+ // Collective permute field.
+ repeated SourceTarget source_target_pairs = 52;
}
// Serialization of HloComputation.
@@ -194,6 +199,17 @@ message HloComputationProto {
int64 root_id = 6;
}
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+ message InstructionSequence {
+ repeated int64 instruction_ids = 1;
+ }
+
+ // Map from computation id to sequence.
+ map<int64, InstructionSequence> sequences = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -209,16 +225,9 @@ message HloModuleProto {
// The id of this module.
int64 id = 5;
-}
-// Serialization of HloOrdering.
-message HloOrderingProto {
- // NOTE: currently only sequential orderings are serialized.
- message SequentialComputation {
- string computation_name = 1;
- repeated string instruction_names = 2;
- }
- repeated SequentialComputation sequential_computations = 1;
+ // The schedule for this module.
+ HloScheduleProto schedule = 7;
}
// Serialization of LogicalBuffer.
@@ -317,8 +326,10 @@ message BufferAssignmentProto {
// Grouping message that contains all of the information above.
message HloProto {
+ reserved 2;
+ reserved "hlo_ordering";
+
HloModuleProto hlo_module = 1;
- HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index e8a4b034b4..0986da65cb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,15 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
// Data structure used to construct the alias analysis. Thrown away after alias
// analysis is complete. This data structure keeps track of which sets of
@@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const {
}
string HloAliasAnalysis::ToString() const {
- string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
+ string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
for (const HloComputation* computation : module_->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
@@ -457,7 +455,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
- auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
+ auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false,
@@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
if (ordering.MayInterfere(*values[i - 1], *values[i],
dataflow_analysis())) {
VLOG(1) << "In buffer " << buffer.id() << " containing values:\n "
- << Join(values, ", ",
- [](string* out, const HloValue* value) {
- StrAppend(out, value->ToShortString());
- })
+ << absl::StrJoin(values, ", ",
+ [](string* out, const HloValue* value) {
+ StrAppend(out, value->ToShortString());
+ })
<< "\nValue " << values[i - 1]->ToShortString()
<< " may interfere with value " << values[i]->ToShortString();
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index 1fea544730..e345804537 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index da94ab5346..0cd0ab36fc 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -39,15 +39,17 @@ namespace {
using ::testing::UnorderedElementsAre;
-class HloAliasAnalysisTest : public HloTestBase {
+class HloAliasAnalysisTest : public HloVerifiedTestBase {
protected:
- HloAliasAnalysisTest() : module_(CreateNewModule()) {}
+ HloAliasAnalysisTest() : HloVerifiedTestBase() {
+ module_ = CreateNewModule();
+ }
// Run alias analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
HloAliasAnalysis& RunAnalysis() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis");
- analysis_ = HloAliasAnalysis::Run(module_.get(),
+ analysis_ = HloAliasAnalysis::Run(module_,
/*fusion_can_share_buffer=*/nullptr)
.ConsumeValueOrDie();
return *analysis_;
@@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase {
// never occurs, but HLO graphs with interference can be explicitly
// constructed.
bool AnyValuesInSameBufferInterfere() {
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
for (const HloBuffer& buffer : analysis_->buffers()) {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
@@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return false;
}
- std::unique_ptr<HloModule> module_;
+ HloModule* module_;
std::unique_ptr<HloAliasAnalysis> analysis_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
@@ -461,7 +463,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
module_->AddEntryComputation(builder.Build());
FlattenCallGraph flattener;
- TF_ASSERT_OK(flattener.Run(module_.get()).status());
+ TF_ASSERT_OK(flattener.Run(module_).status());
const HloAliasAnalysis& analysis = RunAnalysis();
@@ -835,7 +837,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) {
const HloAliasAnalysis& analysis = RunAnalysis();
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
@@ -877,24 +879,26 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
{
// Dependency ordering should interfere because the negate and while are
// unordered.
- DependencyHloOrdering ordering(module_.get());
+ DependencyHloOrdering ordering(module_);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
// For a sequential order, if there is interference iff the negate is after
// the while.
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence[body] = {body_param, body_root};
- sequence[condition] = {cond_param, cond_root};
+ HloSchedule schedule(module_);
+ schedule.set_sequence(body, {body_param, body_root});
+ schedule.set_sequence(condition, {cond_param, cond_root});
{
- sequence[entry] = {init, xla_while, negate, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
}
{
- sequence[entry] = {init, negate, xla_while, entry_root};
- SequentialHloOrdering ordering(module_.get(), sequence);
+ schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index e16413f361..6c11a073b7 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -27,15 +29,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
-
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {
@@ -59,10 +56,11 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
}
string HloBuffer::ToString() const {
- return StrCat("HloBuffer ", id_, ", values: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return absl::StrCat(
+ "HloBuffer ", id_, ", values: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) {
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h
index 4873463b2e..a88c87e46c 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.h
+++ b/tensorflow/compiler/xla/service/hlo_buffer.h
@@ -84,7 +84,7 @@ class HloBuffer {
return a->id() == b->id();
}
- HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
+ HloBuffer(Id id, absl::Span<const HloValue* const> values)
: id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 441288da1a..233d2199d1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -23,9 +23,13 @@ limitations under the License.
#include <set>
#include <sstream>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -36,13 +40,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root_instruction) {
@@ -56,8 +58,8 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
HloInstruction* root =
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
- return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, fusion_instruction_));
+ return absl::WrapUnique(new HloComputation(
+ name_, parameter_count, &instructions_, root, fusion_instruction_));
}
HloComputation::HloComputation(
@@ -135,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) {
}
string after_param = original_name.substr(index + param_underscore.size());
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_param, &numeric_suffix)) {
return StrCat(original_name.substr(0, index + param_underscore.size()),
new_param_no);
}
@@ -317,11 +319,12 @@ void ComputeComputationPostOrder(
}
}
-enum State { kVisiting, kVisited };
+} // namespace
-void ComputeInstructionPostOrder(
+void HloComputation::ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
@@ -354,16 +357,71 @@ void ComputeInstructionPostOrder(
for (HloInstruction* op : current->control_predecessors()) {
dfs_stack.emplace_back(op);
}
+
+ // Add inputs for send->recv_done dependencies and cross-replica-sum
+ // dependencies.
+ switch (current->opcode()) {
+ case HloOpcode::kRecvDone: {
+ auto it = channel_dependency_map.find(current->channel_id());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = current->all_reduce_id();
+ if (all_reduce_id) {
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
}
}
-} // namespace
+HloComputation::ChannelDependencyMap
+HloComputation::ComputeChannelDependencies() const {
+ ChannelDependencyMap channel_dependency_map;
+ for (const auto& instruction : instructions_) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ channel_dependency_map[instruction->channel_id()].push_back(
+ instruction.get());
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = instruction->all_reduce_id();
+ if (all_reduce_id) {
+ auto& dependencies = channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(instruction->operands(),
+ std::back_inserter(dependencies));
+ absl::c_copy(instruction->control_predecessors(),
+ std::back_inserter(dependencies));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ return channel_dependency_map;
+}
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatMap<HloInstruction*, State> visited;
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -371,7 +429,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
+ ComputeInstructionPostOrder(channel_dependency_map, &post_order,
+ instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
@@ -405,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
}
string HloComputation::ToString(const HloPrintOptions& options) const {
+ return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const {
+ CHECK_EQ(instruction_order.size(), instruction_count());
+
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
@@ -427,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
CanonicalNameMap name_map;
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (const HloInstruction* instruction : instruction_order) {
+ CHECK_EQ(this, instruction->parent());
+
for (int i = 0; i < new_options.indent_amount(); i++) {
s << " ";
}
@@ -493,13 +562,13 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
+ &instructions, root,
+ /*fusion_instruction=*/nullptr));
}
void HloComputation::FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction) {
CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
HloInstruction* root = instructions_to_fuse.front();
@@ -518,7 +587,7 @@ void HloComputation::FuseInstructionsInto(
}
HloInstruction* HloComputation::CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind) {
HloInstruction* root = instructions_to_fuse.front();
HloInstruction* fusion_instruction = AddInstruction(
@@ -566,16 +635,15 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
- instruction->name().c_str(), name().c_str());
+ instruction->name(), name());
}
if (indices_to_copy != nullptr &&
!ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
return FailedPrecondition(
"Can't deep copy instruction %s: given shape tree of indices to copy "
"has incompatible shapes: %s vs. %s",
- instruction->name().c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- ShapeUtil::HumanString(indices_to_copy->shape()).c_str());
+ instruction->name(), ShapeUtil::HumanString(instruction->shape()),
+ ShapeUtil::HumanString(indices_to_copy->shape()));
}
ShapeIndex index;
@@ -605,7 +673,7 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
- instruction->name().c_str(), name().c_str());
+ instruction->name(), name());
}
ShapeIndex index;
return DeepCopyHelper(instruction, &index, copy_leaf);
@@ -624,6 +692,9 @@ ProgramShape HloComputation::ComputeProgramShape() const {
}
bool HloComputation::operator==(const HloComputation& other) const {
+ if (this == &other) {
+ return true;
+ }
std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
[&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
@@ -674,13 +745,37 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
const auto& all = MakeInstructionPostOrder();
- auto result = MakeUnique<HloReachabilityMap>(all);
+ auto result = absl::make_unique<HloReachabilityMap>(all);
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kRecvDone: {
+ auto it = channel_dependency_map.find(hlo->channel_id());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = hlo->all_reduce_id();
+ if (all_reduce_id) {
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
@@ -723,11 +818,10 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
}
}
VLOG(3) << "Unreachable roots:"
- << tensorflow::str_util::Join(
- unreachable_roots, "\n\t",
- [](string* out, const HloInstruction* hlo) {
- tensorflow::strings::StrAppend(out, hlo->ToString());
- });
+ << absl::StrJoin(unreachable_roots, "\n\t",
+ [](string* out, const HloInstruction* hlo) {
+ absl::StrAppend(out, hlo->ToString());
+ });
return unreachable_roots;
}
@@ -829,7 +923,7 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
HloCloneContext* context, const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
- context_ptr = MakeUnique<HloCloneContext>(parent(), suffix);
+ context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
context = context_ptr.get();
}
@@ -898,12 +992,11 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
name_ = name_uniquer->GetUniqueName(name_);
}
-HloInstruction* HloComputation::GetInstructionWithName(
- tensorflow::StringPiece name) {
+HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
auto instructions_in_computation = instructions();
- auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) {
- return instr->name() == name;
- });
+ auto it = absl::c_find_if(
+ instructions_in_computation,
+ [&](HloInstruction* instr) { return instr->name() == name; });
return it == instructions_in_computation.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 49ed65910f..91c5234a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -39,7 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@@ -170,6 +170,11 @@ class HloComputation {
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
+ // Overload which accepts an order to emit the instructions in.
+ string ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const;
+
// Returns a serialized representation of this computation.
HloComputationProto ToProto() const;
@@ -237,7 +242,7 @@ class HloComputation {
// removed if they have no uses after fusion (this is necessarily true for at
// least the root).
HloInstruction* CreateFusionInstruction(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction::FusionKind fusion_kind);
// Create a deep copy of the given instruction and return the instruction
@@ -367,7 +372,7 @@ class HloComputation {
// Returns the instruction in this computation that has name `name`. Returns
// null if there is no such computation.
- HloInstruction* GetInstructionWithName(tensorflow::StringPiece name);
+ HloInstruction* GetInstructionWithName(absl::string_view name);
int64 unique_id() const { return unique_id_; }
@@ -385,7 +390,7 @@ class HloComputation {
//
// Pre-condition: fusion_instruction's opcode is kFusion.
void FuseInstructionsInto(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ absl::Span<HloInstruction* const> instructions_to_fuse,
HloInstruction* fusion_instruction);
// Internal helper for recursive copying of an instruction. Creates and
@@ -399,6 +404,20 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
+ // Returns a map from channel-id to directed dependencies of the channel
+ // instructions. For send&recv pairs it means the send instruction and for
+ // cross-replica-sum the union of the dependencies for all participating
+ // instructions.
+ using ChannelDependencyMap =
+ tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ ChannelDependencyMap ComputeChannelDependencies() const;
+
+ enum VisitState { kVisiting, kVisited };
+ void ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
+
string name_;
int64 unique_id_;
HloInstruction* root_instruction_;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index e4c5470331..2aaaef1d36 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -691,6 +700,27 @@ TEST_F(HloComputationTest, StringificationCanonical) {
EXPECT_EQ(computation->ToString(options), expected_computation2);
}
-} // namespace
+TEST_F(HloComputationTest, ChannelReachability) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
+ HloComputation::Builder builder("ChannelReachability");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto send =
+ builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto recv =
+ builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build(recv_done));
+ auto reachability = computation->ComputeReachability();
+ EXPECT_TRUE(reachability->IsReachable(param, recv_done));
+ EXPECT_FALSE(reachability->IsReachable(send, recv));
+ EXPECT_FALSE(reachability->IsReachable(send_done, recv));
+}
+
+} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 7229031c0c..f837816cea 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -38,7 +39,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Limit the constant folding to 0 iterations to skip folding loops. This
// retains the behavior from before while loop support in HloEvaluator and may
// be revised.
- auto evaluator = MakeUnique<HloEvaluator>(/*max_loop_iterations=*/0);
+ auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
@@ -51,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
- // Skip Constant, Parameter, Reduce, and AfterAll operation.
- // TODO(b/35975797): Enable Reduce operation once arbitrary computation
- // are supported by the evaluator.
+ // Skip Constant, Parameter, and AfterAll operation.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
// operand in which case constant folding will be impossible and this
@@ -61,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
@@ -73,14 +71,15 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
- if (instruction->opcode() == HloOpcode::kBroadcast) {
+ if (instruction->opcode() == HloOpcode::kBroadcast ||
+ instruction->opcode() == HloOpcode::kIota) {
continue;
}
- std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+ Literal result;
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
- if (result == nullptr) {
+ if (!evaluator->TryEvaluate(instruction, &result)) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 331480bd02..4557983a9c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -25,7 +25,7 @@ namespace xla {
// computation on constants.
class HloConstantFolding : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "constant_folding"; }
+ absl::string_view name() const override { return "constant_folding"; }
// Run constant folding operations on the given module. Returns whether the
// module was changed (constant expressions folded).
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 64a42c1efc..4da42844bd 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -104,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
TEST_F(HloConstantFoldingTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
- tensorflow::gtl::ArraySlice<int64> dimensions;
- tensorflow::gtl::ArraySlice<int64> concat_sizes;
+ absl::Span<const int64> dimensions;
+ absl::Span<const int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
@@ -174,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -195,12 +196,52 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
bool matched = true;
root->literal().EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
+ [&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
+const char* const kConstantFoldReduce = R"(
+ HloModule ConstantFoldReduce
+
+ add {
+ a = s32[] parameter(0)
+ b = s32[] parameter(1)
+ ROOT add = s32[] add(a, b)
+ }
+
+ ENTRY r {
+ x = s32[3] constant({1, 2, 3})
+ init = s32[] constant(0)
+ ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
+ })";
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_TRUE(result);
+
+ EXPECT_EQ(6, module->entry_computation()
+ ->root_instruction()
+ ->literal()
+ .GetFirstElement<int32>());
+}
+
+TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(kConstantFoldReduce));
+ HloInstruction* add = module->computations().begin()->root_instruction();
+ LayoutUtil::ClearLayout(add->mutable_shape());
+ HloConstantFolding const_folder;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ EXPECT_FALSE(result);
+
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 1bbb0ff08e..a502fff9a0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
+ // Domain does not have any computation or data transfer.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
@@ -258,10 +266,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
- return Status::OK();
-}
-
Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
// Compute properties of the mapped function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
@@ -278,15 +282,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
}
Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
- auto arg = reduce->operand(0);
HloComputation* function = reduce->to_apply();
// Compute the cost of the user function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
ProcessSubcomputation(function));
// Compute the cost of all elements for this Reduce operation.
- int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
- ShapeUtil::ElementsIn(reduce->shape());
+ // This counts the number of times the reduction function is applied, so it
+ // does not need to be multiplied by the number of input tensors - that's
+ // already "priced in" by the sub-computation doing more work.
+ auto arg = reduce->operand(0);
+ auto output_shape = ShapeUtil::IsArray(reduce->shape())
+ ? reduce->shape()
+ : reduce->shape().tuple_shapes(0);
+ int64 reduction_count =
+ ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
for (const auto& property : sub_properties) {
if (property.first != kBytesAccessedKey) {
current_properties_[property.first] = property.second * reduction_count;
@@ -505,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
valid_position_counts.push_back(valid_position_count);
}
- const int64 fma_count =
- input_feature * output_feature * batch * Product(valid_position_counts);
+ const int64 fma_count = (input_feature / convolution->feature_group_count()) *
+ output_feature * batch *
+ Product(valid_position_counts);
current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return Status::OK();
}
@@ -544,15 +555,10 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
}
Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
- // TODO(b/110096724): Compute correct cost here.
- double flops = 0.0;
- ShapeUtil::ForEachSubshape(hlo->shape(),
- [&](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsArray(subshape)) {
- flops += ShapeUtil::ElementsIn(subshape);
- }
- });
- current_properties_[kFlopsKey] = flops;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 193a04bea0..46b4bbeef2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -67,14 +67,15 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
+ Status HandleDomain(const HloInstruction* domain) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override;
+ Status HandleCollectivePermute(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
- Status HandleHostCompute(const HloInstruction* host_compute) override;
Status HandleRng(const HloInstruction* random) override;
Status HandleReverse(const HloInstruction* reverse) override;
Status HandleSort(const HloInstruction* sort) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854eea18..d76ce9ecbc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) {
sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
}
+TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
+ XlaBuilder builder("convolution");
+ auto input = Parameter(
+ &builder, 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = Parameter(
+ &builder, 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x120x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
+
+ // Bytes accessed is sum of inputs and output.
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
+}
+
TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
@@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- XlaBuilder builder("matmul");
+ XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
@@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+ HloCostAnalysis analysis(ShapeSize);
+
+ HloComputation::Builder builder("domain");
+ auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {123}), "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+ auto domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+ ASSERT_IS_OK(domain->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(*domain), 0);
+ EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+ EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
auto input = Parameter(
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 858992a326..b76c50bb5b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -14,15 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::ArraySlice;
-using tensorflow::strings::StrCat;
+using absl::StrCat;
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs) {
@@ -48,9 +49,9 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
- ArraySlice<int64> start_indices,
- ArraySlice<int64> limit_indices,
- ArraySlice<int64> strides) {
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape(
operand->shape(), start_indices,
@@ -60,19 +61,22 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
- TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape(
- lhs->shape(), rhs->shape(),
- window, dimension_numbers));
+ TF_ASSIGN_OR_RETURN(Shape convolve_shape,
+ ShapeInference::InferConvolveShape(
+ lhs->shape(), rhs->shape(), feature_group_count,
+ window, dimension_numbers));
return computation->AddInstruction(HloInstruction::CreateConvolve(
- convolve_shape, lhs, rhs, window, dimension_numbers));
+ convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config));
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
- ArraySlice<int64> dimensions) {
+ absl::Span<const int64> dimensions) {
HloComputation* computation = operand->parent();
TF_ASSIGN_OR_RETURN(
Shape transpose_shape,
@@ -89,15 +93,15 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
}
StatusOr<HloInstruction*> MakeReshapeHlo(
- ArraySlice<int64> result_shape_dim_bounds, HloInstruction* operand) {
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand) {
Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_dim_bounds);
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> MakeDynamicSliceHlo(HloInstruction* operand,
- HloInstruction* start_indices,
- ArraySlice<int64> slice_sizes) {
+StatusOr<HloInstruction*> MakeDynamicSliceHlo(
+ HloInstruction* operand, HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes) {
HloComputation* computation = operand->parent();
CHECK_EQ(computation, start_indices->parent());
TF_ASSIGN_OR_RETURN(
@@ -123,8 +127,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
}
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand, ArraySlice<int64> broadcast_dimensions,
- ArraySlice<int64> result_shape_bounds) {
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds) {
HloComputation* computation = operand->parent();
Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
result_shape_bounds);
@@ -144,18 +148,18 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
HloInstruction::CreateGetTupleElement(gte_shape, operand, index));
}
-StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
- int64 dimension) {
+StatusOr<HloInstruction*> MakeConcatHlo(
+ absl::Span<HloInstruction* const> operands, int64 dimension) {
CHECK_GT(operands.size(), 0);
HloComputation* computation = operands[0]->parent();
- CHECK(c_all_of(operands, [&](HloInstruction* instr) {
+ CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) {
return instr->parent() == computation;
}));
std::vector<const Shape*> operand_shapes;
- c_transform(operands, std::back_inserter(operand_shapes),
- [](HloInstruction* instr) { return &instr->shape(); });
+ absl::c_transform(operands, std::back_inserter(operand_shapes),
+ [](HloInstruction* instr) { return &instr->shape(); });
TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape(
operand_shapes, dimension));
@@ -164,19 +168,19 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
}
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers) {
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
- return computation->AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+ return computation->AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dim_numbers, precision_config));
}
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation) {
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation) {
CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
HloComputation* computation = operands.front()->parent();
std::vector<const Shape*> operand_shapes;
@@ -228,19 +232,19 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
const Shape& operand_shape = operand->shape();
new_shape_dims.reserve(n + operand_shape.dimensions_size());
new_shape_dims.insert(new_shape_dims.begin(), n, 1);
- c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
+ absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims));
return MakeReshapeHlo(new_shape_dims, operand);
}
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, ArraySlice<int64> expanded_dims) {
+ HloInstruction* operand, absl::Span<const int64> expanded_dims) {
CHECK_GT(operand->shape().dimensions_size(), 0);
CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims));
std::vector<int64> expanded_shape_dim_bounds;
expanded_shape_dim_bounds.reserve(expanded_dims.size() +
operand->shape().dimensions_size() - 1);
- c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
+ absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds));
std::copy(operand->shape().dimensions().begin() + 1,
operand->shape().dimensions().end(),
std::back_inserter(expanded_shape_dim_bounds));
@@ -249,9 +253,9 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
return MakeReshapeHlo(new_shape, operand);
}
-StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
- ArraySlice<int64> dims_to_elide) {
- CHECK(c_is_sorted(dims_to_elide));
+StatusOr<HloInstruction*> ElideDegenerateDims(
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide) {
+ CHECK(absl::c_is_sorted(dims_to_elide));
const Shape& input_shape = operand->shape();
// First accumulate in reverse
@@ -268,15 +272,15 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
}
}
- c_reverse(new_shape_dim_bounds);
+ absl::c_reverse(new_shape_dim_bounds);
Shape output_shape =
ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds);
return MakeReshapeHlo(output_shape, operand);
}
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
- CHECK(c_is_sorted(dims_to_insert));
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert) {
+ CHECK(absl::c_is_sorted(dims_to_insert));
const Shape& operand_shape = operand->shape();
int64 output_shape_rank =
@@ -317,26 +321,25 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(
- LiteralUtil::Zero(operand->shape().element_type()))));
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(operand->shape().element_type())));
return MakePadHlo(operand, zero, padding_config);
}
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- ArraySlice<int64> broadcast_dimensions) {
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ absl::Span<const int64> broadcast_dimensions) {
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name) {
- HloComputation::Builder b{std::string(name)};
+ absl::Span<const Shape* const> domain, const Shape& range,
+ absl::string_view name) {
+ HloComputation::Builder b{string(name)};
int64 param_idx = 0;
for (const Shape* param_shape : domain) {
b.AddInstruction(HloInstruction::CreateParameter(
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 5ff8946fb0..b22058abb4 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -40,21 +40,22 @@ StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
// Creates a slice HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeSliceHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
-StatusOr<HloInstruction*> MakeTransposeHlo(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Creates a reshape HLO instruction and adds it to the computation containing
// `operand`.
@@ -62,15 +63,14 @@ StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
HloInstruction* operand);
StatusOr<HloInstruction*> MakeReshapeHlo(
- tensorflow::gtl::ArraySlice<int64> result_shape_dim_bounds,
- HloInstruction* operand);
+ absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
// Creates a dynamic-slice HLO instruction and adds it to the computation
// containing `operand` and `start_indices` (`operand` and `start_indices` must
// be in the same computation).
StatusOr<HloInstruction*> MakeDynamicSliceHlo(
HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Creates a dynamic-update-slice HLO instruction and adds it to the computation
// containing `operand`, `update` and `start_indices` (`operand`, `update` and
@@ -82,9 +82,8 @@ StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
// Creates a broadcast HLO instruction and adds it to the computation containing
// `operand`.
StatusOr<HloInstruction*> MakeBroadcastHlo(
- HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions,
- tensorflow::gtl::ArraySlice<int64> result_shape_bounds);
+ HloInstruction* operand, absl::Span<const int64> broadcast_dimensions,
+ absl::Span<const int64> result_shape_bounds);
// Creates a GetTupleElement HLO instruction and adds it to the computation
// containing `operand`.
@@ -95,18 +94,18 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
// containing `operands` (`operands` must be non-empty and every element must be
// contained in the same computation).
StatusOr<HloInstruction*> MakeConcatHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
+ absl::Span<HloInstruction* const> operands, int64 dimension);
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers);
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
-StatusOr<HloInstruction*> MakeMapHlo(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
@@ -138,7 +137,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
// For instance if `operand` has shape f32[200,9,7] and expanded_dims is
// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> expanded_dims);
+ HloInstruction* operand, absl::Span<const int64> expanded_dims);
// Elides (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_elide` from `operand`. Every dimension in
@@ -148,7 +147,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
StatusOr<HloInstruction*> ElideDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+ HloInstruction* operand, absl::Span<const int64> dims_to_elide);
// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
// exactly one element), `dims_to_insert` into `operand`. The dimensions in
@@ -158,7 +157,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(
// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
StatusOr<HloInstruction*> InsertDegenerateDims(
- HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+ HloInstruction* operand, absl::Span<const int64> dims_to_insert);
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
@@ -171,13 +170,13 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
// broadcast instruction is emitted into `computation`.
StatusOr<HloInstruction*> BroadcastZeros(
HloComputation* computation, PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a HLO computation that takes arguments of type `domain` and produces
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
- tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range,
- tensorflow::StringPiece name);
+ absl::Span<const Shape* const> domain, const Shape& range,
+ absl::string_view name);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 60d3e71757..e07a196d11 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -14,23 +14,22 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
-class HloCreationUtilsTest : public HloTestBase {
+class HloCreationUtilsTest : public HloVerifiedTestBase {
protected:
- static std::unique_ptr<HloModule> CreateModuleWithProgramShape(
- PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims,
- ArraySlice<int64> output_shape_dims, HloInstruction** param,
+ HloModule* CreateModuleWithProgramShape(
+ PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
+ absl::Span<const int64> output_shape_dims, HloInstruction** param,
HloComputation** entry_computation) {
Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
Shape output_shape =
@@ -48,27 +47,27 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
CollapseFirstNDims(param, 1));
entry_computation->set_root_instruction(first_1_dims_collapsed);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, &param,
&entry_computation);
@@ -79,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module,
{LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<int32>(
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -93,10 +92,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{1, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
PrependDegenerateDims(param, 1));
@@ -104,17 +103,17 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, &param,
&entry_computation);
@@ -125,37 +124,37 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module,
+ {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{1, 1},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
PrependDegenerateDims(param, 2));
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
+ HloModule* module = CreateModuleWithProgramShape(
S32,
/*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, &param,
&entry_computation);
@@ -166,21 +165,21 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{2},
+ /*output_shape_dims=*/{6},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zero_padded_param,
@@ -188,20 +187,20 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
entry_computation->set_root_instruction(zero_padded_param);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- S32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(S32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
@@ -209,20 +208,20 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {LiteralUtil::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ Literal result_literal,
+ evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloInstruction* param;
HloComputation* entry_computation;
- std::unique_ptr<HloModule> module = CreateModuleWithProgramShape(
- F32,
- /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, &param,
- &entry_computation);
+ HloModule* module = CreateModuleWithProgramShape(F32,
+ /*input_shape_dims=*/{},
+ /*output_shape_dims=*/{2, 2},
+ &param, &entry_computation);
TF_ASSERT_OK_AND_ASSIGN(
HloInstruction * zeros,
@@ -230,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
entry_computation->set_root_instruction(zeros);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+ evaluator.Evaluate<Literal>(
*module, {LiteralUtil::CreateR0<float>(0.0f)}));
- CHECK_EQ(*result_literal,
- *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ CHECK_EQ(result_literal,
+ LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 06484f4012..b59c9ba3ed 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,7 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) {
for (auto operand : instruction->operands()) {
hash = tensorflow::Hash64Combine(hash, operand->unique_id());
}
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
+ }
return hash;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index 5e2b348bdd..a28c03599a 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface {
: is_layout_sensitive_(is_layout_sensitive),
only_fusion_computations_(only_fusion_computations) {}
~HloCSE() override = default;
- tensorflow::StringPiece name() const override { return "cse"; }
+ absl::string_view name() const override { return "cse"; }
// Run CSE on the given module. Returns whether the module was changed (common
// subexpressions were found and eliminated).
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 90fbaa37c5..9b18b0284f 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -20,16 +20,16 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class HloCseTest : public HloTestBase {
+class HloCseTest : public HloVerifiedTestBase {
protected:
HloCseTest() {}
};
@@ -65,15 +65,15 @@ TEST_F(HloCseTest, CombineTwoConstants) {
EXPECT_EQ(3, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
HloInstruction* constant = *computation->instructions().begin();
EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -96,16 +96,16 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
auto first_operand = add->operand(0);
EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
EXPECT_THAT(add, op::Add(first_operand, first_operand));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -128,14 +128,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
EXPECT_THAT(add, op::Add(constant1, constant2));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
@@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
EXPECT_EQ(20, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
// CSE will remove both the second float(42.0f) and the corresponding
// convert/cast.
@@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) {
op::Tuple(common_constant1, common_constant2, uncommon_constant));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
// Test two identical while loops with same inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
}
// Test two while loops with same conditions, same inputs, but different
// bodies
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body2
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
// Test two identical while loops with different inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(8, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(8, computation->instruction_count());
}
// Test two while loops with identical bodies and same inputs, but different
// conditions
TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
%while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body
- })")
- .ValueOrDie();
+ })");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
@@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
@@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) {
EXPECT_EQ(5, fused_computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, fused_computation->instruction_count());
auto root = fused_computation->root_instruction();
@@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) {
EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
auto operand = tuple->operand(0);
@@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) {
uint32 count_before = computation->instruction_count();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
uint32 count_after = computation->instruction_count();
EXPECT_EQ(count_before, count_after);
@@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
VLOG(3) << "before: " << module->ToString();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
VLOG(3) << "after: " << module->ToString();
@@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
}
TEST_F(HloCseTest, CompareComputations) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule m
add_computation {
@@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) {
r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
ROOT f2 = (f32[],f32[]) tuple(r1, r2)
- })")
- .ValueOrDie();
+ })");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ HloInstruction* root = module().entry_computation()->root_instruction();
EXPECT_EQ(root->operand(0), root->operand(1));
}
@@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
}
TEST_F(HloCseTest, Domain) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule module
ENTRY %entry {
%param = f32[] parameter(0), sharding={maximal device=0}
@@ -735,13 +730,11 @@ ENTRY %entry {
domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
%add = f32[] add(%domain.3, %domain.4)
ROOT %sub = f32[] subtract(%add, %domain.5)
-})")
- .ValueOrDie();
+})");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- LOG(INFO) << "AAAAA " << module->ToString();
- const HloInstruction* sub = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ const HloInstruction* sub = module().entry_computation()->root_instruction();
const HloInstruction* add = sub->operand(0);
EXPECT_EQ(add->operand(0), add->operand(1));
EXPECT_NE(add->operand(0), sub->operand(1));
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index bbfb0c253f..6a63681996 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,8 +19,10 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -29,8 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -46,8 +46,7 @@ namespace {
//
// In this case, we should be able to reuse p0 and output, although p0 has
// multiple uses.
-bool MultiDynamicSliceUseShareSameIndices(
- tensorflow::gtl::ArraySlice<HloUse> uses) {
+bool MultiDynamicSliceUseShareSameIndices(absl::Span<const HloUse> uses) {
if (uses.empty()) {
return false;
}
@@ -78,8 +77,8 @@ bool MultiDynamicSliceUseShareSameIndices(
} // namespace
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
HloDataflowAnalysis::HloDataflowAnalysis(
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
@@ -93,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
tensorflow::gtl::FlatSet<const HloInstruction*> visited;
- tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack;
+ absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(inst);
while (!stack.empty()) {
const HloInstruction* current = stack.back();
@@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const {
bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK(ssa_form_);
VLOG(4) << "Phi(" << instruction->name() << ")";
VLOG(5) << "instruction value set = "
@@ -837,7 +836,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
return Unimplemented(
"Computation %s is called in both a parallel (eg, kMap) and "
"sequential (eg, kCall) context",
- computation->name().c_str());
+ computation->name());
}
if (call_graph_node.caller_callsites().empty() ||
call_graph_node.context() == CallContext::kParallel) {
@@ -886,7 +885,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis(
+ auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
@@ -976,28 +975,22 @@ Status HloDataflowAnalysis::Verify() const {
bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
const HloInstruction* operand, const ShapeIndex& index,
const HloInstruction* user) const {
- CHECK(user->IsUserOf(operand))
- << "user: " << user->ToString() << " operand: " << operand->ToString();
- if (user->opcode() == HloOpcode::kFusion &&
- user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
- // Find fusion parameter associated with 'operand'.
- HloInstruction* fusion_param =
- user->fused_parameter(user->operand_index(operand));
- // Iterate through all users of all uses of the fusion parameter value.
- // Return false if any uses are detected, returns true otherwise.
- const HloValue& value = GetValueDefinedAt(fusion_param, index);
- return value.uses().empty();
- } else {
- // Return false if no value at 'operand' and 'index' is used at 'user'.
- for (const HloValue* value : GetValueSet(operand, index).values()) {
- for (const HloUse& use : value->uses()) {
- if (use.instruction == user) {
- return false;
+ // Return false if no value at 'operand' and 'index' is used at 'user'.
+ for (const HloValue* value : GetValueSet(operand, index).values()) {
+ for (const HloUse& use : value->uses()) {
+ if (use.instruction == user) {
+ if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ HloInstruction* fusion_param =
+ user->fused_parameter(use.operand_number);
+ const HloValue& value =
+ GetValueDefinedAt(fusion_param, use.operand_index);
+ return value.uses().empty();
}
+ return false;
}
}
}
-
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index f4abc7a7c7..e62c1c2ac8 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -138,7 +138,8 @@ class HloDataflowAnalysis {
// Returns true if 'user' cannot possibly use the buffer at 'index' in
// 'operand'. Returns false otherwise.
//
- // REQUIRES: 'operand' is an operand of 'user'.
+ // 'operand' does not have to be an operand of 'user'. This can be the case
+ // with indirect uses.
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user) const;
@@ -201,7 +202,7 @@ class HloDataflowAnalysis {
// the given instruction. If skip_top_level is true, then the top level of the
// value set of 'instruction' is not modified.
bool Phi(HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ absl::Span<const InstructionValueSet* const> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4755c4a0cf..510d6360a1 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param0, negate, param1, exp, add}});
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
bool ssa_form = GetParam();
RunAnalysis(ssa_form);
- SequentialHloOrdering::HloModuleSequence sequence;
- sequence.insert({entry, {param, xla_while}});
- sequence.insert({condition, {cond_param, cond_constant}});
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, xla_while});
+ schedule.set_sequence(condition, {cond_param, cond_constant});
// Construct the order such that 'constant' and its use 'exp' are before
// body_param.
- sequence.insert({body, {constant, exp, body_param, add}});
+ schedule.set_sequence(
+ body, {constant, exp, body_param, add, dead_constant, dead_negate});
+ TF_ASSERT_OK(schedule.Verify());
- SequentialHloOrdering ordering(module_.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
// 'add' is live out of the body and will interfere with an later instructions
// such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
auto entry = module_->AddEntryComputation(builder.Build());
RunAnalysis(GetParam());
- SequentialHloOrdering::HloModuleSequence sequence;
- std::vector<const HloInstruction*> order = {param, negate, exp, add};
- sequence.emplace(entry, order);
-
- SequentialHloOrdering ordering(module_.get(), sequence);
+ HloSchedule schedule(module_.get());
+ schedule.set_sequence(entry, {param, negate, exp, add});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
@@ -1963,6 +1966,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
}
+// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
+// parameter tuple.
+TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto t0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
+ auto t1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
+ // Swap the tuple elements.
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
+
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction never uses tuple element 0, but does use element 1.
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
+ // The same holds for the parameter tuple, except that the tuple elements are
+ // swapped in 'tuple'.
+ EXPECT_TRUE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
+ EXPECT_FALSE(
+ dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
+}
+
class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
@@ -2286,8 +2337,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 4e244494d6..1fe69b1395 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -36,7 +36,7 @@ namespace xla {
class HloDCE : public HloPassInterface {
public:
~HloDCE() override {}
- tensorflow::StringPiece name() const override { return "dce"; }
+ absl::string_view name() const override { return "dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 26e3736e01..3b5cde2996 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
index 78955db0da..72185698c9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();
private:
- // Inserts a kDomain instruction between parent and operand, in case
- // the attribute (ie, sharding) values change between instruction and operand.
- // Returns the newly inserted kDomain instruction, or nullptr if no kDomain
- // instruction was necessary.
- StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
- HloInstruction* parent,
- HloInstruction* operand);
-
HloModule* module_;
HloDomainIsolator* isolator_;
};
-StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
- HloInstruction* instruction, HloInstruction* parent,
- HloInstruction* operand) {
- HloInstruction* domain = nullptr;
- std::unique_ptr<HloInstruction> domain_instruction =
- isolator_->creator_(instruction, operand);
- if (domain_instruction != nullptr) {
- domain = operand->parent()->AddInstruction(std::move(domain_instruction));
- TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain));
- }
- return domain;
-}
-
StatusOr<bool> HloDomainIsolator::RunContext::Run() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
@@ -71,16 +50,16 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
// When applying multiple domains, we could end up stacking more than
// one in one edge, so here we want to build the effective
// (kDomain-less) instruction->operand edge.
- HloInstruction* parent = instruction;
- while (operand->opcode() == HloOpcode::kDomain) {
- parent = operand;
- operand = operand->mutable_operand(0);
+ HloInstruction* root = operand;
+ while (root->opcode() == HloOpcode::kDomain) {
+ root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
- TF_ASSIGN_OR_RETURN(HloInstruction * domain,
- CreateDomain(instruction, parent, operand));
+ HloInstruction* domain =
+ isolator_->creator_(instruction, root, operand);
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
++added_domains;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index eded3e78ee..d36631fc2f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -34,14 +34,16 @@ class HloDomainIsolator : public HloPassInterface {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
- // second HloInstruction argument).
+ // third HloInstruction argument) if the interesting attribute of the
+ // instruction differes from the attribute of the root (the second
+ // HloInstruction argument).
// Returns nullptr in case no domain separation is necessary.
- using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
- HloInstruction*, HloInstruction*)>;
+ using DomainCreator = std::function<HloInstruction*(
+ HloInstruction*, HloInstruction*, HloInstruction*)>;
explicit HloDomainIsolator(DomainCreator creator);
- tensorflow::StringPiece name() const override { return "domain_isolator"; }
+ absl::string_view name() const override { return "domain_isolator"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 9e096320db..113fd18eae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,14 +26,14 @@ namespace xla {
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloComputation* computation, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
return std::move(domain_map);
}
/* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
HloModule* module, string domain_kind) {
- auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind)));
+ auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
for (HloComputation* computation : module->computations()) {
TF_RETURN_IF_ERROR(domain_map->Populate(computation));
}
@@ -50,20 +51,24 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
+int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+ return FindOrDie(domain_metadata_id_, instruction);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
// both sides.
for (HloInstruction* operand : instruction->unique_operands()) {
if (IsDomainInstruction(operand)) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(operand);
domain->exit_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
if (instruction == instruction->parent()->root_instruction()) {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
domain->enter_domains.insert(instruction);
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
@@ -71,6 +76,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
}
Status HloDomainMap::Populate(HloComputation* computation) {
+ InstructionOrderMap instructions_post_order;
+ int64 count = 0;
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
+ instructions_post_order.insert(std::make_pair(instruction, count++));
+ }
for (HloInstruction* instruction : computation->instructions()) {
if (IsDomainInstruction(instruction)) {
// If this is a kDomain of the kind we are currently processing, check
@@ -84,9 +94,46 @@ Status HloDomainMap::Populate(HloComputation* computation) {
continue;
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
- CreateDomain(instruction));
+ CreateDomain(instruction, instructions_post_order));
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
+ TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
+ return Status::OK();
+}
+
+Status HloDomainMap::PopulateDomainMetadataMap() {
+ auto hash = [](const DomainMetadata* m) { return m->Hash(); };
+ auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
+ return a->Matches(*b);
+ };
+ tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
+ decltype(equal)>
+ domain_metadata(1024, hash, equal);
+
+ for (auto& domain : instruction_domains_) {
+ int64 domain_metadata_id = -1;
+ if (!domain->enter_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->enter_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->user_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else if (!domain->exit_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->exit_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->operand_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else {
+ domain_metadata_id = 0;
+ }
+ TF_RET_CHECK(domain_metadata_id >= 0);
+ for (HloInstruction* instruction : domain->instructions) {
+ domain_metadata_id_[instruction] = domain_metadata_id;
+ }
+ }
return Status::OK();
}
@@ -142,10 +189,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
}
StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
- HloInstruction* instruction) const {
- auto domain = MakeUnique<DomainMetadata::Domain>();
+ HloInstruction* instruction,
+ const InstructionOrderMap& instructions_order) const {
+ auto domain = absl::make_unique<DomainMetadata::Domain>();
TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
- domain->instructions = MakeNonDomainInstructions(domain->reach_set);
+ domain->instructions =
+ MakeNonDomainInstructions(domain->reach_set, instructions_order);
return std::move(domain);
}
@@ -167,7 +216,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
/* static */ std::vector<HloInstruction*>
HloDomainMap::MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set) {
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const InstructionOrderMap& instructions_order) {
std::vector<HloInstruction*> instructions;
instructions.reserve(instruction_set.size());
for (HloInstruction* instruction : instruction_set) {
@@ -175,9 +225,10 @@ HloDomainMap::MakeNonDomainInstructions(
instructions.push_back(instruction);
}
}
+ // sort instructions according to instructions_order
std::sort(instructions.begin(), instructions.end(),
- [](HloInstruction* a, HloInstruction* b) {
- return a->unique_id() < b->unique_id();
+ [&instructions_order](HloInstruction* a, HloInstruction* b) {
+ return instructions_order.at(a) < instructions_order.at(b);
});
return instructions;
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 1ca7159725..56b557d7ce 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -69,7 +69,17 @@ class HloDomainMap {
// instruction is not found within any domain.
int64 GetDomainId(HloInstruction* instruction) const;
+ // Returns the unique id of the domain metadata for the domain the given
+ // instruction belongs to. The given instruction must not be a kDomain
+ // instruction since each domain instruction is associated with 2 domains.
+ int64 GetDomainMetadataId(HloInstruction* instruction) const;
+
private:
+ // Map used for representing instruction ordering, i.e.
+ // order_map[a] < order_map[b] means a must be ordered before b.
+ using InstructionOrderMap =
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64>;
+
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
// Check if the kDomain instruction is facing (via its operand link) another
@@ -95,16 +105,23 @@ class HloDomainMap {
// Creates a domain data structure using the ExpandDomain() API.
StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain(
- HloInstruction* instruction) const;
+ HloInstruction* instruction,
+ const InstructionOrderMap& instructions_order) const;
// Out of an instruction set, returns a vector of all the ones which are not
// a kDomain kind.
static std::vector<HloInstruction*> MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set);
+ const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const InstructionOrderMap& instructions_order);
+
+ // Populates domain_metadata_id_ that maps each HloInstruction to the unique
+ // ID of its associated domain metatadata.
+ Status PopulateDomainMetadataMap();
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index f855f2a1fc..302807f816 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,10 +20,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -44,7 +44,10 @@ class DomainMetadata {
// two domains of different kind intersect each other.
tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
- // The same instructions in reach_set, but purged from kDomain instructions.
+ // The same instructions in reach_set, but purged from kDomain instructions
+ // and ordered according to their computation graph post-order, i.e.
+ // if instructions[pos_a] depends on instructions[pos_b], then pos_a >
+ // pos_b.
std::vector<HloInstruction*> instructions;
// If we consider a graph edge as an arrow oriented from the operand to the
@@ -63,12 +66,15 @@ class DomainMetadata {
// Returns the metadata type. A unique identifier which describes the real
// metadata type.
- virtual tensorflow::StringPiece Kind() const = 0;
+ virtual absl::string_view Kind() const = 0;
// Compares the metadata object with another one and returns true if the
// two matches.
virtual bool Matches(const DomainMetadata& other) const = 0;
+ // Returns the hash value of the metadata.
+ virtual size_t Hash() const = 0;
+
// Returns a string representation of the metadata.
virtual string ToString() const = 0;
};
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index c859e05f02..97bc8ef604 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface {
// instructions in it with the same attributes (ie, sharding), a normalizer
// function is tasked at applying attribute normalization on the instructions
// within such domain.
- HloDomainRemover(tensorflow::StringPiece kind,
+ HloDomainRemover(absl::string_view kind,
std::function<Status(const DomainMetadata::Domain&,
const DomainMetadata* metadata)>
normalizer)
- : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {}
+ : kind_(kind), normalizer_(std::move(normalizer)) {}
- tensorflow::StringPiece name() const override { return "domain_remover"; }
+ absl::string_view name() const override { return "domain_remover"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 70271be304..43e74d2f6f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
@@ -45,9 +46,8 @@ class HloDomainTest : public HloVerifiedTestBase {
// Checks whether there is a kDomain instruction in the edge between the
// instruction and the operand.
- bool HasDomainEdge(HloModule* module,
- tensorflow::StringPiece instruction_name,
- tensorflow::StringPiece operand_name) {
+ bool HasDomainEdge(HloModule* module, absl::string_view instruction_name,
+ absl::string_view operand_name) {
HloInstruction* instruction = FindInstruction(module, instruction_name);
HloInstruction* operand = FindInstruction(module, operand_name);
CHECK_NE(instruction, nullptr);
@@ -65,7 +65,7 @@ class HloDomainTest : public HloVerifiedTestBase {
return false;
}
- StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) {
+ StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
ParseAndVerifyModule(hlo_string, config);
@@ -80,10 +80,10 @@ class OpNameMetadata : public DomainMetadata {
explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {}
std::unique_ptr<DomainMetadata> Clone() const override {
- return MakeUnique<OpNameMetadata>(opname_);
+ return absl::make_unique<OpNameMetadata>(opname_);
}
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override {
const OpNameMetadata* other_ptr =
@@ -97,25 +97,28 @@ class OpNameMetadata : public DomainMetadata {
string ToString() const override { return opname_; }
- static tensorflow::StringPiece KindName() { return "opname"; }
+ static absl::string_view KindName() { return "opname"; }
+
+ size_t Hash() const override { return std::hash<string>()(opname_); }
private:
string opname_;
};
// Creator function for OpNameMetadata domains.
-std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
- HloInstruction* operand) {
- if (instruction->metadata().op_name() == operand->metadata().op_name()) {
+HloInstruction* OpNameDomainCreator(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
+ if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<OpNameMetadata>(operand->metadata().op_name());
+ absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<OpNameMetadata>(instruction->metadata().op_name());
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
+ absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
+ return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata)));
}
Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
@@ -142,7 +145,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -184,7 +187,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(!isolator_changed);
}
@@ -211,7 +214,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -248,7 +251,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_FALSE(isolator_changed);
}
@@ -302,7 +305,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator sharding_isolator(CreateShardingDomain);
+ HloDomainIsolator sharding_isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
sharding_isolator.Run(module));
EXPECT_TRUE(sharding_isolator_changed);
@@ -344,7 +347,8 @@ ENTRY entry {
token = token[] after-all()
infeed = ((f32[4], f32[4]), token[]) infeed(token),
sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
- infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0
+ infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0,
+ sharding={{maximal device=1}, {maximal device=0}}
gte0 = f32[4] get-tuple-element(infeed.data), index=0
gte1 = f32[4] get-tuple-element(infeed.data), index=1
copy0 = f32[4] copy(gte0)
@@ -356,7 +360,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -378,11 +382,8 @@ ENTRY entry {
// \ /
// TUPLE
// |
- HloInstruction* infeed = FindInstruction(module, "infeed");
- ASSERT_NE(infeed, nullptr);
- HloInstruction* infeed_data =
- infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0));
+ HloInstruction* infeed_data = FindInstruction(module, "infeed.data");
+ ASSERT_NE(infeed_data, nullptr);
auto infeed_data_users = infeed_data->users();
HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
@@ -445,7 +446,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -474,8 +475,8 @@ ENTRY entry {
TEST_F(HloDomainTest, DumpParseNullSharding) {
auto builder = HloComputation::Builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {});
- auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr);
- auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr);
+ auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr);
+ auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr);
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
@@ -490,6 +491,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) {
ASSERT_TRUE(ParseModule(hlo_string).status().ok());
}
+// Tuple inputs are domain instructions.
TEST_F(HloDomainTest, DomainTuple) {
const char* const hlo_string = R"(
HloModule Module
@@ -497,14 +499,15 @@ HloModule Module
ENTRY entry {
p0 = f32[4] parameter(0), sharding={maximal device=0}
cst = u32[] constant(0), sharding={maximal device=1}
- tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}}
+ tpl = (u32[], f32[4]) tuple(cst, p0),
+ sharding={{maximal device=1}, {maximal device=0}}
ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
- HloDomainIsolator isolator(CreateShardingDomain);
+ HloDomainIsolator isolator(ShardingDomainCreator{});
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
@@ -523,5 +526,168 @@ ENTRY entry {
tpl->sharding());
}
+TEST_F(HloDomainTest, MultiDomainMultiUser) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
+ %p0 = (f32[4], f32[4]) parameter(0)
+ %a = f32[4]{0} get-tuple-element(%p0), index=0
+ %domain = f32[4] domain(%a),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %b = f32[4] get-tuple-element(%p0), index=1
+ %domain.1 = f32[4] domain(%b),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
+ %domain.2 = f32[4] domain(%c),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %d = f32[4] subtract(%domain, %c),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %domain.3 = f32[4] domain(%d),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %e = f32[4] multiply(%c, %d),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
+ %domain.4 = f32[4]{0} domain(%f),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator opname_isolator(OpNameDomainCreator);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
+ opname_isolator.Run(module));
+ EXPECT_TRUE(opname_isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
+
+ HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
+ ShardingMetadata::NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
+ sharding_remover.Run(module));
+ EXPECT_TRUE(sharding_remover_changed);
+
+ HloDomainRemover opname_remover(OpNameMetadata::KindName(),
+ OpNameDomainNormalizer);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
+ opname_remover.Run(module));
+ EXPECT_TRUE(opname_remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "c"));
+}
+
+// Emulate instructions inserted at top and bottom within nested tuple domain.
+TEST_F(HloDomainTest, DomainTupleTopBottomInsert) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ p0 = f32[4] parameter(0), sharding={maximal device=1}
+ p1 = (f32[5], f32[6]) parameter(1),
+ sharding={{maximal device=1}, {maximal device=0}}
+ tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1),
+ sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}}
+ ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1,
+ sharding={{maximal device=1}, {maximal device=0}}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+
+ HloDomainIsolator isolator(ShardingDomainCreator{});
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
+ EXPECT_TRUE(isolator_changed);
+
+ // Clear sharding of tuple.0 instruction, in order to test domain sharding
+ // application.
+ auto tuple0 = FindInstruction(module, "tuple.0");
+ tuple0->clear_sharding();
+
+ // Insert the following instructons above and below tuple.0, to emulate other
+ // passes effects:
+ // COPY.0
+ // \ /
+ // TUPLE.0
+ // / \
+ // COPY.1 \
+ // / \
+ // GTE.0 GTE.1
+ // | |
+ // | COPY.2
+ // \ /
+ // \ /
+ // TUPLE.1
+ // |
+ auto tuple0_users = tuple0->users();
+ auto computation = tuple0->parent();
+ HloInstruction* copy0 = computation->AddInstruction(
+ HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy,
+ tuple0->mutable_operand(1)));
+ TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0));
+
+ HloInstruction* copy1 = computation->AddInstruction(
+ HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0));
+ HloInstruction* gte0 =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0));
+ HloInstruction* gte1 =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1));
+ HloInstruction* copy2 = computation->AddInstruction(
+ HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1));
+ HloInstruction* tuple1 =
+ computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2}));
+
+ for (HloInstruction* user : tuple0_users) {
+ TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1));
+ }
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ ShardingMetadata::NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
+ EXPECT_TRUE(remover_changed);
+
+ EXPECT_TRUE(tuple0->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ tuple0->sharding());
+
+ EXPECT_TRUE(copy0->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ copy0->sharding());
+
+ // copy1 has partial information only from gte.0, so in the end it gets no
+ // sharding at all. During propagation it does propagate the information from
+ // gte.0 though, enabling Tuple.0 to be fully sharded.
+ EXPECT_FALSE(copy1->has_sharding());
+
+ EXPECT_TRUE(gte0->has_sharding());
+ EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding());
+
+ EXPECT_TRUE(gte1->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ gte1->sharding());
+
+ EXPECT_TRUE(copy2->has_sharding());
+ EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1),
+ HloSharding::AssignDevice(0)}),
+ copy2->sharding());
+
+ EXPECT_TRUE(tuple1->has_sharding());
+ EXPECT_EQ(tuple0->sharding(), tuple1->sharding());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
index 751fc677e2..dc514ae3e5 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
@@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() {
TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
instruction->operand_side_metadata().Kind())
<< instruction->ToString();
- kinds.insert(instruction->user_side_metadata().Kind().ToString());
+ kinds.insert(string(instruction->user_side_metadata().Kind()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 8e53cf97f8..81d6d69a8c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
- tensorflow::StringPiece name() const override { return "domain_verifier"; }
+ absl::string_view name() const override { return "domain_verifier"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index b9244b8e9e..72006e17e7 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -151,7 +151,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
}
TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
- if (!HasOperandType(hlo, eliminate_type_)) {
+ bool nullary = hlo->operands().empty();
+ bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
+ bool should_eliminate_type = (nullary && wrong_element_type) ||
+ HasOperandType(hlo, eliminate_type_);
+ if (!should_eliminate_type) {
// If this CHECK fires, then this was an instruction that does not take
// the elimination type as an operand but it does return it. This pass
// does not have a feature to change the output type in that case, so
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 2b109225d0..44ded2c2fa 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface {
HloElementTypeConverter(PrimitiveType eliminate_type,
PrimitiveType replace_with_type);
- tensorflow::StringPiece name() const override {
- return "element_type_converter";
- }
+ absl::string_view name() const override { return "element_type_converter"; }
// Returns the pass on the module and returns whether the module was modified.
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 51353eea6e..064b86493d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -23,13 +23,15 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -43,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -52,12 +53,9 @@ namespace xla {
namespace {
-using tensorflow::gtl::ArraySlice;
-
template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -95,19 +93,20 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<OperandT>(multi_index),
- rhs_literal.Get<OperandT>(multi_index));
- }));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<OperandT>(multi_index),
+ rhs_literal.Get<OperandT>(multi_index));
+ }));
return std::move(result);
}
template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
- LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -125,11 +124,12 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
<< HloOpcodeString(opcode);
}
- auto result = MakeUnique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) {
- return compare_op(lhs_literal.Get<complex64>(multi_index),
- rhs_literal.Get<complex64>(multi_index));
- }));
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(
+ result.Populate<bool>([&](absl::Span<const int64> multi_index) {
+ return compare_op(lhs_literal.Get<complex64>(multi_index),
+ rhs_literal.Get<complex64>(multi_index));
+ }));
return std::move(result);
}
@@ -138,49 +138,62 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
- typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
- typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this);
- typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "U16.");
- });
- typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this);
- typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this);
- typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this);
- typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
- "S16.");
- });
- typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this);
- typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this);
+ typed_visitors_[PRED] =
+ absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this);
+ typed_visitors_[U8] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
+ typed_visitors_[U16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "U16.");
+ });
+ typed_visitors_[U32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
+ typed_visitors_[U64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
+ typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
+ typed_visitors_[S16] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
+ "S16.");
+ });
+ typed_visitors_[S32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
+ typed_visitors_[S64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this);
typed_visitors_[F16] =
- MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
- typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this);
- typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this);
- typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this);
+ absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this);
+ typed_visitors_[F32] =
+ absl::make_unique<HloEvaluatorTypedVisitor<float>>(this);
+ typed_visitors_[F64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
+ typed_visitors_[C64] =
+ absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
// elementwise computations to be done in F32 and do BF16<->F32 conversion
// around the input and the output of the computations.
typed_visitors_[BF16] =
- MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
-
- typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
- });
- typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented(
- "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
- });
+ absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this);
+
+ typed_visitors_[TUPLE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE.");
+ });
+ typed_visitors_[OPAQUE] =
+ absl::make_unique<FunctionVisitor>([](HloInstruction*) {
+ return Unimplemented(
+ "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE.");
+ });
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloModule& module, ArraySlice<LiteralPtr> arg_literals) {
+StatusOr<Literal> HloEvaluator::Evaluate(
+ const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
evaluated_.clear();
@@ -192,12 +205,23 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
- .CloneToUnique();
+ .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloModule& module, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(module, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- const HloComputation& computation, ArraySlice<LiteralPtr> arg_literals) {
+StatusOr<Literal> HloEvaluator::Evaluate(
+ const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals) {
CHECK(computation.parent() != nullptr);
XLA_VLOG_LINES(
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
@@ -209,14 +233,23 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+ return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal_ptr : arg_literals) {
+ arg_literal_ptrs.push_back(&literal_ptr);
+ }
+ return Evaluate<const Literal*>(computation, arg_literal_ptrs);
}
template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
+StatusOr<Literal> HloEvaluator::Evaluate(
+ HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
evaluated_.clear();
arg_literals_.clear();
@@ -233,18 +266,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = input_literal->CloneToUnique();
+ evaluated_[operand] = input_literal->Clone();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
- HloInstruction* instruction) {
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+ HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+ std::vector<const Literal*> arg_literal_ptrs;
+ for (const auto& literal : arg_literals) {
+ arg_literal_ptrs.push_back(&literal);
+ }
+ return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
+}
+
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kParameter) {
return tensorflow::errors::FailedPrecondition(
"Cannot evaluate a parameter.");
@@ -253,7 +295,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
return tensorflow::errors::FailedPrecondition(
"Not all operands are constants.");
}
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear();
evaluated_.clear();
@@ -261,21 +302,22 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+ return GetEvaluatedLiteralFor(instruction).Clone();
}
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
- HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+ CHECK(result != nullptr);
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
VLOG(1) << "TryEvaluate failed:" << result_or.status();
- return nullptr;
+ return false;
}
- return result_or.ConsumeValueOrDie();
+ *result = result_or.ConsumeValueOrDie();
+ return true;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions) {
@@ -286,7 +328,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
owned_operands.push_back(operand->Clone());
} else {
owned_operands.push_back(
- HloInstruction::CreateConstant(it->second->CloneToUnique()));
+ HloInstruction::CreateConstant(it->second->Clone()));
}
}
@@ -303,12 +345,12 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -318,10 +360,10 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
- HloInstruction::CreateConstant(operand.CloneToUnique());
+ HloInstruction::CreateConstant(operand.Clone());
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -330,13 +372,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
- HloInstruction::CreateConstant(lhs.CloneToUnique());
+ HloInstruction::CreateConstant(lhs.Clone());
std::unique_ptr<HloInstruction> rhs_instr =
- HloInstruction::CreateConstant(rhs.CloneToUnique());
+ HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
@@ -344,7 +387,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
- dim_numbers);
+ dim_numbers, precision_config);
return Evaluate(cloned_instruction.get());
}
@@ -357,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = input_literal->CloneToUnique();
+ evaluated_[parameter] = input_literal->Clone();
return Status::OK();
}
@@ -378,7 +421,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
}
Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
- ArraySlice<HloInstruction*> operands(concatenate->operands());
+ absl::Span<HloInstruction* const> operands(concatenate->operands());
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
@@ -407,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
@@ -423,7 +466,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
if (!ShapeUtil::ElementIsFloating(operand->shape())) {
return InvalidArgument(
"expected element type in shape to be float for IsFinite op, got: %s",
- PrimitiveType_Name(operand->shape().element_type()).c_str());
+ PrimitiveType_Name(operand->shape().element_type()));
}
switch (operand->shape().element_type()) {
@@ -464,9 +507,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s",
- ShapeUtil::HumanString(compare->shape()).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str());
+ ShapeUtil::HumanString(compare->shape()),
+ ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()));
}
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
@@ -555,43 +598,41 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-// Returns an ShapeUtil::IndexIterationSpace that iterates over the output
-// gather dimensions while keeping the rest of the output dimensions clamped to
-// 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices(
+// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
+// dimensions while keeping the rest of the output dimensions clamped to 0.
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
int64 output_rank = output_shape.dimensions_size();
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count;
index_count.reserve(output_rank);
for (int64 i = 0; i < output_rank; i++) {
- bool is_output_gather_dim =
- !c_binary_search(dim_numbers.output_window_dims(), i);
- index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i)
- : 1);
+ bool is_output_batch_dim =
+ !absl::c_binary_search(dim_numbers.offset_dims(), i);
+ index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1);
}
return {std::move(index_base), std::move(index_count),
std::vector<int64>(output_rank, 1)};
}
-// Return an ShapeUtil::IndexIterationSpace that iterates over the output window
+// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice
// dimensions while keeping the rest of the output dimensions clamped to 0.
-ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
- int64 output_rank, ArraySlice<int64> window_bounds,
+ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices(
+ int64 output_rank, absl::Span<const int64> slice_sizes,
const GatherDimensionNumbers& dim_numbers) {
std::vector<int64> index_base(output_rank, 0);
std::vector<int64> index_count(output_rank, 1);
- int64 window_bounds_idx = 0;
+ int64 slice_sizes_idx = 0;
for (int64 i = 0; i < output_rank; i++) {
bool is_output_window_dim =
- c_binary_search(dim_numbers.output_window_dims(), i);
+ absl::c_binary_search(dim_numbers.offset_dims(), i);
if (is_output_window_dim) {
- while (c_binary_search(dim_numbers.elided_window_dims(),
- window_bounds_idx)) {
- window_bounds_idx++;
+ while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
+ slice_sizes_idx)) {
+ slice_sizes_idx++;
}
- index_count[i] = window_bounds[window_bounds_idx++];
+ index_count[i] = slice_sizes[slice_sizes_idx++];
}
}
@@ -599,30 +640,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices(
std::vector<int64>(output_rank, 1)};
}
-// This functor computes the contribution of gather_indices to an input index
+// This functor computes the contribution of start_indices to an input index
// corresponding to an output index. That is, given an output index I, it picks
-// out the gather output indices in I and uses them to look up a gather index,
-// G, from the gather indices tensor, and expands G into the input space
-// according to gather_dims_to_operand_dims.
-class OutputGatherIndexToInputIndex {
+// out the batch indices in I and uses them to look up a starting index, G, from
+// the start indices tensor, and expands G into the input space according to
+// start_index_map.
+class OutputBatchIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputGatherIndexToInputIndex(
+ explicit OutputBatchIndexToInputIndex(
const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
- const Shape& output_shape, const Literal* gather_indices)
- : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) {
+ const Shape& output_shape, const Literal* start_indices)
+ : dim_numbers_(*dim_numbers), start_indices_(*start_indices) {
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- output_dim_is_gather_dims_.push_back(
- !c_binary_search(dim_numbers_.output_window_dims(), i));
+ output_dim_is_batch_dims_.push_back(
+ !absl::c_binary_search(dim_numbers_.offset_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
int64 index_of_input_dim_in_index_vector =
- std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(),
- c_find(dim_numbers_.gather_dims_to_operand_dims(), i));
+ std::distance(dim_numbers_.start_index_map().begin(),
+ absl::c_find(dim_numbers_.start_index_map(), i));
if (index_of_input_dim_in_index_vector ==
- dim_numbers_.gather_dims_to_operand_dims_size()) {
+ dim_numbers_.start_index_map_size()) {
input_dim_value_to_index_vector_.push_back(-1);
} else {
input_dim_value_to_index_vector_.push_back(
@@ -630,14 +671,14 @@ class OutputGatherIndexToInputIndex {
}
}
- index_vector_index_.resize(gather_indices_.shape().dimensions_size());
+ index_vector_index_.resize(start_indices_.shape().dimensions_size());
input_index_.resize(input_shape.dimensions_size());
int64 index_vector_size =
- gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
+ start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);
}
- // Returns the contribution of gather_indices to the input index corresponding
+ // Returns the contribution of start_indices to the input index corresponding
// to output_index. See gather_inner_loop_body.
//
// This is conceptually a stateless transformation from output_index to the
@@ -650,24 +691,25 @@ class OutputGatherIndexToInputIndex {
// index_vector_index_ and index_vector on every invocation, we reuse the
// same storage for all invocations.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
- // Propagates the gather index dimensions from the output index into
+ // Propagates the batch dimensions from the output index into
// index_vector_index_ by mutating index_vector_index_ in place. Does not
// update the dim_numbers.index_vector_dim() dimension -- that's the dimension
// we iterate over in FetchIndexVector.
void PropagateOutputIndexGatherDimsToIndexVectorIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = output_index.size(); i < e; i++) {
- if (!output_dim_is_gather_dims_[i]) {
+ if (!output_dim_is_batch_dims_[i]) {
continue;
}
@@ -679,14 +721,14 @@ class OutputGatherIndexToInputIndex {
}
}
- // Populates index_vector_ by iterating over gather_indices_ according to
+ // Populates index_vector_ by iterating over start_indices_ according to
// index_vector_index_.
Status FetchIndexVector() {
int64 index_vector_dim = dim_numbers_.index_vector_dim();
for (int64 i = 0, e = index_vector_.size(); i < e; i++) {
index_vector_index_[index_vector_dim] = i;
- TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64(
- index_vector_index_));
+ TF_ASSIGN_OR_RETURN(index_vector_[i],
+ start_indices_.GetIntegralAsS64(index_vector_index_));
}
return Status::OK();
}
@@ -708,40 +750,39 @@ class OutputGatherIndexToInputIndex {
// PropagateIndexVectorToInputIndex.
std::vector<int64> input_dim_value_to_index_vector_;
- // output_dim_is_gather_dims_[i] is true iff the output index i is a gather
+ // output_dim_is_batch_dims_[i] is true iff the output index i is a gather
// dimension.
- std::vector<bool> output_dim_is_gather_dims_;
+ std::vector<bool> output_dim_is_batch_dims_;
- // The buffer into which we construct an index into gather_indices_ to fetch
+ // The buffer into which we construct an index into start_indices_ to fetch
// the index vector.
std::vector<int64> index_vector_index_;
- // The index vector fetched from gather_indices_.
+ // The index vector fetched from start_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
const GatherDimensionNumbers& dim_numbers_;
- const Literal& gather_indices_;
+ const Literal& start_indices_;
};
-// This functor computes the contribution of the window indices in an output
+// This functor computes the contribution of the offset indices in an output
// index to an input index. That is, given an output index I it picks out the
-// output window indices in I and expands it into a window index into the input
-// shape.
-class OutputWindowIndexToInputIndex {
+// output offset indices in I and expands it into an index into the input shape.
+class OutputOffsetIndexToInputIndex {
public:
// The constructor does some setup work that is amortized across all
// iterations.
- explicit OutputWindowIndexToInputIndex(
+ explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64> window_index_to_output_index;
int64 output_index_count = 0;
for (int64 i = 0; i < output_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
@@ -750,7 +791,7 @@ class OutputWindowIndexToInputIndex {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
@@ -769,10 +810,11 @@ class OutputWindowIndexToInputIndex {
// gather input index on every invocation we reuse the same storage for the
// result (input_index_), mutating it in place.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<ArraySlice<int64>> operator()(ArraySlice<int64> output_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> output_index) {
PropagateOutputIndexWindowDimsToInputIndex(output_index);
- return ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding output dimension index,
@@ -785,7 +827,7 @@ class OutputWindowIndexToInputIndex {
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
void PropagateOutputIndexWindowDimsToInputIndex(
- ArraySlice<int64> output_index) {
+ absl::Span<const int64> output_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_output_index_[i] != -1) {
input_index_[i] = output_index[input_dim_value_to_output_index_[i]];
@@ -801,119 +843,117 @@ class OutputWindowIndexToInputIndex {
// PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
- // The result computed by this functor. operator() returns an ArraySlice into
+ // The result computed by this functor. operator() returns a Span into
// this vector.
std::vector<int64> input_index_;
};
// Rehapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
-// there is one) to `reshaped_gather_indices`.
+// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
- int64 index_vector_dim, const Literal& gather_indices,
- std::unique_ptr<Literal>* reshaped_gather_indices) {
- if (gather_indices.shape().dimensions_size() != index_vector_dim) {
- return std::cref(gather_indices);
+ int64 index_vector_dim, const Literal& start_indices,
+ Literal* reshaped_start_indices) {
+ if (start_indices.shape().dimensions_size() != index_vector_dim) {
+ return std::cref(start_indices);
}
- std::vector<int64> new_shape(gather_indices.shape().dimensions().begin(),
- gather_indices.shape().dimensions().end());
+ std::vector<int64> new_shape(start_indices.shape().dimensions().begin(),
+ start_indices.shape().dimensions().end());
new_shape.push_back(1);
- TF_ASSIGN_OR_RETURN(*reshaped_gather_indices,
- gather_indices.Reshape(new_shape));
- return std::cref(**reshaped_gather_indices);
+ TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
+ start_indices.Reshape(new_shape));
+ return std::cref(*reshaped_start_indices);
}
Status HloEvaluator::HandleGather(HloInstruction* gather) {
- std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+ Literal result = Literal::CreateFromShape(gather->shape());
const Shape& shape = gather->shape();
const GatherDimensionNumbers& dim_numbers =
gather->gather_dimension_numbers();
const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
- std::unique_ptr<Literal> reshaped_gather_indices;
+ Literal reshaped_start_indices;
TF_ASSIGN_OR_RETURN(
- const Literal& gather_indices,
+ const Literal& start_indices,
ReshapedGatherIndices(dim_numbers.index_vector_dim(),
GetEvaluatedLiteralFor(gather->operand(1)),
- &reshaped_gather_indices));
+ &reshaped_start_indices));
// We iterate over the gather dimensions in the output shape in an outer loop
// nest, and iterate over the window dimensions in the output shape in an
// inner loop nest.
- ShapeUtil::IndexIterationSpace gather_indices_iteration_space =
- IterationSpaceForOutputGatherIndices(shape, dim_numbers);
- ShapeUtil::IndexIterationSpace window_indices_iteration_space =
- IterationSpaceForOutputWindowIndices(
- shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers);
+ ShapeUtil::IndexIterationSpace start_indices_iteration_space =
+ IterationSpaceForOutputBatchIndices(shape, dim_numbers);
+ ShapeUtil::IndexIterationSpace offset_indices_iteration_space =
+ IterationSpaceForOutputOffsetIndices(
+ shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers);
// Scratch buffers that hold an index in the output shape and the
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
- std::vector<int64> input_gather_index_clamped(
- operand.shape().dimensions_size());
+ std::vector<int64> input_index_clamped(operand.shape().dimensions_size());
- OutputGatherIndexToInputIndex output_gather_index_to_input_index(
+ OutputBatchIndexToInputIndex output_batch_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
- /*output_shape=*/shape, &gather_indices);
- OutputWindowIndexToInputIndex output_window_index_to_input_index(
+ /*output_shape=*/shape, &start_indices);
+ OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
const Shape& operand_shape = operand.shape();
auto gather_inner_loop_body =
- [&](ArraySlice<int64> output_window_index,
- ArraySlice<int64> input_gather_index,
- ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
+ [&](absl::Span<const int64> output_window_index,
+ absl::Span<const int64> input_gather_index,
+ absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_window_index,
- output_window_index_to_input_index(output_window_index));
+ absl::Span<const int64> input_window_index,
+ output_offset_index_to_input_index(output_window_index));
for (int i = 0, e = output_index.size(); i < e; i++) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
for (int i = 0, e = input_gather_index.size(); i < e; i++) {
int64 output_dim =
- output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ output_offset_index_to_input_index.input_dim_value_to_output_index(i);
// If 'output_dim' is -1, it means 'i' is an elided window dim. This means
// we set the iteration index to 0, so for the purpose of the following
// calculations we can consider the output dimension size to be 1.
int64 output_dim_size =
output_dim == -1 ? 1 : shape.dimensions(output_dim);
// Clamp the gather index so that the gather region fits in the operand.
- // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // input_index_clamped[i] = clamp(input_gather_index[i], 0,
// operand_shape.dimensions(i) -
// output_dim_size);
- input_gather_index_clamped[i] =
+ input_index_clamped[i] =
std::min(operand_shape.dimensions(i) - output_dim_size,
std::max(0LL, input_gather_index[i]));
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ input_index[i] = input_index_clamped[i] + input_window_index[i];
DCHECK_GE(input_index[i], 0);
DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
- result->CopyElementFrom(operand, input_index, output_index));
+ result.CopyElementFrom(operand, input_index, output_index));
return true;
};
auto gather_outer_loop_body =
- [&](ArraySlice<int64> output_gather_index) -> StatusOr<bool> {
- TF_ASSIGN_OR_RETURN(
- ArraySlice<int64> input_gather_index,
- output_gather_index_to_input_index(output_gather_index));
+ [&](absl::Span<const int64> output_gather_index) -> StatusOr<bool> {
+ TF_ASSIGN_OR_RETURN(absl::Span<const int64> input_gather_index,
+ output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, window_indices_iteration_space,
+ shape, offset_indices_iteration_space,
std::bind(gather_inner_loop_body, std::placeholders::_1,
input_gather_index, output_gather_index)));
return true;
};
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- shape, gather_indices_iteration_space, gather_outer_loop_body));
+ shape, start_indices_iteration_space, gather_outer_loop_body));
evaluated_[gather] = std::move(result);
return Status::OK();
}
@@ -929,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
// Checks that operand's dimensions are the same as the broadcast's
// dimensions along the dimensions to be broadcasted.
for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand.shape().dimensions(i));
+ auto operand_dim_size = operand.shape().dimensions(i);
+ auto broadcast_dim_size =
+ broadcast->shape().dimensions(broadcast->dimensions(i));
+ TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
+ "Operand dimension %d is broadcast to output dimension %d, but the "
+ "sizes of these two dims do not match (%d vs %d): %s",
+ i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
+ broadcast->ToString());
}
TF_ASSIGN_OR_RETURN(
@@ -960,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] = MakeUnique<Literal>(
- ShapeUtil::GetTupleElementShape(operand->shape(), index));
- return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
- /*dest_shape_index=*/{},
- /*src_shape_index=*/{index});
+ evaluated_[get_tuple_element] =
+ Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
- auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
- evaluated_[copy] = std::move(result);
+ evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
return Status::OK();
}
@@ -987,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
.ConsumeValueOrDie();
@@ -1019,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) {
}
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result =
+ Literal result =
embedded_evaluator
.Evaluate<const Literal*>(*readded_computation, arg_literals)
.ConsumeValueOrDie();
@@ -1039,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) {
auto* false_computation = conditional->false_computation();
HloEvaluator embedded_evaluator;
- std::unique_ptr<Literal> result;
+ Literal result;
if (pred.Get<bool>({})) {
result = embedded_evaluator
.Evaluate<const Literal*>(*true_computation,
@@ -1064,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
// If predicate is of scalar type, no element-wise selection would be needed.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
- evaluated_[select] = on_true.CloneToUnique();
+ evaluated_[select] = on_true.Clone();
} else {
- evaluated_[select] = on_false.CloneToUnique();
+ evaluated_[select] = on_false.Clone();
}
return Status::OK();
}
@@ -1080,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
if (pred.Get<bool>({})) {
- evaluated_[tuple_select] = on_true.CloneToUnique();
+ evaluated_[tuple_select] = on_true.Clone();
} else {
- evaluated_[tuple_select] = on_false.CloneToUnique();
+ evaluated_[tuple_select] = on_false.Clone();
}
return Status::OK();
}
@@ -1091,23 +1135,23 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
// Initialize the loop carried valued with the input to the While instruction.
- auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+ auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
bool keep_going = true;
int64 iteration_count = 0;
HloEvaluator cond_evaluator(max_loop_iterations_);
HloEvaluator loop_body_evaluator(max_loop_iterations_);
while (keep_going) {
if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
- return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).",
- while_hlo->name().c_str(), max_loop_iterations_);
+ return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
+ while_hlo->name(), max_loop_iterations_);
}
- TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
- *cond_comp, {lcv.get()}));
- keep_going = cond_val->GetFirstElement<bool>();
+ TF_ASSIGN_OR_RETURN(auto cond_val,
+ cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+ keep_going = cond_val.GetFirstElement<bool>();
if (keep_going) {
TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
- *body_comp, {lcv.get()}));
- VLOG(3) << "Loop iteration result: " << body_val->ToString();
+ *body_comp, {&lcv}));
+ VLOG(3) << "Loop iteration result: " << body_val.ToString();
lcv = std::move(body_val);
cond_evaluator.ResetVisitStates();
loop_body_evaluator.ResetVisitStates();
@@ -1122,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
// hoops to make this work.
namespace {
template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
auto rank = ShapeUtil::Rank(keys_literal.shape());
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1162,56 +1206,55 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
result_keys.push_back(key_value.first);
result_values.push_back(key_value.second);
}
- auto result_keys_literal = MakeUnique<Literal>(keys_literal.shape());
- result_keys_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<KeyType>(result_keys));
- auto result_values_literal = MakeUnique<Literal>(values_literal.shape());
- result_values_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ValueType>(result_values));
+ Literal result_keys_literal(keys_literal.shape());
+ result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal result_values_literal(values_literal.shape());
+ result_values_literal.PopulateR1(
+ absl::Span<const ValueType>(result_values));
return std::make_pair(std::move(result_keys_literal),
std::move(result_values_literal));
};
- std::unique_ptr<Literal> result_tuple;
+ Literal result_tuple;
if (rank == 1) {
auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple = LiteralUtil::MakeTuple(
- {result_pair.first.get(), result_pair.second.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto keys_result_literal = MakeUnique<Literal>(keys_literal.shape());
- auto values_result_literal = MakeUnique<Literal>(values_literal.shape());
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
int64 r1_length = keys_literal.shape().dimensions(1);
for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
+ .Reshape({r1_length}));
TF_ASSIGN_OR_RETURN(auto values_r1_slice,
values_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+ .Reshape({r1_length}));
+ auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first->Reshape({1, r1_length}));
+ r1_result_pair.first.Reshape({1, r1_length}));
TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
- *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
- *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+ r1_result_pair.second.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
}
- result_tuple = LiteralUtil::MakeTuple(
- {keys_result_literal.get(), values_result_literal.get()});
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
}
- VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
- HloInstruction* sort, const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(1)->shape().element_type()) {
case F32:
return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1230,9 +1273,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
}
}
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
- const Literal& keys_literal,
- const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
switch (sort->operand(0)->shape().element_type()) {
case F32:
return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1253,7 +1296,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
if (sort_dim != rank - 1) {
return Unimplemented(
- "Trying to support along dimension %lld, which is not the last "
+ "Trying to sort along dimension %d, which is not the last "
"dimension",
sort_dim);
}
@@ -1272,9 +1315,25 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
}
}
+Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
+ if (!ShapeUtil::IsTuple(reduce->shape())) {
+ return DefaultAction(reduce);
+ } else {
+ auto first_element_type = reduce->shape().tuple_shapes(0).element_type();
+ for (const auto& tuple_shape : reduce->shape().tuple_shapes()) {
+ if (tuple_shape.element_type() != first_element_type) {
+ return Unimplemented(
+ "Reduce with several outputs that have mixed element types is "
+ "unsupported");
+ }
+ }
+ return reduce->Visit(typed_visitors_.at(first_element_type).get());
+ }
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
- return Status::OK();
+ return ShapeUtil::ValidateShape(hlo->shape());
}
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
@@ -1285,27 +1344,14 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) {
// Explicit instantiation of templatized Evaluate* methods.
//
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloModule& module,
- ArraySlice<const Literal*> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- const HloModule& module, ArraySlice<std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(const HloComputation& computation,
- ArraySlice<const Literal*> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
+ const HloModule& module, absl::Span<const Literal* const> arg_literals);
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
const HloComputation& computation,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
-
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(HloInstruction* instruction,
- ArraySlice<const Literal*> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
- HloInstruction* instruction,
- ArraySlice<std::unique_ptr<Literal>> arg_literals);
+ absl::Span<const Literal* const> arg_literals);
+
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
+ HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index a4c37ef328..21e676d671 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -18,7 +18,8 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
@@ -47,12 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Precondition: The indices of arg_literals correspond to the parameter
// numbers of the HLO parameters in the computation. See comment below for an
// example.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloModule& module,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloModule& module,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
@@ -70,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
// 1 in this computation. The input literals array will then have its first
// literal map to Parameter0 and the second map to Parameter1.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- const HloComputation& computation,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(const HloComputation& computation,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
@@ -83,42 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
- // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+ // `LiteralPtr` accepts either Literal or const Literal*
// type.
template <typename LiteralPtr>
- StatusOr<std::unique_ptr<Literal>> Evaluate(
- HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction,
+ absl::Span<const LiteralPtr> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
- StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+ StatusOr<Literal> Evaluate(HloInstruction* instruction);
- // Same as Evaluate, except returning nullptr on error.
- std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+ // Same as Evaluate, except returning false on error and accepts an output
+ // pointer.
+ bool TryEvaluate(HloInstruction* instruction, Literal* result);
// Evaluates a single HLO instruction, substituting the given literals for
// some of the instruction's operands.
//
// For example, given instruction = op(A, B, C) and the map
// {A = x, C = y}, this evaluates op(x, B, y).
- StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+ StatusOr<Literal> EvaluateWithSubstitutions(
const HloInstruction* instruction,
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
- HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+ StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+ const Literal& lhs,
+ const Literal& rhs);
- StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
- HloOpcode opcode, const Literal& operand);
+ StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+ const Literal& operand);
- StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
- const Literal& rhs);
+ StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
+ const Literal& lhs, const Literal& rhs);
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -185,6 +184,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
@@ -196,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();
- return *(it->second);
+ return it->second;
}
// Tracks the HLO instruction and its evaluated literal result.
@@ -204,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// that are no longer a parent for any other subsequent instruction in
// post-orderring.
// Must be cleared for each evaluation.
- tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
- evaluated_;
+ // Storing Literal in place require the container to have pointer stability so
+ // we cannot use FlatMap any more.
+ std::unordered_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
- static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+ static StatusOr<Literal> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
@@ -222,13 +224,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(operand->shape()).c_str());
+ ShapeUtil::HumanString(shape),
+ ShapeUtil::HumanString(operand->shape()));
}
- auto result = MakeUnique<Literal>(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ Literal result(shape);
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return unary_op(operand_literal.Get<NativeT>(multi_index));
}));
return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 3ac6d68df3..16411eb078 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -51,12 +52,11 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest() : use_bfloat16_(GetParam()) {
- evaluator_ = MakeUnique<HloEvaluator>();
+ HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
+ evaluator_ = absl::make_unique<HloEvaluator>();
}
- std::unique_ptr<Literal> Evaluate(
- tensorflow::gtl::ArraySlice<const Literal*> arg_literals = {}) {
+ Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
if (use_bfloat16_) {
// In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -68,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<HloEvaluator> evaluator_;
- void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> input, float aabs = 0) {
+ void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+ float aabs = 0) {
HloComputation::Builder b(TestName());
auto c1 =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
- b.AddInstruction(
- HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+ b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto element_type = expected->shape().element_type();
+ auto element_type = expected.shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
}
- void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
- std::unique_ptr<Literal> lhs,
- std::unique_ptr<Literal> rhs) {
+ void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+ Literal rhs) {
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(
- HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+ HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
bool use_bfloat16_;
@@ -116,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- Shape shape = low->shape();
+ Shape shape = low.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -125,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -137,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
auto high = LiteralUtil::CreateR0<float>(1.f);
- Shape shape = value->shape();
+ Shape shape = value.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -146,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -160,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- Shape shape = on_true->shape();
+ Shape shape = on_true.shape();
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
auto c2 =
@@ -171,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -294,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
- std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+ std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -312,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
lhs_instruction, param_rhs2));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate(args);
+ Literal result = Evaluate(args);
auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -326,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
@@ -336,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
- result->EachCell<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
- std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
- });
+ result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+ std::vector<int64> rindexes = Permute(permutation, indices);
+ EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+ });
}
// Verifies Broadcast operation is correctly evaluated.
@@ -355,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction, {1, 2}));
+ output_literal.shape(), literal_instruction, {1, 2}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -373,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloInstruction::CreateConstant(std::move(input_literal)));
// Broadcast dimension should be empty in the case of scalars.
b.AddInstruction(HloInstruction::CreateBroadcast(
- output_literal->shape(), literal_instruction,
+ output_literal.shape(), literal_instruction,
/*broadcast_dimensions=*/{}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate({});
+ Literal result = Evaluate({});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -397,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int64>(
{{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -419,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<int64>({100, 200});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -431,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
- ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -451,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
- ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
- expected->shape()));
+ ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+ expected.shape()));
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
- b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+ b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
PaddingConfig CreatePaddingConfig(
@@ -494,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
shape, operand_instruction, padding_value_instruction, padding_config));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -521,9 +518,9 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
(*expected_array)(1, 0, 0, 0) = 1.0f;
(*expected_array)(1, 2, 0, 0) = 2.0f;
@@ -534,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -547,7 +544,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -565,10 +562,10 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
- auto expected_array = MakeUnique<Array2D<float>>(1, 5);
+ auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
(*expected_array)(0, 0) = 7.0f;
(*expected_array)(0, 1) = 2.718f;
(*expected_array)(0, 2) = 2.718f;
@@ -576,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -588,7 +585,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto input_array = MakeUnique<Array2D<float>>(4, 3);
+ auto input_array = absl::make_unique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
@@ -610,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
- auto expected_array = MakeUnique<Array2D<float>>(0, 9);
+ auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -628,7 +625,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// { 3 },
// { 4 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 1);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 1);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -645,10 +642,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected_array = Array2D<float>({
@@ -660,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -679,7 +677,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -690,14 +688,15 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -710,7 +709,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
- auto lhs_array = MakeUnique<Array2D<float>>(4, 3);
+ auto lhs_array = absl::make_unique<Array2D<float>>(4, 3);
lhs_array->FillUnique(1.0f);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
@@ -722,7 +721,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// { 3, 4 },
// { 5, 6 },
// }
- auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
+ auto rhs_array = absl::make_unique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
@@ -733,10 +732,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected_array = Array2D<float>({
{22.f, 28.f},
@@ -746,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -784,17 +784,18 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
dnums.set_kernel_input_feature_dimension(1);
dnums.add_kernel_spatial_dimensions(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -838,12 +839,13 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@@ -856,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -921,22 +923,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
Array4D<float> expected_array({{{{2514, 2685}}}});
- Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
+ Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -998,22 +1001,23 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
Array4D<float> expected_array({{{{2514, 2685}}}});
- Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
+ Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
// clang-format on
auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1057,12 +1061,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@@ -1076,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1120,12 +1125,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@@ -1140,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest,
@@ -1191,12 +1197,13 @@ TEST_P(HloEvaluatorTest,
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@@ -1212,7 +1219,68 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
+TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
+ HloComputation::Builder b(TestName());
+ std::vector<int64> input_dims = {1, 2, 2, 4};
+ std::vector<int64> filter_dims = {2, 2, 2, 8};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(1);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+ std::iota(input_elems.begin(), input_elems.end(), -7);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
+ HloInstruction* lhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
+
+ std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ std::iota(filter_elems.begin(), filter_elems.end(), -31);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
+ HloInstruction* rhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ shape, lhs_instruction, rhs_instruction,
+ /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ Array4D<float> expected_array(1, 1, 1, 8);
+ expected_array.FillWithYX(
+ Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1245,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
module().AddEntryComputation(b.Build());
HloEvaluator hlo_eval;
- std::unique_ptr<Literal> result =
- hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+ Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
}
// Reducing many numbers should be fast because it doesn't create
@@ -1297,7 +1364,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1324,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR1<float>({6, 18});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1339,7 +1406,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1376,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1390,7 +1457,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto arg_array = MakeUnique<Array2D<float>>(2, 3);
+ auto arg_array = absl::make_unique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
@@ -1433,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1444,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
@@ -1494,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1511,7 +1578,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
// { 9, 10, 11, 12, 13 },
// { 17, 18, 19, 20, 21 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(3, 5);
+ auto operand_array = absl::make_unique<Array2D<float>>(3, 5);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1526,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
/*strides=*/{2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1544,7 +1611,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1560,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1580,7 +1647,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
- auto operand_array = MakeUnique<Array2D<float>>(2, 4);
+ auto operand_array = absl::make_unique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
@@ -1596,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
start_indices, {2, 3}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1614,7 +1681,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1633,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
shape, operand, update, start_indices));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1651,7 +1718,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal2 =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
@@ -1669,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1687,7 +1754,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
- auto operand_array = MakeUnique<Array2D<double>>(2, 3);
+ auto operand_array = absl::make_unique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
@@ -1708,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
auto result_inner_literal =
LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
- auto expected = LiteralUtil::MakeTuple({
- result_inner_literal.get(),
- result_inner_literal.get(),
- });
+ auto expected =
+ LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1748,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
module().AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = Evaluate();
+ Literal result = Evaluate();
// clang-format off
auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1770,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1786,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
+ Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, &param0_literal}, {square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1811,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
- auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+ Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+ auto result =
+ evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1826,21 +1893,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1851,21 +1917,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1876,22 +1941,21 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>(
+ LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1902,23 +1966,22 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest,
@@ -1930,23 +1993,22 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1957,21 +2019,19 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1982,21 +2042,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2007,20 +2066,18 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2031,21 +2088,21 @@ ENTRY main {
operand = s32[3] parameter(0)
indices = s32[2,2,1] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1}
+ slice_sizes={1}
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> gather_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+ Evaluate({&operand, &start_indices})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2070,15 +2127,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2103,15 +2158,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2137,15 +2191,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2171,15 +2223,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2205,17 +2255,15 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(
+ LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
- ErrorSpec{0.1, 0.01}));
+ Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2241,15 +2289,13 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2275,15 +2321,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+ Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2308,21 +2353,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
{{-40, 40}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest,
@@ -2348,21 +2390,18 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ Literal expected =
LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2387,16 +2426,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2421,17 +2458,14 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- std::unique_ptr<Literal> expected =
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ Literal expected =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2456,13 +2490,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *operand,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ operand, Evaluate({&operand, &scatter_indices, &updates})));
}
TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2489,16 +2521,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<int32>({10, 61, 32});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *expected,
- *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+ expected, Evaluate({&operand, &scatter_indices, &updates})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2517,6 +2546,30 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) {
std::move(rhs));
}
+TEST_P(HloEvaluatorTest, Bf16Reduction) {
+ const string hlo_text = R"(
+HloModule Bf16Reduction
+
+add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
+ lhs = bf16[] parameter(0)
+ rhs = bf16[] parameter(1)
+ ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
+}
+
+ENTRY main {
+ arg0 = bf16[4]{0} parameter(0)
+ init = bf16[] constant(0)
+ ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR1<bfloat16>(
+ {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
+ Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 084b49b478..7f090a52db 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,11 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -86,6 +91,29 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// of this class.
template <typename ReturnT, typename ElementwiseT = ReturnT>
class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
+ private:
+ // Get the value in the given literal static_cast as a double.
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ double GetAsDouble(const Literal& literal,
+ absl::Span<const int64> input_index) {
+ return static_cast<double>(literal.Get<NativeT>(input_index));
+ }
+
+ // Specialization for complex types. In this case it is not possible to
+ // static_cast value to a double so just CHECK fail. This method is not used
+ // at run-time, but must be available at compile-time to keep the compiler
+ // happy.
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ double GetAsDouble(const Literal& literal,
+ absl::Span<const int64> input_index) {
+ LOG(FATAL) << "Trying to get complex literal as double: "
+ << literal.ToString();
+ }
+
public:
explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {}
@@ -117,7 +145,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
- HloOpcodeString(hlo_instruction->opcode()).c_str());
+ HloOpcodeString(hlo_instruction->opcode()));
}
// TODO(b/35950897): many of the stl functions used in the handlers are not
@@ -218,15 +246,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
+ parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@@ -234,15 +261,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleBitcastConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
parent_->evaluated_[convert] = std::move(result);
} else {
- parent_->evaluated_[convert] =
- result->Relayout(convert->shape().layout());
+ parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
}
return Status::OK();
}
@@ -525,7 +551,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleDivide(HloInstruction* divide) override {
+ template <
+ typename NativeT,
+ typename std::enable_if<std::is_floating_point<NativeT>::value ||
+ is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleDivide(HloInstruction* divide) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
ElementwiseT rhs_elem) {
@@ -535,6 +565,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT,
+ typename std::enable_if<std::is_signed<NativeT>::value &&
+ std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleDivide(HloInstruction* divide) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[divide],
+ ElementWiseBinaryOp(
+ divide,
+ [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT {
+ if (rhs_elem == 0) {
+ return static_cast<ElementwiseT>(-1);
+ }
+ if (rhs_elem == -1 &&
+ lhs_elem == std::numeric_limits<ElementwiseT>::min()) {
+ return lhs_elem;
+ }
+ return lhs_elem / rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
+ nullptr>
+ Status HandleDivide(HloInstruction* divide) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide],
+ ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return rhs_elem == 0
+ ? std::numeric_limits<ElementwiseT>::max()
+ : (lhs_elem / rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ Status HandleDivide(HloInstruction* divide) {
+ return HandleDivide<ElementwiseT>(divide);
+ }
+
+ template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr>
Status HandleMaximum(HloInstruction* maximum) {
@@ -620,9 +690,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
@@ -632,6 +701,40 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ template <typename NativeT,
+ typename std::enable_if<std::is_unsigned<NativeT>::value>::type* =
+ nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder],
+ ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el,
+ ElementwiseT rhs_el) {
+ return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<std::is_signed<NativeT>::value &&
+ std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleRemainder(HloInstruction* remainder) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[remainder],
+ ElementWiseBinaryOp(
+ remainder,
+ [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT {
+ if (rhs_el == 0) {
+ return lhs_el;
+ }
+ if (rhs_el == -1 &&
+ lhs_el == std::numeric_limits<ElementwiseT>::min()) {
+ return 0;
+ }
+ return lhs_el % rhs_el;
+ }));
+ return Status::OK();
+ }
+
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
@@ -873,10 +976,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto result = MakeUnique<Literal>(result_shape);
+ Literal result(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -916,9 +1019,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
- window, dnums));
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
@@ -941,10 +1045,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
+ int64 feature_group_count = conv->feature_group_count();
+
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](
- tensorflow::gtl::ArraySlice<int64> out_index) {
+ rhs_literal_data,
+ feature_group_count](const absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -955,7 +1061,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_batch_dim = dnums.output_batch_dimension();
const int64 output_z_dim = dnums.output_feature_dimension();
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 input_z_size =
+ ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ // The size of an input feature group.
+ const int64 input_feature_group_size = input_z_size / feature_group_count;
+
+ const int64 output_z_size =
+ ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
+ // The output feature dimension is a concatenation of convolution results
+ // from the different groups.
+ const int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ const int64 feature_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -963,7 +1084,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
- for (int64 iz = 0; iz < z_size; ++iz) {
+ for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
+ const int64 iz =
+ feature_group_index * input_feature_group_size + rhs_iz;
+
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
@@ -972,7 +1096,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 rhs_linear_index = 0;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
- rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+ rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
@@ -1025,13 +1149,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
}
cnt : {}
- } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
+ } while (IndexUtil::BumpIndices(window_shape,
+ absl::MakeSpan(rhs_spatial_index)));
return static_cast<ReturnT>(result_val);
};
- auto result = MakeUnique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
@@ -1078,7 +1203,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// result_index_locations[i] contains one or two pointers to the locations
// in lhs_index or rhs_index where the i'th result index should go.
- tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
+ absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
result_index_locations;
result_index_locations.reserve(lhs_rank + rhs_rank - 2);
@@ -1093,20 +1218,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
if (i != lhs_contracting_dimension &&
- !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
+ !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) {
result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
for (int64 i = 0; i < rhs_rank; i++) {
if (i != rhs_contracting_dimension &&
- !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
+ !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) {
result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
- auto result = MakeUnique<Literal>(dot->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> result_index) {
+ Literal result(dot->shape());
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
for (int64 i = 0; i < result_index.size(); i++) {
@@ -1153,24 +1278,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Create new HLO of padded shape with padding value.
ReturnT scalar =
parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
- auto result = MakeUnique<Literal>(pad->shape());
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) {
- return scalar;
- }));
+ Literal result(pad->shape());
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
+ [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
const Literal& evaluated_operand =
parent_->GetEvaluatedLiteralFor(pad->operand(0));
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
- std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+ std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
// Loop through each element of the operand, assign them to the
// corresponding index of the resulting padded literal.
const PaddingConfig& pad_config = pad->padding_config();
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
+ auto func = [&](absl::Span<const int64> input_index) {
for (auto i = 0; i < input_index.size(); ++i) {
// Interior padding occurs logically before edge padding, so in the case
// of negative edge padding elements are removed from the
@@ -1186,8 +1309,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return true;
}
}
- result->Set<ReturnT>(target_index,
- evaluated_operand.Get<ReturnT>(input_index));
+ result.Set<ReturnT>(target_index,
+ evaluated_operand.Get<ReturnT>(input_index));
return true;
};
@@ -1314,16 +1437,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT>
- StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+ StatusOr<Literal> MapImpl(HloInstruction* map) {
auto operands = map->operands();
HloComputation* computation = map->to_apply();
- auto result = MakeUnique<Literal>(map->shape());
+ Literal result(map->shape());
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+ std::vector<Literal> arg_literals;
arg_literals.reserve(operands.size());
// Construct scalar literal parameters to be passed to the map
@@ -1338,16 +1461,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_literals.push_back(std::move(curr_val_literal));
}
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<std::unique_ptr<Literal>>(*computation,
- arg_literals)
+ Literal computed_result =
+ embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on
// the same computation.
embedded_evaluator.ResetVisitStates();
- return computed_result->Get<ReturnT>({});
+ return computed_result.Get<ReturnT>({});
}));
return std::move(result);
}
@@ -1432,10 +1553,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
[](const ReturnT& a, const ReturnT& b) {
return SafeLess<ReturnT>(a, b);
});
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ Literal result_literal(keys_literal.shape());
+ result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
return result_literal;
};
@@ -1444,16 +1564,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
} else {
// For R2 sort, the desired semantics are to sort each matrix row
// independently.
- auto result_literal = MakeUnique<Literal>(keys_literal.shape());
+ Literal result_literal(keys_literal.shape());
int64 r1_length = keys->shape().dimensions(1);
for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
TF_ASSIGN_OR_RETURN(auto r1_slice,
keys_literal.Slice({row, 0}, {row + 1, r1_length})
- ->Reshape({r1_length}));
- auto r1_result = sort_r1(*r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
- *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+ .Reshape({r1_length}));
+ auto r1_result = sort_r1(r1_slice);
+ TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ r1_result, {0, 0}, {row, 0}, {1, r1_length}));
}
parent_->evaluated_[sort] = std::move(result_literal);
}
@@ -1472,20 +1592,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleSort<ReturnT>(sort);
}
- Status HandleReduce(HloInstruction* reduce) override {
- // TODO(b/112040122): Support variadic reduce.
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return Unimplemented("Variadic reduce is not supported in the Evaluator");
- }
- auto arg = reduce->operand(0);
- auto init_value = reduce->operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ Status HandleReduce(HloInstruction* hlo) override {
+ HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
+ int64 num_args = reduce->inputs().size();
+ bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape());
+ absl::Span<const int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
- ShapeUtil::Rank(arg->shape()) - dimensions.size());
+
+ absl::InlinedVector<const Shape*, 1> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
- {&arg->shape(), &init_value->shape()},
+ operand_shapes,
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
@@ -1493,14 +1613,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
- const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
- VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
- TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
- auto init_scalar = init_literal.Get<ReturnT>({});
+ absl::InlinedVector<const Literal*, 1> arg_literals(num_args);
+ absl::InlinedVector<const Literal*, 1> init_literals(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]);
+ VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString();
+ init_literals[i] =
+ &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]);
+ VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString();
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape()));
+ }
- const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
+ // All args and results have the same dimensions, so pick an arbitrary one.
+ const Shape& arg_shape = arg_literals[0]->shape();
+ const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape())
+ ? reduce->shape().tuple_shapes(0)
+ : reduce->shape();
+ const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions());
std::vector<int64> arg_dim_steps(arg_dimensions.size());
std::vector<int64> arg_dim_counts(arg_dimensions.size());
for (const int64 dim : dimensions) {
@@ -1518,61 +1647,107 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce->shape());
- // For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- ReturnT result_val = init_scalar;
+ absl::InlinedVector<Literal, 1> results(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ results[i] = Literal(result_shape);
+ }
- std::vector<int64> base(arg_dimensions.size());
- for (int64 i = 0; i < multi_index.size(); ++i) {
- base[result_to_arg_index[i]] = multi_index[i];
- }
+ Status eval_status;
+ // For each resulting dimension, calculate and assign computed values.
+ // This is really wasteful when num_args > 1, since we re-run the
+ // reduction num_args time. The alternative is to teach Populate() about
+ // tuples, which we should probably do.
+ absl::InlinedVector<ReturnT, 1> init_scalars(num_args);
+ for (int i = 0; i < num_args; ++i) {
+ init_scalars[i] = init_literals[i]->Get<ReturnT>({});
+ }
+
+ for (int64 input = 0; input < num_args; ++input) {
+ TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
+ [&](absl::Span<const int64> multi_index) {
+ if (!eval_status.ok()) {
+ return init_scalars[input];
+ }
+ absl::InlinedVector<ReturnT, 1> result_values(init_scalars.begin(),
+ init_scalars.end());
+ std::vector<int64> base(arg_dimensions.size());
+ for (int64 i = 0; i < multi_index.size(); ++i) {
+ base[result_to_arg_index[i]] = multi_index[i];
+ }
- // When the reduction is addition of floats, accumulate in a double
- // for better precision. Also, avoid creating Literals for the
- // intermediate results; it's much faster.
- if (ShapeUtil::ElementIsFloating(init_literal.shape()) &&
- IsScalarAdd(function)) {
- double computed_result = 0;
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
- computed_result += arg_literal.Get<float>(input_index);
+ // When the reduction is addition of floats, accumulate in a double
+ // for better precision. Also, avoid creating Literals for the
+ // intermediate results; it's much faster.
+ if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) &&
+ IsScalarAdd(function)) {
+ CHECK_EQ(num_args, 1);
+ double computed_result = 0;
+ auto func = [&](absl::Span<const int64> input_index) {
+ computed_result +=
+ GetAsDouble<ReturnT>(*arg_literals[0], input_index);
+ return true;
+ };
+ ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base,
+ arg_dim_counts, arg_dim_steps, func);
+ return static_cast<ReturnT>(computed_result);
+ }
+ auto func =
+ [&](absl::Span<const int64> input_index) -> StatusOr<bool> {
+ absl::InlinedVector<ReturnT, 1> arg_values(num_args);
+ for (int64 i = 0; i < num_args; ++i) {
+ arg_values[i] = arg_literals[i]->Get<ReturnT>(input_index);
+ }
+
+ // Evaluate computation with specified literal operands.
+ absl::InlinedVector<Literal, 1> embedded_operands;
+ for (ReturnT value : result_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ for (ReturnT value : arg_values) {
+ embedded_operands.push_back(
+ LiteralUtil::CreateR0<ReturnT>(value));
+ }
+ absl::InlinedVector<Literal*, 1> embedded_operands_ptrs(
+ embedded_operands.size());
+ std::transform(embedded_operands.begin(), embedded_operands.end(),
+ embedded_operands_ptrs.begin(),
+ [](Literal& literal) { return &literal; });
+
+ TF_ASSIGN_OR_RETURN(Literal computed_result,
+ embedded_evaluator.Evaluate<const Literal*>(
+ *function, embedded_operands_ptrs));
+ // Clear visit states so that we can use the evaluator again on
+ // the same computation.
+ embedded_evaluator.ResetVisitStates();
+ // Assign computed result to result_val.
+ if (!has_tuple_output) {
+ result_values[0] = computed_result.Get<ReturnT>({});
+ } else {
+ for (int64 i = 0; i < num_args; ++i) {
+ result_values[i] = computed_result.Get<ReturnT>(
+ /*multi_index=*/{}, /*shape_index=*/{i});
+ }
+ }
return true;
};
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return static_cast<ReturnT>(computed_result);
- }
- auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
- auto curr_val = arg_literal.Get<ReturnT>(input_index);
-
- // Evaluate computation with specified literal operands.
- auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
- auto result_val_literal =
- LiteralUtil::CreateR0<ReturnT>(result_val);
-
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator
- .Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
- .ConsumeValueOrDie();
- // Clear visit states so that we can use the evaluator again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
- // Assign computed result to result_val.
- result_val = computed_result->Get<ReturnT>({});
- return true;
- };
- // Computes one element of the result, reducing all dimensions that
- // contribute to that element.
- ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
- arg_dim_steps, func);
- return result_val;
- }));
-
- parent_->evaluated_[reduce] = std::move(result);
- return Status::OK();
+ // Computes one element of the result, reducing all dimensions that
+ // contribute to that element.
+ eval_status = ShapeUtil::ForEachIndexWithStatus(
+ arg_shape, base, arg_dim_counts, arg_dim_steps, func);
+ return result_values[input];
+ }));
+ }
+ if (!has_tuple_output) {
+ parent_->evaluated_[reduce] = std::move(results[0]);
+ } else {
+ Literal tuple_result(reduce->shape());
+ for (int64 i = 0; i < num_args; ++i) {
+ TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
+ }
+ parent_->evaluated_[reduce] = std::move(tuple_result);
+ }
+ return eval_status;
}
bool IsScalarAdd(HloComputation* computation) {
@@ -1599,13 +1774,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
- auto result = MakeUnique<Literal>(select_and_scatter->shape());
+ Literal result(select_and_scatter->shape());
// Initialize result array with the init value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
- return init_scalar;
- }));
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
+ [&](absl::Span<const int64> output_index) { return init_scalar; }));
std::vector<int64> window_dimension_sizes;
for (const auto& window_dimension : window.dimensions()) {
@@ -1643,8 +1816,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// 2. Using the selected index, scatter value from `source` to result. We
// do this by iterating through the window, and compare each index with
// the selected index.
- tensorflow::gtl::optional<ReturnT> selected_val;
- tensorflow::gtl::optional<std::vector<int64>> selected_index;
+ absl::optional<ReturnT> selected_val;
+ absl::optional<std::vector<int64>> selected_index;
IterateThroughWindow(
window_shape, window, operand_literal.shape(), source_index,
@@ -1654,15 +1827,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
selected_val = curr_val;
selected_index = operand_index;
}
- curr_val_literal->Set({}, curr_val);
- selected_val_literal->Set({}, *selected_val);
- std::unique_ptr<Literal> computed_result =
+ curr_val_literal.Set({}, curr_val);
+ selected_val_literal.Set({}, *selected_val);
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *select,
- {selected_val_literal.get(), curr_val_literal.get()})
+ *select, {&selected_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
- bool selected = !computed_result->Get<bool>({});
+ bool selected = !computed_result.Get<bool>({});
if (selected) {
selected_val = curr_val;
selected_index = operand_index;
@@ -1676,22 +1848,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (std::equal(operand_index.begin(), operand_index.end(),
selected_index->begin())) {
auto source = source_literal.Get<ReturnT>(source_index);
- auto scattered = result->Get<ReturnT>(operand_index);
- source_literal_scatter->Set({}, source);
- scattered_literal->Set({}, scattered);
- std::unique_ptr<Literal> computed_result =
+ auto scattered = result.Get<ReturnT>(operand_index);
+ source_literal_scatter.Set({}, source);
+ scattered_literal.Set({}, scattered);
+ Literal computed_result =
embedded_evaluator
- .Evaluate<const Literal*>(*scatter,
- {source_literal_scatter.get(),
- scattered_literal.get()})
+ .Evaluate<const Literal*>(
+ *scatter,
+ {&source_literal_scatter, &scattered_literal})
.ConsumeValueOrDie();
- result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ result.Set(operand_index, computed_result.Get<ReturnT>({}));
// Clear visit states so that the we can use the evaluator again
// on the same computation.
embedded_evaluator.ResetVisitStates();
}
});
- } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+ } while (
+ IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index)));
parent_->evaluated_[select_and_scatter] = std::move(result);
return Status::OK();
@@ -1735,10 +1908,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
- auto result = MakeUnique<Literal>(reduce_window->shape());
+ Literal result(reduce_window->shape());
// For each resulting dimension, calculate and assign computed value.
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
ReturnT result_val = init_scalar;
std::fill(window_index.begin(), window_index.end(), 0);
@@ -1754,18 +1927,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val);
- std::unique_ptr<Literal> computed_result =
+ Literal computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
- *function,
- {result_val_literal.get(), curr_val_literal.get()})
+ *function, {&result_val_literal, &curr_val_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
// on the same computation.
embedded_evaluator.ResetVisitStates();
- result_val = computed_result->Get<ReturnT>({});
+ result_val = computed_result.Get<ReturnT>({});
});
return result_val;
@@ -1780,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64 index_vector_dim, const Literal& indices,
- std::unique_ptr<Literal>* reshaped_indices) {
+ Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(indices);
}
@@ -1789,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
indices.shape().dimensions().end());
new_shape.push_back(1);
TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
- return std::cref(**reshaped_indices);
+ return std::cref(*reshaped_indices);
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -1802,7 +1974,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_scatter_dim =
- !c_binary_search(dim_numbers.update_window_dims(), i);
+ !absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_scatter_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1821,7 +1993,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> index_count(updates_rank, 1);
for (int64 i = 0; i < updates_rank; i++) {
bool is_update_window_dim =
- c_binary_search(dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
index_count[i] = updates_shape.dimensions(i);
}
@@ -1848,7 +2020,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
: dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) {
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
update_dim_is_scatter_dims_.push_back(
- !c_binary_search(dim_numbers_.update_window_dims(), i));
+ !absl::c_binary_search(dim_numbers_.update_window_dims(), i));
}
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
@@ -1883,13 +2055,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// index_vector_index_ and index_vector on every invocation, we reuse the
// same storage for all invocations.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
private:
@@ -1898,7 +2070,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// update the dim_numbers.index_vector_dim() dimension -- that's the
// dimension we iterate over in FetchIndexVector.
void PropagateUpdateIndexScatterDimsToIndexVectorIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
int64 index_vector_index_i = 0;
for (int64 i = 0, e = update_index.size(); i < e; i++) {
if (!update_dim_is_scatter_dims_[i]) {
@@ -1953,7 +2125,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// The index vector fetched from scatter_indices_.
std::vector<int64> index_vector_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
@@ -1978,7 +2150,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> window_index_to_update_index;
int64 update_index_count = 0;
for (int64 i = 0; i < updates_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.update_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
window_index_to_update_index.push_back(update_index_count++);
} else {
update_index_count++;
@@ -1987,7 +2159,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 window_dim_count = 0;
for (int64 i = 0; i < input_shape.dimensions_size(); i++) {
- if (c_binary_search(dim_numbers.inserted_window_dims(), i)) {
+ if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
input_dim_value_to_update_index_.push_back(-1);
} else {
input_dim_value_to_update_index_.push_back(
@@ -2006,11 +2178,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// scatter input index on every invocation we reuse the same storage for the
// result (input_index_), mutating it in place.
//
- // This returns an arrayslice into memory owned by the class.
- StatusOr<tensorflow::gtl::ArraySlice<int64>> operator()(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ // This returns a Span into memory owned by the class.
+ StatusOr<absl::Span<const int64>> operator()(
+ absl::Span<const int64> update_index) {
PropagateUpdateIndexWindowDimsToInputIndex(update_index);
- return tensorflow::gtl::ArraySlice<int64>(input_index_);
+ return absl::Span<const int64>(input_index_);
}
// Returns for a given 'input_dim' the corresponding update dimension index,
@@ -2023,7 +2195,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Propagates window dimensions from the update index to input_index_ by
// mutating input_index_ in place.
void PropagateUpdateIndexWindowDimsToInputIndex(
- tensorflow::gtl::ArraySlice<int64> update_index) {
+ absl::Span<const int64> update_index) {
for (int64 i = 0, e = input_index_.size(); i < e; i++) {
if (input_dim_value_to_update_index_[i] != -1) {
input_index_[i] = update_index[input_dim_value_to_update_index_[i]];
@@ -2039,7 +2211,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// PropagateUpdateIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_update_index_;
- // The result computed by this functor. operator() returns an ArraySlice
+ // The result computed by this functor. operator() returns a Span
// into this vector.
std::vector<int64> input_index_;
};
@@ -2049,7 +2221,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
scatter->scatter_dimension_numbers();
const Literal& operand =
parent_->GetEvaluatedLiteralFor(scatter->operand(0));
- std::unique_ptr<Literal> reshaped_scatter_indices;
+ Literal reshaped_scatter_indices;
TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
ReshapedScatterIndices(dim_numbers.index_vector_dim(),
parent_->GetEvaluatedLiteralFor(
@@ -2079,15 +2251,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
- std::unique_ptr<Literal> result = operand.CloneToUnique();
+ Literal result = operand.Clone();
HloEvaluator embedded_evaluator;
auto scatter_inner_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index,
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
- tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_window_index,
+ absl::Span<const int64> input_scatter_index,
+ absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_window_index,
+ absl::Span<const int64> input_window_index,
update_window_index_to_input_index(update_window_index));
for (int i = 0, e = update_index.size(); i < e; i++) {
update_index[i] = update_scatter_index[i] + update_window_index[i];
@@ -2119,31 +2290,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
auto result_value_literal =
- LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+ LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
auto update_value_literal =
LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
- std::unique_ptr<Literal> updated_result =
+ Literal updated_result =
embedded_evaluator
.Evaluate<const Literal*>(
*scatter->to_apply(),
- {result_value_literal.get(), update_value_literal.get()})
+ {&result_value_literal, &update_value_literal})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again on the
// same computation.
embedded_evaluator.ResetVisitStates();
- result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+ result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
return true;
};
auto scatter_outer_loop_body =
- [&](tensorflow::gtl::ArraySlice<int64> update_scatter_index)
- -> StatusOr<bool> {
+ [&](absl::Span<const int64> update_scatter_index) -> StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
- tensorflow::gtl::ArraySlice<int64> input_scatter_index,
+ absl::Span<const int64> input_scatter_index,
update_scatter_index_to_input_index(update_scatter_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
updates_shape, window_indices_iteration_space,
- [&](tensorflow::gtl::ArraySlice<int64> update_window_index) {
+ [&](absl::Span<const int64> update_window_index) {
return scatter_inner_loop_body(
update_window_index, input_scatter_index, update_scatter_index);
}));
@@ -2171,7 +2341,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 rank = ShapeUtil::Rank(operand->shape());
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
+ auto func = [&](absl::Span<const int64> out_index) {
DimensionVector operand_index(rank);
for (int64 i = 0; i < rank; ++i) {
operand_index[i] =
@@ -2182,7 +2352,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
}
@@ -2387,11 +2557,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_same<NativeT, float>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint32>::value>::type* = nullptr>
- Status HandleIota(HloInstruction* iota) {
- auto result = MakeUnique<Literal>(iota->shape());
- auto data = result->data<ReturnT>();
+ Status HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension()));
std::iota(data.begin(), data.end(), 0);
- parent_->evaluated_[iota] = std::move(result);
+ auto result = LiteralUtil::CreateR1<NativeT>(data);
+
+ if (ShapeUtil::Rank(iota->shape()) > 1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[iota],
+ result.Broadcast(iota->shape(), {iota->iota_dimension()}));
+ } else {
+ TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
+ parent_->evaluated_[iota] = std::move(result);
+ }
+
return Status::OK();
}
template <typename NativeT,
@@ -2432,7 +2612,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// bound, call `f` with the base index.
static void IterateThroughWindow(
const Shape& window_shape, const Window& window, const Shape& base_shape,
- const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const absl::Span<const int64>& window_count_index,
const std::function<void(const std::vector<int64>&)>& f) {
const int64 rank = ShapeUtil::Rank(base_shape);
DimensionVector window_index(rank);
@@ -2451,13 +2631,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
if (!out_of_bound) {
f(base_index);
}
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ } while (
+ IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index)));
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicSlice(
- const Literal& operand_literal, const Literal& start_indices_literal,
- const Shape& result_shape) {
+ StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+ const Literal& start_indices_literal,
+ const Shape& result_shape) {
auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
@@ -2470,9 +2651,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
std::vector<int64> operand_indices(start.size());
- auto result = MakeUnique<Literal>(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ Literal result(result_shape);
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
for (int64 i = 0; i < operand_indices.size(); ++i) {
CHECK_GE(multi_index[i] + start[i], 0);
operand_indices[i] = multi_index[i] + start[i];
@@ -2486,12 +2667,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename IndexT>
- StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
- const Literal& operand_literal, const Literal& update_literal,
- const Literal& start_indices_literal) {
- auto result = operand_literal.CloneToUnique();
+ StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+ const Literal& update_literal,
+ const Literal& start_indices_literal) {
+ auto result = operand_literal.Clone();
auto start_indices_typed = start_indices_literal.data<IndexT>();
- const auto rank = ShapeUtil::Rank(result->shape());
+ const auto rank = ShapeUtil::Rank(result.shape());
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
// Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2499,15 +2680,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
for (int64 i = 0; i < rank; ++i) {
start[i] = std::min<int64>(
std::max<int64>(0, start[i]),
- result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+ result.shape().dimensions(i) - update_literal.shape().dimensions(i));
}
std::vector<int64> result_index(rank, 0);
- auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) {
+ auto func = [&](absl::Span<const int64> update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
- result->Set<ReturnT>(result_index,
- update_literal.Get<ReturnT>(update_index));
+ result.Set<ReturnT>(result_index,
+ update_literal.Get<ReturnT>(update_index));
return true;
};
@@ -2520,7 +2701,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+ StatusOr<Literal> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
const Literal& operand_literal =
@@ -2533,7 +2714,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return std::move(result_literal);
}
- StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+ StatusOr<Literal> ElementWiseBinaryOp(
HloInstruction* instruction,
const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
binary_op) {
@@ -2548,18 +2729,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str());
+ ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()));
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- auto result = MakeUnique<Literal>(shape);
+ Literal result(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ConvertBinaryFunction(binary_op)(
lhs_literal.Get<ReturnT>(multi_index),
rhs_literal.Get<ReturnT>(multi_index));
@@ -2568,7 +2748,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename LhsType, typename RhsType, typename EhsType>
- StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+ StatusOr<Literal> ElementwiseTernaryOp(
HloInstruction* instruction,
const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
const auto shape = instruction->shape();
@@ -2584,20 +2764,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s vs %s: ",
- ShapeUtil::HumanString(shape).c_str(),
- ShapeUtil::HumanString(lhs->shape()).c_str(),
- ShapeUtil::HumanString(rhs->shape()).c_str(),
- ShapeUtil::HumanString(ehs->shape()).c_str());
+ ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()),
+ ShapeUtil::HumanString(rhs->shape()),
+ ShapeUtil::HumanString(ehs->shape()));
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
- auto result = MakeUnique<Literal>(shape);
+ Literal result(shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
+ TF_RETURN_IF_ERROR(
+ result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
return ternary_op(lhs_literal.Get<LhsType>(multi_index),
rhs_literal.Get<RhsType>(multi_index),
ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index c3ccbf0f0c..de3d7a1677 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
@@ -49,7 +51,7 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
size_t profile_counters_size = hlo_profile_index_map.total_count();
std::unique_ptr<HloProfilePrinterData> profile_printer_data =
- MakeUnique<HloProfilePrinterData>();
+ absl::make_unique<HloProfilePrinterData>();
profile_printer_data->set_profile_counters_size(profile_counters_size);
profile_printer_data->mutable_computation_infos()->Reserve(
hlo_profile_index_map.computation_count());
@@ -67,11 +69,11 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
// The profile indices were computed deterministically in
// HloProfileIndexMap::HloProfileIndexMap.
- c_sort(computation_and_profile_idx_list,
- [](const std::pair<const HloComputation*, int64>& left,
- const std::pair<const HloComputation*, int64>& right) {
- return left.second < right.second;
- });
+ absl::c_sort(computation_and_profile_idx_list,
+ [](const std::pair<const HloComputation*, int64>& left,
+ const std::pair<const HloComputation*, int64>& right) {
+ return left.second < right.second;
+ });
for (const auto& pair : computation_and_profile_idx_list) {
CHECK_LT(pair.second, profile_counters_size);
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index eba80c0f19..460ae2b5ec 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -14,15 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::AllOf;
using ::testing::ContainsRegex;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 1efa6eb5bd..d52f4e5a61 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -26,6 +26,12 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -37,50 +43,25 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
-using ::tensorflow::Env;
-using ::tensorflow::WriteStringToFile;
-using ::tensorflow::gtl::nullopt;
-using ::tensorflow::gtl::optional;
-using ::tensorflow::io::JoinPath;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::StringReplace;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
namespace hlo_graph_dumper {
namespace {
-// Helpers for Printf and Appendf.
-template <typename T>
-struct PrintfConvert {
- const T& operator()(const T& t) const { return t; }
-};
-template <>
-struct PrintfConvert<string> {
- const char* operator()(const string& s) const { return s.c_str(); }
-};
-
-// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str()
-// on strings.
-template <typename... Ts>
-string Printf(const char* fmt, const Ts&... ts) {
- return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...);
-}
-template <typename... Ts>
-void Appendf(string* s, const char* fmt, const Ts&... ts) {
- tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...);
-}
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrJoin;
+using tensorflow::Env;
+using tensorflow::WriteStringToFile;
+using tensorflow::io::JoinPath;
// Used to indicate how we should treat a given HLOInstruction in the graph.
// should we treat it like normal, hide it, and so on?
@@ -139,12 +120,23 @@ class NodeFilter {
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
+// We arbitrarily set this as the boundary between "large" and "small"
+// instructions.
+bool IsSmall(const HloInstruction* instr) {
+ if (ShapeUtil::IsOpaque(instr->shape()) ||
+ ShapeUtil::IsToken(instr->shape())) {
+ return true;
+ }
+ return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
+}
+
// Node color schemes, used by NodeColorAttributes.
enum ColorScheme {
kBlue,
kBrown,
kDarkBlue,
kDarkGreen,
+ kDarkOrange,
kDarkRed,
kGray,
kGreen,
@@ -177,6 +169,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) {
return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
case kDarkGreen:
return NodeColors{"filled", "#2e7d32", "#005005", "white"};
+ case kDarkOrange:
+ // This is more of a "medium" orange, made to look close to kOrange;
+ // there's probably room for a darker weight if desired.
+ return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
case kDarkRed:
return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
case kGray:
@@ -209,17 +205,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) {
string NodeColorAttributes(ColorScheme color) {
NodeColors node_colors = NodeColorsForScheme(color);
- return Printf(
- R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
- node_colors.style, node_colors.font_color, node_colors.stroke_color,
- node_colors.fill_color);
+ return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
+ node_colors.style, node_colors.font_color,
+ node_colors.stroke_color, node_colors.fill_color);
}
// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
// graphviz HTML-like string.
-string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
- return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
- "&gt;", /*replace_all=*/true);
+string HtmlLikeStringSanitize(absl::string_view s) {
+ return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
}
// Tries to generates a human-readable one-word description of the given
@@ -322,11 +316,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
class HloDotDumper {
public:
- HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
+ HloDotDumper(const HloComputation* computation, absl::string_view label,
const DebugOptions& debug_options, bool show_backend_config,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
- label_(std::string(label)),
+ label_(label),
debug_options_(debug_options),
show_backend_config_(show_backend_config),
profile_(profile),
@@ -448,7 +442,7 @@ string HloDotDumper::Dump() {
}
string HloDotDumper::Header() {
- const char* fmt = R"(digraph G {
+ constexpr char fmt[] = R"(digraph G {
rankdir = TB;
compound = true;
label = <<b>%s</b>>;
@@ -457,7 +451,7 @@ labelloc = t;
tooltip = " ";
// DOT graphs accept a stylesheet as a URI. So naturally, an inline
// stylesheet is a data URI!
-stylesheet="
+stylesheet=<
data:text/css,
@import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
svg text {
@@ -466,7 +460,7 @@ stylesheet="
}
%s
-"
+>
)";
@@ -481,8 +475,8 @@ stylesheet="
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
- Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
- tensorflow::strings::HumanReadableNum(cycles));
+ absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
+ tensorflow::strings::HumanReadableNum(cycles));
}
// Create CSS rules that say, when you hover over the given node or cluster,
@@ -509,14 +503,14 @@ stylesheet="
// One could imagine other ways of writing this CSS rule that involve
// less duplication, but this way seems to be relatively performant.
edge_css_rules.push_back(
- Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n"
- " #%s%d:hover ~ #edge%lld path { "
- "stroke: %s; stroke-width: .2em; }\n"
- " #%s%d:hover ~ #edge%lld polygon { "
- "fill: %s; stroke: %s; stroke-width: .2em; }\n",
- elem_type, elem_id, edge_id, color, //
- elem_type, elem_id, edge_id, color, //
- elem_type, elem_id, edge_id, color, color));
+ StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
+ " #%s%d:hover ~ #edge%d path { "
+ "stroke: %s; stroke-width: .2em; }\n"
+ " #%s%d:hover ~ #edge%d polygon { "
+ "fill: %s; stroke: %s; stroke-width: .2em; }\n",
+ elem_type, elem_id, edge_id, color, //
+ elem_type, elem_id, edge_id, color, //
+ elem_type, elem_id, edge_id, color, color));
};
// The "to_node" value may be a NULL, indicating that this points to the
@@ -559,10 +553,10 @@ stylesheet="
}
}
- return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
+ return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n"));
}
-string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
+string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
@@ -600,9 +594,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
<< " as " << next_edge_id_;
edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
- const char* edge_fmt =
+ constexpr char edge_fmt[] =
R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
- edges_.push_back(Printf(
+ edges_.push_back(StrFormat(
edge_fmt, InstructionId(from), InstructionId(parent_instr),
SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
}
@@ -619,9 +613,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
string subcomp_label, style;
if (parent_instr->opcode() == HloOpcode::kFusion) {
- subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s",
- HtmlLikeStringSanitize(parent_instr->name()),
- HtmlLikeStringSanitize(parent_instr->ToCategory()));
+ subcomp_label =
+ StrFormat("Fused expression for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(parent_instr->ToCategory()));
string extra_info = GetInstructionNodeExtraInfo(parent_instr);
if (!extra_info.empty()) {
StrAppend(&subcomp_label, "<br/>", extra_info);
@@ -647,18 +642,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
}
style =
- Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
- fillcolor, strokecolor);
+ StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
+ fillcolor, strokecolor);
} else {
- subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s",
- HtmlLikeStringSanitize(parent_instr->name()),
- HtmlLikeStringSanitize(subcomp->name()));
+ subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
+ HtmlLikeStringSanitize(parent_instr->name()),
+ HtmlLikeStringSanitize(subcomp->name()));
style = "style=rounded; color=black;";
}
string comp_body = DumpComputation(subcomp);
- const char* computation_fmt = R"(subgraph %s {
+ constexpr char computation_fmt[] = R"(subgraph %s {
%s
label = <%s>;
labelloc = t;
@@ -667,7 +662,7 @@ tooltip = " ";
} // %s
)";
- return Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
+ return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
}
string HloDotDumper::DumpComputation(const HloComputation* comp) {
@@ -718,11 +713,11 @@ string HloDotDumper::DumpRootTag() {
VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
<< next_edge_id_;
edge_ids_.insert({{from, to}, next_edge_id_++});
- edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
+ edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
- return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
- "\n",
- to_id, node_body, node_shape, NodeColorAttributes(color));
+ return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
+ "\n",
+ to_id, node_body, node_shape, NodeColorAttributes(color));
}
static const HloConstantInstruction* TryGetFusionParameterConstant(
@@ -817,10 +812,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
}
- return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
- "\n",
- InstructionId(instr), node_body, node_shape, node_metadata,
- NodeColorAttributes(color));
+ return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
+ "\n",
+ InstructionId(instr), node_body, node_shape, node_metadata,
+ NodeColorAttributes(color));
}
string HloDotDumper::GetInstructionNodeInlinedOperands(
@@ -833,7 +828,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
if (ShapeUtil::IsZeroElementArray(shape)) {
- return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
@@ -848,19 +843,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// collected from profiling tools. Those constants may not have a valid
// literal.
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
- return Printf("%s (%s)", constant->literal().ToString(),
- ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("%s (%s)", constant->literal().ToString(),
+ ShapeUtil::HumanString(constant->shape()));
}
// Otherwise, print e.g. "%constant.42 (s32[100])".
string constant_name;
- if (tensorflow::str_util::StartsWith(constant->name(), "constant")) {
+ if (absl::StartsWith(constant->name(), "constant")) {
constant_name = constant->name();
} else {
constant_name = StrCat("constant ", constant->name());
}
- return Printf("%s %s", constant_name,
- ShapeUtil::HumanString(constant->shape()));
+ return StrFormat("%s %s", constant_name,
+ ShapeUtil::HumanString(constant->shape()));
};
std::vector<string> lines;
@@ -881,7 +876,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
TryGetFusionParameterConstant(operand)) {
operand_str = stringify_constant(constant);
} else {
- operand_str = Printf("Parameter %lld", operand->parameter_number());
+ operand_str = StrFormat("Parameter %d", operand->parameter_number());
}
} else {
operand_str = operand->name();
@@ -890,13 +885,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
if (operand_str) {
if (instr->operand_count() > 1) {
- lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str));
+ lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
} else {
- lines.push_back(Printf("<b>operand</b> = %s", *operand_str));
+ lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
}
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
@@ -913,7 +908,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
sharding_colors_.emplace(instr->sharding(), color);
return color;
}
- const auto kParameterColor = kOrange;
+
+ // Choose different weights of orange for small vs large parameters. This
+ // distinction is often important, especially in fusion nodes.
+ auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
// Special case: If this instruction has a parameter merged into it, paint it
// the same color as a parameter. Unless the merged-in parameter is a
@@ -925,7 +923,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
- return kParameterColor;
+ return parameter_color;
}
// Pick different colors or shapes for instructions which are particularly
@@ -1035,7 +1033,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kReducePrecision:
return kRed;
case HloOpcode::kParameter:
- return kParameterColor;
+ return parameter_color;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
@@ -1049,6 +1047,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kGray;
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kRecv:
@@ -1059,7 +1058,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
- case HloOpcode::kHostCompute:
case HloOpcode::kWhile:
return kDarkGreen;
case HloOpcode::kConstant:
@@ -1080,14 +1078,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
// If we have a parameter, put the param number in the name.
if (instr->opcode() == HloOpcode::kParameter) {
- return Printf("<b>Parameter %lld</b>", instr->parameter_number());
+ return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
}
// The HLO instruction name contains usually the opcode, e.g. "%add.42" is
// an add instruction. In this case we render just the name.
- if (tensorflow::str_util::StartsWith(instr->name(),
- HloOpcodeString(instr->opcode()))) {
- return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
+ if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
+ return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
}
string extended_opcode =
StrCat(HloOpcodeString(instr->opcode()),
@@ -1095,8 +1092,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
? ""
: StrCat(":", xla::ToString(instr->fusion_kind())));
// If the name does not contain the opcode, render both.
- return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
- HtmlLikeStringSanitize(instr->name()));
+ return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
+ HtmlLikeStringSanitize(instr->name()));
}
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
@@ -1105,16 +1102,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
}
if (!instr->metadata().op_type().empty()) {
- lines.push_back(Printf(
+ lines.push_back(StrFormat(
"op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
}
if (!instr->metadata().source_file().empty() &&
instr->metadata().source_line() != 0) {
- lines.push_back(Printf("op_type: %s", instr->metadata().source_file(),
- instr->metadata().source_line()));
+ lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
+ instr->metadata().source_line()));
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
@@ -1161,13 +1158,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
constexpr int kMaxShapeLen = 64;
if (instr_shape.length() > kMaxShapeLen) {
instr_shape = StrCat(
- tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
- "...");
+ absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
}
lines.push_back(instr_shape);
}
if (debug_options_.xla_hlo_graph_addresses()) {
- lines.push_back(Printf("[%p]", instr));
+ lines.push_back(StrFormat("[%p]", instr));
}
if (profile_ != nullptr) {
double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
@@ -1175,25 +1171,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
profile_->total_cycles_executed(*instr->parent());
if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
lines.push_back(
- Printf("%% of cycles executed=%.2f",
- 100 * hlo_cycles_executed / total_cycles_executed));
+ StrFormat("%% of cycles executed=%.2f",
+ 100 * hlo_cycles_executed / total_cycles_executed));
}
}
- return Join(lines, "<br/>");
-}
-
-// Gets the total number of array elements in the given shape. For tuples, this
-// is the sum of all the sizes of all of the array elements recursively in the
-// tuple.
-static int64 TotalElementsInShape(const Shape& shape) {
- int64 elems = 0;
- ShapeUtil::ForEachSubshape(
- shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) {
- if (ShapeUtil::IsArray(subshape)) {
- elems += ShapeUtil::ElementsIn(subshape);
- }
- });
- return elems;
+ return StrJoin(lines, "<br/>");
}
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
@@ -1211,20 +1193,19 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
string edge_label;
if (instr->operand_count() > 1 && !control_edge) {
- edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num);
+ edge_label =
+ StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
} else if (control_edge) {
edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
}
// We print "small" arrays using a hollow arrowhead and "large" arrays using
- // a filled arrowhead. For now, we use an arbitrary cutoff for what "big"
- // means.
- bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
-
- const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
- edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
- (is_big_array ? "normal" : "empty"), from->name(),
- to->name(), edge_label));
+ // a filled arrowhead.
+ constexpr char kEdgeFmt[] =
+ R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
+ edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
+ (IsSmall(from) ? "empty" : "normal"),
+ from->name(), to->name(), edge_label));
};
// Add edges from instr's operands to instr. Parameters within fusion
@@ -1265,14 +1246,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr(
continue;
}
if (instr->called_computations().size() == 1) {
- lines.push_back(Printf("Subcomputation: <b>%s</b>",
- HtmlLikeStringSanitize(*computation_type)));
+ lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
+ HtmlLikeStringSanitize(*computation_type)));
} else {
- lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i,
- HtmlLikeStringSanitize(*computation_type)));
+ lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
+ HtmlLikeStringSanitize(*computation_type)));
}
}
- return Join(lines, "<br/>");
+ return StrJoin(lines, "<br/>");
}
const HloInstruction* HloDotDumper::GetNodeForEdge(
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 1d7a062c55..064c53252c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,12 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
using ::testing::HasSubstr;
string TestName() {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 7371fde79b..85fa3ce964 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -21,10 +21,17 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -39,17 +46,15 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
@@ -108,7 +113,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
- tensorflow::gtl::ArraySlice<int64>(fft_length));
+ absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kSend:
@@ -153,16 +158,26 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
break;
case HloOpcode::kReduce:
- TF_RET_CHECK(proto.operand_ids_size() == 2)
- << "Reduce instruction should have 2 operands but sees "
+ TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
+ << "Reduce instruction should have an even number of operands but "
+ "sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Reduce instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- instruction = CreateReduce(proto.shape(), operands(0), operands(1),
- std::vector<int64>(proto.dimensions().begin(),
- proto.dimensions().end()),
- computations(0));
+ {
+ const auto reduce_operands = all_operands();
+ auto inputs = absl::MakeSpan(reduce_operands)
+ .subspan(0, reduce_operands.size() / 2);
+ auto init_values =
+ absl::MakeSpan(reduce_operands)
+ .subspan(reduce_operands.size() / 2, reduce_operands.size());
+ instruction =
+ CreateReduce(proto.shape(), inputs, init_values,
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()),
+ computations(0));
+ }
break;
case HloOpcode::kSort: {
TF_RET_CHECK(proto.operand_ids_size() == 1 ||
@@ -224,7 +239,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
} else {
- instruction = MakeUnique<HloConstantInstruction>(proto.shape());
+ instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
}
break;
}
@@ -235,7 +250,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
- instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}
case HloOpcode::kFusion: {
@@ -281,12 +296,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kInfeed: {
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1);
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed:
- CHECK_EQ(proto.operand_ids_size(), 2);
+ TF_RET_CHECK(proto.operand_ids_size() == 2);
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
break;
@@ -294,15 +309,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "CrossReplicaSum should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- tensorflow::gtl::optional<int64> all_reduce_id;
+ absl::optional<int64> all_reduce_id;
if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id();
}
instruction = CreateCrossReplicaSum(
proto.shape(), all_operands(), computations(0),
- /*replica_group_ids=*/
- std::vector<int64>(proto.replica_group_ids().begin(),
- proto.replica_group_ids().end()),
+ /*replica_groups=*/
+ std::vector<ReplicaGroup>(proto.replica_groups().begin(),
+ proto.replica_groups().end()),
/*barrier=*/proto.cross_replica_sum_barrier(),
/*all_reduce_id=*/all_reduce_id);
break;
@@ -312,21 +327,35 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.shape(), all_operands(),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
- proto.replica_groups().end()),
- /*barrier=*/proto.cross_replica_sum_barrier());
+ proto.replica_groups().end()));
break;
}
- case HloOpcode::kConvolution:
+ case HloOpcode::kCollectivePermute: {
+ std::vector<std::pair<int64, int64>> source_target_pairs(
+ proto.source_target_pairs_size());
+ for (int i = 0; i < source_target_pairs.size(); i++) {
+ source_target_pairs[i].first = proto.source_target_pairs(i).source();
+ source_target_pairs[i].second = proto.source_target_pairs(i).target();
+ }
+ instruction = CreateCollectivePermute(proto.shape(), operands(0),
+ source_target_pairs);
+ break;
+ }
+ case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+ PrecisionConfig precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = CreateConvolve(
- proto.shape(), operands(0), operands(1), proto.window(),
- proto.convolution_dimension_numbers(),
- std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
+ proto.shape(), operands(0), operands(1),
+ std::max<int64>(proto.feature_group_count(), 1), proto.window(),
+ proto.convolution_dimension_numbers(), precision_config);
break;
+ }
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "ReduceWindow instruction should have 2 operands but sees "
@@ -360,11 +389,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
}
- break;
- case HloOpcode::kHostCompute:
- instruction =
- CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(),
- proto.cost_estimate_ns());
+ static_cast<HloCustomCallInstruction*>(instruction.get())
+ ->set_feature_group_count(
+ std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
@@ -379,7 +406,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "DynamicSlice instruction should have 2 operands but sees "
<< proto.operand_ids_size();
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
- c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
+ absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
slice_sizes);
break;
@@ -391,14 +418,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
- MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- std::vector<int64> gather_window_bounds;
- for (int64 bound : proto.gather_window_bounds()) {
- gather_window_bounds.push_back(bound);
+ absl::make_unique<GatherDimensionNumbers>(
+ proto.gather_dimension_numbers());
+ std::vector<int64> gather_slice_sizes;
+ for (int64 bound : proto.gather_slice_sizes()) {
+ gather_slice_sizes.push_back(bound);
}
- instruction =
- CreateGather(proto.shape(), operands(0), operands(1),
- *gather_dimension_numbers, gather_window_bounds);
+ instruction = CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_slice_sizes);
break;
}
case HloOpcode::kScatter: {
@@ -410,15 +437,44 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Scatter instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
- auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
- proto.scatter_dimension_numbers());
+ auto scatter_dimension_numbers =
+ absl::make_unique<ScatterDimensionNumbers>(
+ proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
+ case HloOpcode::kIota:
+ TF_RET_CHECK(proto.dimensions_size() <= 1)
+ << "Iota instruction should have at most 1 dimension but sees "
+ << proto.dimensions_size();
+ instruction = CreateIota(proto.shape(), proto.dimensions(0));
+ break;
+ case HloOpcode::kDot: {
+ TF_RET_CHECK(proto.has_dot_dimension_numbers())
+ << "Dot instruction should have dot_dimension_numbers.";
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Dot instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ PrecisionConfig precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
+ instruction = absl::make_unique<HloDotInstruction>(
+ proto.shape(), operands(0), operands(1),
+ proto.dot_dimension_numbers(), precision_config);
+ break;
+ }
+ case HloOpcode::kDomain:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Domain instruction should have 1 operands but sees "
+ << proto.operand_ids_size();
+ instruction = absl::make_unique<HloDomainInstruction>(
+ proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
+ /*user_side_metadata=*/nullptr);
+ break;
default: {
- instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
+ instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
<< "No instruction with id " << operand_id;
@@ -438,6 +494,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
@@ -447,11 +506,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- if (proto.has_dot_dimension_numbers()) {
- instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
- }
-
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
HloSharding::FromProto(proto.sharding()));
@@ -463,44 +517,46 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
- return MakeUnique<HloParameterInstruction>(parameter_number, shape, name);
+ return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
+ name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
const string& tag, HloInstruction* operand) {
- return MakeUnique<HloTraceInstruction>(tag, operand);
+ return absl::make_unique<HloTraceInstruction>(tag, operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
- std::unique_ptr<Literal> literal) {
- return MakeUnique<HloConstantInstruction>(std::move(literal));
+ Literal literal) {
+ return absl::make_unique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
- const Shape& shape) {
- return WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
+ const Shape& shape, int64 iota_dimension) {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
- return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index);
+ return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
+ index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
- return MakeUnique<HloRngInstruction>(shape, distribution, parameters);
+ absl::Span<HloInstruction* const> parameters) {
+ return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
if (opcode == HloOpcode::kCopy) {
// It is impossible to copy an opaque shape, we don't know how big it is.
CHECK(!ShapeUtil::IsOpaque(shape));
}
- auto instruction = WrapUnique(new HloInstruction(opcode, shape));
+ auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -519,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
@@ -551,7 +606,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
- case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -597,54 +651,40 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK_EQ(HloOpcode::kTuple, opcode);
return CreateNary(shape, opcode, operands);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation) {
- return MakeUnique<HloMapInstruction>(shape, operands, map_computation);
+ return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
- return MakeUnique<HloConvolutionInstruction>(
- shape, lhs, rhs, window, dimension_numbers, feature_group_count);
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config) {
+ return absl::make_unique<HloConvolutionInstruction>(
+ shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length) {
- return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length);
+ absl::Span<const int64> fft_length) {
+ return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
+ fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- MakeUnique<DotDimensionNumbers>(dimension_numbers);
- return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
- CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
- CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
-
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
- instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
- instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
- return instruction;
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config) {
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -652,48 +692,55 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, operand, exponent_bits, mantissa_bits);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id) {
- return MakeUnique<HloAllReduceInstruction>(
- shape, operands, reduce_computation, replica_group_ids, barrier,
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
+ const absl::optional<int64>& all_reduce_id) {
+ return absl::make_unique<HloAllReduceInstruction>(
+ shape, operands, reduce_computation, replica_groups, barrier,
all_reduce_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier) {
- return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups,
- barrier);
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups) {
+ return absl::make_unique<HloAllToAllInstruction>(shape, operands,
+ replica_groups);
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs) {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, operand, source_target_pairs);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
const Shape& infeed_shape, HloInstruction* token_operand,
const string& config) {
- return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config);
+ return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
+ config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
- token_operand, outfeed_config);
+ HloInstruction* token_operand, absl::string_view outfeed_config) {
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape, operand, token_operand, outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloSendInstruction>(operand, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
@@ -701,14 +748,15 @@ HloInstruction::CreateCrossReplicaSum(
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer);
+ return absl::make_unique<HloSendDoneInstruction>(send_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, HloInstruction* token, int64 channel_id,
bool is_host_transfer) {
- return MakeUnique<HloRecvInstruction>(shape, token, channel_id,
- is_host_transfer);
+ return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
@@ -716,19 +764,20 @@ HloInstruction::CreateCrossReplicaSum(
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer);
+ return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
+ absl::Span<const int64> dimensions) {
+ return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ absl::Span<HloInstruction* const> operands) {
CHECK(!operands.empty());
- auto instruction = WrapUnique(
+ auto instruction = absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
instruction->AppendOperand(operand);
@@ -737,14 +786,15 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
- return WrapUnique(
+ return absl::WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
instruction->AppendOperand(init);
// Body comes before condition computation in the vector.
instruction->called_computations_.push_back(body);
@@ -757,7 +807,7 @@ HloInstruction::CreateCrossReplicaSum(
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
instruction->AppendOperand(pred);
instruction->AppendOperand(true_computation_arg);
instruction->AppendOperand(false_computation_arg);
@@ -771,18 +821,17 @@ HloInstruction::CreateCrossReplicaSum(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
- return MakeUnique<HloSliceInstruction>(shape, operand, start_indices,
- limit_indices, strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
+ return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
+ limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return MakeUnique<HloDynamicSliceInstruction>(shape, operand, start_indices,
- slice_sizes);
+ absl::Span<const int64> slice_sizes) {
+ return absl::make_unique<HloDynamicSliceInstruction>(
+ shape, operand, start_indices, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -790,8 +839,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
+ auto instruction = absl::WrapUnique(
+ new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(update);
instruction->AppendOperand(start_indices);
@@ -799,14 +848,16 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension) {
- return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension);
+ return absl::make_unique<HloConcatenateInstruction>(shape, operands,
+ dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
const Shape& shape, HloInstruction* operand) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
@@ -815,38 +866,38 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
HloInstruction::CreateBitcastConvert(const Shape& shape,
HloInstruction* operand) {
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- auto instruction = WrapUnique(new HloReduceInstruction(
+ auto instruction = absl::WrapUnique(new HloReduceInstruction(
shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
std::vector<HloInstruction*> all_args;
all_args.reserve(operands.size() * 2);
all_args.insert(all_args.end(), operands.begin(), operands.end());
all_args.insert(all_args.end(), init_values.begin(), init_values.end());
- return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
- reduce_computation);
+ return absl::make_unique<HloReduceInstruction>(
+ shape, all_args, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
- return MakeUnique<HloReduceWindowInstruction>(shape, operand, init_value,
- window, reduce_computation);
+ return absl::make_unique<HloReduceWindowInstruction>(
+ shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -855,7 +906,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, operand, scale, offset, epsilon, feature_index);
}
@@ -864,7 +915,7 @@ HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
@@ -874,9 +925,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
- return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean,
- variance, grad_output, epsilon,
- feature_index);
+ return absl::make_unique<HloBatchNormGradInstruction>(
+ shape, operand, scale, mean, variance, grad_output, epsilon,
+ feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -884,15 +935,15 @@ HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return MakeUnique<HloBroadcastInstruction>(shape, operand,
- broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions) {
+ return absl::make_unique<HloBroadcastInstruction>(shape, operand,
+ broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -950,8 +1001,8 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
- return MakeUnique<HloPadInstruction>(shape, operand, padding_value,
- padding_config);
+ return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
+ padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
@@ -960,34 +1011,36 @@ HloInstruction::CreateBroadcastSequence(
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
+ auto instruction =
+ absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
+ absl::Span<const int64> dimensions) {
+ return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
HloInstruction* values) {
- return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
+ fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation) {
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind, operands,
- fusion_computation);
+ return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
+ fusion_computation);
}
void HloInstruction::set_single_sharding(const HloSharding& sharding) {
@@ -1019,7 +1072,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
- case HloOpcode::kHostCompute:
return true;
case HloOpcode::kCrossReplicaSum:
return all_reduce_id().has_value();
@@ -1042,10 +1094,10 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation) {
std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
+ absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
for (auto operand : operands) {
instruction->AppendOperand(operand);
}
@@ -1054,21 +1106,14 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target) {
- return MakeUnique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
- return MakeUnique<HloHostComputeInstruction>(shape, operands, channel_name,
- cost_estimate_ns);
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target) {
+ return absl::make_unique<HloCustomCallInstruction>(shape, operands,
+ custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
+ absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
for (auto element : elements) {
element_shapes.push_back(element->shape());
@@ -1078,11 +1123,11 @@ bool HloInstruction::HasSideEffect() const {
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
- gather_dim_numbers, window_bounds);
+ absl::Span<const int64> slice_sizes) {
+ return absl::make_unique<HloGatherInstruction>(
+ shape, operand, start_indices, gather_dim_numbers, slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
@@ -1090,25 +1135,22 @@ bool HloInstruction::HasSideEffect() const {
HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers) {
- return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
- updates, update_computation,
- scatter_dim_numbers);
+ return absl::make_unique<HloScatterInstruction>(
+ shape, operand, scatter_indices, updates, update_computation,
+ scatter_dim_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
- instruction->operand_side_metadata_ = std::move(operand_side_metadata);
- instruction->user_side_metadata_ = std::move(user_side_metadata);
- instruction->AppendOperand(operand);
- return instruction;
+ return absl::make_unique<HloDomainInstruction>(
+ shape, operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
@@ -1147,19 +1189,21 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReducePrecision:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1232,11 +1276,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kDot:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_);
- break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
@@ -1261,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kDomain:
- CHECK_EQ(new_operands.size(), 1);
- clone =
- CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
- user_side_metadata_->Clone());
- break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
@@ -1275,6 +1308,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
break;
}
+ // SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_raw_backend_config_string(backend_config_);
@@ -1340,7 +1374,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
// If names ends with .suffix[0-9]+ then replace with a suffix with the
// numeric value incremented.
int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
clone->name_ =
StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
} else {
@@ -1458,7 +1492,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
}
void HloInstruction::RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ absl::Span<const int> ascending_indices) {
if (ascending_indices.empty()) {
return;
}
@@ -1561,11 +1595,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAfterAll:
return false;
- // Check dot dimension numbers.
- case HloOpcode::kDot:
- return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- other.dot_dimension_numbers());
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1581,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath(
return false;
}
- case HloOpcode::kDomain:
- return operand_side_metadata().Matches(other.operand_side_metadata()) &&
- user_side_metadata().Matches(other.user_side_metadata());
-
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
@@ -1615,15 +1640,17 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kOutfeed:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -1813,7 +1840,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) {
string HloInstruction::SignatureString() const {
string operands =
- Join(operands_, ", ", [](string* out, HloInstruction* operand) {
+ StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
@@ -1833,7 +1860,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
}
bool HloInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
switch (opcode_) {
// Unary elementwise operations.
case HloOpcode::kAbs:
@@ -1954,13 +1981,13 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string operands;
- tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
+ absl::Span<HloInstruction* const> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
if (options.compact_operands() &&
slice.size() > kMaxOperandsToShowIfCompact) {
slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
}
- operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
// If operand is already been deleted, put `null` to the string output.
if (operand == nullptr) {
StrAppend(out, "null ");
@@ -1980,7 +2007,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
} else if (!options.compact_operands()) {
str.push_back(PrintName(operand->name(), options));
}
- StrAppend(out, Join(str, " "));
+ StrAppend(out, StrJoin(str, " "));
});
const int64 remaining = operands_.size() - slice.size();
if (slice.size() != operands_.size()) {
@@ -1993,10 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
- if (dot_dimension_numbers_ != nullptr) {
- extra.push_back(DotDimensionNumbersToString());
- }
-
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2022,11 +2045,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
- "calls=", Join(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out,
- PrintName(computation->name(), options));
- })));
+ "calls=",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, PrintName(computation->name(), options));
+ })));
}
} else if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
@@ -2059,12 +2082,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
break;
default:
if (!called_computations().empty()) {
- extra.push_back(
- StrCat("calls=\n",
- Join(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out, computation->ToString(new_options));
- })));
+ extra.push_back(StrCat(
+ "calls=\n",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->ToString(new_options));
+ })));
}
break;
}
@@ -2073,30 +2096,25 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
- if (!control_predecessors_.empty()) {
+ if (options.print_control_dependencies() && !control_predecessors_.empty()) {
extra.push_back(StrCat("control-predecessors={",
- Join(control_predecessors_, ", ",
- [&](string* out, HloInstruction* pre) {
- StrAppend(out,
- PrintName(pre->name(), options));
- }),
+ StrJoin(control_predecessors_, ", ",
+ [&](string* out, HloInstruction* pre) {
+ StrAppend(out,
+ PrintName(pre->name(), options));
+ }),
"}"));
}
- if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
- "\", entry=", user_side_metadata_->ToString(),
- ", exit=", operand_side_metadata_->ToString(), "}"));
- }
return extra;
}
string HloInstruction::ToShortString() const {
return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
- Join(operands_, ", ",
- [](string* out, HloInstruction* operand) {
- StrAppend(out, "%", operand->name());
- }),
+ StrJoin(operands_, ", ",
+ [](string* out, HloInstruction* operand) {
+ StrAppend(out, "%", operand->name());
+ }),
")");
}
@@ -2124,10 +2142,6 @@ HloInstructionProto HloInstruction::ToProto() const {
}
}
- if (dot_dimension_numbers_ != nullptr) {
- *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
- }
-
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
@@ -2156,7 +2170,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
-bool HloInstruction::IsFusable() const {
+bool HloInstruction::IsFusible() const {
// Instructions which are traced should not be fused.
if (tracing()) {
return false;
@@ -2262,6 +2276,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCrossReplicaSum(this);
case HloOpcode::kAllToAll:
return visitor->HandleAllToAll(this);
+ case HloOpcode::kCollectivePermute:
+ return visitor->HandleCollectivePermute(this);
case HloOpcode::kTuple:
return visitor->HandleTuple(this);
case HloOpcode::kMap:
@@ -2330,8 +2346,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
- case HloOpcode::kHostCompute:
- return visitor->HandleHostCompute(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kWhile:
@@ -2370,15 +2384,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return InternalError(
"Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
"please file a bug for XLA.",
- HloOpcodeString(opcode_).c_str());
+ HloOpcodeString(opcode_));
}
// Explicit instantiations.
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
-using DFSStack =
- tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
+using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
@@ -2454,7 +2467,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
- current_node->ToString().c_str());
+ current_node->ToString());
}
}
@@ -2463,7 +2476,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
return FailedPrecondition(
"A cycle is detected while visiting instruction %s",
- current_node->ToString().c_str());
+ current_node->ToString());
}
}
}
@@ -2623,7 +2636,7 @@ bool HloInstruction::IsElementwiseBinary() const {
}
bool HloInstruction::IsElementwise() const {
- return IsElementwiseImpl(tensorflow::gtl::nullopt);
+ return IsElementwiseImpl(absl::nullopt);
}
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
@@ -2711,10 +2724,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
case HloOpcode::kTranspose:
return UseKind::kUsePermutingElements;
case HloOpcode::kPad:
- case HloOpcode::kReduce:
// Pad reuses the padding value but not the padded array elements.
- // Reduce reuses the init value but not the operand array elements.
return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
+ case HloOpcode::kReduce:
+ // Reduce reuses the init values but not the operand array elements.
+ return i >= Cast<HloReduceInstruction>(this)->input_count()
+ ? UseKind::kReuse
+ : UseKind::kUsePermutingElements;
case HloOpcode::kFusion:
// Uses the memoizing, recursive computation defined above.
return FusionReusesParamElements::Compute(i, *fused_expression_root());
@@ -2779,7 +2795,7 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
if (kind_name == "kCustom") {
return HloInstruction::FusionKind::kCustom;
}
- return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
+ return InvalidArgument("Unknown fusion kind: %s", kind_name);
}
string PaddingConfigToString(const PaddingConfig& padding) {
@@ -2788,7 +2804,7 @@ string PaddingConfigToString(const PaddingConfig& padding) {
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
- return Join(
+ return StrJoin(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
StrAppend(
@@ -2812,11 +2828,15 @@ string OpMetadataToString(const OpMetadata& metadata) {
if (metadata.source_line() != 0) {
result.push_back(StrCat("source_line=", metadata.source_line()));
}
- return Join(result, " ");
+ return StrJoin(result, " ");
}
string RandomDistributionToString(const RandomDistribution& distribution) {
- return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
+ return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
+}
+
+string PrecisionToString(const PrecisionConfig::Precision& precision) {
+ return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
}
string ConvolutionDimensionNumbersToString(
@@ -2844,31 +2864,8 @@ string ConvolutionDimensionNumbersToString(
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
- return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
- Join(output_dims, ""));
-}
-
-string HloInstruction::DotDimensionNumbersToString() const {
- std::vector<string> result;
- if (dot_dimension_numbers_ == nullptr) {
- return "";
- }
- const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
- if (!dnums.lhs_batch_dimensions().empty()) {
- result.push_back(StrCat("lhs_batch_dims={",
- Join(dnums.lhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("lhs_contracting_dims={",
- Join(dnums.lhs_contracting_dimensions(), ","), "}"));
-
- if (!dnums.rhs_batch_dimensions().empty()) {
- result.push_back(StrCat("rhs_batch_dims={",
- Join(dnums.rhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("rhs_contracting_dims={",
- Join(dnums.rhs_contracting_dimensions(), ","), "}"));
-
- return Join(result, ", ");
+ return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
+ StrJoin(output_dims, ""));
}
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
@@ -2882,7 +2879,26 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
return map;
}();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
+ auto found = map->find(absl::AsciiStrToLower(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
+ static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
+ static auto* map =
+ new std::unordered_map<string, PrecisionConfig::Precision>;
+ for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
+ if (PrecisionConfig::Precision_IsValid(i)) {
+ auto value = static_cast<PrecisionConfig::Precision>(i);
+ (*map)[PrecisionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown distribution");
}
@@ -2929,6 +2945,16 @@ Status HloInstruction::set_backend_config(
return ret;
}
+const PrecisionConfig& HloInstruction::precision_config() const {
+ if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->precision_config();
+ }
+ if (auto* dot = DynCast<HloDotInstruction>(this)) {
+ return dot->precision_config();
+ }
+ LOG(FATAL) << "Unimplemented method.";
+}
+
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
@@ -3132,31 +3158,25 @@ const string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
}
-const std::vector<int64>& HloInstruction::replica_group_ids() const {
- return Cast<HloAllReduceInstruction>(this)->replica_group_ids();
+const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
+ return Cast<HloCollectiveInstruction>(this)->replica_groups();
}
-const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
- return Cast<HloAllToAllInstruction>(this)->replica_groups();
+const std::vector<std::pair<int64, int64>>&
+HloInstruction::source_target_pairs() const {
+ return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
}
string HloInstruction::cross_replica_sum_barrier() const {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
- return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
- }
- return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier();
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
- if (opcode() == HloOpcode::kCrossReplicaSum) {
- return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
- barrier);
- }
- return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier(
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
barrier);
}
-tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const {
+absl::optional<int64> HloInstruction::all_reduce_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
}
@@ -3183,7 +3203,15 @@ void HloInstruction::set_convolution_dimension_numbers(
}
int64 HloInstruction::feature_group_count() const {
- return Cast<HloConvolutionInstruction>(this)->feature_group_count();
+ if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->feature_group_count();
+ }
+ return Cast<HloCustomCallInstruction>(this)->feature_group_count();
+}
+
+void HloInstruction::set_feature_group_count(int64 feature_group_count) {
+ Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
+ feature_group_count);
}
HloComputation* HloInstruction::select() const {
@@ -3206,10 +3234,6 @@ const string& HloInstruction::custom_call_target() const {
return Cast<HloCustomCallInstruction>(this)->custom_call_target();
}
-const string& HloInstruction::channel_name() const {
- return Cast<HloHostComputeInstruction>(this)->channel_name();
-}
-
const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
@@ -3226,9 +3250,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
}
-tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
- const {
- return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
+ return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
}
const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
@@ -3236,4 +3259,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
+const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
+ return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
+}
+
+const DomainMetadata& HloInstruction::operand_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->operand_side_metadata();
+}
+
+const DomainMetadata& HloInstruction::user_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->user_side_metadata();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index b3eee90099..4f6cac1396 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -32,6 +32,11 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -45,10 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -80,6 +82,7 @@ class HloPrintOptions {
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
+ print_control_dependencies_(true),
canonicalize_instruction_names_(false),
indent_amount_(0),
is_in_nested_computation_(false) {}
@@ -92,7 +95,8 @@ class HloPrintOptions {
.set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
- .set_print_percent(false);
+ .set_print_percent(false)
+ .set_print_control_dependencies(false);
}
// Options to produce the canonical string representing an isomorphic
@@ -101,10 +105,12 @@ class HloPrintOptions {
return HloPrintOptions()
.set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_compact_operands(true)
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
+ .set_print_control_dependencies(false)
.set_canonicalize_instruction_names(true);
}
@@ -150,6 +156,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, control dependencies will be printed.
+ HloPrintOptions& set_print_control_dependencies(bool value) {
+ print_control_dependencies_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
@@ -182,11 +194,14 @@ class HloPrintOptions {
return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
- bool print_backend_config() const { return print_metadata_; }
+ bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
+ bool print_control_dependencies() const {
+ return print_control_dependencies_;
+ }
bool canonicalize_instruction_names() const {
return canonicalize_instruction_names_;
}
@@ -202,6 +217,7 @@ class HloPrintOptions {
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
+ bool print_control_dependencies_;
bool canonicalize_instruction_names_;
int indent_amount_;
bool is_in_nested_computation_;
@@ -220,7 +236,7 @@ class CanonicalNameMap {
return iter->second;
}
- string new_name = tensorflow::strings::StrCat("tmp_", index++);
+ string new_name = absl::StrCat("tmp_", index++);
canonical_name_map[old_name] = new_name;
return new_name;
}
@@ -343,11 +359,11 @@ class HloInstruction {
const string& name);
// Creates a literal constant instruction.
- static std::unique_ptr<HloInstruction> CreateConstant(
- std::unique_ptr<Literal> literal);
+ static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
// Creates an Iota instruction.
- static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape);
+ static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
+ int64 iota_dimension);
// Creates a get tuple element instruction.
static std::unique_ptr<HloInstruction> CreateGetTupleElement(
@@ -361,7 +377,7 @@ class HloInstruction {
// random numbers from a given distribution.
static std::unique_ptr<HloInstruction> CreateRng(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ absl::Span<HloInstruction* const> parameters);
// Creates a unary instruction (one operand).
// Precondition: opcode must be a legitimate unary operation.
@@ -388,39 +404,34 @@ class HloInstruction {
// Precondition: opcode must be a legitimate variadic operation.
static std::unique_ptr<HloInstruction> CreateVariadic(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
// at a given index)
static std::unique_ptr<HloInstruction> CreateMap(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
// Creates a convolution op, where rhs is the convolutional filter
// and window describes how the filter is applied to lhs.
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const PrecisionConfig& precision_config);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers);
-
- // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
- // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
- // and the RHS must be of rank 2.
- static std::unique_ptr<HloInstruction> CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
@@ -433,9 +444,10 @@ class HloInstruction {
//
// `reduction_computation`: the reduction function.
//
- // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
- // replicas belong to one group. Allreduce will be applied within subgroups.
- // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // `replica_groups`: each ReplicaGroup contains a list of replica id. If
+ // empty, all replicas belong to one group in the order of 0 - (n-1).
+ // Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// `all_reduce_id`: for Allreduce nodes from different modules, if they have
@@ -444,11 +456,10 @@ class HloInstruction {
//
// TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
+ const std::vector<ReplicaGroup>& replica_groups,
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// This op handles the communication of an Alltoall operation. On each core,
// the operands are N ops in the same shape, where N is the number of cores
@@ -463,12 +474,18 @@ class HloInstruction {
// within replica 1, 2, 3, and in the gather phase, the received blocks will
// be concatenated in the order of 1, 2, 3; another Alltoall will be applied
// within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
- //
- // TODO(b/110096724): This is NOT YET ready to use.
static std::unique_ptr<HloInstruction> CreateAllToAll(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier);
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
+
+ // Creates a communitation instructions that permutes data cross replicas.
+ // Data is sent/received according to the (source_replica_id,
+ // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
+ // target_replica_id in any pair, the output on that replica is a tensor
+ // conssits of 0(s) in `shape`.
+ static std::unique_ptr<HloInstruction> CreateCollectivePermute(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
@@ -493,7 +510,7 @@ class HloInstruction {
// which is a TOKEN.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
+ HloInstruction* token_operand, absl::string_view outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
@@ -526,17 +543,15 @@ class HloInstruction {
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specified in
// 'slice_sizes'.
static std::unique_ptr<HloInstruction> CreateDynamicSlice(
const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ HloInstruction* start_indices, absl::Span<const int64> slice_sizes);
// Creates a dynamic update slice instruction, which updates a slice
// of 'operand' with 'update' and 'start_indices'.
@@ -547,7 +562,7 @@ class HloInstruction {
// Creates a concatenate instruction, where the operands are concatenated on
// the provided dimension.
static std::unique_ptr<HloInstruction> CreateConcatenate(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
@@ -559,7 +574,7 @@ class HloInstruction {
// f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// A more general, multiple-argument version of the above.
@@ -574,9 +589,9 @@ class HloInstruction {
// ...
// TODO(b/112040122): Add support to this in HLO passes and in backends.
static std::unique_ptr<HloInstruction> CreateReduce(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::Span<HloInstruction* const> init_values,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Creates a reduce-window instruction, where the computation (given
@@ -613,7 +628,7 @@ class HloInstruction {
// Creates a broadcast instruction.
static std::unique_ptr<HloInstruction> CreateBroadcast(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
@@ -643,7 +658,7 @@ class HloInstruction {
// Creates a transpose instruction which permutes the operand dimensions.
static std::unique_ptr<HloInstruction> CreateTranspose(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a sort op, with a keys operand, and an optional values operand.
static std::unique_ptr<HloInstruction> CreateSort(
@@ -667,9 +682,9 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
static std::unique_ptr<HloInstruction> CreateScatter(
const Shape& shape, HloInstruction* operand,
@@ -693,43 +708,37 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation);
// Creates a call instruction that applies the given computation on the given
// operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
-
- // Creates a HostCompute instruction, which records host-side control and
- // data dependencies for use in instruction scheduling.
- static std::unique_ptr<HloInstruction> CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*> elements);
+ absl::Span<HloInstruction* const> elements);
// Creates a reverse instruction, which reverses the order of the elements
// in the specified dimensions.
static std::unique_ptr<HloInstruction> CreateReverse(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a Afterall instruction used for joining or creating new values of
// token type which thread through side-effecting operations. Operands must
// all be tokens, and there must be at least one operand.
static std::unique_ptr<HloInstruction> CreateAfterAll(
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Creates an AfterAll instruction which creates a token type out of thin air
// (no operands). This is a separate method from CreateAfterAll to facility
@@ -766,7 +775,7 @@ class HloInstruction {
int64 operand_count() const { return operands_.size(); }
// Returns the vector of operands of this instruction.
- using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
+ using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
const InstructionVector& operands() const { return operands_; }
// Returns the vector of unique operands, in the same order they are found
@@ -1030,7 +1039,7 @@ class HloInstruction {
// Returns true if this instruction can be legally fused into a fusion
// instruction.
- bool IsFusable() const;
+ bool IsFusible() const;
// Returns the sharding applied to this operator.
// REQUIRES: has_sharding() is true.
@@ -1038,21 +1047,26 @@ class HloInstruction {
CHECK(has_sharding());
return *sharding_;
}
+ std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
+
// Returns the sharding applied to this operator, or default_ if none exists.
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
}
// Returns the sharding unique device, if any.
- tensorflow::gtl::optional<int64> sharding_unique_device() const {
+ absl::optional<int64> sharding_unique_device() const {
if (sharding_ == nullptr) {
- return tensorflow::gtl::optional<int64>();
+ return absl::optional<int64>();
}
return sharding_->UniqueDevice();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
- sharding_ = MakeUnique<HloSharding>(sharding);
+ sharding_ = std::make_shared<const HloSharding>(sharding);
+ }
+ void set_sharding(std::shared_ptr<const HloSharding> sharding) {
+ sharding_ = std::move(sharding);
}
void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
@@ -1072,15 +1086,6 @@ class HloInstruction {
return other->has_sharding() ? sharding() == other->sharding() : false;
}
- // Retrieves the operand side metadata of a kDomain instruction.
- const DomainMetadata& operand_side_metadata() const {
- return *operand_side_metadata_;
- }
- // Retrieves the user side metadata of a kDomain instruction.
- const DomainMetadata& user_side_metadata() const {
- return *user_side_metadata_;
- }
-
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
// properties of the new instruction are copied into the derived one. As of
@@ -1088,28 +1093,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // TODO(b/80249101): Remove these methods once HLO scheduling and copy
- // insertion are integrated, and we don't need to run a separate pass
- // of copy elision anymore.
- bool CopyElisionAllowed() const {
- CHECK_EQ(HloOpcode::kCopy, opcode_);
- return copy_elision_allowed_;
- }
-
- void SetCopyElisionAllowed(bool value) {
- CHECK_EQ(HloOpcode::kCopy, opcode_);
- copy_elision_allowed_ = value;
- }
-
- // Returns data on the dimension numbers used for a dot operation.
- const DotDimensionNumbers& dot_dimension_numbers() const {
- CHECK(dot_dimension_numbers_ != nullptr);
- return *dot_dimension_numbers_;
- }
-
- // Returns the dump string of the dot dimension numbers.
- string DotDimensionNumbersToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1120,8 +1103,7 @@ class HloInstruction {
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
@@ -1253,6 +1235,16 @@ class HloInstruction {
static StatusOr<string> BackendConfigToRawString(
const tensorflow::protobuf::Message& proto);
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ // Precondition: opcode must be kConvolution or kDot.
+ const PrecisionConfig& precision_config() const;
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1421,18 +1413,18 @@ class HloInstruction {
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const;
- // Delegates to HloAllReduceInstruction::replica_group_ids.
- const std::vector<int64>& replica_group_ids() const;
-
- // Delegates to HloAllToAllInstruction::replica_groups.
+ // Delegates to HloCollectiveInstruction::replica_groups.
const std::vector<ReplicaGroup>& replica_groups() const;
+ // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
+
// Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
string cross_replica_sum_barrier() const;
void set_cross_replica_sum_barrier(const string& barrier);
// Delegates to HloAllReduceInstruction::all_reduce_id.
- tensorflow::gtl::optional<int64> all_reduce_id() const;
+ absl::optional<int64> all_reduce_id() const;
// Returns data on the window in a windowed operation such as
// convolution.
@@ -1460,6 +1452,8 @@ class HloInstruction {
// dimension and output feature dimension.
int64 feature_group_count() const;
+ void set_feature_group_count(int64 feature_group_count);
+
// Delegates to HloSelectAndScatterInstruction::select.
HloComputation* select() const;
@@ -1475,9 +1469,6 @@ class HloInstruction {
// Delegates to HloCustomCallInstruction::custom_call_target.
const string& custom_call_target() const;
- // Delegates to HloHostComputeInstruction::channel_name.
- const string& channel_name() const;
-
// Delegates to HloPadInstruction::padding_config.
const PaddingConfig& padding_config() const;
@@ -1489,12 +1480,21 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_dimension_numbers.
const GatherDimensionNumbers& gather_dimension_numbers() const;
- // Delegates to HloGatherInstruction::gather_window_bounds.
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloGatherInstruction::gather_slice_sizes.
+ absl::Span<const int64> gather_slice_sizes() const;
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+ // Delegates to HloDotInstruction::dot_dimension_numbers().
+ const DotDimensionNumbers& dot_dimension_numbers() const;
+
+ // Delegates to HloDomainInstruction::operand_side_metadata().
+ const DomainMetadata& operand_side_metadata() const;
+
+ // Delegates to HloDomainInstruction::user_side_metadata().
+ const DomainMetadata& user_side_metadata() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1516,7 +1516,7 @@ class HloInstruction {
// Removes a list of operands with the given indices in ascending order.
void RemoveOperandsAtAscendingIndices(
- tensorflow::gtl::ArraySlice<int> ascending_indices);
+ absl::Span<const int> ascending_indices);
void AppendComputation(HloComputation* computation) {
called_computations_.push_back(computation);
@@ -1546,8 +1546,7 @@ class HloInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
// TODO(b/80131774): This should be pure virtual.
LOG(FATAL) << "Unimplemented method.";
@@ -1565,7 +1564,7 @@ class HloInstruction {
// NOTE: For all instructions other than kFusion, being elementwise on one of
// the operands is equivalent to being elementwise on all the operands.
virtual bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const;
+ const absl::optional<int64>& operand_idx) const;
// Prints an instruction to a string.
//
// The canonical string representation needs to name operands and instruction
@@ -1593,7 +1592,7 @@ class HloInstruction {
// Creates an n-ary elementwise operation.
static std::unique_ptr<HloInstruction> CreateNary(
const Shape& shape, HloOpcode opcode,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ absl::Span<HloInstruction* const> operands);
// Adds a user for this instruction.
void AddUser(HloInstruction* user);
@@ -1635,18 +1634,11 @@ class HloInstruction {
// Result shape of this instruction.
Shape shape_;
- // Describes the dimension numbers used for a dot.
- std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
-
- // Used to tag kCopy instructions that are eligible for copy elision.
- bool copy_elision_allowed_ = true;
-
// The sharding, if one exists.
- std::unique_ptr<HloSharding> sharding_;
-
- // Fields used by the kDomain instruction.
- std::unique_ptr<DomainMetadata> operand_side_metadata_;
- std::unique_ptr<DomainMetadata> user_side_metadata_;
+ // Uses std::shared_ptr to allow reuse of the same sharding object between
+ // HloInstructions and other components as HloSharding can be very large for
+ // many element tuples.
+ std::shared_ptr<const HloSharding> sharding_;
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
@@ -1683,10 +1675,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string PrecisionToString(const PrecisionConfig::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 8a694dde80..c1b7c3832b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -39,10 +39,8 @@ namespace {
using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;
-class HloInstructionTest : public HloTestBase {
+class HloInstructionTest : public HloVerifiedTestBase {
protected:
- HloInstructionTest() {}
-
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};
@@ -53,7 +51,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
public:
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("not implemented %s",
- HloOpcodeString(hlo_instruction->opcode()).c_str());
+ HloOpcodeString(hlo_instruction->opcode()));
}
Status HandleParameter(HloInstruction* parameter) override {
@@ -1086,16 +1084,14 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
// Fused expression:
- //
- // x y
- // \ / \
- // min broadcast
+ // y
+ // /
+ // x broadcast
+ // \ / |
+ // min |
// \ /
// sub
//
- // The fusion instruction is elementwise on `x` because the only path from x
- // to sub contains only elementwise operations. It is not elementwise on `y`
- // because the path y->broadcast->sub is not all elementwise.
const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
@@ -1104,10 +1100,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
- HloInstruction* min = builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y));
HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0}));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
@@ -1118,10 +1114,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
EXPECT_FALSE(fusion->IsElementwise());
for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
++operand_idx) {
- if (fusion->operand(operand_idx) == x) {
- EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
- } else {
+ if (fusion->operand(operand_idx) == y) {
EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
+ } else {
+ EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
}
}
}
@@ -1151,8 +1147,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1192,8 +1188,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1243,12 +1239,12 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape, one, {1}));
+ HloInstruction::CreateBroadcast(data_shape, one, {}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, dot, add_operand));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1324,8 +1320,8 @@ TEST_F(HloInstructionTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().set_print_metadata(false);
@@ -1355,7 +1351,7 @@ TEST_F(HloInstructionTest, Stringification) {
TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
@@ -1363,19 +1359,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1383,15 +1378,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=4, window_bounds={30,29,28,27,26}");
+ "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyGather_1) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
- Shape gather_indices_tensor_shape =
+ Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
@@ -1399,19 +1394,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
- HloInstruction* gather_indices =
+ HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
- 1, gather_indices_tensor_shape, "gather_indices"));
-
- HloInstruction* gather_instruction =
- builder.AddInstruction(HloInstruction::CreateGather(
- gather_result_shape, input, gather_indices,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ 1, start_indices_tensor_shape, "start_indices"));
+
+ HloInstruction* gather_instruction = builder.AddInstruction(
+ HloInstruction::CreateGather(gather_result_shape, input, start_indices,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1419,10 +1413,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
- "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), "
- "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
- "gather_dims_to_operand_dims={0,1,2,3,4}, "
- "index_vector_dim=2, window_bounds={30,29,28,27,26}");
+ "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
+ "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
+ "start_index_map={0,1,2,3,4}, "
+ "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
@@ -1491,8 +1485,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().Canonical();
@@ -1533,8 +1527,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1589,8 +1583,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1745,5 +1739,23 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
<< clone->convolution_dimension_numbers().DebugString();
}
+TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
+ constexpr char kHloString[] = R"(
+ HloModule test_module
+ ENTRY test {
+ arg0 = f32[1,2,1] parameter(0)
+ arg1 = f32[1,1,1] parameter(1)
+ ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
+ dim_labels=b0f_0io->b0f, operand_precision={high,default}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+
+ auto clone = conv->Clone();
+ EXPECT_THAT(
+ clone->precision_config().operand_precision(),
+ ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 233cdda7b0..e92882c22a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -17,6 +17,12 @@ limitations under the License.
#include <deque>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/escaping.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -27,10 +33,10 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::str_util::CEscape;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::CEscape;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) {
@@ -41,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
return instruction->IsElementwiseOnOperand(operand_index);
});
}
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+ if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
+ })) {
+ return "";
+ }
+
+ return StrCat(
+ "operand_precision={",
+ StrJoin(
+ precision_config.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
+ "}");
+}
} // namespace
HloBatchNormInstruction::HloBatchNormInstruction(
@@ -85,11 +112,10 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloBatchNormTrainingInstruction>(
+ return absl::make_unique<HloBatchNormTrainingInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
feature_index());
}
@@ -107,11 +133,10 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormInferenceInstruction>(
+ return absl::make_unique<HloBatchNormInferenceInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
@@ -129,18 +154,17 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction(
std::unique_ptr<HloInstruction>
HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 5);
- return MakeUnique<HloBatchNormGradInstruction>(
+ return absl::make_unique<HloBatchNormGradInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
new_operands[4], epsilon(), feature_index());
}
-HloFftInstruction::HloFftInstruction(
- const Shape& shape, HloInstruction* operand, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length)
+HloFftInstruction::HloFftInstruction(const Shape& shape,
+ HloInstruction* operand, FftType fft_type,
+ absl::Span<const int64> fft_length)
: HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
fft_length_.assign(fft_length.begin(), fft_length.end());
AppendOperand(operand);
@@ -158,7 +182,7 @@ HloInstructionProto HloFftInstruction::ToProto() const {
std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("fft_type=", FftType_Name(fft_type())),
- StrCat("fft_length={", Join(fft_length(), ","), "}")};
+ StrCat("fft_length={", StrJoin(fft_length(), ","), "}")};
}
bool HloFftInstruction::IdenticalSlowPath(
@@ -171,12 +195,11 @@ bool HloFftInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_,
- fft_length_);
+ return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_,
+ fft_length_);
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
@@ -226,12 +249,11 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand,
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
- channel_id(), is_host_transfer());
+ return absl::make_unique<HloSendInstruction>(
+ new_operands[0], new_operands[1], channel_id(), is_host_transfer());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
@@ -244,11 +266,10 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
std::unique_ptr<HloInstruction>
HloSendDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendDoneInstruction>(
+ return absl::make_unique<HloSendDoneInstruction>(
Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
}
@@ -265,11 +286,10 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape,
}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvInstruction>(
+ return absl::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
is_host_transfer());
}
@@ -287,35 +307,69 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
std::unique_ptr<HloInstruction>
HloRecvDoneInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloRecvDoneInstruction>(
+ return absl::make_unique<HloRecvDoneInstruction>(
Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
-HloAllReduceInstruction::HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id)
- : HloInstruction(HloOpcode::kCrossReplicaSum, shape),
- replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
- all_reduce_id_(all_reduce_id) {
+HloCollectiveInstruction::HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups)
+ : HloInstruction(opcode, shape), replica_groups_(replica_groups) {
for (auto operand : operands) {
AppendOperand(operand);
}
- AppendComputation(reduce_computation);
}
-HloInstructionProto HloAllReduceInstruction::ToProto() const {
+HloInstructionProto HloCollectiveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- for (int64 i : replica_group_ids_) {
- proto.add_replica_group_ids(i);
+ *proto.mutable_replica_groups() = {replica_groups_.begin(),
+ replica_groups_.end()};
+ return proto;
+}
+
+std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> replica_group_str;
+ for (const ReplicaGroup& group : replica_groups()) {
+ replica_group_str.push_back(
+ StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
}
+ result.push_back(
+ StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}"));
+ return result;
+}
+
+bool HloCollectiveInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectiveInstruction&>(other);
+ return absl::c_equal(replica_groups(), casted_other.replica_groups(),
+ [](const ReplicaGroup& a, const ReplicaGroup& b) {
+ return absl::c_equal(a.replica_ids(), b.replica_ids());
+ });
+}
+
+HloAllReduceInstruction::HloAllReduceInstruction(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ HloComputation* reduce_computation,
+ const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
+ const absl::optional<int64>& all_reduce_id)
+ : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands,
+ replica_groups),
+ cross_replica_sum_barrier_(barrier),
+ all_reduce_id_(all_reduce_id) {
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloAllReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloCollectiveInstruction::ToProto();
// Proto3 is so sad.
if (all_reduce_id_) {
proto.set_all_reduce_id(*all_reduce_id_);
@@ -325,9 +379,9 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const {
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
- const HloPrintOptions& /*options*/) const {
- std::vector<string> result = {
- StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")};
+ const HloPrintOptions& options) const {
+ std::vector<string> result =
+ HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (!cross_replica_sum_barrier().empty()) {
result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
}
@@ -342,7 +396,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
- return replica_group_ids() == casted_other.replica_group_ids() &&
+ return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
eq_computations(to_apply(), casted_other.to_apply()) &&
cross_replica_sum_barrier() ==
casted_other.cross_replica_sum_barrier() &&
@@ -351,78 +405,80 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloAllReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllReduceInstruction>(
- shape, new_operands, to_apply(), replica_group_ids(),
+ return absl::make_unique<HloAllReduceInstruction>(
+ shape, new_operands, to_apply(), replica_groups(),
cross_replica_sum_barrier(), all_reduce_id());
}
HloAllToAllInstruction::HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier)
- : HloInstruction(HloOpcode::kAllToAll, shape),
- replica_groups_(replica_groups),
- cross_replica_sum_barrier_(barrier.begin(), barrier.end()) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-bool HloAllToAllInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other);
- return ContainersEqual(replica_groups(), casted_other.replica_groups(),
- [](const ReplicaGroup& a, const ReplicaGroup& b) {
- return ContainersEqual(a.replica_ids(),
- b.replica_ids());
- }) &&
- cross_replica_sum_barrier() ==
- casted_other.cross_replica_sum_barrier();
-}
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups)
+ : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
+ replica_groups) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
- return MakeUnique<HloAllToAllInstruction>(
- shape, new_operands, replica_groups(), cross_replica_sum_barrier());
+ return absl::make_unique<HloAllToAllInstruction>(shape, new_operands,
+ replica_groups());
}
-std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl(
- const HloPrintOptions& options) const {
- std::vector<string> result;
- std::vector<string> replica_group_str;
- for (const ReplicaGroup& group : replica_groups()) {
- replica_group_str.push_back(
- StrCat("{", Join(group.replica_ids(), ","), "}"));
- }
- result.push_back(
- StrCat("replica_groups={", Join(replica_group_str, ","), "}"));
+HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs)
+ : HloInstruction(HloOpcode::kCollectivePermute, shape),
+ source_target_pairs_(source_target_pairs) {
+ AppendOperand(operand);
+}
- if (!cross_replica_sum_barrier().empty()) {
- result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (const auto& pair : source_target_pairs()) {
+ auto* proto_pair = proto.add_source_target_pairs();
+ proto_pair->set_source(pair.first);
+ proto_pair->set_target(pair.second);
}
+ return proto;
+}
+std::vector<string>
+HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result;
+ std::vector<string> strs;
+ for (const auto& pair : source_target_pairs()) {
+ strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
+ }
+ result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}"));
return result;
}
-HloInstructionProto HloAllToAllInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_replica_groups() = {replica_groups_.begin(),
- replica_groups_.end()};
- proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
- return proto;
+bool HloCollectivePermuteInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ /*eq_computations*/) const {
+ const auto& casted_other =
+ static_cast<const HloCollectivePermuteInstruction&>(other);
+ return absl::c_equal(source_target_pairs(),
+ casted_other.source_target_pairs(),
+ [](const std::pair<int64, int64>& a,
+ const std::pair<int64, int64>& b) { return a == b; });
}
-HloReverseInstruction::HloReverseInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+std::unique_ptr<HloInstruction>
+HloCollectivePermuteInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* /*context*/) const {
+ return absl::make_unique<HloCollectivePermuteInstruction>(
+ shape, new_operands[0], source_target_pairs());
+}
+
+HloReverseInstruction::HloReverseInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kReverse, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
AppendOperand(operand);
@@ -438,7 +494,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const {
std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReverseInstruction::IdenticalSlowPath(
@@ -450,16 +506,15 @@ bool HloReverseInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReverseInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloReverseInstruction>(shape, new_operands[0],
+ dimensions());
}
HloConcatenateInstruction::HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
int64 dimension)
: HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
for (auto operand : operands) {
@@ -477,7 +532,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const {
std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloConcatenateInstruction::IdenticalSlowPath(
@@ -491,16 +546,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConcatenateInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConcatenateInstruction>(shape, new_operands,
- dimensions(0));
+ return absl::make_unique<HloConcatenateInstruction>(shape, new_operands,
+ dimensions(0));
}
HloReduceInstruction::HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const Shape& shape, absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
@@ -520,7 +574,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const {
std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloReduceInstruction::IdenticalSlowPath(
@@ -535,12 +589,11 @@ bool HloReduceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
- to_apply());
+ CHECK_EQ(new_operands.size() % 2, 0);
+ return absl::make_unique<HloReduceInstruction>(shape, new_operands,
+ dimensions(), to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
@@ -563,7 +616,7 @@ HloInstructionProto HloSortInstruction::ToProto() const {
std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloSortInstruction::IdenticalSlowPath(
@@ -575,17 +628,17 @@ bool HloSortInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
- return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+ return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys,
+ values);
}
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions)
+ absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
CHECK_EQ(shape.dimensions().size(), dimensions.size());
@@ -595,7 +648,7 @@ HloTransposeInstruction::HloTransposeInstruction(
Permute(dimensions, shape.dimensions()).begin()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << Join(dimensions, ", ") << "}";
+ << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -616,7 +669,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const {
std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloTransposeInstruction::IdenticalSlowPath(
@@ -629,17 +682,16 @@ bool HloTransposeInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloTransposeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloTransposeInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0],
+ dimensions());
}
HloBroadcastInstruction::HloBroadcastInstruction(
const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension)
+ absl::Span<const int64> broadcast_dimension)
: HloInstruction(HloOpcode::kBroadcast, shape),
dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
AppendOperand(operand);
@@ -655,7 +707,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const {
std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloBroadcastInstruction::IdenticalSlowPath(
@@ -668,17 +720,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloBroadcastInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0],
- dimensions());
+ return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0],
+ dimensions());
}
-HloMapInstruction::HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation)
+HloMapInstruction::HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation)
: HloInstruction(HloOpcode::kMap, shape) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -699,7 +750,7 @@ HloInstructionProto HloMapInstruction::ToProto() const {
}
bool HloMapInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!dimensions().empty()) {
// Check that the map is executed in elementwise compatible dimensions.
if (dimensions().size() != shape().dimensions_size()) {
@@ -716,7 +767,7 @@ bool HloMapInstruction::IsElementwiseImpl(
std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+ return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")};
}
bool HloMapInstruction::IdenticalSlowPath(
@@ -727,17 +778,16 @@ bool HloMapInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply());
+ return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply());
}
-HloSliceInstruction::HloSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides)
+HloSliceInstruction::HloSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides)
: HloInstruction(HloOpcode::kSlice, shape),
slice_starts_(start_indices.begin(), start_indices.end()),
slice_limits_(limit_indices.begin(), limit_indices.end()),
@@ -774,7 +824,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
bounds.push_back(
StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
}
- return {StrCat("slice={", Join(bounds, ", "), "}")};
+ return {StrCat("slice={", StrJoin(bounds, ", "), "}")};
}
bool HloSliceInstruction::IdenticalSlowPath(
@@ -788,16 +838,15 @@ bool HloSliceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
- slice_limits_, slice_strides_);
+ return absl::make_unique<HloSliceInstruction>(
+ shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
}
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
- : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+ : HloInstruction(HloOpcode::kConstant, literal.shape()),
literal_(std::move(literal)) {}
HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -805,14 +854,14 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape)
HloInstructionProto HloConstantInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- if (literal_ != nullptr) {
+ if (literal_.has_value()) {
*proto.mutable_literal() = literal_->ToProto();
}
return proto;
}
bool HloConstantInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
@@ -827,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
if (!mutable_array_subshape->has_layout() ||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
+ *literal_ = literal_->Relayout(new_layout, shape_index);
*mutable_array_subshape->mutable_layout() = new_layout;
}
}
@@ -842,10 +891,10 @@ bool HloConstantInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloConstantInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
+ CHECK(literal_.has_value());
+ return absl::make_unique<HloConstantInstruction>(literal_->Clone());
}
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -853,14 +902,14 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if (literal_ != nullptr &&
+ if (literal_.has_value() &&
((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants())) {
// Literal::ToString emits multidimensional arrays over multiple
// lines. Compact this into one line by stripping out white space.
string tmp = literal().ToString();
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
- std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
+ std::vector<string> v = absl::StrSplit(tmp, ' ');
bool first = true;
// Concatenate elements in "v" with spaces separating them, but ignoring
// empty entries.
@@ -888,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstructionProto HloTraceInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
- *proto.mutable_literal() = literal_->ToProto();
+ *proto.mutable_literal() = literal_.ToProto();
return proto;
}
@@ -900,8 +949,7 @@ bool HloTraceInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
}
@@ -919,7 +967,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape,
HloFusionInstruction::HloFusionInstruction(
const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ absl::Span<HloInstruction* const> operands,
HloComputation* fusion_computation)
: HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
for (auto operand : operands) {
@@ -952,7 +1000,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const {
}
bool HloFusionInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
if (!operand_idx.has_value()) {
for (auto* fused : fused_instructions()) {
if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
@@ -1155,7 +1203,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal(
HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
HloInstruction* instruction_to_fuse, bool add_output) {
- CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
+ CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString();
VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
HloInstruction* clone = nullptr;
if (called_computations().empty()) {
@@ -1326,8 +1374,7 @@ bool HloFusionInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloModule* module = context != nullptr ? context->module() : GetModule();
HloComputation* new_fused_computation = nullptr;
@@ -1339,8 +1386,8 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation = module->AddEmbeddedComputation(
fused_instructions_computation()->Clone("clone", context));
}
- return MakeUnique<HloFusionInstruction>(shape, fusion_kind(), new_operands,
- new_fused_computation);
+ return absl::make_unique<HloFusionInstruction>(
+ shape, fusion_kind(), new_operands, new_fused_computation);
}
Status HloFusionInstruction::DeduplicateFusionOperands() {
@@ -1365,7 +1412,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() {
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
+ absl::Span<HloInstruction* const> parameters)
: HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
for (HloInstruction* param : parameters) {
AppendOperand(param);
@@ -1384,7 +1431,7 @@ std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
}
bool HloRngInstruction::IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const {
+ const absl::optional<int64>& operand_idx) const {
return true;
}
@@ -1396,10 +1443,10 @@ bool HloRngInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloRngInstruction>(shape, distribution_, new_operands);
+ return absl::make_unique<HloRngInstruction>(shape, distribution_,
+ new_operands);
}
HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
@@ -1432,10 +1479,10 @@ bool HloParameterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloParameterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name());
+ return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
+ name());
}
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
@@ -1467,12 +1514,11 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0],
- tuple_index());
+ return absl::make_unique<HloGetTupleElementInstruction>(
+ shape, new_operands[0], tuple_index());
}
HloReducePrecisionInstruction::HloReducePrecisionInstruction(
@@ -1510,11 +1556,10 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloReducePrecisionInstruction>(
+ return absl::make_unique<HloReducePrecisionInstruction>(
shape, new_operands[0], exponent_bits(), mantissa_bits());
}
@@ -1551,20 +1596,20 @@ bool HloInfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
- infeed_config());
+ return absl::make_unique<HloInfeedInstruction>(
+ infeed_shape(), new_operands[0], infeed_config());
}
-HloOutfeedInstruction::HloOutfeedInstruction(
- const Shape& outfeed_shape, HloInstruction* operand,
- HloInstruction* token_operand, tensorflow::StringPiece outfeed_config)
+HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ HloInstruction* token_operand,
+ absl::string_view outfeed_config)
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
- outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ outfeed_config_(outfeed_config) {
CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
<< "Outfeed shape " << outfeed_shape
<< " must be compatible with operand shape " << operand->shape();
@@ -1596,22 +1641,23 @@ bool HloOutfeedInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
- new_operands[1], outfeed_config());
+ return absl::make_unique<HloOutfeedInstruction>(
+ outfeed_shape(), new_operands[0], new_operands[1], outfeed_config());
}
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count)
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
: HloInstruction(HloOpcode::kConvolution, shape),
+ feature_group_count_(feature_group_count),
window_(window),
convolution_dimension_numbers_(dimension_numbers),
- feature_group_count_(feature_group_count) {
+ precision_config_(precision_config) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1638,6 +1684,8 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_window() = window_;
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
+ proto.set_feature_group_count(feature_group_count_);
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
@@ -1649,7 +1697,15 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
- extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
return extra;
}
@@ -1659,21 +1715,25 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other =
static_cast<const HloConvolutionInstruction&>(other);
+ if (feature_group_count_ != other.feature_group_count()) {
+ return false;
+ }
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
- casted_other.convolution_dimension_numbers());
+ casted_other.convolution_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction>
HloConvolutionInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloConvolutionInstruction>(
- shape, new_operands[0], new_operands[1], window(),
- convolution_dimension_numbers_, feature_group_count_);
+ return absl::make_unique<HloConvolutionInstruction>(
+ shape, new_operands[0], new_operands[1], feature_group_count_, window(),
+ convolution_dimension_numbers_, precision_config_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -1712,11 +1772,10 @@ bool HloReduceWindowInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloReduceWindowInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceWindowInstruction>(
+ return absl::make_unique<HloReduceWindowInstruction>(
shape, new_operands[0], new_operands[1], window(), to_apply());
}
@@ -1761,21 +1820,20 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloSelectAndScatterInstruction>(
+ return absl::make_unique<HloSelectAndScatterInstruction>(
shape, new_operands[0], select(), window(), new_operands[1],
new_operands[2], scatter());
}
HloCustomCallInstruction::HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target)
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
- custom_call_target_(custom_call_target.begin(),
- custom_call_target.end()) {
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1791,6 +1849,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1805,6 +1864,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
"dim_labels=",
ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -1832,60 +1894,28 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.convolution_dimension_numbers()))) {
return false;
}
+ if (feature_group_count_ != casted_other.feature_group_count_) {
+ return false;
+ }
return custom_call_target_ == casted_other.custom_call_target_;
}
std::unique_ptr<HloInstruction>
HloCustomCallInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
- auto cloned = MakeUnique<HloCustomCallInstruction>(shape, new_operands,
- custom_call_target());
+ auto cloned = absl::make_unique<HloCustomCallInstruction>(
+ shape, new_operands, custom_call_target());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
+ cloned->set_feature_group_count(feature_group_count_);
return std::move(cloned);
}
-HloHostComputeInstruction::HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns)
- : HloInstruction(HloOpcode::kHostCompute, shape),
- channel_name_(channel_name.begin(), channel_name.end()),
- cost_estimate_ns_(cost_estimate_ns) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-HloInstructionProto HloHostComputeInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- proto.set_channel_name(channel_name_);
- proto.set_cost_estimate_ns(cost_estimate_ns_);
- return proto;
-}
-
-bool HloHostComputeInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- // Not yet supported.
- return false;
-}
-
-std::unique_ptr<HloInstruction>
-HloHostComputeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const {
- return MakeUnique<HloHostComputeInstruction>(
- shape, new_operands, channel_name_, cost_estimate_ns_);
-}
-
HloPadInstruction::HloPadInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* padding_value,
@@ -1916,17 +1946,16 @@ bool HloPadInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloPadInstruction>(shape, new_operands[0], new_operands[1],
- padding_config_);
+ return absl::make_unique<HloPadInstruction>(shape, new_operands[0],
+ new_operands[1], padding_config_);
}
HloDynamicSliceInstruction::HloDynamicSliceInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kDynamicSlice, shape),
dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
AppendOperand(operand);
@@ -1943,8 +1972,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {
- StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")};
+ return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","),
+ "}")};
}
bool HloDynamicSliceInstruction::IdenticalSlowPath(
@@ -1956,60 +1985,57 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
std::unique_ptr<HloInstruction>
HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloDynamicSliceInstruction>(
+ return absl::make_unique<HloDynamicSliceInstruction>(
shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
}
HloGatherInstruction::HloGatherInstruction(
- const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds)
+ absl::Span<const int64> slice_sizes)
: HloInstruction(HloOpcode::kGather, shape) {
AppendOperand(operand);
- AppendOperand(gather_indices);
+ AppendOperand(start_indices);
gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+ absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers);
+ absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_));
}
string HloGatherInstruction::GatherDimensionNumbersToString() const {
CHECK(gather_dimension_numbers_ != nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string offset_dims =
+ StrCat("offset_dims={",
+ StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}");
+ string collapsed_slice_dims = StrCat(
+ "collapsed_slice_dims={",
+ StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}");
+ string start_index_map =
+ StrCat("start_index_map={",
+ StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}");
string index_vector_dim = StrCat(
"index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
- return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
+ return StrJoin<std::initializer_list<string>>(
+ {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim},
", ");
}
/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim) {
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim) {
GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
+ for (int64 output_window_dim : offset_dims) {
+ gather_dim_numbers.add_offset_dims(output_window_dim);
}
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ for (int64 elided_window_dim : collapsed_slice_dims) {
+ gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim);
}
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ for (int64 gather_dim_to_input_dim : start_index_map) {
+ gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim);
}
gather_dim_numbers.set_index_vector_dim(index_vector_dim);
@@ -2019,8 +2045,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const {
HloInstructionProto HloGatherInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
+ for (int64 bound : gather_slice_sizes()) {
+ proto.add_gather_slice_sizes(bound);
}
return proto;
}
@@ -2028,7 +2054,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const {
std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {GatherDimensionNumbersToString(),
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+ StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")};
}
bool HloGatherInstruction::IdenticalSlowPath(
@@ -2039,17 +2065,16 @@ bool HloGatherInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(
gather_dimension_numbers(),
casted_other.gather_dimension_numbers()) &&
- gather_window_bounds() == casted_other.gather_window_bounds();
+ gather_slice_sizes() == casted_other.gather_slice_sizes();
}
std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloGatherInstruction>(
+ return absl::make_unique<HloGatherInstruction>(
shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
- gather_window_bounds());
+ gather_slice_sizes());
}
HloScatterInstruction::HloScatterInstruction(
@@ -2063,24 +2088,24 @@ HloScatterInstruction::HloScatterInstruction(
AppendOperand(updates);
AppendComputation(update_computation);
scatter_dimension_numbers_ =
- MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+ absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers);
}
string HloScatterInstruction::ScatterDimensionNumbersToString() const {
- string update_window_dims =
- StrCat("update_window_dims={",
- Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string update_window_dims = StrCat(
+ "update_window_dims={",
+ StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}");
string inserted_window_dims = StrCat(
"inserted_window_dims={",
- Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
string scatter_dims_to_operand_dims = StrCat(
"scatter_dims_to_operand_dims={",
- Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
"}");
string index_vector_dim = StrCat(
"index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
- return Join<std::initializer_list<string>>(
+ return StrJoin<std::initializer_list<string>>(
{update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
index_vector_dim},
", ");
@@ -2088,9 +2113,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const {
/* static */ ScatterDimensionNumbers
HloScatterInstruction::MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim) {
ScatterDimensionNumbers scatter_dim_numbers;
for (int64 update_window_dim : update_window_dims) {
@@ -2130,13 +2155,150 @@ bool HloScatterInstruction::IdenticalSlowPath(
}
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 3);
- return MakeUnique<HloScatterInstruction>(
+ return absl::make_unique<HloScatterInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
scatter_dimension_numbers());
}
+HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
+ : HloInstruction(HloOpcode::kIota, shape),
+ iota_dimension_(iota_dimension) {}
+
+HloInstructionProto HloIotaInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.add_dimensions(iota_dimension());
+ return proto;
+}
+
+std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("iota_dimension=", iota_dimension())};
+}
+
+bool HloIotaInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
+ return iota_dimension() == casted_other.iota_dimension();
+}
+
+std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
+}
+
+HloDotInstruction::HloDotInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
+ : HloInstruction(HloOpcode::kDot, shape),
+ dot_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+ *proto.mutable_precision_config() = precision_config_;
+ return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra = {DotDimensionNumbersToString()};
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+ return extra;
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ casted_other.dot_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloDotInstruction>(
+ shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+ precision_config_);
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+ std::vector<string> result;
+ const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("lhs_batch_dims={",
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("lhs_contracting_dims={",
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
+
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("rhs_batch_dims={",
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("rhs_contracting_dims={",
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
+
+ return StrJoin(result, ", ");
+}
+
+HloDomainInstruction::HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata)
+ : HloInstruction(HloOpcode::kDomain, shape),
+ operand_side_metadata_(std::move(operand_side_metadata)),
+ user_side_metadata_(std::move(user_side_metadata)) {
+ AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}")};
+ }
+ return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+ return operand_side_metadata().Matches(
+ casted_other.operand_side_metadata()) &&
+ user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloDomainInstruction>(
+ shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 546949bc72..2d7bc83855 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@@ -66,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -81,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -96,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -105,7 +103,7 @@ class HloFftInstruction : public HloInstruction {
public:
explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ absl::Span<const int64> fft_length);
FftType fft_type() const { return fft_type_; }
const std::vector<int64>& fft_length() const { return fft_length_; }
@@ -123,8 +121,7 @@ class HloFftInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes FFT type for an FFT instruction.
@@ -173,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -186,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -199,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
@@ -212,24 +206,41 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
private:
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
};
-class HloAllReduceInstruction : public HloInstruction {
+class HloCollectiveInstruction : public HloInstruction {
+ public:
+ const std::vector<ReplicaGroup>& replica_groups() const {
+ return replica_groups_;
+ }
+
+ protected:
+ explicit HloCollectiveInstruction(
+ HloOpcode opcode, const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
+
+ HloInstructionProto ToProto() const override;
+
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ std::vector<ReplicaGroup> replica_groups_;
+};
+
+class HloAllReduceInstruction : public HloCollectiveInstruction {
public:
explicit HloAllReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids,
- tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id);
-
- // Returns the group ids of each replica for CrossReplicaSum op.
- const std::vector<int64>& replica_group_ids() const {
- return replica_group_ids_;
- }
+ const std::vector<ReplicaGroup>& replica_groups,
+ absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
// Returns the barrier config used for the CrossReplicaSum implementation of
// each backend.
@@ -240,9 +251,7 @@ class HloAllReduceInstruction : public HloInstruction {
cross_replica_sum_barrier_ = barrier;
}
- tensorflow::gtl::optional<int64> all_reduce_id() const {
- return all_reduce_id_;
- }
+ absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -257,41 +266,42 @@ class HloAllReduceInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // The group id of each replica for CrossReplicaSum.
- std::vector<int64> replica_group_ids_;
-
// The string representation of the barrier config used for CrossReplicaSum.
string cross_replica_sum_barrier_;
// For Allreduce nodes from different modules, if they have the same
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
- tensorflow::gtl::optional<int64> all_reduce_id_;
+ absl::optional<int64> all_reduce_id_;
};
-class HloAllToAllInstruction : public HloInstruction {
+class HloAllToAllInstruction : public HloCollectiveInstruction {
public:
explicit HloAllToAllInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
- const std::vector<ReplicaGroup>& replica_groups,
- tensorflow::StringPiece barrier);
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ const std::vector<ReplicaGroup>& replica_groups);
- const std::vector<ReplicaGroup>& replica_groups() const {
- return replica_groups_;
- }
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+};
- // TODO(b/110096724): rename this.
- void set_cross_replica_sum_barrier(string barrier) {
- cross_replica_sum_barrier_ = barrier;
- }
- string cross_replica_sum_barrier() const {
- return cross_replica_sum_barrier_;
+class HloCollectivePermuteInstruction : public HloInstruction {
+ public:
+ explicit HloCollectivePermuteInstruction(
+ const Shape& shape, HloInstruction* operand,
+ const std::vector<std::pair<int64, int64>>& source_target_pairs);
+
+ const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
+ return source_target_pairs_;
}
+ // Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
@@ -304,20 +314,16 @@ class HloAllToAllInstruction : public HloInstruction {
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- std::vector<ReplicaGroup> replica_groups_;
-
- // The string representation of the barrier config.
- string cross_replica_sum_barrier_;
+ const std::vector<std::pair<int64, int64>> source_target_pairs_;
};
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -333,8 +339,7 @@ class HloReverseInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -342,9 +347,9 @@ class HloReverseInstruction : public HloInstruction {
class HloConcatenateInstruction : public HloInstruction {
public:
- explicit HloConcatenateInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- int64 dimension);
+ explicit HloConcatenateInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ int64 dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -362,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -371,26 +375,28 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
- explicit HloReduceInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reduce_computation);
+ explicit HloReduceInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> args,
+ absl::Span<const int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns the number of input arrays (and, consequentially, the number of
+ // init values) this reduce has.
+ int64 input_count() const { return operand_count() / 2; }
+
// Returns the input tensors to be reduced.
- tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0,
- operand_count() / 2);
+ absl::Span<HloInstruction* const> inputs() const {
+ return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
- tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
- return tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands(), operand_count() / 2, operand_count());
+ absl::Span<HloInstruction* const> init_values() const {
+ return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
private:
@@ -402,8 +408,7 @@ class HloReduceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -431,8 +436,7 @@ class HloSortInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -440,9 +444,8 @@ class HloSortInstruction : public HloInstruction {
class HloTransposeInstruction : public HloInstruction {
public:
- explicit HloTransposeInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> dimensions);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -460,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -469,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction {
class HloBroadcastInstruction : public HloInstruction {
public:
- explicit HloBroadcastInstruction(
- const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimension);
+ explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
+ absl::Span<const int64> broadcast_dimension);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -487,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -496,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction {
class HloMapInstruction : public HloInstruction {
public:
- explicit HloMapInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation);
+ explicit HloMapInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* map_computation);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
@@ -507,7 +507,7 @@ class HloMapInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -516,8 +516,7 @@ class HloMapInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::vector<int64> dimensions_;
@@ -526,9 +525,9 @@ class HloMapInstruction : public HloInstruction {
class HloSliceInstruction : public HloInstruction {
public:
explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ absl::Span<const int64> start_indices,
+ absl::Span<const int64> limit_indices,
+ absl::Span<const int64> strides);
HloInstructionProto ToProto() const override;
@@ -567,8 +566,7 @@ class HloSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [begin, end) index range for a slice.
@@ -582,13 +580,13 @@ class HloSliceInstruction : public HloInstruction {
class HloConstantInstruction : public HloInstruction {
public:
- explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ explicit HloConstantInstruction(Literal literal);
// Used when the literal is too large and dropped.
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
// Returns whether there is literal associated with this instruction.
- bool HasLiteral() const { return literal_ != nullptr; }
+ bool HasLiteral() const { return literal_.has_value(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -600,7 +598,7 @@ class HloConstantInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
@@ -610,18 +608,16 @@ class HloConstantInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ absl::optional<Literal> literal_;
};
class HloTraceInstruction : public HloInstruction {
public:
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
// Returns a tag to be used in tracing.
- string TracingTag() const { return literal_->GetR1U8AsString(); }
+ string TracingTag() const { return literal_.GetR1U8AsString(); }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -632,11 +628,9 @@ class HloTraceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // TODO(b/36360764): Remove unique_ptr wrapping.
- std::unique_ptr<Literal> literal_;
+ Literal literal_;
};
class HloFusionInstruction : public HloInstruction {
@@ -644,10 +638,9 @@ class HloFusionInstruction : public HloInstruction {
explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
HloInstruction* fused_root);
- explicit HloFusionInstruction(
- const Shape& shape, FusionKind fusion_kind,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* fusion_computation);
+ explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* fusion_computation);
string ToCategory() const override;
// Returns a serialized representation of this instruction.
@@ -751,7 +744,7 @@ class HloFusionInstruction : public HloInstruction {
bool add_output = false);
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -760,8 +753,7 @@ class HloFusionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The type of the fusion. Used by kFusion only.
@@ -770,9 +762,9 @@ class HloFusionInstruction : public HloInstruction {
class HloRngInstruction : public HloInstruction {
public:
- explicit HloRngInstruction(
- const Shape& shape, RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ explicit HloRngInstruction(const Shape& shape,
+ RandomDistribution distribution,
+ absl::Span<HloInstruction* const> parameters);
// Returns the random distribution for this rng node.
RandomDistribution random_distribution() const { return distribution_; }
// Returns a serialized representation of this instruction.
@@ -780,7 +772,7 @@ class HloRngInstruction : public HloInstruction {
private:
bool IsElementwiseImpl(
- const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ const absl::optional<int64>& operand_idx) const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
@@ -789,8 +781,7 @@ class HloRngInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The distribution requested for random number generation.
@@ -815,8 +806,7 @@ class HloParameterInstruction : public HloInstruction {
CanonicalNameMap* canonical_name_map) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 parameter_number_ = 0;
@@ -840,8 +830,7 @@ class HloGetTupleElementInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
int64 tuple_index_ = -1;
@@ -869,8 +858,7 @@ class HloReducePrecisionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The bit sizes for a reduce-precision operation.
@@ -907,8 +895,7 @@ class HloInfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The string representation of the infeed configuration.
@@ -920,7 +907,7 @@ class HloOutfeedInstruction : public HloInstruction {
explicit HloOutfeedInstruction(const Shape& outfeed_shape,
HloInstruction* operand,
HloInstruction* token_operand,
- tensorflow::StringPiece outfeed_config);
+ absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
@@ -940,8 +927,7 @@ class HloOutfeedInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Shape of outfeed request.
@@ -954,9 +940,9 @@ class HloConvolutionInstruction : public HloInstruction {
public:
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ const PrecisionConfig& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -969,6 +955,16 @@ class HloConvolutionInstruction : public HloInstruction {
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count() const { return feature_group_count_; }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -982,15 +978,18 @@ class HloConvolutionInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- Window window_;
- // Describes the dimension numbers used for a convolution.
- ConvolutionDimensionNumbers convolution_dimension_numbers_;
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count_;
+ // Describes the window used for a convolution.
+ Window window_;
+ // Describes the dimension numbers used for a convolution.
+ ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
};
class HloReduceWindowInstruction : public HloInstruction {
@@ -1014,8 +1013,7 @@ class HloReduceWindowInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
@@ -1063,24 +1061,23 @@ class HloSelectAndScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece custom_call_target);
+ explicit HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
}
void set_window(const Window& window) override {
- window_ = MakeUnique<Window>(window);
+ window_ = absl::make_unique<Window>(window);
}
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -1091,9 +1088,13 @@ class HloCustomCallInstruction : public HloInstruction {
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(dnums);
+ absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
+ void set_feature_group_count(int64 feature_group_count) {
+ feature_group_count_ = feature_group_count;
+ }
+ int64 feature_group_count() const { return feature_group_count_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1106,8 +1107,7 @@ class HloCustomCallInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Name of a global symbol to call, only present for kCustomCall.
string custom_call_target_;
@@ -1115,33 +1115,8 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
-};
-
-class HloHostComputeInstruction : public HloInstruction {
- public:
- explicit HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
- // Returns the channel name associated with the instruction. The name is
- // used to identify host Send/Recv operations.
- const string& channel_name() const { return channel_name_; }
- // Returns a serialized representation of this instruction.
- HloInstructionProto ToProto() const override;
-
- private:
- bool IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const override;
- // Implementation for non-common logic of CloneWithNewOperands.
- std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const override;
- // Name to use for host send/recv channels.
- string channel_name_;
- // Estimate of the duration of a host computation in nanoseconds.
- int64 cost_estimate_ns_ = 0;
+ // The number of feature groups. This is used for grouped convolutions.
+ int64 feature_group_count_;
};
class HloPadInstruction : public HloInstruction {
@@ -1163,8 +1138,7 @@ class HloPadInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// The padding configuration that describes the edge padding and interior
@@ -1174,10 +1148,10 @@ class HloPadInstruction : public HloInstruction {
class HloDynamicSliceInstruction : public HloInstruction {
public:
- explicit HloDynamicSliceInstruction(
- const Shape& shape, HloInstruction* operand,
- HloInstruction* start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ explicit HloDynamicSliceInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* start_indices,
+ absl::Span<const int64> slice_sizes);
// Old methods kept for smooth subclassing transition END.
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
@@ -1199,8 +1173,7 @@ class HloDynamicSliceInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// Describes the [start, start + size) range size for a dynamic slice
@@ -1212,15 +1185,15 @@ class HloGatherInstruction : public HloInstruction {
public:
explicit HloGatherInstruction(
const Shape& shape, HloInstruction* operand,
- HloInstruction* gather_indices,
+ HloInstruction* start_indices,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
const GatherDimensionNumbers& gather_dimension_numbers() const {
CHECK(gather_dimension_numbers_ != nullptr);
return *gather_dimension_numbers_;
}
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- return gather_window_bounds_;
+ absl::Span<const int64> gather_slice_sizes() const {
+ return gather_slice_sizes_;
}
// Returns the dump string of the gather dimension numbers.
string GatherDimensionNumbersToString() const;
@@ -1229,10 +1202,9 @@ class HloGatherInstruction : public HloInstruction {
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim);
+ absl::Span<const int64> offset_dims,
+ absl::Span<const int64> collapsed_slice_dims,
+ absl::Span<const int64> start_index_map, int64 index_vector_dim);
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -1242,12 +1214,11 @@ class HloGatherInstruction : public HloInstruction {
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
+ std::vector<int64> gather_slice_sizes_;
};
class HloScatterInstruction : public HloInstruction {
@@ -1268,9 +1239,9 @@ class HloScatterInstruction : public HloInstruction {
// Creates an instance of ScatterDimensionNumbers.
static ScatterDimensionNumbers MakeScatterDimNumbers(
- tensorflow::gtl::ArraySlice<int64> update_window_dims,
- tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
- tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ absl::Span<const int64> update_window_dims,
+ absl::Span<const int64> inserted_window_dims,
+ absl::Span<const int64> scatter_dims_to_operand_dims,
int64 index_vector_dim);
private:
@@ -1282,13 +1253,114 @@ class HloScatterInstruction : public HloInstruction {
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
};
+class HloIotaInstruction : public HloInstruction {
+ public:
+ explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ int64 iota_dimension() const { return iota_dimension_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+
+ const int64 iota_dimension_;
+};
+
+class HloDotInstruction : public HloInstruction {
+ public:
+ // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+ // dimensions specified in 'dimension_numbers'.
+ explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
+
+ // Returns data on the dimension numbers used for a dot operation.
+ const DotDimensionNumbers& dot_dimension_numbers() const {
+ return dot_dimension_numbers_;
+ }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+ // Returns the dump string of the dot dimension numbers.
+ string DotDimensionNumbersToString() const;
+
+ // Describes the dimension numbers used for a dot.
+ DotDimensionNumbers dot_dimension_numbers_;
+
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
+};
+
+class HloDomainInstruction : public HloInstruction {
+ public:
+ explicit HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata);
+
+ // Retrieves the operand side metadata of a kDomain instruction.
+ const DomainMetadata& operand_side_metadata() const {
+ return *operand_side_metadata_;
+ }
+ // Retrieves the user side metadata of a kDomain instruction.
+ const DomainMetadata& user_side_metadata() const {
+ return *user_side_metadata_;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<DomainMetadata> operand_side_metadata_;
+ std::unique_ptr<DomainMetadata> user_side_metadata_;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 8e0d38b6a6..d9be841dd7 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -17,20 +17,20 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/escaping.h"
+#include "absl/strings/numbers.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-
-using ::tensorflow::StringPiece;
-
namespace {
+using absl::string_view;
+
constexpr int kEOF = -1;
constexpr int kError = -2;
@@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
-tensorflow::StringPiece HloLexer::StringPieceFromPointers(
- const char* begin, const char* end) const {
+absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
+ const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
- return tensorflow::StringPiece(begin, end - begin);
+ return absl::string_view(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
@@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kAttributeName;
}
- tensorflow::StringPiece identifier =
+ absl::string_view identifier =
StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
@@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() {
}
}
- str_val_ = std::string(identifier);
+ str_val_ = string(identifier);
return TokKind::kIdent;
}
@@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
- &decimal_val_);
+ CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_));
return TokKind::kDecimal;
}
@@ -339,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() {
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
auto slice = StringPieceFromPointers(token_start_, current_ptr_);
- if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
+ if (absl::SimpleAtoi(slice, &int64_val_)) {
return TokKind::kInt;
}
LOG(ERROR) << "Failed to parse int literal: " << slice;
@@ -365,6 +364,7 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no = line_no_cache_.line_no_of_query;
}
for (; ptr != location; ptr++) {
+ CHECK_LT(ptr, buf_.end());
if (*ptr == '\n') {
line_no++;
}
@@ -374,24 +374,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no_cache_.last_query = ptr;
line_no_cache_.line_no_of_query = line_no;
size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
- if (line_offset == tensorflow::StringPiece::npos) {
+ if (line_offset == absl::string_view::npos) {
line_offset = 0;
}
return {line_no, ptr - start - line_offset};
}
-tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
+absl::string_view HloLexer::GetLine(LocTy loc) const {
if (!CanDereference(loc)) {
return "LINE OUT OF RANGE";
}
size_t line_start =
StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
- const char* start = line_start == tensorflow::StringPiece::npos
+ const char* start = line_start == absl::string_view::npos
? buf_.begin()
: buf_.begin() + line_start + 1;
size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
const char* end =
- line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end;
+ line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
return StringPieceFromPointers(start, end);
}
@@ -403,10 +403,10 @@ TokKind HloLexer::LexString() {
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
if (RE2::Consume(&consumable, *escaping_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::StringPiece raw =
+ absl::string_view raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
- if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
+ if (!absl::CUnescape(raw, &str_val_, &error)) {
LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
return TokKind::kError;
}
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index 003ac34ace..3e2f8bcd52 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
@@ -34,7 +34,7 @@ namespace xla {
// it directly.
class HloLexer {
public:
- explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
+ explicit HloLexer(absl::string_view buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
@@ -77,7 +77,7 @@ class HloLexer {
std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const;
// Returns the whole line given the location.
- tensorflow::StringPiece GetLine(LocTy loc) const;
+ absl::string_view GetLine(LocTy loc) const;
private:
// Returns the current character. If it's neither the end of input buffer nor
@@ -89,8 +89,8 @@ class HloLexer {
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
- tensorflow::StringPiece StringPieceFromPointers(const char* begin,
- const char* end) const;
+ absl::string_view StringPieceFromPointers(const char* begin,
+ const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
@@ -107,11 +107,11 @@ class HloLexer {
TokKind LexNumberOrPattern();
TokKind LexString();
- const tensorflow::StringPiece buf_;
+ const absl::string_view buf_;
const char* current_ptr_;
// Information about the current token.
- const char* token_start_;
+ const char* token_start_ = nullptr;
TokKind current_kind_;
string str_val_;
Shape shape_val_;
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 43c41ece6e..3a1dd471c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -17,8 +17,9 @@ limitations under the License.
#include <deque>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -29,17 +30,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
using Worklist = std::deque<const HloInstruction*>;
using Workset = std::unordered_set<const HloInstruction*>;
-namespace {
-
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
Workset* workset) {
if (workset->count(instruction) == 0) {
@@ -296,7 +294,7 @@ StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run(
VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module));
+ auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module));
liveness_analysis->RunAnalysis();
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 7e4b883435..5269cad94d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -15,15 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace testing {
-using ::tensorflow::str_util::Join;
-
bool HloMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
@@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong lhs_contracting_dimensions (got {"
- << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {"
- << lhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
+ << "} want {" << lhs_contracting_dim_ << "})";
return false;
}
@@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain(
dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
*listener << instruction->ToString()
<< " has wrong rhs_contracting_dimensions (got {"
- << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {"
- << rhs_contracting_dim_ << "})";
+ << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
+ << "} want {" << rhs_contracting_dim_ << "})";
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index c577b4359a..5502e565b6 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
namespace testing {
@@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher
class HloShardingMatcher
: public ::testing::MatcherInterface<const HloInstruction*> {
public:
- explicit HloShardingMatcher(
- const tensorflow::gtl::optional<HloSharding>& sharding)
+ explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
: sharding_(sharding) {}
bool MatchAndExplain(const HloInstruction* instruction,
@@ -129,7 +128,7 @@ class HloShardingMatcher
void DescribeTo(std::ostream* os) const override;
private:
- tensorflow::gtl::optional<HloSharding> sharding_;
+ absl::optional<HloSharding> sharding_;
};
// Matches a Dot HLO instruction with specific LHS and RHS contracting
@@ -189,6 +188,7 @@ HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
+HLO_MATCHER(Iota);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
HLO_MATCHER(Le);
@@ -307,7 +307,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -317,7 +317,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
new ::xla::testing::HloShapeAndLayoutMatcher(shape));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
- tensorflow::StringPiece shape) {
+ absl::string_view shape) {
return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
ShapeUtil::ParseShapeString(shape).ValueOrDie()));
}
@@ -330,14 +330,14 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
}
// Matcher for Sharding from sharding string
inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
- tensorflow::StringPiece sharding) {
+ absl::string_view sharding) {
return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
ParseSharding(sharding).ValueOrDie()));
}
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
return ::testing::MakeMatcher(
- new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
+ new ::xla::testing::HloShardingMatcher(absl::nullopt));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 27cc5361cd..c7ec88d450 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <map>
+#include <queue>
#include <utility>
#include <vector>
@@ -28,16 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Class implementing a list scheduler of HLO instructions which produces a
// sequence which minimizes memory usage by preferring to schedule the node that
// frees bigger buffer and defines smaller outputs.
@@ -71,7 +70,7 @@ class ListScheduler {
public:
// Construct and return a memory-minimizing sequence of HLO instructions
// containing the given HLO computation.
- static StatusOr<std::vector<const HloInstruction*>> Run(
+ static StatusOr<HloInstructionSequence> Run(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -230,8 +229,8 @@ class ListScheduler {
return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
}
- std::vector<const HloInstruction*> CreateSchedule() {
- std::vector<const HloInstruction*> schedule;
+ HloInstructionSequence CreateSchedule() {
+ HloInstructionSequence schedule;
// Populate the ready list with instructions which have no operands or
// control predecessors.
@@ -375,7 +374,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
+StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -393,7 +392,7 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
} // namespace
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -444,7 +443,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
// Construct a total order based on DFS post-order, visiting operands in
// decreasing cumulative extra user order, and next by cumulative size, with a
// tiebreaker by name for determinism.
- std::vector<const HloInstruction*> sequence;
+ HloInstructionSequence sequence;
FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
sequence.push_back(hlo);
return Status::OK();
@@ -464,7 +463,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
return sequence;
} // namespace xla
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -474,18 +473,16 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
}
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation) {
- const auto& post_order = computation.MakeInstructionPostOrder();
- return std::vector<const HloInstruction*>{post_order.begin(),
- post_order.end()};
+ return HloInstructionSequence(computation.MakeInstructionPostOrder());
}
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -500,7 +497,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
// List wins for most of our benchmarks; postorder-based schedulers win for
// some RNNs.
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> list_sequence,
+ HloInstructionSequence list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 list_memory,
@@ -509,7 +506,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
- TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
@@ -519,7 +516,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> post_order_sequence,
+ HloInstructionSequence post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
@@ -546,32 +543,35 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
}
}
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm) {
- SequentialHloOrdering::HloModuleSequence sequence;
+ HloSchedule schedule(&module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
- TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
+ TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
HeapSimulator::MinimumMemoryForComputation(
- *computation, one_computation_sequence, *points_to_analysis,
+ *computation, computation_sequence, *points_to_analysis,
size_function, &memory_by_computation)
.ValueOrDie();
- sequence[computation] = std::move(one_computation_sequence);
+ schedule.set_sequence(computation, std::move(computation_sequence));
}
}
- VLOG(1) << "Module schedule:\n" << sequence;
- return sequence;
+ VLOG(1) << "Module schedule:\n" << schedule;
+
+ TF_RETURN_IF_ERROR(schedule.Verify());
+
+ return std::move(schedule);
}
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
@@ -582,4 +582,22 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
size_function, nullptr, empty_map);
}
+HloMemoryScheduler::HloMemoryScheduler(
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm)
+ : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, size_function_, algorithm_));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ return true;
+}
+
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+ bool changed = module->has_schedule();
+ module->clear_schedule();
+ return changed;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 2b33ccc8bf..5e02868eba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -13,14 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -32,14 +34,14 @@ namespace xla {
// 'computation' that minimizes peak memory, given a points-to analysis result
// that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size.
-typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -47,7 +49,7 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
// DFS-order scheduler
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -55,7 +57,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
memory_by_computation);
// Naive Post Order scheduler
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -65,26 +67,57 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation);
-// Returns an HloModuleSequence which seeks to minimize the memory required for
+// Returns an HloSchedule which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+ // size_function is the function returning the number of bytes required for a
+ // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+ // specified, then DefaultMemoryScheduler is used.
+ HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
+ ~HloMemoryScheduler() override = default;
+ absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ LogicalBuffer::SizeFunction size_function_;
+ MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+ HloDescheduler() = default;
+ ~HloDescheduler() override = default;
+ absl::string_view name() const override { return "hlo-descheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 9ec983c2bc..1b9e9bfc77 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <memory>
#include <string>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -28,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -64,21 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
+ HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ });
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ module->schedule().sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
- EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
- EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(param, sequence.front());
+ EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(module->schedule());
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+ // Clear the schedule using the descheduling pass.
+ HloDescheduler descheduler;
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+ descheduler.Run(module.get()));
+ EXPECT_TRUE(descheduler_changed);
+ EXPECT_FALSE(module->has_schedule());
}
TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
@@ -106,28 +122,26 @@ ENTRY root {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ const std::vector<const HloInstruction*>& sequence =
+ schedule.sequence(module->entry_computation()).instructions();
+ EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
std::unordered_map<string, const HloInstruction*> instructions_by_name;
- for (const HloInstruction* instruction :
- sequence.at(module->entry_computation())) {
+ for (const HloInstruction* instruction : sequence) {
instructions_by_name[instruction->name()] = instruction;
}
// The first instruction should be the parameter and the last the root.
- EXPECT_EQ(instructions_by_name.at("param"),
- sequence.at(module->entry_computation()).front());
- EXPECT_EQ(instructions_by_name.at("result"),
- sequence.at(module->entry_computation()).back());
+ EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
+ EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
// Instructions "d" and "e" will both be schedulable at the same time, but
// instruction "d" allows us to free the buffer of "p1", so the list scheduler
// should prefer it.
- SequentialHloOrdering ordering(module.get(), sequence);
+ SequentialHloOrdering ordering(schedule);
EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
instructions_by_name.at("e")));
}
@@ -218,13 +232,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(entry_computation).size());
+ SequentialHloOrdering ordering(schedule);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
// better to schedule it first, instead of during the busy time.
@@ -241,13 +255,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
- // HeapSimulator accounts for subcomputations. The max mem doesn't change
- // because the while body isn't live during the peak.
- EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ // HeapSimulator accounts for subcomputations. The output buffer is aliased,
+ // so we don't double count.
+ EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
@@ -267,7 +281,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto abs_abs1 = builder.AddInstruction(
HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
- tensorflow::gtl::ArraySlice<HloInstruction*>({abs_abs1})));
+ absl::Span<HloInstruction* const>({abs_abs1})));
auto tuple_elm = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
@@ -279,19 +293,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(
- buffer.shape(), TUPLE_SIZE);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), TUPLE_SIZE);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// tuple allocates the tuple buffer and doesn't free anything.
// abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
// abs_abs2 should be scheduled before tuple by List.
@@ -330,18 +343,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
auto fusion = computation->CreateFusionInstruction(
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
- },
- ListMemoryScheduler));
+ TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+ ScheduleModule(*module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), 2);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
- SequentialHloOrdering ordering(module.get(), sequence);
+ schedule.sequence(module->entry_computation()).size());
+ SequentialHloOrdering ordering(schedule);
// fusion allocates memory for the tuple elements and doesn't free anything,
// so it's more expensive than exp.
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
@@ -350,7 +363,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
auto module = CreateNewModule();
const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
- const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
// param != 0
// Needs 17 bytes
@@ -390,12 +402,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
TF_ASSERT_OK_AND_ASSIGN(
- SequentialHloOrdering::HloModuleSequence sequence,
- ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ HloSchedule schedule,
+ ScheduleModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
auto entry_computation = module->entry_computation();
- EXPECT_EQ(entry_computation->instruction_count(),
- sequence.at(entry_computation).size());
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ schedule.sequence(module->entry_computation()).size());
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
@@ -405,12 +417,13 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
// HeapSimulator doesn't account for subcomputations
EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn)
.ValueOrDie());
- // HeapSimulator accounts for subcomputations
- EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
- *entry_computation, sequence.at(entry_computation),
+ // HeapSimulator accounts for subcomputations. Cond is the largest one.
+ // The output buffer of the while is aliased.
+ EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, schedule.sequence(entry_computation),
*points_to_analysis, size_fn, &memory_by_computation)
.ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 55ff073d3f..cfe906d9c5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -22,12 +22,14 @@ limitations under the License.
#include <unordered_set>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -49,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
return const_cast<HloInstruction*>(hlo);
}
+Status HloModule::set_schedule(HloSchedule schedule) {
+ TF_RET_CHECK(schedule.module() == this);
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ schedule_ = std::move(schedule);
+ return Status::OK();
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
@@ -197,12 +206,23 @@ void HloModule::ReplaceComputations(
string HloModule::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
- s << "HloModule " << name() << "\n\n";
+ s << "HloModule " << name();
+ if (has_schedule()) {
+ TF_CHECK_OK(schedule().Verify());
+ s << ", is_scheduled=true";
+ }
+ s << "\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
if (computation == entry_computation()) {
s << "ENTRY ";
}
- s << computation->ToString(options) << "\n\n";
+ if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+ s << computation->ToString(
+ options, schedule().sequence(computation).instructions())
+ << "\n\n";
+ } else {
+ s << computation->ToString(options) << "\n\n";
+ }
}
return s.str();
}
@@ -220,6 +240,9 @@ HloModuleProto HloModule::ToProto() const {
}
proto.add_computations()->Swap(&computation_proto);
}
+ if (has_schedule()) {
+ *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+ }
return proto;
}
@@ -274,7 +297,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(entry != nullptr);
- auto module = MakeUnique<HloModule>(proto.name(), module_config);
+ auto module = absl::make_unique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
@@ -308,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
}
+ if (proto.has_schedule()) {
+ TF_ASSIGN_OR_RETURN(
+ HloSchedule schedule,
+ HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ }
+
return std::move(module);
}
@@ -352,7 +382,7 @@ bool IsUsedOutsideSubcomputation(
} // anonymous namespace
HloInstruction* HloModule::OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation) {
auto builder = HloComputation::Builder(outlined_computation_name);
@@ -409,7 +439,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
string error_message =
"The subcomputation to outline has multiple outputs:\n";
for (HloInstruction* output : outputs) {
- tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
+ absl::StrAppend(&error_message, output->ToString(), "\n");
}
LOG(FATAL) << error_message;
}
@@ -507,7 +537,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
- auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
+ auto module = absl::make_unique<HloModule>(name_ + "-" + suffix, config_);
HloCloneContext context(module.get(), suffix);
auto cloned_computation = entry_computation_->Clone(suffix, &context);
@@ -535,12 +565,11 @@ uint64 HloModule::RandomNew64() const {
return rng_();
}
-HloComputation* HloModule::GetComputationWithName(
- tensorflow::StringPiece name) {
+HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
auto computations_in_module = computations();
- auto it = c_find_if(computations_in_module, [&](HloComputation* computation) {
- return computation->name() == name;
- });
+ auto it = absl::c_find_if(
+ computations_in_module,
+ [&](HloComputation* computation) { return computation->name() == name; });
return it == computations_in_module.end() ? nullptr : *it;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index d2e726a0db..26fd1b2438 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -24,16 +24,18 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -142,7 +144,7 @@ class HloModule {
// Returns the computation in this module that has the name `name`. Returns
// null if there is no such computation.
- HloComputation* GetComputationWithName(tensorflow::StringPiece name);
+ HloComputation* GetComputationWithName(absl::string_view name);
// Gets the number of computations in this module.
int64 computation_count() const { return computations_.size(); }
@@ -192,7 +194,7 @@ class HloModule {
// order (root of outlined instructions last). TODO(jingyue): takes a set of
// instructions and topologically sorts them.
HloInstruction* OutlineExpressionFromComputation(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ absl::Span<HloInstruction* const> instructions_to_outline,
const string& outlined_computation_name, HloComputation* computation);
// Returns a randomly generated uint64.
@@ -235,6 +237,19 @@ class HloModule {
StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
const HloInstruction* hlo);
+ // Sets the schedule of the module to the given schedule.
+ Status set_schedule(HloSchedule schedule);
+
+ // Clears the schedule of the module.
+ void clear_schedule() { schedule_.reset(); }
+
+ // Returns true if the module has a schedule set.
+ bool has_schedule() const { return schedule_.has_value(); }
+
+ // Returns the schedue of the module. CHECK fails if no schedule is set.
+ const HloSchedule& schedule() const { return *schedule_; }
+ HloSchedule& schedule() { return *schedule_; }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -262,6 +277,11 @@ class HloModule {
static std::atomic<int> next_unique_module_id_;
// A unique id to label modules with.
int unique_id_;
+
+ // The HloSchedule of the module. The schedule if it exists contains a
+ // sequential order of instructions for each non-fusion computation in the
+ // module.
+ absl::optional<HloSchedule> schedule_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 07a8c798db..9bfa3a5f45 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -18,15 +18,15 @@ limitations under the License.
#include <atomic>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
bool ignore_layouts)
@@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout(
}
string HloModuleConfig::compilation_cache_key() const {
- string key =
- tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled());
+ string key = absl::StrCat("profiling=", hlo_profiling_enabled());
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
- StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
+ StrAppend(&key, absl::StrJoin(params, ", "), ") => ",
entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 074e9c9070..68c18836eb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -18,11 +18,11 @@ limitations under the License.
#include <string>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -72,15 +72,6 @@ class HloModuleConfig {
return debug_options_.xla_hlo_profile();
}
- // Sets/returns whether this is a "host module". Host modules are used to
- // record the data- and control-flow dependencies of host side computation
- // that communicates with compiled code. They are used for analysis and
- // scheduling purposes, but no code is generated.
- bool is_host_module() const { return is_host_module_; }
- void set_is_host_module(bool is_host_module) {
- is_host_module_ = is_host_module;
- }
-
// Sets/returns the module seed set during execution.
void set_seed(uint64 seed) { seed_ = seed; }
uint64 seed() const { return seed_; }
@@ -113,10 +104,7 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
-
- // Whether this is a 'host module'.
- bool is_host_module_ = false;
+ absl::optional<ComputationLayout> entry_computation_layout_;
// Module/graph-level seed handle.
uint64 seed_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 29024085c1..12ca2340a6 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -31,7 +31,7 @@ namespace xla {
class HloModuleDCE : public HloPassInterface {
public:
~HloModuleDCE() override {}
- tensorflow::StringPiece name() const override { return "hlo-module-dce"; }
+ absl::string_view name() const override { return "hlo-module-dce"; }
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 10bf9ffd6c..9c01862a4b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -19,9 +19,10 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -59,7 +60,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
- auto metadata = MakeUnique<HloModuleGroupMetadata>(modules);
+ auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
TF_RETURN_IF_ERROR(metadata->Build());
return std::move(metadata);
}
@@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() {
if (VLOG_IS_ON(4)) {
DumpCollectedStats();
}
+
+ for (HloModule* module : modules_) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ points_to_analyses_[module] = std::move(points_to_analysis);
+ }
+
return Status::OK();
}
@@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const {
ss << " " << hlo->name() << std::endl;
}
ss << "has multiple instructions on the same device";
- return FailedPrecondition("%s", ss.str().c_str());
+ return FailedPrecondition("%s", ss.str());
}
}
}
@@ -204,6 +213,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
return channels_[channel_id_map_.at(channel_id)];
}
+bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
+ return channel_id_map_.find(channel_id) != channel_id_map_.end();
+}
+
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));
@@ -267,15 +280,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
-tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
+absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
const HloInstruction& instruction) const {
// The module group metadata can be created in both "single module, multiple
// devices" and "multiple modules, no explicit devices" fashions.
// The API returns an optional even though the current implementation always
// returns a device, to account for cases where we cannot guess a device.
// In such cases the VerifyChannelInstructions() will return proper errors.
- tensorflow::gtl::optional<int64> device =
- instruction.sharding_unique_device();
+ absl::optional<int64> device = instruction.sharding_unique_device();
if (!device) {
device = GetModuleId(instruction.parent()->parent());
}
@@ -283,10 +295,7 @@ tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
}
int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
- return std::count_if(modules_.begin(), modules_.end(),
- [](const HloModule* module) {
- return !module->config().is_host_module();
- });
+ return modules_.size();
}
Status HloModuleGroupMetadata::RecordInstructions() {
@@ -383,7 +392,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- tensorflow::MakeUnique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::unordered_set<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
companion_set->insert(instruction1);
companion_set->insert(instruction2);
@@ -411,16 +420,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
Status HloModuleGroupMetadata::VerifyChannelInstructions() {
for (const Channel& channel : channels_) {
if (channel.send == nullptr) {
- return FailedPrecondition("missing send for id : %lld", channel.id);
+ return FailedPrecondition("missing send for id : %d", channel.id);
}
if (channel.recv == nullptr) {
- return FailedPrecondition("missing recv for id : %lld", channel.id);
+ return FailedPrecondition("missing recv for id : %d", channel.id);
}
if (channel.send_done == nullptr) {
- return FailedPrecondition("missing send-done for id : %lld", channel.id);
+ return FailedPrecondition("missing send-done for id : %d", channel.id);
}
if (channel.recv_done == nullptr) {
- return FailedPrecondition("missing recv-done for id : %lld", channel.id);
+ return FailedPrecondition("missing recv-done for id : %d", channel.id);
}
}
@@ -436,33 +445,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
auto send_done_device = GetInstructionDevice(*channel.send_done);
if (!send_device) {
return FailedPrecondition("send instruction must have a device: %s",
- channel.send->ToString().c_str());
+ channel.send->ToString());
}
if (!send_done_device) {
return FailedPrecondition("send_done instruction must have a device: %s",
- channel.send_done->ToString().c_str());
+ channel.send_done->ToString());
}
if (*send_device != *send_done_device) {
return FailedPrecondition(
- "send and send-done (channel=%lld) must be on the same device: %lld "
- "vs. %lld",
+ "send and send-done (channel=%d) must be on the same device: %d "
+ "vs. %d",
channel.id, *send_device, *send_done_device);
}
auto recv_device = GetInstructionDevice(*channel.recv);
auto recv_done_device = GetInstructionDevice(*channel.recv_done);
if (!recv_done_device) {
return FailedPrecondition("recv_done instruction must have a device: %s",
- channel.recv_done->ToString().c_str());
+ channel.recv_done->ToString());
}
if (*recv_device != *recv_done_device) {
return FailedPrecondition(
- "recv and recv-done (channel=%lld) must be on the same device: %lld "
- "vs. %lld",
+ "recv and recv-done (channel=%d) must be on the same device: %d "
+ "vs. %d",
channel.id, *recv_device, *recv_done_device);
}
if (*send_device == *recv_device) {
return FailedPrecondition(
- "send and recv (channel=%lld) must be on different devices: %lld",
+ "send and recv (channel=%d) must be on different devices: %d",
channel.id, *send_device);
}
}
@@ -483,7 +492,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
!CheckCompanionPathsCompatibility(
path, GetCompanionsPath(channel.recv_done))) {
return FailedPrecondition(
- "Nest companion paths do not match for channel %lld", channel.id);
+ "Nest companion paths do not match for channel %d", channel.id);
}
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 1b256cd00e..768b0c7eb3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -22,14 +22,15 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -125,6 +126,9 @@ class HloModuleGroupMetadata {
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
+ // Returns if the given channel id exists in metadata.
+ bool HasChannel(int64 channel_id) const;
+
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;
@@ -156,7 +160,7 @@ class HloModuleGroupMetadata {
// Retrieves the device an instruction is assigned to. Either from the
// sharding information, or from the ordinal of the module the instruction
// is in.
- tensorflow::gtl::optional<int64> GetInstructionDevice(
+ absl::optional<int64> GetInstructionDevice(
const HloInstruction& instruction) const;
// Returns the number of modules for devices (excluding the host module).
@@ -194,6 +198,10 @@ class HloModuleGroupMetadata {
// Returns the maximum channel id or all_reduce_id used in the module group.
int64 max_channel_id() const { return max_channel_id_; }
+ TuplePointsToAnalysis* points_to_analysis(HloModule* module) const {
+ return points_to_analyses_.at(module).get();
+ }
+
private:
Status Build();
@@ -268,6 +276,9 @@ class HloModuleGroupMetadata {
// The modules that this metadata was built from.
const std::vector<HloModule*>& modules_;
+
+ tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
+ points_to_analyses_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 0dc5676148..d83ee71490 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,7 +22,10 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -30,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
add_unique_predecessor(control_predecessor);
}
}
- if (instruction->opcode() == HloOpcode::kRecvDone) {
+ if (instruction->opcode() == HloOpcode::kRecvDone &&
+ !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_predecessor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
add_unique_successor(control_successor);
}
}
- if (instruction->opcode() == HloOpcode::kRecv) {
+ if (instruction->opcode() == HloOpcode::kRecv &&
+ !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_successor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -187,7 +193,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
}
std::vector<HloInstruction*> HloModuleGroupUtil::RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> roots;
for (HloComputation* computation : computations) {
for (HloInstruction* instruction : computation->instructions()) {
@@ -264,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
string cyclic_instructions;
for (const auto& state : *visit_state) {
if (state.second == VisitState::kVisiting) {
- tensorflow::strings::StrAppend(&cyclic_instructions,
- state.first->ToString(), "\n");
+ absl::StrAppend(&cyclic_instructions, state.first->ToString(),
+ "\n");
}
}
// TODO(b/64305524): Improve the error message to print out the
@@ -276,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
"following nodes. Note that the order of the nodes is arbitrary "
"and that the list may include nodes that are not part of the "
"cycle.\n%s",
- predecessor->ToString().c_str(), cyclic_instructions.c_str());
+ predecessor->ToString(), cyclic_instructions);
}
stack.push(predecessor);
}
@@ -287,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
}
Status HloModuleGroupUtil::VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
@@ -318,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
@@ -332,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability(
TF_RETURN_IF_ERROR(
VisitTopologicalOrder(&visit_states, visit_function, root));
}
- auto reachability = MakeUnique<HloReachabilityMap>(post_order);
+ auto reachability = absl::make_unique<HloReachabilityMap>(post_order);
for (HloInstruction* hlo : post_order) {
reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
index c25ca1aff5..309c23045d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -56,7 +56,7 @@ class HloModuleGroupUtil {
// Returns the root instructions of the computations.
std::vector<HloInstruction*> RootInstructions(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Visit state of each instruction during DFS traversal.
enum VisitState {
@@ -93,15 +93,14 @@ class HloModuleGroupUtil {
HloInstruction* root);
// Verifies that the computations are well-formed (e.g., no cycles).
- Status VerifyComputations(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ Status VerifyComputations(absl::Span<HloComputation* const> computations);
// Below Reachability utils resemble those in HloComputation, except that
// they can handle instructions across multiple computations.
//
// Creates the reachability map for the instructions in the computations.
StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
- tensorflow::gtl::ArraySlice<HloComputation*> computations);
+ absl::Span<HloComputation* const> computations);
// Updates the reachability of the given instruction, taking the global
// predeccessorss and successors into account.
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 236f450086..400bd4d947 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -15,21 +15,26 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
+namespace op = ::xla::testing::opcode_matchers;
+
class HloModuleTest : public HloTestBase {
protected:
HloModuleTest() {}
@@ -44,7 +49,7 @@ class HloModuleTest : public HloTestBase {
// Creates a computation which calls the given zero-parameter computations.
std::unique_ptr<HloComputation> CreateCallComputation(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ absl::Span<HloComputation* const> computations) {
auto builder = HloComputation::Builder("Call");
for (auto computation : computations) {
builder.AddInstruction(
@@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) {
EXPECT_NE(module_a->unique_id(), module_b->unique_id());
}
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_TRUE(module_copy->has_schedule());
+ TF_ASSERT_OK(module_copy->schedule().Verify());
+ EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+ ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+ module_copy->entry_computation()));
+ EXPECT_THAT(
+ module_copy->schedule()
+ .sequence(module_copy->entry_computation())
+ .instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index d1eaf35785..2d4e38589f 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -39,7 +39,7 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
});
auto it = opcode_map->find(opcode_name);
if (it == opcode_map->end()) {
- return InvalidArgument("Unknown opcode: %s", opcode_name.c_str());
+ return InvalidArgument("Unknown opcode: %s", opcode_name);
}
return it->second;
}
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index ec279867e5..e6bfb8025d 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -58,6 +58,7 @@ namespace xla {
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
+ V(kCollectivePermute, "collective-permute") \
V(kClz, "count-leading-zeros") \
V(kComplex, "complex") \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
@@ -85,7 +86,6 @@ namespace xla {
V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
- V(kHostCompute, "host-compute") \
V(kImag, "imag") \
V(kInfeed, "infeed") \
V(kIota, "iota") \
@@ -156,7 +156,7 @@ enum HloOpcodeProperty {
// Returns a string representation of the opcode.
string HloOpcodeString(HloOpcode opcode);
-// Returns a string representation of the opcode.
+// Retrieves the opcode enum by name if the opcode exists.
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name);
inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 6c1e015f77..f1dc08bafa 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -25,8 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -252,14 +253,36 @@ bool HloOrdering::LiveRangeStrictlyBefore(
VLOG(4) << a << " not defined before " << b;
return false;
}
+
+ if (a.live_out_of_module()) {
+ VLOG(4) << a << " is live out of module and defined before " << b;
+ return false;
+ }
+
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
+ if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
+ use.instruction)) {
+ continue;
+ }
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
VLOG(4) << "use of " << a << " (" << use << ") not before " << b
<< " is defined";
return false;
}
}
+
+ if (a.instruction()->parent() == b.instruction()->parent()) {
+ for (const HloPosition& position : a.positions()) {
+ if (position.instruction ==
+ a.instruction()->parent()->root_instruction()) {
+ VLOG(4) << a << " is live out of computation and defined before " << b
+ << " which is in same computation";
+ return false;
+ }
+ }
+ }
+
return true;
}
@@ -270,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
!LiveRangeStrictlyBefore(b, a, dataflow);
}
-HloOrderingProto HloOrdering::ToProto() const {
- HloOrderingProto proto;
- for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- SequentialOrder(*computation);
- if (sequence != nullptr) {
- HloOrderingProto::SequentialComputation* proto_computation =
- proto.add_sequential_computations();
- proto_computation->set_computation_name(computation->name());
- for (const HloInstruction* instruction : *sequence) {
- *proto_computation->add_instruction_names() = instruction->name();
- }
- }
- }
- return proto;
-}
-
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
: HloOrdering(module) {}
@@ -302,22 +308,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
std::vector<string> pieces;
pieces.push_back(name);
for (auto* computation : module_->MakeNonfusionComputations()) {
- pieces.push_back(tensorflow::strings::Printf("computation %s:",
- computation->name().c_str()));
+ pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
const auto all = computation->MakeInstructionPostOrder();
for (auto instruction : all) {
- pieces.push_back(tensorflow::strings::Printf(
- " %s predecessors:", instruction->name().c_str()));
+ pieces.push_back(
+ absl::StrFormat(" %s predecessors:", instruction->name()));
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", predecessor->name().c_str()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
}
- return tensorflow::str_util::Join(pieces, "\n");
+ return absl::StrJoin(pieces, "\n");
}
DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
@@ -334,15 +338,24 @@ string DependencyHloOrdering::ToString() const {
return ToStringHelper("DependencyHloOrdering");
}
-SequentialHloOrdering::SequentialHloOrdering(
- const HloModule* module, const HloModuleSequence& module_sequence)
- : HloOrdering(module), module_sequence_(module_sequence) {
+SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
+ : HloOrdering(schedule.module()), schedule_(schedule) {
+ Initialize();
+}
+
+SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
+ : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
+ Initialize();
+}
+
+void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position.
- for (auto computation_order : module_sequence_) {
- const std::vector<const HloInstruction*>& order = computation_order.second;
+ TF_DCHECK_OK(schedule_.Verify());
+ for (const auto& computation_sequence : schedule_.sequences()) {
+ const std::vector<const HloInstruction*>& order =
+ computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) {
- DCHECK_EQ(0, order_position_.count(order[i]));
- order_position_.emplace(order[i], i);
+ InsertOrDie(&order_position_, order[i], i);
}
}
}
@@ -360,50 +373,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
const std::vector<const HloInstruction*>*
SequentialHloOrdering::SequentialOrder(
const HloComputation& computation) const {
- auto find_it = module_sequence_.find(&computation);
- return find_it == module_sequence_.end() ? nullptr : &find_it->second;
+ return schedule_.is_computation_scheduled(&computation)
+ ? &schedule_.sequence(&computation).instructions()
+ : nullptr;
}
string SequentialHloOrdering::ToString() const {
- std::vector<string> pieces;
- pieces.push_back("SequentialHloOrdering");
- for (auto* computation : module_->computations()) {
- pieces.push_back(tensorflow::strings::Printf("computation %s order:",
- computation->name().c_str()));
- // Gather all instructions in the module sequence for this computation and
- // sort them by their position.
- std::vector<const HloInstruction*> instructions;
- for (auto& instruction_position : order_position_) {
- const HloInstruction* instruction = instruction_position.first;
- if (instruction->parent() == computation) {
- instructions.push_back(instruction);
- }
- }
- std::sort(instructions.begin(), instructions.end(),
- [this](const HloInstruction* a, const HloInstruction* b) {
- return order_position_.at(a) < order_position_.at(b);
- });
- for (auto instruction : instructions) {
- pieces.push_back(
- tensorflow::strings::Printf(" %s", instruction->name().c_str()));
- }
- }
- return tensorflow::str_util::Join(pieces, "\n");
-}
-
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence) {
- for (auto computation_pair : module_sequence) {
- const HloComputation* computation = computation_pair.first;
- const std::vector<const HloInstruction*>& computation_sequence =
- computation_pair.second;
- out << "Computation " << computation->name() << ":\n";
- for (auto* instruction : computation_sequence) {
- out << " " << instruction->name() << "\n";
- }
- }
- return out;
+ return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index 985f3fa64d..b0361c3f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -71,10 +72,6 @@ class HloOrdering {
virtual string ToString() const = 0;
- // Returns the serialized representation of this ordering.
- // Only sequential computation orders are represented.
- HloOrderingProto ToProto() const;
-
protected:
// Returns true if instruction 'a' executes before instruction 'b'.
// Precondition: 'a' and 'b' are in the same computation.
@@ -183,17 +180,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
// interference is reduced relative to DependencyHloOrdering.
class SequentialHloOrdering : public HloOrdering {
public:
- // TODO(dimvar): HloModuleSequence is not a good name because it sounds like
- // a sequence of modules, instead of a map of schedules for all computations
- // in a module. We should change it at some point.
- //
- // A sequence of instructions for each computation in the module.
- using HloModuleSequence =
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::vector<const HloInstruction*>>;
-
- SequentialHloOrdering(const HloModule* module,
- const HloModuleSequence& module_sequence);
+ SequentialHloOrdering(const HloSchedule& schedule);
+ SequentialHloOrdering(HloSchedule&& schedule);
~SequentialHloOrdering() override = default;
// Returns the sequential instruction order for the given computation.
@@ -203,10 +191,12 @@ class SequentialHloOrdering : public HloOrdering {
string ToString() const override;
protected:
+ void Initialize();
+
bool ExecutesBeforeInSameComputation(const HloInstruction* a,
const HloInstruction* b) const override;
- const HloModuleSequence module_sequence_;
+ const HloSchedule schedule_;
// The position of every instruction in the HLO module in its respective
// computation sequence (a value of zero indicates the instruction is first in
@@ -217,10 +207,6 @@ class SequentialHloOrdering : public HloOrdering {
tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
};
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2d9c..00970bcda3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -376,5 +377,104 @@ ENTRY root {
dataflow->GetValueDefinedAt(add_3)));
}
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of the module should interfere with values
+ // defined after the root instruction. That is:
+ //
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* entry =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(entry, {param, root, dead});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+ ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+ // Tests that values live out of a computation should interfere with values
+ // defined after the root instruction of the computation. That is:
+ //
+ // subcomputation:
+ // %param = param(0)
+ // ROOT %root = negate(%param)
+ // %dead = Constant(123.0)
+ //
+ // entry computation:
+ // %c = constant(42.0)
+ // ROOT %call = call({%c}), subcomputation
+ //
+ // %root should interfere with %dead.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+ HloInstruction* param = subbuilder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ HloInstruction* root = subbuilder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+ HloInstruction* dead = subbuilder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ HloComputation* subcomputation = module->AddEmbeddedComputation(
+ subbuilder.Build(/*root_instruction=*/root));
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* c = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ HloInstruction* call = builder.AddInstruction(
+ HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+ HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(subcomputation, {param, root, dead});
+ schedule.set_sequence(entry, {c, call});
+ TF_ASSERT_OK(schedule.Verify());
+ SequentialHloOrdering ordering(schedule);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+ EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+ dataflow->GetValueDefinedAt(dead),
+ *dataflow));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index eb48337cd7..11caa89c54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -15,39 +15,56 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
namespace {
-using ::tensorflow::StringPiece;
-using ::tensorflow::gtl::optional;
-using ::tensorflow::str_util::Join;
-using ::tensorflow::str_util::Split;
-using ::tensorflow::str_util::SplitAndParseAsInts;
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::nullopt;
+using absl::optional;
+using absl::StrAppend;
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrJoin;
const double kF16max = 65504;
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+ HloSchedule schedule(module);
+ for (const HloComputation* computation : module->computations()) {
+ if (!computation->IsFusionComputation()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ schedule.GetOrCreateSequence(computation).push_back(instruction);
+ }
+ }
+ }
+ return schedule;
+}
+
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(StringPiece str, const HloModuleConfig& config)
+ explicit HloParser(absl::string_view str, const HloModuleConfig& config)
: lexer_(str), config_(config) {}
// Runs the parser. Returns false if an error occurred.
@@ -57,14 +74,29 @@ class HloParser {
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
- string GetError() const { return Join(error_, "\n"); }
+ string GetError() const { return StrJoin(error_, "\n"); }
// Stand alone parsing utils for various aggregate data types.
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
+ StatusOr<PaddingConfig> ParsePaddingConfigOnly();
+
+ // Stand-alone parsing utility for a single instruction worth of text.
+ Status ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name);
private:
+ // Locates an instruction with the given name in the instruction_pool_ or
+ // returns nullptr.
+ //
+ // If the missing_instruction_hook_ is registered and a "shape" is provided,
+ // the hook will be called and may satisfy the request for the given
+ // instruction. This is useful when we reify parameters as they're resolved;
+ // i.e. for ParseSingleInstruction.
+ std::pair<HloInstruction*, LocTy>* FindInstruction(
+ const string& name, const optional<Shape>& shape = nullopt);
+
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputations();
@@ -73,16 +105,13 @@ class HloParser {
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseControlPredecessors(HloInstruction* instruction);
- bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
- bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
- bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseLiteral(Literal* literal, const Shape& shape);
+ bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+ bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+ bool ParseSparseLiteral(Literal* literal, const Shape& shape);
template <typename LiteralNativeT>
- bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape);
+ bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
@@ -138,6 +167,7 @@ class HloParser {
kFusionKind,
kDistribution,
kDomain,
+ kPrecisionList,
};
struct AttrConfig {
@@ -203,6 +233,7 @@ class HloParser {
bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
+ bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -221,6 +252,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
+ bool ParsePrecision(PrecisionConfig::Precision* result);
bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -233,8 +265,8 @@ class HloParser {
bool CanBeParamListToShape();
// Logs the current parsing line and the given message. Always returns false.
- bool TokenError(StringPiece msg);
- bool Error(LocTy loc, StringPiece msg);
+ bool TokenError(absl::string_view msg);
+ bool Error(LocTy loc, absl::string_view msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
@@ -265,24 +297,55 @@ class HloParser {
std::vector<std::unique_ptr<HloComputation>> computations_;
const HloModuleConfig config_;
std::vector<string> error_;
+
+ // Function that gets invoked when we try to resolve an instruction
+ // instruction_pool_ but fail to do so.
+ std::function<std::pair<HloInstruction*, LocTy>*(string,
+ const optional<Shape>&)>
+ missing_instruction_hook_;
};
-bool HloParser::Error(LocTy loc, StringPiece msg) {
+bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
+ for (const auto& split : absl::StrSplit(s, delim)) {
+ int64 val;
+ if (!absl::SimpleAtoi(split, &val)) {
+ return false;
+ }
+ out->push_back(val);
+ }
+ return true;
+}
+
+// Creates replica groups from the provided nested array. groups[i] represents
+// the replica ids for group 'i'.
+std::vector<ReplicaGroup> CreateReplicaGroups(
+ absl::Span<const std::vector<int64>> groups) {
+ std::vector<ReplicaGroup> replica_groups;
+ absl::c_transform(groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
+ return replica_groups;
+}
+
+bool HloParser::Error(LocTy loc, absl::string_view msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
const unsigned col = line_col.second;
std::vector<string> error_lines;
error_lines.push_back(
StrCat("was parsing ", line, ":", col, ": error: ", msg));
- error_lines.push_back(std::string(lexer_.GetLine(loc)));
+ error_lines.emplace_back(lexer_.GetLine(loc));
error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
- error_.push_back(Join(error_lines, "\n"));
+ error_.push_back(StrJoin(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
-bool HloParser::TokenError(StringPiece msg) {
+bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
@@ -291,6 +354,17 @@ bool HloParser::Run() {
return ParseHloModule();
}
+std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
+ const string& name, const optional<Shape>& shape) {
+ std::pair<HloInstruction*, LocTy>* instr =
+ tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ // Potentially call the missing instruction hook.
+ if (instr == nullptr && missing_instruction_hook_ != nullptr) {
+ return missing_instruction_hook_(name, shape);
+ }
+ return instr;
+}
+
// ::= 'HloModule' name computations
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
@@ -304,9 +378,25 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = MakeUnique<HloModule>(name, config_);
+ absl::optional<bool> is_scheduled;
+ std::unordered_map<string, AttrConfig> attrs;
+ attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+ if (!ParseAttributes(attrs)) {
+ return false;
+ }
+
+ module_ = absl::make_unique<HloModule>(name, config_);
+
+ if (!ParseComputations()) {
+ return false;
+ }
+
+ if (is_scheduled.has_value() && *is_scheduled) {
+ TF_CHECK_OK(
+ module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ }
- return ParseComputations();
+ return true;
}
// computations ::= (computation)+
@@ -357,7 +447,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = MakeUnique<HloComputation::Builder>(name);
+ auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -370,8 +460,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node =
- tensorflow::gtl::FindOrNull(instruction_pool_, root_name);
+ std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
// This means some instruction was marked as ROOT but we didn't find it in the
// pool, which should not happen.
if (!root_name.empty() && root_node == nullptr) {
@@ -485,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConstant: {
- std::unique_ptr<Literal> literal;
+ Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
@@ -498,11 +587,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kIota: {
+ optional<tensorflow::int64> iota_dimension;
+ attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
+ &iota_dimension};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateIota(shape));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
@@ -597,31 +690,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
+ optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<string> barrier;
optional<int64> all_reduce_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
- attrs["replica_group_ids"] = {
- /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
+ attrs["replica_groups"] = {/*required=*/false,
+ AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
&all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- if (replica_group_ids) {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, *replica_group_ids,
- barrier ? *barrier : "", all_reduce_id));
- } else {
- instruction =
- builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
- shape, operands, *to_apply, {}, barrier ? *barrier : "",
- all_reduce_id));
+ std::vector<ReplicaGroup> replica_groups;
+ if (tmp_groups) {
+ replica_groups = CreateReplicaGroups(*tmp_groups);
}
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, replica_groups,
+ barrier ? *barrier : "", all_reduce_id));
break;
}
case HloOpcode::kAllToAll: {
@@ -629,21 +720,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> barrier;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
- attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
- c_transform(*tmp_groups, std::back_inserter(replica_groups),
- [](const std::vector<int64>& ids) {
- ReplicaGroup group;
- *group.mutable_replica_ids() = {ids.begin(), ids.end()};
- return group;
- });
+ replica_groups = CreateReplicaGroups(*tmp_groups);
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateAllToAll(shape, operands, replica_groups));
+ break;
+ }
+ case HloOpcode::kCollectivePermute: {
+ optional<std::vector<std::vector<int64>>> source_targets;
+ attrs["source_target_pairs"] = {
+ /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
- shape, operands, replica_groups, barrier ? *barrier : ""));
+ std::vector<std::pair<int64, int64>> pairs(source_targets->size());
+ for (int i = 0; i < pairs.size(); i++) {
+ if ((*source_targets)[i].size() != 2) {
+ return TokenError(
+ "expects 'source_target_pairs=' to be a list of pairs");
+ }
+ pairs[i].first = (*source_targets)[i][0];
+ pairs[i].second = (*source_targets)[i][1];
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCollectivePermute(shape, operands[0], pairs));
break;
}
case HloOpcode::kReshape: {
@@ -831,6 +937,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ optional<std::vector<PrecisionConfig::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -841,9 +950,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!feature_group_count) {
feature_group_count = 1;
}
+ PrecisionConfig precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfig::DEFAULT);
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
- feature_group_count.value()));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
+ feature_group_count.value(), *window, *dnums, precision_config));
break;
}
case HloOpcode::kFft: {
@@ -916,11 +1033,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
- operands.size() / 2),
+ absl::Span<HloInstruction* const>(operands).subspan(
+ 0, operands.size() / 2),
/*init_values=*/
- tensorflow::gtl::ArraySlice<HloInstruction*>(
- operands, operands.size() / 2, operands.size()),
+ absl::Span<HloInstruction* const>(operands).subspan(
+ operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
@@ -1159,11 +1276,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1175,20 +1295,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (dnums.has_value()) {
instruction->set_convolution_dimension_numbers(*dnums);
}
- break;
- }
- case HloOpcode::kHostCompute: {
- optional<string> channel_name;
- optional<tensorflow::int64> cost_estimate_ns;
- attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
- &channel_name};
- attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
- &cost_estimate_ns};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
- return false;
+ if (feature_group_count.has_value()) {
+ instruction->set_feature_group_count(*feature_group_count);
}
- instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
- shape, operands, *channel_name, *cost_estimate_ns));
break;
}
case HloOpcode::kDot: {
@@ -1204,6 +1313,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
+ optional<std::vector<PrecisionConfig::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1228,27 +1340,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
rhs_batch_dims->end()};
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
+ PrecisionConfig precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfig::DEFAULT);
+ }
+
+ instruction = builder->AddInstruction(HloInstruction::CreateDot(
+ shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
- optional<std::vector<tensorflow::int64>> output_window_dims;
- attrs["output_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<tensorflow::int64>> elided_window_dims;
- attrs["elided_window_dims"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
- attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
- AttrTy::kBracedInt64List,
- &gather_dims_to_operand_dims};
+ optional<std::vector<tensorflow::int64>> offset_dims;
+ attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &offset_dims};
+ optional<std::vector<tensorflow::int64>> collapsed_slice_dims;
+ attrs["collapsed_slice_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
+ optional<std::vector<tensorflow::int64>> start_index_map;
+ attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &start_index_map};
optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<tensorflow::int64>> window_bounds;
- attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &window_bounds};
+ optional<std::vector<tensorflow::int64>> slice_sizes;
+ attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1257,14 +1377,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*offset_dims=*/*offset_dims,
+ /*collapsed_slice_dims=*/*collapsed_slice_dims,
+ /*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
- shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
- dim_numbers, *window_bounds));
+ shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+ dim_numbers, *slice_sizes));
break;
}
case HloOpcode::kScatter: {
@@ -1510,14 +1630,14 @@ bool HloParser::ParseDomain(DomainData* domain) {
return false;
}
if (*kind == ShardingMetadata::KindName()) {
- auto entry_sharding_ptr = MakeUnique<HloSharding>(
+ auto entry_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*entry_sharding).ValueOrDie());
- auto exit_sharding_ptr = MakeUnique<HloSharding>(
+ auto exit_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*exit_sharding).ValueOrDie());
domain->entry_metadata =
- MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
domain->exit_metadata =
- MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr));
+ absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
} else {
return TokenError(StrCat("unsupported domain kind: ", *kind));
}
@@ -1537,11 +1657,9 @@ bool HloParser::ParseInstructionNames(
if (!ParseName(&name)) {
return Error(loc, "expects a instruction name");
}
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
if (!instr) {
- return TokenError(
- Printf("instruction '%s' is not defined", name.c_str()));
+ return TokenError(StrFormat("instruction '%s' is not defined", name));
}
instructions->push_back(instr->first);
} while (EatIfPresent(TokKind::kComma));
@@ -1689,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
// literal
// ::= tuple
// ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
@@ -1700,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return TokenError(StrCat("expects tuple constant in shape ",
ShapeUtil::HumanString(shape)));
@@ -1709,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
- std::vector<std::unique_ptr<Literal>> elements(
- ShapeUtil::TupleElementCount(shape));
+ std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
@@ -1736,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
@@ -1746,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
return ParseDenseLiteral(literal, shape);
}
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
@@ -1770,10 +1883,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
std::vector<tensorflow::int64> elems_seen_until_dim(
elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim);
return StrCat("[",
- Join(elems_seen_until_dim, ",",
- [](string* out, const tensorflow::int64& num_elems) {
- StrAppend(out, num_elems - 1);
- }),
+ StrJoin(elems_seen_until_dim, ",",
+ [](string* out, const tensorflow::int64& num_elems) {
+ StrAppend(out, num_elems - 1);
+ }),
"]");
};
do {
@@ -1783,17 +1896,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
case TokKind::kLbrace: {
nest_level++;
if (nest_level > rank) {
- return TokenError(Printf(
- "expects nested array in rank %lld, but sees larger", rank));
+ return TokenError(absl::StrFormat(
+ "expects nested array in rank %d, but sees larger", rank));
}
if (nest_level > 1) {
elems_seen_per_dim[nest_level - 2]++;
if (elems_seen_per_dim[nest_level - 2] >
shape.dimensions(nest_level - 2)) {
- return TokenError(Printf(
- "expects %lld elements in the %sth element, but sees more",
+ return TokenError(absl::StrFormat(
+ "expects %d elements in the %sth element, but sees more",
shape.dimensions(nest_level - 2),
- get_index_str(nest_level - 2).c_str()));
+ get_index_str(nest_level - 2)));
}
}
lexer_.Lex();
@@ -1802,9 +1915,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
case TokKind::kRbrace: {
nest_level--;
if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
- return TokenError(Printf(
- "expects %lld elements in the %sth element, but sees %lld",
- shape.dimensions(nest_level), get_index_str(nest_level).c_str(),
+ return TokenError(absl::StrFormat(
+ "expects %d elements in the %sth element, but sees %d",
+ shape.dimensions(nest_level), get_index_str(nest_level),
elems_seen_per_dim[nest_level]));
}
elems_seen_per_dim[nest_level] = 0;
@@ -1825,15 +1938,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
if (rank > 0) {
if (nest_level != rank) {
return TokenError(
- Printf("expects nested array in rank %lld, but sees %lld", rank,
- nest_level));
+ absl::StrFormat("expects nested array in rank %d, but sees %d",
+ rank, nest_level));
}
elems_seen_per_dim[rank - 1]++;
if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
- return TokenError(
- Printf("expects %lld elements on the minor-most dimension, but "
- "sees more",
- shape.dimensions(rank - 1)));
+ return TokenError(absl::StrFormat(
+ "expects %d elements on the minor-most dimension, but "
+ "sees more",
+ shape.dimensions(rank - 1)));
}
}
if (lexer_.GetKind() == TokKind::kw_true ||
@@ -1841,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// TODO(congliu): bool type literals with rank >= 1 are actually
// printed in a compact form instead of "true" or "false". Fix that.
if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
- linear_index++, literal->get())) {
+ linear_index++, literal)) {
return false;
}
lexer_.Lex();
@@ -1852,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1863,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
- if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+ if (!SetValueInLiteral(value, linear_index++, literal)) {
return false;
}
} else {
@@ -1875,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
} // end of switch
} while (nest_level > 0);
- *literal = (*literal)->Relayout(shape.layout());
+ *literal = literal->Relayout(shape.layout());
return true;
}
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
if (!EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -1920,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
}
template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
- const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
std::vector<tensorflow::int64> index;
tensorflow::int64 rank = ShapeUtil::Rank(shape);
- *literal = MakeUnique<Literal>(shape);
+ *literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
@@ -1960,7 +2071,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return Error(
index_loc,
StrCat("invalid multi-dimension index for shape with rank ", rank,
- ": [", Join(index, ", "), "]"));
+ ": [", StrJoin(index, ", "), "]"));
}
}
if (!ParseToken(TokKind::kColon,
@@ -2000,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
return false;
}
- if ((*literal)->sparse_element_count() + 1 ==
+ if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
@@ -2008,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
ShapeUtil::HumanStringWithLayout(shape)));
}
- (*literal)->AppendSparseElement(index, value);
+ literal->AppendSparseElement(index, value);
}
- (*literal)->SortSparseElements();
+ literal->SortSparseElements();
return true;
}
@@ -2021,6 +2132,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
// ::= operand (, operand)*
// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
+ CHECK(operands != nullptr);
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
@@ -2031,9 +2143,10 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
do {
LocTy loc = lexer_.GetLoc();
string name;
+ optional<Shape> shape;
if (CanBeShape()) {
- Shape shape;
- if (!ParseShape(&shape)) {
+ shape.emplace();
+ if (!ParseShape(&shape.value())) {
return false;
}
}
@@ -2041,8 +2154,8 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
return false;
}
std::pair<HloInstruction*, LocTy>* instruction =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
- if (!instruction) {
+ FindInstruction(name, shape);
+ if (instruction == nullptr) {
return Error(loc, StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction->first);
@@ -2053,6 +2166,7 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
+ CHECK(operands != nullptr);
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(operands)) {
return false;
@@ -2086,8 +2200,8 @@ bool HloParser::ParseSubAttributes(
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
- return Error(loc, Printf("sub-attribute %s is expected but not seen",
- attr_it.first.c_str()));
+ return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
+ attr_it.first));
}
}
return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
@@ -2107,8 +2221,8 @@ bool HloParser::ParseAttributes(
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
- return Error(loc, Printf("attribute %s is expected but not seen",
- attr_it.first.c_str()));
+ return Error(loc, StrFormat("attribute %s is expected but not seen",
+ attr_it.first));
}
}
return true;
@@ -2124,7 +2238,7 @@ bool HloParser::ParseAttributeHelper(
}
VLOG(1) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
- return Error(loc, Printf("attribute %s already exists", name.c_str()));
+ return Error(loc, StrFormat("attribute %s already exists", name));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
@@ -2134,13 +2248,13 @@ bool HloParser::ParseAttributeHelper(
} else {
allowed_attrs = StrCat(
"Allowed attributes: ",
- Join(attrs, ", ",
- [&](string* out, const std::pair<string, AttrConfig>& kv) {
- StrAppend(out, kv.first);
- }));
+ StrJoin(attrs, ", ",
+ [&](string* out, const std::pair<string, AttrConfig>& kv) {
+ StrAppend(out, kv.first);
+ }));
}
- return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(),
- allowed_attrs.c_str()));
+ return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
+ allowed_attrs));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
@@ -2322,10 +2436,20 @@ bool HloParser::ParseAttributeHelper(
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
+ case AttrTy::kPrecisionList: {
+ std::vector<PrecisionConfig::Precision> result;
+ if (!ParsePrecisionList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
- return Error(loc, Printf("error parsing attribute %s", name.c_str()));
+ return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}
@@ -2440,20 +2564,24 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
string str = lexer_.GetStrVal();
- // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
+ // The str is expected to have 3 items, lhs, rhs, out, and it must look like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
- // So we replace the "->" with "_" and then split on "_".
- str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
- /*newsub=*/"_",
- /*replace_all=*/false);
- std::vector<string> lhs_rhs_out = Split(str, "_");
- if (lhs_rhs_out.size() != 3) {
+ std::vector<string> split1 = absl::StrSplit(str, "_");
+ if (split1.size() != 2) {
+ LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
+ << str;
+ }
+ std::vector<string> split2 = absl::StrSplit(split1[1], "->");
+ if (split2.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
+ absl::string_view lhs = split1[0];
+ absl::string_view rhs = split2[0];
+ absl::string_view out = split2[1];
- const tensorflow::int64 rank = lhs_rhs_out[0].length();
- if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
+ const tensorflow::int64 rank = lhs.length();
+ if (rank != rhs.length() || rank != out.length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
}
@@ -2468,8 +2596,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
// lhs
{
- const string& lhs = lhs_rhs_out[0];
- if (!is_unique(lhs)) {
+ if (!is_unique(string(lhs))) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
@@ -2486,14 +2613,13 @@ bool HloParser::ParseConvolutionDimensionNumbers(
dnums->set_input_spatial_dimensions(c - '0', i);
} else {
return TokenError(
- Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
+ StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
}
}
}
// rhs
{
- const string& rhs = lhs_rhs_out[1];
- if (!is_unique(rhs)) {
+ if (!is_unique(string(rhs))) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
@@ -2510,14 +2636,13 @@ bool HloParser::ParseConvolutionDimensionNumbers(
dnums->set_kernel_spatial_dimensions(c - '0', i);
} else {
return TokenError(
- Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
+ StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
}
}
}
// output
{
- const string& out = lhs_rhs_out[2];
- if (!is_unique(out)) {
+ if (!is_unique(string(out))) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
@@ -2533,8 +2658,8 @@ bool HloParser::ParseConvolutionDimensionNumbers(
} else if (c < '0' + rank && c >= '0') {
dnums->set_output_spatial_dimensions(c - '0', i);
} else {
- return TokenError(
- Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
+ return TokenError(StrFormat(
+ "expects [0-%dbf] in output dimension numbers", rank - 1));
}
}
}
@@ -2580,9 +2705,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
}
const auto& range = ranges.back();
if (range.size() != 2 && range.size() != 3) {
- return Error(loc, Printf("expects [start:limit:step] or [start:limit], "
- "but sees %ld elements.",
- range.size()));
+ return Error(loc,
+ StrFormat("expects [start:limit:step] or [start:limit], "
+ "but sees %d elements.",
+ range.size()));
}
} while (EatIfPresent(TokKind::kComma));
@@ -2594,6 +2720,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
+// precisionlist ::= start precision_elements end
+// precision_elements
+// ::= /*empty*/
+// ::= precision_val (delim precision_val)*
+bool HloParser::ParsePrecisionList(
+ std::vector<PrecisionConfig::Precision>* result) {
+ auto parse_and_add_item = [&]() {
+ PrecisionConfig::Precision item;
+ if (!ParsePrecision(&item)) {
+ return false;
+ }
+ result->push_back(item);
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2750,14 +2894,13 @@ bool HloParser::ParseDxD(const string& name,
std::vector<tensorflow::int64>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
- return Error(loc,
- Printf("sub-attribute '%s=' already exists", name.c_str()));
+ return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
tensorflow::int64 number;
if (!ParseInt64(&number)) {
- return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
+ return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
}
result->push_back(number);
return true;
@@ -2765,9 +2908,8 @@ bool HloParser::ParseDxD(const string& name,
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
string str = lexer_.GetStrVal();
- if (!SplitAndParseAsInts(str, 'x', result)) {
- return Error(loc,
- Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
+ if (!SplitToInt64s(str, 'x', result)) {
+ return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
}
lexer_.Lex();
return true;
@@ -2785,10 +2927,9 @@ bool HloParser::ParseWindowPad(
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (int i = 0; i < padding_str.size(); i++) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> low_high;
- if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
"expects padding_low and padding_high separated by '_'");
@@ -2809,10 +2950,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
}
LocTy loc = lexer_.GetLoc();
string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (const auto& padding_dim_str : padding_str) {
+ for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<tensorflow::int64> padding_dim;
- if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
+ if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
"expects padding config pattern like 'low_high_interior' or "
@@ -2864,9 +3004,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) {
string val = lexer_.GetStrVal();
auto status_or_result = StringToHloOpcode(val);
if (!status_or_result.ok()) {
- return TokenError(
- Printf("expects opcode but sees: %s, error: %s", val.c_str(),
- status_or_result.status().error_message().c_str()));
+ return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -2880,7 +3019,7 @@ bool HloParser::ParseFftType(FftType* result) {
}
string val = lexer_.GetStrVal();
if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
- return TokenError(Printf("expects fft type but sees: %s", val.c_str()));
+ return TokenError(StrFormat("expects fft type but sees: %s", val));
}
lexer_.Lex();
return true;
@@ -2894,9 +3033,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
string val = lexer_.GetStrVal();
auto status_or_result = StringToFusionKind(val);
if (!status_or_result.ok()) {
- return TokenError(
- Printf("expects fusion kind but sees: %s, error: %s", val.c_str(),
- status_or_result.status().error_message().c_str()));
+ return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
+ val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -2912,8 +3051,25 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
auto status_or_result = StringToRandomDistribution(val);
if (!status_or_result.ok()) {
return TokenError(
- Printf("expects random distribution but sees: %s, error: %s",
- val.c_str(), status_or_result.status().error_message().c_str()));
+ StrFormat("expects random distribution but sees: %s, error: %s", val,
+ status_or_result.status().error_message()));
+ }
+ *result = status_or_result.ValueOrDie();
+ lexer_.Lex();
+ return true;
+}
+
+bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
+ VLOG(1) << "ParsePrecision";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects random distribution");
+ }
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToPrecision(val);
+ if (!status_or_result.ok()) {
+ return TokenError(StrFormat("expects precision but sees: %s, error: %s",
+ val,
+ status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
@@ -3007,7 +3163,7 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
lexer_.Lex();
OpSharding op_sharding;
if (!ParseSharding(&op_sharding)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after sharding");
@@ -3019,7 +3175,7 @@ StatusOr<Window> HloParser::ParseWindowOnly() {
lexer_.Lex();
Window window;
if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after window");
@@ -3032,7 +3188,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
lexer_.Lex();
ConvolutionDimensionNumbers dnums;
if (!ParseConvolutionDimensionNumbers(&dnums)) {
- return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
@@ -3041,40 +3197,104 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
return dnums;
}
+StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
+ lexer_.Lex();
+ PaddingConfig padding_config;
+ if (!ParsePaddingConfig(&padding_config)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
+ }
+ return padding_config;
+}
+
+Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name) {
+ TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+
+ // The missing instruction hook we register creates the shaped instruction on
+ // the fly as a parameter and returns it.
+ int64 parameter_count = 0;
+ missing_instruction_hook_ =
+ [this, builder, &parameter_count](
+ string name,
+ const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ StrCat("Operand ", name,
+ " had no shape in HLO text; cannot create parameter for "
+ "single-instruction module."));
+ return nullptr;
+ }
+ HloInstruction* parameter = builder->AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, *shape, name));
+ instruction_pool_[name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ };
+
+ // Prime the lexer.
+ lexer_.Lex();
+
+ // Parse the instruction with the registered hook.
+ if (!ParseInstruction(builder, root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config) {
+ absl::string_view str, const HloModuleConfig& config) {
HloParser parser(str, config);
if (!parser.Run()) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
return parser.ConsumeHloModule();
}
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
HloModuleConfig config;
return ParseHloString(str, config);
}
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
+ absl::string_view str, absl::string_view name) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ auto builder = absl::make_unique<HloComputation::Builder>(string(name));
+ string root_name;
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
+ std::unique_ptr<HloComputation> computation = builder->Build();
+ auto module = absl::make_unique<HloModule>(string(name), config);
+ module->AddEntryComputation(std::move(computation));
+ return std::move(module);
+}
+
+StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseShardingOnly();
}
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+StatusOr<Window> ParseWindow(absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str) {
+ absl::string_view str) {
HloModuleConfig config;
HloParser parser(str, config);
return parser.ParseConvolutionDimensionNumbersOnly();
}
+StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParsePaddingConfigOnly();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 3f3a51215e..1882a184da 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -16,7 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
@@ -32,27 +33,34 @@ namespace xla {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str, const HloModuleConfig& config);
+ absl::string_view str, const HloModuleConfig& config);
+
+// Parses the text for a single HLO operation into an HLO module with a function
+// that runs that operation (with the same parameters) as its entry computation.
+StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
+ absl::string_view str, absl::string_view name = "single_op");
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with default config.
-StatusOr<std::unique_ptr<HloModule>> ParseHloString(
- tensorflow::StringPiece str);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
// Parses the result of window_util::ToString(const Window&).
-StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+StatusOr<Window> ParseWindow(absl::string_view str);
// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
// "b0f_0io->b0f".
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
- tensorflow::StringPiece str);
+ absl::string_view str);
// ParseHloString sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
-StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+StatusOr<HloSharding> ParseSharding(absl::string_view str);
+
+// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1".
+StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 6fa3c63d83..cca50fab54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -16,17 +16,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <string>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
-
namespace {
-using ::tensorflow::StringPiece;
+namespace op = ::xla::testing::opcode_matchers;
+using absl::string_view;
struct TestData {
string test_name;
@@ -380,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
}
)"
@@ -393,7 +397,7 @@ R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1)
- ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
+ ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
}
)"
@@ -406,7 +410,7 @@ R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
- ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
+ ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
}
)"
@@ -752,10 +756,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] {
"gather",
R"(HloModule StringifyGather
-ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
+ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1030,8 +1034,8 @@ R"(HloModule gather
ENTRY Gather {
input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
- gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
- ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26}
+ start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
@@ -1049,7 +1053,7 @@ add {
ENTRY CRS {
input = f32[8]{0} parameter(0)
- ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add
+ ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add
}
)"
@@ -1067,7 +1071,7 @@ add {
ENTRY CrossReplicaSumWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
- ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add
}
)"
@@ -1091,7 +1095,19 @@ R"(HloModule AllToAllWithSubgroups
ENTRY AllToAllWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
- ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc"
+ ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}
+}
+
+)"
+},
+// collective-permute
+{
+"CollectivePermute",
+R"(HloModule CollectivePermute
+
+ENTRY CollectivePermute {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
}
)"
@@ -1102,31 +1118,44 @@ ENTRY AllToAllWithSubgroups {
R"(HloModule iota
ENTRY Iota {
- ROOT iota = f32[100]{0} iota()
+ ROOT iota = f32[100]{0} iota(), iota_dimension=0
}
)"
},
-// custom-call with window and dim_labels
+// custom-call with window, dim_labels and feature_group_count
{
-"CustomCallWithWindowAndDimLabels",
-R"(HloModule CustomCallWithWindowAndDimLabels
+"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
+R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
ENTRY Computation {
- ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+ ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
}
)"
+ },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
}
- });
+
+)"
+}
+});
// clang-format on
}
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected))
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
@@ -1390,15 +1419,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
)";
- ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=00_01_10", suffix))
- .status()
- .error_message(),
- "expects dim labels pattern");
+ ExpectHasSubstr(
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
ExpectHasSubstr(
- ParseHloString(tensorflow::strings::StrCat(
- prefix, ",dim_labels=010_1100->010", suffix))
+ ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
"must have the same rank");
@@ -1712,6 +1740,25 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
+ const string original = "0_1x2_3";
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
+ EXPECT_EQ(original, PaddingConfigToString(dnums));
+}
+
+TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
+ const string original = "0_1_0x2_3_4";
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
+ EXPECT_EQ(original, PaddingConfigToString(dnums));
+}
+
+TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
+ TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
+ // The extra "_0" gets added to the canonical string because the other dim has
+ // interior padding.
+ EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
+}
+
TEST_F(HloParserTest, NontupleInfeed) {
const string original = R"(HloModule nontuple_infeed:
ENTRY nontuple_infeed {
@@ -1722,5 +1769,128 @@ ENTRY nontuple_infeed {
"infeed must have a non-empty tuple shape");
}
+TEST(HloParserSingleOpTest, SingleOp) {
+ const string text =
+ "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
+ "f32[2,4]{1,0} %x)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
+ const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ ASSERT_TRUE(!module.status().ok());
+ LOG(INFO) << "Status: " << module.status();
+ EXPECT_THAT(
+ module.status().ToString(),
+ ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
+}
+
+TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
+ const string text =
+ R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Convolution(op::Parameter(0), op::Parameter(1)));
+ auto* convolution =
+ Cast<HloConvolutionInstruction>(computation->root_instruction());
+ EXPECT_EQ(convolution->feature_group_count(), 1);
+}
+
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+ op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+ // As above but in with a different schedule order.
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h
index 28194deb0e..791b1a97b0 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_fix.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h
@@ -45,7 +45,7 @@ class HloPassFix : public Pass {
++iteration_count;
if (iteration_count == limit) {
LOG(ERROR)
- << "Unexpectedly number of iterations in HLO passes ("
+ << "Unexpectedly high number of iterations in HLO passes ("
<< iteration_count
<< ")\nIf compilation hangs here, please file a bug with XLA.";
}
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index 0cddf8fb8f..f1ad0f9b01 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -29,7 +29,7 @@ namespace xla {
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
- virtual tensorflow::StringPiece name() const = 0;
+ virtual absl::string_view name() const = 0;
// Run the pass on the given HLO module. Return whether it modified the
// module.
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index d8f1ab916b..6e4ed0de62 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,22 +17,23 @@ limitations under the License.
#include <functional>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+
+using absl::StrAppend;
+using absl::StrCat;
+
void DumpModuleGraph(const HloModule& module, const string& message) {
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
@@ -48,9 +49,9 @@ void DumpModuleProto(const HloModule& module, const string& dump_to,
tensorflow::mutex_lock lock(mu);
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
- const string mod_name = SanitizeFileName(tensorflow::strings::Printf(
- "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number,
- pipeline_name.c_str(), pass_name.c_str()));
+ const string mod_name = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, pipeline_name, pass_name));
TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
dump_to, mod_name));
@@ -68,7 +69,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
repeated_field.end());
if (!disabled_passes.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << tensorflow::str_util::Join(disabled_passes, ", ");
+ << absl::StrJoin(disabled_passes, ", ");
}
auto run_invariant_checkers = [this,
@@ -90,7 +91,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
return Status::OK();
};
- string prefix = std::string(name()) + ": pipeline start";
+ string prefix = StrCat(name(), ": pipeline start");
bool changed = false;
string message;
TF_RETURN_IF_ERROR(
@@ -98,12 +99,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
const string xla_dump_per_pass_hlo_proto_to =
module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
- std::string(name()), "pipeline_start");
+ DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
+ "pipeline_start");
}
for (auto& pass : passes_) {
- if (disabled_passes.count(std::string(pass->name())) > 0) {
+ if (disabled_passes.count(string(pass->name())) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;
@@ -120,8 +121,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
TF_RETURN_IF_ERROR(
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to,
- std::string(name()), std::string(pass->name()));
+ DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
+ string(pass->name()));
}
changed |= changed_this_pass;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a42d7e59fe..1d41a4dac1 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -34,7 +34,7 @@ namespace xla {
class HloPassPipeline : public HloPassInterface {
public:
explicit HloPassPipeline(const string& name) : name_(name) {}
- tensorflow::StringPiece name() const override { return name_; }
+ absl::string_view name() const override { return name_; }
// Add a pass to the pipeline. It should be called with the arguments for the
// pass constructor:
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679558..b9c0b0c4ee 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@ namespace xla {
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
- HloOrderingProto proto_ordering =
- assignment.liveness().hlo_ordering().ToProto();
BufferAssignmentProto proto_assignment = assignment.ToProto();
HloProto proto = MakeHloProto(module);
- proto.mutable_hlo_ordering()->Swap(&proto_ordering);
proto.mutable_buffer_assignment()->Swap(&proto_assignment);
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
index b9cca13870..c3cacd7ce6 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 01b088a957..961930f0a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
+ absl::Span<const HloInstruction* const> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
@@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap(
}
bool HloReachabilityMap::SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
BitVector& bit_vector = GetBitVector(instruction);
tmp_bit_vector_ = bit_vector;
@@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion(
}
void HloReachabilityMap::FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction) {
SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
}
void HloReachabilityMap::SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 48215d32a8..b66a2aa4bd 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,7 @@ class HloReachabilityMap {
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
explicit HloReachabilityMap(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
+ absl::Span<const HloInstruction* const> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
@@ -54,13 +54,12 @@ class HloReachabilityMap {
// vector in the internal graph of this HloReachabilityMap for the given
// instruction and does not transitively update any other part of the
// adjacency matrix.
- bool SetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
- const HloInstruction* instruction);
+ bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
+ const HloInstruction* instruction);
// As above, but faster because it does not check if the reachability changed.
void FastSetReachabilityToUnion(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction);
// Sets entry so that IsReachable(a, b) will return true
@@ -141,7 +140,7 @@ class HloReachabilityMap {
// Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
void SetReachabilityToUnionHelper(
- tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector);
// Return the index of the given instruction. The value is used to index into
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c95972b..d9848cee0b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index cf0be30c7a..bd6dd79b67 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -20,34 +20,33 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
-using ::tensorflow::strings::HumanReadableNumBytes;
-
namespace xla {
-
namespace {
+using ::tensorflow::strings::HumanReadableNumBytes;
+
// Potential optimizations:
// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
// of candidates.
@@ -88,7 +87,7 @@ bool CanBeRematerialized(
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
-using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
+using BufferIdList = absl::InlinedVector<BufferId, 3>;
// We wrap HloInstruction* with an Item that holds auxiliary
// per-instruction state.
@@ -123,7 +122,7 @@ struct Item {
int64 position;
};
-using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
+using ItemList = absl::InlinedVector<Item*, 3>;
// Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements.
@@ -202,15 +201,14 @@ class InstructionList {
// On object construction this ordinal is precisely the instruction's index
// in the list. Later, instructions inserted via InsertBefore receive
// duplicate values. However, monotonicity is preserved.
- void InsertBeforeInstructions(
- Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
+ void InsertBeforeInstructions(Item* to_insert,
+ absl::Span<Item* const> before_instructions) {
VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
<< " before {"
- << tensorflow::str_util::Join(before_instructions, ", ",
- [](string* out, Item* item) {
- tensorflow::strings::StrAppend(
- out, item->instruction->name());
- })
+ << absl::StrJoin(before_instructions, ", ",
+ [](string* out, Item* item) {
+ absl::StrAppend(out, item->instruction->name());
+ })
<< "}";
// Find the minimal position number of any instruction in
@@ -393,10 +391,9 @@ class MemoryUsageTracker {
int64 unfinished_user_count;
string ToString() const {
- return tensorflow::strings::StrCat(
- "Buffer ", id, " (defined by ",
- defining_instruction->instruction->name(), ", size ", size,
- " bytes)");
+ return absl::StrCat("Buffer ", id, " (defined by ",
+ defining_instruction->instruction->name(), ", size ",
+ size, " bytes)");
}
};
@@ -740,29 +737,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
}
string MemoryUsageTracker::ToString() const {
- string output = tensorflow::strings::StrCat("MemoryUsageTracker for ",
- computation_->name(), "\n");
- tensorflow::strings::StrAppend(
- &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
- memory_usage(), " bytes)");
+ string output =
+ absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
+ absl::StrAppend(&output,
+ "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
+ memory_usage(), " bytes)");
for (auto* item = instruction_list_.first(); item != nullptr;
item = instruction_list_.next(item)) {
const HloInstruction* instruction = item->instruction;
string inprogress = item == in_progress_item_ ? " in-progress" : "";
string placed = item->placed ? " placed" : "";
- tensorflow::strings::StrAppend(&output, " ", instruction->name(),
- inprogress, placed, "\n Defines:\n");
+ absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
+ "\n Defines:\n");
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_[buffer_id];
string live = IsCurrentlyLive(buffer_id) ? " live" : "";
- tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live,
- ", ", buffer.unfinished_user_count,
- " unfinished uses\n");
+ absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
+ buffer.unfinished_user_count, " unfinished uses\n");
}
- tensorflow::strings::StrAppend(&output, " Uses:\n");
+ absl::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : item->buffers_used) {
- tensorflow::strings::StrAppend(&output, " ",
- buffers_[buffer_id].ToString(), "\n");
+ absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
}
}
return output;
@@ -780,10 +775,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(defined_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique defined buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
for (const Buffer& buffer : buffers_) {
@@ -803,10 +797,9 @@ bool MemoryUsageTracker::Check() const {
CHECK(elements_are_unique(used_buffers))
<< "Instruction " << instruction->name()
<< " does not have unique used buffers: "
- << tensorflow::str_util::Join(
+ << absl::StrJoin(
used_buffers, ", ", [this](string* out, BufferId buffer_id) {
- tensorflow::strings::StrAppend(
- out, buffers_.at(buffer_id).ToString());
+ absl::StrAppend(out, buffers_.at(buffer_id).ToString());
});
}
for (const Buffer& buffer : buffers_) {
@@ -968,8 +961,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
}
StatusOr<bool> HloRematerialization::RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
+ HloComputation* computation, HloSchedule* schedule,
int64 memory_limit_bytes) {
VLOG(1) << "Rematerializing computation " << computation->name()
<< " with limit " << HumanReadableNumBytes(memory_limit_bytes);
@@ -977,7 +969,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< HumanReadableNumBytes(computation_peak_memory_.at(computation));
CHECK(!ContainsKey(rematerialized_computations_, computation));
- InstructionList instruction_list(sequence->at(computation));
+ InstructionList instruction_list(
+ schedule->sequence(computation).instructions());
MemoryUsageTracker memory_tracker(computation, size_function_,
*points_to_analysis_, instruction_list);
bool changed = false;
@@ -1151,7 +1144,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
0, memory_limit_bytes - memory_tracker.memory_usage());
TF_ASSIGN_OR_RETURN(
bool subcomputation_changed,
- RematerializeComputation(called_computation, sequence,
+ RematerializeComputation(called_computation, schedule,
subcomputation_memory_limit_bytes));
changed |= subcomputation_changed;
}
@@ -1185,12 +1178,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
computation_peak_memory_.at(computation) = peak_memory;
// Update order to include rematerialized instructions.
- auto& dst = sequence->at(computation);
- dst.clear();
+ HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
+ sequence.clear();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
const HloInstruction* instruction = item->instruction;
- dst.push_back(instruction);
+ sequence.push_back(instruction);
}
rematerialized_computations_.insert(computation);
@@ -1200,16 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(
- HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The sequence is constructed entirely by this method.
- TF_RET_CHECK(sequence->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
VLOG(1) << "HloRematerialization() with memory limit of "
- << HumanReadableNumBytes(memory_limit_bytes);
+ << HumanReadableNumBytes(memory_limit_bytes_);
+ XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+ TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1225,39 +1214,23 @@ StatusOr<bool> HloRematerialization::Run(
});
const int64 adjusted_memory_limit_bytes =
- memory_limit_bytes - module_output_size;
+ memory_limit_bytes_ - module_output_size;
VLOG(1) << "Adjusted memory limit accounting for output ("
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
- XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
- *module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(module, *sequence);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
- }
-
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, sequence](const CallGraphNode& node) -> Status {
+ [this, module](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
ComputePeakMemory(node.computation(),
- sequence->at(node.computation())));
+ module->schedule()
+ .sequence(node.computation())
+ .instructions()));
}
return Status::OK();
},
@@ -1275,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), sequence,
- adjusted_memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(
+ bool changed,
+ RematerializeComputation(module->entry_computation(), &module->schedule(),
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1286,30 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(
// After DCE, the module sequence may include instructions which no longer
// exist.
- for (const auto* computation : module->MakeNonfusionComputations()) {
- if (sequence->at(computation).size() != computation->instruction_count()) {
- // A size mismatch between the computation instruction count and the size
- // of the ordering of instructions can only be caused by DCE. Rebuild the
- // order by removing the deleted instructions from the order.
- tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set;
- for (const auto& instruction : computation->instructions()) {
- instruction_set.insert(instruction);
- }
- // Move the old order into a temporary vector, then build new order
- // inplace.
- std::vector<const HloInstruction*>& order = sequence->at(computation);
- std::vector<const HloInstruction*> old_order;
- using std::swap;
- swap(order, old_order);
- std::copy_if(old_order.begin(), old_order.end(),
- std::back_inserter(order),
- [&instruction_set](const HloInstruction* instruction) {
- return ContainsKey(instruction_set, instruction);
- });
- TF_RET_CHECK(sequence->at(computation).size() ==
- computation->instruction_count());
- }
- }
+ TF_RETURN_IF_ERROR(module->schedule().Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1326,34 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
- if (sizes != nullptr) {
- sizes->before_bytes = before_peak_memory;
- sizes->after_bytes = current_peak_memory;
+ if (sizes_ != nullptr) {
+ sizes_->before_bytes = before_peak_memory;
+ sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
- if (current_peak_memory > memory_limit_bytes) {
- LOG(WARNING) << tensorflow::strings::Printf(
- "Can't reduce memory use below %s (%lld bytes) by rematerialization; "
- "only reduced to %s (%lld bytes)",
- HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
- HumanReadableNumBytes(current_peak_memory).c_str(),
- current_peak_memory);
+ if (current_peak_memory > memory_limit_bytes_) {
+ LOG(WARNING) << absl::StrFormat(
+ "Can't reduce memory use below %s (%d bytes) by rematerialization; "
+ "only reduced to %s (%d bytes)",
+ HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
+ HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
- const HloRematerialization::ShapeSizeFunction& size_function,
- int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
- copy_insertion);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ec004350a..e2aaf18b3e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,16 +17,23 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
namespace xla {
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -37,10 +44,7 @@ class HloRematerialization {
int64 after_bytes;
};
- // Rematerialize HLO instructions in the given module to reduce peak memory
- // use below memory_limit_bytes where memory use is defined as the total size
- // of all live HLO instruction values. Parameters and constants are included
- // in memory use estimates. Method parameters:
+ // Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -48,60 +52,34 @@ class HloRematerialization {
// memory_limit_bytes: The threshold number of bytes to reduce memory use to
// via rematerialization.
//
- // hlo_module: HLO module to rematerialize instructions in.
- //
- // sequence: Should point to an empty HloModuleSequence. Upon return
- // contains the HLO instruction order which was used for
- // rematerialization. This is the order in which HLO instructions should
- // be emitted to minimize memory use.
- //
- // sizes: Optional outparam that indicates the peak memory usage of the HLO
- // module before/after rematerialization.
- //
- // copy_insertion: If non-null, run copy elision after scheduling. This
- // pass is used to eliminate copies that were inserted by copy insertion
- // before HLO scheduling.
- //
- // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
- // insertion is integrated with HLO scheduling.
- //
- // Returns whether any instructions were rematerialized. If memory use is
- // already below the given limit then no instructions are rematerialized and
- // false is returned.
- //
- // CSE will undo the effects of this optimization and should not be run after
- // this pass. In general, this pass should be run very late immediately before
- // code generation.
- static StatusOr<bool> RematerializeAndSchedule(
- const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
-
- protected:
- HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function)
- : scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function) {}
+ // sizes: Pointer to data structure which records the peak memory usage of
+ // the HLO module before/after rematerialization. Value are set during
+ // Run(). Can be nullptr.
+ HloRematerialization(const ShapeSizeFunction& size_function,
+ int64 memory_limit_bytes, RematerializationSizes* sizes)
+ : size_function_(size_function),
+ memory_limit_bytes_(memory_limit_bytes),
+ sizes_(sizes) {}
~HloRematerialization() {}
+ absl::string_view name() const override { return "rematerialization"; }
+
// Runs rematerialization on the given module. Returns whether the module was
- // changed. memory_limit is the target maximum peak memory usage by the
- // module. sequence should be an empty HloModuleSequence. Upon return sequence
- // contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion);
+ // changed. Requires that the module has a schedule set
+ // (HloModule::has_schedule() is true) before running. Returns whether any
+ // instructions were rematerialized. If memory use is already below the limit
+ // specified in the constructor then no instructions are rematerialized and
+ // false is returned.
+ StatusOr<bool> Run(HloModule* module) override;
+ protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
// and inserted into 'order'.
- StatusOr<bool> RematerializeComputation(
- HloComputation* computation,
- SequentialHloOrdering::HloModuleSequence* sequence,
- int64 computation_memory_limit);
+ StatusOr<bool> RematerializeComputation(HloComputation* computation,
+ HloSchedule* schedule,
+ int64 memory_limit_bytes);
// Computes and returns the peak memory used by the given computation. The
// peak memory is the maximum total size of all live HLO instruction values at
@@ -122,6 +100,14 @@ class HloRematerialization {
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
+ // The threshold number of bytes to reduce memory use to via
+ // rematerialization.
+ const int64 memory_limit_bytes_;
+
+ // Pointer to data structure which records the peak memory usage of the HLO
+ // module before/after rematerialization
+ RematerializationSizes* sizes_;
+
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index ac8c97d380..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -141,13 +141,16 @@ class HloRematerializationTest : public HloTestBase {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
- StatusOr<bool> RunHloRematerialization(
- int64 memory_limit_bytes, HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence) {
+ StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- SequentialHloOrdering::HloModuleSequence sequence;
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,9 +189,13 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2],
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
+ .instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3],
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
+ .instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -203,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -242,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -276,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -316,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -382,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- SequentialHloOrdering::HloModuleSequence sequence;
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &sequence));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- SequentialHloOrdering::HloModuleSequence sequence;
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -571,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- SequentialHloOrdering::HloModuleSequence sequence;
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index b2725e2918..fa7f216321 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -32,7 +32,7 @@ limitations under the License.
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
-HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
+HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
@@ -106,7 +106,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals) {
+ const absl::Span<const Literal* const> literals) {
std::vector<ScopedShapedBuffer> buffers;
for (const Literal* literal : literals) {
CHECK(literal != nullptr);
@@ -118,16 +118,16 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
}
StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals) {
+ const absl::Span<const Literal> literals) {
std::vector<const Literal*> literal_pointers;
literal_pointers.reserve(literals.size());
for (const auto& literal : literals) {
- literal_pointers.push_back(literal.get());
+ literal_pointers.push_back(&literal);
}
return TransferLiteralsToDevice(literal_pointers);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
TF_ASSIGN_OR_RETURN(
auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,10 +135,10 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
buffer);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
@@ -150,15 +150,15 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return TransferLiteralFromDevice(result);
}
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
- argument_pointers.push_back(argument.get());
+ argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
@@ -169,8 +169,8 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
stream.Init();
@@ -190,8 +190,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
- bool run_hlo_passes, ExecutionProfile* profile) {
+ const absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
@@ -204,7 +204,7 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
/*profile=*/profile);
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(
@@ -226,14 +226,13 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
// no arguments.
std::vector<const ShapedBuffer*> argument_buffer_ptrs(
options.num_replicas * options.arguments.size() + 1);
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- argument_buffer_slices;
+ std::vector<absl::Span<const ShapedBuffer* const>> argument_buffer_slices;
int64 index = 0;
for (int64 i = 0; i < options.num_replicas; ++i) {
int64 device = device_assignment(i, 0);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device));
- streams.push_back(MakeUnique<se::Stream>(executor));
+ streams.push_back(absl::make_unique<se::Stream>(executor));
streams.back()->Init();
service_run_options.emplace_back(GetServiceRunOptionsForDevice(
device, streams.back().get(), &device_assignment));
@@ -260,7 +259,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
num_threads += options.num_replicas;
}
if (num_threads > 0) {
- pool = MakeUnique<tensorflow::thread::ThreadPool>(
+ pool = absl::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "infeed_outfeed",
/*num_threads=*/num_threads);
}
@@ -291,9 +290,9 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
- auto literal = MakeUnique<Literal>();
+ Literal literal;
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- executor, options.outfeed_shape, literal.get()));
+ executor, options.outfeed_shape, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}
@@ -311,10 +310,10 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
- std::vector<std::unique_ptr<Literal>> exec_results;
+ std::vector<Literal> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
backend().transfer_manager()->TransferLiteralFromDevice(
streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 65537f07f5..2e934bf66a 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -72,7 +72,7 @@ class HloRunner {
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
- std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+ std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
@@ -87,8 +87,7 @@ class HloRunner {
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
- const tensorflow::StringPiece hlo_string,
- const DebugOptions& debug_options);
+ const absl::string_view hlo_string, const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
@@ -105,43 +104,42 @@ class HloRunner {
// Transfers data between the host and device.
StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<const Literal*> literals);
+ const absl::Span<const Literal* const> literals);
StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> literals);
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- const ShapedBuffer& buffer);
+ const absl::Span<const Literal> literals);
+ StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const Literal*> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal* const> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ const absl::Span<const Literal> arguments,
+ bool run_hlo_passes = true,
+ ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const absl::Span<const ShapedBuffer* const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<ScopedShapedBuffer> arguments,
+ const absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
- StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+ StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
new file mode 100644
index 0000000000..3fc5dbeb02
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -0,0 +1,343 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <queue>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+ const HloModule* module, const HloScheduleProto& proto) {
+ tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ for (const HloComputation* computation : module->computations()) {
+ id_to_computation[computation->unique_id()] = computation;
+ }
+
+ HloSchedule schedule(module);
+ for (const auto& id_sequence : proto.sequences()) {
+ int64 computation_id = id_sequence.first;
+
+ auto comp_it = id_to_computation.find(computation_id);
+ TF_RET_CHECK(comp_it != id_to_computation.end())
+ << "No computation exists in HLO module with id " << computation_id;
+ const HloComputation* computation = comp_it->second;
+
+ tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ id_to_instruction[instruction->unique_id()] = instruction;
+ }
+
+ HloInstructionSequence& sequence =
+ schedule.GetOrCreateSequence(computation);
+ for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+ auto instr_it = id_to_instruction.find(instruction_id);
+ TF_RET_CHECK(instr_it != id_to_instruction.end())
+ << "No instruction exists in HLO computation " << computation->name()
+ << " with id " << instruction_id;
+ sequence.push_back(instr_it->second);
+ }
+ }
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+ TF_RETURN_IF_ERROR(Verify());
+ HloScheduleProto proto;
+ for (const auto& id_sequence : sequences_) {
+ int64 computation_id = id_sequence.first;
+ const HloInstructionSequence& sequence = id_sequence.second;
+ HloScheduleProto::InstructionSequence& proto_sequence =
+ (*proto.mutable_sequences())[computation_id];
+ proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+ for (const int64 id : sequence.ids()) {
+ proto_sequence.add_instruction_ids(id);
+ }
+ }
+ return std::move(proto);
+}
+
+void HloSchedule::set_sequence(
+ const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence) {
+ set_sequence(computation, HloInstructionSequence(sequence));
+}
+
+void HloSchedule::set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence) {
+ CHECK(computation->parent() == module_);
+ sequences_[computation->unique_id()] = std::move(sequence);
+}
+
+HloInstructionSequence& HloSchedule::GetOrCreateSequence(
+ const HloComputation* computation) {
+ auto it = sequences_.find(computation->unique_id());
+ if (it == sequences_.end()) {
+ // No sequence found for computation. Create and return an empty one.
+ CHECK(computation->parent() == module_);
+ return sequences_[computation->unique_id()];
+ } else {
+ return it->second;
+ }
+}
+
+const HloInstructionSequence& HloSchedule::sequence(
+ const HloComputation* computation) const {
+ return sequences_.at(computation->unique_id());
+}
+
+Status HloSchedule::UpdateComputationSchedule(
+ const HloComputation* computation) {
+ // Map from unique ID to HloInstruction pointer for instructions in the
+ // computation.
+ tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
+ }
+
+ // Set of all HloInstructions in the schedule.
+ tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ InsertOrDie(&ids_in_schedule, id);
+ }
+
+ // Map from HloInstruction X to newly added instructions (instruction is in
+ // computation, but not in schedule) which use X. If an instruction is not in
+ // the map, then it has no users which are newly added instructions.
+ tensorflow::gtl::FlatMap<const HloInstruction*,
+ std::vector<const HloInstruction*>>
+ new_instruction_uses;
+
+ // For each newly added instruction, this is the count of the instruction's
+ // operands that have not yet been scheduled. When this value reaches zero,
+ // then the instruction may be placed in the schedule.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int>
+ unscheduled_operand_count;
+
+ // Create a worklist of newly added instructions which are ready to be added
+ // to the schedule. Initialize worklist with those that have zero operands.
+ std::queue<const HloInstruction*> worklist;
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ if (ids_in_schedule.count(instruction->unique_id()) == 0) {
+ // This is a newly added instruction which is not in the schedule.
+ if (instruction->operands().empty()) {
+ worklist.push(instruction);
+ } else {
+ for (const HloInstruction* operand : instruction->operands()) {
+ new_instruction_uses[operand].push_back(instruction);
+ }
+ unscheduled_operand_count[instruction] = instruction->operand_count();
+ }
+ }
+ }
+
+ // Update the schedule with the newly added instructions, and remove any
+ // instructions no longer in the graph.
+ HloInstructionSequence new_sequence;
+
+ // Lambda which schedules all instructions on the worklist.
+ auto schedule_worklist = [&]() {
+ while (!worklist.empty()) {
+ const HloInstruction* instruction = worklist.front();
+ worklist.pop();
+ new_sequence.push_back(instruction);
+ std::vector<const HloInstruction*>* new_users =
+ tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
+ if (new_users != nullptr) {
+ // This just-scheduled instruction has users which are newly added to
+ // the module. Update the number of unscheduled operands and push the
+ // newly added instruction to the worklist if it is ready to
+ // schedule.
+ for (const HloInstruction* new_user : *new_users) {
+ unscheduled_operand_count.at(new_user)--;
+ CHECK_GE(unscheduled_operand_count.at(new_user), 0);
+ if (unscheduled_operand_count.at(new_user) == 0) {
+ worklist.push(new_user);
+ }
+ }
+ }
+ }
+ };
+
+ schedule_worklist();
+ for (int id : sequences_.at(computation->unique_id()).ids()) {
+ auto it = id_to_instruction.find(id);
+ if (it == id_to_instruction.end()) {
+ // This instruction in the schedule is no longer in the module. Do not add
+ // it to the new schedule.
+ continue;
+ }
+ worklist.push(it->second);
+ schedule_worklist();
+ }
+
+ set_sequence(computation, std::move(new_sequence));
+ return Status::OK();
+}
+
+Status HloSchedule::Update() {
+ // The schedule must contain a sequence for every non-fusion computation in
+ // the module, but can have sequences for computations which no longer exist
+ // (these are removed).
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name() << " not in HloSchedule.";
+ }
+ if (sequences_.size() > nonfusion_computations.size()) {
+ // Schedule contains some computations which have been removed from the
+ // HloModule. Remove them from the schedule as well.
+ tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ for (const HloComputation* computation : nonfusion_computations) {
+ nonfusion_computations_ids.insert(computation->unique_id());
+ }
+ for (auto it = sequences_.begin(); it != sequences_.end();) {
+ if (nonfusion_computations_ids.count(it->first) == 0) {
+ it = sequences_.erase(it);
+ } else {
+ it++;
+ }
+ }
+ }
+ CHECK_EQ(sequences_.size(), nonfusion_computations.size());
+
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
+ }
+
+ TF_RETURN_IF_ERROR(Verify());
+ return Status::OK();
+}
+
+Status HloSchedule::Verify() const {
+ VLOG(2) << "VerifySchedule()";
+ XLA_VLOG_LINES(3, module_->ToString());
+ XLA_VLOG_LINES(2, ToString());
+
+ // Verify schedule contains exactly the same set of non-fusion computations as
+ // module currently does.
+ std::vector<HloComputation*> nonfusion_computations =
+ module_->MakeNonfusionComputations();
+ TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
+ << "Schedule has " << sequences_.size() << " sequences, but module has "
+ << nonfusion_computations.size() << " non-fusion computations";
+ for (const HloComputation* computation : nonfusion_computations) {
+ TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+ << "Computation " << computation->name()
+ << " missing from HLO schedule.";
+ }
+
+ // For each computation verify the set of instructions is the same and that
+ // each dependency and control edge is honored.
+ for (const HloComputation* computation : nonfusion_computations) {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ int pos = 0;
+ for (const HloInstruction* instruction :
+ sequence(computation).instructions()) {
+ TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
+ << "Instruction " << instruction->name()
+ << " appears more than once in the schedule";
+ pos++;
+ }
+
+ TF_RET_CHECK(instruction_position.size() ==
+ computation->instruction_count());
+ for (const HloInstruction* instruction : computation->instructions()) {
+ TF_RET_CHECK(instruction_position.count(instruction) == 1)
+ << "Instruction " << instruction->name() << " is not in schedule";
+ }
+
+ for (const HloInstruction* instruction : computation->instructions()) {
+ for (const HloInstruction* operand : instruction->operands()) {
+ TF_RET_CHECK(instruction_position.at(operand) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its operand " << operand->name();
+ }
+
+ for (const HloInstruction* pred : instruction->control_predecessors()) {
+ TF_RET_CHECK(instruction_position.at(pred) <
+ instruction_position.at(instruction))
+ << "Instruction " << instruction->name()
+ << " is not scheduled after its control predecessor "
+ << pred->name();
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+namespace {
+
+// Returns the computation in the given module with the given unique ID. Returns
+// nullptr if no such computation exists.
+const HloComputation* IdToComputation(const HloModule* module, int64 id) {
+ for (const HloComputation* computation : module->computations()) {
+ if (computation->unique_id() == id) {
+ return computation;
+ }
+ }
+ return nullptr;
+}
+
+} // namespace
+
+string HloSchedule::ToString() const {
+ std::vector<string> pieces;
+
+ pieces.push_back("HloSchedule");
+ for (const auto& id_sequence : sequences_) {
+ const HloComputation* computation =
+ IdToComputation(module_, id_sequence.first);
+ if (computation == nullptr) {
+ // The computation is not in the module and may have been deleted so it is
+ // not safe to dereference any HLO pointers. Just use the HLO unique ids
+ // stored in this object.
+ pieces.push_back(
+ absl::StrFormat("computation with id %d (no longer in HLO module):",
+ id_sequence.first));
+ for (int id : id_sequence.second.ids()) {
+ pieces.push_back(absl::StrCat(" ", id));
+ }
+ } else {
+ pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
+ for (const HloInstruction* instruction :
+ id_sequence.second.instructions()) {
+ pieces.push_back(absl::StrCat(" ", instruction->name()));
+ }
+ }
+ }
+ return absl::StrJoin(pieces, "\n");
+}
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
+ out << schedule.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
new file mode 100644
index 0000000000..270fe6039f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+
+class HloModule;
+
+// Class representing a sequence of HLO instructions such as the sequential
+// execution order of an HLO computation.
+class HloInstructionSequence {
+ public:
+ HloInstructionSequence() = default;
+ explicit HloInstructionSequence(
+ absl::Span<const HloInstruction* const> instructions) {
+ for (const HloInstruction* instruction : instructions) {
+ push_back(instruction);
+ }
+ }
+
+ // Adds the instruction to the end of the sequence.
+ void push_back(const HloInstruction* instruction) {
+ instruction_sequence_.push_back(instruction);
+ id_sequence_.push_back(instruction->unique_id());
+ }
+
+ // Clears the sequence of all instructions.
+ void clear() {
+ instruction_sequence_.clear();
+ id_sequence_.clear();
+ }
+
+ int64 size() const { return instruction_sequence_.size(); }
+
+ // Returns the sequence of HLO instructions.
+ const std::vector<const HloInstruction*>& instructions() const {
+ return instruction_sequence_;
+ }
+
+ // Returns the unique IDs of the instructions in the sequence (in order).
+ const std::vector<int>& ids() const { return id_sequence_; }
+
+ private:
+ // The sequence as HloInstructions.
+ std::vector<const HloInstruction*> instruction_sequence_;
+
+ // The sequence of HLO instructions, represented by their unique IDs. The
+ // sequence is stored as both HloInstructions and unique IDs because the
+ // sequence may be referenced after transformations to the HLO graph and HLO
+ // pointers can be invalidated or recycled in this process (see
+ // HloSchedule::Update).
+ std::vector<int> id_sequence_;
+};
+
+// A class representing a sequential schedule of instructions for an HLO
+// module. A complete HLO schedule contains an instruction sequence for every
+// non-fusion computation in the HLO module.
+class HloSchedule {
+ public:
+ explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+ // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+ static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+ const HloScheduleProto& proto);
+ StatusOr<HloScheduleProto> ToProto() const;
+
+ // Returns a reference to the sequence for the given computation.
+ const HloInstructionSequence& sequence(
+ const HloComputation* computation) const;
+
+ // Returns the sequence for the given computation. An empty sequence is
+ // created if none exists for the computation.
+ HloInstructionSequence& GetOrCreateSequence(
+ const HloComputation* computation);
+
+ // Sets the sequence for the given computation to the given sequence.
+ void set_sequence(const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence);
+ void set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence);
+
+ // Returns a map from HloComputation unique ID to instruction sequence. The
+ // map contains all sequences in the schedule.
+ const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
+ const {
+ return sequences_;
+ }
+
+ // Returns true if the schedule has a sequence for the given computation.
+ bool is_computation_scheduled(const HloComputation* computation) const {
+ return sequences_.count(computation->unique_id()) == 1;
+ }
+
+ // Updates the schedule such that it is (again) a valid schedule for the
+ // module. This is used to update a schedule after the HLO module has been
+ // transformed in some way. In general, the only transformations to the module
+ // for which a schedule can be updated is the addition or removal of
+ // instructions and removal of computations. Updating the schedule after new
+ // dependencies between existing instructions in the module is not supported
+ // and may result in an error status returned.
+ //
+ // Instructions in the module which also exist in the given schedule will
+ // remain in the same order in the updated schedule. Instructions which exist
+ // in the module but not in the given schedule will be placed as early as
+ // possible in the updated schedule.
+ Status Update();
+
+ // Verifies that the given schedule is valid for the given module.
+ // Specifically, the schedule contains exactly the instructions in the
+ // non-fusion computations in the module and every dependency in the module is
+ // satisfied in the schedule.
+ Status Verify() const;
+
+ string ToString() const;
+
+ bool empty() const { return sequences_.empty(); }
+
+ const HloModule* module() const { return module_; }
+
+ private:
+ // Updates the instruction sequence for the given computation.
+ Status UpdateComputationSchedule(const HloComputation* computation);
+
+ const HloModule* module_;
+
+ // A map from computation unique ID to instruction sequence. Unique IDs are
+ // used rather than HloComputation pointers because HLO pointers are not
+ // unique across HLO transformations because pointers may be recycled.
+ tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
new file mode 100644
index 0000000000..1424569ac1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloScheduleTest : public HloTestBase {};
+
+TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
+ // Updating the schedule of an unchanged HLO module should not affect the
+ // schedule at all.
+ const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ const std::vector<const HloInstruction*>& entry_schedule =
+ schedule.sequence(module->entry_computation()).instructions();
+
+ EXPECT_EQ(entry_schedule.size(), 6);
+
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(entry_schedule,
+ schedule.sequence(module->entry_computation()).instructions());
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
+ // Add some additional instructions to a module and verify the schedule can be
+ // updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ HloComputation* entry = module->entry_computation();
+ const Shape shape = entry->root_instruction()->shape();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+ entry->set_root_instruction(sub);
+
+ auto in_schedule = [&](const HloInstruction* hlo) {
+ return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
+ };
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+ EXPECT_FALSE(in_schedule(constant));
+ EXPECT_FALSE(in_schedule(sub));
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 8);
+ EXPECT_TRUE(in_schedule(constant));
+ EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+ // Add and delete some instructions from a module and verify that the schedule
+ // can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ c = f32[] constant(42.0)
+ sum = f32[] add(a, b)
+ neg = f32[] negate(c)
+ ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Set the entry root to some expression containing just a parameter and a
+ // constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ HloInstruction* new_root = entry->AddInstruction(
+ HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+ constant, entry->parameter_instruction(0)));
+ entry->set_root_instruction(new_root);
+
+ // DCE should remove everything but the parameters and the newly added code.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 6);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 4);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
+ // Completely replace a module with an entirely new set of instructions and
+ // verify that the schedule can be updated successfully.
+ const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+ a = f32[] constant(42.0)
+ b = f32[] constant(123.0)
+ ROOT sum = f32[] add(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+
+ // Replace the entry computation with the negation of a constant.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* constant = entry->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ entry->set_root_instruction(new_root);
+
+ // DCE the old instructions.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 3);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(entry).size(), 2);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
+ // Create changes to more than one computation in an HLO module and verify
+ // that the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ const HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->operand(0);
+ HloComputation* body = xla_while->while_body();
+ HloComputation* cond = xla_while->while_condition();
+
+ // Negate the root of the cond.
+ cond->set_root_instruction(cond->AddInstruction(
+ HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kNot, cond->root_instruction())));
+
+ // Replace the body with a computation which just passes through its
+ // parameter.
+ body->set_root_instruction(body->parameter_instruction(0));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 7);
+ EXPECT_EQ(schedule.sequence(cond).size(), 4);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+
+ EXPECT_EQ(schedule.sequence(body).size(), 1);
+ EXPECT_EQ(schedule.sequence(cond).size(), 5);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
+ // Remove computations from a module and verify the schedule can be updated.
+ const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(
+ HloSchedule schedule,
+ ScheduleModule(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(),
+ /*pointer_size=*/sizeof(void*));
+ }));
+
+ HloInstruction* xla_while =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ HloInstruction* init = xla_while->mutable_operand(0);
+
+ // Replace the while with its init value. The conditional and body
+ // computations should then be dead.
+ TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
+
+ // DCE the dead code in the body.
+ HloDCE dce;
+ ASSERT_EQ(module->computation_count(), 3);
+ TF_ASSERT_OK(dce.Run(module.get()).status());
+ ASSERT_EQ(module->computation_count(), 1);
+
+ ASSERT_IS_NOT_OK(schedule.Verify());
+ TF_ASSERT_OK(schedule.Update());
+ TF_ASSERT_OK(schedule.Verify());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 879fb3bbab..de7e6b53d4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -15,13 +15,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
+using absl::StrJoin;
HloSharding HloSharding::AssignDevice(int64 device_id) {
return HloSharding(device_id);
@@ -53,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
return HloSharding(flattened_list);
}
-HloSharding HloSharding::Tuple(
- const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings) {
+HloSharding HloSharding::Tuple(const Shape& tuple_shape,
+ absl::Span<const HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
for (auto& sharding : shardings) {
CHECK(!sharding.IsTuple()) << sharding.ToString();
@@ -71,12 +71,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
const HloSharding& sharding) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
CHECK(!sharding.IsTuple()) << sharding.ToString();
- int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ int64 leaf_count = RequiredLeaves(tuple_shape);
std::vector<HloSharding> flattened_list;
- flattened_list.reserve(leaf_count);
- for (int64 i = 0; i < leaf_count; ++i) {
- flattened_list.push_back(sharding);
- }
+ flattened_list.resize(leaf_count, sharding);
return HloSharding(flattened_list);
}
@@ -92,7 +89,7 @@ string HloSharding::ToString() const {
for (const HloSharding& element : tuple_elements_) {
parts.push_back(element.ToString());
}
- return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+ return StrCat("{", absl::StrJoin(parts, ", "), "}");
}
if (replicated_) {
@@ -101,8 +98,8 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]",
- Join(tile_assignment_, ","), "}");
+ return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","),
+ "]", StrJoin(tile_assignment_, ","), "}");
}
}
@@ -144,7 +141,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
- tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
+ tile_assignment_.Each([&](absl::Span<const int64> index, int64 d) {
if (d == device) {
ret_index = {index.begin(), index.end()};
}
@@ -153,8 +150,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
return ret_index;
}
-int64 HloSharding::DeviceForTileIndex(
- tensorflow::gtl::ArraySlice<int64> index) const {
+int64 HloSharding::DeviceForTileIndex(absl::Span<const int64> index) const {
CHECK(!replicated_);
CHECK(!IsTuple());
if (maximal_) {
@@ -244,16 +240,16 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
-tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
+absl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
- tensorflow::gtl::optional<int64> unique_device;
+ absl::optional<int64> unique_device;
for (auto& tuple_sharding : tuple_elements_) {
auto device = tuple_sharding.UniqueDevice();
if (!device || (unique_device && *device != *unique_device)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
unique_device = device;
}
@@ -262,7 +258,7 @@ tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
int64 HloSharding::GetUniqueDevice() const {
@@ -321,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
Status status = Status::OK();
std::set<int64> seen_cores;
tile_assignment_.Each(
- [&](tensorflow::gtl::ArraySlice<int64> indices, int32 core) {
+ [&](absl::Span<const int64> indices, int32 core) {
// Don't overwrite a bad status, so we report the first error.
if (status.ok()) {
if (core >= num_devices) {
@@ -431,29 +427,39 @@ Shape HloSharding::TileShape(const Shape& shape) const {
HloSharding HloSharding::GetSubSharding(const Shape& shape,
const ShapeIndex& index) const {
CHECK(IsTuple());
-
- Shape sub_shape = ShapeUtil::GetSubshape(shape, index);
- ShapeTree<HloSharding> sub_shape_tree(sub_shape, Replicate());
- sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {});
- return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree)
- : sub_shape_tree.element(ShapeIndex({}));
+ int64 sharding_index = 0;
+ const Shape* sub_shape = &shape;
+ for (int64 idx : index) {
+ for (int64 i = 0; i < idx; ++i) {
+ sharding_index +=
+ ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
+ }
+ sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
+ }
+ if (ShapeUtil::IsTuple(*sub_shape)) {
+ auto begin_it = tuple_elements_.begin() + sharding_index;
+ std::vector<HloSharding> sub_shardings(
+ begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
+ return HloSharding::Tuple(*sub_shape, sub_shardings);
+ } else {
+ return tuple_elements_[sharding_index];
+ }
}
-tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
- const {
+absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
if (!IsTuple()) {
return *this;
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
- return tensorflow::gtl::optional<HloSharding>();
+ return absl::nullopt;
}
}
return tuple_elements_.front();
}
size_t HloSharding::Hash() const {
- if (!tuple_) {
+ if (tuple_) {
size_t h = 0;
for (const auto& element : tuple_elements_) {
h = tensorflow::Hash64Combine(h, element.Hash());
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 894783e5d1..9775505f86 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -23,12 +23,12 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -66,7 +66,7 @@ class HloSharding {
// shardings must match the number of leaf nodes in tuple_shape. For
// empty tuples, the shardings array must have one element.
static HloSharding Tuple(const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings);
+ absl::Span<const HloSharding> shardings);
// Creates a new sharding for a tuple type, with a single input sharding
// repeated on each leaf.
@@ -132,7 +132,7 @@ class HloSharding {
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
// REQUIRES: !IsTuple()
- int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
+ int64 DeviceForTileIndex(absl::Span<const int64> index) const;
// Given a device ID, returns the offset within the specified shape of the
// tile that should be executed on the given core. This returns the lower
@@ -151,7 +151,7 @@ class HloSharding {
// span a single device, the return value will be empty.
// In order for a sharding to span a single device, every leaf sharding must
// be maximal and not replicated, and the used device must match.
- tensorflow::gtl::optional<int64> UniqueDevice() const;
+ absl::optional<int64> UniqueDevice() const;
// Retrieves the unique device or fails with a CHECK.
int64 GetUniqueDevice() const;
@@ -182,7 +182,7 @@ class HloSharding {
// be returned. If it is a tuple, and all the tuple elements are common, the
// common element will be returned. Otherwise the optional will contain no
// value.
- tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const;
+ absl::optional<HloSharding> ExtractSingleSharding() const;
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
@@ -260,9 +260,9 @@ class HloSharding {
bool maximal_;
bool tuple_;
Array<int64> tile_assignment_;
- // Only non-empty when tuple_ is true, but because empty tuples are allowed
- // may also be empty even then. This is a flattened list of all the leaf
- // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+ // Only non-empty when tuple_ is true. If a tuple is empty then one entry is
+ // present for the root. This is a flattened list of all the leaf shardings in
+ // a tuple shape, by pre-order walk (ShapeTree iterator order).
std::vector<HloSharding> tuple_elements_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index a2c1d39d0d..e3f4a9852a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -23,6 +24,23 @@ namespace xla {
namespace {
+// AssignmentKind and kUnassignedDevice are used during tuple domain sharding
+// propagation in order to distinguish among three cases:
+// kUnassigned: no assignment has occurred
+// kAssigned: at least an assignment has occurred
+// kConflict: no assignment has occurred because of conflicting propagations,
+// which occurs when multiple users of an instruction have different
+// shardings.
+enum class AssignmentKind { kUnassigned, kAssigned, kConflict };
+
+// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
+// absence of sharding information for that particular sub-sharding during
+// sharding propagation. It is used to be able to express tuple shardings with
+// partial information. At the end of the propagation the sharding of
+// tuple-shaped instructions using kUnassignedDevice's is cleared.
+// TODO(b/112883246): Centralized enum of reserved devices.
+constexpr int64 kUnassignedDevice = -2;
+
struct PassThrough {
PassThrough(HloInstruction* user, HloInstruction* operand)
: user(user), operand(operand) {}
@@ -117,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
return Status::OK();
}
-std::unique_ptr<HloSharding> CloneShardingForDomain(
- const HloSharding& sharding) {
- auto single_sharding = sharding.ExtractSingleSharding();
+// For tuple shardings if every element have the same sharsing then we want to
+// treat them as single element sharsings to insert less domain separation as a
+// domain can prevent some optimizations and we want to minimize that from
+// happening.
+std::shared_ptr<const HloSharding> CloneShardingForDomain(
+ std::shared_ptr<const HloSharding> sharding) {
+ auto single_sharding = sharding->ExtractSingleSharding();
if (!single_sharding) {
- return MakeUnique<HloSharding>(sharding);
+ return sharding;
}
- return MakeUnique<HloSharding>(*single_sharding);
+ return std::make_shared<const HloSharding>(*single_sharding);
}
Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
@@ -142,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
return Status::OK();
}
-// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree.
-// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate()
-// sharding will be returned.
-ShapeTree<HloSharding> GetTupleSharding(HloInstruction* tuple) {
- if (tuple->has_sharding()) {
- return tuple->sharding().GetAsShapeTree(tuple->shape());
+// Return the ShapeTree<HloSharding> of the user argument. The user argument
+// is assumed to be a user of the instruction argument.
+// If user is a tuple instruction, return the tuple subsharding corresponding to
+// the operand matching the instruction argument, because that is the
+// subsharding corresponding to instruction.
+ShapeTree<HloSharding> GetShardingTreeFromUser(
+ const HloInstruction& instruction, const HloInstruction& user) {
+ if (user.opcode() == HloOpcode::kTuple) {
+ return user.sharding()
+ .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
+ .GetAsShapeTree(instruction.shape());
+ }
+ return user.sharding().GetAsShapeTree(user.shape());
+}
+
+// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
+// then no assignment is made. Therefore kUnassignedDevice is never propagated.
+// kConflict is returned if lhs is already assigned and rhs is assigned to a
+// different device.
+StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
+ const HloSharding& rhs) {
+ TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
+ if (rhs.UsesDevice(kUnassignedDevice)) {
+ return AssignmentKind::kUnassigned;
}
- return ShapeTree<HloSharding>(tuple->shape(), HloSharding::Replicate());
+ if (lhs->UsesDevice(kUnassignedDevice)) {
+ *lhs = rhs;
+ return AssignmentKind::kAssigned;
+ }
+ return lhs->UniqueDevice() != rhs.UniqueDevice()
+ ? AssignmentKind::kConflict
+ : AssignmentKind::kUnassigned;
}
-// Retrieves the sharding of operand, asked from a user instruction which is
-// within domain. If operand is a kDomain, it means that sharding argument is
-// the operand sharding, otherwise the operand's own sharding will be returned.
-const HloSharding* GetOperandSharding(const HloInstruction* operand,
+// Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
+// In case of conflicting assignment AssignmentKind::kConflict is returned. In
+// this case lhs_tree is partially assigned, up to the conflicting leaf. It is
+// up to the caller to discard the partial assignment in case of conflict.
+StatusOr<AssignmentKind> AssignTreeSharding(
+ ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
+ const ShapeTree<HloSharding>& rhs_tree) {
+ AssignmentKind assigned = AssignmentKind::kUnassigned;
+ auto rhs_it = rhs_tree.begin();
+ for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
+ ++lhs_it, ++rhs_it) {
+ // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
+ if (rhs_tree.IsLeaf(rhs_it->first)) {
+ TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
+ TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
+ AssignLeafSharding(&lhs_it->second, rhs_it->second));
+ if (sub_assigned == AssignmentKind::kConflict) {
+ // In case of conflict we return conflict to the caller. At this point
+ // partial assignments to lhs_tree may have been made already. It is up
+ // to the caller to discard the partial assignment in case of conflict.
+ return AssignmentKind::kConflict;
+ } else if (sub_assigned == AssignmentKind::kAssigned) {
+ assigned = sub_assigned;
+ }
+ }
+ }
+ TF_RET_CHECK(rhs_it == rhs_tree.end());
+ return assigned;
+}
+
+StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
const DomainMetadata::Domain& domain,
- const HloSharding& sharding) {
- // Here the user of operand is within the domain instruction set, and since it
- // is user of operand, we need to look into the enter_domains set. If this is
- // not a kDomain within the user domains set, then return the operand
- // sharding, if any.
- if (operand->opcode() != HloOpcode::kDomain ||
- domain.enter_domains.count(const_cast<HloInstruction*>(operand)) == 0) {
- return operand->has_sharding() ? &operand->sharding() : nullptr;
+ const HloSharding& domain_sharding) {
+ if (instruction->users().empty()) {
+ // No sharding from users, use domain_sharding, after checking
+ // compatibility.
+ TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) &&
+ ShapeUtil::GetLeafCount(instruction->shape()) ==
+ domain_sharding.tuple_elements().size());
+ instruction->set_sharding(domain_sharding);
+ return true;
+ }
+ AssignmentKind assigned = AssignmentKind::kUnassigned;
+ // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
+ // subshardings can result in a final sharding assignment containing
+ // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
+ // used by users that don't have a sharding.
+ // Non-tuple shardings are either assigned to a real sharding, or are not
+ // assigned at all. As such they will never get assigned to kUnassignedDevice.
+ // In any case, kUnassignedDevice is never propagated, from the implementation
+ // of AssignLeafSharding.
+ ShapeTree<HloSharding> sharding_tree(
+ instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
+ for (HloInstruction* user : instruction->users()) {
+ if (user->opcode() == HloOpcode::kDomain &&
+ domain.exit_domains.count(const_cast<HloInstruction*>(user)) > 0) {
+ // If a user is a domain and it is registered in the domain exits, then
+ // the instruction sharding is taken directly from the domain, and no
+ // further users need to be visited.
+ instruction->set_sharding(domain_sharding);
+ return true;
+ }
+ if (!user->has_sharding()) {
+ continue;
+ }
+ AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
+ ShapeTree<HloSharding> user_sharding_tree =
+ GetShardingTreeFromUser(*instruction, *user);
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // For tuple-shaped instructions collect individual tuple subshardings
+ // from the uses, and then combine them into the tuple sharding.
+ // If the user is a GTE its sharding concerns only the subtree of
+ // sharding_tree at index user->tuple_index, otherwise the whole
+ // sharding_tree is affected.
+ ShapeTree<HloSharding>::iterator sharding_tree_begin =
+ user->opcode() == HloOpcode::kGetTupleElement
+ ? sharding_tree.find({user->tuple_index()})
+ : sharding_tree.begin();
+ TF_ASSIGN_OR_RETURN(
+ sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
+ user_sharding_tree));
+ } else {
+ // Non-tuple shape: assign common users sharding.
+ TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
+ << "Expected non-tuple user sharding";
+ TF_ASSIGN_OR_RETURN(
+ sub_assigned,
+ AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
+ user_sharding_tree));
+ }
+
+ if (sub_assigned == AssignmentKind::kConflict) {
+ // In case of conflict we don't assign any sharding.
+ return false;
+ } else if (sub_assigned == AssignmentKind::kAssigned) {
+ assigned = sub_assigned;
+ }
+ }
+
+ if (assigned == AssignmentKind::kAssigned) {
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ instruction->set_sharding(HloSharding::Tuple(sharding_tree));
+ } else {
+ TF_RET_CHECK(sharding_tree.leaf_count() == 1);
+ instruction->set_sharding(sharding_tree.leaf_begin()->second);
+ }
+ return true;
}
- // At this point operand is a kDomain of the currently processed domain, so we
- // can refer to sharding as the domain sharding.
- return &sharding;
+ return false;
}
// Tries to propagate the sharding information into the instructions that are
-// part of the domain, in a post order manner (operand propagate to user).
+// part of the domain, in a reverse post order manner (users propoagate to
+// instruction).
StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
- const HloSharding& sharding) {
+ const HloSharding& domain_sharding) {
int64 assigned = 0;
- for (HloInstruction* instruction : domain.instructions) {
+ // domain.instructions are ordered in a post-order manner. As we do
+ // user->operand propagation we process instructions in reverse order. In so
+ // doing we are guaranteed to process all users before their operands.
+ for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
+ ++it) {
+ HloInstruction* instruction = *it;
if (instruction->has_sharding()) {
continue;
}
- if (instruction->opcode() == HloOpcode::kGetTupleElement) {
- HloInstruction* tuple = instruction->mutable_operand(0);
- const HloSharding* tuple_sharding =
- GetOperandSharding(tuple, domain, sharding);
- if (tuple_sharding != nullptr) {
- if (tuple_sharding->IsTuple()) {
- HloSharding sub_sharding = tuple_sharding->GetSubSharding(
- tuple->shape(), {instruction->tuple_index()});
- VLOG(4) << " " << instruction->name() << " to sharding "
- << sub_sharding;
- instruction->set_sharding(sub_sharding);
- } else {
- SetSingleSharding(instruction, *tuple_sharding);
- }
- ++assigned;
- }
- } else if (instruction->opcode() == HloOpcode::kTuple) {
- int64 tuple_assigned = 0;
- ShapeTree<HloSharding> shape_tree = GetTupleSharding(instruction);
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const HloSharding* operand_sharding =
- GetOperandSharding(instruction->operand(i), domain, sharding);
- if (operand_sharding != nullptr) {
- HloSharding operand_subsharding = HloSharding::Replicate();
- if (operand_sharding == &sharding) {
- operand_subsharding =
- sharding.GetSubSharding(instruction->shape(), {i});
- operand_sharding = &operand_subsharding;
- }
- if (shape_tree.element({i}) != *operand_sharding) {
- *shape_tree.mutable_element({i}) = *operand_sharding;
- ++tuple_assigned;
- }
- }
- }
- if (tuple_assigned > 0) {
- HloSharding tuple_sharding = HloSharding::Tuple(shape_tree);
- VLOG(4) << " " << instruction->name() << " to sharding "
- << tuple_sharding;
- instruction->set_sharding(tuple_sharding);
- ++assigned;
- }
- } else {
- // If all the operand of the given instruction has the same single device
- // assignment, assign that device to this instruction as well.
- const HloSharding* common_sharding = nullptr;
- for (const HloInstruction* operand : instruction->operands()) {
- const HloSharding* operand_sharding =
- GetOperandSharding(operand, domain, sharding);
- if (operand_sharding != nullptr) {
- if (common_sharding != nullptr &&
- *common_sharding != *operand_sharding) {
- common_sharding = nullptr;
- break;
- }
- common_sharding = operand_sharding;
- }
- }
- if (common_sharding != nullptr) {
- VLOG(4) << " " << instruction->name() << " to sharding "
- << *common_sharding;
- instruction->set_sharding(*common_sharding);
- ++assigned;
- }
+ // Take the sharding from the users.
+ TF_ASSIGN_OR_RETURN(
+ bool instruction_assigned,
+ ApplyShardingFromUsers(instruction, domain, domain_sharding));
+ if (instruction_assigned) {
+ ++assigned;
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << instruction->sharding();
}
}
return assigned;
@@ -261,83 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
return ApplyDomainSingleSharding(domain, *single_sharding);
}
VLOG(1) << "Assigning non-trivial sharding " << sharding;
- for (;;) {
- TF_ASSIGN_OR_RETURN(int64 assigned,
- ApplyDomainShardingPass(domain, sharding));
- if (assigned == 0) {
- break;
- }
- }
+ TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());
+
int64 unassigned = 0;
for (HloInstruction* instruction : domain.instructions) {
if (!instruction->has_sharding()) {
LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
++unassigned;
+ } else {
+ // Un-set sharding of tuples whose sub-sgardings are assigned to
+ // kUnassignedDevice. Indeed in case of doubt it is better to leave the
+ // entire tuple unassigned, and let the device placer decide for it.
+ if (instruction->sharding().UsesDevice(kUnassignedDevice)) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()))
+ << "Only tuples can have kUnassignedDevice sub shardings";
+ instruction->clear_sharding();
+ }
}
}
// Should we error out if unassigned > 0?
return Status::OK();
}
-// Creates a kDomain instruction to be placed between instruction and operand.
-// The kDomain instruction will be created only if the sharding differ between
-// the instruction and the operand.
-std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
- HloInstruction* operand) {
- const HloSharding* instruction_sharding =
- instruction->has_sharding() ? &instruction->sharding() : nullptr;
- const HloSharding* operand_sharding =
- operand->has_sharding() ? &operand->sharding() : nullptr;
- // No need for domain if they both have no sharding.
- if (instruction_sharding == nullptr && operand_sharding == nullptr) {
- return nullptr;
- }
- // No need for domain if they match.
- if (instruction_sharding != nullptr && operand_sharding != nullptr &&
- ShardingMatches(*instruction_sharding, *operand_sharding)) {
- return nullptr;
- }
- std::unique_ptr<HloSharding> real_instruction_sharding;
- std::unique_ptr<HloSharding> real_operand_sharding;
- if (instruction_sharding != nullptr) {
- real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
- }
- if (operand_sharding != nullptr) {
- real_operand_sharding = CloneShardingForDomain(*operand_sharding);
- }
- VLOG(3) << "Creating domain:";
- VLOG(3) << " Instruction: " << instruction->name();
- VLOG(3) << " Operand: " << operand->name();
- VLOG(3) << " User side sharding: "
- << (real_instruction_sharding != nullptr
- ? real_instruction_sharding->ToString()
- : "None");
- VLOG(3) << " Operand side sharding: "
- << (real_operand_sharding != nullptr
- ? real_operand_sharding->ToString()
- : "None");
-
- std::unique_ptr<DomainMetadata> operand_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_operand_sharding));
- std::unique_ptr<DomainMetadata> user_side_metadata =
- MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding));
- return HloInstruction::CreateDomain(operand->shape(), operand,
- std::move(operand_side_metadata),
- std::move(user_side_metadata));
-}
-
-StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
+ absl::Span<HloInstruction* const> instructions) {
// If we are here, all the instructions being passed had the same sharding
// (or no sharding), by the means of the ShardingMatches() API.
// As such, no kDomain was inserted, and here we are asked to extract the
// original common sharding.
// All the instructions passed to this API are part of the same computation.
- const HloSharding* sharding = nullptr;
+ std::shared_ptr<const HloSharding> sharding;
for (HloInstruction* instruction : instructions) {
if (instruction->has_sharding()) {
if (sharding == nullptr) {
- sharding = &instruction->sharding();
+ sharding = instruction->sharding_ptr();
} else {
TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
<< "Sharding " << *sharding << " does not match the one in "
@@ -346,10 +391,10 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
}
}
if (sharding == nullptr) {
- return std::unique_ptr<HloSharding>();
+ return std::shared_ptr<const HloSharding>();
}
VLOG(4) << "Extracted sharding is " << *sharding;
- return CloneShardingForDomain(*sharding);
+ return CloneShardingForDomain(sharding);
}
} // namespace
@@ -357,9 +402,9 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
std::unique_ptr<HloSharding> sharding;
if (sharding_ != nullptr) {
- sharding = MakeUnique<HloSharding>(*sharding_);
+ sharding = absl::make_unique<HloSharding>(*sharding_);
}
- return MakeUnique<ShardingMetadata>(std::move(sharding));
+ return absl::make_unique<ShardingMetadata>(std::move(sharding));
}
bool ShardingMetadata::Matches(const DomainMetadata& other) const {
@@ -377,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
: false;
}
+size_t ShardingMetadata::Hash() const {
+ if (sharding_ != nullptr) {
+ return sharding_->Hash();
+ }
+ return static_cast<size_t>(0x297814aaad196e6dULL);
+}
+
string ShardingMetadata::ToString() const {
return sharding_ != nullptr ? sharding_->ToString() : "{}";
}
@@ -403,7 +455,7 @@ Status ShardingMetadata::NormalizeShardingDomain(
TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
}
} else {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
ExtractOriginalCommonSharding(domain.instructions));
if (sharding != nullptr) {
VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
@@ -415,9 +467,75 @@ Status ShardingMetadata::NormalizeShardingDomain(
return Status::OK();
}
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand) {
- return CreateDomain(instruction, operand);
+// Creates a kDomain instruction to be placed between instruction and operand.
+// The kDomain instruction will be created only if the sharding differ between
+// the instruction and the operand.
+HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
+ HloInstruction* root,
+ HloInstruction* operand) {
+ auto instruction_sharding = instruction->sharding_ptr();
+ auto root_sharding = root->sharding_ptr();
+ // No need for domain if they both have no sharding.
+ if (instruction_sharding == nullptr && root_sharding == nullptr) {
+ return nullptr;
+ }
+ // No need for domain if they match.
+ if (instruction_sharding != nullptr && root_sharding != nullptr &&
+ ShardingMatches(*instruction_sharding, *root_sharding)) {
+ return nullptr;
+ }
+
+ if (instruction_sharding != nullptr) {
+ instruction_sharding = CloneShardingForDomain(instruction_sharding);
+ }
+ if (root_sharding != nullptr) {
+ root_sharding = CloneShardingForDomain(root_sharding);
+ }
+
+ auto it = domain_cse_map_.find({operand, instruction_sharding});
+ if (it != domain_cse_map_.end()) {
+ return it->second;
+ }
+
+ VLOG(3) << "Creating domain:";
+ VLOG(3) << " Instruction: " << instruction->name();
+ VLOG(3) << " Operand: " << operand->name();
+ VLOG(3) << " User side sharding: "
+ << (instruction_sharding != nullptr ? instruction_sharding->ToString()
+ : "None");
+ VLOG(3) << " Operand side sharding: "
+ << (root_sharding != nullptr ? root_sharding->ToString() : "None");
+
+ HloInstruction* domain =
+ operand->parent()->AddInstruction(HloInstruction::CreateDomain(
+ operand->shape(), operand,
+ absl::make_unique<ShardingMetadata>(root_sharding),
+ absl::make_unique<ShardingMetadata>(instruction_sharding)));
+ domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
+ domain);
+ return domain;
+}
+
+bool ShardingDomainCreator::DomainCseMapKey::operator==(
+ const ShardingDomainCreator::DomainCseMapKey& other) const {
+ if (instruction != other.instruction) {
+ return false;
+ }
+ if (sharding == nullptr && other.sharding == nullptr) {
+ return true;
+ }
+ if (sharding == nullptr || other.sharding == nullptr) {
+ return false;
+ }
+ return *sharding == *other.sharding;
+}
+
+size_t ShardingDomainCreator::DomainCseMapHasher::operator()(
+ const ShardingDomainCreator::DomainCseMapKey& key) const {
+ return tensorflow::Hash64Combine(
+ std::hash<const HloInstruction*>{}(key.instruction),
+ key.sharding ? key.sharding->Hash()
+ : static_cast<size_t>(0x297814aaad196e6dULL));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index 5e01fc0e22..e3ae82a070 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -16,31 +16,33 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
// A DomainMetadata implementation that internally wraps a sharding attribute.
class ShardingMetadata : public DomainMetadata {
public:
- explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding)
+ explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding)
: sharding_(std::move(sharding)) {}
std::unique_ptr<DomainMetadata> Clone() const override;
- tensorflow::StringPiece Kind() const override { return KindName(); }
+ absl::string_view Kind() const override { return KindName(); }
bool Matches(const DomainMetadata& other) const override;
+ size_t Hash() const override;
+
string ToString() const override;
const HloSharding* sharding() const { return sharding_.get(); }
- static tensorflow::StringPiece KindName() { return "sharding"; }
+ static absl::string_view KindName() { return "sharding"; }
static StatusOr<const ShardingMetadata*> ToShardingMetadata(
const DomainMetadata* metadata);
@@ -55,15 +57,33 @@ class ShardingMetadata : public DomainMetadata {
const DomainMetadata* metadata);
private:
- std::unique_ptr<HloSharding> sharding_;
+ std::shared_ptr<const HloSharding> sharding_;
};
-// Given an HLO graph edge between instruction and one of its operands, creates
-// a ShardingMetadata based kDomain instruction if the sharding between
-// instruction and operand changes. Returns nullptr if there is no need for a
-// domain separation.
-std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand);
+// If the sharding between root and instruction changes then returns a
+// ShardingMetadata based kDomain instruction what can be used to separate
+// operand and instruction.
+// Returns nullptr if there is no need for a domain separation.
+class ShardingDomainCreator {
+ public:
+ HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
+ HloInstruction* operand);
+
+ private:
+ // Map from instruction and user sharding to domain users to CSE identical
+ // domains.
+ struct DomainCseMapKey {
+ const HloInstruction* instruction;
+ std::shared_ptr<const HloSharding> sharding;
+
+ bool operator==(const DomainCseMapKey& other) const;
+ };
+ struct DomainCseMapHasher {
+ size_t operator()(const DomainCseMapKey& key) const;
+ };
+ std::unordered_map<DomainCseMapKey, HloInstruction*, DomainCseMapHasher>
+ domain_cse_map_;
+};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 45fc300fca..80634677e7 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -29,8 +29,8 @@ limitations under the License.
namespace xla {
namespace {
-Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> contents) {
+Array<int64> MakeArray(absl::Span<const int64> dimensions,
+ absl::Span<const int64> contents) {
Array<int64> a(dimensions);
std::copy(contents.begin(), contents.end(), a.begin());
return a;
@@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) {
}
}
+// Tests that empty tuple is supported.
+TEST_F(HloShardingTest, EmptySingleTuple) {
+ HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}),
+ HloSharding::AssignDevice(0));
+ EXPECT_TRUE(sharding.ExtractSingleSharding());
+}
+
TEST_F(HloShardingTest, NestedTuple) {
// nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index 2ef38821af..d1cf644f82 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -24,7 +24,7 @@ namespace xla {
// one arbitrarily to use and delete the others.
class HloSubcomputationUnification : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "subcomputation-unification";
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index b78bfa0cdf..4876533449 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -21,28 +23,25 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-using ::tensorflow::GraphDef;
-using ::tensorflow::NodeDef;
-using ::tensorflow::TensorShapeProto;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-using ::tensorflow::str_util::Join;
namespace xla {
namespace hlo_graph_dumper {
namespace {
+using absl::StrAppend;
+using absl::StrCat;
+using tensorflow::GraphDef;
+using tensorflow::NodeDef;
+using tensorflow::TensorShapeProto;
+
string GetOpDefName(const HloInstruction* instruction) {
string name = StrCat("hlo-", HloOpcodeString(instruction->opcode()));
- tensorflow::str_util::TitlecaseString(&name, "-");
+ tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok
name.erase(std::remove(name.begin(), name.end(), '-'), name.end());
if (instruction->opcode() == HloOpcode::kFusion) {
string fusion_name = ToString(instruction->fusion_kind());
- StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1));
+ StrAppend(&name, absl::string_view(fusion_name).substr(1));
}
return name;
}
@@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
} else {
layout_string = StrCat(
- "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}");
+ "{",
+ absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","),
+ "}");
}
attrs["layout"].set_s(layout_string);
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a1f2..6fd734a2b9 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 7fd99fc930..773fc7d225 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -30,16 +32,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
const Shape& HloPosition::shape() const {
return ShapeUtil::GetSubshape(instruction->shape(), index);
@@ -150,7 +149,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace
void HloValue::SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions) {
+ absl::Span<const HloPosition> positions) {
CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
// The positions must be unique and should not contain the defining position
@@ -216,14 +215,14 @@ void HloValueSet::SortAndUniquifyValues() {
}
string HloValueSet::ToString() const {
- return StrCat("HloValueSet: ",
- Join(values_, ", ", [](string* result, const HloValue* value) {
- result->append(value->ToShortString());
- }));
+ return StrCat(
+ "HloValueSet: ",
+ absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
+ result->append(value->ToShortString());
+ }));
}
-bool HloValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
+bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
HloValueSet union_set;
for (const HloValueSet* input : inputs) {
for (const HloValue* value : input->values()) {
@@ -254,7 +253,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
}
bool InstructionValueSet::AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ absl::Span<const InstructionValueSet* const> inputs) {
CHECK_GT(inputs.size(), 0);
for (int i = 1; i < inputs.size(); ++i) {
DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h
index a1151f65e0..b6670d409b 100644
--- a/tensorflow/compiler/xla/service/hlo_value.h
+++ b/tensorflow/compiler/xla/service/hlo_value.h
@@ -20,13 +20,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -108,8 +108,7 @@ class HloValue : public BufferValue {
// Sets the positions in the module at which the HloValue appears. Updates
// uses. Should be called once and only once. The defining position should not
// be included in 'positions' as this is set at construction time.
- void SetPositionsAndComputeUses(
- tensorflow::gtl::ArraySlice<HloPosition> positions);
+ void SetPositionsAndComputeUses(absl::Span<const HloPosition> positions);
// Returns whether this value is a phi value.
bool is_phi() const { return is_phi_; }
@@ -186,14 +185,14 @@ class HloValueSet {
public:
HloValueSet() = default;
- explicit HloValueSet(tensorflow::gtl::ArraySlice<const HloValue*> values)
+ explicit HloValueSet(absl::Span<const HloValue* const> values)
: values_(values.begin(), values.end()) {
SortAndUniquifyValues();
}
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const HloValueSet* const> inputs);
// Return the vector of HloValues in the set. Values in the vector are unique
// and stably sorted by value id.
@@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree<HloValueSet> {
// Sets this value set to the union of the given value sets. Returns whether
// this value set changed.
- bool AssignUnionOf(
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ bool AssignUnionOf(absl::Span<const InstructionValueSet* const> inputs);
string ToString() const;
};
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 949a4d1110..50f39cbcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,11 +15,13 @@ limitations under the License.
#include <set>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -84,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers(),
- convolution->feature_group_count()));
+ convolution->feature_group_count(), convolution->window(),
+ convolution->convolution_dimension_numbers()));
return CheckShape(convolution, expected);
}
@@ -115,6 +117,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
ShapeInference::InferAllToAllTupleShape(operand_shapes));
}
+Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
+ return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
+ hlo->operand(0)->shape()));
+}
+
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
@@ -122,39 +129,32 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-namespace {
-
-Status CheckIsTokenOperand(const HloInstruction* instruction,
- int64 operand_no) {
+Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
- "Expected operand %lld to be token-shaped, actual shape is "
+ "Expected operand %d to be token-shaped, actual shape is "
"%s:\n%s",
- operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
- instruction->ToString().c_str());
+ operand_no, StringifyShape(token->shape()), instruction->ToString());
}
return Status::OK();
}
-Status CheckOperandAndParameter(const HloInstruction* instruction,
- int64 operand_number,
- const HloComputation* computation,
- int64 parameter_number) {
+Status ShapeVerifier::CheckOperandAndParameter(
+ const HloInstruction* instruction, int64 operand_number,
+ const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
- if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
- operand->ToString().c_str(),
- parameter->ToString().c_str(),
- instruction->ToString().c_str());
+ operand->ToString(), parameter->ToString(),
+ instruction->ToString());
}
return Status::OK();
}
-} // namespace
-
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -171,22 +171,16 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
- if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
- outfeed->operand(0)->shape())) {
+ if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed shape to be compatible with operand's shape %s, "
+ "Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
- outfeed->ToString().c_str());
+ StringifyShape(outfeed->operand(0)->shape()),
+ StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
-Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
- return Status::OK();
-}
-
bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
const Shape& shape_1,
const Shape& result_shape) {
@@ -200,7 +194,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
if (instruction->operand_count() != 2) {
return InternalError("Expected two operands for Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
const Shape& shape_0 = instruction->operand(0)->shape();
@@ -208,14 +202,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
return InternalError(
"Expected scalar types for the two operands of Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
return InternalError(
"Expected compatible element types for the result and the two operands"
" of Rng instruction: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
PrimitiveType element_type = shape_0.element_type();
@@ -228,7 +222,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
"Element type not supported."
" Expected element to be of floating point type, integral type or"
" predicate type for RngUniform: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
break;
@@ -237,13 +231,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
return InternalError(
"Element type not supported."
" Expected element to be FloatingPointType for RngNormal: %s",
- instruction->ToString().c_str());
+ instruction->ToString());
}
break;
default:
return InternalError(
"Invalid Rng distribution %s",
- RandomDistribution_Name(instruction->random_distribution()).c_str());
+ RandomDistribution_Name(instruction->random_distribution()));
}
return Status::OK();
@@ -262,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
return InternalError(
"Expected sort to have to have the same dimensions for the keys and "
"the values. Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ StringifyShape(sort->operand(0)->shape()),
+ StringifyShape(sort->operand(1)->shape()));
}
return CheckVariadicShape(sort);
}
@@ -272,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
return CheckShape(constant, constant->literal().shape());
}
-Status ShapeVerifier::HandleIota(HloInstruction* iota) {
- return ShapeUtil::Rank(iota->shape()) == 1
- ? Status::OK()
- : InternalError("Iota only supports arrays of rank 1.");
+Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ const int64 rank = ShapeUtil::Rank(iota->shape());
+ if (rank == 0) {
+ return InternalError("Iota does not support scalars.");
+ }
+ int64 iota_dimension = iota->iota_dimension();
+ if (iota_dimension >= rank) {
+ return InternalError(
+ "The iota dimension cannot go beyond the operation rank.");
+ }
+ return Status::OK();
}
Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
@@ -286,14 +288,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
- if (!ShapeUtil::IsArray(reduce->shape())) {
- return InvalidArgument("Variadic reduce is not supported.");
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : reduce->operands()) {
+ operand_shapes.push_back(&operand->shape());
}
- return CheckShape(
- reduce,
- ShapeInference::InferReduceShape(
- {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
- reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
+ return CheckShape(reduce, ShapeInference::InferReduceShape(
+ operand_shapes, reduce->dimensions(),
+ reduce->to_apply()->ComputeProgramShape()));
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
@@ -337,7 +338,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
+ for (HloInstruction* fused_param : fusion->fused_parameters()) {
+ int64 param_no = fused_param->parameter_number();
+ if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
+ return InternalError(
+ "Shape mismatch between parameter number %d and its operand in "
+ "%s.",
+ param_no, fusion->ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
@@ -419,12 +431,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
- if (!ShapeUtil::Compatible(conditional_shape,
- ShapeUtil::MakeShape(PRED, {}))) {
+ if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
- ShapeUtil::HumanString(conditional_shape).c_str());
+ StringifyShape(conditional_shape));
}
// The shape of kWhile should match the shape of the body computation it
// calls.
@@ -555,7 +566,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
return InternalError(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
- instruction->ToString().c_str());
+ instruction->ToString());
}
return Status::OK();
}));
@@ -572,7 +583,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather,
ShapeInference::InferGatherShape(
gather->operand(0)->shape(), gather->operand(1)->shape(),
- gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+ gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
@@ -602,53 +613,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
// Check if the output shape matches the expected shape.
- bool compatible;
+ //
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
- switch (instruction->opcode()) {
- case HloOpcode::kTupleSelect:
- // TupleSelect only defines the top-level buffer, which in this case is
- // the tuple, so we cannot allow mixed precision.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- case HloOpcode::kGetTupleElement:
- case HloOpcode::kTuple:
- // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
- // precision is disallowed.
- case HloOpcode::kConstant:
- case HloOpcode::kBitcast:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCall:
- case HloOpcode::kConditional:
- case HloOpcode::kConvert:
- case HloOpcode::kCustomCall:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kParameter:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kWhile:
- // The above opcodes should match the expected shapes exactly.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- default:
- if (allow_mixed_precision_) {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- } else {
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- }
- }
- if (!compatible) {
+ bool equal = [&] {
+ switch (instruction->opcode()) {
+ // The opcodes below can't have implicit layout conversions, nor can they
+ // implicitly transform f32 -> bf16. Fundamentally these are either
+ // reinterpreting existing data (e.g. kBitcast) or shuffling data around
+ // without modifying it (e.g. kGetTupleElement, kTupleSelect).
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return ShapesSame(instruction->shape(), inferred_shape);
+
+ // We allow arbitrary layout and f32->bf16 transformations on all other
+ // instructions, although this may be made more strict pending discussion
+ // in b/112709536.
+ default:
+ if (allow_mixed_precision_) {
+ return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
+ inferred_shape);
+ } else {
+ return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ }();
+ if (!equal) {
return InternalError(
- "Expected instruction to have shape compatible with %s, actual "
+ "Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(inferred_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- instruction->ToString().c_str());
+ StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
+ instruction->ToString());
}
return Status::OK();
}
@@ -690,12 +699,11 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
instruction->opcode(), instruction->operands()));
}
-string ComputationsToString(
- tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- return tensorflow::str_util::Join(
- computations, ",", [](string* s, const HloComputation* computation) {
- s->append(computation->name());
- });
+string ComputationsToString(absl::Span<HloComputation* const> computations) {
+ return absl::StrJoin(computations, ",",
+ [](string* s, const HloComputation* computation) {
+ s->append(computation->name());
+ });
}
// Verifies various invariants about the structure of the HLO:
@@ -713,23 +721,23 @@ Status VerifyHloStructure(HloModule* module) {
for (const HloComputation* computation : module->computations()) {
if (computation->parent() == nullptr) {
return InternalError("Computation %s has a null parent pointer",
- computation->name().c_str());
+ computation->name());
}
if (computation->parent() != module) {
return InternalError(
"Computation %s parent() does not point to parent module",
- computation->name().c_str());
+ computation->name());
}
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->parent() == nullptr) {
return InternalError("Instruction %s has a null parent pointer",
- instruction->name().c_str());
+ instruction->name());
}
if (instruction->parent() != computation) {
return InternalError(
"Instruction %s parent() does not point to parent computation",
- instruction->name().c_str());
+ instruction->name());
}
}
}
@@ -746,9 +754,8 @@ Status VerifyHloStructure(HloModule* module) {
return InternalError(
"Operand %d (%s) of instruction %s is in a different "
"computation: %s vs %s",
- i, operand->name().c_str(), instruction->name().c_str(),
- operand->parent()->name().c_str(),
- instruction->parent()->name().c_str());
+ i, operand->name(), instruction->name(),
+ operand->parent()->name(), instruction->parent()->name());
}
}
}
@@ -764,7 +771,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
"Instruction of fused computation does not match expected "
"instruction "
"%s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
// Fused root instruction and fused parameters must all be owned by the
@@ -778,7 +785,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (fused_root == instruction) {
if (root_owned) {
return InternalError("Root appears more than once in %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
root_owned = true;
}
@@ -786,7 +793,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (fused_parameters[i] == instruction) {
if (parameter_owned[i]) {
return InternalError("Parameter appears more than once in %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
parameter_owned[i] = true;
}
@@ -794,20 +801,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
if (!root_owned) {
return InternalError("Root not found in computation of %s.",
- fusion->ToString().c_str());
+ fusion->ToString());
}
// Make sure all the parameter_owned entries are set
for (int i = 0; i < parameter_owned.size(); i++) {
if (!parameter_owned[i]) {
return InternalError("Parameter %d not found in computation of %s.", i,
- fusion->ToString().c_str());
+ fusion->ToString());
}
}
// Fused root must have no users.
if (fused_root->user_count() != 0) {
- return InternalError("Root of %s may not have users.",
- fusion->ToString().c_str());
+ return InternalError("Root of %s may not have users.", fusion->ToString());
}
// All uses of fused instructions must be in the fusion computation, and
@@ -817,54 +823,46 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (instruction != fused_root) {
if (instruction->user_count() == 0) {
return InternalError("Non-root instruction %s in %s must have users.",
- instruction->ToString().c_str(),
- fusion->ToString().c_str());
+ instruction->ToString(), fusion->ToString());
}
for (auto& user : instruction->users()) {
if (fused_computation != user->parent()) {
return InternalError(
"Non-root instruction %s in %s may not have external users.",
- instruction->ToString().c_str(), fusion->ToString().c_str());
+ instruction->ToString(), fusion->ToString());
}
}
}
}
// Fused parameter instructions must be numbered contiguously and match up
- // (shapes compatible) with their respective operand.
+ // (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (param_no < 0) {
- return InternalError("Unexpected negative parameter number %lld in %s.",
- param_no, fusion->ToString().c_str());
+ return InternalError("Unexpected negative parameter number %d in %s.",
+ param_no, fusion->ToString());
}
if (param_no >= fused_parameters.size()) {
return InternalError(
- "Unexpected parameter number %lld in %s: higher then number of "
+ "Unexpected parameter number %d in %s: higher then number of "
"parameters %lu.",
- param_no, fusion->ToString().c_str(), fused_parameters.size());
+ param_no, fusion->ToString(), fused_parameters.size());
}
if (parameter_numbers[param_no]) {
return InternalError(
- "Did not expect parameter number %lld more than once in %s.",
- param_no, fusion->ToString().c_str());
+ "Did not expect parameter number %d more than once in %s.", param_no,
+ fusion->ToString());
}
parameter_numbers[param_no] = true;
- if (!ShapeUtil::Compatible(fused_param->shape(),
- fusion->operand(param_no)->shape())) {
- return InternalError(
- "Shape mismatch between parameter number %lld and its operand in "
- "%s.",
- param_no, fusion->ToString().c_str());
- }
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
if (!parameter_numbers[i]) {
return InternalError("Did not see parameter number %d in %s.", i,
- fusion->ToString().c_str());
+ fusion->ToString());
}
}
@@ -879,18 +877,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
auto* while_body = instruction->while_body();
if (while_cond->num_parameters() != 1) {
return FailedPrecondition(
- "While condition must have exactly 1 parameter; had %lld : %s",
- while_cond->num_parameters(), while_cond->ToString().c_str());
+ "While condition must have exactly 1 parameter; had %d : %s",
+ while_cond->num_parameters(), while_cond->ToString());
}
if (while_body->num_parameters() != 1) {
return FailedPrecondition(
- "While body must have exactly 1 parameter; had %lld : %s",
- while_body->num_parameters(), while_body->ToString().c_str());
+ "While body must have exactly 1 parameter; had %d : %s",
+ while_body->num_parameters(), while_body->ToString());
}
if (instruction->operand_count() != 1) {
return FailedPrecondition(
- "While loop must have exactly one operand; had %lld : %s",
- instruction->operand_count(), instruction->ToString().c_str());
+ "While loop must have exactly one operand; had %d : %s",
+ instruction->operand_count(), instruction->ToString());
}
return Status::OK();
}
@@ -898,16 +896,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
if (instruction->true_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "True computation %s of %s must have 1 parameter insted of %lld",
- instruction->true_computation()->name().c_str(),
- instruction->ToString().c_str(),
+ "True computation %s of %s must have 1 parameter insted of %d",
+ instruction->true_computation()->name(), instruction->ToString(),
instruction->true_computation()->num_parameters());
}
if (instruction->false_computation()->num_parameters() != 1) {
return FailedPrecondition(
- "False computation %s of %s must have 1 parameter insted of %lld",
- instruction->false_computation()->name().c_str(),
- instruction->ToString().c_str(),
+ "False computation %s of %s must have 1 parameter insted of %d",
+ instruction->false_computation()->name(), instruction->ToString(),
instruction->false_computation()->num_parameters());
}
return Status::OK();
@@ -920,11 +916,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
- "Found non-compatible shapes for instruction %s.\n"
+ "Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
- HloOpcodeString(instruction->opcode()).c_str(),
- ShapeUtil::HumanString(out_shape).c_str(),
- ShapeUtil::HumanString(operand_shape).c_str());
+ HloOpcodeString(instruction->opcode()),
+ ShapeUtil::HumanString(out_shape),
+ ShapeUtil::HumanString(operand_shape));
}
}
return Status::OK();
@@ -955,7 +951,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
if (ShapeContainsToken(param->shape())) {
return InternalError(
"Entry parameter %d is or contains a token shape: %s", i,
- ShapeUtil::HumanString(param->shape()).c_str());
+ ShapeUtil::HumanString(param->shape()));
}
}
return Status::OK();
@@ -967,9 +963,9 @@ Status CheckSameChannel(const HloInstruction* instr1,
if (instr1->channel_id() != instr2->channel_id()) {
return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
- "(%lld), %s (%lld)",
- instr1->ToString().c_str(), instr1->channel_id(),
- instr2->ToString().c_str(), instr2->channel_id());
+ "(%d), %s (%d)",
+ instr1->ToString(), instr1->channel_id(), instr2->ToString(),
+ instr2->channel_id());
}
return Status::OK();
}
@@ -990,7 +986,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
"Expected instructions to have the same is-host-transfer property: "
"%s, "
"%s ",
- instr1->ToString().c_str(), instr2->ToString().c_str());
+ instr1->ToString(), instr2->ToString());
}
return Status::OK();
}
@@ -1007,12 +1003,12 @@ Status VerifySendsAndRecvs(const HloModule& module) {
host_channels.insert({sendrecv->channel_id(), sendrecv});
if (!it_inserted.second) {
return FailedPrecondition(
- "Channel %lld is used for multiple host send/recv instructions: "
+ "Channel %d is used for multiple host send/recv instructions: "
"%s "
"and "
"%s",
- sendrecv->channel_id(), sendrecv->ToString().c_str(),
- it_inserted.first->second->ToString().c_str());
+ sendrecv->channel_id(), sendrecv->ToString(),
+ it_inserted.first->second->ToString());
}
}
@@ -1071,9 +1067,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RET_CHECK(instruction->parent() == computation);
if (instruction->opcode() == HloOpcode::kFusion) {
TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(
- ContainersEqual(instruction->called_computations(),
- {instruction->fused_instructions_computation()}))
+ TF_RET_CHECK(instruction->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {instruction->fused_instructions_computation()}))
<< "Fusion HLO calls computations other than the "
"fused_instructions_computation: "
<< instruction->ToString()
@@ -1127,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+ // If the module has a schedule, it must be valid.
+ if (module->has_schedule()) {
+ TF_RETURN_IF_ERROR(module->schedule().Verify());
+ }
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index c942fab08e..42e3027bf1 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
@@ -27,9 +28,9 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
- explicit ShapeVerifier() : allow_mixed_precision_(false) {}
- explicit ShapeVerifier(bool allow_mixed_precision)
- : allow_mixed_precision_(allow_mixed_precision) {}
+ explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
@@ -46,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFft(HloInstruction* fft) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;
+ Status HandleCollectivePermute(HloInstruction* hlo) override;
Status HandleReducePrecision(HloInstruction* reduce_precision) override;
Status HandleInfeed(HloInstruction*) override;
Status HandleOutfeed(HloInstruction*) override;
@@ -63,7 +65,6 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFusion(HloInstruction*) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction*) override;
- Status HandleHostCompute(HloInstruction*) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
@@ -106,13 +107,42 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
- // Return true if the shapes of the two operands have the same element type,
- // and the result shape either has the same element type as the operand
- // shapes or mixed precision is allowed and the result shape and the operand
- // shapes have floating point element types.
+ // Helpers that switch on layout_sensitive_.
+ bool ShapesSame(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::Equal(a, b)
+ : ShapeUtil::Compatible(a, b);
+ }
+ bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
+ : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
+ }
+ string StringifyShape(const Shape& s) {
+ return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
+ : ShapeUtil::HumanString(s);
+ }
+
+ // Checks that the given operand of the given instruction is of type TOKEN.
+ Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Checks that the shape of the given operand of the given instruction matches
+ // the given parameter of the given computation.
+ Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number);
+
+ // Returns true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand shapes
+ // or mixed precision is allowed and the result shape and the operand shapes
+ // have floating point element types.
bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
const Shape& result_shape);
+ // If the verifier is layout-sensitive, shapes must be equal to what's
+ // expected. Otherwise, the shapes must simply be compatible.
+ bool layout_sensitive_;
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
@@ -125,14 +155,10 @@ class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- // Uses standard shape inference.
- explicit HloVerifier()
- : shape_verifier_factory_(
- [] { return MakeUnique<ShapeVerifier>(false); }) {}
-
- explicit HloVerifier(bool allow_mixed_precision)
- : shape_verifier_factory_([allow_mixed_precision] {
- return MakeUnique<ShapeVerifier>(allow_mixed_precision);
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
+ return absl::make_unique<ShapeVerifier>(layout_sensitive,
+ allow_mixed_precision);
}) {}
// Uses custom shape verification.
@@ -140,10 +166,9 @@ class HloVerifier : public HloPassInterface {
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
~HloVerifier() override = default;
- tensorflow::StringPiece name() const override { return "verifier"; }
+ absl::string_view name() const override { return "verifier"; }
- // Note: always returns false (no instructions are ever modified by this
- // pass).
+ // Never returns true; no instructions are ever modified by this pass.
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index d764964f3c..8f0423bb1c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -34,16 +34,20 @@ namespace {
using ::testing::HasSubstr;
+// This class cannot be converted to use HloVerifiedTestBase. It explicitly
+// uses HloTestBase to create and test malformed HLOs.
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
@@ -275,5 +279,84 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
}
+TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
+ // This testcase can't be written using textual HLO, because it doesn't parse
+ // negative interior padding. That's probably a feature. :)
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {100}), "param"));
+ PaddingConfig padding_config;
+ padding_config.add_dimensions()->set_interior_padding(-1);
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {100}), param,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
+ padding_config));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("Interior padding cannot be negative"));
+}
+
+// Simple module containing a convolution as the root.
+static const char* const kConvHloString = R"(
+HloModule module
+ENTRY entry_computation {
+ param0 = f16[128,128,56,56] parameter(0)
+ param1 = f16[3,3,128,128] parameter(1)
+ zero_f16 = f16[] constant(0)
+ ROOT conv = f16[128,128,28,28] convolution(param0, param1),
+ window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
+})";
+
+TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_window_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive window dilation factor"));
+}
+
+TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
+ auto* conv = module->entry_computation()->root_instruction();
+ Window w = conv->window();
+ w.mutable_dimensions(0)->set_base_dilation(-1);
+ conv->set_window(w);
+
+ EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
+ HasSubstr("non-positive base area dilation factor"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index bb5b40a8a8..e76b93107c 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -14,27 +14,27 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/metric_table_report.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
-using tensorflow::strings::Appendf;
+using absl::StrAppend;
+using absl::StrAppendFormat;
+using absl::StrCat;
+using absl::StrFormat;
using tensorflow::strings::HumanReadableElapsedTime;
using tensorflow::strings::HumanReadableNumBytes;
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
string HumanReadableProfileBuilder::ToString() const {
string s;
- Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n",
- computation_name_.c_str(),
- HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str());
+ StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n",
+ computation_name_,
+ HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)));
int64 cumulative_cycles = 0;
auto print_op = [&](const OpInfo& op, bool is_total = false) {
@@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const {
if (op.bytes_accessed > op.cycles) {
bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle");
} else {
- bytes_per_cycle = Printf("%.3fB/cycle", bpc);
+ bytes_per_cycle = StrFormat("%.3fB/cycle", bpc);
}
}
@@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const {
// columns in the output.
cycles_percent_str = "100.% 100Σ";
} else {
- cycles_percent_str =
- Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent);
+ cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent,
+ cumulative_cycles_percent);
}
double nsecs = op.cycles / clock_rate_ghz_;
- Appendf(
+ StrAppendFormat(
&s,
- "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: "
+ "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: "
"%16s :: %s\n",
- op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles),
+ op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles),
op.optimal_seconds < 0
? ""
- : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(),
- op.flop_count <= 0
- ? ""
- : HumanReadableNumFlops(op.flop_count, nsecs).c_str(),
+ : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6),
+ op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs),
op.transcendental_count <= 0
? ""
- : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs)
- .c_str(),
- bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str());
+ : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs),
+ bytes_per_sec, bytes_per_cycle, op.name);
};
float optimal_seconds_sum = 0.0;
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
index 6f56c3aa82..925111fa1f 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -29,10 +29,10 @@ namespace xla {
// computation, suitable for consumption by humans.
class HumanReadableProfileBuilder {
public:
- explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name,
+ explicit HumanReadableProfileBuilder(absl::string_view computation_name,
int64 total_cycles,
double clock_rate_ghz)
- : computation_name_(std::string(computation_name)),
+ : computation_name_(computation_name),
total_cycles_(total_cycles),
clock_rate_ghz_(clock_rate_ghz) {
CHECK_GE(clock_rate_ghz, 1e-9);
@@ -43,15 +43,13 @@ class HumanReadableProfileBuilder {
// Adds an operation to the profile. If you don't know the number of
// floating-point ops or bytes touched by the op, or if you don't know how
// fast it would run optimally, pass -1 for that param.
- void AddOp(tensorflow::StringPiece op_name,
- tensorflow::StringPiece short_name,
- tensorflow::StringPiece category, int64 cycles, int64 flop_count,
+ void AddOp(absl::string_view op_name, absl::string_view short_name,
+ absl::string_view category, int64 cycles, int64 flop_count,
int64 transcendental_count, int64 bytes_accessed,
float optimal_seconds) {
- op_infos_.push_back({std::string(op_name), std::string(short_name),
- std::string(category), cycles, flop_count,
- transcendental_count, bytes_accessed,
- optimal_seconds});
+ op_infos_.push_back({string(op_name), string(short_name), string(category),
+ cycles, flop_count, transcendental_count,
+ bytes_accessed, optimal_seconds});
}
// Gets the human-readable profile.
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index aa325dc8a3..85bb4a8b24 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface {
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "implicit-broadcast-remover";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 3531b7223f..06f0e1ed25 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -14,13 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/optional.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
@@ -31,32 +34,29 @@ using UnknownArray = Analysis::UnknownArray;
using ConstantArray = Analysis::ConstantArray;
using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
-using tensorflow::gtl::ArraySlice;
-using tensorflow::str_util::Join;
+using absl::StrJoin;
} // namespace
string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
auto* unknown_tensor = root->as<UnknownArray>();
- return tensorflow::strings::StrCat("%",
- unknown_tensor->instruction().name());
+ return absl::StrCat("%", unknown_tensor->instruction().name());
}
case Array::kConstant: {
if (print_constants) {
string contents = root->as<ConstantArray>()->literal()->ToString();
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
- ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ " ", contents, ")");
}
- return tensorflow::strings::StrCat(
- "(constant ", ShapeUtil::HumanString(root->shape()), ")");
+ return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
+ ")");
}
case Array::kReshaped: {
ReshapedArray* reshaped_array = root->as<ReshapedArray>();
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(reshape ", ToString(reshaped_array->operand(), print_constants),
" to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
}
@@ -67,11 +67,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
string name = root->kind() == Array::kScalarIndexedConstant
? "scalar-indexed-const"
: "scalar-indexed";
- return tensorflow::strings::StrCat(
+ return absl::StrCat(
"(", name, " ", ToString(indexed_array->source(), print_constants),
" ", ToString(indexed_array->indices(), print_constants), " ",
indexed_array->source_dim(), "->[",
- Join(indexed_array->output_dims(), ","), "])");
+ StrJoin(indexed_array->output_dims(), ","), "])");
}
}
}
@@ -92,7 +92,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
// Depth first search over the DAG, invoking ComputeArrayFor in post order.
// The HLO instructions already in the cache are considered leaves.
- gtl::InlinedVector<const HloInstruction*, 4> stack;
+ absl::InlinedVector<const HloInstruction*, 4> stack;
enum DfsState { kDiscovered, kVisited };
gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
@@ -153,7 +153,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(),
+ instr->gather_slice_sizes(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else if (instr->opcode() == HloOpcode::kReshape) {
@@ -165,6 +165,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ instr->precision_config(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else {
@@ -185,7 +186,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
+ absl::Span<const int64> output_dims, Shape shape) {
// We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
// `source` is the inner Gather(A, X).
@@ -251,24 +252,22 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
- Array* indices) {
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
}
- CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
+ CHECK_EQ(dim_numbers.start_index_map_size(), 1);
- // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should
- // it become relevant.
+ // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
+ // should it become relevant.
- if (dim_numbers.elided_window_dims_size() != 1 ||
- dim_numbers.elided_window_dims(0) !=
- dim_numbers.gather_dims_to_operand_dims(0)) {
+ if (dim_numbers.collapsed_slice_dims_size() != 1 ||
+ dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
VLOG(3) << "ComputeArrayForGather: gather operations must elide "
- "gather_dims_to_operand_dims[0] and "
- "gather_dims_to_operand_dims[0] only";
+ "start_index_map[0] and "
+ "start_index_map[0] only";
return nullptr;
}
@@ -277,27 +276,27 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
// arrays from an array of size [7,4,6]. We check that condition down below:
for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
- if (i != dim_numbers.elided_window_dims(0) &&
- source->shape().dimensions(i) != window_bounds[i]) {
- VLOG(3) << "ComputeArrayForGather: window_bounds[" << i
+ if (i != dim_numbers.collapsed_slice_dims(0) &&
+ source->shape().dimensions(i) != slice_sizes[i]) {
+ VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
<< "] != source->shape().dimensions(" << i << ") -- "
- << source->shape().dimensions(i) << " vs. " << window_bounds[i]
- << " with dim_numbers.elided_window_dims(0) = "
- << dim_numbers.elided_window_dims(0);
+ << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
+ << " with dim_numbers.collapsed_slice_dims(0) = "
+ << dim_numbers.collapsed_slice_dims(0);
return nullptr;
}
}
- int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
+ int64 source_dim = dim_numbers.start_index_map(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
- if (!c_binary_search(dim_numbers.output_window_dims(), i)) {
+ if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
output_dims.push_back(i);
}
}
if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
- if (c_linear_search(indexed->output_dims(), source_dim)) {
+ if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
shape);
}
@@ -314,8 +313,8 @@ namespace {
// Returns an index into `values` such that the product of the range
// [values.begin()+index, values.end()) is equal to `product`. If there is no
// such index, return -1. All integers in `values` must be positive.
-int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
- DCHECK(c_all_of(values, [](int64 value) { return value > 0; }));
+int64 FindSuffixWithProduct(absl::Span<const int64> values, int64 product) {
+ DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; }));
int64 current_product = 1;
int64 i;
@@ -343,7 +342,8 @@ struct ReshapePassthroughDimPair {
// The returned vector of pairs is sorted in both the result_dim and the
// operand_dim components.
std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
- ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape) {
+ absl::Span<const int64> operand_shape,
+ absl::Span<const int64> result_shape) {
// A reshape can be seen as an index mapping from output index to input index:
//
// (i_0, ..., i_n) = f(o_0, ..., o_m)
@@ -378,8 +378,8 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
CHECK_NE(candidate_operand_dim, 0)
<< "result_dim = " << result_dim
<< ", result_subarray_size = " << result_subarray_size
- << ", result_shape = [" << Join(result_shape, ",") << "]"
- << ", operand_shape = [" << Join(operand_shape, ",") << "]";
+ << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
+ << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
if (candidate_operand_dim != -1 &&
result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
@@ -389,26 +389,27 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
result_subarray_size *= result_shape[result_dim];
}
- c_reverse(result);
+ absl::c_reverse(result);
if (VLOG_IS_ON(3)) {
std::vector<string> result_strings;
- c_transform(result, std::back_inserter(result_strings),
- [](ReshapePassthroughDimPair value) {
- return tensorflow::strings::StrCat(value.result_dim, "->",
- value.operand_dim);
- });
- VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
- << Join(result_shape, ",") << "] passthrough indices are ["
- << Join(result_strings, ",") << "] (legend: `result`->`operand`)";
+ absl::c_transform(result, std::back_inserter(result_strings),
+ [](ReshapePassthroughDimPair value) {
+ return absl::StrCat(value.result_dim, "->",
+ value.operand_dim);
+ });
+ VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
+ << StrJoin(result_shape, ",") << "] passthrough indices are ["
+ << StrJoin(result_strings, ",")
+ << "] (legend: `result`->`operand`)";
}
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.result_dim < rhs.result_dim;
}));
- DCHECK(c_is_sorted(
+ DCHECK(absl::c_is_sorted(
result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
return lhs.operand_dim < rhs.operand_dim;
}));
@@ -419,30 +420,31 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// Return true if `dim` is stated as an passthrough operand dim in
// `passthrough_dims`.
bool IsReshapePassthroughOperandDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
- return c_any_of(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == dim;
- });
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
+ return absl::c_any_of(passthrough_dims,
+ [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == dim;
+ });
}
// Maps `operand_dim` which must be an passthrough operand dimension to its
// corresponding passthrough result dimension based on `passthrough_dims`.
int64 MapPassthroughOperandDimToResultDim(
- ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
- auto it = c_find_if(passthrough_dims,
- [&](ReshapePassthroughDimPair passthrough_dim_pair) {
- return passthrough_dim_pair.operand_dim == operand_dim;
- });
+ absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
+ int64 operand_dim) {
+ auto it = absl::c_find_if(
+ passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == operand_dim;
+ });
CHECK(it != passthrough_dims.end());
return it->result_dim;
}
-int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
- ArraySlice<int64> result_shape,
- int64 source_passthrough_dim) {
+int64 FindSourcePositionForPassthroughResultDim(
+ absl::Span<const int64> operand_shape, absl::Span<const int64> result_shape,
+ int64 source_passthrough_dim) {
VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
- << Join(operand_shape, ",") << "], [" << Join(result_shape, ",")
+ << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
<< "], " << source_passthrough_dim << ")";
int64 indexed_source_subarray_size =
@@ -454,8 +456,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
Shape StripDegenerateDimensions(const Shape& shape) {
DimensionVector new_dims;
- c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
- [](int64 dim) { return dim != 1; });
+ absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
+ [](int64 dim) { return dim != 1; });
return ShapeUtil::MakeShape(shape.element_type(), new_dims);
}
}; // namespace
@@ -498,7 +500,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
if (shape.dimensions(i) == 1) {
degenerate_dims_seen++;
- } else if (ArrayContains(operand->output_dims(), i)) {
+ } else if (absl::c_linear_search(operand->output_dims(), i)) {
new_output_dims.push_back(i - degenerate_dims_seen);
}
}
@@ -518,8 +520,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
}
StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims) {
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims) {
if (degenerate_dims.empty()) {
return operand;
}
@@ -531,7 +532,7 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
// element is true iff the i'th component of the result index is an output
// index.
- gtl::InlinedVector<bool, 6> output_dims_bitvector(
+ absl::InlinedVector<bool, 6> output_dims_bitvector(
operand->shape().dimensions_size());
for (int64 output_dim : operand->output_dims()) {
output_dims_bitvector[output_dim] = true;
@@ -553,8 +554,8 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
}();
DimensionVector new_result_shape_dims;
- c_copy(operand->shape().dimensions(),
- std::back_inserter(new_result_shape_dims));
+ absl::c_copy(operand->shape().dimensions(),
+ std::back_inserter(new_result_shape_dims));
for (int64 degenerate_dim : degenerate_dims) {
InsertAt(&new_result_shape_dims, degenerate_dim, 1);
}
@@ -695,8 +696,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
operand_dim);
};
- if (!c_all_of(scalar_indexed->output_dims(),
- is_reshape_passthrough_operand_dim)) {
+ if (!absl::c_all_of(scalar_indexed->output_dims(),
+ is_reshape_passthrough_operand_dim)) {
VLOG(3) << "Not all output dims are passthrough dims "
<< ToString(scalar_indexed);
return nullptr;
@@ -735,11 +736,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
// operand = s32[3,5,2] constant({...})
// indices = s32[7] parameter(0)
// gather = s32[3,2,7] gather(operand, indices),
- // output_window_dims={0,1},
- // elided_window_dims={1},
- // gather_dims_to_operand_dims={1},
+ // offset_dims={0,1},
+ // collapsed_slice_dims={1},
+ // start_index_map={1},
// index_vector_dim=1,
- // window_bounds={3,1,2}
+ // slice_sizes={3,1,2}
// reshape = s32[6,7] reshape(gather)
//
// In this case the gather maps to:
@@ -754,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
if (source_dim_for_new_scalar_indexed_node == -1) {
VLOG(3) << "Could not compute the source dim for the new scalar indexed "
"node: scalar_indexed_source_shape = ["
- << Join(scalar_indexed_source_shape.dimensions(), ",")
+ << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
<< "] and new_scalar_indexed_source_shape = ["
- << Join(new_scalar_indexed_source_shape, ",") << "]";
+ << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
return nullptr;
}
@@ -764,8 +765,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
&new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
- CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL,
- std::multiplies<int64>()),
+ CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
+ std::multiplies<int64>()),
ShapeUtil::ElementsIn(scalar_indexed_source_shape));
CHECK(IsReshapePassthroughOperandDim(
@@ -781,9 +782,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
};
std::vector<int64> output_dims_for_new_scalar_indexed_node;
- c_transform(scalar_indexed->output_dims(),
- std::back_inserter(output_dims_for_new_scalar_indexed_node),
- map_passthrough_operand_dim_to_result_dim);
+ absl::c_transform(scalar_indexed->output_dims(),
+ std::back_inserter(output_dims_for_new_scalar_indexed_node),
+ map_passthrough_operand_dim_to_result_dim);
TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
TakeOwnership(scalar_indexed->literal().Reshape(
@@ -872,13 +873,14 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
return nullptr;
}
- ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+ absl::Span<const int64> broadcast_dims = broadcast_instr->dimensions();
auto is_broadcasted_dim = [&](int64 output_dim) {
- return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+ return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
};
// All of the output dims must be "broadcasted" dims for the other operand.
- if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+ if (!absl::c_all_of(scalar_indexed_const->output_dims(),
+ is_broadcasted_dim)) {
return nullptr;
}
@@ -894,7 +896,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// The scalar-indexed node "removes" the source dim and "inserts" the output
// dims. We do the opposite here to undo the scalar-indexed operation.
- ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+ absl::Span<const int64> output_dims = scalar_indexed_const->output_dims();
for (int64 i = output_dims.size() - 1; i >= 0; --i) {
CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
EraseAt(&simulated_index, output_dims[i]);
@@ -916,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
// inner_broadcast_result is the Broadcast'(Const0) bit in
// BinaryOp(Broadcast'(Const0), Const1)
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> inner_broadcast_result,
+ Literal inner_broadcast_result,
broadcast_const_operand->literal().Broadcast(
scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
@@ -926,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
} else {
TF_ASSIGN_OR_RETURN(
literal_for_new_source,
TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
- opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ opcode, inner_broadcast_result, scalar_indexed_const->literal())));
}
ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
@@ -970,15 +972,15 @@ namespace {
// Returns the non-contracting non-batch dimension (as per `contracting_dims`
// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
-gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
- int64 rank, ArraySlice<int64> contracting_dims,
- ArraySlice<int64> batch_dims) {
- gtl::optional<int64> result;
+absl::optional<int64> GetOnlyNonContractingNonBatchDim(
+ int64 rank, absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
+ absl::optional<int64> result;
for (int64 dim = 0; dim < rank; dim++) {
- if (!ArrayContains(contracting_dims, dim) &&
- !ArrayContains(batch_dims, dim)) {
+ if (!absl::c_linear_search(contracting_dims, dim) &&
+ !absl::c_linear_search(batch_dims, dim)) {
if (result.has_value()) {
- return gtl::nullopt;
+ return absl::nullopt;
}
result = dim;
}
@@ -995,10 +997,10 @@ gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
// `contracting_dims` and `batch_dims` are the contracting and batch dimensions
// of whatever operand `indexed_array` is to the dot (LHS or RHS).
bool CanFoldDotIntoIndexedArray(
- tensorflow::StringPiece tag,
- Analysis::ScalarIndexedConstantArray* indexed_array,
- ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
- gtl::optional<int64> non_contracting_non_batch_dim =
+ absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) {
+ absl::optional<int64> non_contracting_non_batch_dim =
GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
contracting_dims, batch_dims);
if (!non_contracting_non_batch_dim.has_value()) {
@@ -1029,7 +1031,8 @@ bool CanFoldDotIntoIndexedArray(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
+ const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+ ConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1044,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
new_dim_numbers.set_lhs_contracting_dimensions(
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, lhs->literal(), *rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting LHS
// dimension "went".
@@ -1062,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ const PrecisionConfig& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1078,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
new_dim_numbers.set_rhs_contracting_dimensions(
0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, *lhs->literal(), rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting RHS
// dimension "went".
@@ -1094,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
}
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
- const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
- Array* rhs) {
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
// Intuitively, if
//
// - The LHS of a dot product is a gathered sequence of rows from a constant
@@ -1118,6 +1124,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ precision_config,
lhs_indexed_array, rhs_constant);
}
}
@@ -1125,7 +1132,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
if (auto* rhs_indexed_array =
dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
- return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
+ precision_config, lhs_constant,
rhs_indexed_array);
}
}
@@ -1133,7 +1141,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
return nullptr;
}
-tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
+absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index e923dc39f7..df9cbab915 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -188,9 +188,7 @@ class IndexedArrayAnalysis {
// `output_dims` are the dimensions in the output array that are being used
// to compute an index into the `indices` array. See the class
// documentation and the overview for more details.
- tensorflow::gtl::ArraySlice<int64> output_dims() const {
- return output_dims_;
- }
+ absl::Span<const int64> output_dims() const { return output_dims_; }
private:
explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
@@ -265,19 +263,21 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
- Array* indices);
+ absl::Span<const int64> slice_sizes, Array* source, Array* indices);
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
+ const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+ ConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+ const PrecisionConfig& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfig& precision_config,
Array* lhs, Array* rhs);
// This tries to fold a ScalarIndexedArray which has another
@@ -303,7 +303,7 @@ class IndexedArrayAnalysis {
// G1 = [Arr[i] for i in I2]
StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
- tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+ absl::Span<const int64> output_dims, Shape shape);
// Reshapes a scalar-indexed node to remove the degenerate dimensions in its
// output. The result is always a scalar-indexed node.
@@ -313,8 +313,7 @@ class IndexedArrayAnalysis {
// Reshapes a scalar-indexed node such that the result has the degenerate
// dimensions `degenerate_dims`. The result is always a scalar-indexed node.
StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
- ScalarIndexedArray* operand,
- tensorflow::gtl::ArraySlice<int64> degenerate_dims);
+ ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims);
StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
const Shape& shape, ScalarIndexedConstantArray* operand);
@@ -348,21 +347,19 @@ class IndexedArrayAnalysis {
}
}
- Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ Literal* TakeOwnership(Literal literal) {
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
- StatusOr<Literal*> TakeOwnership(
- StatusOr<std::unique_ptr<Literal>> literal_or_error) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
- std::move(literal_or_error));
+ StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
owned_literals_.push_back(std::move(literal));
- return owned_literals_.back().get();
+ return &owned_literals_.back();
}
std::vector<std::unique_ptr<Array>> owned_tensors_;
- std::vector<std::unique_ptr<Literal>> owned_literals_;
+ std::vector<Literal> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
};
@@ -371,7 +368,7 @@ class IndexedArrayAnalysis {
// unconditionally add to the regular HLO pass pipeline.
class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override;
+ absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 5f4b42799b..2d03aebc1a 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -82,11 +82,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -102,11 +102,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5] parameter(0)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -122,11 +122,11 @@ ENTRY main {
operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
indices = s32[5,2] parameter(0)
ROOT gather = s32[5] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
@@ -141,11 +141,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0,2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3,1}
+ slice_sizes={1,3,1}
}
)";
@@ -160,11 +160,11 @@ ENTRY main {
operand = s32[3,3,1] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2,3] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
}
)";
@@ -179,11 +179,11 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[5] parameter(1)
ROOT gather = s32[5,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,2}
+ slice_sizes={1,2}
}
)";
@@ -199,17 +199,17 @@ ENTRY main {
indices_a = s32[5] parameter(0)
indices_b = s32[2] parameter(1)
gather_a = s32[5,3] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT gather_b = s32[2,3] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
}
)";
@@ -228,17 +228,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[2] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -256,17 +256,17 @@ ENTRY main {
indices_a = s32[2] parameter(1)
indices_b = s32[5,7] parameter(2)
gather_a = s32[2,6] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
}
)";
@@ -284,17 +284,17 @@ ENTRY main {
indices_a = s32[5,7] parameter(1)
indices_b = s32[4,8] parameter(2)
gather_a = s32[5,3,7] gather(operand, indices_a),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b),
- output_window_dims={1,2},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={1,2},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=2,
- window_bounds={5,3,1}
+ slice_sizes={5,3,1}
}
)";
@@ -312,11 +312,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5] parameter(0)
gather = s32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2] reshape(gather)
}
)";
@@ -333,11 +333,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,7] parameter(0)
gather = s32[5,4,7] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,7] reshape(gather)
}
)";
@@ -358,11 +358,11 @@ ENTRY main {
{{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[5,7] parameter(0)
gather = s32[5,2,6,7] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,2,6}
+ slice_sizes={1,2,6}
ROOT reshape = s32[5,3,4,7] reshape(gather)
}
)";
@@ -381,11 +381,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -408,14 +408,14 @@ ENTRY main {
operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
i.0 = s64[1,3]{1,0} parameter(0)
- g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
- elided_window_dims={0}, gather_dims_to_operand_dims={0},
- index_vector_dim=2, window_bounds={1,3}
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2},
+ collapsed_slice_dims={0}, start_index_map={0},
+ index_vector_dim=2, slice_sizes={1,3}
i.1 = s64[1] parameter(1)
- g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
- elided_window_dims={1}, gather_dims_to_operand_dims={1},
- index_vector_dim=1, window_bounds={1,1,3}
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2},
+ collapsed_slice_dims={1}, start_index_map={1},
+ index_vector_dim=1, slice_sizes={1,1,3}
ROOT reshape = s32[1,3]{1,0} reshape(g.1)
}
@@ -441,11 +441,11 @@ ENTRY main {
operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
indices = s32[1] parameter(0)
gather = s32[1,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,6] reshape(gather)
}
)";
@@ -469,11 +469,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}}})
indices = s32[1] parameter(0)
gather = s32[1,1,6] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1,2},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={1,1,6}
+ slice_sizes={1,1,6}
ROOT reshape = s32[1,1,1,6] reshape(gather)
}
)";
@@ -500,11 +500,11 @@ ENTRY main {
{1,2,3,4,5,6},{1,2,3,4,5,6}})
indices = s32[1,5] parameter(0)
gather = s32[1,5,6] gather(operand, indices),
- output_window_dims={2},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={2},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,6}
+ slice_sizes={1,6}
ROOT reshape = s32[1,1,5,6] reshape(gather)
}
)";
@@ -530,11 +530,11 @@ ENTRY main {
operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT reshape = s32[5,2,2,2,3] reshape(gather)
}
)";
@@ -562,11 +562,11 @@ ENTRY main {
{{1,2},{3,4},{5,6},{7,8},{9,10}}})
indices = s32[7] parameter(0)
gather = s32[3,2,7] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0,1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1,2}
+ slice_sizes={3,1,2}
ROOT reshape = s32[6,7] reshape(gather)
}
)";
@@ -594,11 +594,11 @@ ENTRY main {
{{1},{2},{3},{4}}})
indices = s32[5,6] parameter(0)
gather = s32[5,4,6,1] gather(operand, indices),
- output_window_dims={1,3},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1,3},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=2,
- window_bounds={1,4,1}
+ slice_sizes={1,4,1}
ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
}
)";
@@ -623,20 +623,20 @@ ENTRY main {
operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
indices = s32[5] parameter(0)
gather = f32[5,4] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT tanh = f32[5,4] tanh(gather)
}
)";
AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
(scalar-indexed-const (constant f32[3,4] f32[3,4] {
- { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
- { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
- { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+ { 0.761594, 0.964028, 0.995055, 0.999329 },
+ { 0.761594, 0.995055, 0.964028, 0.999329 },
+ { 0.999329, 0.995055, 0.964028, 0.761594 }
}) %indices 0->[0]))");
}
@@ -650,11 +650,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -678,11 +678,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
}
)";
@@ -706,11 +706,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
}
)";
@@ -733,11 +733,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -760,11 +760,11 @@ ENTRY main {
constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
indices = s32[5] parameter(0)
gather = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT add = s32[5,4] add(gather, constant_broadcasted)
}
)";
@@ -808,11 +808,11 @@ ENTRY main {
dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_lhs = s32[5,4] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,4}
+ slice_sizes={1,4}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -835,11 +835,11 @@ ENTRY main {
dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
indices = s32[5] parameter(0)
dot_lhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
)";
@@ -863,11 +863,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[3,5] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
@@ -892,11 +892,11 @@ ENTRY main {
dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
indices = s32[5] parameter(0)
dot_rhs = s32[5,3] gather(gather_operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1,3}
+ slice_sizes={1,3}
ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}
)";
@@ -921,11 +921,11 @@ ENTRY main {
dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
indices = s32[4] parameter(0)
dot_rhs = s32[2,3,4] gather(gather_operand, indices),
- output_window_dims={0,1},
- elided_window_dims={2},
- gather_dims_to_operand_dims={2},
+ offset_dims={0,1},
+ collapsed_slice_dims={2},
+ start_index_map={2},
index_vector_dim=1,
- window_bounds={2,3,1}
+ slice_sizes={2,3,1}
ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
lhs_contracting_dims={2}, rhs_contracting_dims={1},
lhs_batch_dims={0}, rhs_batch_dims={0}
@@ -952,11 +952,11 @@ ENTRY main {
dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
indices = s32[2] parameter(0)
dot_lhs = s32[3,2] gather(gather_operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3,1}
+ slice_sizes={3,1}
ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc
index 5c193fceb9..5fd779ebf9 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/inliner.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index a523811f6c..efa8ed3abc 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -27,7 +27,7 @@ namespace xla {
class Inliner : public HloPassInterface {
public:
~Inliner() override = default;
- tensorflow::StringPiece name() const override { return "inline"; }
+ absl::string_view name() const override { return "inline"; }
// Run inlining on the given computation. Returns whether the computation was
// changed.
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 32937b33b3..93a74dbfa6 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index f33942d679..8c907eae0c 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -121,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kConvolution:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDomain:
@@ -130,7 +132,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
- case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap:
@@ -171,7 +172,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
});
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kBroadcast) {
+ if (operand->opcode() == HloOpcode::kBroadcast ||
+ operand->opcode() == HloOpcode::kIota) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
@@ -189,13 +191,13 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (consumer == producer) {
return true;
}
- if (!consumer->IsFusable()) {
+ if (!consumer->IsFusible()) {
return false;
}
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
- // whether it's fusable.
+ // whether it's fusible.
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
@@ -205,7 +207,7 @@ bool InstructionFusion::CanFuseOnAllPaths(
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
- // producer to be fusable into consumer on all paths.
+ // producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
@@ -216,8 +218,8 @@ bool InstructionFusion::CanFuseOnAllPaths(
}
InstructionFusion::HloInstructionSet
-InstructionFusion::ComputeGloballyUnfusable(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
+InstructionFusion::ComputeGloballyUnfusible(
+ absl::Span<HloInstruction* const> post_order) {
// Forbid fusion of producers that:
// a) Need to be duplicated, unless they can be fused into all consumers
// via all paths.
@@ -270,19 +272,19 @@ InstructionFusion::ComputeGloballyUnfusable(
// all of its consumers on all paths.
//
// That means, that for:
- // A --> B (fusable)
- // \-> C (non-fusable)
+ // A --> B (fusible)
+ // \-> C (non-fusible)
// A will be not allowed to be fused into B, as it cannot be fused into C.
//
// Similarly, for:
// A -------------> B
// \-> C -> D -/
// If:
- // - A is fusable into B and C, and D is fusable into B
- // - C is *not* fusable into D
+ // - A is fusible into B and C, and D is fusible into B
+ // - C is *not* fusible into D
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
- if (producer->IsFusable() &&
+ if (producer->IsFusible() &&
CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
continue;
}
@@ -318,7 +320,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
InsertOrDie(&post_order_index, post_order[i], i);
}
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order);
+ HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
@@ -341,7 +343,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// consistent.
post_order_index.erase(instruction);
- if (!instruction->IsFusable() &&
+ if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
@@ -413,7 +415,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
- if (!operand->IsFusable()) {
+ if (!operand->IsFusible()) {
continue;
}
@@ -497,7 +499,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput(
bool InstructionFusion::MultiOutputFusionCreatesCycle(
HloInstruction* producer, HloInstruction* consumer) {
- return c_any_of(
+ return absl::c_any_of(
consumer->operands(), [&](const HloInstruction* consumer_operand) {
// The fusion algorithm traverses the HLO graph in reverse post order.
// Thus `cosumers` is visited before its operands (including
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index f73ca9adf7..00b658959a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface {
bool may_duplicate = true)
: is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
~InstructionFusion() override = default;
- tensorflow::StringPiece name() const override { return "fusion"; }
+ absl::string_view name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
@@ -122,8 +122,8 @@ class InstructionFusion : public HloPassInterface {
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
- HloInstructionSet ComputeGloballyUnfusable(
- tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
+ HloInstructionSet ComputeGloballyUnfusible(
+ absl::Span<HloInstruction* const> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 9e7a15f033..da1ad90959 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
.ValueOrDie());
}
-TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
+TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
@@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
}
-TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
+TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
// Make sure we do not duplicate the add, as we cannot fuse through the rng.
//
// p0 -> add -------------------------> sub
@@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
// A variant of the above that allows the algorithm to put add2 into the set
- // of unfusable ops to short-circuit the decision whether add1 should be fused
+ // of unfusible ops to short-circuit the decision whether add1 should be fused
// into sub2.
//
// /---------------\
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 8652599dc6..146c9052f1 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -12,12 +12,11 @@ cc_library(
srcs = ["interpreter_transfer_manager.cc"],
hdrs = ["interpreter_transfer_manager.h"],
deps = [
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:generic_transfer_manager",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service/interpreter:platform_id",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains per-platform transfer manager registration
)
@@ -32,8 +31,6 @@ cc_library(
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
@@ -54,6 +51,7 @@ cc_library(
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
],
alwayslink = True, # Contains compiler registration
)
@@ -79,7 +77,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
@@ -91,6 +88,8 @@ cc_library(
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -116,5 +115,6 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_headers_lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 9f8f4bda87..bb69cb9c47 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -69,8 +69,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
// Create executable from only the Hlo module.
std::unique_ptr<Executable> executable =
- xla::MakeUnique<InterpreterExecutable>(std::move(hlo_module),
- xla::MakeUnique<HloEvaluator>());
+ absl::make_unique<InterpreterExecutable>(
+ std::move(hlo_module), absl::make_unique<HloEvaluator>());
return std::move(executable);
}
@@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction()
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
se::interpreter::kXlaInterpreterPlatformId, []() {
- return xla::MakeUnique<xla::interpreter::InterpreterCompiler>();
+ return absl::make_unique<xla::interpreter::InterpreterCompiler>();
});
xla::ComputationPlacer::RegisterComputationPlacer(
se::interpreter::kXlaInterpreterPlatformId,
- []() { return xla::MakeUnique<xla::ComputationPlacer>(); });
+ []() { return absl::make_unique<xla::ComputationPlacer>(); });
return true;
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 8d40c08d55..a06d6113e8 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
@@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
@@ -73,30 +73,29 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// Transform the ShapedBuffer arguments into literals which the evaluator
// consumes.
- std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ TF_ASSIGN_OR_RETURN(Literal arg_literal,
transfer_manager->TransferLiteralFromDevice(
run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
// Execute the graph using the HloEvaluator.
- std::unique_ptr<Literal> result_literal;
+ Literal result_literal;
{
tensorflow::mutex_lock lock(evaluator_lock_);
- TF_ASSIGN_OR_RETURN(result_literal,
- evaluator_->Evaluate<std::unique_ptr<Literal>>(
- *computation, arg_literals));
+ TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+ *computation, arg_literals));
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
transfer_manager->AllocateScopedShapedBuffer(
- result_literal->shape(), run_options->allocator(),
+ result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- run_options->stream(), *result_literal, result));
+ run_options->stream(), result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -111,7 +110,7 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return tensorflow::errors::Unimplemented(
"ExecuteAsyncOnStream is not yet supported on Interpreter.");
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index 91d8148d26..3b1ebce0c7 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable {
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_);
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
+ absl::Span<const ShapedBuffer* const> arguments) override;
static int64 ShapeSizeBytes(const Shape& shape);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index db6b910b32..fbb9945784 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -22,9 +22,9 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_description.h"
@@ -47,7 +47,7 @@ limitations under the License.
namespace stream_executor {
namespace interpreter {
-using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+using Args = absl::Span<const DeviceMemoryBase>;
class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
public:
diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
index d27cd7502f..7955ee5cf3 100644
--- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
@@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager()
static std::unique_ptr<xla::TransferManager>
CreateInterpreterTransferManager() {
- return xla::MakeUnique<xla::InterpreterTransferManager>();
+ return absl::make_unique<xla::InterpreterTransferManager>();
}
static bool InitModule() {
diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
index 2b44f30821..b732230fdd 100644
--- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/core/platform/macros.h"
@@ -33,4 +33,4 @@ class InterpreterTransferManager : public GenericTransferManager {
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc
index 42c2c28997..c9b40d3c61 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.cc
+++ b/tensorflow/compiler/xla/service/interpreter/platform.cc
@@ -17,13 +17,14 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/interpreter/executor.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/status_macros.h"
-#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
@@ -70,15 +71,15 @@ port::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
XlaInterpreterPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
- auto executor = MakeUnique<StreamExecutor>(
- this, MakeUnique<XlaInterpreterExecutor>(config.plugin_config));
+ auto executor = absl::make_unique<StreamExecutor>(
+ this, absl::make_unique<XlaInterpreterExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
port::error::INTERNAL,
- port::Printf(
+ absl::StrFormat(
"failed initializing StreamExecutor for device ordinal %d: %s",
- config.ordinal, init_status.ToString().c_str())};
+ config.ordinal, init_status.ToString())};
}
return std::move(executor);
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 805fdb2d5b..082bf8bffe 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -26,9 +26,13 @@ limitations under the License.
#include <string>
#include <tuple>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -48,21 +52,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
-// For now moving only one API here, but we should have a single top level
-// anonymous namespace, instead of three or four spread all over this file.
-namespace {
-
-} // namespace
-
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
@@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
}
string BufferLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s",
- buffer_->ToString().c_str(),
- LayoutUtil::HumanString(layout_).c_str());
+ return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
+ LayoutUtil::HumanString(layout_));
}
OperandLayoutConstraint::OperandLayoutConstraint(
@@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint(
}
string OperandLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf(
- "OperandLayoutConstraint %s, operand %lld: %s",
- instruction_->name().c_str(), operand_no_,
- shape_layout_.ToString().c_str());
+ return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
+ instruction_->name(), operand_no_,
+ shape_layout_.ToString());
}
string ResultLayoutConstraint::ToString() const {
- return tensorflow::strings::Printf("ResultLayoutConstraint: %s",
- shape_layout_.ToString().c_str());
+ return absl::StrFormat("ResultLayoutConstraint: %s",
+ shape_layout_.ToString());
}
LayoutConstraints::LayoutConstraints(
@@ -137,7 +129,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
}
auto& buffer_set =
buffer_sets_cache_
- .emplace(instruction, MakeUnique<PointsToSet::BufferSet>())
+ .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
.first->second;
const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
points_to_set.ForEachElement(
@@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
return FailedPrecondition(
"Layout of buffer %s cannot be constrained because buffer is not "
"array-shaped, has shape: %s",
- buffer.ToString().c_str(),
- ShapeUtil::HumanString(buffer.shape()).c_str());
+ buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
}
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
@@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
- buffer.ToString().c_str(),
- LayoutUtil::HumanString(curr_constraint.layout()).c_str(),
- LayoutUtil::HumanString(layout).c_str());
+ buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
+ LayoutUtil::HumanString(layout));
}
iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
} else {
@@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
}
if (curr_shape_layout->mandatory()) {
return FailedPrecondition(
- "Operand %lld of instruction %s already has a layout constraint "
+ "Operand %d of instruction %s already has a layout constraint "
"%s, cannot add incompatible constraint %s",
- operand_no, instruction->name().c_str(),
- curr_shape_layout->shape_layout().ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ operand_no, instruction->name(),
+ curr_shape_layout->shape_layout().ToString(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
}
@@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
// layouts beyond this immediate use and is complicated to handle.
if (OperandBufferForwarded(instruction, operand_no)) {
return FailedPrecondition(
- "Cannot constraint layout of operand %lld of instruction %s "
+ "Cannot constraint layout of operand %d of instruction %s "
"because instruction forwards operand's LogicalBuffer(s)",
- operand_no, instruction->name().c_str());
+ operand_no, instruction->name());
}
auto key = std::make_pair(instruction, operand_no);
@@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
return FailedPrecondition(
"Result of computation %s already has the layout constraint %s, "
"cannot add incompatible constraint %s",
- computation_->name().c_str(), curr_shape_layout->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ computation_->name(), curr_shape_layout->ToString(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
@@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout(
if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
return FailedPrecondition(
"Instruction %s of shape %s cannot be assigned incompatible layout %s",
- instruction->name().c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ instruction->name(), ShapeUtil::HumanString(instruction->shape()),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout));
}
// Create a BufferLayoutConstraint for each array shape in the output of the
@@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const {
string LayoutConstraints::ToString() const {
string output;
- tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
- computation_->name(), ":\n");
+ absl::StrAppend(&output, "LayoutConstraints for computation ",
+ computation_->name(), ":\n");
for (auto* instruction : computation_->MakeInstructionPostOrder()) {
- tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(),
- "\n");
+ absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
for (int64 i = 0; i < instruction->operand_count(); ++i) {
if (OperandLayout(instruction, i) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " operand (", i,
- "): ", OperandLayout(instruction, i)->ToString(), "\n");
+ absl::StrAppend(&output, " operand (", i,
+ "): ", OperandLayout(instruction, i)->ToString(), "\n");
}
}
for (const LogicalBuffer* buffer :
points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
if (BufferLayout(*buffer) != nullptr) {
- tensorflow::strings::StrAppend(
- &output, " ", buffer->ToString(), " : ",
- LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
+ absl::StrAppend(&output, " ", buffer->ToString(), " : ",
+ LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
}
}
}
if (ResultLayout() != nullptr) {
- tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(),
- "\n");
+ absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
}
return output;
}
@@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter,
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
- parameter->ToString().c_str(), parameter_layout.ToString().c_str());
+ parameter->ToString(), parameter_layout.ToString());
}
return Status::OK();
}
@@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) {
constant->shape())) {
return InternalError(
"constant instruction %s does not match the layout of its literal %s",
- constant->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str());
+ constant->ToString(),
+ ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
}
return Status::OK();
}
@@ -870,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
? instruction.sharding().GetSubSharding(instruction.shape(), index)
: instruction.sharding();
// We propagate the sharding to the copied instruction only if it is a
- // special sharding, like tiled ones, or special devices like the
- // HostCompute module.
+ // special sharding, like tiled ones.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
auto device = sharding.UniqueDevice();
@@ -908,13 +892,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
- instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str(),
- buffer->ToString().c_str(),
- ShapeUtil::HumanStringWithLayout(instruction_subshape)
- .c_str(),
- ShapeUtil::HumanStringWithLayout(buffer->shape())
- .c_str());
+ instruction->name(), absl::StrJoin(index, ","),
+ buffer->ToString(),
+ ShapeUtil::HumanStringWithLayout(instruction_subshape),
+ ShapeUtil::HumanStringWithLayout(buffer->shape()));
}
}
}
@@ -998,17 +979,18 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()));
CHECK(ShapeUtil::IsArray(operand->shape()));
- if (instruction->IsElementwiseOnOperand(operand_no) &&
- !ShapeUtil::IsScalar(operand->shape()) &&
+ if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
- ShapeUtil::Rank(instruction->shape())) {
- // Assign operands the same layout as the instruction, so that
+ ShapeUtil::Rank(instruction->shape()) &&
+ InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) {
+ // Propagate the result layout to the operand layout if the instruction
+ // requires the same layout out for the result and the operand.
+ //
+ // For elementwise operations, using the same layout for the operands and
+ // the result also has the following benefits:
// 1) the elementwise operation can reuse its operand's buffer, and
// 2) the input and output elements can reuse the same linear index.
- //
- // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
- // from assigning the same layout to input and output.
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
if (instruction->opcode() == HloOpcode::kReshape) {
@@ -1031,13 +1013,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
*operand_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(operand_shape);
if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) {
- return MakeUnique<Layout>(operand_shape.layout());
+ return absl::make_unique<Layout>(operand_shape.layout());
}
if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) {
*operand_shape.mutable_layout() = output_layout;
if (ShapeUtil::ReshapeIsBitcast(operand_shape,
output_shape_with_layout)) {
- return MakeUnique<Layout>(output_layout);
+ return absl::make_unique<Layout>(output_layout);
}
}
auto aligned_operand_shape =
@@ -1046,7 +1028,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
auto operand_layout = aligned_operand_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
@@ -1062,7 +1044,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
return nullptr;
@@ -1076,11 +1058,11 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
CHECK(ShapeUtil::IsArray(user->shape()) &&
ShapeUtil::IsArray(operand->shape()));
- if (user->IsElementwiseOnOperand(operand_no) &&
- !ShapeUtil::IsScalar(operand->shape()) &&
- ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
+ if (!ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
+ InstructionRequiresInputLayoutEqualToOutputLayout(user)) {
// Assign users the same layout as the operand.
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
if (user->opcode() == HloOpcode::kReshape) {
@@ -1103,13 +1085,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
*output_shape.mutable_layout() =
LayoutUtil::GetDefaultLayoutForShape(output_shape);
if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) {
- return MakeUnique<Layout>(output_shape.layout());
+ return absl::make_unique<Layout>(output_shape.layout());
}
if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) {
*output_shape.mutable_layout() = operand_layout;
if (ShapeUtil::ReshapeIsBitcast(output_shape,
operand_shape_with_layout)) {
- return MakeUnique<Layout>(operand_layout);
+ return absl::make_unique<Layout>(operand_layout);
}
}
auto aligned_user_shape =
@@ -1118,7 +1100,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
auto user_layout = aligned_user_shape.value().layout();
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
}
@@ -1134,7 +1116,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
}
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
- return MakeUnique<Layout>(user_layout);
+ return absl::make_unique<Layout>(user_layout);
}
return nullptr;
@@ -1385,7 +1367,7 @@ StatusOr<Layout> InferArrayLayout(
// This should not happen because we've assigned layouts to all
// instructions preceding this one.
return InternalError("LogicalBuffer %s does not have a layout",
- source_buffer->ToString().c_str());
+ source_buffer->ToString());
}
if (first_buffer_layout == nullptr) {
@@ -1400,9 +1382,8 @@ StatusOr<Layout> InferArrayLayout(
return FailedPrecondition(
"Array at index {%s} in instruction %s aliases buffers %s "
"and %s which have different layouts",
- tensorflow::str_util::Join(index, ",").c_str(),
- instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
- source_buffer->ToString().c_str());
+ absl::StrJoin(index, ","), instruction->name(),
+ source_buffers[0]->ToString(), source_buffer->ToString());
}
}
@@ -1570,7 +1551,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// present in the IR before layout assignment is a bug.
return InternalError(
"Unexpected bitcast operation seen during layout assignment: %s.",
- instruction->ToString().c_str());
+ instruction->ToString());
}
if (instruction->opcode() != HloOpcode::kInfeed) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
@@ -1822,6 +1803,107 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
return true;
}
+bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
+ const HloInstruction* instruction) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAnd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kBitcastConvert:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kClz:
+ case HloOpcode::kComplex:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCos:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kAllToAll:
+ case HloOpcode::kCollectivePermute:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kEq:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kFft:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kImag:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLe:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kLt:
+ case HloOpcode::kMap:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kNot:
+ case HloOpcode::kOr:
+ case HloOpcode::kXor:
+ case HloOpcode::kPad:
+ case HloOpcode::kPower:
+ case HloOpcode::kReal:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kReverse:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return true;
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConstant:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kCopy:
+ case HloOpcode::kDomain:
+ case HloOpcode::kDot:
+ case HloOpcode::kFusion:
+ case HloOpcode::kGather:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kIota:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReshape:
+ case HloOpcode::kRng:
+ case HloOpcode::kScatter:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kAfterAll:
+ case HloOpcode::kTrace:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kTuple:
+ return false;
+ }
+}
+
Status LayoutAssignment::Init() {
computation_layouts_.clear();
*entry_computation_layout_ = saved_entry_computation_layout_;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index f9e8dbea2f..cf545031d3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
- tensorflow::StringPiece name() const override { return "layout-assignment"; }
+ absl::string_view name() const override { return "layout-assignment"; }
// Assign layouts to the given module. Returns whether the module was changed
// (any layouts were changed).
StatusOr<bool> Run(HloModule* module) override;
+ // Returns true if the instruction requires that operands with the same rank
+ // as the output have to have the same layout as the output.
+ virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
+ const HloInstruction* instruction);
+
protected:
// These methods, invoked by PropagateConstraints, propagate a layout
// constraint to its neighbors (i.e. operands and users) in order to minimize
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index a16fa75e30..752a61476d 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@@ -34,13 +35,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
@@ -49,7 +49,7 @@ namespace {
using ::testing::ElementsAre;
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout,
@@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase {
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
- std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) {
+ std::vector<int64> LayoutOf(HloModule* module, absl::string_view name) {
auto minor_to_major =
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
@@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) {
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
- Shape ashape = constant_literal1->shape();
+ Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
@@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
EXPECT_TRUE(
AlgebraicSimplifier(/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return false; })
- .Run(module.get())
+ .Run(module)
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
- HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto module = CreateNewModule();
@@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
- EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
@@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
- module =
+ std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
- ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
- ->RunBackend(std::move(module),
+ ->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
@@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(&module(), &computation_layout);
- EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
- EXPECT_THAT(FindInstruction(module.get(), "gte1")
+ EXPECT_THAT(FindInstruction(&module(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
@@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
HloComputation* computation = module->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
- AssignLayouts(module.get(), &computation_layout);
+ AssignLayouts(module, &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(module.get()).status();
+ Status error_status = layout_assignment.Run(module).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
@@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
}
)";
- auto module = ParseHloString(module_str).ValueOrDie();
+ ParseAndVerifyModule(module_str);
ComputationLayout computation_layout(
- module->entry_computation()->ComputeProgramShape());
+ module().entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
@@ -851,14 +851,151 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
- AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+ AssignLayouts(&module(), &computation_layout, &channel_constraints);
- EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
- EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::GetSubshape(
- FindInstruction(module.get(), "send")->shape(), {0}),
- ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+ EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+}
+
+TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopySliceOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[4,5]{0,1} parameter(1)
+ slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
+ ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root, op::Add(op::Parameter(),
+ op::Slice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)))));
+}
+
+TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[4,5]{0,1} parameter(1)
+ par2 = s32[2] parameter(2)
+ dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4}
+ ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
+}
+
+TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
+ const char* module_str = R"(
+ HloModule CopyConcatOperandToAvoidImplicitLayoutChange
+
+ ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
+ par0 = f32[3,8]{1,0} parameter(0)
+ par1 = f32[3,5]{0,1} parameter(1)
+ par2 = f32[3,3]{1,0} parameter(2)
+ concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
+ dimensions={1}
+ ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
+}
+
+TEST_F(LayoutAssignmentTest,
+ ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
+ const char* module_str = R"(
+ HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
+
+ ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
+ par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
+ par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
+ ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
+ window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
+ feature_group_count=1
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
+ const char* module_str = R"(
+ HloModule PropagatingLayoutFromResultToOperand
+
+ ENTRY PropagatingLayoutFromResultToOperand {
+ par0 = f32[4,5]{1,0} parameter(0)
+ ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
+ }
+ )";
+
+ ParseAndVerifyModule(module_str);
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
+ EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
+ op::ShapeWithLayout(shape_copy))));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index cdd3daf73b..540bbb7c7a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -38,6 +38,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -69,6 +70,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
"@llvm//:support",
"@llvm//:target",
@@ -88,6 +91,9 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -103,6 +109,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -120,6 +128,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings:str_format",
"@llvm//:core",
],
)
@@ -133,9 +142,7 @@ cc_library(
":llvm_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"@llvm//:core",
@@ -159,6 +166,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -193,7 +201,10 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter",
"//tensorflow/compiler/xla/service/gpu:partition_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
"@llvm//:core",
+ "@llvm//:support",
],
)
@@ -208,6 +219,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
"@llvm//:core",
],
)
@@ -219,7 +231,7 @@ cc_library(
deps = [
":llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
"@llvm//:core",
],
)
@@ -230,6 +242,7 @@ cc_library(
hdrs = ["buffer_assignment_util.h"],
deps = [
"//tensorflow/compiler/xla/service:buffer_assignment",
+ "@com_google_absl//absl/strings",
],
)
@@ -242,3 +255,12 @@ cc_library(
"@llvm//:core",
],
)
+
+cc_library(
+ name = "ir_builder_mixin",
+ srcs = [],
+ hdrs = ["ir_builder_mixin.h"],
+ deps = [
+ "@llvm//:core",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index fe9eab93aa..8d9fa99d82 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace llvm_ir {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index fe5ec1cc66..b6ae4932f5 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -61,7 +61,7 @@ ENTRY while3 {
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
-; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0
; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
index 4eb5d9fb47..bdce4a171b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
+#include "absl/strings/str_cat.h"
namespace xla {
namespace llvm_ir {
@@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName(
c = '_';
}
}
- return tensorflow::strings::StrCat("buffer_for_", instr_name);
+ return absl::StrCat("buffer_for_", instr_name);
}
const Literal& LiteralForConstantAllocation(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index 27fbb11e2e..cc2e862f2e 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
const Shape& update_shape, const ElementGenerator& start_indices_generator,
bool is_signed, ElementGenerator update_array_generator,
const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
- tensorflow::StringPiece name, llvm::IRBuilder<>* b) {
+ absl::string_view name, llvm::IRBuilder<>* b) {
const Shape& output_shape = output_array.GetShape();
// Read start indices from start_indices_generator.
@@ -99,10 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
}
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b) {
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b) {
VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
// No need to use operand_arrays[0], the input array of the
@@ -130,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace(
//
// Emits a sequential loop if launch_dimensions is null.
static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
@@ -174,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
}
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
@@ -184,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace(
}
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
return EmitFusedDynamicUpdateSliceInPlaceImpl(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
index 3502577d23..fb3e4eb97c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h
@@ -63,26 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace(
// Emits IR for running the given dynamic-update-slice op in-place -- that is,
// where the input and output buffers share the same slice, so we can simply
// modify the input/output buffer without touching any of the other elements.
-Status EmitDynamicUpdateSliceInPlace(
- tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
- const IrArray& output_array, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
+ const IrArray& output_array,
+ absl::string_view name,
+ llvm::IRBuilder<>* b);
// Given a loop-fusion node whose root is a dynamic-update-slice op whose
// array-to-be-updated and output share the same buffer slice, emits
// (sequential) code for a fusion node that does the dynamic-update-slice in
// place.
Status EmitFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
llvm::IRBuilder<>* b);
// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with
// the given launch dimensions.
Status EmitParallelFusedDynamicUpdateSliceInPlace(
- HloInstruction* fusion,
- tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
+ HloInstruction* fusion, absl::Span<const IrArray> fusion_operand_arrays,
const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index 72ede377e1..b606c993a2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement(
return Unimplemented(
"GetTupleElement fusion currently only supports"
" parameter operands, but found operand: %s",
- operand->name().c_str());
+ operand->name());
}
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
@@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
}
Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
std::vector<llvm::Type*> operand_elemental_ir_types;
for (HloInstruction* operand : operands) {
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index 30471480c4..44d21fa750 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <unordered_map>
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
public:
using Generator = llvm_ir::ElementGenerator;
- FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
+ FusedIrEmitter(absl::Span<const llvm_ir::IrArray> parameter_arrays,
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
tiled_parameter_info_(nullptr),
@@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
private:
// Arrays of parameters of fusion instruction
- tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+ absl::Span<const llvm_ir::IrArray> parameter_arrays_;
const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
ElementalIrEmitter* elemental_emitter_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 2b6caee6aa..67f7423121 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
Delinearize(&multidim_, linear, shape, b);
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
llvm::Value* linear, const Shape& shape)
: multidim_(multidim.begin(), multidim.end()),
linear_(linear),
@@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
<< " should have a layout.";
}
-IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
const Shape& shape, llvm::IRBuilder<>* b)
: multidim_(multidim.begin(), multidim.end()),
layout_(shape.layout()),
@@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
- Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
- multidim_, common_factors[k].second,
+ Index(absl::Span<llvm::Value* const>(multidim_).subspan(
+ common_factors[k].second,
common_factors[k + 1].second - common_factors[k].second),
index_type_)
- .Linearize(
- tensorflow::gtl::ArraySlice<int64>(
- AsInt64Slice(output_shape.dimensions()),
- common_factors[k].second,
- common_factors[k + 1].second - common_factors[k].second),
- builder);
+ .Linearize(AsInt64Slice(output_shape.dimensions())
+ .subspan(common_factors[k].second,
+ common_factors[k + 1].second -
+ common_factors[k].second),
+ builder);
// Delinearizes logical_linear_index for the source array in row-major
// collapsed order. The first rank-1 indices are the remainder of the
// linear index by each dimension size.
@@ -185,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
- llvm::IRBuilder<>* builder) const {
+ const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
Index source_index(index_type_, multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
@@ -208,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice(
IrArray::Index IrArray::Index::SourceIndexOfTranspose(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
std::vector<llvm::Value*> operand_multidim_index =
Permute(dimension_mapping, multidim());
@@ -257,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ absl::Span<const int64> dimension_mapping,
llvm::IRBuilder<>* builder) const {
int64 rank = ShapeUtil::Rank(operand_shape);
std::vector<llvm::Value*> source_index(rank);
@@ -322,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
return Index(source_index, linear, operand_shape);
}
-llvm::Value* IrArray::Index::Linearize(
- tensorflow::gtl::ArraySlice<int64> dimensions,
- llvm::IRBuilder<>* builder) const {
+llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
+ llvm::IRBuilder<>* builder) const {
// Each dimension is multiplied by the product of the sizes of all
// earlier dimensions and added to the accumulator logical_linear_index.
CHECK_EQ(size(), dimensions.size());
@@ -342,9 +339,9 @@ llvm::Value* IrArray::Index::Linearize(
return logical_linear_index;
}
-llvm::Value* IrArray::EmitArrayElementAddress(
- const IrArray::Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
+ llvm::IRBuilder<>* b,
+ absl::string_view name) const {
if (ShapeUtil::IsScalar(*shape_)) {
// Special handling of scalars: a scalar pretends to have the same value for
// every index, thus effectively implementing broadcasting of its value
@@ -402,7 +399,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata(
llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
llvm::IRBuilder<>* b,
- tensorflow::StringPiece name) const {
+ absl::string_view name) const {
llvm::Value* element_address = EmitArrayElementAddress(index, b, name);
llvm::LoadInst* load = b->CreateLoad(element_address);
AnnotateLoadStoreInstructionWithMetadata(load);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index 28ca793e3e..f4b05f29c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -19,13 +19,14 @@ limitations under the License.
#include <map>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -69,7 +70,7 @@ class IrArray {
// Constructs an index from multi-dimensional index "multidim". The linear
// index is set to nullptr.
- explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ explicit Index(absl::Span<llvm::Value* const> multidim,
llvm::Type* index_ty = nullptr)
: multidim_(multidim.begin(), multidim.end()) {
if (size() == 0) {
@@ -81,7 +82,7 @@ class IrArray {
}
}
CHECK_NE(index_type_, nullptr);
- CHECK(c_all_of(multidim, [&](llvm::Value* v) {
+ CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) {
return index_type_ == v->getType();
}));
}
@@ -98,14 +99,14 @@ class IrArray {
// that it indexes into.
//
// Precondition: "shape" has a layout.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- const Shape& shape, llvm::IRBuilder<>* b);
+ Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
+ llvm::IRBuilder<>* b);
// Constructs an index from both a multi-dimensional index and a linear
// index. "shape" has the same meaning as that in the constructor that takes
// only a linear index.
- Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
- llvm::Value* linear, const Shape& shape);
+ Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
+ const Shape& shape);
const std::vector<llvm::Value*>& multidim() const { return multidim_; }
llvm::Value* linear() const { return linear_; }
@@ -144,17 +145,15 @@ class IrArray {
// by starting indices `starts` and stride values `strides`.
//
// Precondition: "this" is an index into a slice whose shape is `shape`.
- Index SourceIndexOfSlice(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> strides,
+ Index SourceIndexOfSlice(const Shape& shape, absl::Span<const int64> starts,
+ absl::Span<const int64> strides,
llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a transpose from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfTranspose(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Given that "this" is the target index of a bitcast from `operand_shape`
// to `shape`, returns the source index.
@@ -163,14 +162,13 @@ class IrArray {
// Given that "this" is the target index of a broadcast from `operand_shape`
// to `shape` with the given dimension mapping, returns the source index.
- Index SourceIndexOfBroadcast(
- const Shape& shape, const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping,
- llvm::IRBuilder<>* builder) const;
+ Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
+ absl::Span<const int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
// Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
// returns the index into the sole dimension 0 of the new shape.
- llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
+ llvm::Value* Linearize(absl::Span<const int64> dimensions,
llvm::IRBuilder<>* builder) const;
llvm::Type* GetType() const { return index_type_; }
@@ -240,7 +238,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Attach metadata this IrArray instance knows about to "instruction".
void AnnotateLoadStoreInstructionWithMetadata(
@@ -254,7 +252,7 @@ class IrArray {
// The optional name is useful for debugging when looking at
// the emitted LLVM IR.
llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
- tensorflow::StringPiece name = "") const;
+ absl::string_view name = "") const;
// Emit IR to write the given value to the array element at the given index.
void EmitWriteArrayElement(const Index& index, llvm::Value* value,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
new file mode 100644
index 0000000000..abc06fb7b4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -0,0 +1,400 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
+
+#include "llvm/IR/IRBuilder.h"
+
+namespace xla {
+
+// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
+// into a class. Intended to be used as a CRTP base class, like:
+//
+// class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
+// llvm::IRBuilder<>* builder() { return builder_; }
+//
+// void EmitFoo(HloInstruction* foo) {
+// Add(Mul(...), FPToUI(...));
+// }
+// };
+
+template <typename Derived>
+class IrBuilderMixin {
+ protected:
+ template <class... Args>
+ llvm::Value* Add(Args&&... args) {
+ return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* AlignedLoad(Args&&... args) {
+ return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* AlignedStore(Args&&... args) {
+ return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::AllocaInst* Alloca(Args&&... args) {
+ return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* And(Args&&... args) {
+ return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicCmpXchg(Args&&... args) {
+ return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AtomicRMW(Args&&... args) {
+ return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* BitCast(Args&&... args) {
+ return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Br(Args&&... args) {
+ return mixin_builder()->CreateBr(std::forward<Args>(args)...);
+ }
+
+ llvm::CallInst* Call(llvm::Value* callee,
+ llvm::ArrayRef<llvm::Value*> args = llvm::None,
+ const llvm::Twine& name = "",
+ llvm::MDNode* fp_math_tag = nullptr) {
+ return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
+ }
+
+ template <class... Args>
+ llvm::BranchInst* CondBr(Args&&... args) {
+ return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
+ return mixin_builder()->CreateConstInBoundsGEP1_32(
+ std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FAdd(Args&&... args) {
+ return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FMul(Args&&... args) {
+ return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateGEP(ptr, idx_list, name);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpEQ(Args&&... args) {
+ return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpNE(Args&&... args) {
+ return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULE(Args&&... args) {
+ return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpULT(Args&&... args) {
+ return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
+ }
+
+ llvm::Value* InBoundsGEP(llvm::Value* ptr,
+ llvm::ArrayRef<llvm::Value*> idx_list,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name);
+ }
+
+ llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateExtractValue(agg, idxs, name);
+ }
+
+ llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
+ llvm::ArrayRef<unsigned> idxs,
+ const llvm::Twine& name = "") {
+ return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
+ }
+
+ template <class... Args>
+ llvm::Value* IntToPtr(Args&&... args) {
+ return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::LoadInst* Load(Args&&... args) {
+ return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::CallInst* MemCpy(Args&&... args) {
+ return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Mul(Args&&... args) {
+ return mixin_builder()->CreateMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWAdd(Args&&... args) {
+ return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWMul(Args&&... args) {
+ return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* NSWSub(Args&&... args) {
+ return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Or(Args&&... args) {
+ return mixin_builder()->CreateOr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PointerCast(Args&&... args) {
+ return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* PtrToInt(Args&&... args) {
+ return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SDiv(Args&&... args) {
+ return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Select(Args&&... args) {
+ return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SRem(Args&&... args) {
+ return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::StoreInst* Store(Args&&... args) {
+ return mixin_builder()->CreateStore(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UDiv(Args&&... args) {
+ return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* URem(Args&&... args) {
+ return mixin_builder()->CreateURem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* VectorSplat(Args&&... args) {
+ return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ZExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* AShr(Args&&... args) {
+ return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOEQ(Args&&... args) {
+ return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpOLT(Args&&... args) {
+ return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpONE(Args&&... args) {
+ return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FCmpUNE(Args&&... args) {
+ return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FDiv(Args&&... args) {
+ return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FNeg(Args&&... args) {
+ return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPCast(Args&&... args) {
+ return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToSI(Args&&... args) {
+ return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPToUI(Args&&... args) {
+ return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FPTrunc(Args&&... args) {
+ return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FRem(Args&&... args) {
+ return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* FSub(Args&&... args) {
+ return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSGE(Args&&... args) {
+ return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* ICmpSLT(Args&&... args) {
+ return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* IntCast(Args&&... args) {
+ return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* LShr(Args&&... args) {
+ return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* MemSet(Args&&... args) {
+ return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Neg(Args&&... args) {
+ return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Not(Args&&... args) {
+ return mixin_builder()->CreateNot(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::PHINode* PHI(Args&&... args) {
+ return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* RetVoid(Args&&... args) {
+ return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SExtOrTrunc(Args&&... args) {
+ return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Shl(Args&&... args) {
+ return mixin_builder()->CreateShl(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* SIToFP(Args&&... args) {
+ return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Sub(Args&&... args) {
+ return mixin_builder()->CreateSub(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Trunc(Args&&... args) {
+ return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* UIToFP(Args&&... args) {
+ return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Unreachable(Args&&... args) {
+ return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
+ llvm::Value* Xor(Args&&... args) {
+ return mixin_builder()->CreateXor(std::forward<Args>(args)...);
+ }
+
+ private:
+ llvm::IRBuilder<>* mixin_builder() {
+ return static_cast<Derived*>(this)->builder();
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index b79567369a..bd0139f85b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -19,7 +19,7 @@ limitations under the License.
namespace xla {
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value*, bool)>& for_body_generator) {
return If(b_->CreateICmpSLT(start, end), [&]() -> Status {
@@ -30,7 +30,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value*, llvm::Value*)>&
for_body_generator) {
@@ -56,7 +56,7 @@ Status KernelSupportLibrary::For(
}
Status KernelSupportLibrary::If(
- tensorflow::StringPiece name, llvm::Value* condition,
+ absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
@@ -70,7 +70,7 @@ Status KernelSupportLibrary::If(
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name,
+ absl::string_view kernel_name,
KernelSupportLibrary::ArgumentVector arguments,
const std::function<void(KernelSupportLibrary::ArgumentVector)>&
kernel_body_generator) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index b00f903d56..43fec311f1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
#include <string>
+#include "absl/strings/string_view.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
// A thin wrapper around llvm_loop.h to make code generating structured control
@@ -49,13 +49,13 @@ class KernelSupportLibrary {
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>& for_body_generator);
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
@@ -67,7 +67,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ Status For(absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var,
bool is_first_iteration)>&
for_body_generator) {
@@ -77,7 +77,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
@@ -99,13 +99,13 @@ class KernelSupportLibrary {
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
for_body_generator);
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, llvm::Value* step,
bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
@@ -119,7 +119,7 @@ class KernelSupportLibrary {
}));
}
- Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ Status For(absl::string_view name, llvm::Value* start, llvm::Value* end,
int64 step, bool peel_first_iteration,
const std::function<Status(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -129,7 +129,7 @@ class KernelSupportLibrary {
peel_first_iteration, for_body_generator);
}
- void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ void ForReturnVoid(absl::string_view name, llvm::Value* start,
llvm::Value* end, int64 step, bool peel_first_iteration,
const std::function<void(llvm::Value* ind_var,
llvm::Value* is_first_iteration)>&
@@ -140,7 +140,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, step,
@@ -151,7 +151,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ absl::string_view name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end, step,
@@ -162,8 +162,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
/*peel_first_iteration=*/false,
@@ -173,8 +172,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step,
+ absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, start, end,
llvm::ConstantInt::get(start->getType(), step),
@@ -182,7 +180,7 @@ class KernelSupportLibrary {
}
Status For(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
return For(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -190,7 +188,7 @@ class KernelSupportLibrary {
}
void ForReturnVoid(
- tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ absl::string_view name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
ForReturnVoid(name, /*start=*/b_->getInt64(start),
/*end=*/b_->getInt64(end),
@@ -203,7 +201,7 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- Status If(tensorflow::StringPiece name, llvm::Value* condition,
+ Status If(absl::string_view name, llvm::Value* condition,
const std::function<Status()>& true_block_generator,
const std::function<Status()>& false_block_generator =
[]() -> Status { return Status::OK(); });
@@ -222,7 +220,7 @@ class KernelSupportLibrary {
IfReturnVoid("", condition, true_block_generator, false_block_generator);
}
- void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition,
+ void IfReturnVoid(absl::string_view name, llvm::Value* condition,
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {
}) {
@@ -237,7 +235,7 @@ class KernelSupportLibrary {
}));
}
- using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
+ using ArgumentVector = absl::Span<llvm::Value* const>;
// Generates the following control flow structure:
//
@@ -259,13 +257,13 @@ class KernelSupportLibrary {
// Currently we only support at most one nullptr value in `arguments`.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, ArgumentVector arguments,
+ absl::string_view kernel_name, ArgumentVector arguments,
const std::function<void(ArgumentVector)>& kernel_body_generator);
// Thin wrappers around the more general EmitAndCallOutlinedKernel above.
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
kernel_body_generator) {
@@ -278,7 +276,7 @@ class KernelSupportLibrary {
static void EmitAndCallOutlinedKernel(
bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b,
- tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1,
+ absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
llvm::Value* arg2, llvm::Value* arg3,
const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
llvm::Value*)>& kernel_body_generator) {
@@ -296,4 +294,4 @@ class KernelSupportLibrary {
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index 35b3941272..e5fbdbd51b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -28,7 +28,7 @@ namespace {
// Returns the indices of the first elements of all consecutive subarrays of the
// given array. For example:
// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
std::vector<size_t> is = {0};
for (size_t i = 1; i < xs.size(); ++i) {
if (1 != xs[i] - xs[i - 1]) {
@@ -40,8 +40,7 @@ std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
// Merges the sequences of dimensions of the given shape which start at the
// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
+Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
std::vector<int64> dimensions;
for (size_t i = 1; i <= segs.size(); ++i) {
dimensions.push_back(std::accumulate(
@@ -55,10 +54,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
}
} // namespace
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
- const Shape& a, const Shape& b) {
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b) {
if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
std::vector<int64> perm(a.dimensions().size());
@@ -88,7 +87,7 @@ tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
return dims_021;
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
IrArray::Index GetUnreducedOutputIndex(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index ccb9b8ba3e..5ea05b3188 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -36,8 +36,8 @@ namespace llvm_ir {
// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
// reduced shape of `b` or the 0-2-1 shape.
-tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
- const Shape& b);
+absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b);
// Return the unreduced output index corresponding to the given reduced output
// index.
@@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex(
// for 021 transpose.
class TiledParameterInfo {
public:
- TiledParameterInfo(tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers,
+ TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
llvm::Value* y, llvm::Value* x)
: param_buffers_(param_buffers), y_(y), x_(x) {}
@@ -67,7 +67,7 @@ class TiledParameterInfo {
private:
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
// if the parameter is not tiled.
- tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers_;
+ absl::Span<llvm::Value* const> param_buffers_;
// The y coordinate within a tile.
llvm::Value* y_;
// The x coordinate within a tile.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index ba7f94834c..219a9f221f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
@@ -25,19 +26,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace llvm_ir {
-ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index,
llvm::Value* step, UnrollMode unroll_mode,
bool prevent_vectorization)
- : prefix_(std::string(prefix)),
- suffix_(std::string(suffix)),
+ : prefix_(prefix),
+ suffix_(suffix),
start_index_(start_index),
end_index_(end_index),
step_(step),
@@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
- llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
- UnrollMode unroll_mode, bool prevent_vectorization) {
+ absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
+ bool prevent_vectorization) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
end_index, step, unroll_mode,
prevent_vectorization));
@@ -168,16 +167,16 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
return result;
}
-string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
+string ForLoop::GetQualifiedName(absl::string_view name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
}
-llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
+llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
llvm::IRBuilder<>* b) {
return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode,
@@ -186,12 +185,9 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
unroll_mode, prevent_vectorization);
}
-std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- llvm::Value* stride,
- UnrollMode unroll_mode,
- bool prevent_vectorization) {
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
+ absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
+ llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
@@ -216,7 +212,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -227,7 +223,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
@@ -238,22 +234,22 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
}
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix) {
+ absl::string_view suffix) {
std::vector<int64> dimensions(ShapeUtil::Rank(shape));
std::iota(dimensions.begin(), dimensions.end(), 0);
return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
}
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix) {
+ const Shape& shape, absl::Span<const int64> dimensions,
+ absl::string_view suffix) {
llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
/*start_index=*/0,
/*end_index=*/shape.dimensions(dimension),
/*suffix=*/
- llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension)));
+ llvm_ir::IrName(suffix, absl::StrCat(dimension)));
index[dimension] = loop->GetIndVarValue();
}
return index;
@@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
IrArray::Index ForLoopNest::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix) {
+ absl::string_view name_suffix) {
// Prepares the dimension list we will use to emit the loop nest. Outermost
// loops are added first. Add loops in major-to-minor order, and skip the
// 'dimension_to_skip' dimension.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index a4fed5c8dc..ac3bba3c9f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -19,15 +19,15 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -78,7 +78,7 @@ class ForLoop {
// `unroll_mode` specifies the desired LLVM unrolling behavior for generated
// loop.
static std::unique_ptr<ForLoop> EmitForLoop(
- tensorflow::StringPiece prefix, llvm::Value* start_index,
+ absl::string_view prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b,
UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -133,19 +133,18 @@ class ForLoop {
// Allow ForLoopNest to call this private constructor.
friend class ForLoopNest;
- ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
+ ForLoop(absl::string_view prefix, absl::string_view suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
UnrollMode unroll_mode, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* b);
- llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name,
- llvm::IRBuilder<>* b);
+ llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b);
// Creates a name for an LLVM construct, appending prefix_ and suffix_, if
// they are set.
- string GetQualifiedName(tensorflow::StringPiece name);
+ string GetQualifiedName(absl::string_view name);
// Return a list of metadata nodes that should be associated with the
// llvm::Loop for this `ForLoop`.
@@ -182,9 +181,9 @@ class ForLoopNest {
SetIndexType(index_ty);
}
- ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b,
+ ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b,
llvm::Type* index_ty = nullptr)
- : name_(std::string(name)),
+ : name_(name),
outer_loop_preheader_bb_(nullptr),
outer_loop_exit_bb_(nullptr),
inner_loop_body_bb_(nullptr),
@@ -197,14 +196,14 @@ class ForLoopNest {
// been added then emit loop inside the body of the last added loop.
// unroll_mode is used to emit metadata that controls LLVM unrolling.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* stride,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- tensorflow::StringPiece suffix, llvm::Value* start_index,
+ absl::string_view suffix, llvm::Value* start_index,
llvm::Value* end_index,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -213,13 +212,13 @@ class ForLoopNest {
// end index are constant.
std::unique_ptr<ForLoop> AddLoop(
int64 start_index, int64 end_index, int64 stride,
- tensorflow::StringPiece suffix,
+ absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
std::unique_ptr<ForLoop> AddLoop(
- int64 start_index, int64 end_index, tensorflow::StringPiece suffix,
+ int64 start_index, int64 end_index, absl::string_view suffix,
UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
bool prevent_vectorization = false);
@@ -234,8 +233,7 @@ class ForLoopNest {
// within the shape. One possible order for that sequence would be:
//
// (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
- IrArray::Index AddLoopsForShape(const Shape& shape,
- tensorflow::StringPiece suffix);
+ IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix);
// Add a loop for each dimension in "dimensions". "suffix" is the
// name suffix of the indvar and basic blocks in this new loop nest.
@@ -244,8 +242,8 @@ class ForLoopNest {
// size equals the rank of shape and there is a null for each
// dimension that is not in "dimensions".
IrArray::Index AddLoopsForShapeOnDimensions(
- const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::StringPiece suffix);
+ const Shape& shape, absl::Span<const int64> dimensions,
+ absl::string_view suffix);
// Emits a series of nested loops for iterating over an operand array. Loops
// are constructed in major to minor dimension layout order. No loop is
@@ -256,7 +254,7 @@ class ForLoopNest {
// basic blocks) constructed by this method.
IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array,
int64 dimension_to_skip,
- tensorflow::StringPiece name_suffix);
+ absl::string_view name_suffix);
// Convenience methods which return particular basic blocks of the outermost
// or innermost loops. These methods return nullptr if no loops have been
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index e6126881af..1a53c026be 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/MDBuilder.h"
@@ -34,8 +36,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -61,7 +61,7 @@ string AsString(const std::string& str) {
return string(str.data(), str.length());
}
-llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
+llvm::StringRef AsStringRef(absl::string_view str) {
return llvm::StringRef(str.data(), str.size());
}
@@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) {
return AsString(buffer_string);
}
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b) {
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
@@ -262,15 +261,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment) {
return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
}
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment) {
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment) {
llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP();
llvm::Function* function = b->GetInsertBlock()->getParent();
b->SetInsertPoint(&function->getEntryBlock(),
@@ -285,7 +286,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
}
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b) {
return llvm::BasicBlock::Create(
/*Context=*/b->getContext(),
@@ -294,27 +295,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
/*InsertBefore*/ insert_before);
}
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else) {
llvm_ir::LlvmIfData if_data;
if_data.if_block = b->GetInsertBlock();
if_data.true_block =
- CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b);
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
if_data.false_block =
- emit_else ? CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-false"), b)
+ emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
: nullptr;
// Add a terminator to the if block, if necessary.
if (if_data.if_block->getTerminator() == nullptr) {
b->SetInsertPoint(if_data.if_block);
- if_data.after_block = CreateBasicBlock(
- nullptr, tensorflow::strings::StrCat(name, "-after"), b);
+ if_data.after_block =
+ CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
b->CreateBr(if_data.after_block);
} else {
if_data.after_block = if_data.if_block->splitBasicBlock(
- b->GetInsertPoint(),
- AsStringRef(tensorflow::strings::StrCat(name, "-after")));
+ b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after")));
}
// Our basic block should now end with an unconditional branch. Remove it;
@@ -413,14 +412,14 @@ string IrName(string a) {
return a;
}
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) {
+string IrName(absl::string_view a, absl::string_view b) {
if (!a.empty() && !b.empty()) {
- return IrName(tensorflow::strings::StrCat(a, ".", b));
+ return IrName(absl::StrCat(a, ".", b));
}
- return IrName(tensorflow::strings::StrCat(a, b));
+ return IrName(absl::StrCat(a, b));
}
-string IrName(const HloInstruction* a, tensorflow::StringPiece b) {
+string IrName(const HloInstruction* a, absl::string_view b) {
return IrName(a->name(), b);
}
@@ -556,7 +555,7 @@ std::map<int, llvm::MDNode*> MergeMetadata(
return result;
}
-static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) {
+static string GetProcessUniqueIrFileName(absl::string_view prefix) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-");
@@ -584,18 +583,16 @@ Status DumpIRToDirectory(const string& directory_name,
// XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
// dumped from the same process in such cases.
string unique_and_safe_file_name = GetProcessUniqueIrFileName(
- tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
- optimized ? "with" : "no", "-opt"));
+ absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-",
+ optimized ? "with" : "no", "-opt"));
string ir_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, ".ll"));
// For some models the embedded constants can be huge, so also dump the module
// with the constants stripped to get IR that is easier to manipulate.
string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath(
- directory_name,
- tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll"));
+ directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll"));
TF_RETURN_IF_ERROR(CreateAndWriteStringToFile(
directory_name, ir_file_name, DumpModuleToString(llvm_module)));
@@ -607,8 +604,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module) {
+ absl::string_view name, llvm::Module* module) {
llvm::Function* function =
llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
function->setCallingConv(llvm::CallingConv::C);
@@ -638,7 +634,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
fake_argv_storage.push_back("");
for (const auto& it : options) {
// Skip options the XLA backend itself consumes.
- if (!tensorflow::str_util::StartsWith(it.first, "xla_")) {
+ if (!absl::StartsWith(it.first, "xla_")) {
if (it.second.empty()) {
fake_argv_storage.push_back(it.first);
} else {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 0958398534..f59baff263 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -20,6 +20,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
@@ -32,8 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace llvm {
@@ -47,11 +47,11 @@ namespace llvm_ir {
// Convert a std::string (used by LLVM's interfaces) to string.
string AsString(const std::string& str);
-// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both
-// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a
+// Convert a absl::string_view to a llvm::StringRef. Note: both
+// absl::string_view and llvm::StringRef are non-owning pointers into a
// string in memory. This method is used to feed strings to LLVM
// & Clang APIs that expect llvm::StringRef.
-llvm::StringRef AsStringRef(tensorflow::StringPiece str);
+llvm::StringRef AsStringRef(absl::string_view str);
template <typename T>
llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
@@ -59,7 +59,7 @@ llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
}
template <typename T>
-llvm::ArrayRef<T> AsArrayRef(const tensorflow::gtl::ArraySlice<T>& slice) {
+llvm::ArrayRef<T> AsArrayRef(const absl::Span<const T>& slice) {
return llvm::ArrayRef<T>(slice.data(), slice.size());
}
@@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module);
// - removing all '%'s.
//
string IrName(string a);
-string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b);
-string IrName(const HloInstruction* a, tensorflow::StringPiece b = "");
+string IrName(absl::string_view a, absl::string_view b);
+string IrName(const HloInstruction* a, absl::string_view b = "");
// Removes special characters from a function name.
//
@@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name);
// intrinsics (for example, "minnum") must include a type in overloaded_types
// for each overloaded type. Typically, overloaded intrinsics have only a single
// overloaded type.
-llvm::Value* EmitCallToIntrinsic(
- llvm::Intrinsic::ID intrinsic_id,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
- llvm::IRBuilder<>* b);
+llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,
+ absl::Span<llvm::Value* const> operands,
+ absl::Span<llvm::Type* const> overloaded_types,
+ llvm::IRBuilder<>* b);
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
@@ -164,21 +163,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
// This can be useful to avoid e.g. executing an alloca every time
// through a loop.
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b,
int alignment = 0);
// As EmitAllocaAtFunctionEntry, but allocates element_count entries
// instead of a single element.
-llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
- llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
- llvm::IRBuilder<>* b, int alignment = 0);
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
+ llvm::Value* element_count,
+ absl::string_view name,
+ llvm::IRBuilder<>* b,
+ int alignment = 0);
// Creates a basic block with the same context and function as for the
// builder. Inserts at the end of the function if insert_before is
// null.
llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
- tensorflow::StringPiece name,
+ absl::string_view name,
llvm::IRBuilder<>* b);
// Struct with data on a conditional branch in a diamond shape created
@@ -210,7 +211,7 @@ struct LlvmIfData {
// Currently the insertion point of the builder must be a well-formed
// block with a terminator. If you need to use this for a
// non-terminated block, just make the function able to do that too.
-LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::IRBuilder<>* b, bool emit_else = true);
// Emits a compare operation between "lhs" and "rhs" with the given predicate,
@@ -285,8 +286,7 @@ Status DumpIRToDirectory(const string& directory_name,
llvm::Function* CreateFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,
bool enable_fast_math, bool optimize_for_size,
- tensorflow::StringPiece name,
- llvm::Module* module);
+ absl::string_view name, llvm::Module* module);
// Extracts the xla_backend_extra_options from `config` and passes those that
// don't start with xla_ to LLVM.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 36f5fa1952..0dc120e0b0 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -18,13 +18,13 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion(
}
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
+ absl::Span<const IrArray> target_arrays,
llvm::IRBuilder<>* b)
: body_emitter_(MakeBodyEmitterForMultiOutputFusion(
target_element_generator,
@@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
}
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ absl::string_view loop_name, llvm::Type* index_type) {
CHECK_NE(index_type, nullptr);
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
@@ -105,7 +105,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
std::unique_ptr<ForLoop> loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/shape_.dimensions(dimension),
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ /*suffix=*/absl::StrFormat("dim.%d", dimension));
array_index[dimension] = loop->GetIndVarValue();
}
@@ -122,7 +122,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
return {array_index};
}
-Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name,
+Status LoopEmitter::EmitLoop(absl::string_view loop_name,
llvm::Type* index_type) {
if (index_type == nullptr) {
index_type = b_->getInt64Ty();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index c4f5c82086..a537c00066 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -53,8 +53,7 @@ class LoopEmitter {
// This is used for multi-output fusion. target_element_generator must
// produce an LLVM struct with N elements.
LoopEmitter(const ElementGenerator& target_element_generator,
- tensorflow::gtl::ArraySlice<IrArray> target_arrays,
- llvm::IRBuilder<>* b);
+ absl::Span<const IrArray> target_arrays, llvm::IRBuilder<>* b);
LoopEmitter(const LoopEmitter&) = delete;
LoopEmitter& operator=(const LoopEmitter&) = delete;
@@ -69,10 +68,10 @@ class LoopEmitter {
}
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name, llvm::Type* index_type);
+ absl::string_view loop_name, llvm::Type* index_type);
// Emits a complete loop nest for every element in the given shape.
- Status EmitLoop(tensorflow::StringPiece loop_name = "",
+ Status EmitLoop(absl::string_view loop_name = "",
llvm::Type* index_type = nullptr);
protected:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index e546f5cc4a..944c79580c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -16,6 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
@@ -29,8 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -42,7 +43,7 @@ namespace {
void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
const IrArray::Index& compare_keys_index,
const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
+ const absl::optional<IrArray>& values_array,
llvm::IRBuilder<>* b) {
// if (is_smaller_index &&
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
@@ -59,15 +60,39 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
SetToFirstInsertPoint(if_data.true_block, b);
auto key1 = keys_array.EmitReadArrayElement(keys_index, b);
auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b);
+ auto compare_key1 = key1;
+ auto compare_key2 = key2;
auto key_type = keys_array.GetShape().element_type();
+ bool is_signed_comparison = true;
+ if (primitive_util::IsFloatingPointType(key_type)) {
+ // We would like a total order of floating point numbers so that the sort
+ // has a predictable behavior in the presence of NaNs. Rather than using
+ // floating point comparison, we use the following trick:
+ // If f is a float, and
+ // x = bit_cast<int32>(f);
+ // y = x < 0 ? 0x7FFFFFFF - x : x;
+ // then y is ordered as an int32 such that finite values have the obvious
+ // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
+ // and end of the ordering.
+ auto k = b->getInt(llvm::APInt::getSignedMaxValue(
+ key1->getType()->getPrimitiveSizeInBits()));
+ auto comparison_type = k->getType();
+ auto zero = llvm::ConstantInt::get(comparison_type, 0);
+ auto maybe_flip = [&](llvm::Value* v) {
+ return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero),
+ b->CreateSub(k, v), v);
+ };
+ compare_key1 = b->CreateBitCast(key1, comparison_type);
+ compare_key2 = b->CreateBitCast(key2, comparison_type);
+ compare_key1 = maybe_flip(compare_key1);
+ compare_key2 = maybe_flip(compare_key2);
+ } else if (!primitive_util::IsSignedIntegralType(key_type)) {
+ is_signed_comparison = false;
+ }
auto comparison =
- primitive_util::IsFloatingPointType(key_type)
- // TODO(b/26783907): Figure out how to handle NaNs.
- ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1)
- : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type)
- ? llvm::ICmpInst::ICMP_SLT
- : llvm::ICmpInst::ICMP_ULT,
- key2, key1);
+ b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT
+ : llvm::ICmpInst::ICMP_ULT,
+ compare_key2, compare_key1);
// If key2 < key1
auto if_smaller_data =
EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false);
@@ -87,8 +112,8 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index,
} // namespace
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ const absl::optional<IrArray>& values_array,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions) {
const Shape& keys_shape = keys_array.GetShape();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index 8458744c6b..527ed10374 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -31,8 +31,8 @@ namespace llvm_ir {
// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr,
// the inner compare loop will not be parallelized.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
- const tensorflow::gtl::optional<IrArray>& values_array,
- tensorflow::StringPiece name, llvm::Value* xor_mask,
+ const absl::optional<IrArray>& values_array,
+ absl::string_view name, llvm::Value* xor_mask,
llvm::IRBuilder<>* b,
const gpu::LaunchDimensions* launch_dimensions);
} // namespace llvm_ir
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 11ed6ee59f..a60643bc75 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
}
}
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
auto* store = b->CreateStore(
@@ -76,6 +75,16 @@ void EmitTuple(const IrArray& tuple,
}
}
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+ llvm::IRBuilder<>* b, llvm::Module* module) {
+ std::vector<llvm::Value*> buffer_ptrs;
+ buffer_ptrs.reserve(buffers.size());
+ absl::c_transform(
+ buffers, std::back_inserter(buffer_ptrs),
+ [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
+ llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
+}
+
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b, llvm::Module* module) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index cf6bf5d0b1..94340b91d8 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
// Utilities for emitting LLVM IR related to HLO tuples.
@@ -65,8 +65,12 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
-void EmitTuple(const IrArray& tuple,
- tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
+ llvm::IRBuilder<>* b, llvm::Module* module);
+
+// Similar to EmitTuple above, except that the output buffers are provided in
+// the form of IrArray.
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
llvm::IRBuilder<>* b, llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 5e02096ee5..0d0fb7946a 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -19,10 +19,12 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -37,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -73,7 +74,7 @@ namespace {
// If the parameter number is invalid for this computation, nullopt is
// returned. When the return value has_value(), nullptr will never be
// the held value.
-tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+absl::optional<const OpMetadata*> ParameterMetadata(
const XlaComputation& computation, int parameter_number) {
for (const HloComputationProto& comp : computation.proto().computations()) {
if (comp.id() == computation.proto().entry_computation_id()) {
@@ -81,14 +82,14 @@ tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
instr.parameter_number() == parameter_number) {
if (!instr.has_metadata()) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
return &instr.metadata();
}
}
}
}
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
ExecutionOptions CreateExecutionOptions(
@@ -140,7 +141,7 @@ ExecutionOptions CreateExecutionOptions(
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options) {
const HloModuleProto& proto = computation.proto();
TF_RET_CHECK(proto.has_program_shape());
@@ -149,7 +150,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
// Validate incoming layouts.
if (argument_layouts.size() != program_shape.parameters_size()) {
return InvalidArgument(
- "Invalid number of arguments for computation: expected %d, got %zu.",
+ "Invalid number of arguments for computation: expected %d, got %u.",
program_shape.parameters_size(), argument_layouts.size());
}
@@ -158,7 +159,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
- tensorflow::gtl::optional<const OpMetadata*> metadata =
+ absl::optional<const OpMetadata*> metadata =
ParameterMetadata(computation, /*parameter_number=*/i);
auto metadata_string = [&metadata]() -> string {
if (!metadata.has_value()) {
@@ -167,16 +168,15 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
CHECK(metadata.value() != nullptr);
const OpMetadata& m = *metadata.value();
if (!m.source_file().empty()) {
- return tensorflow::strings::Printf(
- " (%s:%d)", m.source_file().c_str(), m.source_line());
+ return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
}
return "";
};
return InvalidArgument(
"Invalid argument shape for argument %d%s, expected %s, got %s.", i,
- metadata_string().c_str(),
- ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(argument_shape).c_str());
+ metadata_string(),
+ ShapeUtil::HumanString(program_shape.parameters(i)),
+ ShapeUtil::HumanString(argument_shape));
}
}
if (build_options.result_layout() != nullptr) {
@@ -214,7 +214,7 @@ StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
if (replica_number >= buffers.size()) {
return InvalidArgument(
- "replica_number %d out of range; must be less than num_replicas = %zu.",
+ "replica_number %d out of range; must be less than num_replicas = %u.",
replica_number, buffers.size());
}
return buffers[replica_number];
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 8f707ea904..3b4f0b5083 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/backend.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -48,7 +48,7 @@ class LocalService : public Service {
// compiler is responsible for freeing any memory it allocates this way.
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
const XlaComputation& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options);
// Returns the device ordinal that corresponds to the given replica number.
diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc
index c742d35a7b..e1f56727bd 100644
--- a/tensorflow/compiler/xla/service/logical_buffer.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer.cc
@@ -15,11 +15,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {}
string LogicalBuffer::ToString() const {
string color_string;
if (has_color()) {
- color_string = tensorflow::strings::StrCat(" @", color().value());
+ color_string = absl::StrCat(" @", color().value());
}
- return tensorflow::strings::StrCat(instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "](#", id(), color_string, ")");
+ return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","),
+ "](#", id(), color_string, ")");
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h
index f9ba5a5547..ceacab4ed7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer.h
+++ b/tensorflow/compiler/xla/service/logical_buffer.h
@@ -18,13 +18,13 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index d631fb5ee4..eaa09591b7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
@@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
const ShapeIndex& index) {
CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
logical_buffers_.emplace_back(
- MakeUnique<LogicalBuffer>(instruction, index, next_buffer_id_));
+ absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
output_buffers_[std::make_pair(instruction, index)] =
logical_buffers_.back().get();
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
new file mode 100644
index 0000000000..8269842426
--- /dev/null
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
+#include "absl/types/variant.h"
+namespace xla {
+
+se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() {
+ if (HasOwnership()) {
+ return absl::get<OwningDeviceMemory>(mem_).AsDeviceMemoryBase();
+ } else {
+ return absl::get<se::DeviceMemoryBase>(mem_);
+ }
+}
+
+bool MaybeOwningDeviceMemory::HasOwnership() const {
+ return absl::holds_alternative<OwningDeviceMemory>(mem_);
+}
+
+absl::optional<OwningDeviceMemory> MaybeOwningDeviceMemory::Release() {
+ if (!HasOwnership()) {
+ return {};
+ }
+ OwningDeviceMemory result = std::move(absl::get<OwningDeviceMemory>(mem_));
+ mem_ = result.AsDeviceMemoryBase();
+ return absl::make_optional<OwningDeviceMemory>(std::move(result));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
new file mode 100644
index 0000000000..82e7f1183c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
+
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
+
+namespace xla {
+
+// MaybeOwningDeviceMemory represents either an owned or unowned device memory.
+// Like std::variant<OwningDeviceMemory, DeviceMemory>. When the object goes
+// output of scope, it will free the underlying memory if it owns it.
+class MaybeOwningDeviceMemory {
+ public:
+ MaybeOwningDeviceMemory() = default;
+ explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned)
+ : mem_(std::move(owned)) {}
+ explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned)
+ : mem_(unowned) {}
+ MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default;
+ ~MaybeOwningDeviceMemory() = default;
+
+ MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) {
+ mem_ = unowned;
+ return *this;
+ }
+
+ MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) {
+ mem_ = std::move(owned);
+ return *this;
+ }
+
+ MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default;
+
+ // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The
+ // caller of this function is *not* responsible for freeing the memory.
+ se::DeviceMemoryBase AsDeviceMemoryBase();
+
+ // Release the OwningDeviceMemory without freeing it, and moves the ownership
+ // of the memory buffer from the object to the caller.
+ //
+ // A nullopt is returned if the HasOwnership() == false;
+ absl::optional<OwningDeviceMemory> Release();
+
+ // Returns true if the device_memory has ownership over underlying memory.
+ bool HasOwnership() const;
+
+ private:
+ absl::variant<OwningDeviceMemory, se::DeviceMemoryBase> mem_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 4166ef5baf..b9ec31c497 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() {
void MultiOutputFusion::UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip) {
for (auto instr : instrs_to_update) {
if (skip != nullptr && skip(instr)) {
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 0019cd7254..d2c52651c4 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
- tensorflow::StringPiece name() const override {
- return "multi_output_fusion";
- }
+ absl::string_view name() const override { return "multi_output_fusion"; }
// Run multi-output fusion on the given module. Returns whether the module
// was changed.
@@ -94,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface {
// Update the reachability map after fusing instr1 and instr2.
void UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
- tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip = nullptr);
// Hook for multi-output fusion along producer-consumer edges.
@@ -104,17 +102,17 @@ class MultiOutputFusion : public HloPassInterface {
// InstructionFusion instead.
virtual bool DoProducerConsumerMultiOutputFusion();
- private:
- // Update the internal data structures after instr1 and instr2 are fused into
- // one fusion instruction.
- void Update(HloInstruction* instr1, HloInstruction* instr2);
-
// Optimization fuel is a compiler debugging technique that makes an
// optimization pass stop what it is doing after having made N changes to the
// program, where N is the fuel. By varying N, this can be used to find the
// first single change that makes a test fail.
int64 fuel_;
+ private:
+ // Update the internal data structures after instr1 and instr2 are fused into
+ // one fusion instruction.
+ void Update(HloInstruction* instr1, HloInstruction* instr2);
+
// Computation for the pass.
HloComputation* computation_;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index f6e7578a89..bd8fb17a23 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -52,8 +53,8 @@ NameUniquer::NameUniquer(const string& separator) {
return result;
}
-string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
- string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix));
+string NameUniquer::GetUniqueName(absl::string_view prefix) {
+ string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix));
// Strip away numeric suffix (if any). Only recognize separator if it is in
// the middle of the name.
@@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
if (separator_index != string::npos && (separator_index > 0) &&
(separator_index < root.size() - 1)) {
string after_suffix = root.substr(separator_index + 1);
- if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
+ if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
has_numeric_suffix = true;
// Remove numeric suffix from root.
root = root.substr(0, separator_index);
+ } else {
+ // absl::SimpleAtoi may modify numeric_suffix even if it returns false.
+ numeric_suffix = 0;
}
}
SequentialIdGenerator& id_generator = generated_names_[root];
numeric_suffix = id_generator.RegisterId(numeric_suffix);
if (numeric_suffix == 0) {
- return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0)
- : root;
+ return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root;
}
- tensorflow::strings::StrAppend(&root, separator_, numeric_suffix);
+ absl::StrAppend(&root, separator_, numeric_suffix);
return root;
}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 4423d61069..6dd89c240f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
@@ -38,7 +38,7 @@ class NameUniquer {
// Get a sanitized unique name in a string, with an optional prefix for
// convenience.
- string GetUniqueName(tensorflow::StringPiece prefix = "");
+ string GetUniqueName(absl::string_view prefix = "");
// Sanitizes and returns the name. Unallowed characters will be replaced with
// '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index ac6ea4c72f..4869db79e7 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
@@ -622,7 +622,7 @@ template <typename Previous>
class HloInstructionPatternNameImpl {
public:
explicit HloInstructionPatternNameImpl(const Previous& previous,
- tensorflow::StringPiece name)
+ absl::string_view name)
: previous_(previous), name_(name) {}
bool Match(const ::xla::HloInstruction* inst) const {
@@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl {
private:
Previous previous_;
- tensorflow::StringPiece name_;
+ absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
@@ -784,7 +784,7 @@ class HloInstructionPattern {
// Modifies the pattern to match only if the instruction has the given name.
HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(tensorflow::StringPiece name) const {
+ WithName(absl::string_view name) const {
return HloInstructionPattern<HloInstructionType,
HloInstructionPatternNameImpl<Impl>>(
HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
@@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) {
}
XLA_NULLOP_PATTERN(Constant)
XLA_NULLOP_PATTERN(Parameter)
+XLA_NULLOP_PATTERN(Iota)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 39fe3c7835..178a78ede0 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -19,20 +19,19 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
-using tensorflow::str_util::Lowercase;
-
// Minimum supported CUDA compute capability is 3.5.
constexpr int kMinCudaComputeCapabilityMajor = 3;
constexpr int kMinCudaComputeCapabilityMinor = 5;
@@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter";
namespace {
string CanonicalPlatformName(const string& name) {
- string platform_str = Lowercase(name);
+ string platform_str = absl::AsciiStrToLower(name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
@@ -90,41 +89,54 @@ PlatformUtil::GetSupportedPlatforms() {
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
- return platforms[0];
+ se::Platform* platform = platforms[0];
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
}
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform found: %s",
- platforms_string.c_str());
+ platforms_string);
}
/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
+
+ se::Platform* platform = nullptr;
if (platforms.empty()) {
return NotFound("no platforms found");
} else if (platforms.size() == 1) {
- return platforms[0];
+ platform = platforms[0];
} else if (platforms.size() == 2) {
for (int i = 0; i < 2; i++) {
- if (Lowercase(platforms[i]->Name()) == kInterpreter &&
- Lowercase(platforms[1 - i]->Name()) != kInterpreter) {
- return platforms[1 - i];
+ if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter &&
+ absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) {
+ platform = platforms[1 - i];
+ break;
}
}
}
+ if (platform != nullptr) {
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
+ }
// Multiple platforms present and we can't pick a reasonable default.
- string platforms_string = tensorflow::str_util::Join(
+ string platforms_string = absl::StrJoin(
platforms, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"must specify platform because more than one platform (except for the "
"interpreter platform) found: %s",
- platforms_string.c_str());
+ platforms_string);
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
@@ -132,11 +144,14 @@ PlatformUtil::GetSupportedPlatforms() {
string platform_str = CanonicalPlatformName(platform_name);
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) == platform_str) {
+ if (absl::AsciiStrToLower(platform->Name()) == platform_str) {
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
return platform;
}
}
- return InvalidArgument("platform %s not found", platform_name.c_str());
+ return InvalidArgument("platform %s not found", platform_name);
}
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor(
@@ -146,23 +161,27 @@ PlatformUtil::GetSupportedPlatforms() {
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
std::vector<se::Platform*> matched;
for (se::Platform* platform : platforms) {
- if (Lowercase(platform->Name()) != platform_name) {
+ if (absl::AsciiStrToLower(platform->Name()) != platform_name) {
matched.push_back(platform);
}
}
if (matched.empty()) {
return InvalidArgument("unable to find platform that is not %s",
- platform_name.c_str());
+ platform_name);
}
if (matched.size() == 1) {
- return matched[0];
+ auto platform = matched[0];
+ if (!platform->Initialized()) {
+ TF_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+ return platform;
}
- string matched_string = tensorflow::str_util::Join(
+ string matched_string = absl::StrJoin(
matched, ", ",
[](string* out, const se::Platform* p) { out->append(p->Name()); });
return InvalidArgument(
"found multiple platforms %s, but expected one platform except for %s",
- matched_string.c_str(), platform_name.c_str());
+ matched_string, platform_name);
}
// Returns whether the device underlying the given StreamExecutor is supported
@@ -193,7 +212,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) {
PlatformUtil::GetStreamExecutors(se::Platform* platform) {
int device_count = platform->VisibleDeviceCount();
if (device_count <= 0) {
- return NotFound("no %s devices found", platform->Name().c_str());
+ return NotFound("no %s devices found", platform->Name());
}
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
@@ -232,7 +251,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s",
- platform->Name().c_str());
+ platform->Name());
}
return stream_executors;
}
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index afde3cf95c..256b231e3a 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface {
~ReducePrecisionInsertion() override{};
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "reduce-precision-insertion";
}
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index ca86c5d13e..4df746fca9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -38,6 +38,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
+
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -374,7 +376,7 @@ StatusOr<bool> TryReshapeMoveOnCandidates(
removed = false;
for (auto operand : nontrivial_operands) {
- if (c_any_of(operand->users(), [&](HloInstruction* user) {
+ if (absl::c_any_of(operand->users(), [&](HloInstruction* user) {
return !reshape_candidates->count(user);
})) {
for (auto* user : operand->users()) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1f59e3b314..1e86a0823a 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -26,7 +26,7 @@ namespace xla {
// them inputward also.
class ReshapeMover : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "reshape-mover"; }
+ absl::string_view name() const override { return "reshape-mover"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index ccb9fb3e3a..fcf269eee9 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -28,13 +28,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-
-namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using ReshapeMoverTest = HloVerifiedTestBase;
+
+namespace op = xla::testing::opcode_matchers;
+
+class ReshapeMoverTest : public HloVerifiedTestBase {};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 45ca731153..2f4b2667c4 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/scatter_expander.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -25,7 +26,6 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::ArraySlice;
// Transposes the given scatter_indices such that the index_vector_dim becomes
// the most-minor dimension.
@@ -86,13 +86,13 @@ static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
// major dimensions and all the window dimensions appear in the minor
// dimensions.
static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
- HloInstruction* updates, ArraySlice<int64> update_window_dims) {
+ HloInstruction* updates, absl::Span<const int64> update_window_dims) {
std::vector<int64> permutation;
const int64 updates_rank = ShapeUtil::Rank(updates->shape());
permutation.reserve(updates_rank);
for (int64 i = 0; i < updates_rank; ++i) {
- bool is_scatter_dim = !c_binary_search(update_window_dims, i);
+ bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i);
if (is_scatter_dim) {
permutation.push_back(i);
}
@@ -290,7 +290,7 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
return Unimplemented(
"Scatter operations with more than 2147483647 scatter indices are not "
"supported. This error occurred for %s.",
- scatter->ToString().c_str());
+ scatter->ToString());
}
// Canonicalize the scatter_indices, after which the size of its most-major
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 8f735e877d..14f062c89c 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -22,7 +22,7 @@ namespace xla {
class ScatterExpander : public HloPassInterface {
public:
- tensorflow::StringPiece name() const override { return "scatter_expander"; }
+ absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 1dbf540d13..922ebdf0e3 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -20,10 +20,12 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -46,8 +48,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -55,24 +55,22 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/ptr_util.h"
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrCat;
-
namespace xla {
-
namespace {
+using absl::StrCat;
+using absl::StrFormat;
+
// Records the arguments used to invoke a computation in an HloSnapshot proto.
-Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::Stream* stream, TransferManager* transfer_manager,
- HloSnapshot* module) {
+Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
+ se::Stream* stream, TransferManager* transfer_manager,
+ HloSnapshot* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, *argument));
- *module->add_arguments() = literal->ToProto();
+ *module->add_arguments() = literal.ToProto();
}
return Status::OK();
}
@@ -82,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
+ Literal literal,
transfer_manager->TransferLiteralFromDevice(stream, result));
- *module->mutable_result() = literal->ToProto();
+ *module->mutable_result() = literal.ToProto();
return Status::OK();
}
@@ -148,19 +146,19 @@ Service::Service(const ServiceOptions& options,
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
<< "Requested more replicas than there are devices.";
}
- LOG(INFO) << Printf(
+ LOG(INFO) << StrFormat(
"XLA service %p executing computations on platform %s. Devices:", this,
- execute_backend_->platform()->Name().c_str());
+ execute_backend_->platform()->Name());
for (int i = 0; i < execute_backend_->device_count(); ++i) {
if (execute_backend_->device_ordinal_supported(i)) {
se::StreamExecutor* executor =
execute_backend_->stream_executor(i).ValueOrDie();
const auto& description = executor->GetDeviceDescription();
- LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
- description.name().c_str(),
- description.platform_version().c_str());
+ LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
+ description.name(),
+ description.platform_version());
} else {
- LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
+ LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i);
}
}
} else {
@@ -200,16 +198,16 @@ Status Service::ValidateResultShape(const Shape& client_shape,
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
"with result shape %s",
- ShapeUtil::HumanStringWithLayout(client_shape).c_str(),
- ShapeUtil::HumanString(result_shape).c_str());
+ ShapeUtil::HumanStringWithLayout(client_shape),
+ ShapeUtil::HumanString(result_shape));
}
return Status::OK();
}
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
Service::ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors) {
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors) {
CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
replicated_arguments.resize(options_.number_of_replicas());
@@ -231,9 +229,9 @@ Service::ResolveAndValidateArguments(
return InvalidArgument(
"argument %lu is on device %s:%d but computation will be executed "
"on device %s",
- i, shaped_buffer->platform()->Name().c_str(),
+ i, shaped_buffer->platform()->Name(),
shaped_buffer->device_ordinal(),
- execute_backend_->device_name(replica_device_ordinal).c_str());
+ execute_backend_->device_name(replica_device_ordinal));
}
replicated_arguments[replica].push_back(shaped_buffer);
}
@@ -243,13 +241,13 @@ Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options) {
- auto config = MakeUnique<HloModuleConfig>(program_shape);
+ auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
- return InvalidArgument("computation takes %d parameters, but %zu given",
+ return InvalidArgument("computation takes %d parameters, but %u given",
program_shape.parameters_size(),
argument_shapes.size());
}
@@ -261,8 +259,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
return InvalidArgument(
"Argument does not match shape of computation parameter %d: want "
"%s, got %s",
- i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(*argument_shapes[i]).c_str());
+ i, ShapeUtil::HumanString(program_shape.parameters(i)),
+ ShapeUtil::HumanString(*argument_shapes[i]));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -300,7 +298,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
@@ -314,7 +312,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf("BuildExecutable on service %p", this);
+ VLOG(1) << StrFormat("BuildExecutable on service %p", this);
// Dump computation proto state if flag is set.
std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
@@ -326,12 +324,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
if (directory_path.empty() && execution_directory_path.empty()) {
continue;
}
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
if (!directory_path.empty()) {
- string filename =
- Printf("computation_%lld__%s", module_protos[i]->id(),
- module_protos[i]->entry_computation_name().c_str());
+ string filename = StrFormat("computation_%d__%s", module_protos[i]->id(),
+ module_protos[i]->entry_computation_name());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
}
@@ -369,12 +366,10 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend, tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile) {
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile) {
// Streams where the computation are launched, so we can wait on the streams
// to complete.
std::vector<StreamPool::Ptr> streams;
@@ -409,7 +404,8 @@ Service::ExecuteParallelAndRegisterResult(
streams.push_back(std::move(stream));
if (replica == 0 && profile != nullptr) {
- timers.push_back(MakeUnique<se::Timer>(streams.back()->parent()));
+ timers.push_back(
+ absl::make_unique<se::Timer>(streams.back()->parent()));
streams.back()
->InitTimer(timers.back().get())
.ThenStartTimer(timers.back().get());
@@ -453,8 +449,8 @@ Service::ExecuteParallelAndRegisterResult(
for (int64 i = 0; i < streams.size(); ++i) {
Status block_status = streams[i]->BlockHostUntilDone();
if (!block_status.ok()) {
- return InternalError("failed to complete execution for stream %lld: %s",
- i, block_status.error_message().c_str());
+ return InternalError("failed to complete execution for stream %d: %s", i,
+ block_status.error_message());
}
}
@@ -512,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult(
StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile) {
// Set up streams.
std::vector<StreamPool::Ptr> streams;
@@ -556,8 +551,7 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
// TODO(b/69985541): Support profiling also on this path.
- std::vector<tensorflow::gtl::ArraySlice<const ShapedBuffer*>>
- replicated_arguments;
+ std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
for (const auto& arg : arguments) {
replicated_arguments.push_back(arg);
}
@@ -579,7 +573,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
if (requests_size > 1 && execution_options.device_handles_size() > 1) {
return InvalidArgument(
"Parallel requests with multiple device handles is not supported. "
- "Found %lld parallel requests, with request %lld containing %d device "
+ "Found %d parallel requests, with request %d containing %d device "
"handles.",
requests_size, request_index, execution_options.device_handles_size());
}
@@ -596,7 +590,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments) {
+ absl::Span<const GlobalDataHandle* const> arguments) {
// Resolve the allocations for the arguments of the computation, and create
// a vector of device memory offsets for the arguments from the allocations.
// In the case of partitioned computations, assume all arguments go on the
@@ -744,8 +738,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
}
if (available_device_count < arg->device_count() * replica_count) {
return ResourceExhausted(
- "Requested device count (%lld) exceeds the number of available devices "
- "on the target (%lld)",
+ "Requested device count (%d) exceeds the number of available devices "
+ "on the target (%d)",
arg->device_count(), available_device_count);
}
@@ -795,12 +789,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf(
+ VLOG(1) << StrFormat(
"BuildExecutable on service %p with serialized module proto: %s", this,
- module_proto.name().c_str());
+ module_proto.name());
// Dump computation proto state if flag is set.
- auto hlo_snapshot = MakeUnique<HloSnapshot>();
+ auto hlo_snapshot = absl::make_unique<HloSnapshot>();
const string& directory_path =
module_config->debug_options().xla_dump_computations_to();
const string& execution_directory_path =
@@ -808,8 +802,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
if (!directory_path.empty() || !execution_directory_path.empty()) {
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
if (!directory_path.empty()) {
- string filename = Printf("computation_%lld__%s", module_proto.id(),
- module_proto.entry_computation_name().c_str());
+ string filename = StrFormat("computation_%d__%s", module_proto.id(),
+ module_proto.entry_computation_name());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
}
@@ -934,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> result_literal,
+ Literal result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
- if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
- result_literal->shape())) {
- *result->mutable_literal() = result_literal->ToProto();
+ if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+ *result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
- result_literal->Relayout(*return_shape)->ToProto();
+ result_literal.Relayout(*return_shape).ToProto();
}
return Status::OK();
}
@@ -954,7 +947,7 @@ namespace {
// shape and DeviceMemoryBase values of the clone are identical to the original.
std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
const ShapedBuffer& shaped_buffer, int device_ordinal) {
- auto clone = MakeUnique<ShapedBuffer>(
+ auto clone = absl::make_unique<ShapedBuffer>(
shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(),
shaped_buffer.platform(), device_ordinal);
clone->buffers() = shaped_buffer.buffers();
@@ -965,9 +958,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
@@ -989,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- stream.get(), *literal, shaped_buffer));
+ stream.get(), literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1009,8 +1002,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
"%s",
StrCat("The replica_id=", arg->replica_id(),
" on TransferToInfeedRequest not in range [0, replica_count=",
- replica_count, ").")
- .c_str());
+ replica_count, ")."));
}
se::StreamExecutor* executor;
@@ -1025,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ TF_ASSIGN_OR_RETURN(Literal literal,
Literal::CreateFromProto(arg->literal()));
- return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, *literal);
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+ literal);
}
Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1036,8 +1028,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
- "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
- "%lld)",
+ "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
arg->replica_id(), replica_count);
}
@@ -1057,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
- executor, arg->shape_with_layout(), *literal));
- *result->mutable_literal() = literal->ToProto();
+ executor, arg->shape_with_layout(), literal));
+ *result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@@ -1093,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModule::CreateFromProto(arg->computation(), config));
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, /*arg_literals=*/{}));
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+ *module, /*arg_literals=*/{}));
// Since the result layout is non-effective to the Evaluator results, explicit
// relayout here.
//
// TODO(b/77824332): Make HloEvaluator take care of the re-layout.
if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
+ result_literal = result_literal.Relayout(arg->output_layout());
}
- *result->mutable_literal() = result_literal->ToProto();
+ *result->mutable_literal() = result_literal.ToProto();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 47d196fb2a..44c5248b15 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -176,7 +176,7 @@ class Service : public ServiceInterface {
// class.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options);
// Picks a parallel response and fills the result.
@@ -191,7 +191,7 @@ class Service : public ServiceInterface {
// Prepare the arguments for executing parallel.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
const ExecutionOptions& execution_options,
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
+ absl::Span<const GlobalDataHandle* const> arguments);
protected:
friend class LocalExecutable;
@@ -207,14 +207,14 @@ class Service : public ServiceInterface {
// the corresponding replica.
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(
- tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
+ absl::Span<const GlobalDataHandle* const> arguments,
+ absl::Span<se::StreamExecutor* const> stream_executors);
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
- tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
+ absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options);
// Builds an Executable for the given parameters.
@@ -242,21 +242,17 @@ class Service : public ServiceInterface {
// ExecutionProfile object which will be filled in with profile data.
StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
Executable* executable,
- const tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>>
- arguments,
+ const absl::Span<const std::vector<const ShapedBuffer*>> arguments,
Backend* backend, const string& result_tag, ExecutionProfile* profile);
// Runs the given executables with the given arguments and register the result
// from each executable in the allocation tracker. The handles of the result
// from the tracker are returned.
StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
- tensorflow::gtl::ArraySlice<Executable*> executables,
- tensorflow::gtl::ArraySlice<std::vector<std::vector<const ShapedBuffer*>>>
- arguments,
- Backend* backend,
- tensorflow::gtl::ArraySlice<DeviceHandle> device_handles,
- tensorflow::gtl::ArraySlice<string> result_tags,
- ExecutionProfile* profile);
+ absl::Span<Executable* const> executables,
+ absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
+ Backend* backend, absl::Span<const DeviceHandle> device_handles,
+ absl::Span<const string> result_tags, ExecutionProfile* profile);
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec5743a777..74bdf2a2e3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -21,6 +21,11 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -28,44 +33,37 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
-using tensorflow::str_util::Join;
-using tensorflow::strings::Printf;
-
namespace xla {
-
namespace {
+using absl::StrFormat;
+using absl::StrJoin;
+
// Returns true if no element is present in slice more than once.
-bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+bool AllUnique(absl::Span<const int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+Status ExpectArray(const Shape& shape, absl::string_view op_type) {
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument("Expected array argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
+ string(op_type), ShapeUtil::HumanString(shape));
}
return Status::OK();
}
-Status VerifyReducerShape(
- const ProgramShape& reducer_shape,
- tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
- tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
- int64 inputs) {
+Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ absl::Span<const Shape* const> init_value_shapes,
+ absl::Span<const PrimitiveType> input_element_types,
+ int64 inputs) {
if (reducer_shape.parameters_size() != inputs * 2) {
return InvalidArgument(
- "Reduction function must take %lld parameters, but "
+ "Reduction function must take %d parameters, but "
"takes %d parameter(s).",
inputs * 2, reducer_shape.parameters_size());
}
@@ -75,7 +73,7 @@ Status VerifyReducerShape(
if (ShapeUtil::IsArray(accumulator_shape)) {
if (inputs != 1) {
return InvalidArgument(
- "Reduction function must produce a tuple with %lld elements, but "
+ "Reduction function must produce a tuple with %d elements, but "
"produces a scalar",
inputs);
}
@@ -83,8 +81,8 @@ Status VerifyReducerShape(
} else if (ShapeUtil::IsTuple(accumulator_shape)) {
if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
return InvalidArgument(
- "Reduction function must produce a tuple with %lld elements, but has "
- "%lld elements",
+ "Reduction function must produce a tuple with %d elements, but has "
+ "%d elements",
inputs, ShapeUtil::TupleElementCount(accumulator_shape));
}
for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
@@ -94,7 +92,7 @@ Status VerifyReducerShape(
return InvalidArgument(
"Reduction function must produce a scalar or tuple of scalars, but has "
"shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ ShapeUtil::HumanString(accumulator_shape));
}
for (const Shape* element_shape : accumulator_subshapes) {
@@ -102,7 +100,7 @@ Status VerifyReducerShape(
return InvalidArgument(
"Reduction function must return a scalar or tuple of scalars but "
"returns shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ ShapeUtil::HumanString(accumulator_shape));
}
}
@@ -113,19 +111,19 @@ Status VerifyReducerShape(
if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
reducer_shape.parameters(i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape differs from the "
+ "Reduction function's %d-th parameter shape differs from the "
"result shape: %s vs %s",
- i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]));
}
// Check that init_value's shapes are suitable for reducer_shape.
if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
*init_value_shapes[i])) {
return InvalidArgument(
- "Reduction function's accumulator shape at index %lld differs from "
+ "Reduction function's accumulator shape at index %d differs from "
"the init_value shape: %s vs %s",
- i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(),
- ShapeUtil::HumanString(*init_value_shapes[i]).c_str());
+ i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
+ ShapeUtil::HumanString(*init_value_shapes[i]));
}
// Check that the inputs can be passed in as the non-accumulator arguments.
const Shape input_element_shape =
@@ -133,11 +131,11 @@ Status VerifyReducerShape(
if (!ShapeUtil::CompatibleIgnoringFpPrecision(
input_element_shape, reducer_shape.parameters(inputs + i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape differs from the "
+ "Reduction function's %d-th parameter shape differs from the "
"input type element type: %s vs %s",
inputs + i,
- ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
- ShapeUtil::HumanString(input_element_shape).c_str());
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
+ ShapeUtil::HumanString(input_element_shape));
}
// Check that the accumulator and inputs to the reducer function match.
// If the accumulator is scalar, it must have the same type as the inputs
@@ -147,11 +145,11 @@ Status VerifyReducerShape(
if (!ShapeUtil::CompatibleIgnoringFpPrecision(
*accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
return InvalidArgument(
- "Reduction function's %lld-th parameter shape must "
+ "Reduction function's %d-th parameter shape must "
"match the result shape, but got %s vs %s.",
inputs + i,
- ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
- ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]));
}
}
@@ -164,7 +162,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
bool allow_negative_padding) {
if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
return InvalidArgument(
- "Window has dimension %d but base shape has dimension %lld.",
+ "Window has dimension %d but base shape has dimension %d.",
window.dimensions_size(), ShapeUtil::Rank(base_shape));
}
@@ -173,29 +171,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const auto& dim = window.dimensions(i);
if (dim.size() <= 0) {
return InvalidArgument("Window %s has a non-positive dimension.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.stride() <= 0) {
return InvalidArgument("Window %s has a non-positive stride.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (!allow_negative_padding && dim.padding_low() < 0) {
return InvalidArgument("Window %s has a negative low padding.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (!allow_negative_padding && dim.padding_high() < 0) {
return InvalidArgument("Window %s has a negative high padding.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.base_dilation() < 1) {
return InvalidArgument(
"Window %s has a non-positive base area dilation factor.",
- window.DebugString().c_str());
+ window.DebugString());
}
if (dim.window_dilation() < 1) {
return InvalidArgument(
"Window %s has a non-positive window dilation factor.",
- window.DebugString().c_str());
+ window.DebugString());
}
const int64 dilated_base = window_util::DilatedBound(
@@ -233,11 +231,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
switch (opcode) {
case HloOpcode::kFloor:
case HloOpcode::kCeil:
+ case HloOpcode::kRoundNearestAfz:
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating for floor/ceil "
- "operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "Expected element type in shape to be floating for %s operation; "
+ "got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
return shape;
case HloOpcode::kCos:
@@ -250,9 +249,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (!ShapeUtil::ElementIsFloating(shape) &&
!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating or complex for "
- "sin/cos/exp/log/tanh operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "Expected element type in shape to be floating or complex for %s "
+ "operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
return shape;
case HloOpcode::kReal:
@@ -264,19 +263,47 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
} else {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
- "real/imag operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
shape, primitive_util::ComplexComponentType(shape.element_type()));
+ } else if (ShapeUtil::ElementIsSigned(shape)) {
+ return shape;
+ } else {
+ return InvalidArgument(
+ "Expected element type in shape to be floating or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
}
- return shape;
case HloOpcode::kClz:
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Expected an integral element type in argument to Clz "
+ "operation; got %s.",
+ PrimitiveType_Name(shape.element_type()));
+ }
+ return shape;
case HloOpcode::kNegate:
- case HloOpcode::kRoundNearestAfz:
+ if (!ShapeUtil::ElementIsIntegral(shape) &&
+ !ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be integral, floating or "
+ "complex for %s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
+ }
+ return shape;
case HloOpcode::kSign:
+ if (!ShapeUtil::ElementIsSigned(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be signed or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
+ }
return shape;
case HloOpcode::kNot:
@@ -285,7 +312,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Expected pred or an integral element type in argument to Not "
"operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return shape;
@@ -295,25 +322,24 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
"Expected element type in shape to be floating "
"point for IsFinite "
"operation; got %s.",
- PrimitiveType_Name(shape.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()));
}
return ShapeUtil::ChangeElementType(shape, PRED);
default:
return InvalidArgument(
"Unknown operation for unary shape inference: \"%s\".",
- HloOpcodeString(opcode).c_str());
+ HloOpcodeString(opcode));
}
}
/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const int64 dimension) {
+ absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
if (arg_shapes.empty()) {
return InvalidArgument("Concatenate expects at least one argument.");
}
if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
- return InvalidArgument("Concatenate dimension out of bounds: %lld.",
+ return InvalidArgument("Concatenate dimension out of bounds: %d.",
dimension);
}
const Shape* arg_shape = nullptr;
@@ -327,17 +353,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
return InvalidArgument(
- "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
+ "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
"(%s).",
- ShapeUtil::Rank(*arg_shape),
- ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
- ShapeUtil::HumanString(*shape).c_str());
+ ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape),
+ ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
return InvalidArgument(
"Cannot concatenate arrays with different element types: %s vs %s.",
- PrimitiveType_Name(arg_shape->element_type()).c_str(),
- PrimitiveType_Name(shape->element_type()).c_str());
+ PrimitiveType_Name(arg_shape->element_type()),
+ PrimitiveType_Name(shape->element_type()));
}
for (int64 dimension_number = 0;
dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
@@ -350,9 +375,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Cannot concatenate arrays that differ in dimensions other than "
"the one being concatenated (the other array dimensions must be "
- "the same): %s vs %s in dimension %lld.",
- ShapeUtil::HumanString(*arg_shape).c_str(),
- ShapeUtil::HumanString(*shape).c_str(), dimension);
+ "the same): %s vs %s in dimension %d.",
+ ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
+ dimension);
}
}
element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
@@ -367,7 +392,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ absl::Span<const Shape* const> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
return InvalidArgument(
@@ -384,8 +409,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
!primitive_util::IsComplexType(new_element_type)) {
return Unimplemented(
"Conversion from complex to real type %s => %s is not implemented.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (!ShapeUtil::IsArray(operand_shape) ||
!primitive_util::IsArrayType(new_element_type)) {
@@ -394,8 +419,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// are valid. For now we just reject them, though.
return InvalidArgument(
"Convert does not allow non-arrays, so cannot convert from %s to %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
@@ -407,8 +432,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (primitive_util::IsComplexType(old_element_type) !=
primitive_util::IsComplexType(new_element_type)) {
return InvalidArgument("Conversion from complex to real type %s => %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (!ShapeUtil::IsArray(operand_shape) ||
!primitive_util::IsArrayType(new_element_type)) {
@@ -417,15 +442,15 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// are valid. For now we just reject them, though.
return InvalidArgument(
"Cannot convert from or to tuple type; requested conversion: %s => %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ PrimitiveType_Name(new_element_type));
}
if (primitive_util::BitWidth(old_element_type) !=
primitive_util::BitWidth(new_element_type)) {
return InvalidArgument(
"Cannot bitcast types with different bit-widths: %s => %s.",
- PrimitiveType_Name(old_element_type).c_str(),
- PrimitiveType_Name(new_element_type).c_str());
+ PrimitiveType_Name(old_element_type),
+ PrimitiveType_Name(new_element_type));
}
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
@@ -438,7 +463,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"Expected element type in shape to be floating point for "
"ReducePrecision operation; got %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (exponent_bits < 1) {
// One exponent bit is necessary to distinguish 0 from infinity. Having
@@ -470,21 +495,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"The rank of the operand and the padding configuration do not match: "
"%s vs %s.",
- ShapeUtil::HumanString(operand_shape).c_str(),
- padding_config.ShortDebugString().c_str());
+ ShapeUtil::HumanString(operand_shape),
+ padding_config.ShortDebugString());
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
padding_value_shape)) {
return InvalidArgument(
"The element types of the operands to Pad do not match.");
}
+ if (absl::c_any_of(padding_config.dimensions(),
+ [](const PaddingConfig::PaddingConfigDimension& p) {
+ return p.interior_padding() < 0;
+ })) {
+ return InvalidArgument("Interior padding cannot be negative: %s",
+ padding_config.ShortDebugString());
+ }
+
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
- dimensions[i] = operand_shape.dimensions(i) +
- padding_config.dimensions(i).edge_padding_low() +
- padding_config.dimensions(i).edge_padding_high() +
+ const auto& p = padding_config.dimensions(i);
+ dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
+ p.edge_padding_high() +
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
- padding_config.dimensions(i).interior_padding();
+ p.interior_padding();
}
return ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
@@ -515,22 +548,22 @@ Status ValidateDotDimensionNumbers(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
// Check that dimension numbers are in range.
- auto dims_in_range =
- [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_in_range = [](const int64 rank,
+ absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
in_range) &&
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
};
- tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
+ absl::Span<const int64> lhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
+ absl::Span<const int64> rhs_contracting_dimensions =
AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
- tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
+ absl::Span<const int64> lhs_batch_dimensions =
AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
- tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
+ absl::Span<const int64> rhs_batch_dimensions =
AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
@@ -538,12 +571,12 @@ Status ValidateDotDimensionNumbers(
!dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
rhs_batch_dimensions)) {
return InvalidArgument("A dimension number is out of range in Dot: %s.",
- dimension_numbers.DebugString().c_str());
+ dimension_numbers.DebugString());
}
// Check that dimension numbers are unique.
- auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
- tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto dims_unique = [](absl::Span<const int64> contracting_dims,
+ absl::Span<const int64> batch_dims) -> bool {
tensorflow::gtl::FlatSet<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
@@ -556,7 +589,7 @@ Status ValidateDotDimensionNumbers(
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
!dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
return InvalidArgument("A dimension number is not unique in Dot: %s.",
- dimension_numbers.DebugString().c_str());
+ dimension_numbers.DebugString());
}
// Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
@@ -601,14 +634,13 @@ Status ValidateDotDimensionNumbers(
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
- string message = tensorflow::strings::Printf(
- "Cannot infer shape for dot operation: %s <dot> %s.",
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ string message =
+ StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
if (!addendum.empty()) {
message += " " + addendum;
}
- return InvalidArgument("%s", message.c_str());
+ return InvalidArgument("%s", message);
};
// Check if both element types are the same.
@@ -704,9 +736,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
- HloOpcodeString(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs));
}
}
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
@@ -715,20 +746,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
// the user to provide an explicit broadcast dimension in this case.
// See b/25177275 for more details.
return InvalidArgument("Automatic shape inference not supported: %s and %s",
- ShapeUtil::HumanString(smaller_shape).c_str(),
- ShapeUtil::HumanString(larger_shape).c_str());
+ ShapeUtil::HumanString(smaller_shape),
+ ShapeUtil::HumanString(larger_shape));
} else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
return InvalidArgument(
"Size of broadcast_dimensions has to match lower-rank operand's "
"rank; "
- " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
- "%zu.",
+ " lower-rank operand's rank is %d, size of broadcast_dimensions is "
+ "%u.",
ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
}
@@ -778,12 +809,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
int64 dimension_to_match = broadcast_dimensions.at(i);
if (dimension_to_match < 0) {
return InvalidArgument(
- "Broadcast dimension number (%lld) cannot be negative.",
+ "Broadcast dimension number (%d) cannot be negative.",
dimension_to_match);
}
if (dimension_to_match >= larger_shape.dimensions_size()) {
return InvalidArgument(
- "Broadcast dimension number (%lld) too large; higher-rank "
+ "Broadcast dimension number (%d) too large; higher-rank "
"operand has rank %d.",
dimension_to_match, larger_shape.dimensions_size());
}
@@ -795,16 +826,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (small_dimension_size != large_dimension_size &&
small_dimension_size != 1 && large_dimension_size != 1) {
return InvalidArgument(
- "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i,
+ "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
small_dimension_size, large_dimension_size,
- ShapeUtil::HumanString(smaller_shape).c_str(),
- ShapeUtil::HumanString(larger_shape).c_str());
+ ShapeUtil::HumanString(smaller_shape),
+ ShapeUtil::HumanString(larger_shape));
}
// Make sure the broadcast dimensions are listed in a strictly increasing
// order.
if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
return InvalidArgument(
- "Broadcast dimensions order is wrong: %lld comes after %lld.",
+ "Broadcast dimensions order is wrong: %d comes after %d.",
dimension_to_match, broadcast_dimensions.at(i - 1));
}
@@ -816,15 +847,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ absl::Span<const int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Binary op %s with different element types: %s and %s.",
- HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
@@ -873,21 +904,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- VLOG(2) << tensorflow::strings::Printf(
+ absl::Span<const int64> broadcast_dimensions) {
+ VLOG(2) << StrFormat(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
- HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str(),
- Join(broadcast_dimensions, ", ").c_str());
+ HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
+ ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(
- ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(
- ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(ExpectArray(
+ rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -909,7 +937,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected element type in shape to be floating for complex compose "
"operation; got %s.",
- PrimitiveType_Name(lhs.element_type()).c_str());
+ PrimitiveType_Name(lhs.element_type()));
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
@@ -928,7 +956,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected pred or integral type in argument to and/or operation; "
"got %s.",
- PrimitiveType_Name(lhs.element_type()).c_str());
+ PrimitiveType_Name(lhs.element_type()));
}
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
@@ -946,8 +974,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
default:
return Unimplemented(
"Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
- HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(),
- rhs.ShortDebugString().c_str());
+ HloOpcodeString(opcode), lhs.ShortDebugString(),
+ rhs.ShortDebugString());
}
}
@@ -970,14 +998,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case HloOpcode::kTupleSelect:
return InferTupleSelectShape(lhs, rhs, ehs);
default:
- return InvalidArgument("Unknown operation %s.",
- HloOpcodeString(opcode).c_str());
+ return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
}
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
std::vector<const Shape*> operand_shapes;
operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
@@ -987,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
@@ -1010,8 +1035,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Sort keys and values dimensions must match. "
"Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(*operand_shapes[0]).c_str(),
- ShapeUtil::HumanString(*operand_shapes[1]).c_str());
+ ShapeUtil::HumanString(*operand_shapes[0]),
+ ShapeUtil::HumanString(*operand_shapes[1]));
}
return ShapeUtil::MakeTupleShape(
{*operand_shapes[0], *operand_shapes[1]});
@@ -1019,15 +1044,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument("Unexpected number of operands for sort");
}
default:
- return InvalidArgument("Unknown operation %s.",
- HloOpcodeString(opcode).c_str());
+ return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
}
}
/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions) {
if (arg_shapes.empty()) {
return InvalidArgument("Map expects at least one argument.");
}
@@ -1058,7 +1081,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Map operation requires all operands to have the same shape; got: "
"%s.",
- Join(pieces, ", ").c_str());
+ StrJoin(pieces, ", "));
}
// Check that dimensions.size == arg_shape.dimensions_size() (we currently
@@ -1066,7 +1089,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (dimensions.size() != arg_shape->dimensions_size()) {
return InvalidArgument(
"Map applied to a subset of dimensions currently not supported: "
- "arg_dimension_size: %d, requested_map_dimensions_size: %zu.",
+ "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
arg_shape->dimensions_size(), dimensions.size());
}
@@ -1075,7 +1098,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (dimensions[i] != i) {
return InvalidArgument(
"Map requires monotonically increasing dimension numbers; got: %s.",
- Join(dimensions, ", ").c_str());
+ StrJoin(dimensions, ", "));
}
}
@@ -1083,7 +1106,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (arg_shapes.size() != to_apply.parameters_size()) {
return InvalidArgument(
"Map applied function arity must match number of arguments; got: "
- "arity: %d, arguments: %zu.",
+ "arity: %d, arguments: %u.",
to_apply.parameters_size(), arg_shapes.size());
}
@@ -1092,7 +1115,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::IsScalar(output_shape)) {
return InvalidArgument(
"Mapped computation's result has to be a scalar; got: %s.",
- ShapeUtil::HumanString(output_shape).c_str());
+ ShapeUtil::HumanString(output_shape));
}
for (int i = 0; i < to_apply.parameters_size(); ++i) {
@@ -1102,7 +1125,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Mapped computation's parameter has to be a scalar; "
"got parameter %d shape: %s.",
- i, ShapeUtil::HumanString(parameter_shape).c_str());
+ i, ShapeUtil::HumanString(parameter_shape));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
@@ -1110,8 +1133,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Mapped computation's parameter type has to match argument element "
"type; got parameter %d shape: %s, argument shape: %s.",
- i, ShapeUtil::HumanString(parameter_shape).c_str(),
- ShapeUtil::HumanString(*arg_shape).c_str());
+ i, ShapeUtil::HumanString(parameter_shape),
+ ShapeUtil::HumanString(*arg_shape));
}
}
@@ -1140,35 +1163,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-training to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-training to "
- "be a non-negative number, got %lld.",
+ "be a non-negative number, got %d.",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
- "batch-norm-training to be at least 1; got %lld.",
+ "batch-norm-training to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-training must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-training must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
@@ -1176,7 +1199,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-training must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
@@ -1185,8 +1208,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-training, "
"but the shape of offset factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(offset_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(offset_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1195,8 +1218,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-training, "
"but the shape of scale factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1206,16 +1229,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
@@ -1250,35 +1273,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-inference to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to "
- "be a non-negative number, got %lld.",
+ "be a non-negative number, got %d.",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
- "batch-norm-inference to be at least 1; got %lld.",
+ "batch-norm-inference to be at least 1; got %d.",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-inference must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-inference must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
@@ -1286,7 +1309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-inference must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
@@ -1296,8 +1319,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of offset factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(offset_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(offset_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1307,8 +1330,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of scale factor is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
@@ -1318,8 +1341,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of mean is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
@@ -1329,8 +1352,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"batch-norm-inference, "
"but the shape of variance is %s "
"and the shape of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(variance_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(variance_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1340,32 +1363,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
- "but the size of mean is %lld "
- "and the feature count is %lld.",
+ "but the size of mean is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
- "but the size of variance is %lld "
- "and the feature count is %lld.",
+ "but the size of variance is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
}
@@ -1395,36 +1418,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Expected feature_index of batch-norm-grad to be "
"smaller than the rank of operand_shape; "
- "got feature_index %lld, and rank %lld.",
+ "got feature_index %d, and rank %d.",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
return InvalidArgument(
"Expected operand_shape of batch-norm-grad to have the same rank as"
- " output_grad_shape; got rank(oprand_shape) %lld, and"
- " rank(output_grad_shape) %lld.",
+ " output_grad_shape; got rank(oprand_shape) %d, and"
+ " rank(output_grad_shape) %d.",
ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
}
if (ShapeUtil::Rank(mean_shape) != 1) {
return InvalidArgument(
"Mean input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(mean_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(scale_shape));
}
if (ShapeUtil::Rank(var_shape) != 1) {
return InvalidArgument(
"Var input of batch-norm-grad must have"
- " rank 1, but has rank %lld.",
+ " rank 1, but has rank %d.",
ShapeUtil::Rank(var_shape));
}
@@ -1432,14 +1455,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The operand to batch-norm-grad must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
return InvalidArgument(
"The output_grad to batch-norm-grad must have a floating point "
"element type, but the shape is %s.",
- PrimitiveType_Name(output_grad_shape.element_type()).c_str());
+ PrimitiveType_Name(output_grad_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
@@ -1448,8 +1471,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of output_grad is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(output_grad_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(output_grad_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
@@ -1458,8 +1481,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of scale factor is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(scale_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(scale_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
@@ -1468,8 +1491,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
@@ -1478,8 +1501,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
"and the element type of operand is %s.",
- PrimitiveType_Name(mean_shape.element_type()).c_str(),
- PrimitiveType_Name(operand_shape.element_type()).c_str());
+ PrimitiveType_Name(mean_shape.element_type()),
+ PrimitiveType_Name(operand_shape.element_type()));
}
const int64 feature_count = operand_shape.dimensions(feature_index);
@@ -1490,24 +1513,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
- "but the size of offset factor is %lld "
- "and the feature count is %lld.",
+ "but the size of offset factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
- "but the size of scale factor is %lld "
- "and the feature count is %lld.",
+ "but the size of scale factor is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
- "but the size of variance is %lld "
- "and the feature count is %lld.",
+ "but the size of variance is %d "
+ "and the feature count is %d.",
ShapeUtil::GetDimension(var_shape, 0), feature_count);
}
@@ -1517,8 +1540,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::GetDimension(output_grad_shape, i)) {
return InvalidArgument(
"The bounds of operand shape should be the same as output_grad's,"
- "but the bound of operand_shape at dimension %lld is %lld "
- "and the bound of output_grad_shape is %lld.",
+ "but the bound of operand_shape at dimension %d is %d "
+ "and the bound of output_grad_shape is %d.",
i, ShapeUtil::GetDimension(operand_shape, i),
ShapeUtil::GetDimension(output_grad_shape, i));
}
@@ -1529,23 +1552,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dnums) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Convolution with different element types: %s and %s.",
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str());
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
}
if (dnums.input_spatial_dimensions_size() !=
dnums.kernel_spatial_dimensions_size()) {
return InvalidArgument(
"Both arguments to convolution must have same number of dimensions.\n"
"Window: %s",
- window.DebugString().c_str());
+ window.DebugString());
}
const int num_spatial_dims = dnums.input_spatial_dimensions_size();
@@ -1553,19 +1575,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Window must have same number of dimensions as dimension numbers.\n"
"Window: %s\nDimension numbers: %s.",
- window.DebugString().c_str(), dnums.DebugString().c_str());
+ window.DebugString(), dnums.DebugString());
}
const int num_dims = num_spatial_dims + 2;
if (ShapeUtil::Rank(lhs) != num_dims) {
return InvalidArgument(
"The LHS argument to a convolution should have rank %d; lhs: %s.",
- num_dims, ShapeUtil::HumanString(lhs).c_str());
+ num_dims, ShapeUtil::HumanString(lhs));
}
if (ShapeUtil::Rank(rhs) != num_dims) {
return InvalidArgument(
"The RHS argument to a convolution should have rank %d; lhs: %s.",
- num_dims, ShapeUtil::HumanString(lhs).c_str());
+ num_dims, ShapeUtil::HumanString(lhs));
}
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
@@ -1602,26 +1624,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
return InvalidArgument(
"A dimension number is out of range in convolution: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (input_dnums != expected_dnums) {
return InvalidArgument(
"Input dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (window_dnums != expected_dnums) {
return InvalidArgument(
"Window dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
if (output_dnums != expected_dnums) {
return InvalidArgument(
"Output dimensions of convolution must contain each dimension exactly "
"once: %s.",
- dnums.DebugString().c_str());
+ dnums.DebugString());
}
std::vector<int64> input_spatial_dims(num_spatial_dims);
@@ -1642,13 +1664,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
- "Expected LHS feature dimension (value %lld) to match RHS "
- "input feature dimension * feature_group_count (value %lld); "
+ "Expected LHS feature dimension (value %d) to match RHS "
+ "input feature dimension * feature_group_count (value %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
input_features, kernel_input_features * feature_group_count,
- ShapeUtil::HumanString(lhs).c_str(),
- ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
+ }
+ if (kernel_output_features % feature_group_count > 0) {
+ return InvalidArgument(
+ "Expected output feature dimension (value %d) to be divisible by "
+ "feature_group_count (value %d); "
+ "got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}.",
+ kernel_output_features, feature_group_count,
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
}
std::vector<int64> window_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
@@ -1660,8 +1692,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
"RHS shape: %s\n\t"
"Window: {%s}\n\t"
"Dimension numbers: {%s}.",
- ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
- dnums.ShortDebugString().c_str());
+ ShapeUtil::HumanString(rhs), window.ShortDebugString(),
+ dnums.ShortDebugString());
}
Shape base_shape =
@@ -1684,32 +1716,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
const Shape& in, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
+ const absl::Span<const int64> fft_length) {
const int64 fft_rank = fft_length.size();
if (fft_rank < 1 || fft_rank > 3) {
- return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank);
+ return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
}
-#define RET_CHECK_RANK(x) \
- if (x.dimensions_size() < fft_rank) { \
- return InvalidArgument( \
- "FFT of rank %lld requires input of at least " \
- "same rank; got input of rank %d", \
- fft_rank, x.dimensions_size()); \
+#define RET_CHECK_RANK(x) \
+ if (x.dimensions_size() < fft_rank) { \
+ return InvalidArgument( \
+ "FFT of rank %d requires input of at least " \
+ "same rank; got input of rank %d", \
+ fft_rank, x.dimensions_size()); \
}
switch (fft_type) {
case FFT:
case IFFT:
if (in.element_type() != C64) {
return InvalidArgument("%s requires C64 input type, found %s.",
- FftType_Name(fft_type).c_str(),
- PrimitiveType_Name(in.element_type()).c_str());
+ FftType_Name(fft_type),
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
return in;
case RFFT: {
if (in.element_type() != F32) {
return InvalidArgument("RFFT requires F32 input type, found %s.",
- PrimitiveType_Name(in.element_type()).c_str());
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
for (int i = 0; i < fft_rank; i++) {
@@ -1717,7 +1749,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[i]) {
return InvalidArgument(
"RFFT requires innermost dimensions match fft_length but "
- "dimension %lld is %lld and should be %lld.",
+ "dimension %d is %d and should be %d.",
in.dimensions_size() - fft_rank + i,
in.dimensions(in.dimensions_size() - fft_rank + i),
fft_length[i]);
@@ -1731,7 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case IRFFT: {
if (in.element_type() != C64) {
return InvalidArgument("IRFFT requires C64 input type, found %s.",
- PrimitiveType_Name(in.element_type()).c_str());
+ PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
Shape result = ShapeUtil::ComplexComponentShape(in);
@@ -1740,7 +1772,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[i]) {
return InvalidArgument(
"IRFFT requires all but one innermost dimensions match "
- "fft_length, but dimension %lld is %lld and should be %lld.",
+ "fft_length, but dimension %d is %d and should be %d.",
in.dimensions_size() - fft_rank + i,
in.dimensions(in.dimensions_size() - fft_rank + i),
fft_length[i]);
@@ -1750,7 +1782,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[fft_rank - 1] / 2 + 1) {
return InvalidArgument(
"IRFFT requires innermost dimension matches fft_length/2+1, but "
- "dimension %d is %lld and should be %lld.",
+ "dimension %d is %d and should be %d.",
in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
fft_length[fft_rank - 1] / 2 + 1);
}
@@ -1765,7 +1797,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
ExpectArray(*operand_shape, "operand of cross replica sum"));
@@ -1786,18 +1818,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RET_CHECK(split_count > 0);
if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) {
return InvalidArgument(
- "AllToAll split_dimension %lld is out-of-bounds in shape %s.",
- split_dimension, ShapeUtil::HumanString(shape).c_str());
+ "AllToAll split_dimension %d is out-of-bounds in shape %s.",
+ split_dimension, ShapeUtil::HumanString(shape));
}
if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) {
return InvalidArgument(
- "AllToAll concat_dimension %lld is out-of-bounds in shape %s.",
- concat_dimension, ShapeUtil::HumanString(shape).c_str());
+ "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
+ concat_dimension, ShapeUtil::HumanString(shape));
}
if (shape.dimensions(split_dimension) % split_count != 0) {
return InvalidArgument(
- "AllToAll split dimension size %lld must be dividable by split_count "
- "%lld.",
+ "AllToAll split dimension size %d must be dividable by split_count "
+ "%d.",
shape.dimensions(split_dimension), split_count);
}
std::vector<int64> new_dimensions(shape.dimensions().begin(),
@@ -1808,7 +1840,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
+ absl::Span<const Shape* const> operand_shapes) {
// An Alltoall HLO instruction receives N operands (with the same shape) and
// returns a tuple that contains N array shapes.
TF_RET_CHECK(!operand_shapes.empty());
@@ -1817,17 +1849,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"HLO all-to-all has operands with different shapes: the 0th "
"operand shape %s, but the %dth operand has shape %s.",
- ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i,
- ShapeUtil::HumanString(*operand_shapes[i]).c_str());
+ ShapeUtil::HumanString(*operand_shapes[0]), i,
+ ShapeUtil::HumanString(*operand_shapes[i]));
}
}
return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
}
+/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape));
+ return shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
if (arg_shapes.empty()) {
return InvalidArgument("Reduce must have at least 2 arguments, has 0");
@@ -1839,17 +1877,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
int64 num_reduced_args = arg_shapes.size() / 2;
- tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
- num_reduced_args);
+ auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
// Check that all of the reduced tensors have the same dimensions. The element
// types may be different.
for (int64 i = 1; i < num_reduced_args; ++i) {
if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
return InvalidArgument(
"All reduced tensors must have the sime dimension. Tensor 0 has "
- "shape %s, Tensor %lld has shape %s",
- ShapeUtil::HumanString(*reduced_args[0]).c_str(), i,
- ShapeUtil::HumanString(*reduced_args[i]).c_str());
+ "shape %s, Tensor %d has shape %s",
+ ShapeUtil::HumanString(*reduced_args[0]), i,
+ ShapeUtil::HumanString(*reduced_args[i]));
}
}
@@ -1859,14 +1896,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
- return InvalidArgument(
- "Reducing out-of-bounds dimension %lld in shape %s.", dimension,
- ShapeUtil::HumanString(arg).c_str());
+ return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
+ dimension, ShapeUtil::HumanString(arg));
}
}
- tensorflow::gtl::ArraySlice<const Shape*> init_values(
- arg_shapes, num_reduced_args, arg_shapes.size());
+ auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
std::vector<PrimitiveType> element_types;
for (const Shape* arg : reduced_args) {
element_types.push_back(arg->element_type());
@@ -1934,16 +1969,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Select function's first parameter shape currently must "
"match the operand element shape, but got %s vs %s.",
- ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
- ShapeUtil::HumanString(operand_element_shape).c_str());
+ ShapeUtil::HumanString(select_shape.parameters(0)),
+ ShapeUtil::HumanString(operand_element_shape));
}
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
select_shape.parameters(1))) {
return InvalidArgument(
"Select function's second parameter shape currently must "
"match the operand element shape, but got %s vs %s.",
- ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(operand_element_shape).c_str());
+ ShapeUtil::HumanString(select_shape.parameters(1)),
+ ShapeUtil::HumanString(operand_element_shape));
}
// Check if the scatter function has a proper shape as a reduction.
@@ -1961,43 +1996,40 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Source shape does not match the shape of window-reduced operand: "
"source(%s), window-reduced operand(%s).",
- ShapeUtil::HumanString(source_shape).c_str(),
- ShapeUtil::HumanString(window_result_shape).c_str());
+ ShapeUtil::HumanString(source_shape),
+ ShapeUtil::HumanString(window_result_shape));
}
return operand_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ const Shape& arg, absl::Span<const int64> starts,
+ absl::Span<const int64> limits, absl::Span<const int64> strides) {
auto error = [&](const string& message) {
return InvalidArgument(
"%s in slice operation; argument shape: %s; starts: {%s}; limits: "
"{%s}; strides: {%s}.",
- message.c_str(), ShapeUtil::HumanString(arg).c_str(),
- Join(starts, ",").c_str(), Join(limits, ",").c_str(),
- Join(strides, ",").c_str());
+ message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
+ StrJoin(limits, ","), StrJoin(strides, ","));
};
TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
- VLOG(2) << tensorflow::strings::Printf(
- "slicing shape %s starts={%s} limits={%s}",
- ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
- Join(limits, ", ").c_str());
+ VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
+ ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
+ StrJoin(limits, ", "));
if (starts.size() != limits.size()) {
- return error(Printf("slice start and limit sizes differ: %zu vs %zu",
- starts.size(), limits.size()));
+ return error(StrFormat("slice start and limit sizes differ: %u vs %u",
+ starts.size(), limits.size()));
}
if (starts.size() != strides.size()) {
- return error(Printf("slice start and strides sizes differ: %zu vs %zu",
- starts.size(), strides.size()));
+ return error(StrFormat("slice start and strides sizes differ: %u vs %u",
+ starts.size(), strides.size()));
}
if (starts.size() != ShapeUtil::Rank(arg)) {
return InvalidArgument(
- "Slice index count does not match argument rank: %zu vs %lld.",
+ "Slice index count does not match argument rank: %u vs %d.",
starts.size(), ShapeUtil::Rank(arg));
}
@@ -2007,27 +2039,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
int64 limit_index = limits[dimension];
int64 stride = strides[dimension];
if (start_index < 0) {
- return InvalidArgument("Negative start index to slice: %lld.",
- start_index);
+ return InvalidArgument("Negative start index to slice: %d.", start_index);
}
if (limit_index > arg.dimensions(dimension)) {
return error(
- Printf("limit index (%lld) must be less than or equal to dimension "
- "size (%lld)",
- limit_index, arg.dimensions(dimension)));
- }
- VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
- start_index);
- VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
- limit_index);
+ StrFormat("limit index (%d) must be less than or equal to dimension "
+ "size (%d)",
+ limit_index, arg.dimensions(dimension)));
+ }
+ VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
+ VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
if (start_index > limit_index) {
return error(
- Printf("limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice with positive stride",
- limit_index, start_index));
+ StrFormat("limit index (%d) must be greater or equal to "
+ "start index (%d) in slice with positive stride",
+ limit_index, start_index));
}
if (stride <= 0) {
- return InvalidArgument("Stride (%lld) must be positive.", stride);
+ return InvalidArgument("Stride (%d) must be positive.", stride);
}
sizes.push_back((limit_index - start_index + stride - 1) / stride);
}
@@ -2037,20 +2066,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
ExpectArray(start_indices_shape, "start indices of dynamic slice"));
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << StrFormat(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
- ShapeUtil::HumanString(operand_shape).c_str(),
- ShapeUtil::HumanString(start_indices_shape).c_str(),
- Join(slice_sizes, ", ").c_str());
+ ShapeUtil::HumanString(operand_shape),
+ ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", "));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
return InvalidArgument(
- "Dynamic slice start indices of rank %lld must be rank1.",
+ "Dynamic slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
}
@@ -2062,16 +2090,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
return InvalidArgument(
- "Dynamic slice start number of dimensions %lld (%s) must match rank "
- "%lld of slice input (%s).",
- start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::Rank(operand_shape),
- ShapeUtil::HumanString(operand_shape).c_str());
+ "Dynamic slice start number of dimensions %d (%s) must match rank "
+ "%d of slice input (%s).",
+ start_num_dims, ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
}
if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
- "Dynamic slice index count does not match argument rank: %zu vs %lld.",
+ "Dynamic slice index count does not match argument rank: %u vs %d.",
slice_sizes.size(), ShapeUtil::Rank(operand_shape));
}
@@ -2079,16 +2106,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 input_dim_size = operand_shape.dimensions(dim);
const int64 slice_dim_size = slice_sizes[dim];
if (slice_dim_size < 0) {
- return InvalidArgument("Negative size index to dynamic slice: %lld.",
+ return InvalidArgument("Negative size index to dynamic slice: %d.",
slice_dim_size);
}
if (slice_dim_size > input_dim_size) {
return InvalidArgument(
- "Slice dim size %lld greater than dynamic slice dimension: %lld.",
+ "Slice dim size %d greater than dynamic slice dimension: %d.",
slice_dim_size, input_dim_size);
}
- VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
- slice_dim_size);
+ VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
}
return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
@@ -2104,16 +2130,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
"start indices of dynamic update slice"));
- VLOG(2) << tensorflow::strings::Printf(
+ VLOG(2) << StrFormat(
"updating slice of shape %s at dynamic start_indices %s with update "
"shape %s",
- ShapeUtil::HumanString(operand_shape).c_str(),
- ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::HumanString(update_shape).c_str());
+ ShapeUtil::HumanString(operand_shape),
+ ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::HumanString(update_shape));
if (ShapeUtil::Rank(start_indices_shape) != 1) {
return InvalidArgument(
- "Dynamic update slice start indices of rank %lld must be rank1.",
+ "Dynamic update slice start indices of rank %d must be rank1.",
ShapeUtil::Rank(start_indices_shape));
}
@@ -2125,17 +2151,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 start_num_dims = start_indices_shape.dimensions(0);
if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
return InvalidArgument(
- "Dynamic update slice start number of dimensions %lld (%s) must match "
- "rank %lld of slice input (%s).",
- start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
- ShapeUtil::Rank(operand_shape),
- ShapeUtil::HumanString(operand_shape).c_str());
+ "Dynamic update slice start number of dimensions %d (%s) must match "
+ "rank %d of slice input (%s).",
+ start_num_dims, ShapeUtil::HumanString(start_indices_shape),
+ ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape));
}
if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
"Dynamic update slice update rank does not match argument rank: "
- "%lld vs %lld.",
+ "%d vs %d.",
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
}
@@ -2144,8 +2169,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Dynamic update slice update element type does not match argument. "
"operand.element_type: %s vs update.element_type: %s.",
- PrimitiveType_Name(operand_shape.element_type()).c_str(),
- PrimitiveType_Name(update_shape.element_type()).c_str());
+ PrimitiveType_Name(operand_shape.element_type()),
+ PrimitiveType_Name(update_shape.element_type()));
}
for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
@@ -2153,23 +2178,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 update_dim_size = update_shape.dimensions(dim);
if (update_dim_size < 0) {
return InvalidArgument(
- "Size index %lld to dynamic update slice must be >= 0.",
+ "Size index %d to dynamic update slice must be >= 0.",
update_dim_size);
}
if (update_dim_size > input_dim_size) {
return InvalidArgument(
- "Update dim size %lld greater than dynamic slice dimension: %lld.",
+ "Update dim size %d greater than dynamic slice dimension: %d.",
update_dim_size, input_dim_size);
}
- VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
- update_dim_size);
+ VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
}
return operand_shape;
}
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
- const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand_shape, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
@@ -2177,8 +2201,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
for (int64 dimension : dimensions) {
if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
return InvalidArgument(
- "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.",
- dimension, ShapeUtil::HumanString(operand_shape).c_str());
+ "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
+ dimension, ShapeUtil::HumanString(operand_shape));
}
}
return operand_shape;
@@ -2189,14 +2213,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::IsTuple(arg)) {
return InvalidArgument(
"Cannot infer shape: attempting to index into non-tuple: %s.",
- ShapeUtil::HumanString(arg).c_str());
+ ShapeUtil::HumanString(arg));
}
if (index >= arg.tuple_shapes_size()) {
return InvalidArgument(
- "Cannot infer shape: attempt to index out of tuple bounds: %lld "
+ "Cannot infer shape: attempt to index out of tuple bounds: %d "
">= %d in shape %s.",
- index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
+ index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
}
return arg.tuple_shapes(index);
@@ -2216,17 +2240,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
auto shape_string = [&]() {
- return tensorflow::strings::Printf(
- "Condition: %s; body: %s; init: %s.",
- ShapeUtil::HumanString(condition).c_str(),
- ShapeUtil::HumanString(body).c_str(),
- ShapeUtil::HumanString(init).c_str());
+ return StrFormat(
+ "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
+ ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
};
// Check the shapes of computation parameters and return types.
if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
return InvalidArgument("Condition must return a boolean; got %s.",
- shape_string().c_str());
+ shape_string());
}
if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
!ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
@@ -2234,7 +2256,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"The parameter of condition and body, the result of the body, and init "
"must all have the same shape; got %s.",
- shape_string().c_str());
+ shape_string());
}
return init;
@@ -2246,7 +2268,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const ProgramShape& false_computation) {
if (!ShapeUtil::ShapeIs(predicate, PRED, {})) {
return InvalidArgument("Predicate must be a boolean; got %s.",
- ShapeUtil::HumanString(predicate).c_str());
+ ShapeUtil::HumanString(predicate));
}
if (true_computation.parameters_size() != 1) {
@@ -2255,15 +2277,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) {
auto true_shape_string = [&]() {
- return tensorflow::strings::Printf(
- "true_operand: %s; true_computation: %s",
- ShapeUtil::HumanString(true_operand).c_str(),
- ShapeUtil::HumanString(true_computation).c_str());
+ return StrFormat("true_operand: %s; true_computation: %s",
+ ShapeUtil::HumanString(true_operand),
+ ShapeUtil::HumanString(true_computation));
};
return InvalidArgument(
"true_operand must match the shape of the only parameter of "
"true_computation: got %s.",
- true_shape_string().c_str());
+ true_shape_string());
}
if (false_computation.parameters_size() != 1) {
@@ -2272,38 +2293,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) {
auto false_shape_string = [&]() {
- return tensorflow::strings::Printf(
- "false_operand: %s; false_computation: %s",
- ShapeUtil::HumanString(false_operand).c_str(),
- ShapeUtil::HumanString(false_computation).c_str());
+ return StrFormat("false_operand: %s; false_computation: %s",
+ ShapeUtil::HumanString(false_operand),
+ ShapeUtil::HumanString(false_computation));
};
return InvalidArgument(
"false_operand must match the shape of the only parameter of "
"false_computation: got %s.",
- false_shape_string().c_str());
+ false_shape_string());
}
if (!ShapeUtil::Compatible(true_computation.result(),
false_computation.result())) {
auto shape_string = [&]() {
- return tensorflow::strings::Printf(
+ return StrFormat(
"true_computation result: %s; false_computation result: %s.",
- ShapeUtil::HumanString(true_computation.result()).c_str(),
- ShapeUtil::HumanString(false_computation.result()).c_str());
+ ShapeUtil::HumanString(true_computation.result()),
+ ShapeUtil::HumanString(false_computation.result()));
};
return InvalidArgument(
"the result of true_computation and false_computation must have the "
"same shape: got %s.",
- shape_string().c_str());
+ shape_string());
}
return true_computation.result();
}
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ const Shape& operand, absl::Span<const int64> broadcast_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
- return InvalidArgument("Broadcast with negative dimension size %lld.",
+ return InvalidArgument("Broadcast with negative dimension size %d.",
size);
}
}
@@ -2317,8 +2337,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ const Shape& operand, absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
@@ -2328,11 +2348,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
return InvalidArgument(
- "Reshape operation has mismatched element counts: from=%lld (%s) "
- "to=%lld (%s).",
- ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
+ "Reshape operation has mismatched element counts: from=%d (%s) "
+ "to=%d (%s).",
+ ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
ShapeUtil::ElementsIn(inferred_shape),
- ShapeUtil::HumanString(inferred_shape).c_str());
+ ShapeUtil::HumanString(inferred_shape));
}
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2343,14 +2363,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Reshape dimensions [%s] are not a permutation of the operand "
"dimensions (operand shape is %s).",
- Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str());
+ StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
}
return inferred_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
@@ -2378,9 +2398,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
- ShapeUtil::HumanString(min).c_str(),
- ShapeUtil::HumanString(operand).c_str(),
- ShapeUtil::HumanString(max).c_str());
+ ShapeUtil::HumanString(min),
+ ShapeUtil::HumanString(operand),
+ ShapeUtil::HumanString(max));
}
if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
ShapeUtil::IsScalar(min)) &&
@@ -2397,9 +2417,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return ShapeUtil::ChangeElementType(min, operand.element_type());
}
}
- return Unimplemented(
- "%s, %s <clamp> %s is not implemented.", min.ShortDebugString().c_str(),
- max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
+ return Unimplemented("%s, %s <clamp> %s is not implemented.",
+ min.ShortDebugString(), max.ShortDebugString(),
+ operand.ShortDebugString());
}
// TODO(b/36794510): Make broadcast semantics more consistent, by supporting
@@ -2410,13 +2430,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
return InvalidArgument(
"Operands to select must be the same shape; got %s and %s.",
- ShapeUtil::HumanString(on_true).c_str(),
- ShapeUtil::HumanString(on_false).c_str());
+ ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
}
if (pred.element_type() != PRED) {
return InvalidArgument(
"Select's pred operand must have PRED element type; got %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
ShapeUtil::IsScalar(pred)) {
@@ -2429,7 +2448,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Select operation with non-scalar predicate with dimensionality "
" different from the other operands: %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
}
@@ -2440,38 +2459,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (!ShapeUtil::Compatible(on_true, on_false)) {
return InvalidArgument(
"Operands to tuple-select must be the same shape; got %s and %s.",
- ShapeUtil::HumanString(on_true).c_str(),
- ShapeUtil::HumanString(on_false).c_str());
+ ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
}
if (pred.element_type() != PRED) {
return InvalidArgument(
"TupleSelect's pred operand must have PRED element type; got %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
if (!ShapeUtil::IsScalar(pred)) {
return InvalidArgument(
"TupleSelect operation with non-scalar predicate: %s.",
- ShapeUtil::HumanString(pred).c_str());
+ ShapeUtil::HumanString(pred));
}
return on_true;
}
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply) {
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
// The applied function's arity equals the number of arguments.
if (arg_shapes.size() != to_apply.parameters_size()) {
string computation_signature = ShapeUtil::HumanString(to_apply);
string argument_shapes =
- Join(arg_shapes, ", ", [](string* out, const Shape* shape) {
- tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
+ StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
+ absl::StrAppend(out, ShapeUtil::HumanString(*shape));
});
return InvalidArgument(
"Call applied function arity must match number of arguments; got: "
- "arity: %d, arguments: %zu; computation signature: %s; argument "
+ "arity: %d, arguments: %u; computation signature: %s; argument "
"shapes: [%s].",
- to_apply.parameters_size(), arg_shapes.size(),
- computation_signature.c_str(), argument_shapes.c_str());
+ to_apply.parameters_size(), arg_shapes.size(), computation_signature,
+ argument_shapes);
}
// All arguments must be compatible with the program shape.
@@ -2482,8 +2499,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InvalidArgument(
"Call parameter must match argument; got parameter %d shape: %s, "
"argument shape: %s.",
- i, ShapeUtil::HumanString(param_shape).c_str(),
- ShapeUtil::HumanString(arg_shape).c_str());
+ i, ShapeUtil::HumanString(param_shape),
+ ShapeUtil::HumanString(arg_shape));
}
}
@@ -2491,202 +2508,198 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
static Status ValidateGatherDimensionNumbers(
- const Shape& input_shape,
- tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ const Shape& input_shape, absl::Span<const int64> start_indices_shape,
const GatherDimensionNumbers& dim_numbers) {
- if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
return InvalidArgument(
"Output window dimensions in gather op must be ascending; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.output_window_dims()) !=
- dim_numbers.output_window_dims().end()) {
+ if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
+ dim_numbers.offset_dims().end()) {
return InvalidArgument(
"Output window dimensions in gather op must not repeat; got: %s.",
- Join(dim_numbers.output_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.offset_dims(), ", "));
}
- const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
const int64 output_shape_rank =
- output_window_dim_count + gather_indices_shape.size() - 1;
+ output_offset_dim_count + start_indices_shape.size() - 1;
- for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
- int64 window_index = dim_numbers.output_window_dims(i);
- if (window_index < 0 || window_index >= output_shape_rank) {
+ for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
+ int64 offset_dim = dim_numbers.offset_dims(i);
+ if (offset_dim < 0 || offset_dim >= output_shape_rank) {
return InvalidArgument(
- "Window index %d in gather op is out of bounds; got %lld, but should "
- "have been in [0,%lld).",
- i, window_index, output_shape_rank);
+ "Offset dimension %d in gather op is out of bounds; got %d, but "
+ "should "
+ "have been in [0,%d).",
+ i, offset_dim, output_shape_rank);
}
}
- if (dim_numbers.gather_dims_to_operand_dims_size() !=
- gather_indices_shape[dim_numbers.index_vector_dim()]) {
+ if (dim_numbers.start_index_map_size() !=
+ start_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
- "Gather op has %d elements in gather_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of gather_indices is "
- "%lld. These two numbers must be equal.",
- dim_numbers.gather_dims_to_operand_dims_size(),
- dim_numbers.index_vector_dim(),
- gather_indices_shape[dim_numbers.index_vector_dim()]);
+ "Gather op has %d elements in start_index_map and the "
+ "bound of dimension index_vector_dim=%d of start_indices is "
+ "%d. These two numbers must be equal.",
+ dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
+ start_indices_shape[dim_numbers.index_vector_dim()]);
}
- for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
- int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
- if (gather_dim_to_input_dim < 0 ||
- gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
+ int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
+ if (operand_dim_for_start_index_i < 0 ||
+ operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
- input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
+ input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
}
}
- std::vector<int64> sorted_gather_dims_to_operand_dims(
- dim_numbers.gather_dims_to_operand_dims().begin(),
- dim_numbers.gather_dims_to_operand_dims().end());
+ std::vector<int64> sorted_start_index_map(
+ dim_numbers.start_index_map().begin(),
+ dim_numbers.start_index_map().end());
- c_sort(sorted_gather_dims_to_operand_dims);
+ absl::c_sort(sorted_start_index_map);
- if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
- sorted_gather_dims_to_operand_dims.end()) {
+ if (absl::c_adjacent_find(sorted_start_index_map) !=
+ sorted_start_index_map.end()) {
return InvalidArgument(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "Repeated dimensions are not allowed in start_index_map; "
"got: %s.",
- Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ StrJoin(dim_numbers.start_index_map(), ", "));
}
- for (int64 elided_dim : dim_numbers.elided_window_dims()) {
- if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
+ if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
return InvalidArgument(
- "Invalid elided_window_dims set in gather op; valid range is [0, "
- "%d), got: %lld.",
- input_shape.dimensions_size(), elided_dim);
+ "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
+ "%d), got: %d.",
+ input_shape.dimensions_size(), collapsed_dim);
}
}
- if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
return InvalidArgument(
- "elided_window_dims in gather op must be sorted; got: %s",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ "collapsed_slice_dims in gather op must be sorted; got: %s",
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
- dim_numbers.elided_window_dims().end()) {
+ if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
+ dim_numbers.collapsed_slice_dims().end()) {
return InvalidArgument(
- "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
"got: %s.",
- Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
}
return Status::OK();
}
/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ absl::Span<const int64> slice_sizes) {
TF_RETURN_IF_ERROR(
ExpectArray(input_shape, "input tensor operand gather op"));
TF_RETURN_IF_ERROR(
- ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(start_indices_shape, "gather indices operand of gather op"));
- if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
return InvalidArgument(
"Gather indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(gather_indices_shape).c_str());
+ ShapeUtil::HumanString(start_indices_shape));
}
// We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
// index_vector_dim is rank(P). The bounds of this expanded shape is
- // stored in expanded_gather_indices_shape.
+ // stored in expanded_start_indices_shape.
- if (gather_indices_shape.dimensions_size() <
+ if (start_indices_shape.dimensions_size() <
gather_dim_numbers.index_vector_dim() ||
gather_dim_numbers.index_vector_dim() < 0) {
return InvalidArgument(
- "Gather index leaf dimension must be within [0, rank(gather_indices) + "
- "1). rank(gather_indices) is %d and gather index leaf dimension is "
- "%lld.",
- gather_indices_shape.dimensions_size(),
+ "Gather index leaf dimension must be within [0, rank(start_indices) + "
+ "1). rank(start_indices) is %d and gather index leaf dimension is "
+ "%d.",
+ start_indices_shape.dimensions_size(),
gather_dim_numbers.index_vector_dim());
}
- std::vector<int64> expanded_gather_indices_shape;
- expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
- c_copy(gather_indices_shape.dimensions(),
- std::back_inserter(expanded_gather_indices_shape));
- if (expanded_gather_indices_shape.size() ==
+ std::vector<int64> expanded_start_indices_shape;
+ expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
+ absl::c_copy(start_indices_shape.dimensions(),
+ std::back_inserter(expanded_start_indices_shape));
+ if (expanded_start_indices_shape.size() ==
gather_dim_numbers.index_vector_dim()) {
- expanded_gather_indices_shape.push_back(1);
+ expanded_start_indices_shape.push_back(1);
}
TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
- input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+ input_shape, expanded_start_indices_shape, gather_dim_numbers));
- if (window_bounds.size() != input_shape.dimensions_size()) {
+ if (slice_sizes.size() != input_shape.dimensions_size()) {
return InvalidArgument(
- "Gather op must have one window bound for every input dimension; got: "
- "len(window_bounds)=%lu, input_shape.rank=%d.",
- window_bounds.size(), input_shape.dimensions_size());
+ "Gather op must have one slice size for every input dimension; got: "
+ "len(slice_sizes)=%lu, input_shape.rank=%d.",
+ slice_sizes.size(), input_shape.dimensions_size());
}
- if (window_bounds.size() !=
- gather_dim_numbers.output_window_dims_size() +
- gather_dim_numbers.elided_window_dims_size()) {
+ if (slice_sizes.size() !=
+ gather_dim_numbers.offset_dims_size() +
+ gather_dim_numbers.collapsed_slice_dims_size()) {
return InvalidArgument(
- "All components of the window index in a gather op must either be a "
- "output window index or explicitly elided; got len(window_bounds)=%lu, "
- "output_window_bounds=%s, elided_window_bounds=%s.",
- window_bounds.size(),
- Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
- Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ "All components of the offset index in a gather op must either be a "
+ "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
+ "output_slice_sizes=%s, collapsed_slice_dims=%s.",
+ slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
+ StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
}
- for (int i = 0; i < window_bounds.size(); i++) {
- int64 window_bound = window_bounds[i];
- int64 corresponding_input_bound = input_shape.dimensions(i);
- if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ for (int i = 0; i < slice_sizes.size(); i++) {
+ int64 slice_size = slice_sizes[i];
+ int64 corresponding_input_size = input_shape.dimensions(i);
+ if (slice_size < 0 || slice_size > corresponding_input_size) {
return InvalidArgument(
- "Window bound at index %d in gather op is out of range, must be "
- "within "
- "[0, %lld), got %lld.",
- i, corresponding_input_bound + 1, window_bound);
+ "Slice size at index %d in gather op is out of range, must be "
+ "within [0, %d), got %d.",
+ i, corresponding_input_size + 1, slice_size);
}
}
- for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
- if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
+ if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
return InvalidArgument(
- "Gather op can only elide window indices with bound 1, but bound is "
- "%lld for index %lld at position %d.",
- window_bounds[gather_dim_numbers.elided_window_dims(i)],
- gather_dim_numbers.elided_window_dims(i), i);
+ "Gather op can only collapse slice dims with bound 1, but bound is "
+ "%d for index %d at position %d.",
+ slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
+ gather_dim_numbers.collapsed_slice_dims(i), i);
}
}
- int64 result_rank = gather_dim_numbers.output_window_dims_size() +
- (expanded_gather_indices_shape.size() - 1);
- int64 window_dims_seen = 0;
+ int64 result_rank = gather_dim_numbers.offset_dims_size() +
+ (expanded_start_indices_shape.size() - 1);
+ int64 offset_dims_seen = 0;
int64 gather_dims_seen = 0;
std::vector<int64> output_dim_bounds;
output_dim_bounds.reserve(result_rank);
for (int64 i = 0; i < result_rank; i++) {
int64 current_bound;
bool is_window_index =
- c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
if (is_window_index) {
- while (c_binary_search(gather_dim_numbers.elided_window_dims(),
- window_dims_seen)) {
- window_dims_seen++;
+ while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
+ offset_dims_seen)) {
+ offset_dims_seen++;
}
- current_bound = window_bounds[window_dims_seen++];
+ current_bound = slice_sizes[offset_dims_seen++];
} else {
if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
gather_dims_seen++;
}
- current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ current_bound = expanded_start_indices_shape[gather_dims_seen++];
}
output_dim_bounds.push_back(current_bound);
@@ -2698,48 +2711,47 @@ static Status ValidateGatherDimensionNumbers(
namespace {
Status ValidateScatterDimensionNumbers(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
// Validate update_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.update_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
return InvalidArgument(
"update_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.update_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
dim_numbers.update_window_dims().end()) {
return InvalidArgument(
"update_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.update_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.update_window_dims(), ", "));
}
const int64 updates_rank = ShapeUtil::Rank(updates_shape);
for (int64 window_dim : dim_numbers.update_window_dims()) {
if (window_dim < 0 || window_dim >= updates_rank) {
return InvalidArgument(
"Invalid update_window_dims set in scatter op; valid range is [0, "
- "%lld). got: %lld.",
+ "%d). got: %d.",
updates_rank, window_dim);
}
}
// Validate inserted_window_dims in ScatterDimensionNumbers.
- if (!c_is_sorted(dim_numbers.inserted_window_dims())) {
+ if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
return InvalidArgument(
"inserted_window_dims in scatter op must be sorted; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", "));
}
- if (c_adjacent_find(dim_numbers.inserted_window_dims()) !=
+ if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
dim_numbers.inserted_window_dims().end()) {
return InvalidArgument(
"inserted_window_dims in scatter op must not repeat; got: %s.",
- Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ StrJoin(dim_numbers.inserted_window_dims(), ", "));
}
for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
return InvalidArgument(
"Invalid inserted_window_dims set in scatter op; valid range is [0, "
- "%d), got: %lld.",
+ "%d), got: %d.",
operand_shape.dimensions_size(), inserted_dim);
}
}
@@ -2749,7 +2761,7 @@ Status ValidateScatterDimensionNumbers(
scatter_indices_shape[dim_numbers.index_vector_dim()]) {
return InvalidArgument(
"Scatter op has %d elements in scatter_dims_to_operand_dims and the "
- "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. "
+ "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
"These two numbers must be equal.",
dim_numbers.scatter_dims_to_operand_dims_size(),
dim_numbers.index_vector_dim(),
@@ -2762,20 +2774,20 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
return InvalidArgument(
"Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
- "got: %d->%lld.",
+ "got: %d->%d.",
operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
}
}
std::vector<int64> sorted_scatter_dims_to_operand_dims(
dim_numbers.scatter_dims_to_operand_dims().begin(),
dim_numbers.scatter_dims_to_operand_dims().end());
- c_sort(sorted_scatter_dims_to_operand_dims);
- if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
+ absl::c_sort(sorted_scatter_dims_to_operand_dims);
+ if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
sorted_scatter_dims_to_operand_dims.end()) {
return InvalidArgument(
"Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
"got: %s.",
- Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
+ StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
}
return Status::OK();
@@ -2796,7 +2808,7 @@ Status ValidateScatterDimensionNumbers(
if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
return InvalidArgument(
"Scatter indices parameter must be an integral tensor; got %s.",
- ShapeUtil::HumanString(scatter_indices_shape).c_str());
+ ShapeUtil::HumanString(scatter_indices_shape));
}
if (scatter_indices_shape.dimensions_size() <
@@ -2805,7 +2817,7 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Scatter index leaf dimension must be within [0, rank(scatter_indices)"
" + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
- "is %lld.",
+ "is %d.",
scatter_indices_shape.dimensions_size(),
scatter_dim_numbers.index_vector_dim());
}
@@ -2827,7 +2839,7 @@ Status ValidateScatterDimensionNumbers(
int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
scatter_dim_numbers.update_window_dims_size();
if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
- return InvalidArgument("Updates tensor must be of rank %lld; got %lld.",
+ return InvalidArgument("Updates tensor must be of rank %d; got %d.",
expected_updates_rank,
ShapeUtil::Rank(updates_shape));
}
@@ -2837,32 +2849,32 @@ Status ValidateScatterDimensionNumbers(
scatter_dim_numbers));
int64 inserted_dims_seen = 0;
- std::vector<int64> max_update_window_bounds;
+ std::vector<int64> max_update_slice_sizes;
for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
++inserted_dims_seen;
} else {
- max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ max_update_slice_sizes.push_back(operand_shape.dimensions(i));
}
}
for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
if (updates_shape.dimensions(update_window_dim) >
- max_update_window_bounds[i]) {
+ max_update_slice_sizes[i]) {
return InvalidArgument(
"Bounds of the window dimensions of updates must not exceed the "
"bounds of the corresponding dimensions of operand. For dimension "
- "%lld, updates bound is %lld, operand bound is %lld.",
+ "%d, updates bound is %d, operand bound is %d.",
update_window_dim, updates_shape.dimensions(update_window_dim),
- max_update_window_bounds[i]);
+ max_update_slice_sizes[i]);
}
}
int64 scatter_dims_seen = 0;
for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
bool is_update_window_dim =
- c_binary_search(scatter_dim_numbers.update_window_dims(), i);
+ absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
if (is_update_window_dim) {
continue;
}
@@ -2874,8 +2886,8 @@ Status ValidateScatterDimensionNumbers(
return InvalidArgument(
"Bounds of the scatter dimensions of updates must be same as the "
"bounds of the corresponding dimensions of scatter indices. For "
- "scatter dimension %lld, updates bound is %lld, scatter_indices "
- "bound is %lld.",
+ "scatter dimension %d, updates bound is %d, scatter_indices "
+ "bound is %d.",
i, updates_shape.dimensions(i),
expanded_scatter_indices_shape[scatter_dims_seen]);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index bfd79a4433..96a0ee165d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -55,7 +55,7 @@ class ShapeInference {
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
const HloInstruction* lhs,
const HloInstruction* rhs);
@@ -73,18 +73,15 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
- HloOpcode opcode,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
+ HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
// Infers the shape produced by applying the given mapping computation shape
// to the given operand shapes.
static StatusOr<Shape> InferMapShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by InferBatchNormTraining with the given
// operands.
@@ -111,19 +108,18 @@ class ShapeInference {
// Infers the shape produced by applying the given convolutional
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
// Infers the shape produced by the given FFT type on the given operand.
- static StatusOr<Shape> InferFftShape(
- const Shape& in, FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
+ static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
+ absl::Span<const int64> fft_length);
// Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferCrossReplicaSumShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
// Infers final shape of an Alltoall operation that is created by the xla
// builder.
@@ -134,7 +130,10 @@ class ShapeInference {
// Infers the shape of an HLO all-to-all instruction.
static StatusOr<Shape> InferAllToAllTupleShape(
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+ absl::Span<const Shape* const> operand_shapes);
+
+ // Infers the shape of a collective permute operation.
+ static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
// Infers the shape produced by applying the given reduction computation
// shape to the given input operand shape.
@@ -143,8 +142,8 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ absl::Span<const Shape* const> arg_shapes,
+ absl::Span<const int64> dimensions_to_reduce,
const ProgramShape& to_apply);
// Infers the shape produced by applying the given computation to the operand
@@ -162,24 +161,23 @@ class ShapeInference {
// Infers the shape produced by a reverse operation that reverses the order
// of the elements in the given dimensions.
- static StatusOr<Shape> InferReverseShape(
- const Shape& operand_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
+ absl::Span<const int64> dimensions);
// Infers the shape produced by a slice operation spanning from the starts to
// the limits in the original shape's dimensions.
//
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
- static StatusOr<Shape> InferSliceShape(
- const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides);
+ static StatusOr<Shape> InferSliceShape(const Shape& arg,
+ absl::Span<const int64> starts,
+ absl::Span<const int64> limits,
+ absl::Span<const int64> strides);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
static StatusOr<Shape> InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ absl::Span<const int64> slice_sizes);
// Infers the shape produced by a dynamic update slice operation based
// on the shape of operand and update.
@@ -210,30 +208,30 @@ class ShapeInference {
// Infers the shape produced by a broadcast operation.
static StatusOr<Shape> InferBroadcastShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ const Shape& operand, absl::Span<const int64> broadcast_sizes);
// Infers the shape produced by a reshape operation from the element type of
// its operand and the new dimension sizes specified.
- static StatusOr<Shape> InferReshapeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
+ static StatusOr<Shape> InferReshapeShape(const Shape& operand,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> new_sizes);
// Infers the shape produced by a transpose operation from the element type of
// its operand and its dimensions field.
static StatusOr<Shape> InferTransposeShape(
- const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+ const Shape& operand, absl::Span<const int64> dimensions);
// Helper that infers the shape produced by performing a concatenate operation
// with the given operand shapes.
static StatusOr<Shape> InferConcatOpShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ absl::Span<const Shape* const> arg_shapes, int64 dimension);
// Infers the shape produced by a kAfterAll. Trivially this shape is always a
// TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
// and checking operand shapes. This method verifies that the operand shapes
// are all TOKENs.
static StatusOr<Shape> InferAfterAllShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+ absl::Span<const Shape* const> arg_shapes);
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
@@ -263,8 +261,7 @@ class ShapeInference {
// Helper that validates the given arg_shapes are compatible with the shape of
// the to_apply parameters, and returns the to_apply result shape.
static StatusOr<Shape> InferCallShape(
- tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
- const ProgramShape& to_apply);
+ absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
@@ -276,9 +273,9 @@ class ShapeInference {
// with the given input shape, gather indices shape and gather dimension
// numbers.
static StatusOr<Shape> InferGatherShape(
- const Shape& input_shape, const Shape& gather_indices_shape,
+ const Shape& input_shape, const Shape& start_indices_shape,
const GatherDimensionNumbers& gather_dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
+ absl::Span<const int64> slice_sizes);
// Helper that validates the given input shape, scatter indices shape, updates
// shape, and scatter dimension numbers that constitute a scatter operation,
@@ -296,7 +293,7 @@ class ShapeInference {
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
@@ -324,7 +321,7 @@ class ShapeInference {
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
const Shape& smaller_shape, const Shape& larger_shape,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ absl::Span<const int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
};
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index a73fa181cd..864ed43118 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -17,18 +17,17 @@ limitations under the License.
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
-using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
// Helper that runs reduce shape inference with the input 'arg' and given
// dimensions to reduce, and checks the inferred shape is as expected. The
// element type here is hard-coded to F32.
- void ExpectInferredReduceShape(
- const Shape& expected_inferred_shape, const Shape& arg,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
+ const Shape& arg,
+ absl::Span<const int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
{&arg, &f32_}, dimensions_to_reduce, to_apply);
@@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
- const tensorflow::gtl::ArraySlice<int64>& bcast) {
+ const absl::Span<const int64>& bcast) {
return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
bcast);
};
@@ -420,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -465,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(2);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -510,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(1);
dim1->set_base_dilation(2);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -548,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_stride(2);
dim1->set_padding_low(1);
dim1->set_padding_high(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));
@@ -1654,11 +1653,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1}));
+ /*slice_sizes=*/{64, 1}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1669,11 +1668,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{1},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{1},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/1),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1684,11 +1683,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{4},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 48}));
+ /*slice_sizes=*/{1, 48}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
<< ShapeUtil::HumanString(gather_shape);
@@ -1700,11 +1699,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
@@ -1717,11 +1716,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1735,11 +1734,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(
gather_shape,
@@ -1749,16 +1748,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
- TF_ASSERT_OK_AND_ASSIGN(
- Shape gather_shape,
- ShapeInference::InferGatherShape(
- f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
- HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3, 4},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
- /*index_vector_dim=*/0),
- /*window_bounds=*/{30, 29, 28, 27, 26}));
+ TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*offset_dims=*/{0, 1, 2, 3, 4},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0),
+ /*slice_sizes=*/{30, 29, 28, 27, 26}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
@@ -1772,11 +1770,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0, 1, 2, 3},
- /*elided_window_dims=*/{0},
- /*gather_dims_to_operand_dims=*/{0},
+ /*offset_dims=*/{0, 1, 2, 3},
+ /*collapsed_slice_dims=*/{0},
+ /*start_index_map=*/{0},
/*index_vector_dim=*/0),
- /*window_bounds=*/{1, 30, 29, 28, 27}));
+ /*slice_sizes=*/{1, 30, 29, 28, 27}));
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
@@ -1787,11 +1785,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/1),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for input"))
@@ -1802,11 +1800,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Expected array argument for gather indices"))
@@ -1817,11 +1815,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{0},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{1},
+ /*offset_dims=*/{0},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{1},
/*index_vector_dim=*/0),
- /*window_bounds=*/{64, 1});
+ /*slice_sizes=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather indices parameter must be an integral tensor"))
@@ -1833,11 +1831,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 8, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 8, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1850,11 +1848,11 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 7},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 7},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
@@ -1867,14 +1865,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 99, 100, 101},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 99, 100, 101},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 2 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 2 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1883,14 +1881,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 9},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 9},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window index 4 in gather op is out of bounds"))
+ HasSubstr("Offset dimension 4 in gather op is out of bounds"))
<< statusor.status();
}
@@ -1899,16 +1897,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{4},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{4},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr("All components of the window index in a gather op must either "
- "be a output window index or explicitly elided"))
+ HasSubstr("All components of the offset index in a gather op must either "
+ "be a offset dimension or explicitly collapsed"))
<< statusor.status();
}
@@ -1917,14 +1915,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 19},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
"range is [0, 5), got: 19"))
<< statusor.status();
}
@@ -1934,16 +1932,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{0, 1, 2, 3, 3},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Repeated dimensions not allowed in "
+ "collapsed_slice_dims in gather op"))
<< statusor.status();
}
@@ -1952,17 +1949,16 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and "
- "the bound of dimension index_vector_dim=4 of "
- "gather_indices is 5. These two numbers must be equal."))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op has 4 elements in start_index_map and "
+ "the bound of dimension index_vector_dim=4 of "
+ "start_indices is 5. These two numbers must be equal."))
<< statusor.status();
}
@@ -1971,16 +1967,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 7},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
- EXPECT_THAT(
- statusor.status().error_message(),
- HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
- "[0, 5), got: 4->7"))
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
<< statusor.status();
}
@@ -1989,16 +1983,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 3},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ HasSubstr("Repeated dimensions are not allowed in start_index_map"))
<< statusor.status();
}
@@ -2007,14 +2000,14 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{2, 1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{2, 1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{1, 1, 28, 27, 26});
+ /*slice_sizes=*/{1, 1, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("elided_window_dims in gather op must be sorted"))
+ HasSubstr("collapsed_slice_dims in gather op must be sorted"))
<< statusor.status();
}
@@ -2023,15 +2016,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{2},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{2},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 1, 300, 26});
+ /*slice_sizes=*/{30, 29, 1, 300, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Window bound at index 3 in gather op is out of range, "
- "must be within [0, 48), got 300"))
+ HasSubstr("Slice size at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300."))
<< statusor.status();
}
@@ -2040,16 +2033,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26});
+ /*slice_sizes=*/{30, 29, 28, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
statusor.status().error_message(),
- HasSubstr(
- "Gather op must have one window bound for every input dimension"))
+ HasSubstr("Gather op must have one slice size for every input dimension"))
<< statusor.status();
}
@@ -2058,15 +2050,15 @@ TEST_F(ScatterGatherShapeInferenceTest,
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7},
- /*elided_window_dims=*/{1},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7},
+ /*collapsed_slice_dims=*/{1},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
- /*window_bounds=*/{30, 29, 28, 26, 20});
+ /*slice_sizes=*/{30, 29, 28, 26, 20});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Gather op can only elide window indices with bound 1, "
- "but bound is 29 for index 1 at position 0"))
+ HasSubstr("Gather op can only collapse slice dims with bound 1, "
+ "but bound is 29 for index 1 at position 0."))
<< statusor.status();
}
@@ -2074,16 +2066,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/{4, 5, 6, 7, 8},
- /*elided_window_dims=*/{},
- /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*offset_dims=*/{4, 5, 6, 7, 8},
+ /*collapsed_slice_dims=*/{},
+ /*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/32),
- /*window_bounds=*/{30, 29, 28, 27, 26});
+ /*slice_sizes=*/{30, 29, 28, 27, 26});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("Gather index leaf dimension must be within [0, "
- "rank(gather_indices) + 1)"))
+ "rank(start_indices) + 1)"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 7d7dcac10b..921a984589 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,20 +18,19 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::strings::Appendf;
-
ShapedBuffer::ShapedBuffer(const Shape& on_host_shape,
const Shape& on_device_shape,
const se::Platform* platform, int device_ordinal)
@@ -76,7 +75,7 @@ void ShapedBuffer::clear() {
}
string ShapedBuffer::ToString() const {
- string s = tensorflow::strings::StrCat(
+ string s = absl::StrCat(
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
", on-device shape=" +
@@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const {
shape_str = ShapeUtil::HumanStringWithLayout(subshape);
}
const se::DeviceMemoryBase& memory = buffer(index);
- Appendf(&s, " %s%p (%lld bytes) : %s\n",
- string(index.size() * 2, ' ').c_str(), memory.opaque(),
- memory.size(), shape_str.c_str());
+ absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n",
+ string(index.size() * 2, ' '), memory.opaque(),
+ memory.size(), shape_str);
});
return s;
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index 905a7e82e6..e1d26da4a2 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <ostream>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@@ -84,6 +84,14 @@ class ShapedBuffer {
*buffers_.mutable_element(index) = buffer;
}
+ // Sets all buffers.
+ //
+ // Precondition: buffers.shape == on_device_shape_
+ void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
+ CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
+ buffers_ = std::move(buffers);
+ }
+
// Returns the underlying ShapeTree containing all the device addresses in the
// ShapedBuffer.
const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
index 0fc2436679..d69e6362e9 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
xla::StreamExecutorMemoryAllocator allocator(platform, executors);
const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {});
const int kDeviceOrdinal = 0;
- auto scoped_buffer = tensorflow::MakeUnique<xla::ScopedShapedBuffer>(
+ auto scoped_buffer = absl::make_unique<xla::ScopedShapedBuffer>(
shape, shape, &allocator, kDeviceOrdinal);
std::unique_ptr<xla::ShapedBuffer> buffer = std::move(scoped_buffer);
buffer = nullptr;
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
index 8cbaac7b37..dd53c7531b 100644
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ b/tensorflow/compiler/xla/service/source_map_util.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/source_map_util.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -26,11 +27,10 @@ Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
string message;
tensorflow::strings::Appendv(&message, format, args);
if (!op_metadata.source_file().empty()) {
- tensorflow::strings::Appendf(&message, " (%s:%d)",
- op_metadata.source_file().c_str(),
- op_metadata.source_line());
+ absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
+ op_metadata.source_line());
}
- return InvalidArgument("%s", message.c_str());
+ return InvalidArgument("%s", message);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h
index 18e2651abb..c5a7e17cb4 100644
--- a/tensorflow/compiler/xla/service/source_map_util.h
+++ b/tensorflow/compiler/xla/service/source_map_util.h
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
-#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/platform/macros.h"
@@ -24,23 +25,40 @@ namespace xla {
namespace source_map_util {
// Creates an INVALID_ARGUMENT status with the given format string.
+template <typename... Args>
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ string message = absl::StrFormat(format, args...);
+ if (!op_metadata.source_file().empty()) {
+ absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
+ op_metadata.source_line());
+ }
+ return InvalidArgument("%s", message);
+}
+
+// Creates an INVALID_ARGUMENT status with the given format string.
//
// Also, attempts to extract the OpMetadata for parameter_number on executable
// and append it to the status message for source mapping to user code.
//
// executable may be nullptr, but parameter_number should not be out of bounds
// or a CHECK-failure may occur.
+template <typename... Args>
Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...)
- TF_PRINTF_ATTRIBUTE(3, 4);
-
-// As above, but takes the parameter metadata directly instead of extracting it
-// from the executable.
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...)
- TF_PRINTF_ATTRIBUTE(2, 3);
+ const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ if (executable != nullptr && executable->has_module()) {
+ const HloModule& module = executable->module();
+ const HloComputation& computation = *module.entry_computation();
+ HloInstruction* param = computation.parameter_instruction(parameter_number);
+ const OpMetadata& metadata = param->metadata();
+ return InvalidParameterArgument(metadata, format, args...);
+ }
+ return InvalidArgument(format, args...);
+}
} // namespace source_map_util
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index c0582c6a2d..5d1cd1c442 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
if (!stream) {
// Create a new stream.
- stream = MakeUnique<se::Stream>(executor);
+ stream = absl::make_unique<se::Stream>(executor);
stream->Init();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool created new stream";
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 32d368a904..a21e586efa 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
-using ::tensorflow::strings::StrCat;
+using absl::StrCat;
namespace xla {
/* static */ tensorflow::mutex
@@ -40,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
@@ -61,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferLiteralFromDevice(
@@ -97,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice(
return substream->BlockHostUntilDone();
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source) {
- StatusOr<std::unique_ptr<Literal>> ret;
+ StatusOr<Literal> ret;
// Implement the synchronous version by waiting on the asynchronous version.
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
@@ -120,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
if (!s.ok()) {
return s;
}
- return MakeUnique<Literal>(std::move(literal));
+ return std::move(literal);
}
Status TransferManager::TransferArrayToDevice(
@@ -147,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
return FailedPrecondition(
"Allocation on device not large enough for array: "
- "%lld < %lld",
+ "%d < %d",
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
@@ -164,12 +166,12 @@ void TransferManager::TransferArrayFromDevice(
auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
" has a differently shaped representation on-device: ",
ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
- return done(FailedPrecondition("%s", error.c_str()));
+ return done(FailedPrecondition("%s", error));
}
if (source.size() < GetByteSizeRequirement(shape)) {
return done(
FailedPrecondition("Allocation on device not large enough for array: "
- "%lld < %lld",
+ "%d < %d",
source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
@@ -201,7 +203,7 @@ void TransferManager::TransferArrayFromDevice(
return NotFound(
"could not find registered transfer manager for platform %s -- check "
"target linkage",
- platform->Name().c_str());
+ platform->Name());
}
if (it->second.manager == nullptr) {
@@ -252,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice(
if (source.size() < size) {
return FailedPrecondition(
"Source allocation on device not large enough for data tranfer: "
- "%lld < %lld",
+ "%d < %d",
source.size(), size);
}
stream->ThenMemcpy(destination, source, size);
@@ -265,7 +267,7 @@ Status TransferManager::TransferBufferToDevice(
if (destination->size() < size) {
return FailedPrecondition(
"Destination allocation on device not large enough for data tranfer: "
- "%lld < %lld",
+ "%d < %d",
destination->size(), size);
}
stream->ThenMemcpy(destination, source, size);
@@ -276,9 +278,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal) {
if (!LayoutUtil::HasLayout(on_host_shape)) {
- return InvalidArgument(
- "Shape must have a layout: %s",
- ShapeUtil::HumanStringWithLayout(on_host_shape).c_str());
+ return InvalidArgument("Shape must have a layout: %s",
+ ShapeUtil::HumanStringWithLayout(on_host_shape));
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 475a2e5c14..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -20,12 +20,12 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -57,7 +57,7 @@ class TransferManager {
// without waiting for any other operation on a stream to complete.
//
// This function should be avoided in favor of the asynchronous version below.
- virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+ virtual StatusOr<Literal> TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer);
virtual Status TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@ class TransferManager {
Status TransferArrayToDeviceAsync(se::Stream* stream,
const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
- StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::Stream* stream, const Shape& shape,
- const se::DeviceMemoryBase& source);
+ StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+ const Shape& shape,
+ const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
@@ -130,7 +130,7 @@ class TransferManager {
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
- tensorflow::gtl::ArraySlice<se::StreamExecutor*> executor) = 0;
+ absl::Span<se::StreamExecutor* const> executor) = 0;
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
@@ -152,6 +152,26 @@ class TransferManager {
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal);
+ // The given ShapedBuffer holds a handle to allocated memory, but it is not
+ // in the general case legal to immediately copy or access that allocated
+ // memory because queued operations on the device may alias that memory.
+ // Memory ordering is enforced by the Stream's happens-before relationship
+ // which allows eager deallocation and reallocation of buffers host-side even
+ // if the device hasn't finished with them.
+ //
+ // In certain cases, it can be known that a ShapedBuffer does not have any
+ // conflicting accesses on the device and thus is eligible to be accessed at
+ // any time from the host.
+ //
+ // This function returns true if device_buffer can be accessed immediately
+ // without waiting for the Stream's previously enqueued items. This only
+ // returns true if all subbuffers in device_buffer can be accessed
+ // immediately.
+ virtual bool CanShapedBufferBeAccessedNow(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
+ return false;
+ }
+
/////
// The TransferManager class also serves as a point to register objects for
// the various platforms.
@@ -191,8 +211,7 @@ class TransferManager {
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
- se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
+ se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
private:
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 49e1f87319..7c1f4b5cc6 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -108,7 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
}
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
- dot->shape(), new_lhs, new_rhs, new_dim_numbers);
+ dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -177,7 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
}
auto new_conv = HloInstruction::CreateConvolve(
- convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
+ convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
+ convolution.window(), new_dnums, convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 71e8446452..3e5aa2db60 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface {
explicit TransposeFolding(
TransposableGemmOperandsFn transposable_gemm_operands,
TransposableConvOperandsFn transposable_conv_operands);
- tensorflow::StringPiece name() const override { return "transpose-folding"; }
+ absl::string_view name() const override { return "transpose-folding"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 58f767e913..79b5c09abb 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -240,10 +240,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -293,10 +295,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -351,10 +355,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -415,10 +421,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0447807a41..6fed7c76d0 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -19,6 +19,10 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -26,17 +30,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
string BufferAlias::ToString() const {
- return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[",
- tensorflow::str_util::Join(index_, ","),
- "])");
+ return absl::StrCat("BufferAlias(", instruction_->name(), "[",
+ absl::StrJoin(index_, ","), "])");
}
std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
@@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
}
Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
+ absl::Span<HloInstruction* const> operands(tuple->operands());
PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
points_to_set.AddPointedToBuffer(
logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
@@ -441,7 +441,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
PerInstruction* pi = PerInst(instruction);
CHECK(pi->points_to_set == nullptr)
<< "instruction should not have been present in the map.";
- auto set = MakeUnique<PointsToSet>(&instruction->shape());
+ auto set = absl::make_unique<PointsToSet>(&instruction->shape());
pi->points_to_set = std::move(set);
// Return *set using the iterator returned by emplace.
return *pi->points_to_set;
@@ -462,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
return FailedPrecondition(
"LogicalBuffer %s is ill-defined: instruction %s does not define a "
"buffer at that index",
- buffer.ToString().c_str(), buffer.instruction()->name().c_str());
+ buffer.ToString(), buffer.instruction()->name());
}
}
if (buffer.id() < 0 ||
buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
- return FailedPrecondition(
- "LogicalBuffer %s is ill-defined: invalid id %lld",
- buffer.ToString().c_str(), buffer.id());
+ return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
+ buffer.ToString(), buffer.id());
}
if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
GetBuffer(buffer.id()).index() != buffer.index()) {
return FailedPrecondition(
"LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
- buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str());
+ buffer.ToString(), GetBuffer(buffer.id()).ToString());
}
return Status::OK();
@@ -495,8 +494,7 @@ StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
return FailedPrecondition(
"instruction %s does not define buffer at index {%s}",
- instruction->name().c_str(),
- tensorflow::str_util::Join(index, ",").c_str());
+ instruction->name(), absl::StrJoin(index, ","));
}
return buffers[0];
}
@@ -557,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
}
string TuplePointsToAnalysis::ToString() const {
- string output = tensorflow::strings::Printf(
- "TuplePointsToSet for module %s:\n", module_->name().c_str());
+ string output =
+ absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
for (const auto* computation : module_->MakeNonfusionComputations()) {
const char* entry =
computation == module_->entry_computation() ? "entry " : "";
- tensorflow::strings::StrAppend(&output, entry, "computation ",
- computation->name(), ":\n");
+ absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
InstructionToString(instruction, &output);
@@ -575,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const {
}
}
- tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n");
+ absl::StrAppend(&output, "LogicalBuffers:\n");
for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
- tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n");
+ absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
- tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(),
- "\n");
+ absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
}
}
return output;
@@ -589,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const {
void TuplePointsToAnalysis::InstructionToString(
const HloInstruction* instruction, string* output) const {
const string prefix = instruction->IsFused() ? " " : "";
- tensorflow::strings::StrAppend(output, prefix, " instruction ",
- instruction->ToShortString(), ":\n");
+ absl::StrAppend(output, prefix, " instruction ",
+ instruction->ToShortString(), ":\n");
const PointsToSet& points_to_set = GetPointsToSet(instruction);
points_to_set.ForEachElement([&prefix, &output](
const ShapeIndex& index,
const PointsToSet::BufferList& points_to) {
- tensorflow::strings::StrAppend(
- output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ",
- tensorflow::str_util::Join(
- points_to, ", ",
- [](string* out, const LogicalBuffer* source) {
- out->append(source->ToString());
- }),
- "\n");
+ absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
+ absl::StrJoin(points_to, ", ",
+ [](string* out, const LogicalBuffer* source) {
+ out->append(source->ToString());
+ }),
+ "\n");
});
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 686bb05328..a9e8a51e09 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -23,6 +23,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -33,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -109,7 +110,7 @@ class PointsToSet {
// Add a tuple source instruction for the given index.
void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
- using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
+ using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>;
// Return the list of logical buffers for the subshape at index.
const BufferList& element(const ShapeIndex& index) const {
@@ -203,7 +204,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
// logical buffer The buffer alias set is the inverse of the points-to set.
// That is, LogicalBuffer B is in the points-to set of instruction I at index
// N iff instruction I, index N is a BufferAlias of B.
- using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>;
+ using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>;
const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
// Returns the number of logical buffers in the module
@@ -226,8 +227,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
// instructions produce a single buffer (the top-level buffer), some produce
// no buffers (eg bitcast), and some produce more than one buffer (eg,
// tuple-shaped parameters).
- using BufferDefinitionVector =
- tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>;
+ using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>;
const BufferDefinitionVector& GetBuffersDefinedByInstruction(
const HloInstruction* instruction) const;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 10d382e8ab..e9a07b14ed 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Checks that the given points-to set contains exactly (unordered) the given
// LogicalBuffers.
- void ExpectHasBuffers(
- const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
+ void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
+ absl::Span<const LogicalBuffer* const> buffers) {
std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
}
@@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// top-level buffers of the given instructions.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferList& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
PointsToSet::BufferList buffers;
for (auto instruction : instructions) {
buffers.push_back(GetBuffer(instruction, /*index=*/{}));
@@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// Overload which takes a set instead of a vector.
void ExpectHasTopLevelBuffers(
const PointsToSet::BufferSet& points_to_set,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
ExpectHasTopLevelBuffers(
PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
instructions);
@@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase {
// aliases which are exactly (unordered) the given instruction/index pairs.
void ExpectHasBufferAliases(
const HloInstruction* instruction, const ShapeIndex& index,
- tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>>
- expected) {
+ absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
const LogicalBuffer* buffer =
points_to_analysis_->GetBufferDefinedAt(instruction, index)
.ValueOrDie();
@@ -557,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})};
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
@@ -1066,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfig::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 7509501883..8c91d6e69d 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface {
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
- tensorflow::StringPiece name() const override { return "tuple-simplifier"; }
+ absl::string_view name() const override { return "tuple-simplifier"; }
// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b693872d..516754e211 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;
@@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
EXPECT_THAT(computation->root_instruction(), gte);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param1);
}
@@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
EXPECT_THAT(computation->root_instruction(),
op::Negate(op::GetTupleElement(op::Tuple())));
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
}
@@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
EXPECT_THAT(computation->root_instruction(), element);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), tuple_param);
}
@@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
EXPECT_THAT(computation->root_instruction(), tuple);
}
@@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
entry = module->AddEntryComputation(builder.Build());
}
- Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+ Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);
diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc
index 4a530bb0b2..cfb0c787d0 100644
--- a/tensorflow/compiler/xla/service/tuple_util.cc
+++ b/tensorflow/compiler/xla/service/tuple_util.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/tuple_util.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -40,7 +40,7 @@ namespace xla {
/*static*/ HloInstruction* TupleUtil::AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values) {
+ absl::Span<HloInstruction* const> trailing_values) {
CHECK(ShapeUtil::IsTuple(input_tuple->shape()));
HloComputation* computation = input_tuple->parent();
diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h
index e5ff9aaa83..bc5aac09f2 100644
--- a/tensorflow/compiler/xla/service/tuple_util.h
+++ b/tensorflow/compiler/xla/service/tuple_util.h
@@ -38,7 +38,7 @@ class TupleUtil {
// `input_tuple`.
static HloInstruction* AppendSuffix(
HloInstruction* input_tuple,
- tensorflow::gtl::ArraySlice<HloInstruction*> trailing_values);
+ absl::Span<HloInstruction* const> trailing_values);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index af2cb6dc2a..541b117e02 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -18,8 +18,8 @@ limitations under the License.
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::nullopt;
+using absl::optional;
// Finds and returns the non-constant operand in instr.
//
@@ -183,8 +183,7 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
HloEvaluator evaluator(/*max_loop_iterations=*/0);
auto* while_init = while_op->mutable_operand(0);
auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
- StatusOr<std::unique_ptr<Literal>> indvar_init_result =
- evaluator.Evaluate(indvar_init);
+ StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
if (!indvar_init_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable init: "
<< indvar_init_result.status();
@@ -197,32 +196,27 @@ optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
// The initial value of the induction variable.
- std::unique_ptr<Literal> indvar_iter_val =
- std::move(indvar_init_result).ValueOrDie();
+ Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
for (int64 trip_count = 0; trip_count != max_value_returned + 1;
++trip_count) {
auto* while_cond = while_op->while_condition();
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_indvar = NonConstantOperand(while_cond_root);
- StatusOr<std::unique_ptr<Literal>> result =
- evaluator.EvaluateWithSubstitutions(
- while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+ while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
if (!result.ok()) {
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- if (result.ValueOrDie()->data<bool>() ==
- tensorflow::gtl::ArraySlice<bool>{false}) {
+ if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
VLOG(2) << "Loop has static trip count of " << trip_count;
return trip_count;
}
// Calculate the value of the induction variable after one iteration of the
// loop, and check whether the while condition is true with this new value.
- StatusOr<std::unique_ptr<Literal>> indvar_next_result =
- evaluator.EvaluateWithSubstitutions(
- while_body_indvar_update,
- {{while_body_indvar, indvar_iter_val.get()}});
+ StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+ while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
if (!indvar_next_result.ok()) {
VLOG(2) << "Couldn't evaluate induction variable update: "
<< indvar_next_result.status();
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h
index bf59813e8c..bf497f4892 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.h
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -25,8 +25,8 @@ namespace xla {
// nullopt otherwise. max_value_returned limits the number of steps that are
// evaluated while trying to brute force a loop trip count, trip counts larger
// than max_value_returned result in nullopt.
-tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount(
- HloInstruction *while_op, int64 max_value_returned = 128);
+absl::optional<int64> ComputeWhileLoopTripCount(HloInstruction *while_op,
+ int64 max_value_returned = 128);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 62af45128a..56145822be 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
@@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance(
std::vector<HloInstruction*> users;
users.reserve(old_instr->user_count());
- c_copy(old_instr->users(), std::back_inserter(users));
+ absl::c_copy(old_instr->users(), std::back_inserter(users));
for (auto* user : users) {
for (int64 i = 0, e = user->operand_count(); i < e; i++) {
@@ -108,10 +109,10 @@ StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) {
//
// This will let us sink the constant into the outer while first and then
// into the inner while in a single run of this pass.
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 21fb8568a8..2dba7d7f75 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface {
public:
~WhileLoopConstantSinking() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 09ddcffb22..e8fe33e626 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -14,18 +14,19 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
+using absl::InlinedVector;
using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
-using tensorflow::gtl::InlinedVector;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
@@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy(
};
InlinedVector<HloInstruction*, 4> new_operands;
- c_transform(old_instruction->operands(), std::back_inserter(new_operands),
- get_new_operand);
+ absl::c_transform(old_instruction->operands(),
+ std::back_inserter(new_operands), get_new_operand);
HloInstruction* new_instruction =
parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands(
@@ -109,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
+ case HloOpcode::kIota:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
@@ -197,7 +199,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
op->opcode() == HloOpcode::kConstant;
};
- if (!c_all_of(instruction->operands(), is_invariant)) {
+ if (!absl::c_all_of(instruction->operands(), is_invariant)) {
continue;
}
@@ -257,10 +259,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
- c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
- [](const HloInstruction* instr) {
- return instr->opcode() == HloOpcode::kWhile;
- });
+ absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
+ [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ });
}
for (HloInstruction* while_instr : while_instrs) {
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8e6cc87875..2cdf20ce80 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface {
: hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index dd8697e680..6a7bfe3f12 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,17 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/optional.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::optional;
+using absl::optional;
// Determines whether the given instruction is a send/recv node, or has a
// subcomputation which contains a send/recv node.
@@ -237,12 +236,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
<< "Instruction " << user->ToString(print_no_metadata)
<< " should be unused (except by root of while body), but has "
"users: {"
- << tensorflow::str_util::Join(
- user->users(), ", ",
- [&](string* out, const HloInstruction* instr) {
- tensorflow::strings::StrAppend(
- out, instr->ToString(print_no_metadata));
- })
+ << absl::StrJoin(user->users(), ", ",
+ [&](string* out, const HloInstruction* instr) {
+ absl::StrAppend(
+ out, instr->ToString(print_no_metadata));
+ })
<< "}";
replacements.emplace(user, nullptr);
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 3d3e1d60f2..78024f14dc 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -33,9 +33,7 @@ namespace xla {
class WhileLoopSimplifier : public HloPassInterface {
public:
~WhileLoopSimplifier() override {}
- tensorflow::StringPiece name() const override {
- return "simplify-while-loops";
- }
+ absl::string_view name() const override { return "simplify-while-loops"; }
StatusOr<bool> Run(HloModule* module) override;
};
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 2e1571943e..1c892ba179 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -15,11 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
namespace {
@@ -64,10 +65,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
@@ -103,10 +102,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
}
)";
- string hlo_string = tensorflow::str_util::StringReplace(
- hlo_string_template, "{{LOOP_BOUND}}",
- tensorflow::strings::StrCat(42 + num_iters),
- /*replace_all=*/true);
+ string hlo_string = absl::StrReplaceAll(
+ hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}});
ParseAndVerifyModule(hlo_string);
}
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 1ef17b9d7d..f90ac91f9d 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -14,15 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
-using tensorflow::strings::StrCat;
+using absl::StrCat;
static StatusOr<HloComputation*> WidenWhileCondition(
HloComputation* narrow_condition, const Shape& wide_shape) {
@@ -93,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
/*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
WhileUtil::MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ absl::Span<HloInstruction* const> instructions) {
CHECK(ShapeUtil::IsTuple(while_instr->shape()));
int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
@@ -206,7 +207,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
- c_copy(init_values, std::back_inserter(init_values_with_indvar));
+ absl::c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
HloInstruction::CreateTuple(init_values_with_indvar));
}
@@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) {
std::vector<Shape> loop_state_shape_components;
loop_state_shape_components.reserve(init_values.size() + 1);
loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {}));
- c_transform(init_values, std::back_inserter(loop_state_shape_components),
- [](HloInstruction* instr) { return instr->shape(); });
+ absl::c_transform(init_values,
+ std::back_inserter(loop_state_shape_components),
+ [](HloInstruction* instr) { return instr->shape(); });
return ShapeUtil::MakeTupleShape(loop_state_shape_components);
}
diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h
index e67636d80f..b1c4486887 100644
--- a/tensorflow/compiler/xla/service/while_util.h
+++ b/tensorflow/compiler/xla/service/while_util.h
@@ -55,7 +55,7 @@ class WhileUtil {
// that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
+ absl::Span<HloInstruction* const> instructions);
using LoopStateTy = std::vector<HloInstruction*>;
using LoopBodyGeneratorTy = std::function<StatusOr<LoopStateTy>(
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index 2ccb919acf..5e69419333 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
@@ -206,7 +207,7 @@ ENTRY main {
auto is_while = [](const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kWhile;
};
- EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+ EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index 8763e588c4..a7f0e207eb 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -24,7 +24,7 @@ namespace xla {
class ZeroSizedHloElimination : public HloPassInterface {
public:
StatusOr<bool> Run(HloModule* module) override;
- tensorflow::StringPiece name() const override {
+ absl::string_view name() const override {
return "zero_sized_hlo_elimination";
}
};
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index caad31d6ce..d44db89d57 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -25,8 +25,8 @@ namespace xla {
Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
if (!ShapeUtil::Compatible(other_shape, shape_)) {
return InvalidArgument("Shape %s is not compatible with shape %s",
- ShapeUtil::HumanString(other_shape).c_str(),
- ShapeUtil::HumanString(shape()).c_str());
+ ShapeUtil::HumanString(other_shape),
+ ShapeUtil::HumanString(shape()));
}
shape_ = other_shape;
return Status::OK();
@@ -35,8 +35,8 @@ Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const {
if (!ShapeUtil::Compatible(*to_shape, shape_)) {
return InvalidArgument("Shape %s is not compatible with shape %s",
- ShapeUtil::HumanString(*to_shape).c_str(),
- ShapeUtil::HumanString(shape()).c_str());
+ ShapeUtil::HumanString(*to_shape),
+ ShapeUtil::HumanString(shape()));
}
*to_shape = shape_;
return Status::OK();
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index c74dd648ad..df610102b4 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -21,16 +21,16 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -224,14 +224,13 @@ class ShapeTree {
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
- return iterator(&nodes_, typename std::vector<Node>::iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.begin() + (element - &nodes_[0]);
+ return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
- return iterator(&nodes_,
- typename std::vector<Node>::const_iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
+ return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
// Returns the number of leaf nodes in the tree.
@@ -262,6 +261,25 @@ class ShapeTree {
template <typename Fn>
Status ForEachMutableElementWithStatus(const Fn& func);
+ // Maps each element to generate a new tree with the same shape.
+ template <typename U>
+ ShapeTree<U> Map(const std::function<U(const T&)>& func) {
+ ShapeTree<U> result(shape_storage_);
+ ForEachElement([&](const ShapeIndex& index, const T& t) {
+ *result.mutable_element(index) = func(t);
+ });
+ return result;
+ }
+
+ template <typename U>
+ ShapeTree<U> Map(const std::function<U(T*)>& func) {
+ ShapeTree<U> result(shape_storage_);
+ ForEachMutableElement([&](const ShapeIndex& index, T* t) {
+ *result.mutable_element(index) = func(t);
+ });
+ return result;
+ }
+
// Copy the subtree of values from 'other' rooted at ShapeIndex
// 'source_base_index' into the subtree of value in this ShapeTree rooted at
// 'target_base_index'.
@@ -463,9 +481,6 @@ template <typename T>
ShapeTree<T>::ShapeTree(Shape shape)
: shape_storage_(std::make_shared<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
- // The shape_ field is just used to hold the structure of the shape.
- // It should not be relied upon to store layout information.
- LayoutUtil::ClearLayout(shape_storage_.get());
const int64 count = CountSubshapes(*shape_);
nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
@@ -502,9 +517,6 @@ template <typename T>
ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
: shape_storage_(std::make_shared<Shape>(std::move(shape))),
shape_(shape_storage_.get()) {
- // The shape_ field is just used to hold the structure of the shape.
- // It should not be relied upon to store layout information.
- LayoutUtil::ClearLayout(shape_storage_.get());
const int64 count = CountSubshapes(*shape_);
nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index c4c958be4a..c8ff55e784 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_tree.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) {
ShapeTree<std::unique_ptr<int>> shape_tree{tuple_shape_};
EXPECT_EQ(shape_tree.element({2}).get(), nullptr);
- *shape_tree.mutable_element({2}) = MakeUnique<int>(42);
+ *shape_tree.mutable_element({2}) = absl::make_unique<int>(42);
EXPECT_EQ(*shape_tree.element({2}), 42);
}
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index b69c346f1e..9772c06bce 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -22,6 +22,14 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/ascii.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
@@ -30,26 +38,22 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
+using absl::StrAppend;
+using absl::StrCat;
string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
string ShapeIndexView::ToString() const {
- return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
+ return StrCat("{", absl::StrJoin(indices_, ","), "}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
@@ -91,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
}
if (ShapeUtil::IsTuple(lhs)) {
- return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- [=](const Shape& l, const Shape& r) {
- return CompareShapes(l, r, compare_layouts,
- ignore_fp_precision);
- });
+ return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ [=](const Shape& l, const Shape& r) {
+ return CompareShapes(l, r, compare_layouts,
+ ignore_fp_precision);
+ });
} else if (!ShapeUtil::IsArray(lhs)) {
// Non-tuple, non-array tupes such as opaque and token types are trivially
// the same.
@@ -107,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
return false;
}
if (LayoutUtil::IsDenseArray(lhs)) {
- if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs),
- LayoutUtil::MinorToMajor(rhs))) {
+ if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs),
+ LayoutUtil::MinorToMajor(rhs))) {
VLOG(3) << "CompareShapes: lhs layout != rhs layout";
return false;
}
- if (!ContainersEqual(lhs.layout().padded_dimensions(),
- rhs.layout().padded_dimensions())) {
+ if (!absl::c_equal(lhs.layout().padded_dimensions(),
+ rhs.layout().padded_dimensions())) {
VLOG(3)
<< "CompareShapes: lhs padded_dimensions != rhs padded_dimensions";
return false;
@@ -135,15 +139,15 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr<Shape> MakeShapeWithLayoutInternal(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
if (dimensions.size() != minor_to_major.size()) {
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
dimensions.size(), minor_to_major.size());
}
if (element_type == OPAQUE || element_type == TUPLE) {
return InvalidArgument("Unsupported element type: %s",
- PrimitiveType_Name(element_type).c_str());
+ PrimitiveType_Name(element_type));
}
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
@@ -210,8 +214,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return program_shape;
}
-/* static */ Shape ShapeUtil::MakeShape(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type,
+ absl::Span<const int64> dimensions) {
CHECK(IsArrayPrimitiveType(element_type));
Shape result;
PopulateShape(element_type, dimensions, &result);
@@ -219,21 +223,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major) {
return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ PrimitiveType element_type, absl::Span<const int64> dimensions) {
std::vector<int64> layout(dimensions.size());
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
return MakeShapeWithLayout(element_type, dimensions, layout);
}
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ PrimitiveType element_type, absl::Span<const int64> dimensions,
int64 max_sparse_elements) {
CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
@@ -252,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
}
-/* static */ void ShapeUtil::PopulateShape(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- Shape* shape) {
+/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ Shape* shape) {
shape->Clear();
shape->set_element_type(element_type);
for (int64 dimension : dimensions) {
@@ -264,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
TF_DCHECK_OK(ValidateShape(*shape));
}
-/* static */ Shape ShapeUtil::MakeTupleShape(
- tensorflow::gtl::ArraySlice<Shape> shapes) {
+/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
Shape result;
result.set_element_type(TUPLE);
result.mutable_tuple_shapes()->Reserve(shapes.size());
@@ -449,14 +452,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
namespace {
// Class to memoize the computation of
-// tensorflow::str_util::Lowercase(PrimitiveType_Name(p))
+// absl::AsciiStrToLower(PrimitiveType_Name(p))
// for all PrimitiveType values "p"
class PrimitiveTypeNameGenerator {
public:
PrimitiveTypeNameGenerator() {
for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
if (PrimitiveType_IsValid(i)) {
- lowercase_name_[i] = tensorflow::str_util::Lowercase(
+ lowercase_name_[i] = absl::AsciiStrToLower(
PrimitiveType_Name(static_cast<PrimitiveType>(i)));
}
}
@@ -487,8 +490,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
}();
auto found = name_to_type->find(name);
if (found == name_to_type->end()) {
- return InvalidArgument("Invalid element type string: \"%s\".",
- name.c_str());
+ return InvalidArgument("Invalid element type string: \"%s\".", name);
}
return found->second;
}
@@ -507,7 +509,7 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
return text;
}
return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
- tensorflow::str_util::Join(shape.dimensions(), ","), "]");
+ absl::StrJoin(shape.dimensions(), ","), "]");
}
/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
@@ -543,30 +545,29 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
: "(unknown)",
": ", HumanString(shape)));
}
- return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
+ return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
HumanString(program_shape.result()));
}
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
-StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
- tensorflow::str_util::RemoveLeadingWhitespace(s);
+StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
+ *s = StripLeadingAsciiWhitespace(*s);
- if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple.
+ if (absl::ConsumePrefix(s, "(")) { // Tuple.
std::vector<Shape> shapes;
bool must_end = false;
while (true) {
- if (tensorflow::str_util::ConsumePrefix(s, ")")) {
+ if (absl::ConsumePrefix(s, ")")) {
break;
} else if (must_end) {
- return InvalidArgument("Expected end of tuple; got: \"%s\"",
- std::string(*s).c_str());
+ return InvalidArgument("Expected end of tuple; got: \"%s\"", *s);
}
shapes.emplace_back();
TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s));
- tensorflow::str_util::RemoveLeadingWhitespace(s);
- must_end = !tensorflow::str_util::ConsumePrefix(s, ",");
+ *s = StripLeadingAsciiWhitespace(*s);
+ must_end = !absl::ConsumePrefix(s, ",");
}
return ShapeUtil::MakeTupleShape(shapes);
}
@@ -575,9 +576,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
string dimensions_string;
string format_string;
string layout_string;
- // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
+ // absl::string_view is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
- // amount from our StringPiece type.
+ // amount from our string_view type.
static LazyRE2 shape_pattern = {
"^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"};
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
@@ -585,12 +586,12 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
&dimensions_string, &format_string, &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
- auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
+ auto string_to_int64 = [&s](absl::string_view input) -> StatusOr<int64> {
int64 element;
- if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) {
+ if (!absl::SimpleAtoi(input, &element)) {
return InvalidArgument(
- "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"",
- input.c_str(), std::string(*s).c_str());
+ "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input,
+ *s);
}
return element;
};
@@ -598,7 +599,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
auto comma_list_to_int64s =
[string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
- for (const string& piece : tensorflow::str_util::Split(input, ',')) {
+ for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) {
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
results.push_back(element);
}
@@ -614,7 +615,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
StringToPrimitiveType(element_type_string));
if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
return InvalidArgument("Invalid element type string: \"%s\".",
- element_type_string.c_str());
+ element_type_string);
}
Shape result;
@@ -644,17 +645,14 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return std::move(result);
}
- return InvalidArgument("Invalid shape string to parse: \"%s\"",
- std::string(*s).c_str());
+ return InvalidArgument("Invalid shape string to parse: \"%s\"", *s);
}
} // namespace
-/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(
- tensorflow::StringPiece s) {
+/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(absl::string_view s) {
TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s));
if (!s.empty()) {
- return InvalidArgument("Invalid shape string to parse: \"%s\"",
- std::string(s).c_str());
+ return InvalidArgument("Invalid shape string to parse: \"%s\"", s);
}
return shape;
}
@@ -663,7 +661,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
const Shape& rhs) {
CHECK(ShapeUtil::IsArray(lhs));
CHECK(ShapeUtil::IsArray(rhs));
- return ContainersEqual(lhs.dimensions(), rhs.dimensions());
+ return absl::c_equal(lhs.dimensions(), rhs.dimensions());
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
@@ -677,8 +675,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return IsArray(rhs) && SameDimensions(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringElementType);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringElementType);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -692,8 +690,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
CompatibleIgnoringElementType(lhs, rhs);
} else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
- CompatibleIgnoringFpPrecision);
+ absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ CompatibleIgnoringFpPrecision);
} else {
// Opaque, token, etc types are vacuously compatible.
return lhs.element_type() == rhs.element_type();
@@ -792,7 +790,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
- tensorflow::gtl::ArraySlice<int64> padded_dimensions =
+ absl::Span<const int64> padded_dimensions =
LayoutUtil::PaddedDimensions(shape);
if (!padded_dimensions.empty()) {
CHECK_EQ(Rank(shape), padded_dimensions.size());
@@ -819,7 +817,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
const Shape& shape) {
if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("shape has invalid element type: %s",
- shape.ShortDebugString().c_str());
+ shape.ShortDebugString());
}
if (shape.element_type() == TUPLE) {
if (shape.dimensions_size() != 0) {
@@ -842,21 +840,21 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape.dimensions_size() != 0) {
return InvalidArgument(
"shape has %s element type, but has dimensions field: %s",
- LowercasePrimitiveTypeName(shape.element_type()).c_str(),
- shape.ShortDebugString().c_str());
+ LowercasePrimitiveTypeName(shape.element_type()),
+ shape.ShortDebugString());
}
if (shape.has_layout()) {
return InvalidArgument(
"shape has %s element type, but has layout field: %s",
- LowercasePrimitiveTypeName(shape.element_type()).c_str(),
- shape.ShortDebugString().c_str());
+ LowercasePrimitiveTypeName(shape.element_type()),
+ shape.ShortDebugString());
}
return Status::OK();
}
if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument(
- "shape's rank is mismatched with dimension count; rank=%lld "
+ "shape's rank is mismatched with dimension count; rank=%d "
"dimensions_size=%d",
Rank(shape), shape.dimensions_size());
}
@@ -864,9 +862,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
int64 dimension = shape.dimensions(i);
if (dimension < 0) {
return InvalidArgument(
- "shape's dimensions must not be < 0; dimension at index %lld was "
- "%lld",
- i, dimension);
+ "shape's dimensions must not be < 0; dimension at index %d was %d", i,
+ dimension);
}
}
@@ -931,7 +928,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape_size < 0) {
return InvalidArgument("Shape %s size may overflow int64.",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
VLOG(3) << "Shape size is valid: " << shape_size;
@@ -991,7 +988,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
i >= return_shape->tuple_shapes_size()) {
return InvalidArgument(
"Shape index %s not a valid subshape index for tuple with shape %s",
- index.ToString().c_str(), shape.DebugString().c_str());
+ index.ToString(), shape.DebugString());
}
return_shape = &return_shape->tuple_shapes(i);
}
@@ -1037,7 +1034,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
CHECK(ShapeUtil::IsArray(shape));
- return ArrayContains<int64>(AsInt64Slice(shape.dimensions()), 1);
+ return absl::c_linear_search(shape.dimensions(), 1);
}
namespace {
@@ -1117,7 +1114,7 @@ Status ForEachMutableSubshapeHelper(
}
/* static */ Shape ShapeUtil::PermuteDimensions(
- tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) {
+ absl::Span<const int64> permutation, const Shape& shape) {
Shape new_shape = shape;
new_shape.clear_dimensions();
for (auto dim : Permute(permutation, shape.dimensions())) {
@@ -1172,8 +1169,7 @@ Status ForEachMutableSubshapeHelper(
CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
<< "shape=" << HumanStringWithLayout(shape)
<< ", new_shape=" << HumanStringWithLayout(new_shape)
- << ", permutation={" << tensorflow::str_util::Join(permutation, ",")
- << "}";
+ << ", permutation={" << absl::StrJoin(permutation, ",") << "}";
}
return new_shape;
}
@@ -1262,7 +1258,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::TransposeIsBitcast(
const Shape& input_shape, const Shape& output_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping) {
+ absl::Span<const int64> dimension_mapping) {
CHECK(LayoutUtil::HasLayout(input_shape) &&
LayoutUtil::HasLayout(output_shape));
@@ -1289,7 +1285,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
// apply(input_dimensions, I) =
// apply((dimension_mapping * output_dimensions), I)
// input_dimensions = dimension_mapping * output_dimensions
- return ContainersEqual(
+ return absl::c_equal(
ComposePermutations(dimension_mapping,
AsInt64Slice(output_shape.layout().minor_to_major())),
input_shape.layout().minor_to_major());
@@ -1460,7 +1456,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
check_input_unit_indices(output_shape, input_shape);
}
-/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
+/* static */ absl::optional<Shape> ShapeUtil::AlignLayouts(
const Shape& input_shape, const Shape& output_shape) {
CHECK(IsArray(input_shape));
CHECK(IsArray(output_shape));
@@ -1499,7 +1495,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
if (input_dimension_product < output_dimension_product ||
j == output_rank) {
if (i == input_rank) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
dimension_to_alignment_index[i] = alignment.size() - 1;
input_dimension_product *= input_shape.dimensions(i);
@@ -1510,7 +1506,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
}
}
if (input_dimension_product != output_dimension_product) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
// We also need to store an end element so that we know where the last
// alignment part ends.
@@ -1554,7 +1550,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part;
++i, ++j) {
if (i == input_rank) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
// Skip trivial dimensions with a bound of 1.
if (input_shape.dimensions(input_dimension_numbers[i]) == 1) {
@@ -1567,7 +1563,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
if (dimension_to_alignment_index[input_dimension_numbers[i]] !=
current_alignment_index ||
input_dimension_numbers[i] > current_dimension_number) {
- return tensorflow::gtl::nullopt;
+ return absl::nullopt;
}
current_dimension_number = input_dimension_numbers[i];
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index d6f17fc965..8234fcdd3f 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,9 @@ limitations under the License.
#include <initializer_list>
#include <string>
+#include "absl/container/inlined_vector.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -30,9 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
@@ -74,7 +74,7 @@ class ShapeIndex {
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
- using container_type = tensorflow::gtl::InlinedVector<int64, 2>;
+ using container_type = absl::InlinedVector<int64, 2>;
container_type::const_iterator begin() const { return indices_.begin(); }
container_type::const_iterator end() const { return indices_.end(); }
@@ -131,12 +131,12 @@ class ShapeIndexView {
}
ShapeIndexView ConsumeFront() const {
ShapeIndexView result = *this;
- result.indices_.pop_front();
+ result.indices_.remove_prefix(1);
return result;
}
ShapeIndexView ConsumeBack() const {
ShapeIndexView result = *this;
- result.indices_.pop_back();
+ result.indices_.remove_suffix(1);
return result;
}
ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
@@ -147,7 +147,7 @@ class ShapeIndexView {
string ToString() const;
private:
- tensorflow::gtl::ArraySlice<int64> indices_;
+ absl::Span<const int64> indices_;
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
@@ -228,7 +228,7 @@ class ShapeUtil {
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
- static StatusOr<Shape> ParseShapeString(tensorflow::StringPiece s);
+ static StatusOr<Shape> ParseShapeString(absl::string_view s);
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
// not check element type.
@@ -328,7 +328,7 @@ class ShapeUtil {
static Shape ChangeElementType(const Shape& original, PrimitiveType type);
// Creates a tuple shape from a slice of element shapes within the tuple.
- static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes);
+ static Shape MakeTupleShape(absl::Span<const Shape> shapes);
// Creates an opaque shape. These are generally used for threading a context
// into a custom operation.
@@ -355,31 +355,29 @@ class ShapeUtil {
// Constructs a new shape with the given element type and sequence of
// dimensions.
static Shape MakeShape(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ absl::Span<const int64> dimensions);
// Creates a Shape with element type corresponding to T and the given
// dimensions
template <typename T>
- static Shape MakeShapeWithType(
- tensorflow::gtl::ArraySlice<int64> dimensions) {
+ static Shape MakeShapeWithType(absl::Span<const int64> dimensions) {
return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(),
dimensions);
}
// Constructs a new shape with the given minor_to_major order in its Layout.
// Returns a value shape such that shape.has_layout().
- static Shape MakeShapeWithLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ static Shape MakeShapeWithLayout(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ absl::Span<const int64> minor_to_major);
- static Shape MakeShapeWithSparseLayout(
- PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
- int64 max_sparse_elements);
+ static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
+ absl::Span<const int64> dimensions,
+ int64 max_sparse_elements);
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithDescendingLayout(
- PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
+ PrimitiveType element_type, absl::Span<const int64> dimensions);
// Returns a new Shape based on the given Shape with low-dimension-major
// layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
@@ -391,8 +389,7 @@ class ShapeUtil {
// As MakeShape, but the object to write to is passed in.
static void PopulateShape(PrimitiveType element_type,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- Shape* shape);
+ absl::Span<const int64> dimensions, Shape* shape);
// Validates that the provided shape satisfies invariants.
static Status ValidateShape(const Shape& shape);
@@ -539,7 +536,7 @@ class ShapeUtil {
// !HasLayout(shape) ||
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
// InversePermutation(permutation)).
- static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
+ static Shape PermuteDimensions(absl::Span<const int64> permutation,
const Shape& shape);
// If we can go from `shape_pre` to `shape_post` by merely inserting or
@@ -580,9 +577,9 @@ class ShapeUtil {
// to its input and thus may be replaced with a bitcast.
//
// Precondition: Both input_shape and output_shape have explicit layouts.
- static bool TransposeIsBitcast(
- const Shape& input_shape, const Shape& output_shape,
- tensorflow::gtl::ArraySlice<int64> dimension_mapping);
+ static bool TransposeIsBitcast(const Shape& input_shape,
+ const Shape& output_shape,
+ absl::Span<const int64> dimension_mapping);
// Returns whether a reshape from "input_shape" to "output_shape" is a
// bitcast.
@@ -597,8 +594,8 @@ class ShapeUtil {
// layout). The layout of 'input_shape' is kept fixed. Returns
// 'output_shape_with_layout' if such a layout can be found, and an error
// otherwise.
- static tensorflow::gtl::optional<Shape> AlignLayouts(
- const Shape& input_shape, const Shape& output_shape);
+ static absl::optional<Shape> AlignLayouts(const Shape& input_shape,
+ const Shape& output_shape);
// Returns a shape with the given dimension deleted.
// For example:
@@ -621,12 +618,12 @@ class ShapeUtil {
// continue, or false otherwise.
//
// visitor_function must be a callable of type
- // StatusOr<bool>(ArraySlice<int64>) or compatible.
+ // StatusOr<bool>(Span<int64>) or compatible.
template <typename FnType>
static Status ForEachIndexWithStatus(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
return ForEachIndexInternal(shape, base, count, incr, visitor_function);
}
@@ -648,13 +645,12 @@ class ShapeUtil {
}
template <typename FnType>
- static void ForEachIndex(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ static void ForEachIndex(const Shape& shape, absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
ForEachIndexWithStatus(shape, base, count, incr,
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ [&](absl::Span<const int64> indices) {
return StatusOr<bool>(visitor_function(indices));
})
.IgnoreError();
@@ -676,7 +672,7 @@ class ShapeUtil {
template <typename FnType>
static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
ForEachIndexWithStatus(shape,
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ [&](absl::Span<const int64> indices) {
return StatusOr<bool>(visitor_function(indices));
})
.IgnoreError();
@@ -687,18 +683,18 @@ class ShapeUtil {
// matter.
//
// visitor_function must be a callable of type
- // void(ArraySlice<int64>) or compatible.
+ // void(Span<int64>) or compatible.
template <typename FnType>
static void ForEachIndexParallel(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function) {
// The parallel version of ForEachIndexInternal can never fail.
CHECK(ForEachIndexInternal(
shape, base, count, incr,
- [&visitor_function](tensorflow::gtl::ArraySlice<int64> indexes)
- -> StatusOr<bool> {
+ [&visitor_function](
+ absl::Span<const int64> indexes) -> StatusOr<bool> {
visitor_function(indexes);
return true;
},
@@ -720,9 +716,9 @@ class ShapeUtil {
template <typename FnType>
static Status ForEachIndexInternal(const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> base,
- tensorflow::gtl::ArraySlice<int64> count,
- tensorflow::gtl::ArraySlice<int64> incr,
+ absl::Span<const int64> base,
+ absl::Span<const int64> count,
+ absl::Span<const int64> incr,
const FnType& visitor_function,
bool parallel = false) {
if (ShapeUtil::IsZeroElementArray(shape)) {
@@ -737,13 +733,13 @@ class ShapeUtil {
int64 n = -1;
std::vector<int64> indexes(base.begin(), base.end());
const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
- tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
+ absl::optional<tensorflow::thread::ThreadPool> pool;
if (parallel) {
pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
}
while (n < rank) {
- if (pool != tensorflow::gtl::nullopt) {
+ if (pool != absl::nullopt) {
pool->Schedule(
[indexes, &visitor_function] { visitor_function(indexes); });
} else {
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index e5dd62ae9a..6ca4085aaf 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include <numeric>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@@ -23,8 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
@@ -705,11 +705,10 @@ TEST(ShapeUtilTest, ForEachIndex) {
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
- auto increment_func =
- [&invocations](tensorflow::gtl::ArraySlice<int64> indexes) {
- invocations++;
- return true;
- };
+ auto increment_func = [&invocations](absl::Span<const int64> indexes) {
+ invocations++;
+ return true;
+ };
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
@@ -726,8 +725,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
// Increments at every invocation.
int invocations = 0;
auto increment_func =
- [&invocations](
- tensorflow::gtl::ArraySlice<int64> indexes) -> StatusOr<bool> {
+ [&invocations](absl::Span<const int64> indexes) -> StatusOr<bool> {
if (++invocations == 5) {
return Unimplemented("Cannot increment beyond 5.");
}
@@ -748,7 +746,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) {
Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
int64 output[10][10];
int init = 5;
- auto set_func = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ auto set_func = [&](absl::Span<const int64> indexes) {
output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1];
};
@@ -849,13 +847,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
std::iota(layout.begin(), layout.end(), 0);
do {
Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout);
- SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s)));
+ SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s)));
std::vector<int64> permutation(3);
std::iota(permutation.begin(), permutation.end(), 0);
do {
- SCOPED_TRACE(tensorflow::strings::StrCat(
- "permutation=", tensorflow::str_util::Join(permutation, ",")));
+ SCOPED_TRACE(
+ absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
// TransposeIsBitcast takes the inverse of the permutation that
// PermuteDimensions takes.
diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc
index 31844abd89..1c135dda86 100644
--- a/tensorflow/compiler/xla/sparse_index_array.cc
+++ b/tensorflow/compiler/xla/sparse_index_array.cc
@@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
}
SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank,
- tensorflow::gtl::ArraySlice<int64> indices)
+ absl::Span<const int64> indices)
: SparseIndexArray(max_indices, rank,
std::vector<int64>(indices.begin(), indices.end())) {}
@@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const {
return indices_.size() / rank_;
}
-tensorflow::gtl::ArraySlice<int64> SparseIndexArray::At(
+absl::Span<const int64> SparseIndexArray::At(
int64 sparse_element_number) const {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
- return tensorflow::gtl::ArraySlice<int64>(
+ return absl::Span<const int64>(
indices_.data() + rank_ * sparse_element_number, rank_);
}
-tensorflow::gtl::MutableArraySlice<int64> SparseIndexArray::At(
- int64 sparse_element_number) {
+absl::Span<int64> SparseIndexArray::At(int64 sparse_element_number) {
CHECK_GT(rank_, 0);
CHECK_GE(sparse_element_number, 0);
CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size());
- return tensorflow::gtl::MutableArraySlice<int64>(
- indices_.data() + rank_ * sparse_element_number, rank_);
+ return absl::Span<int64>(indices_.data() + rank_ * sparse_element_number,
+ rank_);
}
-void SparseIndexArray::Append(tensorflow::gtl::ArraySlice<int64> index) {
+void SparseIndexArray::Append(absl::Span<const int64> index) {
CHECK_GT(rank_, 0);
CHECK_EQ(index.size(), rank_);
indices_.insert(indices_.end(), index.begin(), index.end());
@@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const {
if (num_indices < 2) {
return true;
}
- tensorflow::gtl::ArraySlice<int64> last = At(0);
+ absl::Span<const int64> last = At(0);
if (!IndexUtil::IndexInBounds(shape, last)) {
return false;
}
for (int64 n = 1; n < num_indices; ++n) {
- tensorflow::gtl::ArraySlice<int64> next = At(n);
+ absl::Span<const int64> next = At(n);
if (!IndexUtil::IndexInBounds(shape, next)) {
return false;
}
diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h
index f2ce22d672..a96d483462 100644
--- a/tensorflow/compiler/xla/sparse_index_array.h
+++ b/tensorflow/compiler/xla/sparse_index_array.h
@@ -20,10 +20,11 @@ limitations under the License.
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@@ -64,7 +65,7 @@ class SparseIndexArray {
SparseIndexArray(int64 max_indices, int64 rank,
std::vector<int64> indices = {});
SparseIndexArray(int64 max_indices, int64 rank,
- tensorflow::gtl::ArraySlice<int64> indices);
+ absl::Span<const int64> indices);
// Returns the number of elements represented by the indices stored in the
// array.
@@ -72,12 +73,12 @@ class SparseIndexArray {
// Returns a slice that refers to the given sparse index number. The argument
// must be in the range [0, element_count()).
- tensorflow::gtl::ArraySlice<int64> At(int64 sparse_element_number) const;
- tensorflow::gtl::MutableArraySlice<int64> At(int64 sparse_element_number);
+ absl::Span<const int64> At(int64 sparse_element_number) const;
+ absl::Span<int64> At(int64 sparse_element_number);
// Adds the given index at the end of the array. The new size of the
// SparseIndexArray must not exceed `max_indices`.
- void Append(tensorflow::gtl::ArraySlice<int64> index);
+ void Append(absl::Span<const int64> index);
// Removes all indices from the array.
void Clear();
@@ -95,8 +96,8 @@ class SparseIndexArray {
int64 max_indices() const { return max_indices_; }
// Returns a pointer to the int64 array that holds the sparse indices.
- tensorflow::gtl::MutableArraySlice<int64> mutable_data() { return &indices_; }
- tensorflow::gtl::ArraySlice<int64> data() const { return indices_; }
+ absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
+ absl::Span<const int64> data() const { return indices_; }
// Sorts this sparse index array along with the set of corresponding values.
// The indices and values are sorted in the lexicographic order of the
@@ -114,7 +115,7 @@ class SparseIndexArray {
// std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
//
template <typename NativeT>
- void SortWithValues(tensorflow::gtl::MutableArraySlice<NativeT> values);
+ void SortWithValues(absl::Span<NativeT> values);
private:
std::vector<int64> indices_;
@@ -123,8 +124,7 @@ class SparseIndexArray {
};
template <typename NativeT>
-void SparseIndexArray::SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT> values) {
+void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
int64 num_elements = index_count();
CHECK_EQ(values.size(), num_elements);
std::vector<int64> sort_order;
@@ -139,7 +139,7 @@ void SparseIndexArray::SortWithValues(
// Reorder the array elements according to sort_order. Work through the array
// and follow cycles so we can do the reorder in-place.
- tensorflow::gtl::InlinedVector<int64, 8> saved_index(rank());
+ absl::InlinedVector<int64, 8> saved_index(rank());
for (int64 i = 0; i < num_elements; ++i) {
// sort_order[i] == -1 indicates the element has already been copied.
if (sort_order[i] < 0) {
diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc
index 7377f88958..e54057c400 100644
--- a/tensorflow/compiler/xla/sparse_index_array_test.cc
+++ b/tensorflow/compiler/xla/sparse_index_array_test.cc
@@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) {
std::vector<double> values = {
12.0, 13.0, 11.0, 15.0, 14.0, 16.0,
};
- a.SortWithValues<double>(&values);
+ a.SortWithValues<double>(absl::MakeSpan(values));
ASSERT_EQ(a.data(), std::vector<int64>({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5,
6, 7, 6, 7, 8}));
ASSERT_EQ(values, std::vector<double>({11.0, 12.0, 13.0, 14.0, 15.0, 16.0}));
diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc
index a6b1f9004f..b88fe367d7 100644
--- a/tensorflow/compiler/xla/status_macros.cc
+++ b/tensorflow/compiler/xla/status_macros.cc
@@ -17,9 +17,8 @@ limitations under the License.
#include <algorithm>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line,
if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) {
string stack_trace;
if (should_log_stack_trace) {
- stack_trace =
- tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace());
+ stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace());
}
switch (log_severity) {
case tensorflow::INFO:
@@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() {
is_done_ = true;
const string& stream_str = stream_.str();
- const string str =
- prior_message_handling_ == kAppendToPriorMessage
- ? tensorflow::strings::StrCat(prior_message_, stream_str)
- : tensorflow::strings::StrCat(stream_str, prior_message_);
+ const string str = prior_message_handling_ == kAppendToPriorMessage
+ ? absl::StrCat(prior_message_, stream_str)
+ : absl::StrCat(stream_str, prior_message_);
if (TF_PREDICT_FALSE(str.empty())) {
- return MakeError(file_, line_, code_,
- tensorflow::strings::StrCat(
- str, "Error without message at ", file_, ":", line_),
- true /* should_log */,
- tensorflow::ERROR /* log_severity */,
- should_log_stack_trace_);
+ return MakeError(
+ file_, line_, code_,
+ absl::StrCat(str, "Error without message at ", file_, ":", line_),
+ true /* should_log */, tensorflow::ERROR /* log_severity */,
+ should_log_stack_trace_);
} else {
return MakeError(file_, line_, code_, str, should_log_, log_severity_,
should_log_stack_trace_);
diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h
index 87a8c5f3a5..a657554dc2 100644
--- a/tensorflow/compiler/xla/test.h
+++ b/tensorflow/compiler/xla/test.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_
-#define TENSORFLOW_COMPLIER_XLA_TEST_H_
+#ifndef TENSORFLOW_COMPILER_XLA_TEST_H_
+#define TENSORFLOW_COMPILER_XLA_TEST_H_
// This header includes gmock.h and enables the use of gmock matchers in tests
// in third_party/tensorflow/compiler/xla.
@@ -45,4 +45,4 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
-#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_
+#endif // TENSORFLOW_COMPILER_XLA_TEST_H_
diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h
index 8918350135..3ede5e6e38 100644
--- a/tensorflow/compiler/xla/test_helpers.h
+++ b/tensorflow/compiler/xla/test_helpers.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 76f2e519ae..d0bda45cf8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -43,6 +43,7 @@ cc_library(
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
alwayslink = True,
)
@@ -68,14 +69,14 @@ cc_library(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_headers_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -98,6 +99,9 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -113,7 +117,6 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:backend",
@@ -127,6 +130,10 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
],
)
@@ -144,6 +151,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -187,7 +195,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
@@ -201,6 +208,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -274,6 +284,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -385,6 +397,8 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
],
)
@@ -551,6 +565,8 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -567,8 +583,7 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -591,8 +606,8 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -614,12 +629,11 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -665,6 +679,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
],
)
@@ -683,7 +698,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -691,6 +705,7 @@ xla_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -715,10 +730,8 @@ xla_test(
deps = [
":client_library_test_base",
":hlo_test_base",
- "//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -742,7 +755,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -750,6 +762,7 @@ xla_test(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -825,7 +838,10 @@ xla_test(
timeout = "long",
srcs = ["convolution_test.cc"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS,
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
)
xla_test(
@@ -835,7 +851,10 @@ xla_test(
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
backends = ["gpu"],
shard_count = 25,
- deps = CONVOLUTION_TEST_DEPS,
+ deps = CONVOLUTION_TEST_DEPS + [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
)
xla_test(
@@ -886,6 +905,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -919,6 +939,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -995,6 +1016,10 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1068,6 +1093,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1103,7 +1129,6 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1121,6 +1146,9 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1149,6 +1177,9 @@ xla_test_library(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1156,6 +1187,7 @@ xla_test(
name = "reduce_window_test",
timeout = "long",
srcs = [],
+ shard_count = 20,
tags = [
"enable_for_xla_interpreter",
"optonly",
@@ -1211,6 +1243,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1221,12 +1254,12 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- ":client_library_test_base",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1237,12 +1270,12 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- ":client_library_test_base",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1286,6 +1319,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1351,6 +1385,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1401,7 +1436,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
@@ -1412,6 +1446,9 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1426,11 +1463,11 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1444,14 +1481,12 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
@@ -1461,7 +1496,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1481,6 +1516,8 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
],
)
@@ -1541,17 +1578,16 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
+ "@com_google_absl//absl/algorithm:container",
],
)
@@ -1616,8 +1652,8 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1630,12 +1666,13 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1648,7 +1685,6 @@ xla_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:global_data",
@@ -1659,6 +1695,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/strings",
],
)
@@ -1752,6 +1789,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
],
)
@@ -1773,6 +1811,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
"@llvm//:core",
],
)
@@ -1793,6 +1832,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1806,15 +1846,11 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_parser",
- "//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -1824,6 +1860,8 @@ xla_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1831,18 +1869,12 @@ xla_test(
name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"],
deps = [
- "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
- "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1850,6 +1882,9 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
)
@@ -1876,7 +1911,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1884,6 +1918,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -1978,16 +2013,15 @@ xla_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/types:span",
],
)
@@ -2010,6 +2044,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -2051,6 +2086,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -2087,6 +2123,7 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
@@ -2094,19 +2131,15 @@ xla_test(
xla_test(
name = "iota_test",
srcs = ["iota_test.cc"],
- blacklisted_backends = [
- "cpu",
- "gpu",
- ],
+ shard_count = 30,
tags = [
"enable_for_xla_interpreter",
+ # Require optimized builds, iota_test_cpu is very slow in fastbuild.
+ "optonly",
],
deps = [
":client_library_test_base",
- ":literal_test_util",
":xla_internal_test_main",
- "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 74f2e36f82..c257566fb2 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -225,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<uint64> rhs{1,
0x7FFFFFFFFFFFFFFLL,
@@ -239,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Add(lhs_param, rhs_param);
@@ -265,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
- auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+ Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
- client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
std::vector<int64> rhs{-1,
0,
@@ -278,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
- auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+ Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
- client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+ client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
Sub(lhs_param, rhs_param);
@@ -293,6 +294,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
}
+XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
+ XlaBuilder b(TestName());
+
+ std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
+ Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
+
+ std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
+ Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
+
+ Lt(lhs_param, rhs_param);
+
+ ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
+}
+
TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
const int count = GetParam();
XlaBuilder builder(TestName());
@@ -303,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
auto a_constant = ConstantR1<float>(&builder, a_values);
- auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
+ auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
auto b_param = ConstantR1<float>(&builder, b_values);
auto sum1 = Add(a_constant, b_constant);
@@ -411,7 +428,65 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
-XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
+class IntegerDivideOpTest : public ArrayElementwiseOpTest {
+ protected:
+ template <typename T>
+ void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
+ absl::Span<const T> quotients,
+ absl::Span<const T> remainders) {
+ {
+ XlaBuilder builder(TestName());
+ XlaOp dividend;
+ XlaOp divisor;
+ auto dividend_data =
+ CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
+ auto divisor_data =
+ CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
+ Div(dividend, divisor);
+
+ ComputeAndCompareR1<T>(&builder, quotients,
+ {dividend_data.get(), divisor_data.get()});
+ }
+
+ // Test with a compile-time constant divisor.
+ {
+ XlaBuilder builder(TestName());
+ XlaOp dividend;
+ auto dividend_data =
+ CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
+ Div(dividend, ConstantR1<T>(&builder, divisors));
+
+ ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
+ }
+
+ {
+ XlaBuilder builder(TestName());
+ XlaOp dividend;
+ XlaOp divisor;
+ auto dividend_data =
+ CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
+ auto divisor_data =
+ CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
+ Rem(dividend, divisor);
+
+ ComputeAndCompareR1<T>(&builder, remainders,
+ {dividend_data.get(), divisor_data.get()});
+ }
+
+ // Test with a compile-time constant divisor.
+ {
+ XlaBuilder builder(TestName());
+ XlaOp dividend;
+ auto dividend_data =
+ CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
+ Rem(dividend, ConstantR1<T>(&builder, divisors));
+
+ ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
+ }
+ }
+};
+
+XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
// clang-format off
// Some interesting values to test.
std::vector<int32> vals = {
@@ -435,58 +510,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
}
}
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- XlaOp divisor;
- auto dividend_data =
- CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- auto divisor_data =
- CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
- Div(dividend, divisor);
-
- ComputeAndCompareR1<int32>(&builder, quotients,
- {dividend_data.get(), divisor_data.get()});
- }
-
- // Test with a compile-time constant divisor.
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- auto dividend_data =
- CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- Div(dividend, ConstantR1<int32>(&builder, divisors));
-
- ComputeAndCompareR1<int32>(&builder, quotients, {dividend_data.get()});
- }
-
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- XlaOp divisor;
- auto dividend_data =
- CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- auto divisor_data =
- CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
- Rem(dividend, divisor);
-
- ComputeAndCompareR1<int32>(&builder, remainders,
- {dividend_data.get(), divisor_data.get()});
- }
+ TestDivRem<int32>(dividends, divisors, quotients, remainders);
+}
- // Test with a compile-time constant divisor.
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- auto dividend_data =
- CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- Rem(dividend, ConstantR1<int32>(&builder, divisors));
+XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
+ std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
+ quotients = {-1, INT32_MIN}, remainders = {5, 0};
- ComputeAndCompareR1<int32>(&builder, remainders, {dividend_data.get()});
- }
+ TestDivRem<int32>(dividends, divisors, quotients, remainders);
}
-XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
+XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
// clang-format off
// Some interesting values to test.
std::vector<uint32> vals = {
@@ -506,53 +540,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
}
}
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- XlaOp divisor;
- auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
- &builder, &dividend);
- auto divisor_data =
- CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
- Div(dividend, divisor);
-
- ComputeAndCompareR1<uint32>(&builder, quotients,
- {dividend_data.get(), divisor_data.get()});
- }
-
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
- &builder, &dividend);
- Div(dividend, ConstantR1<uint32>(&builder, divisors));
-
- ComputeAndCompareR1<uint32>(&builder, quotients, {dividend_data.get()});
- }
-
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- XlaOp divisor;
- auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
- &builder, &dividend);
- auto divisor_data =
- CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
- Rem(dividend, divisor);
-
- ComputeAndCompareR1<uint32>(&builder, remainders,
- {dividend_data.get(), divisor_data.get()});
- }
+ TestDivRem<uint32>(dividends, divisors, quotients, remainders);
+}
- {
- XlaBuilder builder(TestName());
- XlaOp dividend;
- auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
- &builder, &dividend);
- Rem(dividend, ConstantR1<uint32>(&builder, divisors));
+XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
+ std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
+ remainders = {5};
- ComputeAndCompareR1<uint32>(&builder, remainders, {dividend_data.get()});
- }
+ TestDivRem<int32>(dividends, divisors, quotients, remainders);
}
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
@@ -1426,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ Literal param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto sum = ConstantR0<float>(&b, 0.0f);
- auto param = Parameter(&b, 0, param_literal->shape(), "param");
+ auto param = Parameter(&b, 0, param_literal.shape(), "param");
for (float exponent : exponents) {
sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
@@ -1454,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
@@ -1479,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
@@ -1504,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
@@ -1529,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
@@ -1555,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
@@ -1587,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
@@ -1620,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
@@ -1654,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+ Literal literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+ Literal literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+ Literal literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
- client_->TransferToServer(*literal2).ConsumeValueOrDie();
+ client_->TransferToServer(literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
+ Literal literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
- client_->TransferToServer(*literal3).ConsumeValueOrDie();
+ client_->TransferToServer(literal3).ConsumeValueOrDie();
- auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
- auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+ auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
+ auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
@@ -2100,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@@ -2122,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Add(p0, p1);
Array3D<float> expected(0, 7, 0);
@@ -2144,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
- auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@@ -2210,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
-0.79, 1.41, 1.21, 1.05});
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Tanh(input);
ComputeAndCompareR1<float>(
@@ -2243,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2256,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Exp(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::exp(input_literal->Get<float>({i})));
+ expected_result.push_back(std::exp(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2277,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+ Literal input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2294,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
Log(input);
std::vector<float> expected_result;
- int64 input_size = input_literal->shape().dimensions(0);
+ int64 input_size = input_literal.shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(std::log(input_literal->Get<float>({i})));
+ expected_result.push_back(std::log(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2469,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
- LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@@ -2825,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- auto a = ConstantLiteral(&builder, *a_literal);
+ Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ auto a = ConstantLiteral(&builder, a_literal);
auto b = ConstantR1<float>(&builder, r1);
Add(a, b, {1});
@@ -2890,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
- auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
+ auto y = Parameter(&builder, 1, y_literal.shape(), "y");
auto slice = Slice(x, {1}, {2}, {1});
Sub(slice, y);
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 24b17b7100..bc2ba151a3 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -41,7 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -63,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+ input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({4, 5}).get(),
- LiteralUtil::CreateR1<float>({5, 5}).get()});
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+ LiteralUtil::CreateR1<float>({4, 5}),
+ LiteralUtil::CreateR1<float>({5, 5})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR3FromArray3D<float>(
- {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
- .get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
@@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
- .get(),
- LiteralUtil::CreateR1<float>({0, 0}).get(),
- LiteralUtil::CreateR1<float>({16, 20}).get()});
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+ LiteralUtil::CreateR1<float>({0, 0}),
+ LiteralUtil::CreateR1<float>({16, 20})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
@@ -382,7 +377,7 @@ struct BatchNormTestParam {
friend ::std::ostream& operator<<(::std::ostream& os,
const BatchNormTestParam& p) {
- os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, ";
+ os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
os << "feature_index=" << p.feature_index << ", ";
os << "random_value_mean=" << p.random_value_mean << ", ";
os << "random_value_var=" << p.random_value_var;
@@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
- auto expected = LiteralUtil::MakeTuple(
- {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
- LiteralUtil::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+ LiteralUtil::CreateR1<float>(var)});
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
BatchNormTraining(input_activations, scale_activations, offset_activations,
epsilon, feature_index);
@@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, *expected,
+ &builder, expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
@@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
auto scale_activations =
- Parameter(&builder, 1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal.shape(), "offset");
auto offset_activations =
- Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal.shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
auto variance_activations =
- Parameter(&builder, 4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal.shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
- client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+ client_->TransferToServer(offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
BatchNormInference(input_activations, scale_activations, offset_activations,
mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
auto grad_output_literal =
LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
- auto input_parameter =
- Parameter(&builder, 0, input_literal->shape(), "input");
- auto scale_parameter =
- Parameter(&builder, 1, scale_literal->shape(), "scale");
- auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
- auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+ auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+ auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
auto grad_output_parameter =
- Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
- client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+ client_->TransferToServer(scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
- client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+ client_->TransferToServer(mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> var_data =
- client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+ client_->TransferToServer(var_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> grad_output_data =
- client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+ client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
grad_output_parameter, epsilon, feature_index);
- auto expected =
- LiteralUtil::MakeTuple({expected_grad_activation.get(),
- LiteralUtil::CreateR1<float>(grad_scale).get(),
- LiteralUtil::CreateR1<float>(grad_offset).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+ LiteralUtil::CreateR1<float>(grad_offset)});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, *expected,
+ ComputeAndCompareTuple(&builder, expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 6c20f654fe..e9728e636f 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -65,7 +65,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) {
Log(x);
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {},
- error_spec_);
+ ErrorSpec(0.01, 0.01));
}
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
@@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
- {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
- .get(),
+ {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
- .get()});
+ {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = LiteralUtil::MakeTuple(
+ auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
- {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
- .get(),
+ {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
- .get(),
+ {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
LiteralUtil::CreateR1<bfloat16>(
- {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
- .get()});
+ {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 1d28e85b16..dde19fb65d 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -53,29 +53,31 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
}
}
- std::unique_ptr<GlobalData> MakeR3Data(
- tensorflow::gtl::ArraySlice<int64> bounds,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r3_shape,
- Array3D<float>* r3_array, float start, float end, int seed) {
+ std::unique_ptr<GlobalData> MakeR3Data(absl::Span<const int64> bounds,
+ absl::Span<const int64> minor_to_major,
+ Shape* r3_shape,
+ Array3D<float>* r3_array, float start,
+ float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
- client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+ client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
- std::unique_ptr<GlobalData> MakeR2Data(
- tensorflow::gtl::ArraySlice<int64> bounds,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, Shape* r2_shape,
- Array2D<float>* r2_array, float start, float end, int seed) {
+ std::unique_ptr<GlobalData> MakeR2Data(absl::Span<const int64> bounds,
+ absl::Span<const int64> minor_to_major,
+ Shape* r2_shape,
+ Array2D<float>* r2_array, float start,
+ float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
- client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+ client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@@ -291,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@@ -299,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
spec.output_bounds[2]);
- auto Each = ([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+ auto Each = ([&](absl::Span<const int64> indices, float* value) {
float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
indices[1] % spec.input_bounds[1],
indices[2] % spec.input_bounds[2]);
@@ -368,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
- {r3_implicit_global_data.get(), r3_global_data.get()},
+ &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@@ -393,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+ ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@@ -616,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
+ &builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@@ -628,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -695,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -707,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -728,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -737,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index c7b94b5bba..9966e4606e 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
error_spec_));
}
@@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralSlice(*result, {0}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ LiteralSlice(result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralSlice(*result, {1}), error_spec_));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ LiteralSlice(result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
error_spec_));
}
@@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
error_spec_));
}
@@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
+ result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index b1d18210ea..8b31e53707 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant =
- ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+ auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
- auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
auto x =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
auto y =
- ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
@@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
auto elem = LiteralUtil::CreateR0<float>(42.0);
- auto tuple = LiteralUtil::MakeTuple({elem.get()});
- Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
+ auto tuple = LiteralUtil::MakeTuple({&elem});
+ Call(&builder, callee, {ConstantLiteral(&builder, elem)});
- ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+ ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index a4eb57fc7b..2f1510ff69 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
- auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
- auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+ auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
Add(p0, p1);
auto param0_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto param1_data =
- client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param_literal).ConsumeValueOrDie();
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
@@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
auto computation = computation_status.ConsumeValueOrDie();
auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
- auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
- client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
- auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+ auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
// Match
auto status = client_->Execute(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 59d917054b..fbdf0fcb65 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -17,18 +17,18 @@ limitations under the License.
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const {
}
StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return client_->Execute(computation, arguments, &execution_options_);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
@@ -114,18 +113,16 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
&execution_options);
}
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
}
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
+ const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout) {
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
@@ -138,7 +135,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference(
}
string ClientLibraryTestBase::ExecuteToString(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
auto computation_status = builder->Build();
if (!computation_status.ok()) {
return computation_status.status().ToString();
@@ -150,29 +147,28 @@ string ClientLibraryTestBase::ExecuteToString(
if (!result.ok()) {
return result.status().ToString();
} else {
- return result.ValueOrDie()->ToString();
+ return result.ValueOrDie().ToString();
}
}
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
void ClientLibraryTestBase::ComputeAndCompareLiteral(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout) {
+ absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
shape_with_layout));
}
void ClientLibraryTestBase::ComputeAndCompareLiteral(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
error, shape_with_layout));
@@ -180,12 +176,12 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output) {
// Try with no layout requirement.
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
- verify_output(*actual, "");
+ verify_output(actual, "");
// Try with all output layouts.
std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -196,8 +192,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
AsInt64Slice(expected.shape().dimensions()), minor_to_major);
TF_ASSIGN_OR_RETURN(auto actual,
ExecuteAndTransfer(computation, arguments, &layout));
- verify_output(*actual, tensorflow::strings::StrCat(
- "Test with output layout: ",
+ verify_output(actual,
+ absl::StrCat("Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
return Status::OK();
@@ -205,7 +201,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& /*expected*/,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout) {
@@ -221,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_ASSIGN_OR_RETURN(auto literal,
client_->Transfer(*arguments[index], nullptr));
// Skip tuples because they don't have a rank.
- if (ShapeUtil::IsTuple(literal->shape())) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal->shape()));
+ ShapeUtil::HumanStringWithLayout(literal.shape()));
arguments_with_layout.push_back(arguments[index]);
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -231,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
return Status::OK();
}
- std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+ std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
do {
auto literal_relayout =
- literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+ literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
layout_strings.push_back(
- ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+ ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
TF_ASSIGN_OR_RETURN(auto data,
- client_->TransferToServer(*literal_relayout));
+ client_->TransferToServer(literal_relayout));
arguments_with_layout.push_back(data.get());
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
@@ -252,15 +248,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
// Every argument has an assigned layout.
TF_ASSIGN_OR_RETURN(
auto actual,
- ExecuteAndTransfer(
- computation,
- tensorflow::gtl::ArraySlice<GlobalData*>(arguments_with_layout),
- output_with_layout));
+ ExecuteAndTransfer(computation,
+ absl::Span<GlobalData* const>(arguments_with_layout),
+ output_with_layout));
string error_message = "Test with input layouts: ";
for (const auto& str : layout_strings) {
- tensorflow::strings::StrAppend(&error_message, str, " ");
+ absl::StrAppend(&error_message, str, " ");
}
- verify_output(*actual, error_message);
+ verify_output(actual, error_message);
return Status::OK();
};
@@ -269,7 +264,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
+ absl::Span<GlobalData* const> arguments_passed_in,
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
@@ -290,19 +285,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
- } else {
- TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
- expected.shape().element_type() == PRED)
- << ShapeUtil::HumanString(expected.shape());
}
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -327,14 +318,14 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
return Status::OK();
}
Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
- ErrorSpec error, const Shape* shape_with_layout) {
+ absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
+ const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
arguments_passed_in.end());
@@ -350,17 +341,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
}
- TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
- ShapeUtil::ElementIsComplex(expected.shape()));
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
// We allow using a float expected literal for a bfloat16 output. In this
// case, we need to convert the expected literal to bfloat16.
const Literal* expected_ptr = &expected;
- std::unique_ptr<Literal> converted_expected;
+ Literal converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
- expected_ptr = converted_expected.get();
+ expected_ptr = &converted_expected;
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
ShapeUtil::ForEachMutableSubshape(
@@ -386,13 +375,13 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
return Status::OK();
}
void ClientLibraryTestBase::ComputeAndCompareR1U8(
- XlaBuilder* builder, tensorflow::StringPiece expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::string_view expected,
+ absl::Span<GlobalData* const> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
@@ -401,66 +390,65 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+ Literal expected_literal = LiteralUtil::CreateR1U8(expected);
- VLOG(1) << "expected: " << expected_literal->ToString();
- VLOG(1) << "actual: " << actual->ToString();
+ VLOG(1) << "expected: " << expected_literal.ToString();
+ VLOG(1) << "actual: " << actual.ToString();
- EXPECT_EQ(expected, actual->GetR1U8AsString());
+ EXPECT_EQ(expected, actual.GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<GlobalData* const> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
if (!actual_status.ok()) {
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments) {
auto status_or_data = ComputeValueAndReference(builder, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
}
void ClientLibraryTestBase::ComputeAndCompare(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
auto status_or_data = ComputeValueAndReference(builder, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
return;
}
- std::unique_ptr<Literal> reference, result;
+ Literal reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+ EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
}
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
ClientLibraryTestBase::ComputeValueAndReference(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
+ XlaBuilder* builder, absl::Span<const Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
// into a vector to keep the data alive on the service until the end of this
// function.
@@ -546,7 +534,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
int rows, int cols, float offset) {
- auto array = MakeUnique<Array2D<float>>(rows, cols);
+ auto array = absl::make_unique<Array2D<float>>(rows, cols);
for (int64 row = 0; row < rows; ++row) {
for (int64 col = 0; col < cols; ++col) {
(*array)(row, col) = col + (row * 1000.0f) + offset;
@@ -561,7 +549,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
int cols_padded) {
CHECK_GE(rows_padded, rows);
CHECK_GE(cols_padded, cols);
- auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0);
+ auto array = absl::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
for (int64 row = 0; row < rows; ++row) {
for (int64 col = 0; col < cols; ++col) {
(*array)(row, col) = col + (row * 1000.0f);
@@ -580,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return ConstantLiteral(builder, use_bfloat16_
- ? *LiteralUtil::ConvertF32ToBF16(literal)
- : literal);
+ ? LiteralUtil::ConvertF32ToBF16(literal)
+ : LiteralSlice(literal));
}
std::unique_ptr<GlobalData>
@@ -611,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
const Literal& literal) {
if (use_bfloat16_) {
- return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+ return LiteralUtil::ConvertF32ToBF16(literal);
}
return literal.Clone();
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index b04a3b105c..9d32f4f517 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -21,6 +21,9 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -30,14 +33,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -49,8 +49,8 @@ namespace xla {
// use_bfloat16_params with that value. Returns the result.
template <typename TestCase>
std::vector<TestCase> ExpandUseBfloat16(
- tensorflow::gtl::ArraySlice<bool> use_bfloat16_params,
- tensorflow::gtl::ArraySlice<TestCase> specs) {
+ absl::Span<const bool> use_bfloat16_params,
+ absl::Span<const TestCase> specs) {
std::vector<TestCase> expanded;
for (bool use_bfloat16 : use_bfloat16_params) {
for (const auto& spec : specs) {
@@ -93,29 +93,29 @@ class ClientLibraryTestBase : public ::testing::Test {
// execution options. Modify execution_options_ in your test if you want to
// customize the options.
StatusOr<std::unique_ptr<GlobalData>> Execute(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ StatusOr<Literal> ExecuteAndTransfer(
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ StatusOr<Literal> ExecuteAndTransfer(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
// This executes the computation via the reference client (which connects a
// interpreter backend). The result is used as the expected values of the
// computation.
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
+ StatusOr<Literal> ExecuteAndTransferReference(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_output_layout = nullptr);
// Run a computation and return its value as a string. If an error
// occurs, then instead return the error as a string.
string ExecuteToString(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
// Convenience methods for building and running a computation, transferring
// the result, and comparing it to the expected value(s). Methods are
@@ -125,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test {
// for integral types without the ErrorSpec parameter.
template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR1(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// As above, but uses a bitmap to hold the predicate vector to avoid
// deficiencies of vector<bool>.
void ComputeAndCompareR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR2(XlaBuilder* builder,
const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR3(XlaBuilder* builder,
const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompareR4(XlaBuilder* builder,
const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// Build and run the computation and compare the result with the given
// literal. shape_with_layout indicates the result layout to request when
// calling Execute.
- void ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout = nullptr);
- void ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
- const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
// ComputeAndCompare variant which returns an error status.
Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const Shape* shape_with_layout = nullptr);
Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
// Compare the result of the computation to a strings. In XLA strings are
// represented using rank-1 U8 shapes.
- void ComputeAndCompareR1U8(
- XlaBuilder* builder, tensorflow::StringPiece expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
+ absl::Span<GlobalData* const> arguments);
// Convenience method for running a built computation, transferring the
// result, and comparing it to the expected tuple literal.
- void ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- void ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+ void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments);
+ void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
+ absl::Span<GlobalData* const> arguments,
+ ErrorSpec error);
// Convenience method for running a built computation and comparing the result
// with the reference result.
void ComputeAndCompare(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
+ absl::Span<const Literal> arguments);
void ComputeAndCompare(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error);
+ absl::Span<const Literal> arguments, ErrorSpec error);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
@@ -286,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
+ return AddParam(LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -301,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -337,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR1Parameter(
- tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle);
// Creates a parameter instruction that wraps the given constant array
@@ -379,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test {
// Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
+ StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
+ XlaBuilder* builder, absl::Span<const Literal> arguments);
Client* client_;
Client* ref_client_; // To compute reference result.
@@ -390,12 +385,12 @@ class ClientLibraryTestBase : public ::testing::Test {
private:
Status ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output);
Status ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
@@ -415,130 +410,126 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR0<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ XlaBuilder* builder, absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ XlaBuilder* builder, absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal =
+ absl::Span<GlobalData* const> arguments) {
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
- std::unique_ptr<Literal> expected_literal =
+ Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
@@ -546,27 +537,27 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR0(value);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
- tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ absl::Span<const NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR1(values);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -574,13 +565,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -588,13 +579,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
- if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralUtil::ConvertF32ToBF16(*literal);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
@@ -613,7 +604,7 @@ template <typename NativeT>
std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
- auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
+ auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index c898dacf48..6f2ca84bb6 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
- auto computed, client_->Transfer(*data, &expected_literal->shape()));
+ auto computed, client_->Transfer(*data, &expected_literal.shape()));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
@@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralSlice(*result, {0}));
+ LiteralSlice(result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralSlice(*result, {1}));
+ LiteralSlice(result, {1}));
- EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
- EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+ EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::GetTupleElementShape(result.shape(), 0),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1})));
EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::GetTupleElementShape(result.shape(), 1),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})));
}
@@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
+ LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
Add(Parameter(&b, 0, shape, "param_0"),
@@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
- client_->Transfer(*results[0], &expected_result->shape()));
+ client_->Transfer(*results[0], &expected_result.shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 7c52c9fbbb..6ef7ca035f 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -38,25 +38,24 @@ namespace {
class CompilationCacheTest : public ClientLibraryTestBase {
public:
- void ExecuteComputationR0F32(
- const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
- bool expect_cache_hit) {
+ void ExecuteComputationR0F32(const XlaComputation& computation,
+ absl::Span<GlobalData* const> arguments,
+ float expected_result, bool expect_cache_hit) {
ExecutionProfile execution_profile;
- std::unique_ptr<Literal> result =
+ Literal result =
client_
->ExecuteAndTransfer(computation, arguments,
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
void ExecuteComputationR2F32(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ absl::Span<GlobalData* const> arguments,
std::initializer_list<std::initializer_list<float>> expected_result,
bool expect_cache_hit) {
ExecutionProfile execution_profile;
@@ -64,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase {
->Execute(computation, arguments,
&execution_options_, &execution_profile)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data_handle).ConsumeValueOrDie();
+ Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
+ LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
@@ -146,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
- client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
- client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+ client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 5a06d061f0..3b0414a604 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test {
LOG(FATAL) << "invalid client_type value";
}
- StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
- Client* client, const XlaOp& operand, XlaBuilder* builder,
- Layout* output_layout = nullptr) {
+ StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+ XlaBuilder* builder,
+ Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
TF_ASSIGN_OR_RETURN(auto computed,
client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test {
XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
builder, nullptr));
- return literal->Get<Scalar>({});
+ return literal.Get<Scalar>({});
}
bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
- "depends on a parameter"))
+ EXPECT_TRUE(
+ absl::StrContains(value.status().ToString(), "depends on a parameter"))
<< value.status();
}
}
@@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
- EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(),
- "depends on a parameter"))
+ EXPECT_TRUE(
+ absl::StrContains(value.status().ToString(), "depends on a parameter"))
<< value.status();
}
}
@@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR1<int32>({4, 6});
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
@@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateR2WithLayout<int32>(
- {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+ Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
- expected_literal->shape(), computed->shape()));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+ expected_literal.shape(), computed.shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index be017477d8..9811a015e9 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
@@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
@@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
- auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
- auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+ auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+ auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index b27c1044ba..32cac499c7 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
- LiteralUtil::CreateR0<float>(25.0f).get()}),
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+ LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
- LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<bool>(true).get(),
- LiteralUtil::CreateR0<float>(12.2f).get(),
- LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(true),
+ LiteralUtil::CreateR0<float>(12.2f),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(46.6f).get(),
- LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
- .get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
- LiteralUtil::CreateR0<float>(9.3f).get()})
- .get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(46.6f),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+ LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
- LiteralUtil::CreateR0<float>(b).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@@ -642,5 +638,57 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
test_swap(11.24f, 5.55f);
}
+// Test conditional that duplicates tuple elements in the then and else
+// computations. This is a regression test for b/112550242.
+XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
+ const Shape scalar = ShapeUtil::MakeShape(S32, {});
+ const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
+ XlaComputation then_comp;
+ {
+ XlaBuilder builder(TestName() + ".then");
+ auto p = Parameter(&builder, 0, tuple2, "then.p");
+ auto e0 = GetTupleElement(p, 0);
+ auto e1 = GetTupleElement(p, 1);
+ Tuple(&builder, {e0, e1, e0});
+ then_comp = builder.Build().ConsumeValueOrDie();
+ }
+ XlaComputation else_comp;
+ {
+ XlaBuilder builder(TestName() + ".else");
+ auto p = Parameter(&builder, 0, tuple2, "else.p");
+ auto e0 = GetTupleElement(p, 0);
+ auto e1 = GetTupleElement(p, 1);
+ Tuple(&builder, {e0, e1, e1});
+ else_comp = builder.Build().ConsumeValueOrDie();
+ }
+
+ {
+ // Pred is true case.
+ std::vector<Literal> args;
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(true));
+ XlaBuilder builder(TestName() + ".main");
+ auto p = Parameter(&builder, 0, tuple2, "p0");
+ auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
+ Conditional(p_pred, p, then_comp, p, else_comp);
+ ComputeAndCompare(&builder, args);
+ }
+ {
+ // Pred is false case.
+ std::vector<Literal> args;
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(false));
+ XlaBuilder builder(TestName() + ".main");
+ auto p = Parameter(&builder, 0, tuple2, "p0");
+ auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
+ Conditional(p_pred, p, then_comp, p, else_comp);
+ ComputeAndCompare(&builder, args);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4937574831..72ff1e74a4 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) {
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4D(input_array);
+ Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *input_literal);
+ ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
@@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
- LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+ LiteralUtil::CreateR1<float>({2.0, 42})}));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
+ Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
- LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ LiteralSlice(result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
- ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+ ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 1adc68cc48..5f063e6784 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -209,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -228,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, F32);
@@ -246,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, U32);
@@ -263,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -280,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -317,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
- auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+ Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
- client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+ client_->TransferToServer(arg_literal).ConsumeValueOrDie();
ConvertElementType(arg_param, S64);
@@ -447,15 +448,15 @@ std::vector<float> GetInterestingF16ConversionTestCases() {
XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
std::vector<float> test_cases = GetInterestingF16ConversionTestCases();
std::vector<half> input;
- c_transform(test_cases, std::back_inserter(input),
- [](float f) { return Eigen::half(f); });
+ absl::c_transform(test_cases, std::back_inserter(input),
+ [](float f) { return Eigen::half(f); });
std::vector<float> expected_output;
- c_transform(input, std::back_inserter(expected_output),
- [](Eigen::half h) { return static_cast<float>(h); });
+ absl::c_transform(input, std::back_inserter(expected_output),
+ [](Eigen::half h) { return static_cast<float>(h); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -470,12 +471,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
std::vector<float> input = GetInterestingF16ConversionTestCases();
std::vector<half> expected_output;
- c_transform(input, std::back_inserter(expected_output),
- [](float f) { return Eigen::half(f); });
+ absl::c_transform(input, std::back_inserter(expected_output),
+ [](float f) { return Eigen::half(f); });
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 7b6bbc4f57..fd98bf29b8 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -17,11 +17,11 @@ limitations under the License.
#include <array>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -88,13 +88,12 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) {
XLA_TEST_F(ConvolutionDimensionNumbersTest,
TwoConvsWithDifferentDimensionNumbers) {
- auto input_array = MakeUnique<Array4D<float>>(2, 3, 5, 5);
+ auto input_array = absl::make_unique<Array4D<float>>(2, 3, 5, 5);
input_array->FillWithMultiples(0.1);
- auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
+ auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 689928aee4..070b092d18 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -26,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -35,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -71,16 +70,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
const int kKernelSizeY = 2;
const int kOutputActivationSizeZ = 256;
const int kMiniBatchSize = 4;
- auto alhs =
- MakeUnique<Array4D<T>>(kMiniBatchSize, kInputActivationSizeZ,
- kInputActivationSizeY, kInputActivationSizeX);
+ auto alhs = absl::make_unique<Array4D<T>>(
+ kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY,
+ kInputActivationSizeX);
alhs->FillWithMultiples(static_cast<T>(1.0f));
ASSERT_EQ(3, alhs->width());
ASSERT_EQ(3, alhs->height());
- auto arhs =
- MakeUnique<Array4D<T>>(kOutputActivationSizeZ, kInputActivationSizeZ,
- kKernelSizeY, kKernelSizeX);
+ auto arhs = absl::make_unique<Array4D<T>>(kOutputActivationSizeZ,
+ kInputActivationSizeZ,
+ kKernelSizeY, kKernelSizeX);
Array2D<T> rhs_raster({
{1.0f, 0.0f}, // row 0
{0.0f, 0.0f}, // row 1
@@ -124,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -158,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -193,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -225,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
};
@@ -250,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -285,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -320,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -351,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -387,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -436,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
- auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
- auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
- auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
+ auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r5).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r5,
+ ComputeAndCompareLiteral(&builder, expected_r5,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -499,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -559,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@@ -572,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -625,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
+ ComputeAndCompareLiteral(&builder, expected_r4,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -693,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(param0)),
- std::move(*LiteralUtil::CreateFromArray(param1))},
+ {LiteralUtil::CreateFromArray(param0),
+ LiteralUtil::CreateFromArray(param1)},
error_spec_);
}
@@ -750,26 +749,25 @@ class Convolve1D1WindowTestBase
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
- auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
- auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
- auto expected_r3 =
- expected_r1->Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
+ auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
auto input_literal =
- client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+ client_->TransferToServer(input_r3).ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r3,
+ client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
@@ -869,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))},
+ {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)},
error_spec_);
}
@@ -892,9 +890,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillIota(10);
- ComputeAndCompare(&builder,
- {std::move(*LiteralUtil::CreateFromArray(input_data)),
- std::move(*LiteralUtil::CreateFromArray(filter_data))});
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
+ LiteralUtil::CreateFromArray(filter_data)});
+}
+
+XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
+ XlaBuilder builder(TestName());
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
+ Array4D<float> input_data(1, 64, 100, 100);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
+ Array4D<float> filter_data(7, 7, 1, 64);
+ input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = ConstantR4FromArray4D(&builder, filter_data);
+
+ // Specify bf01_01io->bf01 as dimension numbers.
+ ConvolutionDimensionNumbers dnums;
+ // Input
+ dnums.set_input_feature_dimension(1);
+ dnums.set_input_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_input_spatial_dimensions(3);
+ // Kernel
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ // Output
+ dnums.set_output_batch_dimension(0);
+ dnums.set_output_feature_dimension(1);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(3);
+ ConvGeneral(input, filter, /*window_strides=*/{1, 1},
+ /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
+ /*feature_group_count=*/64);
+
+ ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
+ error_spec_);
}
class ConvolutionHloTest : public HloTestBase {};
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 6784c16715..ba3e9c436e 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
- weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto weights = ConstantLiteral(&builder, *weights_literal);
+ weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto weights = ConstantLiteral(&builder, weights_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
auto mirrored_weights = Rev(weights, {2, 3, 4});
ConvWithGeneralPadding(gradients, mirrored_weights,
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {1, 1}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
- activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
- auto activations = ConstantLiteral(&builder, *activations_literal);
+ activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
+ auto activations = ConstantLiteral(&builder, activations_literal);
auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
- gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto gradients = ConstantLiteral(&builder, *gradients_literal);
+ gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ auto gradients = ConstantLiteral(&builder, gradients_literal);
auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
- expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+ expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
auto forward_conv =
ConvGeneralDilated(activations, gradients,
@@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
Transpose(forward_conv, {0, 1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 5ef273e5a2..1407e68d9a 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -16,10 +16,10 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -40,49 +40,48 @@ class CopyOpTest : public HloTestBase {
protected:
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(literal.CloneToUnique()));
+ auto constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
- tensorflow::gtl::ArraySlice<int64> permutation);
+ absl::Span<const int64> permutation);
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+ TestCopyOp(LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(
- *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4(
+ TestCopyOp(LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
// Copy literal to device to use as parameter.
auto literal = LiteralUtil::CreateR0<float>(42.0);
- Shape shape = literal->shape();
+ Shape shape = literal.shape();
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
@@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result =
- ExecuteAndTransfer(std::move(module), {literal.get()});
- LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {&literal});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
error_spec_);
}
XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
- Layout* literal_layout =
- literal->mutable_shape_do_not_use()->mutable_layout();
+ Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
@@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
// The result of the computation has the default layout, which is the inverse
// of the layout of the source literal.
- LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
error_spec_);
}
@@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+ Literal literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -182,14 +178,14 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+ LiteralTestUtil::ExpectR3EqualArray3D(a, result);
}
-void CopyOpTest::TestCopyConstantLayoutR4(
- size_t n1, size_t n2, size_t n3, size_t n4,
- tensorflow::gtl::ArraySlice<int64> permutation) {
+void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
+ size_t n4,
+ absl::Span<const int64> permutation) {
Array4D<int32> a(n1, n2, n3, n4);
for (size_t i = 0; i < n1; ++i) {
for (size_t j = 0; j < n2; ++j) {
@@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+ Literal literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(
auto module = CreateNewModule();
module->AddEntryComputation(std::move(computation));
ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+ LiteralTestUtil::ExpectR4EqualArray4D(a, result);
}
XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
XlaBuilder builder(TestName());
Parameter(&builder, 0, in_shape, "input");
- auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
+ auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index d12a4e7fcd..410732c07b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
- EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+ EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(
- *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
}
// On the GPU backend, constants get special handling. Someone might pass a
@@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
- EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
- *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+ EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+ ExecuteAndTransfer(std::move(module), {&literal0}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 13c777835e..a693fa3595 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
module->AddEntryComputation(builder.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+ Literal result = ExecuteAndTransfer(std::move(module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest,
@@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest,
module->AddEntryComputation(b.Build());
- std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+ Literal result = ExecuteAndTransfer(std::move(module), {});
LiteralTestUtil::ExpectR3EqualArray3D<float>(
- Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
index 5f234f36a8..86fd1ceb13 100644
--- a/tensorflow/compiler/xla/tests/deallocation_test.cc
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
@@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase {
// Build and execute the given computation then verify the results can be
// transferred from the device successfully.
std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
XlaComputation computation = builder->Build().ConsumeValueOrDie();
auto global_data =
client_->Execute(computation, arguments, &execution_options_)
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index 2db6503afa..e0f23b0fa8 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase {
// Build and execute the given computation then verify the results can be
// transferred from the device successfully.
std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
XlaComputation computation = builder->Build().ConsumeValueOrDie();
auto global_data =
client_->Execute(computation, arguments, &execution_options_)
@@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
// Try copying the elements back and comparing it
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
auto handles1 = result_status1.ConsumeValueOrDie();
auto handles2 = result_status2.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
handles1[0].reset();
handles1[1].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
}
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
// the same as handle[3] and handle[1] should be the same as handle[2].
auto handles = result_status.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
// should not have been deallocated because of reference counting.
global_data.reset();
- std::unique_ptr<Literal> literal;
+ Literal literal;
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
- LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
/// Try deallocating one of the repeated elements, then copy
handles[0].reset();
TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
- LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+ LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
}
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0e9e92ed99..0171f51583 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -67,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
+ LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -195,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -218,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -261,16 +262,14 @@ string PrintDotTestParam(
const ::testing::TestParamInfo<DotTestParam>& test_param) {
const DotTestParam& param = test_param.param;
if (param.has_addend) {
- return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
- "_MajorToMinor",
- param.dot_lhs_row_major ? "T" : "F",
- param.dot_rhs_row_major ? "T" : "F",
- param.addend_row_major ? "T" : "F");
+ return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
+ param.dot_lhs_row_major ? "T" : "F",
+ param.dot_rhs_row_major ? "T" : "F",
+ param.addend_row_major ? "T" : "F");
} else {
- return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n,
- "_MajorToMinor",
- param.dot_lhs_row_major ? "T" : "F",
- param.dot_rhs_row_major ? "T" : "F");
+ return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
+ param.dot_lhs_row_major ? "T" : "F",
+ param.dot_rhs_row_major ? "T" : "F");
}
}
@@ -287,24 +286,23 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
- param.dot_lhs_row_major)));
+ Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
- client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
- std::unique_ptr<Literal> dot_rhs_lit =
+ Literal dot_rhs_lit =
LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
- client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
+ client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
std::unique_ptr<Array2D<NativeT>> addend_data;
- std::unique_ptr<Literal> addend_lit;
+ Literal addend_lit;
std::unique_ptr<GlobalData> addend_handle;
if (param.has_addend) {
@@ -312,7 +310,7 @@ void ParametricDotTest::TestImpl() {
addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
- addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
+ addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
}
XlaBuilder builder(TestName());
@@ -478,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -512,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+ ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -585,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -593,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -631,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -669,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto x_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{9.0f, 10.0f}, {11.0f, 12.0f}},
{{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@@ -677,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
auto y_data =
this->client_
- ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
{{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
.ConsumeValueOrDie();
@@ -709,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
auto lhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*lhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ LiteralUtil::CreateR2FromArray2DWithLayout<T>(
*rhs, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
@@ -779,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -828,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 7f6f203a1b..7501c6d957 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -114,23 +114,23 @@ class DynamicSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
- void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
+ void RunR1(absl::Span<const int> input_values_int,
const std::vector<IndexT> slice_starts,
const std::vector<int64>& slice_sizes,
- tensorflow::gtl::ArraySlice<int> expected_values_int) {
+ absl::Span<const int> expected_values_int) {
// bfloat16 has explicit constructors, so it does not implicitly convert the
// way built-in types do, which is why we can't take the parameter as an
- // ArraySlice<DataT>. We also can't convert it to a vector, because
- // vector<bool> is special so that it cannot be an ArraySlice<bool>, which
+ // Span<DataT>. We also can't convert it to a vector, because
+ // vector<bool> is special so that it cannot be a Span<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie();
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*LiteralUtil::CreateR0(input_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(input_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_value =
- std::move(*LiteralUtil::CreateR0(update_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(update_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_value =
- std::move(*LiteralUtil::CreateR0(expected_value_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR0(expected_value_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -385,22 +385,22 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
}
template <typename IndexT, typename DataT>
- void RunR1(tensorflow::gtl::ArraySlice<int> input_values_int,
- tensorflow::gtl::ArraySlice<int> update_values_int,
+ void RunR1(absl::Span<const int> input_values_int,
+ absl::Span<const int> update_values_int,
const std::vector<IndexT> slice_starts,
- tensorflow::gtl::ArraySlice<int> expected_values_int) {
+ absl::Span<const int> expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR1(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR1(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR1(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR1(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal update_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
Literal expected_values =
- std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
- ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
- .ValueOrDie());
+ std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+ .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+ .ValueOrDie());
XlaBuilder builder(TestName());
// Initialize and transfer dynamic slice start indices parameter.
@@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
- std::unique_ptr<Literal> literal =
- LiteralUtil::CreateR3FromArray3D<NativeT>(values);
- LOG(INFO) << name << ":" << literal->ToString();
+ Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << literal.ToString();
}
};
@@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) {
auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- auto input = ConstantLiteral(&builder, *input_literal);
+ auto input = ConstantLiteral(&builder, input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) {
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- stream.get(), *start_indices_literal, buffer));
+ stream.get(), start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 5116e60ca6..b08ece0e63 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index bf1de02ba9..738f2600d4 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -38,7 +38,7 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal =
+ Literal input_literal =
LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
index 39cc6c5927..3be9657db4 100644
--- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -16,13 +16,13 @@ limitations under the License.
#include <limits>
#include <string>
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -37,10 +37,9 @@ class FloorCeilTest : public ClientLibraryTestBase {
};
// Runs a computation and comparison on expected vs f(input)
- void TestR1F32(tensorflow::gtl::ArraySlice<float> input,
- tensorflow::gtl::ArraySlice<float> expected, Function f) {
- LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ")
- << "}";
+ void TestR1F32(absl::Span<const float> input,
+ absl::Span<const float> expected, Function f) {
+ LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}";
XlaBuilder builder(TestName());
auto c = ConstantR1<float>(&builder, input);
if (f == kCeil) {
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 792be0d3fc..9c94acb437 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -22,13 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -42,14 +43,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::gtl::ArraySlice;
-
namespace xla {
namespace {
@@ -113,26 +111,26 @@ class FusionTest : public HloTestBase {
hlos[0] = builder.AddInstruction(std::move(root_hlo));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(
- ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
+ absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1),
HloInstruction::FusionKind::kLoop);
auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
} else {
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
}
private:
template <typename T>
- T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
+ T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span<const float> xs);
};
template <>
float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
- ArraySlice<float> xs) {
+ absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kAdd:
return xs[0] + xs[1];
@@ -157,7 +155,7 @@ float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
template <>
bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
- ArraySlice<float> xs) {
+ absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kEq:
return xs[0] == xs[1];
@@ -224,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
@@ -250,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -285,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
// Every element of result should be y = x^2 = 4.0.
for (int i = 0; i < rand_dim0_size; ++i) {
for (int j = 0; j < dim1_size; ++j) {
- EXPECT_EQ(4.0, result->Get<float>({i, j}));
+ EXPECT_EQ(4.0, result.Get<float>({i, j}));
}
}
}
@@ -310,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+ LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -325,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -340,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -355,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -370,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -385,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
@@ -400,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -415,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -430,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -445,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
@@ -461,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -479,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -497,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
@@ -515,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -537,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -554,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, TransposeNegate) {
@@ -572,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -601,11 +599,11 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
- HloInstruction::FusionKind::kLoop);
+ HloInstruction::FusionKind::kInput);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -626,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -676,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
@@ -712,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+ ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
@@ -784,19 +782,17 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> operand =
- LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ Literal operand = LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text, config));
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
- test_runner_.Execute(std::move(module), {operand.get()},
- /*run_hlo_passes=*/false));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
+ test_runner_.Execute(std::move(module), {&operand},
+ /*run_hlo_passes=*/false));
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
- *result));
+ LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ result));
}
class FusionClientLibraryTest : public ClientLibraryTestBase {};
@@ -823,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
// where overflow is OK.
Array2D<uint32> arr(32, 32);
arr.FillUnique();
- std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({0, 1}));
- std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+ Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
LayoutUtil::MakeLayout({1, 0}));
- XlaOp p0 = AddParam(*l1, &b);
+ XlaOp p0 = AddParam(l1, &b);
XlaOp sum = p0;
for (int i = 1; i < kNumParams; ++i) {
- auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+ auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
sum = sum + p0 * pN * pN;
}
@@ -881,19 +877,19 @@ void BM_ParallelFusion(int num_iters) {
auto param0_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
- client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
- client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
- client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
+ client->LiteralToShapedBuffer(param2_literal, device_ordinal)
.ConsumeValueOrDie();
// Build executable.
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index b77bece85a..daa89398a6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -25,17 +25,16 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class GatherOperationTest : public HloTestBase {
protected:
void RunTest(const string& hlo_text, Literal* operand,
- Literal* gather_indices) {
- RunTest(hlo_text, {operand, gather_indices});
+ Literal* start_indices) {
+ RunTest(hlo_text, {operand, start_indices});
}
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -52,18 +51,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -74,18 +72,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -96,18 +93,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -118,18 +114,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -140,18 +136,18 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
ROOT gather = s32[2,1,1,2] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -162,20 +158,19 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -186,20 +181,19 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -210,18 +204,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -232,18 +225,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
ROOT gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -254,17 +246,16 @@ ENTRY main {
operand = s32[3,0] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,0] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 0}
+ slice_sizes={1, 0}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -278,19 +269,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -304,19 +295,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = u32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ Literal start_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -330,19 +321,19 @@ ENTRY main {
operand = s32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = s32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = s32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -356,19 +347,19 @@ ENTRY main {
operand = u32[3,3]{1,0} parameter(0)
indices = s32[6,2]{1,0} parameter(1)
gather = u32[6,1,1]{2,1,0} gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1}
+ slice_sizes={1,1}
ROOT result = u32[6]{0} reshape(gather)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ Literal start_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -379,17 +370,17 @@ ENTRY main {
operand = s32[2,3,2]{2,1,0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index),
- output_window_dims={0,1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0},
+ offset_dims={0,1,2},
+ collapsed_slice_dims={},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1,3,2}
+ slice_sizes={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -400,16 +391,16 @@ ENTRY main {
operand = s32[4]{0} parameter(0)
index = s32[] parameter(1)
ROOT gather = s32[] gather(operand, index),
- output_window_dims={},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=0,
- window_bounds={1}
+ slice_sizes={1}
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -420,17 +411,17 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[0] parameter(1)
ROOT gather = s32[0,3] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0},
- gather_dims_to_operand_dims={0},
+ offset_dims={1},
+ collapsed_slice_dims={0},
+ start_index_map={0},
index_vector_dim=1,
- window_bounds={1, 3}
+ slice_sizes={1, 3}
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -441,21 +432,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[3,2] gather(operand, indices),
- output_window_dims={0},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={0},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=1,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[3,2] broadcast(one), dimensions={}
ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -466,21 +456,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,3,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={1},
- gather_dims_to_operand_dims={1},
+ offset_dims={1},
+ collapsed_slice_dims={1},
+ start_index_map={1},
index_vector_dim=2,
- window_bounds={3, 1}
+ slice_sizes={3, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -491,21 +480,21 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=2,
- window_bounds={1, 1}
+ slice_sizes={1, 1}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
+ Literal start_indices =
LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -516,23 +505,22 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=1,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest,
@@ -544,23 +532,22 @@ ENTRY main {
operand = s32[3,3,2] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,2] gather(operand, indices),
- output_window_dims={1},
- elided_window_dims={0,1},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1},
+ collapsed_slice_dims={0,1},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1,2}
+ slice_sizes={1,1,2}
one = s32[] constant(1)
one_broadcasted = s32[2,2] broadcast(one), dimensions={}
ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -571,21 +558,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
gather = s32[1,1] gather(operand, indices),
- output_window_dims={0,1},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={0,1},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[1,1] broadcast(one), dimensions={}
ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -596,21 +582,20 @@ ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2,2] parameter(1)
gather = s32[2,1,1] gather(operand, indices),
- output_window_dims={1,2},
- elided_window_dims={},
- gather_dims_to_operand_dims={0,1},
+ offset_dims={1,2},
+ collapsed_slice_dims={},
+ start_index_map={0,1},
index_vector_dim=0,
- window_bounds={1,1}
+ slice_sizes={1,1}
one = s32[] constant(1)
one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- RunTest(hlo_text, operand.get(), gather_indices.get());
+ Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ RunTest(hlo_text, &operand, &start_indices);
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -622,11 +607,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// operand = s32[3,3] parameter(0)
// indices = s32[2] parameter(1)
// ROOT gather = s32[2,3] gather(operand, indices),
- // output_window_dims={1},
- // elided_window_dims={0},
- // gather_dims_to_operand_dims={0},
+ // offset_dims={1},
+ // collapsed_slice_dims={0},
+ // start_index_map={0},
// index_vector_dim=1,
- // window_bounds={1, 3}
+ // slice_sizes={1, 3}
// }
XlaBuilder builder("gather_basic");
@@ -637,9 +622,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
auto operand = Parameter(&builder, 0, operand_shape, "operand");
auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
- dim_numbers.add_output_window_dims(1);
- dim_numbers.add_elided_window_dims(0);
- dim_numbers.add_gather_dims_to_operand_dims(0);
+ dim_numbers.add_offset_dims(1);
+ dim_numbers.add_collapsed_slice_dims(0);
+ dim_numbers.add_start_index_map(0);
dim_numbers.set_index_vector_dim(1);
Gather(operand, indices, dim_numbers, {1, 3});
@@ -647,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> operand_arg,
client_->TransferToServer(
- *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -664,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<xla::GlobalData>> result_data,
client_->ExecuteParallel(computation_instances));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
- *result_literal);
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 51450314b6..1115e50fe3 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
[](half x) { return isfinite(x); }, &IsFinite}));
-using BinaryBuildFuncTy =
- std::function<void(const xla::XlaOp& x, const xla::XlaOp& y,
- tensorflow::gtl::ArraySlice<int64>)>;
+using BinaryBuildFuncTy = std::function<void(
+ const xla::XlaOp& x, const xla::XlaOp& y, absl::Span<const int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 64e361f14f..bdd4fd7e3d 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -20,9 +20,11 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -32,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -41,9 +42,8 @@ namespace xla {
namespace {
-using tensorflow::StringPiece;
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::optional;
+using absl::optional;
+using absl::string_view;
constexpr char kInterpreter[] = "interpreter";
@@ -85,21 +85,24 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
-HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier)
+HloTestBase::HloTestBase(bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
+ verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
+ bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ =
- MakeUnique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
+ hlo_verifier_ = absl::make_unique<HloVerifier>(
+ /*layout_sensitive=*/verifier_layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
}
-/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
- return MakeUnique<HloModule>(name, GetModuleConfigForTest());
+ return absl::make_unique<HloModule>(name, GetModuleConfigForTest());
}
/* static */
@@ -117,7 +120,14 @@ StatusOr<bool> HloTestBase::RunHloPass(HloPassInterface* hlo_pass,
return status_or;
}
-/*static*/
+/* static */
+PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfig::DEFAULT);
+ return precision_config;
+}
+
DebugOptions HloTestBase::GetDebugOptionsForTest() {
auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
@@ -126,24 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() {
return debug_options;
}
-StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments);
}
-std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_
.Execute(std::move(module), arguments,
/*run_hlo_passes=*/false)
.ValueOrDie();
}
-std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments) {
+Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments) {
return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
}
@@ -166,7 +173,8 @@ StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
}
StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor) {
TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
@@ -180,12 +188,13 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
TF_ASSIGN_OR_RETURN(auto reference,
reference_runner_.Execute(std::move(reference_module),
arguments, run_hlo_passes));
- return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
+ return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
error);
}
::testing::AssertionResult HloTestBase::RunAndCompare(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
@@ -198,7 +207,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
- std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
+ std::unique_ptr<HloModule> module,
+ const absl::Span<Literal* const> arguments,
const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto result =
@@ -213,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
::testing::AssertionResult HloTestBase::RunAndCompare(
std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
- const auto& fake_arguments =
- MakeFakeArguments(module.get()).ConsumeValueOrDie();
+ auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompare(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
@@ -231,17 +240,16 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
const auto& fake_arguments =
MakeFakeArguments(module.get()).ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
reference_preprocessor);
}
::testing::AssertionResult HloTestBase::RunAndCompare(
- const StringPiece hlo_string,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ string_view hlo_string, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
@@ -254,7 +262,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
reference_preprocessor);
}
-::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) {
+::testing::AssertionResult HloTestBase::Run(string_view hlo_string) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
if (!module_or_status.ok()) {
@@ -266,9 +274,9 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
MakeFakeArguments(module_or_status.ValueOrDie().get())
.ConsumeValueOrDie();
std::vector<Literal*> fake_argument_ptrs;
- c_transform(
+ absl::c_transform(
fake_arguments, std::back_inserter(fake_argument_ptrs),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ [](const Literal& literal) { return const_cast<Literal*>(&literal); });
return test_runner_
.Execute(std::move(module_or_status.ValueOrDie()),
fake_argument_ptrs, /*run_hlo_passes=*/true)
@@ -278,7 +286,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareFromFile(
- const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
+ const string& filename, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
@@ -291,8 +299,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses(
- const StringPiece hlo_string,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ string_view hlo_string, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
@@ -306,7 +313,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile(
- const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
+ const string& filename, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor) {
auto module_or_status =
HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest());
@@ -319,10 +326,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
}
HloComputation* HloTestBase::FindComputation(HloModule* module,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
auto computations = module->computations();
- auto it = c_find_if(computations,
- [&](HloComputation* c) { return c->name() == name; });
+ auto it = absl::c_find_if(
+ computations, [&](HloComputation* c) { return c->name() == name; });
if (it == computations.end()) {
return nullptr;
}
@@ -330,11 +337,11 @@ HloComputation* HloTestBase::FindComputation(HloModule* module,
}
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
- tensorflow::StringPiece name) {
+ absl::string_view name) {
for (const HloComputation* c : module->computations()) {
auto instructions = c->instructions();
- auto it = c_find_if(instructions,
- [&](HloInstruction* i) { return i->name() == name; });
+ auto it = absl::c_find_if(
+ instructions, [&](HloInstruction* i) { return i->name() == name; });
if (it != instructions.end()) {
return *it;
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index c860c416f1..0ae4bdc104 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -20,6 +20,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -31,8 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
@@ -72,8 +72,7 @@ class HloTestBase : public ::testing::Test {
// options from command-line flags. If you want a fresh HloModule object and
// then add HloComputations to it, it's recommended to use this method in your
// tests.
- static std::unique_ptr<HloModule> CreateNewModule(
- const string& name = TestName());
+ std::unique_ptr<HloModule> CreateNewModule(const string& name = TestName());
// Runs the hlo_pass with the provided module and returns the result. This
// function also verifies that the module remains unchanged when hlo_pass
@@ -81,17 +80,21 @@ class HloTestBase : public ::testing::Test {
static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
HloModule* module);
+ static PrecisionConfig DefaultPrecisionConfig(int operands);
+
protected:
// This uses the interpreter backend as the reference backend and
// automatically finds another supported backend as the test backend. If the
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
- HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true);
+ HloTestBase(bool verifier_layout_sensitive = false,
+ bool allow_mixed_precision_in_hlo_verifier = true);
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
+ bool verifier_layout_sensitive = false,
bool allow_mixed_precision_in_hlo_verifier = true);
~HloTestBase() override {}
@@ -99,29 +102,29 @@ class HloTestBase : public ::testing::Test {
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
- static DebugOptions GetDebugOptionsForTest();
+ //
+ // This function is virtual so tests can specify an alternative set of debug
+ // options (e.g. disabling additional passes).
+ virtual DebugOptions GetDebugOptionsForTest();
// Gets an HloModuleConfig with options appropriate for tests.
- static HloModuleConfig GetModuleConfigForTest() {
+ HloModuleConfig GetModuleConfigForTest() {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
return config;
}
// Executes the given module and return the result as a Literal.
- StatusOr<std::unique_ptr<Literal>> Execute(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Same as above, except the module will be executed without running any HLO
// passes on it.
- std::unique_ptr<Literal> ExecuteNoHloPasses(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
- std::unique_ptr<Literal> ExecuteAndTransfer(
- std::unique_ptr<HloModule> module,
- tensorflow::gtl::ArraySlice<Literal*> arguments);
+ Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+ absl::Span<Literal* const> arguments);
// Executes the given hlo module on two backends and compares results.
//
@@ -136,8 +139,8 @@ class HloTestBase : public ::testing::Test {
// modified.
::testing::AssertionResult RunAndCompare(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ const absl::Span<Literal* const> arguments,
+ const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -145,23 +148,21 @@ class HloTestBase : public ::testing::Test {
// optimization.
::testing::AssertionResult RunAndCompareNoHloPasses(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ const absl::Span<Literal* const> arguments,
+ const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Executes an hlo module with fake inputs and compares the results.
::testing::AssertionResult RunAndCompare(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
// Same as above, except that the module will be executed without Hlo
// optimization.
::testing::AssertionResult RunAndCompareNoHloPasses(
- std::unique_ptr<HloModule> module,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -169,23 +170,23 @@ class HloTestBase : public ::testing::Test {
// input. Module can be passed in directly, or parsed from an hlo_string,
// or loaded from a file.
::testing::AssertionResult RunAndCompare(
- const tensorflow::StringPiece hlo_string,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ const absl::string_view hlo_string,
+ const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
- ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string)
+ ::testing::AssertionResult Run(const absl::string_view hlo_string)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareFromFile(
- const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
+ const string& filename, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareNoHloPasses(
- const tensorflow::StringPiece hlo_string,
- const tensorflow::gtl::optional<ErrorSpec>& error,
+ const absl::string_view hlo_string,
+ const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
::testing::AssertionResult RunAndCompareNoHloPassesFromFile(
- const string& filename, const tensorflow::gtl::optional<ErrorSpec>& error,
+ const string& filename, const absl::optional<ErrorSpec>& error,
const std::function<void(HloModule*)>& reference_preprocessor = nullptr)
TF_MUST_USE_RESULT;
@@ -228,10 +229,8 @@ class HloTestBase : public ::testing::Test {
//
// This is useful for tests which create HLOs from a string and then want to
// inspect a particular computation or instruction.
- HloComputation* FindComputation(HloModule* module,
- tensorflow::StringPiece name);
- HloInstruction* FindInstruction(HloModule* module,
- tensorflow::StringPiece name);
+ HloComputation* FindComputation(HloModule* module, absl::string_view name);
+ HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
// Return an HLO verifier constructed for the test backend.
HloVerifier& verifier() const { return *hlo_verifier_; }
@@ -261,8 +260,8 @@ class HloTestBase : public ::testing::Test {
// error happens before the results are computed, returns the error status.
StatusOr<::testing::AssertionResult> RunAndCompareInternal(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
- const tensorflow::gtl::optional<ErrorSpec>& error, bool run_hlo_passes,
+ const absl::Span<Literal* const> arguments,
+ const absl::optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor);
};
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index ad1f5b9eed..8f86c528d0 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -24,8 +25,11 @@ limitations under the License.
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase()
- : shape_verifier_(MakeUnique<ShapeVerifier>()) {}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
@@ -50,8 +54,7 @@ void HloVerifiedTestBase::TearDown() {
}
void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- HloVerifier verifier(/*allow_mixed_precision=*/true);
- xla::StatusOr<bool> mutated = verifier.Run(module);
+ xla::StatusOr<bool> mutated = verifier().Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -72,7 +75,7 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
return modules_.back().get();
}
-void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 5b28c01c36..8fbc4fa753 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -29,7 +29,8 @@ namespace xla {
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
protected:
- HloVerifiedTestBase();
+ explicit HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
~HloVerifiedTestBase() override;
// Constructs a default shape verifier.
@@ -44,32 +45,28 @@ class HloVerifiedTestBase : public HloTestBase {
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
HloModule& module();
- void ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+ void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
- // Sets the shape-size function used during hlo verification. If this isn't
- // called, a default ShapeVerifier is used instead.
- void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) {
- shape_verifier_ = std::move(shape_verifier);
- }
-
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
HloModule* CreateNewModule(const string& name = TestName());
+ private:
+ void VerifyModule(HloModule* module);
+
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
- private:
+ //
// Lazily populated. Access via module().
std::unique_ptr<HloModule> module_;
// Populated by calls to CreateNewModule.
std::vector<std::unique_ptr<HloModule>> modules_;
- std::unique_ptr<ShapeVerifier> shape_verifier_;
+
bool tear_down_called_ = false;
- static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc
index 17ac95ae01..310f349592 100644
--- a/tensorflow/compiler/xla/tests/iota_test.cc
+++ b/tensorflow/compiler/xla/tests/iota_test.cc
@@ -23,40 +23,95 @@ limitations under the License.
namespace xla {
namespace {
-class IotaTest : public ClientLibraryTestBase {
- public:
- explicit IotaTest(se::Platform* platform = nullptr)
- : ClientLibraryTestBase(platform) {}
- template <typename T>
- std::vector<T> GetExpected(const int64 num_elements) {
- std::vector<T> result(num_elements);
- std::iota(result.begin(), result.end(), 0);
- return result;
+template <typename T>
+std::vector<T> GetR1Expected(const int64 num_elements) {
+ std::vector<T> result(num_elements);
+ std::iota(result.begin(), result.end(), 0);
+ return result;
+}
+
+class IotaR1Test
+ : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<std::tuple<PrimitiveType, int>> {};
+
+TEST_P(IotaR1Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ Iota(&builder, element_type, num_elements);
+ if (element_type == F32) {
+ ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {},
+ ErrorSpec{0.0001});
+ } else if (element_type == U32) {
+ ComputeAndCompareR1<uint32>(&builder, GetR1Expected<uint32>(num_elements),
+ {});
+ } else {
+ CHECK_EQ(element_type, S32);
+ ComputeAndCompareR1<int32>(&builder, GetR1Expected<int32>(num_elements),
+ {});
}
-};
-
-XLA_TEST_F(IotaTest, SimpleR1) {
- for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) {
- {
- XlaBuilder builder(TestName() + "_f32");
- IotaGen(&builder, F32, num_elements);
- ComputeAndCompareR1<float>(&builder, GetExpected<float>(num_elements), {},
- ErrorSpec{0.0001});
- }
- {
- XlaBuilder builder(TestName() + "_u32");
- IotaGen(&builder, U32, num_elements);
- ComputeAndCompareR1<uint32>(&builder, GetExpected<uint32>(num_elements),
- {});
- }
- {
- XlaBuilder builder(TestName() + "_s32");
- IotaGen(&builder, S32, num_elements);
- ComputeAndCompareR1<int32>(&builder, GetExpected<int32>(num_elements),
- {});
- }
+}
+
+INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test,
+ ::testing::Combine(::testing::Values(F32, U32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/10001,
+ /*step=*/10)));
+
+class IotaR2Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<PrimitiveType, int, int>> {};
+
+TEST_P(IotaR2Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ const int64 iota_dim = std::get<2>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ std::vector<int64> dimensions = {42};
+ dimensions.insert(dimensions.begin() + iota_dim, num_elements);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ if (primitive_util::IsFloatingPointType(element_type)) {
+ ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
+ } else {
+ ComputeAndCompare(&builder, {});
}
}
+INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test,
+ ::testing::Combine(::testing::Values(F32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/1001,
+ /*step=*/10),
+ ::testing::Values(0, 1)));
+
+class IotaR3Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<PrimitiveType, int, int>> {};
+
+TEST_P(IotaR3Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ const int64 iota_dim = std::get<2>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ std::vector<int64> dimensions = {42, 19};
+ dimensions.insert(dimensions.begin() + iota_dim, num_elements);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ if (primitive_util::IsFloatingPointType(element_type)) {
+ ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
+ } else {
+ ComputeAndCompare(&builder, {});
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test,
+ ::testing::Combine(::testing::Values(F32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/1001,
+ /*step=*/10),
+ ::testing::Values(0, 1, 2)));
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index cde1dcd9cd..554eb24d44 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -35,8 +35,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
int64 now_usec = tensorflow::Env::Default()->NowMicros();
string filename = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(),
- tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(),
- now_usec, name.c_str()));
+ absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name));
TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename,
literal.ToProto()));
LOG(ERROR) << "wrote to " << name << " file: " << filename;
@@ -94,7 +93,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
const LiteralSlice& expected, const LiteralSlice& actual,
- const tensorflow::gtl::optional<ErrorSpec>& error) {
+ const absl::optional<ErrorSpec>& error) {
if (error.has_value()) {
VLOG(1) << "Expects near";
return StatusToAssertion(literal_comparison::Near(
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 31a099c15f..43cca91f64 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <random>
#include <string>
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -32,8 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -62,7 +62,7 @@ class LiteralTestUtil {
static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
template <typename NativeT>
- static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
+ static void ExpectR1Equal(absl::Span<const NativeT> expected,
const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR2Equal(
@@ -102,7 +102,7 @@ class LiteralTestUtil {
const ErrorSpec& error);
template <typename NativeT>
- static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
+ static void ExpectR1Near(absl::Span<const NativeT> expected,
const LiteralSlice& actual, const ErrorSpec& error);
template <typename NativeT>
@@ -146,7 +146,7 @@ class LiteralTestUtil {
// will be compared recursively.
static ::testing::AssertionResult NearOrEqual(
const LiteralSlice& expected, const LiteralSlice& actual,
- const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
+ const absl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
@@ -155,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
- tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
+ absl::Span<const NativeT> expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -176,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
- tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
+ absl::Span<const NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -223,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -232,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index f297b2b847..b6f9b8156b 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -31,11 +31,11 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal literal = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(42).get(),
- LiteralUtil::CreateR0<int32>(64).get(),
+ Literal lhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(42),
+ LiteralUtil::CreateR0<int32>(64),
});
- std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
- LiteralUtil::CreateR0<int32>(64).get(),
- LiteralUtil::CreateR0<int32>(42).get(),
+ Literal rhs = LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR0<int32>(64),
+ LiteralUtil::CreateR0<int32>(42),
});
- CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
};
ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
}
@@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto two = LiteralUtil::CreateR0<float>(2);
auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
- CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
};
tensorflow::Env* env = tensorflow::Env::Default();
@@ -80,20 +80,20 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
std::vector<string> results;
TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
- LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
+ LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
EXPECT_EQ(3, results.size());
for (const string& result : results) {
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
if (result.find("expected") != string::npos) {
- EXPECT_EQ("2", literal->ToString());
+ EXPECT_EQ("2", literal.ToString());
} else if (result.find("actual") != string::npos) {
- EXPECT_EQ("4", literal->ToString());
+ EXPECT_EQ("4", literal.ToString());
} else if (result.find("mismatches") != string::npos) {
- EXPECT_EQ("true", literal->ToString());
+ EXPECT_EQ("true", literal.ToString());
} else {
FAIL() << "unknown file in temporary directory: " << result;
}
@@ -103,10 +103,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
- ::testing::AssertionResult result =
- LiteralTestUtil::Equal(*expected, *actual);
- EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
- EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
+ ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
+ EXPECT_THAT(result.message(),
+ ::testing::HasSubstr("Actual literal:\n{4, 5, 6}"));
}
TEST(LiteralTestUtilTest, NearComparatorR1) {
@@ -114,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
@@ -122,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
auto b = LiteralUtil::CreateR1<float>(
{0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+ EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
@@ -130,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
auto b =
LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
- EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
- EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
+ EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index e719da54d4..8d65869557 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
@@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test {
static std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return MakeUnique<HloModule>(TestName(), config);
+ return absl::make_unique<HloModule>(TestName(), config);
}
};
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index 6fc1115097..0487d31409 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -51,8 +51,9 @@ void LlvmIrGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
- TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status());
+ Status status = CompileToExecutable(std::move(hlo_module)).status();
ResetIrHook();
+ TF_ASSERT_OK(status);
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
TF_ASSERT_OK(filecheck_result.status());
@@ -73,9 +74,10 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
const string& pattern, bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
- TF_ASSERT_OK(
- CompileToAotCompilationResult(std::move(hlo_module), options).status());
+ Status status =
+ CompileToAotCompilationResult(std::move(hlo_module), options).status();
ResetIrHook();
+ TF_ASSERT_OK(status);
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
ASSERT_TRUE(filecheck_result.ok());
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index e2cd5bcc5a..dbdd20daf0 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
@@ -53,12 +53,12 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
// deallocation happen on the right allocator.
ExecutableRunOptions options;
options.set_allocator(allocator);
- tensorflow::gtl::optional<ScopedShapedBuffer> result =
+ absl::optional<ScopedShapedBuffer> result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {},
DefaultExecutableBuildOptions(), options);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
@@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
// At least one allocation should have been performed when executing the
// computation.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 1a823cf189..a99b43f469 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
auto y = ConstantR0<float>(&builder, 123.0f);
Add(x, y);
- auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
- LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
auto y = ConstantR1<float>(&builder, {});
Add(x, y);
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
- LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
+ LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
error_spec_);
}
@@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
@@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions().set_execution_profile(&profile));
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
}
@@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
ScopedShapedBuffer result_colmaj =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with the parameter values in a different order.
ScopedShapedBuffer result_param_swap =
ExecuteLocallyOrDie(computation, {&y_array, &x_array});
- LiteralTestUtil::ExpectR2Near<float>(
- {{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_param_swap), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+ ShapedBufferToLiteral(result_param_swap),
+ error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
@@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_colmaj),
+ ShapedBufferToLiteral(result_colmaj),
error_spec_);
// Run with row-major result layout.
@@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
- *ShapedBufferToLiteral(result_rowmaj),
+ ShapedBufferToLiteral(result_rowmaj),
error_spec_);
}
@@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {2}));
+ LiteralSlice(result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
+ LiteralSlice(result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
+ LiteralSlice(result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralSlice(result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
options, DefaultExecutableRunOptions());
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
+ auto y_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
- auto x_buffer = LiteralToShapedBuffer(*x_literal);
- auto y_buffer = LiteralToShapedBuffer(*y_literal);
+ auto x_buffer = LiteralToShapedBuffer(x_literal);
+ auto y_buffer = LiteralToShapedBuffer(y_literal);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
@@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
- LiteralSlice(*result_literal, {0}));
+ LiteralSlice(result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
- LiteralSlice(*result_literal, {1}));
+ LiteralSlice(result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
+ Literal result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
+ LiteralSlice(result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
- LiteralSlice(*result_0_literal, {1}));
+ LiteralSlice(result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
- std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
+ Literal result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
- LiteralSlice(*result_1_literal, {0}));
+ LiteralSlice(result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
- LiteralSlice(*result_1_literal, {1}));
+ LiteralSlice(result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
// Feed in a tuple where each two-element vector element is {tuple_index,
// -tuple_index}.
- std::vector<std::unique_ptr<Literal>> arg_elements;
+ std::vector<Literal> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
- std::unique_ptr<Literal> arg_literal =
- LiteralUtil::MakeTupleOwned(std::move(arg_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
}
}
@@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::vector<std::unique_ptr<Literal>> outer_tuple_elements;
+ std::vector<Literal> outer_tuple_elements;
for (int i = 0; i < kFanout; ++i) {
- std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
+ std::vector<Literal> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
@@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
}
auto arg_literal =
LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
- LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
- error_spec_);
+ LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
+ LiteralSlice(result_literal, {i, j}),
+ error_spec_);
}
}
}
@@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
+ Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
- std::vector<std::unique_ptr<Literal>> arg_vector;
+ std::vector<Literal> arg_vector;
arg_vector.push_back(std::move(arg_literal));
arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
- auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+ auto arg_buffer = LiteralToShapedBuffer(arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
- std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+ Literal result_literal = ShapedBufferToLiteral(result);
ShapeIndex index;
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(165.0,
- LiteralSlice(*result_literal, index));
+ LiteralSlice(result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
Neg(x);
auto x_array = LiteralToShapedBuffer(
- *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
DefaultExecutableRunOptions().set_device_ordinal(d));
EXPECT_EQ(d, result.device_ordinal());
LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ ShapedBufferToLiteral(result));
}
}
}
@@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// As a check to verify that the computation ran of the device associated
// with the stream. This is a weak check, but stronger verification is hard.
EXPECT_EQ(d, result.device_ordinal());
- LiteralTestUtil::ExpectR0Equal<float>(42.0f,
- *ShapedBufferToLiteral(result));
+ LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
}
}
@@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
- std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
+ Literal tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
- LiteralSlice(*tuple_literal, {0}));
+ LiteralSlice(tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
- LiteralSlice(*tuple_literal, {1}));
+ LiteralSlice(tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
->BlockHostUntilDone());
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+ {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
@@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
// Array shapes.
- test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
- test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
- test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*LiteralUtil::MakeTuple({}));
+ test_to_device_and_back(LiteralUtil::MakeTuple({}));
// Non-nested tuples.
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
- test_to_device_and_back(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(12223.0)}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}));
// Nested tuple.
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<float>(123456.0).get()})
- .get(),
- LiteralUtil::CreateR0<bool>(false).get()}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+ LiteralUtil::CreateR0<float>(123456.0)}),
+ LiteralUtil::CreateR0<bool>(false)}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
TF_ASSERT_OK_AND_ASSIGN(
auto transferred_literal,
local_client_->ShapedBufferToLiteral(shaped_buffer));
- EXPECT_EQ(literal, *transferred_literal);
+ EXPECT_EQ(literal, transferred_literal);
};
test_to_device_and_back(
- *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
- LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+ LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}),
+ LiteralUtil::CreateR0<int64>(123456789000LL)}));
}
XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
Add(in, constant);
- std::unique_ptr<Literal> result;
+ Literal result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
}));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
// Join the thread.
thread.reset();
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
@@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ TF_ASSERT_OK_AND_ASSIGN(Literal result,
local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal()));
- LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}
// Benchmark that measures the overhead of the LocalClient API when running a
@@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) {
auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
auto stream =
client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
- buffer));
+ ASSERT_IS_OK(
+ transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
const int kWarmups = 2;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index eaddf756db..f90ef22d2d 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -18,11 +18,11 @@ limitations under the License.
#include <vector>
+#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
.ConsumeValueOrDie();
}
-std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+Literal LocalClientTestBase::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
return local_client_->ShapedBufferToLiteral(shaped_buffer)
.ConsumeValueOrDie();
@@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
return ExecuteLocally(computation, arguments, build_options, run_options)
@@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ absl::Span<const ShapedBuffer* const> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions());
}
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
std::vector<const Shape*> argument_layouts(arguments.size());
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index b4477e9a6b..4027c7b124 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -86,26 +86,25 @@ class LocalClientTestBase : public ::testing::Test {
// Construct and return a literal containing the array represented by
// shaped_buffer.
- std::unique_ptr<Literal> ShapedBufferToLiteral(
- const ShapedBuffer& shaped_buffer);
+ Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
// Execute the given computation on the local client. With and without
// options.
StatusOr<ScopedShapedBuffer> ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
StatusOr<ScopedShapedBuffer> ExecuteLocally(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
ScopedShapedBuffer ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ absl::Span<const ShapedBuffer* const> arguments);
ScopedShapedBuffer ExecuteLocallyOrDie(
const XlaComputation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ absl::Span<const ShapedBuffer* const> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e195d4..4d327a6fe9 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) {
TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
Map(&builder, {map1}, CreateMulByTwo(), {0});
@@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
@@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
{0});
@@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+ Literal param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1});
@@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ Literal param1_literal =
LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
{0, 1, 2});
@@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) {
TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param2_literal =
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+ Literal param2_literal =
LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
- client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param2_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
- auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal =
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+ Literal param1_literal =
LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param1_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
- auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index da8c42d465..3f278115e0 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -17,12 +17,14 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -32,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -62,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
});
Exp(data);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
{0.36788f, 1.64872f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
@@ -91,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
});
Map(&builder, {data}, add_half, {0, 1});
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
{-0.5f, 1.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
@@ -110,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
});
Max(lhs, rhs);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
{3.0f, -4.0f}}); // row 1
- this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+ this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6));
}
struct TestLinspaceMaxParam {
@@ -133,10 +134,9 @@ class TestLinspaceMaxParametric
float from = -128.0, to = 256.0;
std::unique_ptr<Array2D<T>> alhs =
MakeLinspaceArray2D<T>(from, to, rows, cols);
- auto arhs = MakeUnique<Array2D<T>>(rows, cols, static_cast<T>(1.0f));
+ auto arhs = absl::make_unique<Array2D<T>>(rows, cols, static_cast<T>(1.0f));
- XlaBuilder builder(
- tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
+ XlaBuilder builder(absl::StrFormat("max_%dx%d_linspace", rows, cols));
auto lhs = ConstantR2FromArray2D<T>(&builder, *alhs);
auto rhs = ConstantR2FromArray2D<T>(&builder, *arhs);
Max(lhs, rhs);
@@ -158,7 +158,7 @@ class TestLinspaceMaxParametric
string PrintTestLinspaceMaxParam(
const ::testing::TestParamInfo<TestLinspaceMaxParam>& test_param) {
const TestLinspaceMaxParam& param = test_param.param;
- return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c");
+ return absl::StrCat(param.rows, "r", param.cols, "c");
}
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
@@ -200,14 +200,12 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(
- *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index eb06b115da..56aaeb0e68 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -19,10 +19,12 @@ limitations under the License.
#include <new>
#include <utility>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -36,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -46,18 +47,26 @@ limitations under the License.
namespace xla {
namespace {
-using ::tensorflow::gtl::ArraySlice;
-
class MultiOutputFusionTest : public HloTestBase {
protected:
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
+ // Layout assignment assumes that there are no fusions in the input graph.
+ // Since the purpose of this test is to send pre-fused graphs to XLA, we have
+ // to do layout assignment ourselves.
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.add_xla_disable_hlo_passes("layout-assignment");
+ return opts;
+ }
+
void RunTest2D(bool manual_fusion, int64 size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {});
- const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
+ const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ const Shape elem_shape2 =
+ ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
@@ -80,13 +89,13 @@ class MultiOutputFusionTest : public HloTestBase {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub, add2}, 0, 2)));
+ auto tuple =
+ computation->AddInstruction(HloInstruction::CreateTuple({sub, add2}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0));
auto gte1 = computation->AddInstruction(
@@ -100,23 +109,25 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
arg1.PopulateWithValue<float>(2.5f);
- Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
+ Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
auto actual =
- ExecuteAndTransfer(std::move(hlo_module),
- {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size});
- const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size});
+ const Shape elem_shape_F32 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
+ const Shape elem_shape_U8 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
auto param1 = builder.AddInstruction(
@@ -136,17 +147,18 @@ class MultiOutputFusionTest : public HloTestBase {
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {size, 1}), add));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
+ dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
- auto tuple = computation->AddInstruction(HloInstruction::CreateTuple(
- ArraySlice<HloInstruction*>({sub_U8, add}, 0, 2)));
+ auto tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple({sub_U8, add}));
auto gte0 = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0));
@@ -161,15 +173,14 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal input0(ShapeUtil::MakeShape(F32, {size}));
+ Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
input0.PopulateWithValue(2.5f);
- Literal input1(ShapeUtil::MakeShape(F64, {size}));
+ Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect =
- std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
}
};
@@ -206,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
LiteralUtil::CreateR0<float>(1.0)),
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
LiteralUtil::CreateR0<int32>(4)));
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -235,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -268,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
- LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
}
const char* const kScalarOps = R"(
@@ -291,7 +299,7 @@ const char* const kScalarOps = R"(
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -312,18 +320,17 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -344,18 +351,17 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -377,18 +383,17 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
- LiteralUtil::CreateR1<float>({36, 64}),
- LiteralUtil::CreateR1<float>({66, 138})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -410,19 +415,18 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -444,20 +448,19 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
c0 = f32[] constant(0)
@@ -480,21 +483,20 @@ XLA_TEST_F(MultiOutputFusionTest,
.ValueOrDie();
auto param =
LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({14, 22}),
LiteralUtil::CreateR3<float>(
{{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
LiteralUtil::CreateR3<float>(
{{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce {
p0 = f32[2,2,2]{2,1,0} parameter(0)
init1 = f32[] parameter(1)
@@ -518,18 +520,18 @@ XLA_TEST_F(MultiOutputFusionTest,
LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
auto init1 = LiteralUtil::CreateR0<float>(5);
auto init2 = LiteralUtil::CreateR0<float>(6);
- std::unique_ptr<Literal> result = ExecuteNoHloPasses(
- std::move(module), {param.get(), init1.get(), init2.get()});
+ Literal result =
+ ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
- *result));
+ result));
}
XLA_TEST_F(MultiOutputFusionTest,
DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
- const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ const string testcase = absl::StrCat(kScalarOps, R"(
fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
p0 = f16[2,2,2]{2,1,0} parameter(0)
convert = f32[2,2,2]{2,1,0} convert(p0)
@@ -553,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest,
auto param = LiteralUtil::CreateR3<Eigen::half>(
{{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
- std::unique_ptr<Literal> result =
- ExecuteNoHloPasses(std::move(module), {param.get()});
+ Literal result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
LiteralUtil::CreateR3<Eigen::half>(
@@ -564,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest,
{Eigen::half(3), Eigen::half(4)}},
{{Eigen::half(5), Eigen::half(6)},
{Eigen::half(7), Eigen::half(8)}}})),
- *result));
+ result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
index 0a0426adcb..f2460822a6 100644
--- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
+++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
@@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
GetTupleElement(result_tuple, 0);
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) {
VLOG(1) << "Transferring trip count to computation";
// Transfer number of iterations to Infeed.
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
// Pick up value from outfeed
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 1);
+ EXPECT_EQ(r.Get<int32>({}), 1);
}
VLOG(1) << "Writing data to infeed";
// Transfer some stuff to Infeed for use inside of loop.
TF_ASSERT_OK(local_client_->TransferToInfeed(
- *LiteralUtil::CreateR1<int32_t>({10, 20})));
+ LiteralUtil::CreateR1<int32_t>({10, 20})));
// Pick up value from outfeed
{
VLOG(1) << "Reading from body outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&xfeed_shape));
- EXPECT_EQ(r->Get<int32>({0}), 11);
- EXPECT_EQ(r->Get<int32>({1}), 21);
+ EXPECT_EQ(r.Get<int32>({0}), 11);
+ EXPECT_EQ(r.Get<int32>({1}), 21);
}
{
VLOG(1) << "Reading from condition outfeed";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&int_shape));
- EXPECT_EQ(r->Get<int32>({}), 0);
+ EXPECT_EQ(r.Get<int32>({}), 0);
}
// Joins the thread
thread.reset();
- EXPECT_EQ(comp_result->Get<int32>({}), 0);
+ EXPECT_EQ(comp_result.Get<int32>({}), 0);
}
XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
@@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
- std::unique_ptr<xla::Literal> comp_result;
+ Literal comp_result;
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
}));
TF_ASSERT_OK(
- local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+ TF_ASSERT_OK_AND_ASSIGN(Literal r,
local_client_->TransferFromOutfeed(&result_shape));
- EXPECT_EQ(r->Get<bool>({}), true);
+ EXPECT_EQ(r.Get<bool>({}), true);
// Join the thread
thread.reset();
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index ca21b0b2ba..6e98167739 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -16,12 +16,12 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
@@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
@@ -140,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
XlaBuilder b(TestName());
- auto input = MakeUnique<Array4D<float>>(1, 1, 3, 2);
+ auto input = absl::make_unique<Array4D<float>>(1, 1, 3, 2);
Array2D<float> input_xy({
{1.0f, 2.0f}, // row 0
{3.0f, 4.0f}, // row 1
@@ -148,10 +148,10 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
expected->Fill(1.5);
(*expected)(1, 0, 0, 0) = 1.0f;
(*expected)(1, 0, 0, 1) = 2.0f;
@@ -168,10 +168,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
Pad(AddParam(input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
expected->Fill(pad_value);
(*expected)(1, 0, 0, 0) = 1.0f;
(*expected)(1, 2, 0, 0) = 2.0f;
@@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
- input = input->Relayout(layout);
+ input = input.Relayout(layout);
- Pad(AddParam(*input, &b),
- AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(input, &b),
+ AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -269,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
XLA_TEST_F(PadTest, Pad4DU8Array) {
XlaBuilder b(TestName());
- auto input = MakeUnique<Array4D<uint8>>(1, 1, 3, 2);
+ auto input = absl::make_unique<Array4D<uint8>>(1, 1, 3, 2);
Array2D<uint8> input_xy({
{1, 2}, // row 0
{3, 4}, // row 1
@@ -280,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) {
Pad(AddParam(*input, &b), ConstantR0<uint8>(&b, 35),
r4_padding_on_dim0_dim1_);
- auto expected = MakeUnique<Array4D<uint8>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<uint8>>(2, 3, 3, 2);
expected->Fill(35);
(*expected)(1, 0, 0, 0) = 1;
(*expected)(1, 0, 0, 1) = 2;
@@ -301,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
Pad(input, ConstantR0<bool>(&b, false), r4_padding_on_dim0_dim1_);
// For the same reason, use Select to convert boolean values to int32.
- auto zeros = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
- auto ones = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto zeros = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
+ auto ones = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
zeros->Fill(0);
ones->Fill(1);
Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b));
- auto expected = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto expected = absl::make_unique<Array4D<int32>>(2, 3, 3, 2);
expected->Fill(0);
(*expected)(1, 0, 0, 0) = 1;
(*expected)(1, 0, 0, 1) = 1;
@@ -321,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
XLA_TEST_P(PadTestFloat, Large2DPad) {
XlaBuilder b(TestName());
- auto ones = MakeUnique<Array2D<float>>(4, 4);
+ auto ones = absl::make_unique<Array2D<float>>(4, 4);
ones->Fill(1.0f);
auto input = AddParam(*ones, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -342,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
constexpr int64 in_rows = 35;
constexpr int64 in_cols = 35;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(0.0f);
auto input = AddParam(*operand, &b);
@@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
- padding_config);
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(3.14f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -368,7 +367,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
constexpr int64 low_padding = 0;
int64 high_padding[2] = {5, 7};
constexpr int64 interior_padding = 0;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -395,7 +394,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {-3, 4};
constexpr int64 interior_padding = 0;
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -423,7 +422,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
int64 low_padding[2] = {4, -1};
int64 high_padding[2] = {-2, -4};
int64 interior_padding[2] = {1, 2};
- auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ auto operand = absl::make_unique<Array2D<float>>(in_rows, in_cols);
operand->FillUnique(1.0f);
auto input = AddParam(*operand, &b);
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -446,19 +445,18 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
// Regression test for b/31827337.
XLA_TEST_P(PadTestFloat, ReducePad) {
XlaBuilder b(TestName());
- auto ones = MakeUnique<Array4D<float>>(2, 2, 2, 2);
+ auto ones = absl::make_unique<Array4D<float>>(2, 2, 2, 2);
ones->Fill(1.0);
auto input = AddParam(*ones, &b);
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
- padding_config);
+ Pad(reduce, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index f6c762e7a4..dcb4c11c3c 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR0<float>(3.14159f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
@@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
@@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
- LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
@@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ Literal param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0,
ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
@@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
@@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ Literal param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ client_->TransferToServer(param0_literal).ConsumeValueOrDie();
Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
@@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
// Use both parameters
//
@@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
@@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
- Parameter(&builder, 0, literal0->shape(), "param0");
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
+ Parameter(&builder, 0, literal0.shape(), "param0");
- std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
- Parameter(&builder, 1, literal1->shape(), "param1");
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
+ Parameter(&builder, 1, literal1.shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
- client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ client_->TransferToServer(literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 =
- LiteralUtil::CreateR1<float>({10, 20, 30});
+ Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
- client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ client_->TransferToServer(literal1).ConsumeValueOrDie();
- auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
- auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
+ auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
+ auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
// This add is unused.
Add(param1, param2);
@@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ Literal literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
- client_->TransferToServer(*literal).ConsumeValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ client_->TransferToServer(literal).ConsumeValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
+ Literal literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
sum_handle = Add(sum_handle, param);
}
@@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
sum_handle = Add(sum_handle, param);
}
@@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+ Literal literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
- std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ std::move(client_->TransferToServer(literal)).ValueOrDie());
+ XlaOp param = Parameter(&builder, i, literal.shape(), "param");
params.push_back(param);
- parameter_shapes.push_back(literal->shape());
+ parameter_shapes.push_back(literal.shape());
}
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
+ Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
- std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
+ std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
XlaOp bool_param =
- Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
+ Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
params.push_back(bool_param);
- parameter_shapes.push_back(bool_literal->shape());
+ parameter_shapes.push_back(bool_literal.shape());
auto init = Tuple(&builder, params);
@@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest,
param_data.push_back(data.get());
}
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
std::vector<const Literal*> ptrs;
+ elements.reserve(kParamCount);
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
- ptrs.push_back(elements.back().get());
+ ptrs.push_back(&elements.back());
}
- ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
- LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1, 2, 3}),
+ LiteralUtil::CreateR1<float>({4, 5, 6}),
}))
.ConsumeValueOrDie();
@@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
- Parameter(&builder, 0, literal->shape(), "input");
+ Parameter(&builder, 0, literal.shape(), "input");
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+ client_->TransferToServer(literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ Literal literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});
- const Shape original = literal->shape();
+ const Shape original = literal.shape();
{
// Reverse the layout present in original, and make that the layout of the
// literal.
@@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
original.layout().minor_to_major().begin(),
original.layout().minor_to_major().end());
std::reverse(original_layout.begin(), original_layout.end());
- *literal->mutable_shape_do_not_use()->mutable_layout() =
+ *literal.mutable_shape_do_not_use()->mutable_layout() =
LayoutUtil::MakeLayout(original_layout);
- ASSERT_EQ(2, literal->Get<float>({0, 1}));
+ ASSERT_EQ(2, literal.Get<float>({0, 1}));
}
// Use the original shape in building the computation.
XlaBuilder builder(TestName());
@@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*literal).ConsumeValueOrDie();
+ client_->TransferToServer(literal).ConsumeValueOrDie();
// Check that we got the off-diagonal value that we expected.
Array2D<float> expected(1, 1);
expected(0, 0) = 2;
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 2fc7f816b5..58539e6b06 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase {
protected:
void TestCompare(bool lhs, bool rhs, bool expected,
std::function<XlaOp(const xla::XlaOp&, const xla::XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
+ absl::Span<const int64>)>
op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<bool>(&builder, lhs);
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 326e13b386..8f2c26f0ee 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -37,9 +37,7 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- std::unique_ptr<Literal> UniformTest(T a, T b,
- tensorflow::gtl::ArraySlice<int64> dims,
- int64 seed = 42);
+ Literal UniformTest(T a, T b, absl::Span<const int64> dims, int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
// of the given range size. `expected_count` is the number of times each
@@ -50,8 +48,8 @@ class PrngTest : public ClientLibraryTestBase {
};
template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(
- T a, T b, tensorflow::gtl::ArraySlice<int64> dims, int64 seed) {
+Literal PrngTest::UniformTest(T a, T b, absl::Span<const int64> dims,
+ int64 seed) {
XlaBuilder builder(TestName());
RngUniform(
ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -60,8 +58,8 @@ std::unique_ptr<Literal> PrngTest::UniformTest(
SetSeed(seed);
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
- EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
- actual->EachCell<T>([=](tensorflow::gtl::ArraySlice<int64>, T value) {
+ EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions()));
+ actual.EachCell<T>([=](absl::Span<const int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
@@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
constexpr int64 count = 100;
for (int64 seed = 0; seed < count; ++seed) {
auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
- result->Literal::EachCell<bfloat16>(
- [&](tensorflow::gtl::ArraySlice<int64>, bfloat16 value) {
- int64 index = static_cast<int64>((value - low) / interval);
- counts[index]++;
- });
+ result.EachCell<bfloat16>([&](absl::Span<const int64>, bfloat16 value) {
+ int64 index = static_cast<int64>((value - low) / interval);
+ counts[index]++;
+ });
}
// Each bucket should have similar amount of counts. That is, not more than
// 10% of total counts. This mostly tests that we don't fall into a 1:2:2
@@ -149,8 +146,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
auto actual =
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
- actual->EachCell<int32>([&counts](tensorflow::gtl::ArraySlice<int64>,
- int32 value) { ++counts[value]; });
+ actual.EachCell<int32>(
+ [&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
int64 sum = 0;
for (int32 i = 0; i < range_size; ++i) {
sum += Square(static_cast<int64>(counts[i] - expected_count));
@@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
};
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
- client_->TransferToServer(*param0_literal));
+ client_->TransferToServer(param0_literal));
- auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
auto fn = build_sum_rng(builder);
Map(&builder, {param0}, fn, {0});
@@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
computation,
/*arguments=*/{param0_data.get()}, &execution_options));
- EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
- ShapeUtil::ElementsIn(param0_literal->shape()));
- for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
- EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
- EXPECT_LT(actual->data<float>()[i],
- param0_literal->data<float>()[i] + 1.0f);
+ EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()),
+ ShapeUtil::ElementsIn(param0_literal.shape()));
+ for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) {
+ EXPECT_GE(actual.data<float>()[i], param0_literal.data<float>()[i]);
+ EXPECT_LT(actual.data<float>()[i], param0_literal.data<float>()[i] + 1.0f);
}
}
@@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
ExecutionOptions execution_options2 = execution_options_;
execution_options2.set_seed(65);
- std::unique_ptr<Literal> result1;
+ Literal result1;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options1));
}
- std::unique_ptr<Literal> result2;
- std::unique_ptr<Literal> result3;
+ Literal result2;
+ Literal result3;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options1));
}
- std::unique_ptr<Literal> result4;
- std::unique_ptr<Literal> result5;
- std::unique_ptr<Literal> result6;
+ Literal result4;
+ Literal result5;
+ Literal result6;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
@@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_));
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
- EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index a080dd1732..c9096fb29b 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -15,11 +15,11 @@ limitations under the License.
#include <array>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -29,16 +29,13 @@ limitations under the License.
namespace xla {
namespace {
-namespace str_util = tensorflow::str_util;
-namespace strings = tensorflow::strings;
-
struct ReduceLayout {
std::array<int64, 4> input_minor_to_major;
std::array<int64, 3> output_minor_to_major;
string ToString() const {
- return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_",
- str_util::Join(output_minor_to_major, "x"));
+ return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_",
+ absl::StrJoin(output_minor_to_major, "x"));
}
};
@@ -95,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ Literal reduce_input = LiteralUtil::CreateR4<float>(
{{ /*i0=0*/
{/*i1=0*/
{-0.246092796, -0.179497838, -0.161181688},
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 531648fe3e..26e2bfde5c 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10};
string TestDataToString(const ::testing::TestParamInfo<int> data) {
int i = data.param;
- return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_",
- mantissa_sizes[i], "_mantissa_bits");
+ return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i],
+ "_mantissa_bits");
}
// The FPVAL macro allows us to write out the binary representation of the
@@ -230,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal =
- LiteralUtil::CreateR1<float>({input_values});
+ Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
ReducePrecision(a, exponent_bits, mantissa_bits);
@@ -254,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// Abs doesn't affect resolution.
auto abs = Abs(a);
@@ -283,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -309,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -333,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
@@ -358,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+ Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
+ auto a = Parameter(&builder, 0, a_literal.shape(), "a");
// These two operations should be fused by any reasonable backend.
auto abs = Abs(a);
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 2065271a7f..83997cdac2 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -32,6 +32,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -51,7 +54,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -79,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase {
}, 4);
// clang-format on
CHECK(ShapeUtil::Equal(
- literal_3d_->shape(),
+ literal_3d_.shape(),
ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
- << literal_3d_->shape().ShortDebugString();
+ << literal_3d_.shape().ShortDebugString();
}
// Runs an R1 => R0 reduction test with the given number of elements.
@@ -100,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase {
input_data[i] *= -1;
}
}
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(AsSlice(input_data));
+ Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (float item : input_data) {
@@ -113,8 +114,7 @@ class ReduceTest : public ClientLibraryTestBase {
ErrorSpec(0.001));
}
- void RunR1ToR0PredTest(bool and_reduce,
- tensorflow::gtl::ArraySlice<int> input_data) {
+ void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
const int element_count = input_data.size();
XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
@@ -133,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase {
Reduce(pred_values, init_value, reduce,
/*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
+ Literal input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
bool expected = and_reduce;
for (bool item : input_data) {
@@ -174,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::array<bool, cols> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -208,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
float expected = 0.0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
@@ -236,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -259,8 +256,8 @@ class ReduceTest : public ClientLibraryTestBase {
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_floating_point<NativeT>::value,
XlaBuilder>::type* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments,
ErrorSpec(0.01, 1e-4));
}
@@ -269,8 +266,8 @@ class ReduceTest : public ClientLibraryTestBase {
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_integral<NativeT>::value,
XlaBuilder>::type* builder,
- tensorflow::gtl::ArraySlice<NativeT> expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ absl::Span<const NativeT> expected,
+ absl::Span<GlobalData* const> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments);
}
@@ -294,15 +291,14 @@ class ReduceTest : public ClientLibraryTestBase {
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+ input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
// NativeT can be bool, and std::vector<bool> does not convert to
- // ArraySlice.
+ // Span.
std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
for (int64 colno = 0; colno < cols; ++colno) {
NativeT column_result = initial_value;
@@ -314,7 +310,7 @@ class ReduceTest : public ClientLibraryTestBase {
}
ComputeAndCompareGeneric<NativeT>(
- &builder, tensorflow::gtl::ArraySlice<NativeT>(expected.get(), cols),
+ &builder, absl::Span<const NativeT>(expected.get(), cols),
{input_global_data.get()});
}
@@ -351,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase {
reference_reduction_function_for_uints, unsigned_int_identity);
}
- std::unique_ptr<Literal> literal_2d_;
- std::unique_ptr<Literal> literal_3d_;
+ Literal literal_2d_;
+ Literal literal_3d_;
uint32 seed_ = 0xdeadbeef;
};
@@ -449,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -481,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2D(input_data);
- input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
@@ -510,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
- MakeFakeLiteral(input_shape));
+ TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
- ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4));
+ ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
@@ -530,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3D(input_data);
+ Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 major = 0; major < 2; ++major) {
@@ -556,12 +548,11 @@ struct BoundsLayout {
};
void PrintTo(const BoundsLayout& spec, std::ostream* os) {
- *os << tensorflow::strings::Printf(
- "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(),
- spec.bounds.size() - spec.reduce_dims.size(),
- tensorflow::str_util::Join(spec.bounds, "x").c_str(),
- tensorflow::str_util::Join(spec.layout, "").c_str(),
- tensorflow::str_util::Join(spec.reduce_dims, "").c_str());
+ *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(),
+ spec.bounds.size() - spec.reduce_dims.size(),
+ absl::StrJoin(spec.bounds, "x"),
+ absl::StrJoin(spec.layout, ""),
+ absl::StrJoin(spec.reduce_dims, ""));
}
// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
@@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
input.Each(
@@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
- Reduce(ConstantLiteral(&builder, *input_literal),
+ Reduce(ConstantLiteral(&builder, input_literal),
ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
auto input_min = FLT_MAX;
@@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
@@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
auto initial_value =
ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
- Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
+ Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XlaBuilder builder("reduce_among_y");
- auto m = ConstantLiteral(&builder, *literal_2d_);
+ auto m = ConstantLiteral(&builder, literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
@@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
@@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
@@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
@@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
@@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
XlaBuilder builder(TestName());
- auto m = ConstantLiteral(&builder, *literal_3d_);
+ auto m = ConstantLiteral(&builder, literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
@@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
- input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
+ input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
auto input_activations =
- Parameter(&builder, 0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal.shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
GetParam().reduce_dims);
@@ -866,21 +857,17 @@ INSTANTIATE_TEST_CASE_P(
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
-// TODO(b/64093391) Disabled on GPU due to an assertion failure when running
-// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on
-// 2017-07-26.
-XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
+XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
XlaBuilder builder(TestName());
XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
auto a = ConstantR0<float>(&builder, 2.0f);
auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal =
- LiteralUtil::CreateR1<float>({1.0f, 4.0f});
+ Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
+ auto b = Parameter(&builder, 0, b_literal.shape(), "b");
Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
@@ -907,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest {
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
- max_fn, {0});
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
+ {0});
ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
}
@@ -955,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) {
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1<float>(operand);
+ Literal input_literal = LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
- client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
+ client_->TransferToServer(input_literal).ConsumeValueOrDie();
+ Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
- client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+ client_->TransferToServer(input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
&builder, expected, {input_global_data.get(), input_global_data2.get()},
ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 92c93f08b2..d5de9650f1 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -18,6 +18,10 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
@@ -35,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -54,7 +57,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
public:
ErrorSpec DefaultErrorSpec() const {
if (use_bfloat16()) {
- return ErrorSpec(1e-1, 5e-2);
+ return ErrorSpec(2e-1, 6e-2);
} else {
return ErrorSpec(1e-3, 1e-3);
}
@@ -67,10 +70,10 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
void ReduceWindowAdd(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
&builder_);
ReduceWindow(input, init,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -78,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
}
void ReduceWindowMax(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
auto init =
CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
@@ -89,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
}
void ReduceWindowMin(const XlaOp& input,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
+ absl::Span<const int64> window_dimensions,
+ absl::Span<const int64> window_strides,
Padding padding) {
auto init =
CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
@@ -104,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
ReduceWindow(input, init_value,
CreateScalarAddComputation(FloatType(), &builder_),
@@ -121,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
/*window_dimensions=*/{},
/*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
{}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
{}, ErrorSpec(0.00001));
}
@@ -158,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -173,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -187,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -204,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -226,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -249,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -274,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -291,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -310,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
auto lhs = Parameter(b.get(), 0, scalar, "lhs");
auto rhs = Parameter(b.get(), 1, scalar, "rhs");
Min(Add(lhs, rhs),
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
ReduceWindow(
input,
- CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -329,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
{}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
@@ -349,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -357,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- auto arg_literal = MakeUnique<Literal>(shape);
- arg_literal->PopulateWithValue(1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ Literal arg_literal(shape);
+ arg_literal.PopulateWithValue(1.0f);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
@@ -368,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
Shape result_shape =
ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
- auto expected = MakeUnique<Literal>(result_shape);
- expected->PopulateWithValue(27.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ Literal expected(result_shape);
+ expected.PopulateWithValue(27.0f);
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> input_dims(6, 8);
auto shape = ShapeUtil::MakeShape(F32, input_dims);
- std::unique_ptr<Literal> arg_literal =
+ Literal arg_literal =
LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
- const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+ const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
Padding padding = Padding::kValid;
ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
- ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 1;
int stride = 8;
@@ -410,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 3;
int stride = 1;
@@ -432,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "parameter", &builder_, &input);
+ 0, input_literal, "parameter", &builder_, &input);
int win_len = 8;
int stride = 5;
@@ -454,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -475,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+ DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -501,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -518,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+ LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -537,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
@@ -553,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_,
- *LiteralUtil::CreateFromArray<float>(*res), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+ {}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -578,21 +575,20 @@ string R4ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), //
- "__window_bounds_",
- tensorflow::str_util::Join(param.window_bounds, "x"), //
- "__strides_", tensorflow::str_util::Join(param.strides, "x"), //
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), //
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), //
- "__layout_", tensorflow::str_util::Join(param.layout, "_"), //
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
+ "__strides_", absl::StrJoin(param.strides, "x"), //
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"), //
+ "__pad_high_", absl::StrJoin(param.pad_high, "x"), //
+ "__layout_", absl::StrJoin(param.layout, "_"), //
(param.reducer == kAdd) ? "_add" : "_max");
CHECK(param.reducer == kAdd || param.reducer == kMax);
// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -611,12 +607,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], param.base_bounds[3]);
- input.FillIota(1);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ input.FillRandom(0.1f, 0.1f);
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(4);
@@ -625,9 +620,16 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
- auto computation = param.reducer == kAdd
+ auto reducer = param.reducer;
+ if (use_bfloat16() && Product(param.window_bounds) > 128) {
+ // To avoid numerical issues, force the reducer to be kMax for large bf16
+ // windows.
+ reducer = kMax;
+ }
+
+ auto computation = reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
ReduceWindowWithGeneralPadding(
@@ -638,8 +640,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window_strides=*/param.strides,
/*padding=*/padding);
- CHECK(param.reducer == kAdd || param.reducer == kMax);
- auto reduce_func = param.reducer == kAdd
+ CHECK(reducer == kAdd || reducer == kMax);
+ auto reduce_func = reducer == kAdd
? +[](float a, float b) { return a + b; }
: +[](float a, float b) { return std::max(a, b); };
std::unique_ptr<Array4D<float>> expected =
@@ -650,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- std::unique_ptr<Literal> expected_literal =
- LiteralUtil::CreateFromArray(*expected);
+ Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
- input_literal->shape().element_type(),
- AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
- ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+ input_literal.shape().element_type(),
+ AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
+ ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
DefaultErrorSpec(), &expected_shape_with_layout);
}
};
@@ -807,6 +808,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*pad_high=*/{1, 0, 0, 0},
/*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
+
+ R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3},
+ /*window_bounds=*/{1, 64, 64, 1},
+ /*strides=*/{1, 64, 64, 1},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 0, 2, 1},
+ /*reducer=*/kAdd},
+
+ R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
+ /*window_bounds=*/{112, 112, 1, 8},
+ /*strides=*/{112, 112, 1, 8},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
+ /*reducer=*/kAdd},
};
INSTANTIATE_TEST_CASE_P(
@@ -928,21 +945,42 @@ struct R3ReduceWindowTestData {
{/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
/*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
/*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
+ /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+ /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
};
string R3ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
- "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
- "__strides_", tensorflow::str_util::Join(param.strides, "x"),
- "__padding_", param.padding == Padding::kSame ? "same" : "valid",
- "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2],
- "__reducer_", param.reducer == kAdd ? "add" : "max");
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
+ absl::StrJoin(param.window_bounds, "x"), "__strides_",
+ absl::StrJoin(param.strides, "x"), "__padding_",
+ param.padding == Padding::kSame ? "same" : "valid", "__layout_",
+ param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
+ param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -954,35 +992,41 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
};
-TEST_P(R3ReduceWindowTest, Add) {
+TEST_P(R3ReduceWindowTest, DoIt) {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
const float kInitValue = 0.0f;
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
- param.base_bounds[2], 1.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR3FromArray3DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ param.base_bounds[2]);
+ input.FillRandom(0.1f, 0.1f);
+ Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
+ auto reducer = param.reducer;
+ if (use_bfloat16()) {
+ input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
+ if (Product(param.window_bounds) > 128) {
+ // To avoid numerical issues, force the reducer to be kMax for large bf16
+ // windows.
+ reducer = kMax;
+ }
+ }
- XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
+
+ auto computation = reducer == kAdd
+ ? CreateScalarAddComputation(FloatType(), &b)
+ : CreateScalarMaxComputation(FloatType(), &b);
+
ReduceWindow(/*operand=*/parameter,
/*init_value=*/init_value,
- /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+ /*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides, /*padding=*/param.padding);
- auto expected = ReferenceUtil::ReduceWindow3DAdd(
- /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
- /*stride=*/param.strides, /*padding=*/param.padding);
-
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
- {input_arg.get()}, DefaultErrorSpec());
+ ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(
@@ -1068,17 +1112,16 @@ string R2ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), //
- "__window_bounds_",
- tensorflow::str_util::Join(param.window_bounds, "x"), //
- "__strides_", tensorflow::str_util::Join(param.strides, "x"), //
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
- "__layout_", param.layout[0], "_", param.layout[1], //
+ string str = absl::StrCat(
+ "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
+ "__strides_", absl::StrJoin(param.strides, "x"), //
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
+ absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
+ param.layout[1], //
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -1092,16 +1135,14 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
void DoIt() {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
&b, &parameter);
std::vector<std::pair<int64, int64>> padding(2);
for (int i = 0; i < 2; ++i) {
@@ -1111,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1127,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1261,21 +1302,27 @@ struct R1ReduceWindowTestData {
/*pad_low=*/{5},
/*pad_high=*/{0},
/*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+ /*strides=*/{1},
+ /*pad_low=*/{4095},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kMax},
};
string R1ReduceWindowTestDataToString(
const ::testing::TestParamInfo<
::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
const auto& param = ::testing::get<0>(data.param);
- string str = tensorflow::strings::StrCat(
- "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"),
- "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"),
- "__strides_", tensorflow::str_util::Join(param.strides, "x"),
- "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"),
- "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"),
- "__reducer_", param.reducer == kAdd ? "add" : "max");
+ string str =
+ absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
+ "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
+ "__strides_", absl::StrJoin(param.strides, "x"),
+ "__pad_low_", absl::StrJoin(param.pad_low, "x"),
+ "__pad_high_", absl::StrJoin(param.pad_high, "x"),
+ "__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = tensorflow::strings::StrCat(str, "_bfloat16");
+ str = absl::StrCat(str, "_bfloat16");
}
return str;
}
@@ -1295,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) {
const float kInitValue = 0.0f;
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
+ Literal input_literal =
+ LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
XlaOp parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
+ auto input_arg =
+ CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
std::vector<std::pair<int64, int64>> padding(1);
padding[0] = {param.pad_low[0], param.pad_high[0]};
@@ -1308,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
@@ -1320,14 +1367,14 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? +[](float a, float b) { return a + b; }
: +[](float a, float b) { return std::max(a, b); };
auto expected = ReferenceUtil::ReduceWindow1DGeneric(
- /*operand=*/tensorflow::gtl::ArraySlice<float>(input_vector),
+ /*operand=*/absl::Span<const float>(input_vector),
/*init=*/kInitValue,
/*reduce_func=*/reduce_func,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
@@ -1442,7 +1489,7 @@ ENTRY reduce-window-identity {
}
)";
- EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+ EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
}
XLA_TEST_F(HloTestBase, ReduceWindowS32) {
@@ -1461,7 +1508,7 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
}
)";
- EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+ EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
}
XLA_TEST_F(HloTestBase, ReduceWindowF16) {
@@ -1480,7 +1527,7 @@ ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
}
)";
- EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+ EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index d891451381..5cf87e565b 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect 4.
- LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(4, literal);
}
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
@@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed,
/*arguments=*/{x_data.get(), y_data.get()},
@@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
.ConsumeValueOrDie();
// Expect 5.
- LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+ LiteralTestUtil::ExpectR0Equal<int32>(5, literal);
}
TEST_F(ReplayTest, MapPlusTwoOverR1) {
@@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
// Run it.
- std::unique_ptr<Literal> literal =
+ Literal literal =
client_
->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
.ConsumeValueOrDie();
// Expect result.
- LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, literal);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index 368f5583c9..ae24eb5eb4 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 382d1b1ae7..dedc95b5ae 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <random>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
input_array.Fill(1.0f);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
&builder, &parameter);
auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
+ Literal param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
auto a = Neg(parameter);
Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) {
Array2D<float> input_array(0, 3);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal =
+ Literal param0_literal =
LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+ auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) {
Array2D<float> input_array(3, 0);
auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
auto input_literal =
LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0},
/*new_sizes=*/{2, 3});
auto expected_literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 3, 0, 0});
auto expected_literal =
LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{24, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
/*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
auto expected_literal = LiteralUtil::CreateFromArray(*expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{3, 0});
auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
/*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
auto expected_literal = LiteralUtil::CreateFromArray(expected);
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
/*new_sizes=*/{8, 3});
@@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
{35, 36, 37},
{40, 41, 42},
{45, 46, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{24});
auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{8, 3});
@@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
{45, 16, 26},
{36, 46, 17},
{27, 37, 47}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
/*new_sizes=*/{2, 6, 2});
auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
t2x2x2x3.FillWithYX(*filler2x3);
auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 1, 1) = 7;
auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{2, 4});
auto expected_literal =
LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
Reshape(parameter, dimensions, {});
auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
- ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&b, expected_literal, {input.get()},
zero_error_spec_);
}
}
@@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {}, {});
EXPECT_THAT(
@@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
&parameter);
Reshape(parameter, {1}, {});
EXPECT_THAT(ExecuteToString(&b, {}),
@@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
LayoutUtil::MakeLayout({0, 1, 2, 3}));
// clang-format on
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
@@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0});
- std::unique_ptr<Literal> actual =
+ Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralUtil::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(expected);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
{{204, 205, 206, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+ Literal input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
{{206, 7, 107, 207}}}
});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+ ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
zero_error_spec_);
}
@@ -689,20 +688,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 1, 1);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -711,20 +707,17 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 4, 1);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal);
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -734,25 +727,23 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(5, 10, 2, 3);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
/*new_sizes=*/{5, 60});
Array2D<float> expected_array(5, 60);
- input.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* cell) {
+ input.Each([&](absl::Span<const int64> indices, float* cell) {
expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
*cell;
});
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
zero_error_spec_);
}
@@ -762,14 +753,13 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
std::uniform_real_distribution<float> distribution;
Array4D<float> input_array(2, 3, 5, 7);
input_array.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ [&rng, &distribution](absl::Span<const int64> /* indices */,
float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
/*new_sizes=*/{7, 2, 3, 5});
XlaComputation computation = builder.Build().ConsumeValueOrDie();
@@ -778,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
{2, 3, 0, 1});
- std::unique_ptr<Literal> output_literal =
+ Literal output_literal =
client_
->ExecuteAndTransfer(computation, {input_data.get()},
&execution_options)
@@ -787,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
- EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
+ auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
+ EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
} else {
- EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
+ EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
}
}
@@ -801,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
/*new_sizes=*/{1, 2, 3, 4});
- ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
+ ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
@@ -816,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+ auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
&builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
/*new_sizes=*/{2, 4, 3, 1});
@@ -833,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
{{16}, {20}, {24}}}});
// clang-format on
- ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()});
+ ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
@@ -842,27 +832,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
std::vector<int64> bounds = {2, 2, 2, 2};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
@@ -871,27 +859,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
std::vector<int64> bounds = {1, 1, 250, 300};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
@@ -900,27 +886,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
std::vector<int64> bounds = {5, 5, 1, 10};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
@@ -930,27 +914,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
std::vector<int64> bounds = {5, 5, 10, 1};
std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
- ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+ .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
@@ -959,27 +941,25 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
std::vector<int64> bounds = {3, 3, 1, 3};
std::vector<int64> new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]};
Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
- input.Each(
- [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
- float* cell) { *cell = distribution(rng); });
- std::unique_ptr<Literal> input_literal =
- LiteralUtil::CreateR4FromArray4DWithLayout(
- input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
- auto input_data = CreateParameterAndTransferLiteral(
- 0, *input_literal, "input", &builder, &parameter);
+ auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+ &builder, &parameter);
Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
/*new_sizes=*/new_bounds);
- std::unique_ptr<Literal> expected =
- LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
- ->Relayout(input_literal->shape().layout());
+ Literal expected =
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal)
+ .Relayout(input_literal.shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
// actually corresponds to a two minor transpose.
- ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
- zero_error_spec_, &expected->shape());
+ ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+ zero_error_spec_, &expected.shape());
}
#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 41e49b4003..4e55b0d7ac 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include <memory>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -37,16 +39,14 @@ static std::array<bool, 1> use_bfloat16_params{false};
#endif
struct ReverseSpec {
- tensorflow::gtl::ArraySlice<int64> input_dims;
- tensorflow::gtl::ArraySlice<int64> reversal;
+ absl::Span<const int64> input_dims;
+ absl::Span<const int64> reversal;
bool use_bfloat16;
string ToTestCaseName() const {
- return tensorflow::strings::Printf(
- "reverse_%s_in_dims_%s_%s",
- tensorflow::str_util::Join(input_dims, "x").c_str(),
- tensorflow::str_util::Join(reversal, "x").c_str(),
- use_bfloat16 ? "bf16" : "f32");
+ return absl::StrFormat(
+ "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"),
+ absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32");
}
};
@@ -83,26 +83,25 @@ TEST_P(FloatReverseTest, Reverses) {
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
- auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
+ auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto a = AddParam(*input_literal, &builder);
+ auto a = AddParam(input_literal, &builder);
Rev(a, spec.reversal);
- std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
+ Literal expected = input_literal.Clone();
std::vector<int64> output_indices(spec.input_dims.size());
- expected->EachCell<float>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float) {
- for (int64 i = 0; i < indices.size(); ++i) {
- output_indices[i] = indices[i];
- }
- float value = input_literal->Get<float>(indices);
- for (int64 dim : spec.reversal) {
- output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
- }
- expected->Set<float>(output_indices, value);
- });
- ComputeAndCompareLiteral(&builder, *expected, {});
+ expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
+ for (int64 i = 0; i < indices.size(); ++i) {
+ output_indices[i] = indices[i];
+ }
+ float value = input_literal.Get<float>(indices);
+ for (int64 dim : spec.reversal) {
+ output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
+ }
+ expected.Set<float>(output_indices, value);
+ });
+ ComputeAndCompareLiteral(&builder, expected, {});
}
INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index a620fe1908..091a5d2cac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <memory>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -38,7 +38,7 @@ namespace {
class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
protected:
// Sends the literal to the server and retrieves it back.
- std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ Literal RoundTripToServer(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
return client_->Transfer(*data).ConsumeValueOrDie();
@@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
string data(sizeof(float) * 2, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 2);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 2);
floats[0] = 42.0;
floats[1] = 24.0;
@@ -60,18 +59,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
+ Literal actual =
reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0, actual->Get<float>({0}));
- EXPECT_EQ(24.0, actual->Get<float>({1}));
+ EXPECT_EQ(42.0, actual.Get<float>({0}));
+ EXPECT_EQ(24.0, actual.Get<float>({1}));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
string data(sizeof(float) * 4, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 4);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
// With x as the minor dimension, these will become:
floats[0] = 42.0; // y=0,x=0
floats[1] = 24.0; // y=0,x=1
@@ -89,24 +87,22 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
string data(sizeof(float) * 4, 0);
- tensorflow::gtl::MutableArraySlice<float> floats(
- tensorflow::bit_cast<float*>(data.data()), 4);
+ absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
// With y as the minor dimension, these will become:
floats[0] = 42.0; // y=0,x=0
floats[1] = 24.0; // y=1,x=0
@@ -124,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
std::unique_ptr<tensorflow::RandomAccessFile> f;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
PackedLiteralReader reader(f.release());
- std::unique_ptr<Literal> actual =
- reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
- .ConsumeValueOrDie();
+ Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
EXPECT_TRUE(reader.IsExhausted());
- EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
- EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
- EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
- EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+ EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+ EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+ EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+ EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
- std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+ Literal round_tripped = RoundTripToServer(actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index a8193c2eac..cd5a531603 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
void RoundTripTest(const Literal& original) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(original).ConsumeValueOrDie();
- std::unique_ptr<Literal> result =
- client_->Transfer(*data).ConsumeValueOrDie();
- EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
+ Literal result = client_->Transfer(*data).ConsumeValueOrDie();
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, result));
}
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+ RoundTripTest(LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+ RoundTripTest(LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+ RoundTripTest(LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(
- *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ RoundTripTest(LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*LiteralUtil::MakeTuple({}));
+ RoundTripTest(LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
- LiteralUtil::CreateR1<float>({3, 4}).get()}));
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}),
+ LiteralUtil::CreateR1<float>({3, 4})}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
- LiteralUtil::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc
index b4f2b74e3d..2b03a0b0b2 100644
--- a/tensorflow/compiler/xla/tests/sample_text_test.cc
+++ b/tensorflow/compiler/xla/tests/sample_text_test.cc
@@ -19,18 +19,18 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class SampleTextTest : public HloTestBase {};
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index e42c71eb28..1dd937a6d0 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <limits>
#include <memory>
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -30,8 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
// A template for building and running a binary comparison test.
template <typename NativeT>
void TestCompare(NativeT lhs, NativeT rhs, bool expected,
- std::function<XlaOp(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
- op) {
+ const std::function<XlaOp(const XlaOp&, const XlaOp&,
+ absl::Span<const int64>)>& op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
@@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
template <typename NativeT>
void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
- std::function<XlaOp(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)>
- op) {
+ const std::function<XlaOp(const XlaOp&, const XlaOp&,
+ absl::Span<const int64>)>& op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
@@ -163,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
+ Literal a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
@@ -227,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+ Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
- client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ client_->TransferToServer(a_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> b_data =
- client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ client_->TransferToServer(b_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> c_data =
- client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+ client_->TransferToServer(c_literal).ConsumeValueOrDie();
- XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
- XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
- XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+ XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
+ XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
+ XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
Mul(Mul(a, b), c);
ComputeAndCompareR0<float>(&builder, 5.775f,
@@ -379,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(div_computation,
@@ -390,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend / divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -421,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
- client_->TransferToServer(*dividend_literal));
+ client_->TransferToServer(dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
- client_->TransferToServer(*divisor_literal));
+ client_->TransferToServer(divisor_literal));
auto actual_literal =
client_
->ExecuteAndTransfer(rem_computation,
@@ -432,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
.ConsumeValueOrDie();
auto expected_literal =
LiteralUtil::CreateR0<uint32>(dividend % divisor);
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
}
}
}
@@ -443,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
- TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
+ Literal literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 922d70b752..d20dba028a 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
namespace xla {
namespace {
-using tensorflow::gtl::nullopt;
+using absl::nullopt;
class ScatterTest : public HloTestBase {
protected:
@@ -32,8 +32,7 @@ class ScatterTest : public HloTestBase {
RunTest(hlo_text, {operand, scatter_indices, updates});
}
- void RunTest(const string& hlo_text,
- tensorflow::gtl::ArraySlice<Literal*> args) {
+ void RunTest(const string& hlo_text, absl::Span<Literal* const> args) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -63,13 +62,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
@@ -93,13 +90,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates =
LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
@@ -124,13 +120,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
@@ -155,13 +149,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
@@ -186,13 +178,12 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ Literal operand = LiteralUtil::CreateR2<float>(
{{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({2, 1});
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+ Literal updates =
LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
@@ -217,13 +208,11 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
@@ -248,13 +237,12 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
@@ -278,15 +266,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
@@ -310,15 +296,13 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
{{-4, 4}, {-5, 5}, {-6, 6}}, //
{{-7, 7}, {-8, 8}, {-9, 9}}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
@@ -342,12 +326,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({1, 1});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
@@ -371,13 +354,11 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ZeroDimBounds) {
@@ -401,11 +382,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> scatter_indices =
- LiteralUtil::CreateR1<int32>({0, 2});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
@@ -430,12 +410,11 @@ ENTRY main {
index_vector_dim=2
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
- std::unique_ptr<Literal> scatter_indices =
+ Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ Literal scatter_indices =
LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- std::unique_ptr<Literal> updates =
- LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
@@ -459,13 +438,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
@@ -489,13 +468,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, NegativeIndex) {
@@ -519,13 +498,13 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand =
+ Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ Literal updates = LiteralUtil::CreateR3<int32>(
{{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, OneScalarIndex) {
@@ -549,12 +528,12 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ Literal operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates =
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates =
LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScalarUpdate) {
@@ -578,10 +557,10 @@ ENTRY main {
index_vector_dim=0
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ Literal updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, EmptyIndices) {
@@ -605,10 +584,10 @@ ENTRY main {
index_vector_dim=1
}
)";
- std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
- std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
- std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
- RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+ Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
+ Literal updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index e3d4f98dd7..f737b5158b 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -42,8 +42,8 @@ struct SelectAndScatterTestParam {
std::vector<int64> operand_shape;
std::vector<int64> source_shape;
Padding padding_type;
- tensorflow::gtl::ArraySlice<int64> window_dimensions;
- tensorflow::gtl::ArraySlice<int64> window_strides;
+ absl::Span<const int64> window_dimensions;
+ absl::Span<const int64> window_strides;
};
class SelectAndScatterTest
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index b8ad6668f8..a40c2d7de6 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -18,6 +18,11 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -25,16 +30,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
-using ::tensorflow::str_util::Join;
-
class SliceTest : public ClientLibraryTestBase {};
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
@@ -175,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
XlaBuilder builder(TestName());
auto original = ConstantR4FromArray4D(&builder, values);
Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
- ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
- &expected_literal->shape());
+ ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
+ &expected_literal.shape());
}
struct R1Spec {
@@ -193,26 +194,26 @@ class SliceR1Test : public ClientLibraryTestBase,
protected:
template <typename NativeT>
void Run(const R1Spec& spec) {
- // This can't be an std::vector, since you can't grab an ArraySlice of a
+ // This can't be an std::vector, since you can't grab a Span of a
// vector<bool>.
- tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
+ absl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = Parameter(&builder, 0, literal->shape(), "p0");
+ auto original = Parameter(&builder, 0, literal.shape(), "p0");
Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
// Ditto.
- tensorflow::gtl::InlinedVector<NativeT, 1> expected;
+ absl::InlinedVector<NativeT, 1> expected;
for (int i = spec.slice_start; i < spec.slice_limit;
i += spec.slice_stride) {
expected.push_back(i);
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
}
};
@@ -222,9 +223,8 @@ class SliceR1LargeTest : public SliceR1Test {};
string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) {
const R1Spec& spec = data.param;
- return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0,
- spec.slice_start, spec.slice_limit,
- spec.slice_stride);
+ return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start,
+ spec.slice_limit, spec.slice_stride);
}
XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
@@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) {
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = Parameter(&builder, 0, literal->shape(), "p0");
+ auto a = Parameter(&builder, 0, literal.shape(), "p0");
Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
@@ -448,13 +448,11 @@ struct R4Spec {
string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
const R4Spec& spec = data.param;
- return tensorflow::strings::StrCat( //
- "input_", Join(spec.input_dims, "x"), //
- "__layout_", Join(spec.input_layout, ""), //
- "__starts_", Join(spec.slice_starts, "x"), //
- "__limits_", Join(spec.slice_limits, "x"), //
- "__strides_", Join(spec.slice_strides, "x") //
- );
+ return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"),
+ "__layout_", absl::StrJoin(spec.input_layout, ""),
+ "__starts_", absl::StrJoin(spec.slice_starts, "x"),
+ "__limits_", absl::StrJoin(spec.slice_limits, "x"),
+ "__strides_", absl::StrJoin(spec.slice_strides, "x"));
}
class SliceR4Test : public ClientLibraryTestBase,
@@ -469,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase,
XlaBuilder builder(TestName());
auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
- auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
+ auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
- client_->TransferToServer(*literal));
+ client_->TransferToServer(literal));
Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc
index be35ec6c6e..a9874a9186 100644
--- a/tensorflow/compiler/xla/tests/test_macros.cc
+++ b/tensorflow/compiler/xla/tests/test_macros.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include <string>
#include <unordered_map>
-#include "tensorflow/core/lib/strings/str_util.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
@@ -44,7 +46,7 @@ ManifestT ReadManifest() {
string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
- std::vector<string> lines = tensorflow::str_util::Split(contents, '\n');
+ std::vector<string> lines = absl::StrSplit(contents, '\n');
for (string& line : lines) {
auto comment = line.find("//");
if (comment != string::npos) {
@@ -53,8 +55,8 @@ ManifestT ReadManifest() {
if (line.empty()) {
continue;
}
- tensorflow::str_util::StripTrailingWhitespace(&line);
- std::vector<string> pieces = tensorflow::str_util::Split(line, ' ');
+ absl::StripTrailingAsciiWhitespace(&line);
+ std::vector<string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (int64 i = 1; i < pieces.size(); ++i) {
@@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name,
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
- auto it = manifest.find(
- tensorflow::strings::StrCat(test_case_name, ".", test_name));
+ auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index faeec657b6..5155f0c652 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include <cmath>
+
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace xla {
@@ -26,89 +29,102 @@ namespace {
template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
- // Create uniform numbers between 1 and 1.125 to avoid creating denormal
- // numbers.
- std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f);
- const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
- TF_CHECK_OK(literal->Populate<FloatT>(
- [&](tensorflow::gtl::ArraySlice<int64> indices) {
- // Generate a random uniform number from -0.0625 and 0.0625 and bias it
- // with a position dependent number with mean 0.037109375. These number
- // should allow for long chains of accumulation without being too close
- // to zero or too large to accumulate all numbers accurately. Only do
- // this for large literals where the number of elements is much greater
- // than 47 otherwise only negative values are produced.
- //
- // The value is positionally biased using a product of the indices. Add
- // one to each index value to avoid collapsing to zero if any of the
- // indices are zero.
- int64 index_product = 1;
- for (int64 i : indices) {
- index_product *= (1 + i);
- }
- const int64 negative_bias = should_index_bias ? 47 : 0;
- FloatT index_bias =
- static_cast<FloatT>(index_product % 113 - negative_bias) /
- static_cast<FloatT>(256.0f);
- return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias;
- }));
+ if (no_duplicates) {
+ // Duplicates may be generated if the number of elements in the literal
+ // exceeds the number of positive values supported by the type.
+ FloatT next_value = std::numeric_limits<FloatT>::min();
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = next_value;
+ next_value =
+ std::nextafter(next_value, std::numeric_limits<FloatT>::max());
+ }
+ std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
+ *engine);
+ } else {
+ std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = static_cast<FloatT>(generator(*engine));
+ }
+ }
}
template <typename FloatT>
void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
+ PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine,
+ no_duplicates);
}
template <>
void PopulateWithRandomFloatingPointData<half>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for half types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (half& value : literal->data<half>()) {
+ value = static_cast<half>(generator(*engine));
+ }
}
-// The standard library does not have a case for bfloat16, unsurprisingly, so we
-// handle that one specially.
template <>
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
- std::minstd_rand0* engine) {
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ // no_duplicates is ignored for bfloat types. Unique values can only be
+ // generated for arrays with fewer than ~2**16 elements and no_duplicates is
+ // best-effort anyway.
CHECK(engine != nullptr);
- CHECK_EQ(literal->shape().element_type(), BF16);
- std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
- TF_CHECK_OK(literal->Populate<bfloat16>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return static_cast<bfloat16>(generator(*engine));
- }));
+ std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+ for (bfloat16& value : literal->data<bfloat16>()) {
+ value = static_cast<bfloat16>(generator(*engine));
+ }
}
template <typename IntT>
-void PopulateWithRandomIntegralData(Literal* literal,
- std::minstd_rand0* engine) {
+void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
+ bool no_duplicates) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
- std::uniform_int_distribution<IntT> generator(
- std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
- TF_CHECK_OK(literal->Populate<IntT>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
- return generator(*engine);
- }));
+ if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) <
+ std::numeric_limits<IntT>::max()) {
+ std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(), 0);
+ std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
+ *engine);
+ } else {
+ std::uniform_int_distribution<IntT> generator(
+ std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+ for (IntT& value : literal->data<IntT>()) {
+ value = generator(*engine);
+ }
+ }
}
// Similar to MakeFakeLiteral but takes a random number generator engine to
-// enable reusing the engine across randomly generated literals.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
- const Shape& shape, std::minstd_rand0* engine) {
+// enable reusing the engine across randomly generated literals. 'no_duplicates'
+// indicates that there should be no duplicate values in each generated
+// array. This is uniqueness is best-effort only. Some types (half and bfloat16)
+// are not supported and uniqueness cannot be guaranteed if the number of
+// elements exceeds the number of different values supported by the type.
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
if (ShapeUtil::IsTuple(shape)) {
- std::vector<std::unique_ptr<Literal>> elements;
+ std::vector<Literal> elements;
for (const Shape& element_shape : shape.tuple_shapes()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
- MakeFakeLiteralInternal(element_shape, engine));
+ TF_ASSIGN_OR_RETURN(
+ Literal element,
+ MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
elements.push_back(std::move(element));
}
return LiteralUtil::MakeTupleOwned(std::move(elements));
@@ -116,48 +132,52 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
- auto literal = MakeUnique<Literal>(shape);
+ Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
+ no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<half>(&literal, engine,
+ no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<float>(&literal, engine,
+ no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(literal.get(), engine);
+ PopulateWithRandomFloatingPointData<double>(&literal, engine,
+ no_duplicates);
break;
case S8:
- PopulateWithRandomIntegralData<int8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
break;
case U8:
- PopulateWithRandomIntegralData<uint8>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
break;
case S16:
- PopulateWithRandomIntegralData<int16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
break;
case U16:
- PopulateWithRandomIntegralData<uint16>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
break;
case S32:
- PopulateWithRandomIntegralData<int32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
break;
case U32:
- PopulateWithRandomIntegralData<uint32>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
break;
case S64:
- PopulateWithRandomIntegralData<int64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
break;
case U64:
- PopulateWithRandomIntegralData<uint64>(literal.get(), engine);
+ PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
break;
case PRED: {
std::uniform_int_distribution<int> generator(0, 1);
- TF_CHECK_OK(literal->Populate<bool>(
- [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+ TF_CHECK_OK(
+ literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
return generator(*engine);
}));
break;
@@ -167,7 +187,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
break;
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
return std::move(literal);
}
@@ -176,6 +196,7 @@ enum class ConstantType { kUnknown, kZero, kOne };
// Return the constant type required by this computation, if known.
ConstantType GetInitValue(const HloComputation& computation) {
+ // TODO(b/77635120): Add init values, for min, max, and their arg variants.
const HloInstruction* const root = computation.root_instruction();
if (computation.num_parameters() != 2 || root->operand_count() != 2 ||
root->operand(0)->opcode() != HloOpcode::kParameter ||
@@ -200,16 +221,16 @@ bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
- return (
- ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
+ return ((opcode == HloOpcode::kReduceWindow && op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2) ||
+ (opcode == HloOpcode::kReduce &&
+ op_num >= instruction->operand_count() / 2));
}
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(
- tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+ std::minstd_rand0* engine) {
std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
for (int i = 0; i < index_space.size(); ++i) {
@@ -250,6 +271,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
auto converted_uses = FindConstrainedUses(dataflow, *instruction);
constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
converted_uses.end());
+ } else if (opcode == HloOpcode::kSort &&
+ instruction->operand_count() == 2 && op_num == 0) {
+ // Operand 0 of sort is the array of keys used for key/value
+ // (two-operand) kSort instructions.
+ constrained_uses.push_back(instruction);
}
}
}
@@ -260,10 +286,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
// no constrained uses in the dataflow graph. If such constraints exist,
// generate a constrained literal (either bounded in the case of indices, or
// zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
- const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
+StatusOr<Literal> CreateLiteralForConstrainedUses(
+ const absl::Span<HloInstruction* const> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
std::vector<int64> index_space;
+ bool no_duplicates = false;
bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
@@ -302,67 +329,98 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
constant_type = GetInitValue(*use->scatter());
break;
+ case HloOpcode::kSort:
+ no_duplicates = true;
+ break;
+
default:
return Unimplemented(
"Constrained operand generation not implemented for %s.",
- use->ToString().c_str());
+ use->ToString());
}
}
- if (!index_space.empty() && needs_constant) {
- return Unimplemented(
- "Conflicting operand generation constraints. Dynamically indexes a "
- "shape and is the init value of a reduction.");
+ int constraint_count = 0;
+ constraint_count += no_duplicates ? 1 : 0;
+ constraint_count += !index_space.empty() ? 1 : 0;
+ constraint_count += needs_constant ? 1 : 0;
+ if (constraint_count > 1) {
+ return Unimplemented("Conflicting operand generation constraints.");
}
if (!index_space.empty()) {
return MakeRandomIndex(index_space, engine);
} else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
- return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type());
case ConstantType::kOne:
- return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type());
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine,
+ /*no_duplicates=*/false);
}
} else {
- return MakeFakeLiteralInternal(param.shape(), engine);
+ return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates);
}
}
// Given a module entry parameter, use the dataflow analysis to see if a
// special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
- const HloDataflowAnalysis& dataflow, const HloInstruction& param,
- std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+ const HloInstruction& param,
+ std::minstd_rand0* engine) {
const auto constrained_uses = FindConstrainedUses(dataflow, param);
return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
}
} // namespace
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random) {
- auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
- return MakeFakeLiteralInternal(shape, engine.get());
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
+ auto engine =
+ pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
+ return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
+}
+
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random) {
+ auto engine =
+ pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
+ return MakeFakeArguments(module, engine.get());
}
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
const auto params = module->entry_computation()->parameter_instructions();
- auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
- std::vector<std::unique_ptr<Literal>> arguments(params.size());
+ std::vector<Literal> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
- .ValueOrDie();
+ arguments[i] =
+ MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
}
return std::move(arguments);
}
-Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) {
- return HloVerifier(allow_mixed_precision).Run(module).status();
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision) {
+ return HloVerifier(/*layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision)
+ .Run(module)
+ .status();
}
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index e59f215a9a..b3c8a73905 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -20,14 +20,14 @@ limitations under the License.
#include <memory>
#include <random>
+#include "absl/memory/memory.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/platform.h"
namespace xla {
@@ -57,14 +57,23 @@ class PseudorandomGenerator {
// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
- bool pseudo_random = true);
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
+ bool pseudo_random = true);
// Generates a vector of arguments containing fake data. The number, shape and
// layout of the arguments is appropriate for given HLO module.
//
-// Will handle special cases such as making sure that indices used for dynamic
-// slices are bounded, reduces that call adds use 0 as an init value, etc.
+// A best-effort attempt is made to generate the data in a way which produce
+// stable computation results across platforms. Specifically:
+//
+// (1) Init values of reductions should be the identity of the reduction
+// computation.
+//
+// (2) Indices of dynamic slices and update slices should be in bounds.
+//
+// (3) Keys of key/value sorts should contain no duplicates.
+//
+// These constraints are best-effort only.
//
// If pseudo_random is true, the generated numbers will be generated
// deterministically in a pseudo random way unless the values are constrated to
@@ -75,14 +84,26 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
// TODO(b/79942829): Make interesting argument generation fast enough that using
// pseudo_random does not save any noticeable amount of time so that the
// parameter can be removed.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
- HloModule* const module, bool pseudo_random = true);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ bool pseudo_random = true);
+
+// Overload which accepts a random number generator. This enables generation of
+// different random values with sequential calls to MakeFakeArguments by reusing
+// the same generator.
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+ std::minstd_rand0* engine);
// Check that a given module satisfies various constraints before trying to
// execute it.
-Status VerifyHloModule(HloModule* const module,
- bool allow_mixed_precision = false);
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision);
+// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of
+// the LHS with dimension 0 of the RHS with no batch dimensions.
+// Both LHS and the RHS must be of rank 2.
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 64d9e2031e..181e5cbe29 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -84,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 3);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -113,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
})")
.ValueOrDie();
- TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
MakeFakeArguments(module.get()));
ASSERT_EQ(args.size(), 5);
- const Literal& index_arg = *args[0];
+ const Literal& index_arg = args[0];
EXPECT_EQ(index_arg.Get<int32>({0}), 0);
@@ -127,5 +128,51 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
EXPECT_LE(index_arg.Get<int32>({2}), 3);
}
+XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) {
+ // Inputs which are sort keys in key/value sorts should have no duplicates.
+ auto module = ParseHloString(R"(
+HloModule sort.148.1589
+
+ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
+ %parameter.0 = f32[1048576]{0} parameter(0)
+ %parameter.1 = s32[1048576]{0} parameter(1)
+ ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+}
+)")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 2);
+ const Literal& key_arg = args[0];
+
+ tensorflow::gtl::FlatSet<uint32> key_set;
+ for (const float& value : key_arg.data<float>()) {
+ EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+ }
+}
+
+XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) {
+ // Inputs which are sort keys in key/value sorts should have no duplicates.
+ auto module = ParseHloString(R"(
+HloModule sort.148.1589
+
+ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
+ %parameter.0 = s32[1048576]{0} parameter(0)
+ %parameter.1 = s32[1048576]{0} parameter(1)
+ ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+}
+)")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 2);
+ const Literal& key_arg = args[0];
+
+ tensorflow::gtl::FlatSet<int32> key_set;
+ for (const int32& value : key_arg.data<int32>()) {
+ EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 2bdbd08309..b34fd0f2e8 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -15,11 +15,10 @@ limitations under the License.
#include <array>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -35,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, TokenTree) {
@@ -51,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {}));
- EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
}
XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -67,7 +64,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -84,7 +84,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
"param"));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -101,7 +104,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr(
@@ -185,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(true);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(42, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(42, result.Get<int32>({}));
}
{
@@ -196,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
std::unique_ptr<HloModule> module,
HloRunner::CreateModuleFromString(module_string, debug_options));
auto arg = LiteralUtil::CreateR0<bool>(false);
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- Execute(std::move(module), {arg.get()}));
- EXPECT_EQ(7, result->Get<int32>({}));
+ TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+ EXPECT_EQ(7, result.Get<int32>({}));
}
}
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 125513ddfd..d6641d257a 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase {
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR0<uint32>(42);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
+ LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
- *result);
+ result);
}
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1<float>(test_vector);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
+ LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
}
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
- const Shape& shape = literal->shape();
+ Literal literal = LiteralUtil::CreateR1U8(test_string);
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_EQ(result->GetR1U8AsString(), test_string);
+ EXPECT_EQ(result.GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
- const Shape& shape = literal->shape();
+ const Shape& shape = literal.shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+ Literal literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
- LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
+ LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTuple({});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
+ Literal literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ Literal literal = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<complex64>(
- {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
- .get(),
- LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
}
XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
@@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
// supported.
auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
- EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+ EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
}
XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
const int64 kIterationCount = 5000;
- std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(123.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
- std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(456.0f).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
- LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
- .get(),
- LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
-
- auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
- auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+ Literal literal1 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(123.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+ Literal literal2 = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(456.0f),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
auto stream1 = stream_;
auto stream2 = stream_->GetOrCreateSubStream();
- std::unique_ptr<Literal> result1, result2;
+ Literal result1, result2;
// Round trip literals through device in multiple streams asynchronously.
for (int i = 0; i < kIterationCount; ++i) {
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
device_buffer1));
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
device_buffer2));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result1,
+ Literal this_result1,
transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> this_result2,
+ Literal this_result2,
transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
result1 = std::move(this_result1);
result2 = std::move(this_result2);
}
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
- EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
}
class TransferDeviceToHostBenchmark : public TransferManagerTest {
@@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Literal> result,
+ Literal result,
transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
}
tensorflow::testing::StopTiming();
@@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest {
tensorflow::testing::StopTiming();
SetUp();
- std::vector<std::unique_ptr<Literal>> tuple_elements;
+ std::vector<Literal> tuple_elements;
for (int i = 0; i < num_tuple_elements; ++i) {
tuple_elements.push_back(
LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
}
- std::unique_ptr<Literal> literal =
- LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
- auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal.shape());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
- TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
device_buffer));
}
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 97bbf80aff..619d2a388b 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <initializer_list>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -50,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
@@ -65,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
- LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar1),
+ LiteralUtil::CreateR0<float>(constant_scalar2)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -87,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -101,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
- LiteralUtil::CreateR1<float>({}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
@@ -112,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
@@ -195,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>(constant_matrix).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>(constant_matrix),
+ LiteralUtil::CreateR1<float>(constant_vector)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -217,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
- LiteralUtil::CreateR0<bool>(!direction).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(direction),
+ LiteralUtil::CreateR0<bool>(!direction)});
- ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+ ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
@@ -286,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -331,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
- LiteralUtil::CreateR1<float>(vec2).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -407,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
@@ -422,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) {
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
- auto expected =
- LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -445,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::MakeTuple(
- {
- LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
- })
- .get(),
- LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+ }),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
@@ -483,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
- .get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<complex64>({1, 2}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
- {{10000, 20000}, {30000, 40000}}})
- .get()})
- .get()}))
+ {{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
- *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = MakeUnique<Literal>(sum->shape());
- ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
- return sum->Get<complex64>(indexes) *
- (indexes[indexes.size() - 1] == 0
- ? complex64(1, 2)
- : complex64(1, -2));
- })
+ Literal prod(sum.shape());
+ ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+ return sum.Get<complex64>(indexes) *
+ (indexes[indexes.size() - 1] == 0
+ ? complex64(1, 2)
+ : complex64(1, -2));
+ })
.ok());
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
- LiteralUtil::CreateR0<complex64>({123, 456}).get()});
- ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+ LiteralUtil::CreateR0<complex64>({123, 456})});
+ ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -540,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
.ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
- auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+ auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ result));
}
// Disabled on interpreter due to lack of outfeed.
@@ -580,16 +570,15 @@ XLA_TEST_F(TupleHloTest,
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
- {param0.get(), param1.get(), param1.get(),
- param0.get(), param4.get()})
+ {&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = Literal::CreateFromShape(expected->shape());
+ auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), *literal));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+ backend().default_stream_executor(), expected.shape(), literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 20ae68ab74..4fbd7f2fb1 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper<complex64>() {
{-inf<float>(), 0}});
Abs(arg);
- std::unique_ptr<Literal> expected =
+ Literal expected =
LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper<complex64>() {
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
Sign(arg);
- std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
+ Literal expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
template <>
@@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper<complex64>() {
auto abs = Abs(arg);
Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
@@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) {
Add(sgnc, ConvertElementType(
Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
- std::unique_ptr<Literal> expected =
- LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
- ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+ Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
+ ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, SignTestR1) {
@@ -190,25 +188,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
SignAbsTestHelper<complex64>();
}
-XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
- XlaBuilder builder(TestName());
- auto arg = ConstantR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- Abs(arg);
-
- ComputeAndCompareR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
-}
-
-XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
- XlaBuilder builder(TestName());
- auto arg = ConstantR1<unsigned int>(
- &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- Sign(arg);
-
- ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
-}
-
XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
XlaBuilder builder(TestName());
auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}});
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..7abd8651d5 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// have all reached 2.0.
auto expected_data =
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = LiteralUtil::MakeTuple({expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple(
+ {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
- auto expected = LiteralUtil::MakeTuple(
- {expected_counter.get(), expected_predicate.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+ auto expected =
+ LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
}
TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR0<int32>(7);
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({&expected_element, &expected_element});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
auto add1 = LiteralUtil::CreateR0<int32>(15);
auto add2 = LiteralUtil::CreateR0<int32>(16);
- auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(
- auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
XlaBuilder builder(TestName());
While(condition, body, ConstantR0<int32>(&builder, 0));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
ComputeAndCompareR0<int32>(&builder, 2, {});
}
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 11f3efb1f3..db5a824de0 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -16,6 +16,10 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -29,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -81,8 +84,7 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
- tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
- {}) {
+ absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
string separator = "[^:]*:: +";
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
@@ -99,7 +101,7 @@ Status ParseOneProfileOutputLine(
string match_opcode =
expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])";
- string regexp_pattern = tensorflow::strings::StrCat(
+ string regexp_pattern = absl::StrCat(
" +", match_cycles, separator, match_usecs, separator, match_flops,
separator, match_trops, separator, match_bytes_per_sec, separator,
match_bytes_per_cycle, separator, match_opcode);
@@ -116,7 +118,7 @@ Status ParseOneProfileOutputLine(
", Regexp: ", regexp_pattern);
}
- if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
+ if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
}
@@ -142,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
@@ -169,10 +171,10 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
ServiceExecutableRunOptions run_options(
exec_run_options, /*borrow_stream=*/nullptr,
backend->eigen_intra_op_thread_pool());
+ std::vector<const ShapedBuffer*> args = {&lhs_arg, &rhs_arg};
TF_ASSERT_OK_AND_ASSIGN(
auto execution_result,
- executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg},
- &hlo_execution_profile));
+ executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile));
TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
(void)execution_result;
@@ -204,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
rhs_shape);
std::vector<string> profile_output_lines =
- tensorflow::str_util::Split(profile_output, '\n');
+ absl::StrSplit(profile_output, '\n');
gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
@@ -291,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
matrix_shape);
std::vector<string> profile_output_lines =
- tensorflow::str_util::Split(profile_output, '\n');
+ absl::StrSplit(profile_output, '\n');
auto while_body_profile_start =
- c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
- return tensorflow::str_util::StartsWith(s,
- "Execution profile for body");
+ absl::c_find_if(profile_output_lines, [](absl::string_view s) {
+ return absl::StartsWith(s, "Execution profile for body");
});
ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
- auto while_body_profile_end =
- std::find_if(while_body_profile_start, profile_output_lines.end(),
- [](tensorflow::StringPiece s) {
- return tensorflow::str_util::StartsWith(
- s, "********** microseconds report **********");
- });
+ auto while_body_profile_end = std::find_if(
+ while_body_profile_start, profile_output_lines.end(),
+ [](absl::string_view s) {
+ return absl::StartsWith(s, "********** microseconds report **********");
+ });
// We emit a blank line before the "********** microseconds report **********"
// line.
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index a075195618..15603619b6 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) {
// If the --benchmarks flag is passed in then only run the benchmarks, not the
// tests.
for (int i = 1; i < argc; i++) {
- tensorflow::StringPiece arg(argv[i]);
- if (arg == "--benchmarks" ||
- tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
+ absl::string_view arg(argv[i]);
+ if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) {
const char* pattern = nullptr;
- if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
+ if (absl::StartsWith(arg, "--benchmarks=")) {
pattern = argv[i] + strlen("--benchmarks=");
} else {
// Handle flag of the form '--benchmarks foo' (no '=').
- if (i + 1 >= argc ||
- tensorflow::str_util::StartsWith(argv[i + 1], "--")) {
+ if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) {
LOG(ERROR) << "--benchmarks flag requires an argument.";
return 2;
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 897123d760..cdde88c135 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -20,25 +20,27 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/memory/memory.h"
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
#include "tensorflow/compiler/xla/literal.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
- tensorflow::StringPiece path) {
- CHECK(!tensorflow::str_util::EndsWith(path, ".gz"))
+StatusOr<Literal> TextLiteralReader::ReadPath(absl::string_view path) {
+ CHECK(!absl::EndsWith(path, ".gz"))
<< "TextLiteralReader no longer supports reading .gz files";
std::unique_ptr<tensorflow::RandomAccessFile> file;
Status s =
@@ -54,34 +56,7 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
: file_(file) {}
-namespace {
-// This is an optimized version of tensorflow::str_util::Split which uses
-// StringPiece for the delimited strings and uses an out parameter for the
-// result to avoid vector creation/destruction.
-void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim,
- std::vector<tensorflow::StringPiece>* result) {
- result->clear();
-
- if (text.empty()) {
- return;
- }
-
- // The following loop is a little strange: its bound is text.size() + 1
- // instead of the more typical text.size().
- // The final iteration of the loop (when i is equal to text.size()) handles
- // the trailing token.
- size_t token_start = 0;
- for (size_t i = 0; i < text.size() + 1; i++) {
- if (i == text.size() || text[i] == delim) {
- tensorflow::StringPiece token(text.data() + token_start, i - token_start);
- result->push_back(token);
- token_start = i + 1;
- }
- }
-}
-} // namespace
-
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+StatusOr<Literal> TextLiteralReader::ReadAllLines() {
tensorflow::io::RandomAccessInputStream stream(file_.get());
tensorflow::io::BufferedInputStream buf(&stream, 65536);
string shape_string;
@@ -90,63 +65,57 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
return s;
}
- tensorflow::StringPiece sp(shape_string);
- if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) {
- string tmp = std::string(sp);
- shape_string = tmp;
- }
+ absl::StripAsciiWhitespace(&shape_string);
TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
if (shape.element_type() != F32) {
return Unimplemented(
"unsupported element type for text literal reading: %s",
- ShapeUtil::HumanString(shape).c_str());
+ ShapeUtil::HumanString(shape));
}
- auto result = MakeUnique<Literal>(shape);
+ Literal result(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
- result->PopulateWithValue<float>(fill);
- std::vector<tensorflow::StringPiece> pieces;
- std::vector<tensorflow::StringPiece> coordinates;
+ result.PopulateWithValue<float>(fill);
+ std::vector<absl::string_view> pieces;
+ std::vector<absl::string_view> coordinates;
std::vector<int64> coordinate_values;
string line;
while (buf.ReadLine(&line).ok()) {
- SplitByDelimToStringPieces(line, ':', &pieces);
- tensorflow::StringPiece coordinates_string = pieces[0];
- tensorflow::StringPiece value_string = pieces[1];
- tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string);
- tensorflow::str_util::RemoveWhitespaceContext(&value_string);
- if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) {
+ pieces = absl::StrSplit(line, ':');
+ absl::string_view coordinates_string =
+ absl::StripAsciiWhitespace(pieces[0]);
+ absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]);
+ if (!absl::ConsumePrefix(&coordinates_string, "(")) {
return InvalidArgument(
- "expected '(' at the beginning of coordinates: \"%s\"", line.c_str());
+ "expected '(' at the beginning of coordinates: \"%s\"", line);
}
- if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) {
+ if (!absl::ConsumeSuffix(&coordinates_string, ")")) {
return InvalidArgument("expected ')' at the end of coordinates: \"%s\"",
- line.c_str());
+ line);
}
float value;
- if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(),
- &value)) {
+ if (!absl::SimpleAtof(value_string, &value)) {
return InvalidArgument("could not parse value as float: \"%s\"",
- std::string(value_string).c_str());
+ value_string);
}
- SplitByDelimToStringPieces(coordinates_string, ',', &coordinates);
+ coordinates = absl::StrSplit(coordinates_string, ',');
coordinate_values.clear();
- for (tensorflow::StringPiece piece : coordinates) {
+ for (absl::string_view piece : coordinates) {
int64 coordinate_value;
- if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) {
+ if (!absl::SimpleAtoi(piece, &coordinate_value)) {
return InvalidArgument(
"could not parse coordinate member as int64: \"%s\"",
- std::string(piece).c_str());
+ std::string(piece));
}
coordinate_values.push_back(coordinate_value);
}
if (coordinate_values.size() != shape.dimensions_size()) {
return InvalidArgument(
- "line did not have expected number of coordinates; want %d got %zu: "
+ "line did not have expected number of coordinates; want %d got %u: "
"\"%s\"",
- shape.dimensions_size(), coordinate_values.size(), line.c_str());
+ shape.dimensions_size(), coordinate_values.size(), line);
}
- result->Set<float>(coordinate_values, value);
+ result.Set<float>(coordinate_values, value);
}
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index 708e8c80d8..c40b43279f 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -18,11 +18,11 @@ limitations under the License.
#include <memory>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,8 +41,7 @@ class TextLiteralReader {
public:
// See class comment -- reads a file in its entirety (there must be only one
// literal in the text file path provided).
- static StatusOr<std::unique_ptr<Literal>> ReadPath(
- tensorflow::StringPiece path);
+ static StatusOr<Literal> ReadPath(absl::string_view path);
private:
// Ownership of file is transferred.
@@ -50,7 +49,7 @@ class TextLiteralReader {
// Parses a shape string on the first line, followed by lines of values to the
// end of the file.
- StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+ StatusOr<Literal> ReadAllLines();
// Owns the file being read
std::unique_ptr<tensorflow::RandomAccessFile> file_;
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 92f9b4f9f0..1fab4e3a08 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) {
tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
.ok());
- std::unique_ptr<Literal> literal =
- TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+ Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
EXPECT_TRUE(
- ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
- EXPECT_EQ(42.5, literal->Get<float>({0, 0, 0}));
- EXPECT_EQ(43.5, literal->Get<float>({0, 0, 1}));
- EXPECT_EQ(44.5, literal->Get<float>({0, 0, 2}));
- EXPECT_EQ(45.5, literal->Get<float>({0, 1, 0}));
- EXPECT_EQ(46.5, literal->Get<float>({0, 1, 1}));
- EXPECT_EQ(47.5, literal->Get<float>({0, 1, 2}));
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape()));
+ EXPECT_EQ(42.5, literal.Get<float>({0, 0, 0}));
+ EXPECT_EQ(43.5, literal.Get<float>({0, 0, 1}));
+ EXPECT_EQ(44.5, literal.Get<float>({0, 0, 2}));
+ EXPECT_EQ(45.5, literal.Get<float>({0, 1, 0}));
+ EXPECT_EQ(46.5, literal.Get<float>({0, 1, 1}));
+ EXPECT_EQ(47.5, literal.Get<float>({0, 1, 2}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 24e0784741..7289ae7df6 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -17,23 +17,23 @@ limitations under the License.
#include <string>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
-/* static */ Status TextLiteralWriter::WriteToPath(
- const Literal& literal, tensorflow::StringPiece path) {
+/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal,
+ absl::string_view path) {
std::unique_ptr<tensorflow::WritableFile> f;
- auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f);
+ auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f);
if (!s.ok()) {
return s;
}
@@ -46,16 +46,14 @@ namespace xla {
Status status;
tensorflow::WritableFile* f_ptr = f.get();
literal.EachCellAsString(
- [f_ptr, &status](tensorflow::gtl::ArraySlice<int64> indices,
- const string& value) {
+ [f_ptr, &status](absl::Span<const int64> indices, const string& value) {
if (!status.ok()) {
return;
}
- string coordinates = tensorflow::strings::StrCat(
- "(", tensorflow::str_util::Join(indices, ", "), ")");
+ string coordinates =
+ absl::StrCat("(", absl::StrJoin(indices, ", "), ")");
- status = f_ptr->Append(
- tensorflow::strings::StrCat(coordinates, ": ", value, "\n"));
+ status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n"));
});
auto ignored = f->Close();
return status;
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
index 159ac1b7e1..34de8572d6 100644
--- a/tensorflow/compiler/xla/text_literal_writer.h
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -37,8 +37,7 @@ namespace xla {
// This should be readable by xla::TextLiteralReader.
class TextLiteralWriter {
public:
- static Status WriteToPath(const Literal& literal,
- tensorflow::StringPiece path);
+ static Status WriteToPath(const Literal& literal, absl::string_view path);
private:
TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter);
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 4ea02faffc..5cbaf2fcc1 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) {
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
- ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+ ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path));
string contents;
TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
&contents));
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 40d28a57bf..3a086c66bb 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -24,6 +24,8 @@ tf_cc_binary(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/base",
+ "@com_google_absl//absl/strings",
],
)
@@ -42,6 +44,7 @@ cc_library(
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -67,6 +70,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -94,6 +98,7 @@ cc_library(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
alwayslink = True,
)
@@ -172,6 +177,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
@@ -191,6 +197,9 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:span",
],
)
@@ -210,6 +219,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
],
)
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
index f20dcef382..c866a13de7 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -38,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -46,7 +46,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -77,8 +77,8 @@ int main(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
index f0af0580c1..4375e7c138 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -19,6 +19,9 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -29,9 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -44,16 +44,14 @@ class OperationDumper : public DfsHloVisitorWithDefault {
explicit OperationDumper(const string& path) : path_(path) {}
Status DefaultAction(HloInstruction* hlo) override {
- string params = tensorflow::str_util::Join(
+ string params = absl::StrJoin(
hlo->operands(), ", ", [](string* out, const HloInstruction* operand) {
- tensorflow::strings::StrAppend(
- out, ShapeUtil::HumanString(operand->shape()));
+ absl::StrAppend(out, ShapeUtil::HumanString(operand->shape()));
});
// Spit `op_name(params...) -> result_type :: path` to stdout.
- std::cout << tensorflow::strings::Printf(
- "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(),
- params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(),
- path_.c_str());
+ std::cout << absl::StrFormat("%s :: (%s) -> %s :: %s\n",
+ HloOpcodeString(hlo->opcode()), params,
+ ShapeUtil::HumanString(hlo->shape()), path_);
return Status::OK();
}
@@ -61,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault {
string path_;
};
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
@@ -106,8 +104,8 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index f03e1b1f96..723569862c 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -34,7 +34,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
+void RealMain(absl::Span<char* const> args, bool compile) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
LocalService* local_service =
ClientLibrary::GetXlaService(client->platform());
@@ -102,8 +102,8 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage;
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args, compile);
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
index dc5c106d02..07ef5ff656 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,7 +45,7 @@ using tensorflow::Env;
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -78,8 +78,8 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index eb7bff053b..0c3ec5934e 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/casts.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/platform/env.h"
@@ -67,9 +67,8 @@ int main(int argc, char** argv) {
floats.push_back(value);
}
- tensorflow::StringPiece content(
- tensorflow::bit_cast<const char*>(floats.data()),
- floats.size() * sizeof(float));
+ tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()),
+ floats.size() * sizeof(float));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
output_file, content));
return 0;
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index b4774233e5..0c41f227b3 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@@ -59,7 +60,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -121,11 +121,10 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
}
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
- Literal::CreateFromProto(proto));
+ TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer data,
- client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
scoped_shaped_buffer_arguments.push_back(std::move(data));
}
for (const auto& argument : scoped_shaped_buffer_arguments) {
@@ -160,13 +159,13 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// concurrent infeed occur via the fake_infeed_shape, or when
// --generate_fake_infeed is passed and there exists an infeed operation in
// the HloSnapshot.
- tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
- std::unique_ptr<Literal> data;
+ absl::optional<tensorflow::thread::ThreadPool> pool;
+ Literal data;
if (provide_infeed) {
data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
}
auto transfer_infeed = [&data, client]() {
- TF_CHECK_OK(client->TransferToInfeed(*data));
+ TF_CHECK_OK(client->TransferToInfeed(data));
};
if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
@@ -196,7 +195,7 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
StreamExecutorMemoryAllocator allocator(
client->platform(),
{client->platform()->ExecutorForDevice(0).ValueOrDie()});
- tensorflow::gtl::optional<ScopedShapedBuffer> result;
+ absl::optional<ScopedShapedBuffer> result;
for (int i = 0; i < opts.num_runs; ++i) {
// If xla_hlo_profile is enabled, print a noisy message before the last run,
// making it easier to separate this profile from the others in the logspam.
@@ -214,9 +213,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< "s: " << module.hlo().hlo_module().name();
}
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ TF_ASSIGN_OR_RETURN(Literal result_literal,
client->ShapedBufferToLiteral(*result));
- return std::move(*result_literal);
+ return result_literal;
}
StatusOr<HloSnapshot> ParseInputFile(const string& filename,
@@ -250,10 +249,10 @@ StatusOr<HloSnapshot> ParseInputFile(const string& filename,
}
fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
filename.c_str());
- return InvalidArgument("Could not parse %s.", filename.c_str());
+ return InvalidArgument("Could not parse %s.", filename);
}
-int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
+int RealMain(absl::Span<char* const> args, const Options& opts) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
int exit_status = EXIT_SUCCESS;
@@ -305,11 +304,11 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
result.ToString().c_str());
auto& snapshot = snapshots[i];
if (snapshot.has_result()) {
- std::unique_ptr<Literal> literal =
+ Literal literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
- literal->ToString().c_str());
+ literal.ToString().c_str());
}
}
}
@@ -344,7 +343,7 @@ int main(int argc, char** argv) {
LOG(QFATAL) << usage;
}
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
return xla::tools::RealMain(args, opts);
}
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index 51909190a3..4f8852f8c1 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -40,8 +40,8 @@ int main(int argc, char **argv) {
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal_proto));
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
- fprintf(stderr, "%s\n", literal->ToString().c_str());
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
}
diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc
index 4e53fafcc9..cdf306dfd1 100644
--- a/tensorflow/compiler/xla/tools/show_signature.cc
+++ b/tensorflow/compiler/xla/tools/show_signature.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,7 +45,7 @@ limitations under the License.
namespace xla {
namespace tools {
-void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+void RealMain(absl::Span<char* const> args) {
Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
HloSnapshot module;
@@ -66,8 +66,8 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- tensorflow::gtl::ArraySlice<char*> args(argv, argc);
- args.pop_front(); // Pop off the binary name, argv[0]
+ absl::Span<char* const> args(argv, argc);
+ args.remove_prefix(1); // Pop off the binary name, argv[0]
xla::tools::RealMain(args);
return 0;
}
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 48c8374811..4b5c276bdf 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -36,16 +36,16 @@ int main(int argc, char **argv) {
LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
}
- std::unique_ptr<xla::Literal> literal =
+ xla::Literal literal =
xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
- LOG(INFO) << "literal: " << *literal;
- fprintf(stderr, "%s\n", literal->ToString().c_str());
- if (literal->shape().element_type() == xla::F32) {
- float min = *std::min_element(literal->data<float>().begin(),
- literal->data<float>().end());
- float max = *std::max_element(literal->data<float>().begin(),
- literal->data<float>().end());
+ LOG(INFO) << "literal: " << literal;
+ fprintf(stderr, "%s\n", literal.ToString().c_str());
+ if (literal.shape().element_type() == xla::F32) {
+ float min = *std::min_element(literal.data<float>().begin(),
+ literal.data<float>().end());
+ float max = *std::max_element(literal.data<float>().begin(),
+ literal.data<float>().end());
fprintf(stderr, "min: %a=%f\n", min, min);
fprintf(stderr, "max: %a=%f\n", max, max);
}
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index e43498e381..68cab7387c 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -18,12 +18,13 @@ limitations under the License.
#include <stdarg.h>
#include <numeric>
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -54,111 +55,28 @@ ScopedLoggingTimer::~ScopedLoggingTimer() {
}
}
-Status AddStatus(Status prior, tensorflow::StringPiece context) {
+Status AddStatus(Status prior, absl::string_view context) {
CHECK(!prior.ok());
- return Status{prior.code(), tensorflow::strings::StrCat(
- context, ": ", prior.error_message())};
+ return Status{prior.code(),
+ absl::StrCat(context, ": ", prior.error_message())};
}
-Status AppendStatus(Status prior, tensorflow::StringPiece context) {
+Status AppendStatus(Status prior, absl::string_view context) {
CHECK(!prior.ok());
- return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(),
- ": ", context)};
+ return Status{prior.code(),
+ absl::StrCat(prior.error_message(), ": ", context)};
}
-// Implementation note: we can't common these out (without using macros) because
-// they all need to va_start/va_end their varargs in their frame.
-
-Status InvalidArgumentV(const char* format, va_list args) {
- string message;
- tensorflow::strings::Appendv(&message, format, args);
- return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
-}
-
-Status InvalidArgument(const char* format, ...) {
- va_list args;
- va_start(args, format);
- Status result = InvalidArgumentV(format, args);
- va_end(args);
- return result;
-}
-
-Status Unimplemented(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::Unimplemented(message));
-}
-
-Status InternalError(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::Internal(message));
-}
-
-Status FailedPrecondition(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message));
-}
-
-Status Cancelled(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::Cancelled(message));
-}
-
-Status ResourceExhausted(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message));
-}
-
-Status NotFound(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::NotFound(message));
-}
-
-Status Unavailable(const char* format, ...) {
- string message;
- va_list args;
- va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
- va_end(args);
- return WithLogBacktrace(tensorflow::errors::Unavailable(message));
-}
-
-string Reindent(tensorflow::StringPiece original,
- const tensorflow::StringPiece indentation) {
- std::vector<string> pieces = tensorflow::str_util::Split(
- tensorflow::StringPiece(original.data(), original.size()), '\n');
- return tensorflow::str_util::Join(
- pieces, "\n", [indentation](string* out, string s) {
- tensorflow::StringPiece piece(s);
- tensorflow::str_util::RemoveWhitespaceContext(&piece);
- tensorflow::strings::StrAppend(out, indentation, piece);
- });
+string Reindent(absl::string_view original,
+ const absl::string_view indentation) {
+ std::vector<string> pieces =
+ absl::StrSplit(absl::string_view(original.data(), original.size()), '\n');
+ return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) {
+ absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s));
+ });
}
-bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
+bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
}
@@ -172,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
}
std::vector<int64> InversePermutation(
- tensorflow::gtl::ArraySlice<int64> input_permutation) {
+ absl::Span<const int64> input_permutation) {
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
std::vector<int64> output_permutation(input_permutation.size(), -1);
for (size_t i = 0; i < input_permutation.size(); ++i) {
@@ -181,8 +99,8 @@ std::vector<int64> InversePermutation(
return output_permutation;
}
-std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
- tensorflow::gtl::ArraySlice<int64> p2) {
+std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
+ absl::Span<const int64> p2) {
CHECK_EQ(p1.size(), p2.size());
std::vector<int64> output;
for (size_t i = 0; i < p1.size(); ++i) {
@@ -191,7 +109,7 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
return output;
}
-bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) {
+bool IsIdentityPermutation(absl::Span<const int64> permutation) {
for (int64 i = 0; i < permutation.size(); ++i) {
if (permutation[i] != i) {
return false;
@@ -212,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) {
}
PaddingConfig MakeEdgePaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ absl::Span<const std::pair<int64, int64>> padding) {
PaddingConfig padding_config;
for (const std::pair<int64, int64>& dim : padding) {
auto dimension = padding_config.add_dimensions();
@@ -234,20 +152,20 @@ bool HasInteriorPadding(const PaddingConfig& config) {
namespace {
string HumanReadableNumOps(double flops, double nanoseconds,
- tensorflow::StringPiece op_prefix) {
+ absl::string_view op_prefix) {
if (nanoseconds == 0) {
- return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s");
+ return absl::StrCat("NaN ", op_prefix, "OP/s");
}
double nano_flops = flops / nanoseconds;
string throughput = tensorflow::strings::HumanReadableNum(
static_cast<int64>(nano_flops * 1e9));
- tensorflow::StringPiece sp(throughput);
+ absl::string_view sp(throughput);
// Use the more common "G(FLOPS)", rather than "B(FLOPS)"
- if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case
- tensorflow::str_util::EndsWith(sp, "b")) {
+ if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case
+ absl::EndsWith(sp, "b")) {
*throughput.rbegin() = 'G';
}
- throughput += tensorflow::strings::StrCat(op_prefix, "OP/s");
+ throughput += absl::StrCat(op_prefix, "OP/s");
return throughput;
}
} // namespace
@@ -260,8 +178,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
return HumanReadableNumOps(trops, nanoseconds, "TR");
}
-void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
- int lineno) {
+void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
const int orig_sev = sev;
if (sev == tensorflow::FATAL) {
sev = tensorflow::ERROR;
@@ -275,7 +192,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
size_t cur = 0;
while (cur < text.size()) {
size_t eol = text.find('\n', cur);
- if (eol == tensorflow::StringPiece::npos) {
+ if (eol == absl::string_view::npos) {
eol = text.size();
}
auto msg = text.substr(cur, eol - cur);
@@ -290,14 +207,13 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
}
}
-int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
+int64 Product(absl::Span<const int64> xs) {
return std::accumulate(xs.begin(), xs.end(), static_cast<int64>(1),
std::multiplies<int64>());
}
-std::vector<std::pair<int64, int64>> CommonFactors(
- tensorflow::gtl::ArraySlice<int64> a,
- tensorflow::gtl::ArraySlice<int64> b) {
+std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
+ absl::Span<const int64> b) {
CHECK_EQ(Product(a), Product(b));
if (0 == Product(a)) {
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 5ae099a462..8ce7416474 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -24,17 +24,20 @@ limitations under the License.
#include <type_traits>
#include <vector>
+#include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -54,7 +57,7 @@ Status WithLogBacktrace(const Status& status);
// the InlinedVector will just behave like an std::vector<> and allocate the
// memory to store its values.
static constexpr int kInlineRank = 8;
-using DimensionVector = tensorflow::gtl::InlinedVector<int64, kInlineRank>;
+using DimensionVector = absl::InlinedVector<int64, kInlineRank>;
// RAII timer that logs with a given label the wall clock time duration in human
// readable form. This differs from base's ElapsedTimer primarily in that it
@@ -98,65 +101,63 @@ struct ScopedLoggingTimer {
uint64 start_micros;
};
-// Given a vector<T>, returns a MutableArraySlice<char> that points at its
+// Given a vector<T>, returns a Span<char> that points at its
// internals.
//
// Warning: if the vector is updated its storage pointer may change, so use this
// with caution (ideally in limited scopes with temporary lifetimes).
template <typename T>
-tensorflow::gtl::MutableArraySlice<uint8> MutableByteSlice(std::vector<T>* v) {
- return tensorflow::gtl::MutableArraySlice<uint8>(
- reinterpret_cast<uint8*>(v->data()), v->size() * sizeof(T));
+absl::Span<uint8> MutableByteSlice(std::vector<T>* v) {
+ return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
+ v->size() * sizeof(T));
}
// Turns an immutable slice of type T into an immutable slice of bytes with the
// same byte size.
template <typename T>
-tensorflow::gtl::ArraySlice<uint8> CastToByteSlice(
- tensorflow::gtl::ArraySlice<T> slice) {
- return tensorflow::gtl::ArraySlice<uint8>(
- reinterpret_cast<const uint8*>(slice.data()), slice.size() * sizeof(T));
+absl::Span<const uint8> CastToByteSlice(absl::Span<const T> slice) {
+ return absl::Span<const uint8>(reinterpret_cast<const uint8*>(slice.data()),
+ slice.size() * sizeof(T));
}
// Casts a byte slice to a non-byte type T, checking that the original slice
// length is a multiple of sizeof(T).
template <typename T>
-tensorflow::gtl::ArraySlice<T> CastByteSlice(
- tensorflow::gtl::ArraySlice<uint8> slice) {
+absl::Span<const T> CastByteSlice(absl::Span<const uint8> slice) {
CHECK_EQ(0, slice.size() % sizeof(T));
- return tensorflow::gtl::ArraySlice<T>(
- reinterpret_cast<const T*>(slice.data()), slice.size() / sizeof(T));
+ return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
+ slice.size() / sizeof(T));
}
// Convenience function to force a vector to convert to an immutable slice.
template <typename T>
-tensorflow::gtl::ArraySlice<T> AsSlice(const std::vector<T>& v) {
- return tensorflow::gtl::ArraySlice<T>(v);
+absl::Span<const T> AsSlice(const std::vector<T>& v) {
+ return absl::Span<const T>(v);
}
-// Converts a mutable vector pointer into a MutableArraySlice of the same
+// Converts a mutable vector pointer into a Span of the same
// type.
template <typename T>
-tensorflow::gtl::MutableArraySlice<T> AsMutableSlice(std::vector<T>* v) {
- return tensorflow::gtl::MutableArraySlice<T>(v->data(), v->size());
+absl::Span<T> AsMutableSlice(std::vector<T>* v) {
+ return absl::Span<T>(v->data(), v->size());
}
// xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source.
// Wrapper function that gives an int64 array slice view of a repeated int64
// protobuf field.
-static inline tensorflow::gtl::ArraySlice<int64> AsInt64Slice(
+static inline absl::Span<const int64> AsInt64Slice(
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& v) {
- tensorflow::gtl::ArraySlice<tensorflow::protobuf_int64> slice(v);
- return tensorflow::gtl::ArraySlice<int64>(
- reinterpret_cast<const int64*>(slice.data()), slice.size());
+ absl::Span<const tensorflow::protobuf_int64> slice(v);
+ return absl::Span<const int64>(reinterpret_cast<const int64*>(slice.data()),
+ slice.size());
}
// As above, but for uint64 types.
-static inline tensorflow::gtl::ArraySlice<uint64> AsUInt64Slice(
+static inline absl::Span<const uint64> AsUInt64Slice(
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
- tensorflow::gtl::ArraySlice<tensorflow::protobuf_uint64> slice(v);
- return tensorflow::gtl::ArraySlice<uint64>(
- reinterpret_cast<const uint64*>(slice.data()), slice.size());
+ absl::Span<const tensorflow::protobuf_uint64> slice(v);
+ return absl::Span<const uint64>(reinterpret_cast<const uint64*>(slice.data()),
+ slice.size());
}
// Compares two containers for equality. Returns true iff the two containers
@@ -172,7 +173,7 @@ template <typename Container1T,
typename ElementType = typename Container1T::value_type>
bool ContainersEqual(const Container1T& c1,
std::initializer_list<ElementType> il) {
- tensorflow::gtl::ArraySlice<ElementType> c2{il};
+ absl::Span<const ElementType> c2{il};
return ContainersEqual(c1, c2);
}
@@ -190,9 +191,9 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2,
// source and destination. The source starting index is src_base, while the
// destination one is dest_base.
template <typename D, typename S>
-void StridedCopy(tensorflow::gtl::MutableArraySlice<D> dest, int64 dest_base,
- int64 dest_stride, tensorflow::gtl::ArraySlice<S> src,
- int64 src_base, int64 src_stride, int64 count) {
+void StridedCopy(absl::Span<D> dest, int64 dest_base, int64 dest_stride,
+ absl::Span<const S> src, int64 src_base, int64 src_stride,
+ int64 count) {
for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
dest[dest_base] = static_cast<D>(src[src_base]);
}
@@ -201,46 +202,76 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice<D> dest, int64 dest_base,
// Adds some context information to the error message in a
// Status. This is useful as Statuses are
// propagated upwards.
-Status AddStatus(Status prior, tensorflow::StringPiece context);
-Status AppendStatus(Status prior, tensorflow::StringPiece context);
-
-// Status error shorthands -- printfs the arguments to be
-// used as an error message and returns a status in the canonical
-// error space.
-Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status Cancelled(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
-
-// Passed-varargs variant of the InvalidArgument factory above.
-Status InvalidArgumentV(const char* format, va_list args);
+Status AddStatus(Status prior, absl::string_view context);
+Status AppendStatus(Status prior, absl::string_view context);
+
+// Status error shorthands -- StrFormat's the arguments to be used as an error
+// message and returns a status in the canonical error space.
+template <typename... Args>
+Status InvalidArgument(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status Unimplemented(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::Unimplemented(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status InternalError(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::Internal(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status FailedPrecondition(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status Cancelled(const absl::FormatSpec<Args...>& format, const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::Cancelled(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status ResourceExhausted(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status NotFound(const absl::FormatSpec<Args...>& format, const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::NotFound(absl::StrFormat(format, args...)));
+}
+template <typename... Args>
+Status Unavailable(const absl::FormatSpec<Args...>& format,
+ const Args&... args) {
+ return WithLogBacktrace(
+ tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
+}
template <typename... Args>
Status InvalidArgumentStrCat(Args&&... concat) {
- return InvalidArgument(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return InvalidArgument("%s", absl::StrCat(std::forward<Args>(concat)...));
}
template <typename... Args>
Status UnimplementedStrCat(Args&&... concat) {
- return Unimplemented(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return Unimplemented("%s", absl::StrCat(std::forward<Args>(concat)...));
}
template <typename... Args>
Status InternalErrorStrCat(Args&&... concat) {
- return InternalError(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return InternalError("%s", absl::StrCat(std::forward<Args>(concat)...));
}
template <typename... Args>
Status ResourceExhaustedStrCat(Args&&... concat) {
- return ResourceExhausted(
- "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+ return ResourceExhausted("%s", absl::StrCat(std::forward<Args>(concat)...));
}
// Splits the lines of the original, replaces leading whitespace with the prefix
@@ -249,11 +280,10 @@ Status ResourceExhaustedStrCat(Args&&... concat) {
//
// Note: even different amounts of leading whitespace on different lines will be
// uniformly replaced with "indentation".
-string Reindent(tensorflow::StringPiece original,
- tensorflow::StringPiece indentation);
+string Reindent(absl::string_view original, absl::string_view indentation);
// Checks whether permutation is a permutation of the [0, rank) integer range.
-bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
+bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
@@ -261,10 +291,11 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
-template <template <typename...> class C, typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- C<T> input) {
- tensorflow::gtl::ArraySlice<T> data(input);
+template <typename Container>
+std::vector<typename Container::value_type> Permute(
+ absl::Span<const int64> permutation, const Container& input) {
+ using T = typename Container::value_type;
+ absl::Span<const T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
@@ -273,27 +304,16 @@ std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
return output;
}
-// Override of the above that works around compile failures with gcc 7.1.1.
-// For details see https://github.com/tensorflow/tensorflow/issues/10843
-// Hide this workaround from MSVC as it causes ambiguous error.
-#ifndef _MSC_VER
-template <typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- const std::vector<T>& input) {
- return Permute<std::vector, T>(permutation, input);
-}
-#endif
-
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
- tensorflow::gtl::ArraySlice<int64> input_permutation);
+ absl::Span<const int64> input_permutation);
// Composes two permutations: output[i] = p1[p2[i]].
-std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
- tensorflow::gtl::ArraySlice<int64> p2);
+std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
+ absl::Span<const int64> p2);
// Returns true iff permutation == {0, 1, 2, ...}.
-bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation);
+bool IsIdentityPermutation(absl::Span<const int64> permutation);
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
@@ -312,7 +332,7 @@ string CommaSeparatedString(const Container& c, const char* prefix = "",
string comma_separated = prefix;
const char* separator = "";
for (const auto& entry : c) {
- tensorflow::strings::StrAppend(&comma_separated, separator, entry);
+ absl::StrAppend(&comma_separated, separator, entry);
separator = ", ";
}
comma_separated += suffix;
@@ -347,7 +367,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank);
// Returns a PaddingConfig object where 'padding' contains
// (low edge padding, high edge padding) pairs for each dimension.
PaddingConfig MakeEdgePaddingConfig(
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ absl::Span<const std::pair<int64, int64>> padding);
// Returns true if the padding configuration has at least one dimension with
// non-zero interior padding.
@@ -394,8 +414,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
// Split the text into multiple lines and log each line with the given
// severity, filename, and line number.
-void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
- int lineno);
+void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
template <typename T>
inline bool IsPowerOfTwo(T x) {
@@ -415,7 +434,7 @@ std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
}
-int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
+int64 Product(absl::Span<const int64> xs);
// Returns the start indices of consecutive non-overlapping subsequences of `a`
// and `b` with the same product, i.e. `(i, j)` so
@@ -428,130 +447,18 @@ int64 Product(tensorflow::gtl::ArraySlice<int64> xs);
//
// If the given shapes have non-zero size, returns the bounds of the shortest
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
-std::vector<std::pair<int64, int64>> CommonFactors(
- tensorflow::gtl::ArraySlice<int64> a, tensorflow::gtl::ArraySlice<int64> b);
+std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
+ absl::Span<const int64> b);
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
-template <typename Container, typename Predicate>
-bool c_all_of(const Container& container, Predicate&& predicate) {
- return std::all_of(std::begin(container), std::end(container),
- std::forward<Predicate>(predicate));
-}
-
-template <typename Container, typename Predicate>
-bool c_any_of(const Container& container, Predicate&& predicate) {
- return std::any_of(std::begin(container), std::end(container),
- std::forward<Predicate>(predicate));
-}
-
-template <typename InputContainer, typename OutputIterator,
- typename UnaryOperation>
-OutputIterator c_transform(const InputContainer& input_container,
- OutputIterator output_iterator,
- UnaryOperation&& unary_op) {
- return std::transform(std::begin(input_container), std::end(input_container),
- output_iterator,
- std::forward<UnaryOperation>(unary_op));
-}
-
-template <class InputContainer, class OutputIterator, class UnaryPredicate>
-OutputIterator c_copy_if(const InputContainer& input_container,
- OutputIterator output_iterator,
- UnaryPredicate&& predicate) {
- return std::copy_if(std::begin(input_container), std::end(input_container),
- output_iterator, std::forward<UnaryPredicate>(predicate));
-}
-
-template <class InputContainer, class OutputIterator>
-OutputIterator c_copy(const InputContainer& input_container,
- OutputIterator output_iterator) {
- return std::copy(std::begin(input_container), std::end(input_container),
- output_iterator);
-}
-
-template <class InputContainer>
-void c_sort(InputContainer& input_container) {
- std::sort(std::begin(input_container), std::end(input_container));
-}
-
-template <class InputContainer, class Comparator>
-void c_sort(InputContainer& input_container, Comparator&& comparator) {
- std::sort(std::begin(input_container), std::end(input_container),
- std::forward<Comparator>(comparator));
-}
-
-template <typename Sequence, typename T>
-bool c_binary_search(const Sequence& sequence, T&& value) {
- return std::binary_search(std::begin(sequence), std::end(sequence),
- std::forward<T>(value));
-}
-
-template <typename C>
-bool c_is_sorted(const C& c) {
- return std::is_sorted(std::begin(c), std::end(c));
-}
-
-template <typename C, typename Compare>
-bool c_is_sorted(const C& c, Compare&& comp) {
- return std::is_sorted(std::begin(c), std::end(c),
- std::forward<Compare>(comp));
-}
-
-template <typename C>
-auto c_adjacent_find(C& c) -> decltype(std::begin(c)) {
- return std::adjacent_find(std::begin(c), std::end(c));
-}
-
-template <typename C, typename Pred>
-auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) {
- return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
-}
-
-template <typename C, typename Value>
-auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) {
- return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
-}
-
-template <typename Sequence>
-void c_reverse(Sequence& sequence) {
- std::reverse(std::begin(sequence), std::end(sequence));
-}
-
-template <typename Sequence, typename T, typename BinaryOp>
-typename std::decay<T>::type c_accumulate(const Sequence& sequence, T&& init,
- BinaryOp&& binary_op) {
- return std::accumulate(std::begin(sequence), std::end(sequence),
- std::forward<T>(init),
- std::forward<BinaryOp>(binary_op));
-}
-
-template <typename C, typename Pred>
-typename std::iterator_traits<
- decltype(std::begin(std::declval<C>()))>::difference_type
-c_count_if(const C& c, Pred&& pred) {
- return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
-}
-
-// Determines whether `value` is present in `c`.
-template <typename C, typename T>
-bool c_linear_search(const C& c, T&& value) {
- auto last = std::end(c);
- return std::find(std::begin(c), last, std::forward<T>(value)) != last;
-}
-
template <typename C, typename Value>
int64 FindIndex(const C& c, Value&& value) {
- auto it = c_find(c, std::forward<Value>(value));
+ auto it = absl::c_find(c, std::forward<Value>(value));
return std::distance(c.begin(), it);
}
-template <typename T>
-bool ArrayContains(tensorflow::gtl::ArraySlice<T> c, const T& value) {
- return c_find(c, value) != c.end();
-}
-
template <typename C, typename Value>
void InsertAt(C* c, int64 index, Value&& value) {
c->insert(c->begin() + index, std::forward<Value>(value));
@@ -563,13 +470,13 @@ void EraseAt(C* c, int64 index) {
}
template <typename T>
-std::vector<T> ArraySliceToVector(tensorflow::gtl::ArraySlice<T> slice) {
+std::vector<T> ArraySliceToVector(absl::Span<const T> slice) {
return std::vector<T>(slice.begin(), slice.end());
}
-template <typename T, int N>
+template <typename T, size_t N>
std::vector<T> InlinedVectorToVector(
- const tensorflow::gtl::InlinedVector<T, N>& inlined_vector) {
+ const absl::InlinedVector<T, N>& inlined_vector) {
return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
}
@@ -584,8 +491,8 @@ bool IsInt32(T x) {
template <typename T>
Status EraseElementFromVector(std::vector<T>* container, const T& value) {
- // c_find returns a const_iterator which does not seem to work on gcc 4.8.4,
- // and this breaks the ubuntu/xla_gpu build bot.
+ // absl::c_find returns a const_iterator which does not seem to work on
+ // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
auto it = std::find(container->begin(), container->end(), value);
TF_RET_CHECK(it != container->end());
container->erase(it);
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
index 288479c893..50a3c545fb 100644
--- a/tensorflow/compiler/xla/util_test.cc
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -37,45 +37,6 @@ TEST(UtilTest, ReindentsDifferentNumberOfLeadingSpacesUniformly) {
EXPECT_EQ(want, got);
}
-// Some smoke tests for ContainersEqual. Keeping it simple since these are just
-// basic wrappers around std::equal.
-TEST(UtilTest, ContainersEqualDefault) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {1, 2, 3};
- std::vector<int> c3 = {};
- std::vector<int> c4 = {1, 2, 3, 4};
- std::vector<int> c5 = {1, 2, 3, 4, 5};
- std::vector<int> c6 = {1, 3, 4, 5};
-
- EXPECT_TRUE(ContainersEqual(c1, c4));
- EXPECT_TRUE(ContainersEqual(c4, c1));
- EXPECT_FALSE(ContainersEqual(c1, c2));
- EXPECT_FALSE(ContainersEqual(c2, c1));
- EXPECT_FALSE(ContainersEqual(c1, c3));
- EXPECT_FALSE(ContainersEqual(c3, c1));
- EXPECT_FALSE(ContainersEqual(c1, c5));
- EXPECT_FALSE(ContainersEqual(c5, c1));
- EXPECT_FALSE(ContainersEqual(c1, c6));
- EXPECT_FALSE(ContainersEqual(c6, c1));
-}
-
-TEST(UtilTest, ContainersEqualPredicate) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::vector<int> c2 = {10, 20, 30, 40};
-
- EXPECT_TRUE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 < i2; }));
- EXPECT_FALSE(ContainersEqual(
- c1, c2, [](const int& i1, const int& i2) { return i1 > i2; }));
-}
-
-TEST(UtilTest, ContainersEqualDifferentContainerTypes) {
- std::vector<int> c1 = {1, 2, 3, 4};
- std::list<int> c2 = {1, 2, 3, 4};
-
- EXPECT_TRUE(ContainersEqual(c1, c2));
-}
-
TEST(UtilTest, HumanReadableNumFlopsExample) {
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
}
@@ -117,8 +78,8 @@ TEST(UtilTest, CommonFactors) {
/*.expected =*/{{0, 0}, {0, 1}, {2, 2}, {3, 2}, {4, 3}, {4, 4}}},
};
for (const auto& test_case : test_cases) {
- EXPECT_TRUE(ContainersEqual(test_case.expected,
- CommonFactors(test_case.a, test_case.b)));
+ EXPECT_TRUE(absl::c_equal(test_case.expected,
+ CommonFactors(test_case.a, test_case.b)));
}
}
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index f11123ca24..8ea8dbab25 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -17,16 +17,15 @@ limitations under the License.
#include <vector>
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace window_util {
-Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
+Window MakeWindow(absl::Span<const int64> sizes) {
Window window;
for (int64 size : sizes) {
auto* dimension = window.add_dimensions();
@@ -38,7 +37,7 @@ Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
return window;
}
-PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
+PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
PaddingConfig config;
for (int64 size : sizes) {
auto* dimension = config.add_dimensions();
@@ -49,8 +48,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
}
/* static */ string ToString(const WindowDimension& dim) {
- using tensorflow::strings::StrAppend;
- using tensorflow::strings::StrCat;
+ using absl::StrAppend;
+ using absl::StrCat;
string str = StrCat("(size=", dim.size());
if (dim.stride() != 1) {
StrAppend(&str, ",stride=", dim.stride());
@@ -75,8 +74,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
}
string ToString(const Window& window) {
- using tensorflow::strings::StrAppend;
- using tensorflow::strings::StrCat;
+ using absl::StrAppend;
+ using absl::StrCat;
string str;
const auto add_field =
diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h
index ba473e2c8c..1fb9e855fc 100644
--- a/tensorflow/compiler/xla/window_util.h
+++ b/tensorflow/compiler/xla/window_util.h
@@ -16,22 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace window_util {
// Creates a window with the given sizes in the dimensions and all strides set
// to 1.
-Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes);
+Window MakeWindow(absl::Span<const int64> sizes);
// Creates a padding config with symmetrical padding in each dimension, of value
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
// side of dimension 1, and two pixels of padding symmetrically on dimension 2.
-PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes);
+PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes);
string ToString(const WindowDimension& dim);
string ToString(const Window& window);
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 3b72eb17c6..b53f89d63b 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -195,8 +195,13 @@ message DebugOptions {
bool xla_cpu_enable_fast_math = 99;
bool xla_gpu_enable_fast_math = 100;
- // Extra options to pass to the compilation backend; specific interpretation
- // of these values is left to the backend.
+ // Crashes the program when any kind of verification fails, instead of just
+ // logging the failures. One example is cross checking of convolution results
+ // among different algorithms.
+ bool xla_gpu_crash_on_verification_failures = 101;
+
+ // Extra options to pass to the compilation backend (e.g. LLVM); specific
+ // interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 4c35e93d38..dd329f1181 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -105,13 +105,14 @@ enum PaddingValue {
message PaddingConfig {
// Describes the padding configuration for a dimension.
message PaddingConfigDimension {
- // Padding amount on the low-end (next to the index 0).
+ // Padding amount on the low-end (next to the index 0). May be negative.
int64 edge_padding_low = 1;
- // Padding amount on the high-end (next to the highest index).
+ // Padding amount on the high-end (next to the highest index). May be
+ // negative.
int64 edge_padding_high = 2;
- // Padding amount between the elements.
+ // Padding amount between the elements. May not be negative.
int64 interior_padding = 3;
}
@@ -393,13 +394,14 @@ message WindowDimension {
// Dilation factor of the sliding window in this dimension. A dilation factor
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
- // implicitly placed between each kernel element. See documentation for
- // convolution.
+ // implicitly placed between each kernel element. This value may not be less
+ // than 1. See documentation for convolution.
int64 window_dilation = 5;
// Dilation factor of the base area in this dimension. A dilation factor of 1
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
- // placed between each base area element. See documentation for convolution.
+ // placed between each base area element. This value may not be less than 1.
+ // See documentation for convolution.
int64 base_dilation = 6;
// Window reversal means that this dimension was logically reversed before the
@@ -424,25 +426,25 @@ message GatherDimensionNumbers {
// "Window indices" is a term for a set of indices that index into the
// interior of a dynamic-slice from the input tensor, the starting indices for
// which were computed from output_gather_dims (see the operation semantic for
- // how this is defined) and the gather_indices tensor.
+ // how this is defined) and the start_indices tensor.
//
// The window indices for a specific output index Out is computed as:
//
// i = 0
// for (k : [0, input_tensor_shape.rank))
// window_indices[k] =
- // if k in elided_window_dims
+ // if k in collapsed_slice_dims
// then 0
- // else Out[output_window_dims[i++]]
- repeated int64 output_window_dims = 1;
- repeated int64 elided_window_dims = 2;
+ // else Out[offset_dims[i++]]
+ repeated int64 offset_dims = 1;
+ repeated int64 collapsed_slice_dims = 2;
- // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
- // transforms the gather index looked up from the gather_indices tensor into
+ // This is interpreted as a map from i to start_index_map[i]. It
+ // transforms the gather index looked up from the start_indices tensor into
// the starting index in the input space.
- repeated int64 gather_dims_to_operand_dims = 3;
+ repeated int64 start_index_map = 3;
- // The dimension in the gather_indices input that contains the starting
+ // The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
}
@@ -569,3 +571,24 @@ message ReplicaGroup {
// ids matters in some op (e.g., all-to-all).
repeated int64 replica_ids = 1;
}
+
+// Describes the source target pair in the collective permute op.
+message SourceTarget {
+ int64 source = 1;
+ int64 target = 2;
+}
+
+// Used to indicate the precision configuration. It has backend specific
+// meaning.
+message PrecisionConfig {
+ enum Precision {
+ DEFAULT = 0;
+ HIGH = 1;
+ HIGHEST = 2;
+
+ // Next: 3
+ }
+ repeated Precision operand_precision = 1;
+
+ // Next: 2
+}
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
new file mode 100644
index 0000000000..2ff97914f8
--- /dev/null
+++ b/tensorflow/compiler/xrt/BUILD
@@ -0,0 +1,84 @@
+# Description: Operations defined for XRT
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_libs",
+)
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "xrt_proto",
+ srcs = ["xrt.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ ],
+)
+
+cc_library(
+ name = "xrt_utils",
+ srcs = [
+ "xrt_compilation_cache.cc",
+ "xrt_device.cc",
+ "xrt_state.cc",
+ ],
+ hdrs = [
+ "xrt_compilation_cache.h",
+ "xrt_device.h",
+ "xrt_state.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/jit:xla_device",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "xrt_server",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":xrt_compile_ops_op_lib",
+ ":xrt_execute_op_op_lib",
+ ":xrt_state_ops_op_lib",
+ "//tensorflow/compiler/xrt/kernels:xrt_ops",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD
new file mode 100644
index 0000000000..5c1e86b76b
--- /dev/null
+++ b/tensorflow/compiler/xrt/cc/BUILD
@@ -0,0 +1,20 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrappers_cc",
+)
+
+tf_gen_op_wrappers_cc(
+ name = "xrt_ops",
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ pkg = "//tensorflow/compiler/xrt",
+)
diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD
new file mode 100644
index 0000000000..9e3d2454d1
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/BUILD
@@ -0,0 +1,69 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+cc_library(
+ name = "xrt_state_ops",
+ hdrs = ["xrt_state_ops.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xrt_ops",
+ srcs = [
+ "xrt_compile_ops.cc",
+ "xrt_execute_op.cc",
+ "xrt_state_ops.cc",
+ ],
+ deps = [
+ ":xrt_state_ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/stream_executor:stream_executor_headers_lib",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
new file mode 100644
index 0000000000..1d4f8d97f2
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -0,0 +1,239 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for compiling XLA computations and managing handles that refer to
+// them.
+
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+const int kDefaultCacheSize = 100;
+
+class XRTCompileOp : public OpKernel {
+ public:
+ explicit XRTCompileOp(OpKernelConstruction* ctx);
+ ~XRTCompileOp() override;
+ XRTCompileOp(const XRTCompileOp&) = delete;
+ XRTCompileOp& operator=(const XRTCompileOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ Status Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program);
+};
+
+Status CompilationCacheKey(const xrt::XLAComputation& computation,
+ string* key) {
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ *key = absl::StrCat(fingerprint);
+ return Status::OK();
+}
+
+XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+Status XRTCompileOp::Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program) {
+ const xrt::XLAComputationConfig& config = computation_proto.config();
+
+ // The default config value is 0; treat it as 1 for convenience.
+ int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
+ TF_RET_CHECK(num_replicas == 1);
+ int num_cores_per_replica =
+ config.num_cores_per_replica() ? config.num_cores_per_replica() : 1;
+ TF_RET_CHECK(num_cores_per_replica == 1);
+ TF_RET_CHECK(config.per_core_program_shape_size() == 0);
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(ctx, 0, &device_ref));
+
+ xla::LocalClient* client = device_ref.client();
+
+ // There is officially no way to use XLA in a client/server architecture where
+ // client and server are built from different revisions, because the XLA team
+ // does not want to give any guarantees about the stability of the Hlo
+ // proto. For cloud TPU this is fine because server and client versions can be
+ // assumed to be synced to the same version. For general use the mechanism
+ // here (using a snapshot from XlaComputation) works as well as the "official"
+ // XLA client/server design, which serializes the same proto between client
+ // and server, so in reality is probably fine.
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
+ client->LoadSnapshot(computation_proto.hlo_snapshot()));
+
+ std::vector<const xla::Shape*> argument_layouts(
+ config.program_shape().parameters_size());
+ for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
+ argument_layouts[i] = &config.program_shape().parameters(i);
+ }
+ xla::ExecutableBuildOptions build_options;
+ build_options.set_device_ordinal(client->default_device_ordinal());
+ build_options.set_result_layout(config.program_shape().result());
+ build_options.set_device_allocator(device_ref.backend()->memory_allocator());
+
+ VLOG(1) << "Building executable";
+ auto compile_result =
+ client->Compile(computation, argument_layouts, build_options);
+ if (!compile_result.ok()) {
+ return compile_result.status();
+ }
+ *program = std::move(compile_result.ValueOrDie());
+ return Status::OK();
+}
+
+void XRTCompileOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTCompileOp::Compute";
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ const Tensor& computation_input = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
+ errors::Internal("computation input should be a string scalar"));
+
+ xrt::XLAComputation computation_proto;
+ OP_REQUIRES(
+ ctx,
+ computation_proto.ParseFromString(computation_input.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse computation input to XLAComputation"));
+
+ string key;
+ OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx,
+ rm->LookupOrCreate<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName,
+ &cache, [](XRTCompilationCache** new_cache) {
+ *new_cache = new XRTCompilationCache(kDefaultCacheSize);
+ return Status::OK();
+ }));
+ core::ScopedUnref cache_unref(cache);
+
+ int64 uid;
+ OP_REQUIRES_OK(
+ ctx, cache->CompileIfKeyAbsent(
+ key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
+ VLOG(1) << "Compiling XLA executable";
+ return Compile(ctx, computation_proto, program);
+ }));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = uid;
+ ctx->set_output(0, output);
+}
+
+XRTCompileOp::~XRTCompileOp() = default;
+
+class XRTReleaseCompilationRefOp : public OpKernel {
+ public:
+ explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
+ ~XRTReleaseCompilationRefOp() override;
+ XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
+ XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
+ delete;
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
+XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
+ OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
+
+void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
+
+ const Tensor& key_tensor = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()),
+ errors::Internal("computation key should be a string scalar"));
+ int64 uid = key_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
+ rm->default_container(),
+ kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ OP_REQUIRES_OK(ctx, cache->Release(uid));
+
+ VLOG(2) << "Released computation handle " << uid;
+}
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
new file mode 100644
index 0000000000..257b054f16
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
@@ -0,0 +1,254 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace tensorflow {
+
+namespace {
+
+uint32 InitialRandomSeed() {
+ // Support plumbing the TF seed through to XLA is being worked on.
+ // If a user wants deterministic behavior, their best option
+ // is to start with a known checkpoint. This also handles issues when
+ // multiple random calls can be invoked in any order by TF executor.
+ // Another option is to use stateless random ops. They have much cleaner
+ // semantics.
+ // If a user really wants to set a deterministic seed for XLA-based
+ // devices, this is the place to do it.
+ std::random_device rd;
+ // Make the starting value odd.
+ return rd() | 1;
+}
+
+uint32 GetXLARandomSeed() {
+ // We initialize counter with an odd number and increment it by two
+ // everytime. This ensures that it will never be zero, even
+ // after an overflow. When seeded with zero, some XLA backends
+ // can return all zeros instead of random numbers.
+ static std::atomic<uint32> counter(InitialRandomSeed());
+ return counter.fetch_add(2);
+}
+
+// Looks up the input `key` in the compilation cache.
+Status GetComputationCacheEntry(
+ XRTCompilationCache* cache, int64 key,
+ std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ TF_RETURN_IF_ERROR(cache->Lookup(key, entry));
+ return Status::OK();
+}
+
+// Populates `inputs` with the input tensors to the computation.
+Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
+ bool release_inputs,
+ std::vector<XRTTupleAllocation*>* input_tuples,
+ std::vector<xla::ShapedBuffer>* input_allocations,
+ std::vector<xla::ShapedBuffer*>* input_pointers) {
+ OpInputList arg_list;
+ TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
+
+ input_tuples->resize(arg_list.size());
+ input_pointers->resize(arg_list.size());
+ for (int i = 0; i < arg_list.size(); ++i) {
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape()));
+ int64 input_uid = arg_list[i].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
+ if (release_inputs) {
+ // We are holding a reference to the tuple, so we can safely delete it
+ // from the resource manager here.
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid));
+ VLOG(2) << "Released allocation handle " << input_uid;
+ }
+ XRTTupleAllocation* tuple = (*input_tuples)[i];
+ input_allocations->emplace_back(tuple->ToShapedBuffer());
+ }
+ for (int i = 0; i < arg_list.size(); ++i) {
+ (*input_pointers)[i] = &(*input_allocations)[i];
+ }
+ return Status::OK();
+}
+
+// XRTExecuteOp
+
+class XRTExecuteOp : public AsyncOpKernel {
+ public:
+ explicit XRTExecuteOp(OpKernelConstruction* context);
+ ~XRTExecuteOp() override;
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+ Status DoWork(OpKernelContext* context);
+};
+
+XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
+ // Schedule onto the default queue, for unbounded concurrency. See b/73520706
+ Env::Default()->SchedClosure([this, context, done]() {
+ OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
+ done();
+ });
+}
+
+Status XRTExecuteOp::DoWork(OpKernelContext* context) {
+ VLOG(1) << "XRTExecuteOp::Compute";
+ ResourceMgr* rm;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
+
+ const Tensor& execution_input = context->input(0);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
+ int64 compilation_handle = execution_input.scalar<int64>()();
+
+ const Tensor& execution_config = context->input(1);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
+ xrt::XRTExecutionConfig config_proto;
+ TF_RET_CHECK(
+ config_proto.ParseFromString(execution_config.scalar<string>()()));
+
+ int core_index_in_replica = config_proto.core_index_in_replica();
+ TF_RET_CHECK(core_index_in_replica == 0);
+ bool release_inputs = config_proto.release_input_handles();
+ bool release_compilation = config_proto.release_compilation_handle();
+
+ XRTCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+ TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
+
+ if (release_compilation) {
+ // Process-wide cache of XLA executables.
+ TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
+ VLOG(2) << "Released compilation handle " << compilation_handle;
+ }
+
+ std::vector<XRTTupleAllocation*> input_tuples;
+ // Make a cleanup method so that we can safely return in error conditions
+ // without leaking references to allocations.
+ auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() {
+ for (auto tuple : input_tuples) {
+ if (tuple != nullptr) {
+ tuple->Unref();
+ }
+ }
+ });
+ std::vector<xla::ShapedBuffer> input_allocations;
+ std::vector<xla::ShapedBuffer*> input_pointers;
+ TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs,
+ &input_tuples, &input_allocations,
+ &input_pointers));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref));
+
+ int rng_seed = config_proto.rng_seed();
+ if (rng_seed == 0) {
+ rng_seed = GetXLARandomSeed();
+ }
+
+ se::Stream* stream = context->op_device_context()
+ ? context->op_device_context()->stream()
+ : nullptr;
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(device_ref.backend()->memory_allocator());
+ run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
+ run_options.set_rng_seed(rng_seed);
+
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ xla::LocalExecutable* executable = entry->get().get_executable();
+ auto run_result = executable->Run(input_pointers, run_options);
+ if (!run_result.ok()) {
+ return run_result.status();
+ }
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ auto scoped_buffer = run_result.ConsumeValueOrDie();
+ auto shaped_buffer = scoped_buffer.release();
+ XRTTupleAllocation* output_tuple;
+ TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
+ shaped_buffer, device_ref.backend(), device_ref.device_ordinal(),
+ &output_tuple));
+
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, TensorShape({}), &output_tensor));
+ int64 key;
+ TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
+ output_tensor->scalar<int64>()() = key;
+
+ return Status::OK();
+}
+
+XRTExecuteOp::~XRTExecuteOp() = default;
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
new file mode 100644
index 0000000000..ffea592491
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h"
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
new file mode 100644
index 0000000000..54b06558ad
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -0,0 +1,424 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+#define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Helper functions for templated ops.
+class XRTStateHelpers {
+ public:
+ // The Status return value allows us to use the
+ // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
+ // OpKernel::Compute method.
+ static Status MakeLiteral(const xla::LiteralProto& proto,
+ xla::Literal* literal) {
+ TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
+ return Status::OK();
+ }
+
+ // ParseTupleNode is the recursive function used to parse a recursive
+ // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
+ // tuple shape where every leaf is an existing allocation. As a side-effect it
+ // fills in input_vector by looking up allocations from handles in the
+ // input_tensor_list as they are referenced by nodes in the proto.
+ static Status ParseTupleNode(
+ const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::Shape* shape, ResourceMgr* rm) {
+ if (tuple_node.tuples_size() > 0) {
+ // This is an internal node in the proto so descend recursively.
+ xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
+ std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
+ *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
+ xla::ShapeUtil::MakeTupleShape(subshapes);
+ for (int i = 0; i < tuple_node.tuples_size(); ++i) {
+ TF_RETURN_IF_ERROR(ParseTupleNode(
+ tuple_node.tuples(i), input_tensor_list, input_vector,
+ xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
+ }
+ } else {
+ // This is a leaf node in the proto so look up the referenced input.
+ int input_index = tuple_node.input_index();
+ if (input_index < 0 || input_index >= input_vector->size()) {
+ return errors::InvalidArgument("Invalid tuple input index ",
+ input_index, ": MakeTuple has ",
+ input_vector->size(), " inputs.");
+ }
+ bool release_this_input = tuple_node.release_input_handle();
+ XRTTupleAllocation::ExpandedTupleInput& input =
+ input_vector->at(input_index);
+ if (input.allocation != nullptr &&
+ (input.release_allocation_after_use || release_this_input)) {
+ return errors::InvalidArgument(
+ "Invalid tuple tree: input index ", input_index,
+ " is repeated but release_input_handle is true.");
+ }
+ if (input.allocation == nullptr) {
+ // We haven't dereferenced this handle yet.
+ TF_RET_CHECK(
+ TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
+ int64 key = input_tensor_list[input_index].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, key, &input.allocation));
+ input.release_allocation_after_use = release_this_input;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
+ // ShapeTree where each leaf is an allocation corresponding to a handle in
+ // input_tensor_list. The ordinal of one of the allocations is returned in
+ // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
+ // no leaves, device_ordinal will always be filled in by a successful call to
+ // ParseTupleTree.
+ static Status ParseTupleTree(
+ const xrt::XLATupleNode& tuple_tree_root,
+ const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
+ int* device_ordinal, ResourceMgr* rm) {
+ // First get the shape of the 'spine' of the new tuple, where every leaf is
+ // an existing allocation. As a side-effect dereference the input handles
+ // into allocations in input_vector.
+ xla::Shape tuple_tree_shape;
+ TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
+ input_vector, &tuple_tree_shape, rm));
+ // Make the shape tree of allocations where the shape is the spine and each
+ // leaf is one of the allocations looked up in input_vector. Internal nodes
+ // have nullptr allocations.
+ *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
+ tuple_tree_shape);
+ tuple_shape_tree->ForEachMutableElement(
+ [&](const xla::ShapeIndex& index,
+ XRTTupleAllocation::ExpandedTupleInput* element) {
+ if (tuple_shape_tree->IsLeaf(index)) {
+ // Find the matching leaf in the proto tree.
+ const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
+ for (int i = 0; i < index.size(); ++i) {
+ tuple_node = &tuple_node->tuples(index[i]);
+ }
+ // Copy the appropriate input allocation to the leaf of the
+ // tuple_shape_tree.
+ int input_index = tuple_node->input_index();
+ *element = input_vector->at(input_index);
+ CHECK(element->release_allocation_after_use ==
+ tuple_node->release_input_handle());
+ // We just need to know the device_ordinal of one of the
+ // allocations. We will validate later that they are all the same.
+ *device_ordinal = (*element).allocation->device_ordinal();
+ }
+ });
+ return Status::OK();
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTAllocateOp : public OpKernel {
+ public:
+ explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTAllocateOp() override = default;
+ XRTAllocateOp(const XRTAllocateOp&) = delete;
+ XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTAllocateOp::Compute";
+
+ const Tensor& allocation_info = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
+ errors::Internal("allocation input should be a string scalar"));
+ xrt::XLAAllocation allocation_proto;
+ OP_REQUIRES(
+ ctx,
+ allocation_proto.ParseFromString(allocation_info.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse allocation input to XLAAllocation"));
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(
+ ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx,
+ DeviceAccessor::InitScopedRef(
+ ctx, allocation_proto.device_ordinal(), &device_ref));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
+ literal, device_ref.backend(),
+ device_ref.device_ordinal(), &allocation));
+
+ // Intern takes ownership of our reference to allocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that takes a tuple handle input and returns a handle to a sub-tuple of the
+// input.
+template <bool discard_, class DeviceAccessor>
+class XRTSubTupleOp : public OpKernel {
+ public:
+ explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTSubTupleOp() override = default;
+ XRTSubTupleOp(const XRTSubTupleOp&) = delete;
+ XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTSubTupleOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ const Tensor& subtuple_info = ctx->input(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
+ errors::Internal("tuple index input should be an int32 vector"));
+ xla::ShapeIndex shape_index;
+ for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
+ shape_index.push_back(subtuple_info.vec<int32>()(i));
+ }
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ XRTTupleAllocation* suballocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::MakeSubBuffer(allocation, shape_index,
+ &suballocation, !discard_));
+
+ // Intern takes ownership of our reference to suballocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, suballocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTMakeTupleOp : public OpKernel {
+ public:
+ explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTMakeTupleOp() override = default;
+ XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
+ XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTMakeTupleOp::Compute";
+
+ const Tensor& tuple_info = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
+ errors::Internal("tuple description input should be a string scalar"));
+ xrt::XLATupleNode tuple_proto;
+ OP_REQUIRES(
+ ctx, tuple_proto.ParseFromString(tuple_info.scalar<string>()()),
+ errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
+
+ OpInputList arg_list;
+ OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
+
+ // For each input, the allocation it corresponds to and a flag indicating
+ // whether or not it should be released, i.e. discarded from the resource
+ // manager. One ref on each allocation is owned by this vector, and freed on
+ // exit.
+ std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
+ arg_list.size());
+ auto cleanup = gtl::MakeCleanup([&input_vector] {
+ for (auto& input : input_vector) {
+ if (input.allocation != nullptr) {
+ input.allocation->Unref();
+ }
+ }
+ });
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
+ // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
+ // the allocations. It is guaranteed that there is at least on allocation in
+ // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
+ // all the allocations are on the same device.
+ int device_ordinal;
+ OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
+ tuple_proto, arg_list, &input_vector,
+ &tuple_shape_tree, &device_ordinal, rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(
+ ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
+
+ XRTTupleAllocation* output_allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
+ device_ref.backend(), device_ref.device_ordinal(),
+ tuple_shape_tree, &output_allocation));
+ // Add a ScopedUnref to simplify the error path while calling
+ // DeleteFromResourceManager.
+ core::ScopedUnref unref(output_allocation);
+ for (int i = 0; i < input_vector.size(); ++i) {
+ if (input_vector[i].release_allocation_after_use) {
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, arg_list[i].scalar<int64>()()));
+ }
+ }
+
+ // Intern takes ownership of a reference to output_allocation, so add
+ // another since the ScopedUnref will release one when this method exits.
+ output_allocation->Ref();
+ int64 key;
+ OP_REQUIRES_OK(ctx, output_allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that reads a device-resident tuple to host memory and returns it as a
+// literal.
+template <bool discard_, class DeviceAccessor>
+class XRTReadLiteralOp : public OpKernel {
+ public:
+ explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReadLiteralOp() override = default;
+ XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
+ XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReadLiteralOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
+ ctx, allocation->device_ordinal(), &device_ref));
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(
+ ctx, allocation->ToLiteral(device_ref.backend(),
+ device_ref.device_ordinal(), &literal));
+ xla::LiteralProto literal_proto = literal.ToProto();
+
+ Tensor output(DT_STRING, TensorShape({}));
+ literal_proto.SerializeToString(&output.scalar<string>()());
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that discards a handle to device memory.
+template <class DeviceAccessor>
+class XRTReleaseAllocationOp : public OpKernel {
+ public:
+ explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReleaseAllocationOp() override = default;
+ XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
+ XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReleaseAllocationOp::Compute";
+
+ const Tensor& allocation_handle = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()),
+ errors::Internal("handle input should be an int64 scalar"));
+ int64 key = allocation_handle.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key));
+
+ VLOG(2) << "Released allocation handle " << key;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
new file mode 100644
index 0000000000..5cfc8711f9
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTCompile")
+ .Input("computation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a computation proto, compiles it, and places it in the global compilation
+cache.
+
+'computation' is a serialized xrt::XLAComputation proto.
+'handle' is an identifier that can be used in other ops to refer to the
+computation.
+)");
+
+REGISTER_OP("XRTReleaseCompilationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards a computation from the compilation cache. The handle cannot be
+subsequently used.
+
+'handle' is an id returned from a XRTCompile Op.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
new file mode 100644
index 0000000000..fda4c31298
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTExecute")
+ .Attr("Ninputs: int")
+ .Input("computation_handle: int64")
+ .Input("execution_config: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Runs a previously-compiled computation on a core. If
+execution_config.release_input_handles is true, the input handles are invalid
+after this op runs.
+
+'computation_handle' is an id returned by XRTCompile.
+'execution_config' is a serialized xrt::TPUExecutionConfig proto.
+'input_handles' is a list of ids of allocations, one per input to the compiled
+computation.
+'output_handle' is an identifier for the result of the compiled computation.
+'Ninputs' is the number of input handles.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
new file mode 100644
index 0000000000..07d025ce34
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
@@ -0,0 +1,122 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTAllocate")
+ .Input("allocation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a literal proto and transfers it to TPU device memory.
+
+'allocation' is a serialized xrt::TPUAllocation proto.
+'handle' is an id that can be used in other ops to refer to the allocation.
+)");
+
+REGISTER_OP("XRTSubTuple")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used in other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTSubTupleAndRelease")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple, and releases the handle
+of the input tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used by other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTMakeTuple")
+ .Attr("Ninputs: int")
+ .Input("tuple_description: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a new allocation constructed by assembling existing
+allocations in a tuple.
+
+'tuple_description' is a serialized xrt::XLATupleNode proto describing the
+shape of the output tuple, and whether each input handle should be aliased or
+released.
+'input_handles' is a list of input handles to assemble into the output tuple.
+'output_handle' is an id that can be used by other ops to refer to the new
+tuple.
+'Ninputs' is the number of input handles.
+)");
+
+REGISTER_OP("XRTReadLiteral")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory and returns it as a literal.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReadLiteralAndRelease")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory, and returns it as a literal, and
+releases the handle.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReleaseAllocationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards an allocation from device memory. The handle cannot be subsequently
+used.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
new file mode 100644
index 0000000000..09ab4ed95f
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -0,0 +1,65 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler:__subpackages__",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+
+cc_library(
+ name = "raw_api_test_lib",
+ testonly = 1,
+ srcs = [
+ "raw_api_test.cc",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_server",
+ "//tensorflow/compiler/xrt/cc:xrt_ops",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "raw_api_test_cpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_CPU"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_cpu_device",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "raw_api_test_gpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_GPU"],
+ tags = ["requires-gpu-sm35"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_gpu_device",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
new file mode 100644
index 0000000000..2952feb16a
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -0,0 +1,421 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+string* xla_test_device_ptr; // initial value set in main()
+
+string DeviceFromFlag() {
+ string xla_test_device = *xla_test_device_ptr;
+ return absl::StrCat("/device:", xla_test_device, ":0");
+}
+
+xla::LiteralProto TwoElementTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ return tuple.ToProto();
+}
+
+xla::LiteralProto ScalarLiteral() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ return scalar.ToProto();
+}
+
+xla::LiteralProto NestedTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
+ return nested.ToProto();
+}
+
+xla::LiteralProto MakeTuple0() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+ auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
+ auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
+ return nested1.ToProto();
+}
+
+xla::LiteralProto FloatVector(absl::Span<const float> v) {
+ auto array = xla::LiteralUtil::CreateR1<float>(v);
+ return array.ToProto();
+}
+
+bool CompareLiteralProtos(const xla::LiteralProto& a,
+ const xla::LiteralProto& b) {
+ auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = l_a == l_b;
+ if (!equal) {
+ LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
+ << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+bool CompareLiteralToLiteralProto(const xla::Literal& a,
+ const xla::LiteralProto& b) {
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = a == l_b;
+ if (!equal) {
+ LOG(INFO) << "Literal and LiteralProto don't match "
+ << a.ToProto().DebugString() << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+xla::XlaComputation AddAndScale() {
+ xla::XlaBuilder builder("AddAndScale");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ auto c = xla::ConstantR0<float>(&builder, 3.0f);
+ xla::Mul(sum, c);
+ return builder.Build().ValueOrDie();
+}
+
+xla::XlaComputation AddAndTuple() {
+ xla::XlaBuilder builder("AddAndTuple");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ xla::Tuple(&builder, {sum});
+ return builder.Build().ValueOrDie();
+}
+
+void StoreComputationSnapshot(const xla::XlaComputation& computation,
+ xla::HloSnapshot* dst) {
+ auto snapshot = computation.Snapshot().ValueOrDie();
+ *dst = *snapshot;
+}
+
+TEST(RawApiTest, ReadAndWriteState) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteral(root, handle);
+ auto release = ops::XRTReleaseAllocationHandle(
+ root.WithControlDependencies(read_back), handle);
+ TF_ASSERT_OK(root.status());
+
+ tensorflow::ClientSession session(root);
+ std::vector<tensorflow::Tensor> outputs;
+ TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
+ {release}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, ReadAndWriteStateAutoFree) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, SubBuffer) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = NestedTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto base_handle = ops::XRTAllocate(root, value);
+ auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
+ auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
+ auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
+ auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
+ auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
+ auto sub_00 = ops::XRTSubTupleAndRelease(
+ root.WithControlDependencies(
+ {sub_0.output_handle.op(), sub_1.output_handle.op()}),
+ base_handle, index_00);
+ auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
+ auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
+ auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
+
+ auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
+ auto base_elements = base_literal.DecomposeTuple();
+ auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
+ xla::LiteralProto response_00;
+ EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
+}
+
+TEST(RawApiTest, MakeTuple) {
+ xrt::XLAAllocation alloc_0;
+ alloc_0.set_device_ordinal(0);
+ *alloc_0.mutable_value() = TwoElementTuple();
+ xrt::XLAAllocation alloc_1;
+ alloc_1.set_device_ordinal(0);
+ *alloc_1.mutable_value() = ScalarLiteral();
+
+ // The trivial tuple that just forwards its input and releases it.
+ xrt::XLATupleNode desc_0;
+ desc_0.set_input_index(0);
+ desc_0.set_release_input_handle(true);
+
+ xrt::XLATupleNode desc_1;
+ auto subdesc_10 = desc_1.add_tuples();
+ auto subdesc_11 = desc_1.add_tuples();
+ subdesc_10->set_input_index(0);
+ auto subdesc_110 = subdesc_11->add_tuples();
+ subdesc_110->set_input_index(0);
+ auto subdesc_111 = subdesc_11->add_tuples();
+ subdesc_111->set_input_index(1);
+
+ xrt::XLATupleNode desc_2;
+ auto subdesc_20 = desc_2.add_tuples();
+ auto subdesc_21 = desc_2.add_tuples();
+ subdesc_20->set_input_index(1);
+ subdesc_20->set_release_input_handle(true);
+ subdesc_21->set_input_index(0);
+ subdesc_21->set_release_input_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
+ auto handle_0 = ops::XRTAllocate(root, value_0);
+ auto value_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
+ auto handle_1 = ops::XRTAllocate(root, value_1);
+ auto tuple_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
+ auto handle_2 =
+ ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
+ // handle_0 has now been released.
+ auto tuple_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
+ auto handle_3 = ops::XRTMakeTuple(
+ root, tuple_1,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ auto tuple_2 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
+ // Make sure this runs after handle_3 has completed, since it will free
+ // handle_1 and handle_2.
+ auto handle_4 = ops::XRTMakeTuple(
+ root.WithControlDependencies(handle_3), tuple_2,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ // handle_1 and handle_2 have now been released.
+
+ auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
+ auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+
+ auto expected_0 = MakeTuple0();
+ EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
+ auto expected_1 = NestedTuple();
+ EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
+}
+
+TEST(RawApiTest, CompileAndExecute) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
+}
+
+TEST(RawApiTest, CompileAndExecuteReturnTuple) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::F32, {2})});
+ StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
+ auto expected = xla::LiteralUtil::MakeTuple({&sum});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
+}
+
+} // namespace
+
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
+ "Tensorflow device type to use for test, e.g., XLA_CPU"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto
new file mode 100644
index 0000000000..5678f0905f
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt.proto
@@ -0,0 +1,78 @@
+syntax = "proto3";
+
+package xrt;
+
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
+
+// Options for an XLA compilation.
+message XLAComputationConfig {
+ // The number of replicas the computation will be run on. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_replicas = 1;
+ // The number of "model-parallel" cores per replica. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_cores_per_replica = 2;
+ // Optional metadata about host sends and recvs.
+ tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
+
+ // The arg/result shapes for the whole computation.
+ xla.ProgramShape program_shape = 4;
+ // The arg/result shapes for each core of a model-parallel
+ // computation. per_core_args_and_result_shapes is optional for a
+ // single-core computation.
+ repeated xla.ProgramShape per_core_program_shape = 5;
+}
+
+// Options and XLA computation for a compilation.
+message XLAComputation {
+ XLAComputationConfig config = 1;
+ xla.HloSnapshot hlo_snapshot = 2;
+}
+
+// Literal to allocate space for, and transfer to, device memory.
+message XLAAllocation {
+ int32 device_ordinal = 1;
+ xla.LiteralProto value = 2;
+}
+
+// Node in a tree describing a tuple constructed from input handles. A
+// node is an internal node if tuples is non-empty, in which case
+// input_index and release_input_handle are ignored. Otherwise a node
+// is a leaf node. Each leaf XLATupleNode is the index of an input
+// which corresponds to a handle that will be grafted onto the output
+// tuple at that location. If release_input_handle is true that input
+// handle will be released and become invalid. Inputs may be repeated
+// in which case leaves of the output tuple will alias. If an input is
+// repeated, release_input_handle must be false for every leaf where
+// that input appears.
+//
+// For example, if input 0 has shape {} and input 1 has shape {2,3}
+// then the XLATupleNode with structure {1,{0,1}} corresponds to a
+// tuple with shape {{2,3},{{},{2,3}}}.
+message XLATupleNode {
+ int32 input_index = 1;
+ bool release_input_handle = 2;
+ repeated XLATupleNode tuples = 3;
+}
+
+// Options for an XLA execution.
+message XRTExecutionConfig {
+ // Local device to run on. This is present because the execute Op
+ // may be placed on a device such as CPU or TPU_SYSTEM that
+ // logically manages multiple cores.
+ int32 device_ordinal = 1;
+ // Which model-parallel computation to run from the compiled bundle.
+ int32 core_index_in_replica = 2;
+ // Optional key to disambiguate between executions. This is only
+ // needed if multiple host send/recvs may be outstanding
+ // concurrently with executions.
+ string execution_instance_key = 3;
+ // If non-zero, rng_seed to reset the core with.
+ uint32 rng_seed = 4;
+ // If true, release allocation handles on the inputs after running.
+ bool release_input_handles = 5;
+ // If true, release the handle to the computation after running.
+ bool release_compilation_handle = 6;
+}
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
new file mode 100644
index 0000000000..4844c7fb71
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
+
+XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
+ CompiledSubgraph* entry)
+ : parent_(parent), entry_(entry) {
+ entry_->Ref();
+}
+
+XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
+ parent_->DiscardEntryRef(entry_);
+}
+
+XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
+ return XRTCompilationCacheEntry(entry_->program.get());
+}
+
+XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
+ : max_cache_entries_(max_number_of_entries) {
+ CHECK_GE(max_cache_entries_, 0);
+ VLOG(1) << "Created compilation cache max " << max_cache_entries_
+ << " entries.";
+}
+
+XRTCompilationCache::~XRTCompilationCache() {
+ VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
+ while (!entries_by_last_use_.empty()) {
+ MarkOldestEntryForEviction();
+ }
+ // By the time the cache is deleted all reference holders should have already
+ // been deleted, since they were holding references to the cache. So all
+ // entries should be gone at this point.
+ CHECK_EQ(cache_.size(), 0);
+ CHECK_EQ(entries_by_uid_.size(), 0);
+ CHECK_EQ(cache_entries_, 0);
+ CHECK_EQ(marked_for_eviction_entries_, 0);
+}
+
+Status XRTCompilationCache::Release(int64 uid) {
+ absl::MutexLock lock(&mu_);
+ auto iter = entries_by_uid_.find(uid);
+
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No cache entry found for uid ", uid);
+ }
+
+ DiscardEntryRefLocked(iter->second);
+
+ VLOG(1) << "After releasing entry " << uid << " refs cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_
+ << "), marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return Status::OK();
+}
+
+void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
+ absl::MutexLock lock(&mu_);
+ DiscardEntryRefLocked(entry);
+}
+
+void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
+ if (entry->RefCountIsOne()) {
+ // The last reference to this entry is going away, so really delete it from
+ // the cache in such a way that it can't be restored by being looked up
+ // again.
+
+ // Sanity-check that it has been marked for eviction.
+ CHECK(entries_by_last_use_.find(entry->last_use) ==
+ entries_by_last_use_.end());
+ // Update the counter tracking how much space is taken up by entries that
+ // are marked for eviction.
+ --marked_for_eviction_entries_;
+
+ // Remove the entry from the cache.
+ auto erased = cache_.erase(entry->key);
+ if (erased == 0) {
+ LOG(FATAL) << "Tried to discard nonexistent cache entry";
+ }
+ erased = entries_by_uid_.erase(entry->uid);
+ CHECK_EQ(erased, 1);
+ }
+ entry->Unref();
+}
+
+void XRTCompilationCache::MarkOldestEntryForEviction() {
+ CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
+ VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
+ entries_by_last_use_.erase(entry_to_mark->last_use);
+ --cache_entries_;
+ ++marked_for_eviction_entries_;
+ // Discard the cache's reference to entry. If steps are holding onto
+ // references to entry it won't be deleted until the last step holding it
+ // completes. It stays in the cache in the meantime and can be resurrected
+ // by a call to CompileIfKeyAbsent if that occurs before the last reference
+ // expires.
+ DiscardEntryRefLocked(entry_to_mark);
+}
+
+void XRTCompilationCache::LookupEntryMarkedForEviction(
+ CompiledSubgraph* entry) {
+ // The entry was previously marked for eviction (or is newly created) so
+ // unmark it. Add a reference (owned by the cache), update the cache size, and
+ // mark something old for eviction if necessary.
+ entry->Ref();
+ --marked_for_eviction_entries_;
+ ++cache_entries_;
+
+ // Mark the least-recently-used non-marked entry for eviction. Never mark the
+ // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
+ // which means there's only one entry not already marked for eviction), so
+ // that an entry persists in the cache even if it is larger than the allocated
+ // cache size.
+ while (entries_by_last_use_.size() > 1 &&
+ cache_entries_ > max_cache_entries_) {
+ MarkOldestEntryForEviction();
+ }
+}
+
+XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) {
+ CompiledSubgraph* entry = new CompiledSubgraph();
+ entry->parent = this;
+ entry->key = key;
+ entry->uid = next_uid_++;
+ // Add the entry to the cache. Once the computation has been compiled,
+ // UpdateEntryAfterCompilation will be called to potentially mark old entries
+ // that don't fit any more for eviction.
+ //
+ // At this point there is one reference to entry, which is owned by the caller
+ // who created the entry. A second reference, owned by the cache, will be
+ // added below since we leave the entry in the 'marked for eviction' state
+ // here.
+ auto cache_inserted =
+ cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
+ CHECK(cache_inserted.second);
+
+ // Initialize the program outside the lock so that other cache operations
+ // can proceed during the (potentially lengthy) initialization.
+ Status s;
+ std::unique_ptr<xla::LocalExecutable> program;
+ {
+ mu_.Unlock();
+ { s = initialize_program(&program); }
+ mu_.Lock();
+ }
+
+ // Add the entry to the uid index.
+ auto uid_inserted = entries_by_uid_.insert(
+ std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
+ CHECK(uid_inserted.second);
+
+ entry->initialized = true;
+ entry->initialization_status = s;
+ if (s.ok()) {
+ entry->program = std::move(program);
+ }
+ // Add the entry to marked_for_eviction_entries_ since it will be adjusted
+ // down again when the newly-created entry gets unmarked.
+ ++marked_for_eviction_entries_;
+ return entry;
+}
+
+Status XRTCompilationCache::CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function) {
+ CompiledSubgraph* entry = nullptr;
+
+ absl::MutexLock lock(&mu_);
+ auto iter = cache_.find(key);
+
+ if (iter == cache_.end()) {
+ // The single ref on the newly-created entry is owned by the caller.
+ VLOG(1) << "Before adding new entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = InitializeEntry(key, compile_function);
+ } else {
+ VLOG(1) << "Before refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = iter->second;
+ // Make a new reference that is owned by the caller.
+ entry->Ref();
+ // Block if necessary until the subgraph has been initialized.
+ mu_.Await(absl::Condition(
+ +[](CompiledSubgraph* e) { return e->initialized; }, entry));
+ }
+
+ // Let the caller know the uid of the entry.
+ *uid = entry->uid;
+
+ // Remove the old LRU-table entry if it wasn't already marked for eviction.
+ auto erased = entries_by_last_use_.erase(entry->last_use);
+ // Update the LRU table indicating this entry is the most recently used.
+ entry->last_use = use_counter_++;
+ entries_by_last_use_[entry->last_use] = entry;
+ if (erased == 0) {
+ // The entry had been marked for eviction, or is newly created.
+ LookupEntryMarkedForEviction(entry);
+ }
+
+ VLOG(1) << "After refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return entry->initialization_status;
+}
+
+Status XRTCompilationCache::Lookup(
+ int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ entry->reset();
+
+ absl::MutexLock lock(&mu_);
+ const auto iter = entries_by_uid_.find(uid);
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No executable found for uid ", uid);
+ }
+ CompiledSubgraph* cache_entry = iter->second;
+ *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
+ new EntryRefImpl(this, cache_entry));
+ return Status::OK();
+}
+
+string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; }
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h
new file mode 100644
index 0000000000..c505299a45
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h
@@ -0,0 +1,238 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace tensorflow {
+
+extern const char* kXRTCompilationCacheResourceName;
+
+struct XRTCompilationCacheEntry {
+ explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable)
+ : executable(executable) {}
+
+ // Returns a non-owned pointer to an immutable executable.
+ xla::LocalExecutable* get_executable() const { return executable; }
+
+ private:
+ xla::LocalExecutable* executable;
+};
+
+// Base class for a reference to a cached executable. A unique_ptr to a
+// XRTCompilationCacheEntryRef is returned by the cache Lookup methods below,
+// and ensures the underlying executable is not garbage-collected until the
+// client discards the ptr.
+class XRTCompilationCacheEntryRef {
+ public:
+ virtual ~XRTCompilationCacheEntryRef() = default;
+
+ // Returns a XRTCompilationCacheEntry that should not be used beyond the
+ // lifetime of the XRTCompilationCacheEntryRef.
+ virtual XRTCompilationCacheEntry get() = 0;
+};
+
+// Cache for compiled XLA executables.
+// TODO(b/112646171) rationalize this with the other compilation caches.
+//
+// Each key identifies a unique XLA computation, and the value is executable
+// generated by compiling the computation.
+//
+// When a computation is considered for compilation, the client calls
+//
+// auto key = <compute key for computation>;
+// auto compile_function = <lambda to compile computation into executable>;
+// int64 uid;
+// CompileIfKeyAbsent(computation_key, &uid, compile_function);
+//
+// where computation_key is the key computed for the computation. On success,
+// uid contains an identifier that can be used to look up the executable. If the
+// compiled executable were not present in the cache, compile_function would be
+// called to generate it.
+//
+// The caller is responsible for calling Release(uid) once for every
+// call to CompileIfKeyAbsent(key, ...) to discard the reference to the
+// compilation results, after the caller is sure it will not look up the
+// compiled executables again.
+//
+// Subsequently the client can call
+//
+// std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+// Lookup(uid, &entry);
+// auto proto = entry->get();
+//
+// to access a cached executable.
+class XRTCompilationCache : public ResourceBase {
+ public:
+ // There is no way in general to discover the size taken by an XLA executable,
+ // so the cache defaults to a specific number of entries to determine when to
+ // start evicting programs. TODO(b/112592410) change this if the XLA API gets
+ // a mechanism to query size.
+ explicit XRTCompilationCache(int max_number_of_entries);
+ ~XRTCompilationCache() override;
+
+ // Ensures there is an entry for key present in the cache. By the time
+ // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
+ // for key, and that entry will remain valid at least until Release is called
+ // on the returned uid. The first call to CompileIfKeyAbsent with a key that
+ // is not in the cache will evaluate compile_function to compute the value to
+ // use in the entry. Subsequent calls with the same key will block until
+ // compile_function completes. Other cache reads and inserts may proceed on
+ // other threads while compile_function is executing. The caller is
+ // responsible for calling Release(uid) to manually discard its reference to
+ // the compiled program, once the caller will not look up the compiled program
+ // again.
+ //
+ // compile_function should compile the computation represented by key and fill
+ // the xla::LocalExecutable into its passed argument. It should return OK
+ // if and only if compilation succeeds. The executable will be discarded on
+ // non-OK status.
+ Status CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function);
+
+ Status Release(int64 uid);
+
+ // Looks up an executable corresponding to uid. On success a pointer to an
+ // EntryRef holding the program is returned in entry.
+ Status Lookup(int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry);
+
+ string DebugString() override;
+
+ private:
+ // An entry in the compilation cache. The entry is deleted once it has been
+ // marked for eviction from the cache _and_ all looked-up entries have been
+ // released. When the entry is first created, it is uninitialized and a
+ // client-supplied compilation function is run outside the cache's lock to
+ // generate the program to be stored in the entry. Any other client that
+ // requests the entry will block until it has been initialized. Each entry has
+ // a last_use value that set from a monotonically-increasing counter in the
+ // cache whenever the entry is referenced. When the cache becomes full,
+ // entries are marked for eviction in LRU order.
+ struct CompiledSubgraph : public core::RefCounted {
+ ~CompiledSubgraph() override = default;
+
+ XRTCompilationCache* parent = nullptr; // Not owned.
+ bool initialized = false;
+ // The Status returned by the compilation function when the entry is
+ // initialized. This status will be returned to any client that requests the
+ // entry.
+ Status initialization_status;
+ // Counter to keep track of LRU entries for the eviction policy.
+ int64 last_use = -1;
+ // The unique key describing this entry.
+ string key;
+ // The uid describing this entry.
+ int64 uid;
+ // The compiled payload corresponding to the key.
+ std::unique_ptr<xla::LocalExecutable> program;
+ };
+
+ // Wrapper for a cache entry that holds a reference to the entry until the
+ // wrapper is deleted. This wrapper is the concrete type of
+ // XRTCompilationCacheEntryRef returned by Lookup.
+ class EntryRefImpl : public XRTCompilationCacheEntryRef {
+ public:
+ EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry);
+ ~EntryRefImpl() override;
+
+ XRTCompilationCacheEntry get() override;
+
+ private:
+ XRTCompilationCache* parent_; // Not owned.
+ // A reference to entry_ is acquired in the contructor and released via
+ // parent->DiscardEntryRef in the destructor.
+ CompiledSubgraph* entry_;
+ };
+
+ // Releases one reference to entry. This is called by the cache when entry is
+ // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the
+ // last reference to entry is released, entry is removed from cache_.
+ void DiscardEntryRef(CompiledSubgraph* entry);
+ void DiscardEntryRefLocked(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Marks the oldest unmarked entry for eviction. Requires that there is at
+ // least one such entry.
+ void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Updates datastructures to indicate that entry, which had been marked for
+ // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+ // entry is newly created, or an entry that has been marked for eviction but
+ // not yet evicted is looked up.
+ //
+ // First the entry is unmarked for eviction, i.e. the cache gains a reference
+ // to entry, entry's last_use field is set to be the most recent value of
+ // use_counter_ and entries_by_last_use_ is updated accordingly.
+ //
+ // Next, the size of the cache is examined to see if any other entries need to
+ // be marked for eviction now that entry has been unmarked. While the total
+ // number of unmarked cached entries is greater than max_cache_entries_,
+ // entries are marked for eviction in LRU order. The most recently used entry
+ // is never marked for eviction, so an entry larger than the max cache entries
+ // will remain in the cache until it is replaced by something else.
+ void LookupEntryMarkedForEviction(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Creates a new entry by running initialize_program and places it in the
+ // cache to be looked up by key. The new entry is in the 'marked for eviction'
+ // state (not present in entries_by_last_use_) and the caller is expected to
+ // call LookupEntryMarkedForEviction after InitializeEntry.
+ //
+ // **InitializeEntry releases mu_ during the call to initialize_program.**
+ CompiledSubgraph* InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // The maximum number of entries that are stored in the cache before entries
+ // are marked for eviction.
+ const int max_cache_entries_;
+
+ mutable absl::Mutex mu_;
+ // The uid to assign to the next new entry created.
+ int64 next_uid_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are stored and not marked for eviction.
+ int cache_entries_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are marked for eviction.
+ int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0;
+ // The value to assign to the last_use field of the next entry that is looked
+ // up.
+ int64 use_counter_ GUARDED_BY(mu_) = 0;
+ // All the executables that can be looked up in the cache index by key. An
+ // entry is marked for eviction iff it is present in cache_ and not in
+ // entries_by_last_use_.
+ std::unordered_map<string, CompiledSubgraph*> cache_ GUARDED_BY(mu_);
+ // All the executable entries that can be looked up in the cache indexed by
+ // uid.
+ std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ GUARDED_BY(mu_);
+ // Map from last_use to entry, used to mark entries for eviction in LRU
+ // order. If an entry's last_use counter is not present as a key in
+ // entries_by_last_use_ then the entry has been marked for eviction.
+ std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc
new file mode 100644
index 0000000000..ea40e6c895
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for managing access to XLA resources.
+
+#include "tensorflow/compiler/xrt/xrt_device.h"
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
+ OpKernelContext* ctx, ResourceMgr** rm) {
+ *rm = ctx->resource_manager();
+ return Status::OK();
+}
+
+/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
+ OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
+ const XlaDevice::Metadata* metadata;
+ TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
+ if (device_ordinal != metadata->device_ordinal()) {
+ return errors::Internal("XRT device ordinal requested ", device_ordinal,
+ " on device with ordinal ",
+ metadata->device_ordinal());
+ }
+ scoped_ref->Acquire(metadata->client());
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h
new file mode 100644
index 0000000000..1e3fddd2a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.h
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+// This accessor is used for XLA CPU/GPU. It uses the device resource manager,
+// so e.g., on multi-GPU setups the compilation cache will not be shared across
+// devices.
+class XRTGenericDeviceAccessor {
+ public:
+ static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
+
+ // We use a ScopedRef pattern here even though it's not strictly necessary,
+ // just so that templated uses of this and the TPU accessor class will be as
+ // similar as possible.
+ class ScopedRef {
+ public:
+ ScopedRef() {}
+ ~ScopedRef() {}
+
+ ScopedRef(const ScopedRef&) = delete;
+ ScopedRef& operator=(const ScopedRef&) = delete;
+
+ // Returns the XLA device protected by this ScopedRef.
+ xla::LocalClient* client() { return client_; }
+ xla::Backend* backend() { return client_->mutable_backend(); }
+ int device_ordinal() { return 0; }
+
+ private:
+ // XRTGenericDeviceAccessor::InitScopedRef is the only way to initialize
+ // ScopedRef.
+ friend class XRTGenericDeviceAccessor;
+
+ void Acquire(xla::LocalClient* client) { client_ = client; }
+
+ xla::LocalClient* client_ = nullptr;
+ };
+
+ static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal,
+ ScopedRef* scoped_ref);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
new file mode 100644
index 0000000000..d05a1e7dcb
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include "tensorflow/compiler/xrt/xrt_state.h"
+
+#include <stdint.h>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+namespace {
+
+const char* kTupleContainer = "tuples";
+
+// Counter used to assign unique handles.
+mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
+int64 _uid GUARDED_BY(_uid_mutex) = 0;
+int64 get_uid() {
+ mutex_lock l(_uid_mutex);
+ return _uid++;
+}
+
+Status AllocateScopedShapedBuffer(
+ xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
+ std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ // XLA may use a different representation on device than the representation on
+ // the host. XLA does not document any contract for the relationship between
+ // these representations :/ Right now, the device shape is always a superset
+ // of the host shape, meaning that for any valid ShapeIndex in the host shape
+ // that ShapeIndex is also valid in the device shape, but not vice versa. In
+ // particular, some host-side types are rewritten to be tuples. We rely on
+ // this property when making sub-buffers, because we assume that if the client
+ // requests the host-shape sub-buffer at index i, that will correspond to the
+ // right device-shape sub-buffer at the same index.
+ xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
+
+ // The ScopedShapedBuffer frees the buffers that have so far been allocated if
+ // it goes out of scope. That's useful if we return early as the result of an
+ // error allocating one of the later buffers.
+ *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
+ shape, on_device_shape, allocator, device_ordinal);
+ for (auto& index_to_buffer : (*buffer)->buffers()) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(
+ xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
+ // Move our buffer into shaped_buffer, which takes ownership of it.
+ index_to_buffer.second = buffer.Forget();
+ VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
+ << " index " << index_to_buffer.first.ToString();
+ }
+
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
+
+ return Status::OK();
+}
+
+} // namespace
+
+XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator)
+ : allocation_(allocation),
+ device_ordinal_(device_ordinal),
+ allocator_(allocator) {}
+
+XRTBufferAllocation::~XRTBufferAllocation() {
+ // Deallocate explicitly allows allocation_ to be null.
+ Status s = allocator_->Deallocate(device_ordinal_, allocation_);
+ // Nothing to do but check fail here if memory datastructures are corrupted.
+ CHECK(s.ok());
+ VLOG(2) << "Freed buffer at " << allocation_.opaque();
+}
+
+const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
+ return allocation_;
+}
+
+void XRTBufferAllocation::DiscardAllocation() {
+ // Replace the allocation with a null.
+ allocation_ = se::DeviceMemoryBase();
+}
+
+XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape)
+ : device_ordinal_(device_ordinal),
+ allocator_(allocator),
+ on_host_shape_(on_host_shape),
+ on_device_shape_(on_device_shape),
+ buffers_(&on_device_shape_) {}
+
+XRTTupleAllocation::~XRTTupleAllocation() {
+ for (auto& buffer : buffers_) {
+ buffer.second->Unref();
+ }
+}
+
+/*static*/ Status XRTTupleAllocation::CreateAndTransfer(
+ const xla::Literal& literal, xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+
+ std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
+ TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
+ backend, device_ordinal, literal.shape(), &scoped_buffer));
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
+ stream.get(), literal, *scoped_buffer));
+
+ // By releasing the ScopedShapedBuffer we ensure that the underlying storage
+ // won't be freed when the buffer goes out of scope at the end of this
+ // call. To avoid a leak, there must be no error-case returns from here until
+ // the end of the method.
+ auto shaped_buffer = scoped_buffer->release();
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
+ const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
+ int device_ordinal, XRTTupleAllocation** allocation) {
+ auto allocator = backend->memory_allocator();
+
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
+ xla::Literal* literal) {
+ auto transfer_manager = backend->transfer_manager();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
+ stream.get(), ToShapedBuffer()));
+ return Status::OK();
+}
+
+void XRTTupleAllocation::DiscardAllocation(
+ const xla::ShapeIndex& buffer_index) {
+ buffers_.element(buffer_index)->DiscardAllocation();
+}
+
+const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
+
+const xla::Shape& XRTTupleAllocation::on_device_shape() {
+ return on_device_shape_;
+}
+
+int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
+
+const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
+ return buffers_.element({})->allocation();
+}
+
+/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation) {
+ string key_string = absl::StrCat(key);
+ TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
+ int64 key) {
+ string key_string = absl::StrCat(key);
+ return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
+}
+
+// Helper typedef to make ShapeTree ForEach helper lambda signatures more
+// readable. They need a type of const T& where in this case T is the
+// following pointer.
+typedef XRTBufferAllocation* XRTBufferAllocationPtr;
+
+/*static*/ Status XRTTupleAllocation::MakeSubBuffer(
+ XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation, bool alias_parent_allocation) {
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* host_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* device_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
+
+ *allocation =
+ new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
+ *host_sub_shape, *device_sub_shape);
+ if (alias_parent_allocation) {
+ // Copy the subtree of allocations from the parent allocation.
+ (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
+ // Increment the refcount on each aliased buffer.
+ (*allocation)
+ ->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ } else {
+ // Find the buffers in the parent allocation that match the subtree, and
+ // move the parent allocation's buffer over to the new allocation.
+ (*allocation)
+ ->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ // Extend the allocation's index to the parent's frame by adding
+ // subshape as a prefix.
+ xla::ShapeIndex parent_index = subshape;
+ for (int i = 0; i < index.size(); ++i) {
+ parent_index.push_back(index[i]);
+ }
+ *buffer = parent->buffers_.element(parent_index);
+ *parent->buffers_.mutable_element(parent_index) =
+ new XRTBufferAllocation(se::DeviceMemoryBase(),
+ parent->device_ordinal(),
+ parent->allocator_);
+ });
+ }
+
+ return Status::OK();
+}
+
+/* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape) {
+ // Initialize both host and device shape to be the 'spine' of the new tuple
+ // shape, given by the shape of the tree of tuples.
+ *host_shape = elements.shape();
+ *device_shape = elements.shape();
+ // Now go over the leaves of the tree of tuples, and 'graft' the host/device
+ // shapes of the allocation at that leaf onto the expanded host/device shapes
+ // at the leaf position.
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ if (element.allocation == nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a null internal node at index ",
+ index.ToString());
+ }
+ if (device_ordinal != element.allocation->device_ordinal() ||
+ allocator != element.allocation->allocator_) {
+ return errors::InvalidArgument(
+ "MakeTuple elements must all be allocated on the same device "
+ "as the destination.");
+ }
+ *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
+ element.allocation->on_host_shape();
+ *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
+ element.allocation->on_device_shape();
+ } else {
+ if (element.allocation != nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a non-null internal node at index ",
+ index.ToString());
+ }
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::MakeTuple(
+ xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ xla::Shape host_shape;
+ xla::Shape device_shape;
+ TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
+ &host_shape, &device_shape));
+
+ // The aliasing is determined below based on whether or not all the inputs are
+ // released while being transferred. allocation_tmp is a local pointer that is
+ // copied to *allocation at the end only if the method succeeds.
+ auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
+ host_shape, device_shape);
+ core::ScopedUnref allocation_unref(allocation_tmp);
+ // First allocate device memory for the new tuple index tables, one at each
+ // internal node of the elements tree. Do this in a separate pass into a
+ // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
+ // an allocation fails. Make sure the shape has layout so that the code that
+ // writes index tables will be happy lower down.
+ xla::Shape spine_shape = elements.shape();
+ xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
+ auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
+ spine_shape, spine_shape, allocator, device_ordinal);
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (!elements.IsLeaf(index)) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(device_shape, index);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size,
+ /*retry_on_failure=*/false));
+ VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
+ << index.ToString();
+ // Move the new buffer into new_tuple_buffers, which takes ownership
+ // of it.
+ new_tuple_buffers->set_buffer(std::move(buffer), index);
+ }
+ return Status::OK();
+ }));
+ // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
+ // the newly-allocated index tables. Right now there's no owner for the new
+ // index tables, so next we will transfer ownership to the new allocation,
+ // taking care not to return early on any errors in the meantime.
+ xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
+ // Now fill in the remaining datastructures. After this ForEachElement
+ // completes:
+ // 1) Every leaf element of tuple_buffers will be the root buffer of
+ // an existing allocation, and every internal element of tuple_buffers
+ // will be a newly-allocated index table. tuple_buffers does not own any
+ // of these.
+ // 2) Every element of allocation_tmp->buffers_ will be a correctly
+ // constructed
+ // XRTBufferAllocation wrapping the necessary allocations. For buffers in
+ // existing allocations there will be a new reference owned by the new
+ // allocation, and for newly-allocated index tables there will be a
+ // single reference owned by the new allocation.
+ elements.ForEachElement([&](const xla::ShapeIndex& index,
+ const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
+ index);
+ tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
+ if (element.release_allocation_after_use) {
+ // Transfer the references from element's buffers to the new allocation
+ // rather than incrementing the refcount. The caller should have
+ // validated that release_allocation_after_use is false if
+ // element.allocation appears in more than one leaf.
+ element.allocation->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ *buffer = new XRTBufferAllocation(
+ se::DeviceMemoryBase(), element.allocation->device_ordinal(),
+ element.allocation->allocator_);
+ });
+ } else {
+ // Increment the refcount on each newly-aliased buffer.
+ element.allocation->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ }
+ } else {
+ // This is an internal node of the tuple tree so take ownership of the
+ // newly-created index table.
+ *allocation_tmp->buffers_.mutable_element(index) =
+ new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
+ allocator);
+ }
+ });
+ // Because the internal nodes of tuple_buffers are exactly the new index
+ // tables, WriteTupleIndexTables will write only the new index tables and not
+ // rewrite the index tables for the existing allocations.
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
+
+ *allocation = allocation_tmp;
+ // Get another reference since allocation_tmp will be Unrefed automatically on
+ // exit.
+ (*allocation)->Ref();
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
+ *key = get_uid();
+ string key_string = absl::StrCat(*key);
+ return rm->Create(kTupleContainer, key_string, this);
+}
+
+bool XRTTupleAllocation::IsExclusiveOwner() {
+ for (const auto& buffer : buffers_) {
+ if (!buffer.second->RefCountIsOne()) return false;
+ }
+ return true;
+}
+
+void XRTTupleAllocation::InitializeFromShapedBuffer(
+ const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
+ for (auto& buffer : buffers_) {
+ // Make a reference-counted version of the allocated buffer.
+ buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
+ device_ordinal, allocator);
+ }
+}
+
+xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
+ xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
+ allocator_->platform(), device_ordinal_);
+ for (const auto& buffer : buffers_) {
+ shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
+ }
+ return shaped_buffer;
+}
+
+xla::ShapeTree<xla::MaybeOwningDeviceMemory>
+XRTTupleAllocation::ToDeviceMemoryTree(bool release) {
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
+ for (const auto& buffer : buffers_) {
+ if (!release) {
+ *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
+ } else {
+ *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
+ buffer.second->allocation(), device_ordinal_, allocator_);
+ DiscardAllocation(buffer.first);
+ }
+ }
+ return shaped_tree;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
new file mode 100644
index 0000000000..73b5584e38
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -0,0 +1,208 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// TODO(misard) make this a Tensor if and when that makes sense.
+// A reference-counted wrapper around a buffer allocation. This maps an XLA
+// tuple index or a non-tuple XLA shape to a region of device memory. The device
+// memory buffer is freed when the reference count drops to zero.
+class XRTBufferAllocation : public core::RefCounted {
+ public:
+ XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator);
+ ~XRTBufferAllocation() override;
+
+ // The region of device memory being wrapped.
+ const se::DeviceMemoryBase& allocation();
+
+ // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called
+ // when ownership of the underlying buffer has been transferred, e.g., to an
+ // output buffer when input and output buffers are aliased during
+ // execution. The call to DiscardAllocation prevents any device buffer being
+ // freed when the reference count drops to zero.
+ void DiscardAllocation();
+
+ private:
+ se::DeviceMemoryBase allocation_;
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+};
+
+// Entry in the resource manager corresponding to an allocation handle returned
+// to a client. The handle identifies an immutable tuple of data in device
+// memory. New handles can be created in three ways: by passing a literal in
+// which case device memory is allocated and the literal is transferred to that
+// memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by
+// aliasing a vector of existing handles to create a new tuple. The underlying
+// storage is reference-counted. When a handle is released, the reference count
+// of each storage buffer is decremented, and buffers with no outstanding
+// references are freed.
+class XRTTupleAllocation : public ResourceBase {
+ public:
+ ~XRTTupleAllocation() override;
+
+ // Allocates new device memory buffers sufficient to store literal, transfers
+ // literal to that memory, and returns a XRTTupleAllocation handle to the
+ // allocated buffers.
+ static Status CreateAndTransfer(const xla::Literal& literal,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
+ static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
+ // to the sub-shape. If alias_base_allocation is true, the buffers in the
+ // sub-shape will be shared between parent and the returned allocation,
+ // otherwise the overlapping buffers in parent will be replaced by
+ // nullptr.
+ static Status MakeSubBuffer(XRTTupleAllocation* parent,
+ const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation,
+ bool alias_parent_allocation);
+
+ // A structure describing a leaf of a tree of tuples to expand. Each leaf
+ // contains an allocation and indicates whether or not the allocation's handle
+ // should be freed after incorporating its buffers into the expanded tree.
+ struct ExpandedTupleInput {
+ XRTTupleAllocation* allocation;
+ bool release_allocation_after_use;
+ };
+
+ // Returns a handle to a new tuple where the subtree of the new tuple at an
+ // index corresponding to a leaf of 'elements' is constructed from the
+ // allocation (i.e., a tuple or array) pointed to by that leaf. If
+ // release_allocation_after_use is false at a leaf, the new tuple will alias
+ // the input allocation at that leaf, otherwise the input allocation will be
+ // released. Input allocations may be repeated (appear in more than one leaf)
+ // in which case the corresponding buffers in the output tuple will alias. If
+ // an input is repeated, release_input_handle must be false for every leaf
+ // where that input appears. The latter property is not validated by MakeTuple
+ // and must be enforced by the caller.
+ static Status MakeTuple(xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation);
+
+ // Retrieves the allocation interned under key from rm. The caller owns a
+ // reference to allocation after looking it up.
+ static Status Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation);
+
+ // Deletes the reference in the rm to an allocation interned under key.
+ static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key);
+
+ // Adds the allocation to a ResourceMgr and returns the key that will be used
+ // to retrieve it. Transfers a reference on *this to rm.
+ Status Intern(ResourceMgr* rm, int64* key);
+
+ // Copies the allocation from device to host and returns it in literal.
+ Status ToLiteral(xla::Backend* backend, int device_ordinal,
+ xla::Literal* literal);
+
+ // True if none of the buffers in the allocation are aliased by any other live
+ // handle.
+ bool IsExclusiveOwner();
+
+ // The ordinal of the device holding this tuple.
+ int device_ordinal();
+
+ // Returns the shape of the tuple as seen by the host.
+ const xla::Shape& on_host_shape();
+
+ // Returns the shape of the tuple as stored on the device.
+ const xla::Shape& on_device_shape();
+
+ // Returns the buffer pointed to by the root of the tuple.
+ const se::DeviceMemoryBase& root_allocation();
+
+ // Stops managing the storage for the allocation at buffer_index, e.g.,
+ // because it has been aliased to the output buffer of a computation.
+ void DiscardAllocation(const xla::ShapeIndex& buffer_index);
+
+ // Returns the tree of allocations as a ShapedBuffer. This tree may not have
+ // the same shape as on_host_shape.
+ xla::ShapedBuffer ToShapedBuffer();
+
+ // Returns the device memory tree of this allocation. If 'release' is set, the
+ // ownership of the device memory is transferred to the result.
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree(bool release);
+
+ string DebugString() override { return "XLA allocation handle"; }
+
+ private:
+ // Creates a new handle with (tuple) shape.
+ XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape);
+
+ // Inherits the allocations represented in buffer, which must have the same
+ // shape as buffers_.
+ void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator,
+ int device_ordinal);
+
+ // Takes a tree 'elements' where each leaf is an allocation, validates that
+ // they are all on device_ordinal managed by allocator, and returns in
+ // host_shape and device_shape the host/device shapes of the expanded tree,
+ // where at each leaf of elements the shape of the allocation at elements is
+ // grafted on.
+ static Status ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape);
+
+ // Location of the memory that is being managed.
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+
+ // The shape that the caller thinks the tuple has.
+ const xla::Shape on_host_shape_;
+ // The shape that the tuple has on device. Store this explicitly instead of
+ // using a shape stored in ShapeTree because ShapeTree discards the layout.
+ const xla::Shape on_device_shape_;
+ // The tree of reference-counted buffers, which uses on_device_shape_ as its
+ // shape.
+ xla::ShapeTree<XRTBufferAllocation*> buffers_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 23bb783e22..798f499870 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -55,7 +55,6 @@ py_library(
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/keras",
"//tensorflow/contrib/kernel_methods",
- "//tensorflow/contrib/kfac",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
@@ -64,6 +63,7 @@ py_library(
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
+ "//tensorflow/contrib/lite/python:lite",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/losses:metric_learning_py",
@@ -129,13 +129,8 @@ py_library(
]) + if_not_windows([
"//tensorflow/contrib/bigtable", # depends on bigtable
"//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
- # TODO(aaroey): tensorrt dependency has to appear before tflite so the
- # build can resolve its flatbuffers symbols within the tensorrt library.
- # This is an issue with the tensorrt static library and will be fixed by
- # the next tensorrt release, so fix the order here after that.
"//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
+ "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
]),
)
@@ -181,6 +176,7 @@ cc_library(
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
"//tensorflow/contrib/data:dataset_ops_op_lib",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index e18ea8df4d..9478e42b46 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,6 +21,14 @@ from __future__ import print_function
import os
+from tensorflow.python.tools import component_api_helper
+component_api_helper.package_hook(
+ parent_package_str=(
+ "tensorflow.contrib"),
+ child_package_str=(
+ "tensorflow_estimator.contrib.estimator"))
+del component_api_helper
+
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import autograph
from tensorflow.contrib import batching
@@ -51,7 +59,6 @@ from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import keras
from tensorflow.contrib import kernel_methods
-from tensorflow.contrib import kfac
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
@@ -94,8 +101,7 @@ from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.eager.python import tfe as eager
-if os.name != "nt":
- from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
from tensorflow.contrib.recurrent.python import recurrent_api as recurrent
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc
index 513d519eab..d14b2126a0 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.cc
+++ b/tensorflow/contrib/android/asset_manager_filesystem.cc
@@ -28,7 +28,7 @@ string RemoveSuffix(const string& name, const string& suffix) {
string output(name);
StringPiece piece(output);
str_util::ConsumeSuffix(&piece, suffix);
- return piece.ToString();
+ return string(piece);
}
// Closes the given AAsset when variable is destructed.
@@ -231,7 +231,7 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
StringPiece piece(name);
str_util::ConsumePrefix(&piece, prefix_);
- return piece.ToString();
+ return string(piece);
}
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index b26c52294c..29dce13999 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
@@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base):
TF equivalent, like `len`.
"""
- def _convert_builtin(self, node):
+ def _convert_builtin(self, f, args, as_expression):
template = """
- ag__.utils.dynamic_builtin(func, args)
+ ag__.func(args)
"""
- return templates.replace(template, func=node.func, args=node.args)[0].value
-
- def _convert_print(self, node):
- template = """
- ag__.utils.dynamic_print(args)
- """
- return templates.replace(template, args=node.args)[0].value
+ if as_expression:
+ return templates.replace_as_expression(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+ else:
+ return templates.replace(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
def visit_Call(self, node):
- self.generic_visit(node)
- # TODO(mdan): This won't work if the function was hidden.
- # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
- if (isinstance(node.func, gast.Name) and
- node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
- return self._convert_builtin(node)
- # Print needs to be handled separately because it can be read as statement.
- if isinstance(node.func, gast.Name) and node.func.id == 'print':
- return self._convert_print(node)
+ node = self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ live_val = anno.getanno(node.func, 'live_val')
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
return node
def visit_Print(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
args = node.values
# Following is the case when calling print(a, b)
if len(args) == 1 and isinstance(args[0], gast.Tuple):
args = args[0].elts
- template = """
- fname(args)
- """
- function_call = templates.replace(template, fname='print', args=args)[0]
- return self.visit(function_call)
+ return self._convert_builtin(print, args, as_expression=False)
def transform(node, ctx):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index d5c3e2c250..3e3a04f38b 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -23,6 +23,7 @@ import six
from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
def test_fn(a):
return len(a)
- with self.converted(test_fn, builtin_functions, {'len': len},
- array_ops.shape) as result:
- with self.test_session() as sess:
- ops = result.test_fn(constant_op.constant([0, 0, 0]))
- self.assertEqual(sess.run(ops), 3)
+ with self.converted(test_fn, builtin_functions, {'len': len}) as result:
+ with self.cached_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ ops = result.test_fn(p)
+ self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
def test_print(self):
@@ -49,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertPrints('a\n'):
sess.run(result.test_fn('a'))
@@ -62,7 +63,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a, b, c)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertPrints('a 1 [2, 3]\n'):
sess.run(
result.test_fn(
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py
index 8cdba659ee..ca4d1f2932 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/contrib/autograph/converters/call_trees_test.py
@@ -91,7 +91,7 @@ class CallTreesTest(converter_testing.TestCase):
setattr(a, 'foo', 'bar')
with self.converted(test_fn, call_trees, {'setattr': setattr}) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
class Dummy(object):
pass
@@ -110,7 +110,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {'np': np},
dtypes.int64) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
@@ -129,7 +129,7 @@ class CallTreesTest(converter_testing.TestCase):
node = call_trees.transform(node, ctx)
with self.compiled(node, ns) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result_tensor = result.test_fn(constant_op.constant(1))
self.assertEquals(sess.run(result_tensor), 3)
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 5a5a6ad63a..3530fbb2ec 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -95,6 +95,18 @@ class ControlFlowTransformer(converter.Base):
return 'no variables'
return ', '.join(map(str, symbol_set))
+ def _validate_no_live_vars_created(self, node):
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ live_vars_created_in_body = live_vars_out & body_scope.created
+ if live_vars_created_in_body:
+ raise ValueError(
+ 'The following variables are created inside the loop and used later:'
+ '\n%s\n'
+ 'Variables must be declared outside loops because loops may not'
+ ' necessarily execute.' % self._fmt_symbol_list(
+ live_vars_created_in_body))
+
def visit_If(self, node):
node = self.generic_visit(node)
@@ -197,13 +209,15 @@ class ControlFlowTransformer(converter.Base):
def visit_While(self, node):
self.generic_visit(node)
+ self._validate_no_live_vars_created(node)
+
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
cond_closure = set()
- for s in cond_scope.referenced:
+ for s in cond_scope.used:
for root in s.support_set:
if root not in body_scope.created:
cond_closure.add(root)
@@ -236,6 +250,7 @@ class ControlFlowTransformer(converter.Base):
node_body = ast_util.rename_symbols(node.body, ssf_map)
test = ast_util.rename_symbols(node.test, ssf_map)
+ # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
template = """
def test_name(state_ssf):
return test
@@ -262,6 +277,8 @@ class ControlFlowTransformer(converter.Base):
def visit_For(self, node):
self.generic_visit(node)
+ self._validate_no_live_vars_created(node)
+
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
@@ -294,7 +311,9 @@ class ControlFlowTransformer(converter.Base):
template = """
def extra_test_name(state_ssf):
return extra_test_expr
- def body_name(iterate, state_ssf):
+ def body_name(loop_vars, state_ssf):
+ # Workaround for PEP-3113
+ iterate = loop_vars
body
return state_ssf,
state_ast_tuple = ag__.for_stmt(
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index ade3501426..1d04ba3ba6 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -33,7 +33,7 @@ class ControlFlowTest(converter_testing.TestCase):
inputs = (inputs,)
with self.converted(test_fn, control_flow, {},
constant_op.constant) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
def test_while_basic(self):
@@ -48,6 +48,24 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
+ def test_while_nested(self):
+
+ def test_fn(n):
+ i = 0
+ j = 0
+ s = 0
+ while i < n:
+ while j < i:
+ j += 3
+ u = i + j # 'u' is not defined within the inner loop
+ s += u
+ i += 1
+ j = 0
+ return s, i, j, n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5),
+ (25, 5, 0, 5))
+
def test_while_single_output(self):
def test_fn(n):
@@ -57,6 +75,17 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
+ def test_while_variable_defined_in_body(self):
+ def bad_while_loop(n):
+ while n > 0:
+ n -= 1
+ s = n
+ return s
+
+ node, ctx = self.prepare(bad_while_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
def test_if_basic(self):
def test_fn(n):
@@ -89,7 +118,7 @@ class ControlFlowTest(converter_testing.TestCase):
return obj
with self.converted(test_fn, control_flow, {}) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
self.assertEqual(sess.run((res_obj.a, res_obj.b)), (-1, 0))
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
@@ -196,6 +225,23 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertEqual(result.test_fn(5), 10)
self.assertEqual(eval_count[0], 1)
+ def test_for_variable_defined_in_body(self):
+ def bad_for_loop(n):
+ for i in range(n):
+ s = i
+ return s
+
+ node, ctx = self.prepare(bad_for_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
+ def test_for_tuple_unpacking(self):
+ def test_fn(x_list):
+ z = tf.constant(0) # pylint:disable=undefined-variable
+ for i, x in enumerate(x_list):
+ z = z + x + i
+ return z
+ self.assertTransformedResult(test_fn, [3, 3], 7)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index 996e99ee61..c5e2dcf75e 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -65,7 +65,7 @@ class ListTest(converter_testing.TestCase):
ns = {'special_functions': special_functions}
with self.converted(test_fn, lists, ns) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(sess.run(r), [1, 2, 3])
@@ -88,7 +88,7 @@ class ListTest(converter_testing.TestCase):
node = lists.transform(node, ctx)
with self.compiled(node, ns, dtypes.int32) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(sess.run(r), [1, 2])
@@ -122,7 +122,7 @@ class ListTest(converter_testing.TestCase):
node = lists.transform(node, ctx)
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
# TODO(mdan): Add a test with tf.stack with axis kwarg.
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py
index 16eb1f0e3f..41c3424fa3 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions.py
@@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'autograph_utils.dynamic_is',
- gast.IsNot: 'autograph_utils.dynamic_is_not'
+ gast.Is: 'ag__.utils.dynamic_is',
+ gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
index ca07de5e8a..409a73afba 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
@@ -33,7 +33,7 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.converted(test_fn, logical_expressions, {},
math_ops.equal) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(1, 1)))
self.assertFalse(sess.run(result.test_fn(1, 2)))
@@ -44,9 +44,18 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or,
math_ops.logical_and) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
+ def test_ag_utils_lookup(self):
+ def test_fn(a, b):
+ return a is b or a is not b
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
+ ) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False)))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
index bee512abbc..5fe5114d4b 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
@@ -46,7 +46,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
@@ -67,7 +67,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
@@ -87,7 +87,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, control_flow_ops.Assert) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'expected in throw'):
sess.run(result.test_fn(constant_op.constant(-1)))
@@ -107,7 +107,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign_add) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
@@ -128,7 +128,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body[0].body), 1)
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
@@ -151,7 +151,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
with self.compiled(node, {}, state_ops.assign,
state_ops.assign_add) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
sess.run(result.test_fn(v))
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py
index c822d53a4a..d74b2e025e 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/contrib/autograph/converters/slices_test.py
@@ -45,7 +45,7 @@ class SliceTest(converter_testing.TestCase):
node = slices.transform(node, ctx)
with self.compiled(node, {}, dtypes.int32) as result:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = result.test_fn(tl)
diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
index bcbb920cc5..c2427f5f4f 100644
--- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
+++ b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
@@ -4,7 +4,7 @@ The `py_func` op requires specifying a
[data type](https://www.tensorflow.org/guide/tensors#data_types).
When wrapping a function with `py_func`, for instance using
-`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two
+`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
options to specify the returned data type:
* explicitly, with a specified `tf.DType` value
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
index 6c281485b4..3630b41fc8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -23,7 +23,6 @@ py_test(
],
srcs_version = "PY2AND3",
tags = ["no_windows"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
index f4b9159942..04a968be10 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
+++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
@@ -97,7 +97,7 @@ class ErrorsTest(tf.test.TestCase):
compiled_fn = ag.to_graph(test_fn)
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
@@ -144,7 +144,7 @@ class ErrorsTest(tf.test.TestCase):
# frame with "g" as the function name but because we don't yet add
# try/except blocks to inner functions the name is "tf__g".
with self.assertRaises(ag.TfRuntimeError) as error:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn):
sess.run(x)
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
index 680b6dbaf0..904246afb7 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
+++ b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
@@ -33,7 +33,7 @@ class ListLiteralsTest(tf.test.TestCase):
converted = ag.to_graph(list_used_as_tuple)
result = converted()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(result), [1, 2, 3])
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 276a387180..8b38d5d080 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -29,9 +29,9 @@ import six
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import builtins
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
@@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
unknown_arg_value = object() # Sentinel for arguments of unknown value
if inspect_utils.isbuiltin(f):
- return builtins.dynamic_builtin(f, *args, **kwargs)
+ return py_builtins.overload_of(f)(*args, **kwargs)
if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
# Regular functions
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 803fde9089..a4c6fed265 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -38,9 +38,6 @@ class ApiTest(test.TestCase):
def setUp(self):
config.COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
- 'from tensorflow.contrib.autograph import utils'
- ' as autograph_utils',
- 'tf = autograph_utils.fake_tf()',
)
def test_decorator_recurses(self):
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 332d5dab19..29759bad79 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
+ "py_builtins.py",
"slices.py",
],
srcs_version = "PY2AND3",
@@ -62,6 +63,16 @@ py_test(
)
py_test(
+ name = "py_builtins_test",
+ srcs = ["py_builtins_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "slices_test",
srcs = ["slices_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 392cb60bcc..c4fbc260a2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack
from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.py_builtins import float_
+from tensorflow.contrib.autograph.operators.py_builtins import int_
+from tensorflow.contrib.autograph.operators.py_builtins import len_
+from tensorflow.contrib.autograph.operators.py_builtins import print_
+from tensorflow.contrib.autograph.operators.py_builtins import range_
from tensorflow.contrib.autograph.operators.slices import get_item
from tensorflow.contrib.autograph.operators.slices import GetItemOpts
from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 9909e52164..9a66a6bb60 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state):
def _known_len_for_stmt(iter_, extra_test, body, init_state):
- """Overload of for_stmt that iterates over objects that define a length."""
- n = builtins.dynamic_len(iter_)
+ """Overload of for_stmt that iterates over objects that admit a length."""
+ n = py_builtins.len_(iter_)
def while_body(iterate_index, *state):
iterate = iter_[iterate_index]
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
index b14d7edba3..677b7f8f62 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -34,7 +34,7 @@ class ForLoopTest(test.TestCase):
extra_test=lambda s: True,
body=lambda i, s: (s + i,),
init_state=(0,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((10,), sess.run(s))
def test_python(self):
@@ -52,7 +52,7 @@ class ForLoopTest(test.TestCase):
extra_test=lambda s: True,
body=lambda i, s: (s + i,),
init_state=(0,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((10,), sess.run(s))
@@ -65,7 +65,7 @@ class WhileLoopTest(test.TestCase):
body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
extra_deps=(n,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((5, 10), sess.run(results))
def test_python(self):
@@ -86,7 +86,8 @@ class IfStmtTest(test.TestCase):
cond=cond,
body=lambda: 1,
orelse=lambda: -1)
- with self.test_session() as sess:
+
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py
index 7ea11a839b..4b1e835d44 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/contrib/autograph/operators/data_structures_test.py
@@ -42,7 +42,7 @@ class ListTest(test.TestCase):
def test_tf_tensor_list_new(self):
l = data_structures.tf_tensor_list_new([3, 4, 5])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_list_new_illegal_input(self):
@@ -63,7 +63,7 @@ class ListTest(test.TestCase):
def test_tf_tensor_array_new(self):
l = data_structures.tf_tensor_array_new([3, 4, 5])
t = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
def test_tf_tensor_array_new_illegal_input(self):
@@ -88,14 +88,14 @@ class ListTest(test.TestCase):
l = data_structures.list_append(l, x)
t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [[1, 2, 3]])
def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l1 = data_structures.list_append(l, 1)
l2 = data_structures.list_append(l1, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(l1.stack()), [1])
self.assertAllEqual(sess.run(l2.stack()), [1, 2])
@@ -116,7 +116,7 @@ class ListTest(test.TestCase):
with self.assertRaises(NotImplementedError):
data_structures.list_pop(l, 0, opts)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l, x = data_structures.list_pop(l, None, opts)
self.assertAllEqual(sess.run(x), [3, 4])
@@ -137,7 +137,7 @@ class ListTest(test.TestCase):
opts = data_structures.ListStackOpts(
element_dtype=initial_list.dtype, original_call=None)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py
new file mode 100644
index 0000000000..c5730934e7
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+ if f in SUPPORTED_BUILTINS:
+ return BUILTIN_FUINCTIONS_MAP[f.__name__]
+ return f
+
+
+def abs_(x):
+ if tensor_util.is_tensor(x):
+ return _tf_abs(x)
+ return _py_abs(x)
+
+
+def _tf_abs(x):
+ return math_ops.abs(x)
+
+
+def _py_abs(x):
+ return abs(x)
+
+
+def float_(x=0):
+ if tensor_util.is_tensor(x):
+ return _tf_float(x)
+ return _py_float(x)
+
+
+def _tf_float(x):
+ # TODO(mdan): We shouldn't assume float32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+ return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+ return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+ if tensor_util.is_tensor(x):
+ return _tf_int(x, base)
+ return _py_int(x, base)
+
+
+def _tf_int(x, base):
+ if base not in (10, UNDEFINED):
+ raise NotImplementedError('base {} not supported for int'.format(base))
+
+ # TODO(mdan): We shouldn't assume int32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+ return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+ if base is UNDEFINED:
+ return int(x)
+ return int(x, base)
+
+
+def len_(s):
+ if tensors.is_tensor_array(s):
+ return _tf_tensor_array_len(s)
+ elif tensors.is_tensor_list(s):
+ return _tf_tensor_list_len(s)
+ elif tensor_util.is_tensor(s):
+ return _tf_tensor_len(s)
+ return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+ return s.size()
+
+
+def _tf_tensor_list_len(s):
+ return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+ """Overload of len_ for Tensor arguments."""
+ # Statically shaped tensors: length is known ahead of time.
+ if s.shape.ndims and s.shape[0].value is not None:
+ return s.shape[0].value
+
+ # Static shape of unknown dimensions: use dynamic shape but statically
+ # chech that it's a scalar.
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+ if shape.shape[0].value is not None:
+ return array_ops.shape(s)[0]
+
+ # Fully dynamic shape: use ops.
+ rank = array_ops.rank(s)
+
+ def raise_zero_rank_error():
+ msg = gen_string_ops.string_join(
+ ['len requires non-zero rank, got ',
+ gen_string_ops.as_string(rank)])
+ with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+ return constant_op.constant(0, dtype=dtypes.int32)
+
+ return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+ raise_zero_rank_error)
+
+
+def _py_len(s):
+ return len(s)
+
+
+def print_(*objects, **kwargs):
+ # Note: Python 2.6 doesn't support explicit keywords after starargs.
+ unknown_kwargs = tuple(
+ set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+ if unknown_kwargs:
+ raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+ # TODO(mdan): use logging_ops.Print when py_func is not supported.
+ return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+ """Overload of print_ as a py_func implementation."""
+ override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+ if 'flush' not in override_kwargs:
+ # Defaulting to flushing the console in graph mode, which helps reduce
+ # garbled output in IPython.
+ override_kwargs['flush'] = True
+
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(
+ v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+ six.print_(*vals, **override_kwargs)
+
+ return py_func.wrap_py_func(
+ print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+ if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+ return _tf_range(start_or_stop, stop, step)
+ return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+ # TODO(mdan): We should optimize this when a full tensor is not required.
+ if step is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop)
+ return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+ if step is not UNDEFINED:
+ return range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return range(start_or_stop, stop)
+ return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+ SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+ 'abs': abs_,
+ 'float': float_,
+ 'int': int_,
+ 'len': len_,
+ 'print': print_,
+ 'range': range_,
+ 'xrange': range_,
+}
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000000..4073c51785
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+ def test_abs(self):
+ self.assertEqual(py_builtins.abs_(-1), 1)
+ with self.test_session() as sess:
+ t = py_builtins.abs_(constant_op.constant(-1))
+ self.assertEqual(sess.run(t), 1)
+ t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+ self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+ def test_float(self):
+ self.assertEqual(py_builtins.float_(10), 10.0)
+ self.assertEqual(py_builtins.float_('10.0'), 10.0)
+ with self.test_session() as sess:
+ t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+ self.assertEqual(sess.run(t), 1.0)
+ st = py_builtins.float_(constant_op.constant('1.0'))
+ self.assertEqual(sess.run(st), 1.0)
+
+ def test_int(self):
+ self.assertEqual(py_builtins.int_(10.0), 10)
+ self.assertEqual(py_builtins.int_('11', 2), 3)
+ with self.test_session() as sess:
+ t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+ self.assertEqual(sess.run(t), 1)
+ st = py_builtins.int_(constant_op.constant('1'))
+ self.assertEqual(sess.run(st), 1)
+ st = py_builtins.int_(constant_op.constant('1'), 10)
+ self.assertEqual(sess.run(st), 1)
+
+ def test_int_unsupported_base(self):
+ t = constant_op.constant(1, dtype=dtypes.float64)
+ with self.assertRaises(NotImplementedError):
+ py_builtins.int_(t, 2)
+
+ def test_len(self):
+ self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+ with self.test_session() as sess:
+ t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+ self.assertEqual(t, 3)
+ ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+ self.assertEqual(sess.run(ta), 5)
+ tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+ self.assertEqual(sess.run(tl), 3)
+
+ def test_len_scalar(self):
+ with self.assertRaises(ValueError):
+ py_builtins.len_(constant_op.constant(1))
+
+ def test_len_dynamic_shape(self):
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ t = py_builtins.len_(p)
+ self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ t = py_builtins.len_(p)
+ sess.run(t, {p: 1})
+
+ def test_print_tensors(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_complex(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(
+ py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_range(self):
+ self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+ self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+ self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+ def test_range_tensor(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [0, 1, 2])
+ r = py_builtins.range_(1, constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [1, 2])
+ r = py_builtins.range_(2, 0, constant_op.constant(-1))
+ self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
index 04fbeb2f6e..2b7f5ad922 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/contrib/autograph/operators/slices.py
@@ -22,6 +22,7 @@ import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
@@ -57,6 +58,8 @@ def get_item(target, i, opts):
elif tensor_util.is_tensor(target):
if target.dtype == dtypes.variant:
return _tf_tensor_list_get_item(target, i, opts)
+ elif target.dtype == dtypes.string and target.shape.ndims == 0:
+ return _tf_tensor_string_get_item(target, i)
else:
return _tf_tensor_get_item(target, i)
else:
@@ -82,6 +85,12 @@ def _tf_tensor_get_item(target, i):
return target[i]
+def _tf_tensor_string_get_item(target, i):
+ """Overload of get_item that stages a Tensor string read."""
+ x = gen_string_ops.substr(target, i, 1)
+ return x
+
+
def _py_get_item(target, i):
"""Overload of get_item that executes a Python list modification."""
return target[i]
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py
index d4aacb9d20..5255b7e2b6 100644
--- a/tensorflow/contrib/autograph/operators/slices_test.py
+++ b/tensorflow/contrib/autograph/operators/slices_test.py
@@ -32,7 +32,7 @@ class SlicesTest(test.TestCase):
l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
l = slices.set_item(l, 0, [5, 6])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
@@ -43,9 +43,24 @@ class SlicesTest(test.TestCase):
t = slices.get_item(
l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4])
+ def test_get_item_tensor_string(self):
+ initial_str = constant_op.constant('abcd')
+ t = slices.get_item(initial_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'b')
+
+ initial_list_str = constant_op.constant(['abcd', 'bcde'])
+ t = slices.get_item(initial_list_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'bcde')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
index a0938b3e5f..fe630ef852 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -22,9 +22,11 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
"@gast_archive//:gast",
"@six_archive//:six",
+ # TODO(aqj) Revisit this dependency direction when pyct is more
+ # modularized
+ "//tensorflow/contrib/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index e42f679cfe..d77c15915b 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base):
# just recur.
def visit_List(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def visit_Tuple(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def transform(node, entity_info, gensym_source=None):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index 951974820c..1ffd4bbe55 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase):
self.assert_body_anfs_as_expected(expected_result, test_function)
+ def test_nested_multi_value_assign(self):
+
+ def test_function(a, b, c):
+ x, y = a, a + b
+ (z, y), x = (c, y + b), x + a
+ return z, (y, x)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ x, y = a, tmp_1001
+ tmp_1002 = y + b
+ tmp_1003 = (c, tmp_1002)
+ tmp_1004 = x + a
+ (z, y), x = tmp_1003, tmp_1004
+ tmp_1005 = y, x
+ tmp_1006 = z, tmp_1005
+ return tmp_1006
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_deeply_nested_multi_value_assign(self):
+
+ def test_function(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+ def expected_result(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ tmp_1001 = b, c
+ tmp_1002 = [d, e]
+ tmp_1003 = [tmp_1001, tmp_1002]
+ tmp_1004 = f, g
+ tmp_1005 = h, i, j
+ tmp_1006 = tmp_1003, tmp_1004
+ tmp_1007 = [tmp_1005, k]
+ tmp_1008 = [tmp_1006, tmp_1007]
+ return tmp_1008
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
def test_local_definition_and_binary_compare(self):
def test_function():
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 2d8f922a45..e7baa244b2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
@@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
+ elif node.id in _special_symbols:
+ anno.setanno(node, 'live_val', _special_symbols[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index 5831d57ceb..d81c50f524 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
self._check_has_context(node)
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._check_inner_children_have_context(e)
self._check_has_context(node)
@@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._set_inner_child_context(node.value, gast.Load())
node.ctx = ctx
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._set_inner_child_context(e, ctx)
node.ctx = ctx
@@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer):
# Preserve the target context.
for n in new_nodes:
- if isinstance(n, gast.Tuple):
+ if isinstance(n, (gast.Tuple, gast.List)):
for e in n.elts:
self._set_inner_child_context(e, node.ctx)
if isinstance(n, gast.Attribute):
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py
index 77e8ff62fd..074105ea50 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/contrib/autograph/pyct/templates_test.py
@@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
+ def test_replace_list_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_tuple_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_complex_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD
index 9ef1ac9663..29a92444bb 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/contrib/autograph/pyct/testing/BUILD
@@ -34,8 +34,10 @@ py_test(
srcs = ["codegen_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
"no_windows",
"nomsan",
+ "notap",
],
deps = [
":testing",
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index d2b399f19b..4504a5c7a3 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -20,12 +20,12 @@ py_library(
name = "utils",
srcs = [
"__init__.py",
- "builtins.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
"py_func.py",
"tensor_list.py",
+ "tensors.py",
"testing.py",
"type_check.py",
],
@@ -42,17 +42,6 @@ py_library(
)
py_test(
- name = "builtins_test",
- srcs = ["builtins_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":utils",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "context_managers_test",
srcs = ["context_managers_test.py"],
srcs_version = "PY2AND3",
@@ -113,3 +102,13 @@ py_test(
"//tensorflow/python:list_ops",
],
)
+
+py_test(
+ name = "tensors_test",
+ srcs = ["tensors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
index 57b5f74741..38e0a0a8f0 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/contrib/autograph/utils/__init__.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_print
-from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.autograph.utils.misc import alias_tensors
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
deleted file mode 100644
index 4dd440ef19..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Builtin conversion utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import list_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-
-
-def dynamic_builtin(f, *args, **kwargs):
- """Converts a builtin function call inline."""
- if f is len:
- return dynamic_len(*args, **kwargs)
- if six.PY2 and f is xrange:
- return dynamic_range(*args, **kwargs)
- if f is range:
- return dynamic_range(*args, **kwargs)
- if f is int:
- return dynamic_int(*args, **kwargs)
- if f is float:
- return dynamic_float(*args, **kwargs)
- if f is abs:
- return dynamic_abs(*args, **kwargs)
-
- raise NotImplementedError(
- 'The "%s" builtin is not yet supported.' % f.__name__)
-
-
-def dynamic_len(list_or_tensor):
- """Implementation of len using dynamic dispatch."""
- if _is_tensor_list(list_or_tensor):
- return list_ops.tensor_list_length(list_or_tensor)
- elif tensor_util.is_tensor(list_or_tensor):
- shape = list_or_tensor.shape
- if not shape.ndims:
- raise ValueError(
- 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
- return array_ops.shape(list_or_tensor)[0]
- return len(list_or_tensor)
-
-
-def _is_tensor_list(list_or_tensor):
- return (tensor_util.is_tensor(list_or_tensor)
- and list_or_tensor.dtype == dtypes.variant)
-
-
-def dynamic_int(num_or_tensor, **kwargs):
- """Implementation of int() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
- return int(num_or_tensor)
-
-
-def dynamic_float(num_or_tensor, **kwargs):
- """Implementation of float() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
- return float(num_or_tensor)
-
-
-def dynamic_abs(num_or_tensor, **kwargs):
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.abs(num_or_tensor, **kwargs)
- else:
- return abs(num_or_tensor, **kwargs)
-
-
-def dynamic_range(start_or_stop, stop=None, step=None):
- """Implementation of range using dynamic dispatch."""
- if type_check.is_tensor(start_or_stop, stop, step):
- if step is not None:
- return math_ops.range(start_or_stop, stop, step)
- if stop is not None:
- return math_ops.range(start_or_stop, stop)
- return math_ops.range(start_or_stop)
-
- if step is not None:
- return range(start_or_stop, stop, step)
- elif stop is not None:
- return range(start_or_stop, stop)
- return range(start_or_stop)
-
-
-def is_tf_print_compatible(value):
- # TODO(mdan): Enable once we can reliably test this.
- # This is currently disabled because we can't capture the output of
- # op kernels from Python.
- del value
- return False
-
-
-def dynamic_print(*values):
- """Implementation of print using dynamic dispatch.
-
- The function attempts to use tf.Print if all the values are compatible.
- Otherwise, it will fall back to py_func.
-
- Args:
- *values: values to print
- Returns:
- A dummy value indicating the print completed. If tf.
- """
-
- if all(map(is_tf_print_compatible, values)):
- return logging_ops.Print(1, values)
-
- def print_wrapper(*vals):
- if six.PY3:
- # TensorFlow doesn't seem to generate Unicode when passing strings to
- # py_func. This causes the print to add a "b'" wrapper to the output,
- # which is probably never what you want.
- vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
- print(*vals)
- # The flush helps avoid garbled output in IPython.
- sys.stdout.flush()
-
- return py_func.wrap_py_func(
- print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
deleted file mode 100644
index b1cd5253bc..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for builtins module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
-
-class BuiltinsTest(test.TestCase):
-
- def test_dynamic_len_tf_scalar(self):
- a = constant_op.constant(1)
-
- with self.assertRaisesRegexp(ValueError,
- 'len requires non-zero rank for tensor.*'):
- with self.test_session() as sess:
- sess.run(builtins.dynamic_builtin(len, a))
-
- def test_dynamic_len_tf_array(self):
- a = constant_op.constant([1, 2, 3])
-
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_abs_tf_scalar(self):
- a = constant_op.constant(-1)
-
- with self.test_session() as sess:
- self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
-
- def test_dynamic_abs_tf_array(self):
- a = constant_op.constant([-1, 2, -3])
-
- with self.test_session() as sess:
- self.assertListEqual([1, 2, 3],
- list(sess.run(builtins.dynamic_builtin(abs, a))))
-
- def test_dynamic_abs_py_scalar(self):
- a = -1
- self.assertEqual(1, builtins.dynamic_builtin(abs, a))
-
- def test_dynamic_len_tf_matrix(self):
- a = constant_op.constant([[1, 2], [3, 4]])
-
- with self.test_session() as sess:
- self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_len_py_list(self):
- a = [3] * 5
-
- self.assertEqual(5, builtins.dynamic_builtin(len, a))
-
- def test_dynamic_range_all_python(self):
- self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
- self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
-
- def test_dynamic_range_tf(self):
- with self.test_session() as sess:
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
- [0, 1, 2])
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
- [1, 2])
- self.assertAllEqual(
- sess.run(
- builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
- [2, 1])
-
- def test_dynamic_range_detection(self):
- def range(x): # pylint:disable=redefined-builtin
- return x
-
- # Functions that just have the names of builtins are rejected.
- with self.assertRaises(NotImplementedError):
- self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
- if six.PY2:
- self.assertListEqual(
- list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
-
- def test_casts(self):
- i = constant_op.constant(2, dtype=dtypes.int32)
- f = constant_op.constant(1.0, dtype=dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
- self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, True), 1)
- self.assertEqual(builtins.dynamic_builtin(int, False), 0)
- self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
- self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
-
- def test_dynamic_print_tf(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', 1))
- self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_dynamic_print_complex(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', [1, 2]))
- self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
- finally:
- sys.stdout = sys.__stdout__
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py
index 71e358c33e..968ea03df6 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/contrib/autograph/utils/misc_test.py
@@ -31,7 +31,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
def test_alias_tensors(self):
@@ -46,7 +46,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_v is v)
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py
index 2468263142..f60b57bcce 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/contrib/autograph/utils/py_func_test.py
@@ -31,7 +31,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c):
return a + b + c
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, sess.run(result))
@@ -52,7 +52,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b):
return a * b.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
@@ -69,7 +69,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c, d):
return a * b.foo + c * d.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
'c': 11,
'd': TestClass(13)
@@ -89,7 +89,7 @@ class PyFuncTest(test.TestCase):
def test_fn(_):
side_counter[0] += 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual([1], side_counter)
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py
index d58489eb68..faaf7b7877 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py
@@ -42,18 +42,18 @@ class TensorListTest(test.TestCase):
l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(l[0]), 1)
def test_list_append_python(self):
@@ -107,7 +107,7 @@ class TensorListTest(test.TestCase):
l0 = l[0]
l[0] = b
l1 = l[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l0, l1, a, b = sess.run([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py
new file mode 100644
index 0000000000..fa5db81a71
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+ return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+ # TODO(mdan): This is just a heuristic.
+ # With TF lacking support for templated types, this is unfortunately the
+ # closest we can get right now. A dedicated op ought to be possible to
+ # construct.
+ return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+ not t.shape.ndims)
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py
new file mode 100644
index 0000000000..e855e0b6cb
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+ def _simple_tensor_array(self):
+ return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+ def _simple_tensor_list(self):
+ return list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+ def _simple_list_of_tensors(self):
+ return [constant_op.constant(1), constant_op.constant(2)]
+
+ def test_is_tensor_array(self):
+ self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+ self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_array(None))
+
+ def test_is_tensor_list(self):
+ self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+ self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 68ead2f760..9afe3df585 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -14,8 +14,6 @@
# ==============================================================================
"""Monte Carlo integration and helpers.
-See the @{$python/contrib.bayesflow.monte_carlo} guide.
-
@@expectation
@@expectation_importance_sampler
@@expectation_importance_sampler_logspace
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index b9abfa8295..f33eaf7e3d 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -324,8 +324,14 @@ If you encounter a log line that includes the following:
"filename":"/usr/share/grpc/roots.pem"
```
-you likely need to copy the [gRPC `roots.pem` file][grpcPem] to
-`/usr/share/grpc/roots.pem` on your local machine.
+you can solve it via either of the following approaches:
+
+* copy the [gRPC `roots.pem` file][grpcPem] to
+ `/usr/share/grpc/roots.pem` on your local machine, which is the default
+ location where gRPC will look for this file
+* export the environment variable `GRPC_DEFAULT_SSL_ROOTS_FILE_PATH` to point to
+ the full path of the gRPC `roots.pem` file on your file system if it's in a
+ different location
[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index a25a641cdb..6138d79126 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -172,6 +172,11 @@ class BigtableTableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
BigtableTableOp);
+} // namespace
+
+namespace data {
+namespace {
+
class ToBigtableOp : public AsyncOpKernel {
public:
explicit ToBigtableOp(OpKernelConstruction* ctx)
@@ -354,5 +359,6 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
ToBigtableOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index a2a5df1037..4652021fec 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -79,6 +79,8 @@ class BigtableTableResource : public ResourceBase {
::google::cloud::bigtable::noex::Table table_;
};
+namespace data {
+
// BigtableReaderDatasetIterator is an abstract class for iterators from
// datasets that are "readers" (source datasets, not transformation datasets)
// that read from Bigtable.
@@ -138,6 +140,8 @@ class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
};
+} // namespace data
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index bd32672aa9..11f530e82a 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
@@ -226,4 +227,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
BigtableLookupDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index a803fdcb49..5cab729d9c 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
@@ -111,4 +112,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
BigtablePrefixKeyDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 5cd0371c79..4dc4647bd2 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
@@ -117,4 +118,5 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
BigtableRangeKeyDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 6928d9423c..736775bdac 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
@@ -205,4 +206,5 @@ REGISTER_KERNEL_BUILDER(
BigtableSampleKeyPairsDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index a759fb5063..208b7b3e08 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
@@ -118,4 +119,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
BigtableSampleKeysDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 78a920b077..9407855fe8 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableScanDatasetOp : public DatasetOpKernel {
@@ -224,4 +225,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
BigtableScanDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD
index 8eac1243ef..f03eab510c 100644
--- a/tensorflow/contrib/boosted_trees/BUILD
+++ b/tensorflow/contrib/boosted_trees/BUILD
@@ -445,6 +445,7 @@ tf_kernel_library(
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
+ "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/core:framework_headers_lib",
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce2442b..4c7a538b38 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
@@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
label_keys=None,
logits_modifier_function=None,
center_bias=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 68d710d713..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -16,7 +16,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import tempfile
+import numpy as np
+
from tensorflow.contrib.boosted_trees.estimator_batch import estimator
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column
@@ -26,6 +29,7 @@ from tensorflow.python.feature_column import feature_column_lib as core_feature_
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
@@ -473,6 +477,63 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
classifier.predict(input_fn=_eval_input_fn)
+ def testWeightedCategoricalColumn(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ feature_columns = [
+ core_feature_column.weighted_categorical_column(
+ categorical_column=core_feature_column.
+ categorical_column_with_vocabulary_list(
+ key="word", vocabulary_list=["the", "cat", "dog"]),
+ weight_feature_key="weight")
+ ]
+
+ labels = np.array([[1], [1], [0], [0.]], dtype=np.float32)
+
+ def _make_input_fn():
+
+ def _input_fn():
+ features_dict = {}
+ # Sparse tensor representing
+ # example 0: "cat","the"
+ # examaple 1: "dog"
+ # example 2: -
+ # example 3: "the"
+ # Weights for the words are 5 - cat, 6- dog and 1 -the.
+ features_dict["word"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=constant_op.constant(
+ ["the", "cat", "dog", "the"], dtype=dtypes.string),
+ dense_shape=[4, 3])
+ features_dict["weight"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=[1., 5., 6., 1.],
+ dense_shape=[4, 3])
+ return features_dict, labels
+
+ return _input_fn
+
+ est = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns)
+
+ input_fn = _make_input_fn()
+ est.train(input_fn=input_fn, steps=100)
+ est.evaluate(input_fn=input_fn, steps=1)
+ est.predict(input_fn=input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3483..a6e422847d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@ def model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@ def model_builder(features,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@ def ranking_model_builder(features,
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@ def ranking_model_builder(features,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index d9e7a0f466..3b28ed77f3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <limits>
#include <memory>
#include <string>
#include <vector>
@@ -325,13 +326,21 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
}
float best_gain = std::numeric_limits<float>::lowest();
- int64 best_bucket_idx = 0;
+ int64 best_bucket_id = 0;
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
- int64 current_bucket_id = 0;
+ int64 current_bucket_id = std::numeric_limits<int64>::max();
int64 last_bucket_id = -1;
+ // Find the lowest bucket id, this is going to be the first bucket id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (bucket_ids(start_index, 0) < current_bucket_id) {
+ current_bucket_id = bucket_ids(start_index, 0);
+ }
+ }
// Indexes offsets for each of the partitions that can be used to access
// gradients of a partition for a current bucket we consider.
std::vector<int> current_layer_offsets(num_elements, 0);
@@ -373,6 +382,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain = gain_of_split;
best_left_node_stats = current_left_node_stats;
best_right_node_stats = current_right_node_stats;
+ best_bucket_id = current_bucket_id;
}
current_bucket_id = next_bucket_id;
}
@@ -383,22 +393,23 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain -= num_elements * state->tree_complexity_regularization();
ObliviousSplitInfo oblivious_split_info;
- auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
- ->mutable_dense_float_binary_split();
+ auto* oblivious_dense_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
- oblivious_dense_split->set_threshold(
- bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ oblivious_dense_split->set_threshold(bucket_boundaries(best_bucket_id));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
- auto* left_children = oblivious_split_info.add_children_leaves();
- auto* right_children = oblivious_split_info.add_children_leaves();
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
- state->FillLeaf(best_left_node_stats[root_idx], left_children);
- state->FillLeaf(best_right_node_stats[root_idx], right_children);
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
}
oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
@@ -728,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
context->input("bias_feature_id", &bias_feature_id_t));
int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -756,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer.
+ int size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+
Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -779,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -790,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -802,17 +861,132 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(state.feature_column_group_id());
+ equality_split->set_feature_column(state->feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ }
+ }
+
+ void ComputeObliviousDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ // First feature ID in each partition should be the bias feature.
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
+ errors::InvalidArgument("Bias feature ID missing."));
+ GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
}
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_feature_id = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_feature_id = std::numeric_limits<int64>::max();
+ int64 last_feature_id = -1;
+ // Find the lowest feature id, this is going to be the first feature id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (feature_ids(start_index + 1, 0) < current_feature_id) {
+ current_feature_id = feature_ids(start_index + 1, 0);
+ }
+ }
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current feature we consider. Start at one
+ // beacuse the zero index is for the bias.
+ std::vector<int> current_layer_offsets(num_elements, 1);
+ // The idea is to try every feature id in increasing order. In each
+ // iteration we calculate the gain of the layer using the current feature id
+ // as split value, and we also obtain the following feature id to try.
+ while (current_feature_id > last_feature_id) {
+ last_feature_id = current_feature_id;
+ int64 next_feature_id = -1;
+ // Left gradient stats per node.
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && feature_ids(idx, 0) == current_feature_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] = g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) {
+ next_feature_id = feature_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ best_feature_id = current_feature_id;
+ }
+ current_feature_id = next_feature_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* equality_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_categorical_id_binary_split();
+ equality_split->set_feature_column(state->feature_column_group_id());
+ equality_split->set_feature_id(best_feature_id);
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
+ }
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
index 6d9a6ee5a0..ab2853352a 100644
--- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <vector>
+
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
+#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -26,6 +29,7 @@ namespace boosted_trees {
namespace {
+using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig;
using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode;
@@ -42,6 +46,9 @@ struct SplitCandidate {
// Split info.
learner::SplitInfo split_info;
+
+ // Oblivious split info.
+ learner::ObliviousSplitInfo oblivious_split_info;
};
// Checks that the leaf is not empty.
@@ -343,7 +350,12 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
float learning_rate = learning_rate_t->scalar<float>()();
- // Read seed that was used for dropout.
+ // Read the weak learner type to use.
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
const Tensor* seed_t;
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
// Cast seed to uint64.
@@ -363,9 +375,18 @@ class GrowTreeEnsembleOp : public OpKernel {
// Find best splits for each active partition.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerPartition(context, partition_ids_list, gains_list,
- splits_list, &best_splits);
-
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ FindBestSplitsPerPartitionNormal(context, partition_ids_list,
+ gains_list, splits_list, &best_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list,
+ &best_splits);
+ break;
+ }
+ }
// No-op if no new splits can be considered.
if (best_splits.empty()) {
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
@@ -377,25 +398,34 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context,
context->input("max_tree_depth", &max_tree_depth_t));
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
-
// Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
- dropout_seed, max_tree_depth);
-
+ dropout_seed, max_tree_depth,
+ weak_learner_type);
// Split tree nodes.
- for (auto& split_entry : best_splits) {
- SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
- ensemble_resource);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ for (auto& split_entry : best_splits) {
+ SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
+ ensemble_resource);
+ }
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
+ }
}
-
// Post-prune finalized tree if needed.
if (learner_config_.pruning_mode() ==
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
ensemble_resource->LastTreeMetadata()->is_finalized()) {
VLOG(2) << "Post-pruning finalized tree.";
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
+ LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
+ }
PruneTree(tree_config);
// If after post-pruning the whole tree has no gain, remove the tree
@@ -409,10 +439,9 @@ class GrowTreeEnsembleOp : public OpKernel {
private:
// Helper method which effectively does a reduce over all split candidates
// and finds the best split for each partition.
- void FindBestSplitsPerPartition(
- OpKernelContext* const context,
- const OpInputList& partition_ids_list, const OpInputList& gains_list,
- const OpInputList& splits_list,
+ void FindBestSplitsPerPartitionNormal(
+ OpKernelContext* const context, const OpInputList& partition_ids_list,
+ const OpInputList& gains_list, const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
// TODO(salehay): Is this worth parallelizing?
@@ -446,6 +475,90 @@ class GrowTreeEnsembleOp : public OpKernel {
}
}
+ void FindBestSplitsPerPartitionOblivious(
+ OpKernelContext* const context, const OpInputList& gains_list,
+ const OpInputList& splits_list,
+ std::map<int32, SplitCandidate>* best_splits) {
+ // Find best split per partition going through every feature candidate.
+ for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
+ const auto& gains = gains_list[handler_id].vec<float>();
+ const auto& splits = splits_list[handler_id].vec<string>();
+ OP_REQUIRES(context, gains.size() == 1,
+ errors::InvalidArgument(
+ "Gains size must be one for oblivious weak learner: ",
+ gains.size(), " != ", 1));
+ OP_REQUIRES(context, splits.size() == 1,
+ errors::InvalidArgument(
+ "Splits size must be one for oblivious weak learner: ",
+ splits.size(), " != ", 1));
+ // Get current split candidate.
+ const auto& gain = gains(0);
+ const auto& serialized_split = splits(0);
+ SplitCandidate split;
+ split.handler_id = handler_id;
+ split.gain = gain;
+ OP_REQUIRES(
+ context, split.oblivious_split_info.ParseFromString(serialized_split),
+ errors::InvalidArgument("Unable to parse oblivious split info."));
+
+ auto split_info = split.oblivious_split_info;
+ CHECK(split_info.children_size() % 2 == 0)
+ << "The oblivious split should generate an even number of children: "
+ << split_info.children_size();
+
+ // If every node is pure, then we shouldn't split.
+ bool only_pure_nodes = true;
+ for (int idx = 0; idx < split_info.children_size(); idx += 2) {
+ if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
+ IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
+ only_pure_nodes = false;
+ break;
+ }
+ }
+ if (only_pure_nodes) {
+ VLOG(1) << "The oblivious split does not actually split anything.";
+ continue;
+ }
+
+ // Don't consider negative splits if we're pre-pruning the tree.
+ if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
+ gain < 0) {
+ continue;
+ }
+
+ // Take the split if we don't have a candidate yet.
+ auto best_split_it = best_splits->find(0);
+ if (best_split_it == best_splits->end()) {
+ best_splits->insert(std::make_pair(0, std::move(split)));
+ continue;
+ }
+
+ // Determine if we should update best split.
+ SplitCandidate& best_split = best_split_it->second;
+ trees::TreeNode current_node = split_info.split_node();
+ trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
+ if (TF_PREDICT_FALSE(gain == best_split.gain)) {
+ // Tie break on node case preferring simpler tree node types.
+ VLOG(2) << "Attempting to tie break with smaller node case. "
+ << "(current split: " << current_node.node_case()
+ << ", best split: " << best_node.node_case() << ")";
+ if (current_node.node_case() < best_node.node_case()) {
+ best_split = std::move(split);
+ } else if (current_node.node_case() == best_node.node_case()) {
+ // Tie break on handler Id.
+ VLOG(2) << "Tie breaking with higher handler Id. "
+ << "(current split: " << handler_id
+ << ", best split: " << best_split.handler_id << ")";
+ if (handler_id > best_split.handler_id) {
+ best_split = std::move(split);
+ }
+ }
+ } else if (gain > best_split.gain) {
+ best_split = std::move(split);
+ }
+ }
+ }
+
void UpdateTreeWeightsIfDropout(
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
@@ -501,7 +614,7 @@ class GrowTreeEnsembleOp : public OpKernel {
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
const float learning_rate, const uint64 dropout_seed,
- const int32 max_tree_depth) {
+ const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees();
if (num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
@@ -647,6 +760,71 @@ class GrowTreeEnsembleOp : public OpKernel {
}
}
+ void SplitTreeLayer(
+ SplitCandidate* split,
+ boosted_trees::trees::DecisionTreeConfig* tree_config,
+ boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
+ int depth = 0;
+ while (depth < tree_config->nodes_size() &&
+ tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
+ depth++;
+ }
+ CHECK(tree_config->nodes_size() > 0)
+ << "A tree must have at least one dummy leaf.";
+ // The number of new children.
+ int num_children = 1 << (depth + 1);
+ auto split_info = split->oblivious_split_info;
+ CHECK(num_children >= split_info.children_size())
+ << "Too many new children, expected <= " << num_children << " and got "
+ << split_info.children_size();
+ std::vector<trees::Leaf> new_leaves;
+ new_leaves.reserve(num_children);
+ int next_id = 0;
+ for (int idx = 0; idx < num_children / 2; idx++) {
+ trees::Leaf old_leaf =
+ *tree_config->mutable_nodes(depth + idx)->mutable_leaf();
+ // Check if a split was made for this leaf.
+ if (next_id < split_info.children_parent_id_size() &&
+ depth + idx == split_info.children_parent_id(next_id)) {
+ // Add left leaf.
+ new_leaves.push_back(*MergeLeafWeights(
+ old_leaf, split_info.mutable_children(2 * next_id)));
+ // Add right leaf.
+ new_leaves.push_back(*MergeLeafWeights(
+ old_leaf, split_info.mutable_children(2 * next_id + 1)));
+ next_id++;
+ } else {
+ // If there is no split for this leaf, just duplicate it.
+ new_leaves.push_back(old_leaf);
+ new_leaves.push_back(old_leaf);
+ }
+ }
+ CHECK(next_id == split_info.children_parent_id_size());
+ TreeNodeMetadata* split_metadata =
+ split_info.mutable_split_node()->mutable_node_metadata();
+ split_metadata->set_gain(split->gain);
+
+ TreeNode new_split = *split_info.mutable_split_node();
+ // Move old children to metadata.
+ for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
+ *new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
+ *tree_config->mutable_nodes(idx)->mutable_leaf();
+ }
+ // Add the new split to the tree_config in place before the children start.
+ *tree_config->mutable_nodes(depth) = new_split;
+ // Add the new children
+ int nodes_size = tree_config->nodes_size();
+ for (int idx = 0; idx < num_children; idx++) {
+ if (idx + depth + 1 < nodes_size) {
+ // Update leaves that were already there.
+ *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
+ new_leaves[idx];
+ } else {
+ // Add new leaves.
+ *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
+ }
+ }
+ }
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
// No-op if tree is empty.
if (tree_config->nodes_size() <= 0) {
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index efe29216c2..35d727482b 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import constant_op
@@ -46,6 +47,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -66,6 +68,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(EqualitySplitHandler, self).__init__(
@@ -85,6 +88,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
name="StatsAccumulator/{}".format(self._name))
self._sparse_int_column = sparse_int_column
+ self._weak_learner_type = weak_learner_type
def update_stats(self, stamp_token, example_partition_ids, gradients,
hessians, empty_gradients, empty_hessians, weights,
@@ -137,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
# The bias is computed on gradients and hessians (and not
# filtered_gradients) which have exactly one value per example, so we
# don't double count a gradient in multivalent columns.
+ # Since unsorted_segment_sum can be numerically unstable, use 64bit
+ # operation.
+ gradients64 = math_ops.cast(gradients, dtypes.float64)
+ hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
- gradients, mapped_partitions, array_ops.size(unique_partitions))
+ gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
- hessians, mapped_partitions, array_ops.size(unique_partitions))
-
+ hessians64, mapped_partitions, array_ops.size(unique_partitions))
+ per_partition_gradients = math_ops.cast(per_partition_gradients,
+ dtypes.float32)
+ per_partition_hessians = math_ops.cast(per_partition_hessians,
+ dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
# Bias is added to the stats even if there are no examples with values in
@@ -197,7 +208,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
tree_complexity_regularization=self._tree_complexity_regularization,
min_node_weight=self._min_node_weight,
bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
+ multiclass_strategy=self._multiclass_strategy,
+ weak_learner_type=self._weak_learner_type))
# There are no warm-up rounds needed in the equality column handler. So we
# always return ready.
are_splits_ready = constant_op.constant(True)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index ef253e7cec..94ea7bc2eb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -169,10 +169,121 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
- def testGenerateFeatureSplitCandidatesSumReduction(self):
+ def testObliviousFeatureSplitGeneration(self):
with self.test_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 1 | 1 |
+ # i1 | (-0.5, 0.07) | 1 | 2 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [1, 1, 1, 2]
+ indices = [[0, 0], [1, 0], [2, 0], [3, 0]]
+ values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+
+ with ops.control_dependencies([update_1, update_2]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([1, 2], partitions)
+
+ # For partition 1.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight1 = -0.9848484848484846
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain1 = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight1 = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain1 = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain1 = 0.46043165467625885
+
+ split_info = split_info_pb2.ObliviousSplitInfo()
+ split_info.ParseFromString(splits[0])
+ # Children of partition 1.
+ left_child = split_info.children[0].vector
+ right_child = split_info.children[1].vector
+ split_node = split_info.split_node.oblivious_categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+ self.assertEqual(1, split_node.feature_id)
+ self.assertAllClose([expected_left_weight1], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight1], right_child.value, 0.00001)
+
+ # For partition2.
+ expected_left_weight2 = 0
+ expected_left_gain2 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight2 = -3.4513274336283186
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_right_gain2 = 13.460176991150442
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain2 = 13.460176991150442
+
+ # Children of partition 2.
+ left_child = split_info.children[2].vector
+ right_child = split_info.children[3].vector
+ self.assertAllClose([expected_left_weight2], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight2], right_child.value, 0.00001)
+
+ self.assertAllClose(
+ expected_left_gain1 + expected_right_gain1 - expected_bias_gain1 +
+ expected_left_gain2 + expected_right_gain2 - expected_bias_gain2,
+ gains[0], 0.00001)
+
+ def testGenerateFeatureSplitCandidatesSumReduction(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
# i1 | (-0.5, 0.07) | 0 | |
# i2 | (1.2, 0.2) | 0 | 2 |
@@ -293,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testGenerateFeatureSplitCandidatesMulticlass(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
[[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2])
@@ -371,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
@@ -419,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 6572f2f414..74b0ea6989 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -183,17 +183,18 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
- # i0 | (0.2, 0.12) | 0 | 2 |
- # i1 | (-0.5, 0.07) | 0 | 2 |
- # i2 | (1.2, 0.2) | 0 | 0 |
- # i3 | (4.0, 0.13) | 1 | 1 |
- dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ # i0 | (0.2, 0.12) | 1 | 3 |
+ # i1 | (-0.5, 0.07) | 1 | 3 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ dense_column = array_ops.placeholder(
+ dtypes.float32, shape=(4, 1), name="dense_column")
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
- partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32)
class_id = -1
gradient_shape = tensor_shape.scalar()
@@ -230,87 +231,96 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([update_1]):
are_splits_ready = split_handler.make_splits(
np.int64(0), np.int64(1), class_id)[0]
+ # Forcing the creation of four buckets.
+ are_splits_ready = sess.run(
+ [are_splits_ready],
+ feed_dict={dense_column: [[0.2], [0.62], [0.3], [0.52]]})[0]
- with ops.control_dependencies([are_splits_ready]):
- update_2 = split_handler.update_stats_sync(
- 1,
- partition_ids,
- gradients,
- hessians,
- empty_gradients,
- empty_hessians,
- example_weights,
- is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
- are_splits_ready, are_splits_ready2, partitions, gains, splits = (
- sess.run([
- are_splits_ready, are_splits_ready2, partitions, gains, splits
- ]))
+ # Only using the last three buckets.
+ are_splits_ready2, partitions, gains, splits = (
+ sess.run(
+ [are_splits_ready2, partitions, gains, splits],
+ feed_dict={dense_column: [[0.62], [0.62], [0.3], [0.52]]}))
# During the first iteration, inequality split handlers are not going to
# have any splits. Make sure that we return not_ready in that case.
self.assertFalse(are_splits_ready)
self.assertTrue(are_splits_ready2)
- self.assertAllEqual([0, 1], partitions)
+ self.assertAllEqual([1, 2], partitions)
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0])
- split_node = oblivious_split_info.split_node.dense_float_binary_split
-
+ split_node = oblivious_split_info.split_node
+ split_node = split_node.oblivious_dense_float_binary_split
self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column)
- # Check the split on partition 0.
+ # Check the split on partition 1.
# -(1.2 - 0.1) / (0.2 + 1)
- expected_left_weight_0 = -0.9166666666666666
+ expected_left_weight_1 = -0.9166666666666666
- # expected_left_weight_0 * -(1.2 - 0.1)
- expected_left_gain_0 = 1.008333333333333
+ # expected_left_weight_1 * -(1.2 - 0.1)
+ expected_left_gain_1 = 1.008333333333333
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
- expected_right_weight_0 = 0.1680672
+ expected_right_weight_1 = 0.1680672
- # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
- expected_right_gain_0 = 0.033613445378151252
+ # expected_right_weight_1 * -(-0.5 + 0.2 + 0.1))
+ expected_right_gain_1 = 0.033613445378151252
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
- expected_bias_gain_0 = 0.46043165467625896
+ expected_bias_gain_1 = 0.46043165467625896
- left_child = oblivious_split_info.children_leaves[0].vector
- right_child = oblivious_split_info.children_leaves[1].vector
+ left_child = oblivious_split_info.children[0].vector
+ right_child = oblivious_split_info.children[1].vector
- self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
+ self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
- self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
- # Check the split on partition 1.
- expected_left_weight_1 = 0
- expected_left_gain_1 = 0
+ # Check the split on partition 2.
+ expected_left_weight_2 = 0
+ expected_left_gain_2 = 0
# -(4 - 0.1) / (0.13 + 1)
- expected_right_weight_1 = -3.4513274336283186
- # expected_right_weight_1 * -(4 - 0.1)
- expected_right_gain_1 = 13.460176991150442
+ expected_right_weight_2 = -3.4513274336283186
+ # expected_right_weight_2 * -(4 - 0.1)
+ expected_right_gain_2 = 13.460176991150442
# (-4 + 0.1) ** 2 / (0.13 + 1)
- expected_bias_gain_1 = 13.460176991150442
+ expected_bias_gain_2 = 13.460176991150442
- left_child = oblivious_split_info.children_leaves[2].vector
- right_child = oblivious_split_info.children_leaves[3].vector
+ left_child = oblivious_split_info.children[2].vector
+ right_child = oblivious_split_info.children[3].vector
- self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
+ self.assertAllClose([expected_left_weight_2], left_child.value, 0.00001)
- self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight_2], right_child.value, 0.00001)
# The layer gain is the sum of the gains of each partition
layer_gain = (
- expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
- expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
+ expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + (
+ expected_left_gain_2 + expected_right_gain_2 - expected_bias_gain_2)
self.assertAllClose(layer_gain, gains[0], 0.00001)
+ # We have examples in both partitions, then we get both ids.
+ self.assertEqual(2, len(oblivious_split_info.children_parent_id))
+ self.assertEqual(1, oblivious_split_info.children_parent_id[0])
+ self.assertEqual(2, oblivious_split_info.children_parent_id[1])
+
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -448,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -536,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -623,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -698,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testGenerateFeatureSplitCandidatesWithTreeComplexity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -832,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -941,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1064,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1197,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1292,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1387,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1465,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
# No values in this feature column in this mini-batch.
values = array_ops.constant([], dtype=dtypes.float32)
@@ -1535,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testEmptyBuckets(self):
"""Test that reproduces the case when quantile buckets were empty."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_column = array_ops.sparse_placeholder(dtypes.float32)
# We have two batches - at first, a sparse feature is empty.
@@ -1628,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testDegenerativeCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# One data example only, one leaf and thus one quantile bucket.The same
# situation is when all examples have the same values. This case was
# causing before a failure.
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index 69bb8fd4ad..8d71a6cdbc 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -36,12 +36,6 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max) {
- // Explicitly initialize all of memory (including padding from memory
- // alignment) to allow the struct to be msan-resistant "plain old data".
- //
- // POD = http://en.cppreference.com/w/cpp/concept/PODType
- memset(this, 0, sizeof(*this));
-
value = v;
weight = w;
min_rank = min;
@@ -49,8 +43,6 @@ class WeightedQuantilesSummary {
}
SummaryEntry() {
- memset(this, 0, sizeof(*this));
-
value = ValueType();
weight = 0;
min_rank = 0;
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
index 0e5578693a..64921faf81 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <algorithm>
+
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/core/platform/macros.h"
-#include <algorithm>
-
namespace tensorflow {
namespace boosted_trees {
namespace trees {
@@ -28,14 +28,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
return kInvalidLeaf;
}
-
// Traverse tree starting at the provided sub-root.
int32 node_id = sub_root_id;
+ // The index of the leave that holds this example in the oblivious case.
+ int oblivious_leaf_idx = 0;
while (true) {
const auto& current_node = config.nodes(node_id);
switch (current_node.node_case()) {
case TreeNode::kLeaf: {
- return node_id;
+ return node_id + oblivious_leaf_idx;
}
case TreeNode::kDenseFloatBinarySplit: {
const auto& split = current_node.dense_float_binary_split();
@@ -100,6 +101,28 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
}
break;
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ const auto& split = current_node.oblivious_dense_float_binary_split();
+ oblivious_leaf_idx <<= 1;
+ if (example.dense_float_features[split.feature_column()] >
+ split.threshold()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ const auto& split =
+ current_node.oblivious_categorical_id_binary_split();
+ oblivious_leaf_idx <<= 1;
+ const auto& features =
+ example.sparse_int_features[split.feature_column()];
+ if (features.find(split.feature_id()) == features.end()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break;
@@ -165,6 +188,16 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
split->set_right_id(*++children_it);
break;
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
+ break;
+ }
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children.";
break;
@@ -199,6 +232,16 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
const auto& split = node.categorical_id_set_membership_binary_split();
return {split.left_id(), split.right_id()};
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
+ return {};
+ }
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
return {};
}
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
index ec06787e1d..1f3672bf85 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
-#define TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_PARALLEL_FOR_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_PARALLEL_FOR_H_
#include "tensorflow/core/lib/core/threadpool.h"
@@ -30,4 +30,4 @@ void ParallelFor(int64 batch_size, int64 desired_parallelism,
} // namespace boosted_trees
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_PARALLEL_FOR_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h
index 546d344f55..249651e99e 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/random.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
-#define TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_RANDOM_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_RANDOM_H_
#include "tensorflow/core/lib/random/simple_philox.h"
@@ -36,4 +36,4 @@ inline int32 PoissonBootstrap(random::SimplePhilox* rng) {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_RANDOM_H_
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 9b68a9de96..f1e12a028a 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -179,6 +179,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -224,6 +225,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
index 22ac9edb72..604ec8e0bf 100644
--- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
@@ -57,6 +57,7 @@ REGISTER_OP("GrowTreeEnsemble")
.Input("learning_rate: float")
.Input("dropout_seed: int64")
.Input("max_tree_depth: int32")
+ .Input("weak_learner_type: int32")
.Input("partition_ids: num_handlers * int32")
.Input("gains: num_handlers * float")
.Input("splits: num_handlers * string")
@@ -82,6 +83,7 @@ tree_ensemble_handle: Handle to the ensemble variable.
stamp_token: Stamp token for validating operation consistency.
next_stamp_token: Stamp token to be used for the next iteration.
learning_rate: Scalar learning rate.
+weak_learner_type: The type of weak learner to use.
partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
gains: List of Rank 1 Tensors containing gains per candidate.
splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto
index 850340f5c2..784977af39 100644
--- a/tensorflow/contrib/boosted_trees/proto/split_info.proto
+++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto
@@ -19,8 +19,10 @@ message SplitInfo {
}
message ObliviousSplitInfo {
- // The split node with the feature_column and threshold defined.
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
- // The new leaves of the tree.
- repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
+ repeated tensorflow.boosted_trees.trees.Leaf children = 2;
+ // For each child, children_parent_id stores the node_id of its parent when it
+ // was a leaf. For the idx-th child it corresponds the idx/2-th
+ // children_parent_id.
+ repeated int32 children_parent_id = 3;
}
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
index 81411aa84a..520b4f8b11 100644
--- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -15,6 +15,8 @@ message TreeNode {
CategoricalIdBinarySplit categorical_id_binary_split = 5;
CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6;
+ ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
+ ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8;
}
TreeNodeMetadata node_metadata = 777;
}
@@ -26,6 +28,9 @@ message TreeNodeMetadata {
// The original leaf node before this node was split.
Leaf original_leaf = 2;
+
+ // The original layer of leaves before that layer was converted to a split.
+ repeated Leaf original_oblivious_leaves = 3;
}
// Leaves can either hold dense or sparse information.
@@ -101,6 +106,28 @@ message CategoricalIdSetMembershipBinarySplit {
int32 right_id = 4;
}
+// Split rule for dense float features in the oblivious case.
+message ObliviousDenseFloatBinarySplit {
+ // Float feature column and split threshold describing
+ // the rule feature <= threshold.
+ int32 feature_column = 1;
+ float threshold = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
+// Split rule for categorical features with a single feature Id in the oblivious
+// case.
+message ObliviousCategoricalIdBinarySplit {
+ // Categorical feature column and Id describing the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
index 63b9c5fddf..42d69645ac 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
@@ -98,7 +98,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
self._seed = 123
def testCreate(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
_append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
@@ -133,7 +133,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
def testSerialization(self):
with ops.Graph().as_default() as graph:
- with self.test_session(graph):
+ with self.session(graph):
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree only for second class.
tree1 = tree_ensemble_config.trees.add()
@@ -164,7 +164,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
serialized_config = serialized_config.eval()
with ops.Graph().as_default() as graph:
- with self.test_session(graph):
+ with self.session(graph):
tree_ensemble_handle2 = model_ops.tree_ensemble_variable(
stamp_token=9,
tree_ensemble_config=serialized_config,
@@ -204,14 +204,14 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]])
def testRestore(self):
- # Calling self.test_session() without a graph specified results in
+ # Calling self.cached_session() without a graph specified results in
# TensorFlowTestCase caching the session and returning the same one
# every time. In this test, we need to create two different sessions
- # which is why we also create a graph and pass it to self.test_session()
+ # which is why we also create a graph and pass it to self.cached_session()
# to ensure no caching occurs under the hood.
save_path = os.path.join(self.get_temp_dir(), "restore-test")
with ops.Graph().as_default() as graph:
- with self.test_session(graph) as sess:
+ with self.session(graph) as sess:
# Prepare learner config.
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
@@ -288,7 +288,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with ops.Graph().as_default() as graph:
- with self.test_session(graph) as sess:
+ with self.session(graph) as sess:
tree_ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="restore_tree")
my_saver = saver.Saver()
@@ -311,7 +311,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(result.eval(), [[-1.1], [-1.1]])
def testUsedHandlers(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_config.growing_metadata.used_handler_ids.append(1)
tree_ensemble_config.growing_metadata.used_handler_ids.append(5)
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index cf55759aaa..46dfbdefeb 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -96,6 +96,20 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None):
split.dimension_id = feature_dim_id
+def _set_float_oblivious_split(split, feat_col, thresh):
+ """Helper method for building tree float splits.
+
+ Sets split feature column and threshold.
+
+ Args:
+ split: split node to update.
+ feat_col: feature column for the split.
+ thresh: threshold to split on forming rule x <= thresh.
+ """
+ split.feature_column = feat_col
+ split.threshold = thresh
+
+
def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id):
"""Helper method for building tree categorical id splits.
@@ -119,15 +133,17 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float
+ Creates, a batch of two examples having three dense float, two sparse float
single valued, one sparse float multidimensional and one sparse int
features. The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM
- | 0 | 7 | -3 | | 9,1 | __, 5.0
- | 1 | -2 | | 4 | | 3, ___
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |SparseM
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 | __, 5.0
+ | 1 | -2 | 2 | 0.5 | | 4 | | 3, ___
"""
super(PredictionOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -153,7 +169,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
reduce_dim=False):
return prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor],
+ self._seed, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -165,8 +181,27 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
center_bias=center_bias,
reduce_dim=reduce_dim)
+ def _get_predictions_oblivious_case(self,
+ tree_ensemble_handle,
+ learner_config,
+ apply_dropout=False,
+ apply_averaging=False,
+ center_bias=False,
+ reduce_dim=False):
+ return prediction_ops.gradient_trees_prediction(
+ tree_ensemble_handle,
+ self._seed, [
+ self._dense_float_tensor1, self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [],
+ learner_config=learner_config,
+ apply_dropout=apply_dropout,
+ apply_averaging=apply_averaging,
+ center_bias=center_bias,
+ reduce_dim=reduce_dim)
+
def testEmptyEnsemble(self):
- with self.test_session():
+ with self.cached_session():
# Empty tree ensenble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
@@ -189,7 +224,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testBiasEnsembleSingleClass(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
tree_ensemble_config.tree_metadata.add().is_finalized = True
@@ -217,7 +252,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testBiasEnsembleMultiClass(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
tree_ensemble_config.tree_metadata.add().is_finalized = True
@@ -247,7 +282,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testFullEnsembleSingleClass(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -295,8 +330,55 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Empty dropout.
self.assertAllEqual([[], []], dropout_info.eval())
+ def testObliviousEnsemble(self):
+ with self.cached_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Bias tree.
+ tree1 = tree_ensemble_config.trees.add()
+ tree_ensemble_config.tree_metadata.add().is_finalized = True
+ _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4)
+
+ # Depth 3 tree.
+ tree2 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree2.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree2.nodes.add().leaf, 0, i / 10.0)
+
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+
+ result, dropout_info = self._get_predictions_oblivious_case(
+ tree_ensemble_handle,
+ learner_config=learner_config.SerializeToString(),
+ reduce_dim=True)
+
+ # The first example will get bias -0.4 from first tree and 0.6 from
+ # the 5th leaf of the second tree corresponding to node_id = 8, hence a
+ # prediction of 0.2.
+ # The second example will get bias -0.4 and 0.1 from the 0th leaf of the
+ # second tree corresponding to node_id = 3, hence a prediction of -0.3
+ self.assertAllClose([[0.2], [-0.3]], result.eval())
+
+ # Empty dropout.
+ self.assertAllEqual([[], []], dropout_info.eval())
+
def testFullEnsembleWithMultidimensionalSparseSingleClass(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -358,7 +440,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
result, dropout_info = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- self._seed, [self._dense_float_tensor], [
+ self._seed, [self._dense_float_tensor1], [
self._sparse_float_indices1, self._sparse_float_indices2,
self._sparse_float_indices_m
], [
@@ -384,7 +466,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testExcludeNonFinalTree(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -431,7 +513,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testIncludeNonFinalTree(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -482,7 +564,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def testMetadataMissing(self):
# Sometimes we want to do prediction on trees that are not added to ensemble
# (for example in
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -530,7 +612,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# For TREE_PER_CLASS strategy, predictions size is num_classes-1
def testFullEnsembleMultiClassTreePerClassStrategy(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree only for second class.
tree1 = tree_ensemble_config.trees.add()
@@ -581,7 +663,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# This test is when leafs have SPARSE weights stored (class id and
# contribution).
def testFullEnsembleMultiNotClassTreePerClassStrategySparseVector(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree only for second class.
tree1 = tree_ensemble_config.trees.add()
@@ -631,7 +713,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# will have the size of the number of classes.
# This test is when leafs have DENSE weights stored (weight for each class)
def testFullEnsembleMultiNotClassTreePerClassStrategyDenseVector(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree only for second class.
tree1 = tree_ensemble_config.trees.add()
@@ -678,7 +760,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testDropout(self):
- with self.test_session():
+ with self.cached_session():
# Empty tree ensenble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Add 1000 trees with some weights.
@@ -741,7 +823,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# This is for normal non-batch mode where ensemble does not contain the tree
# that is being built currently.
num_trees = 10
- with self.test_session():
+ with self.cached_session():
# Empty tree ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Add 10 trees with some weights.
@@ -809,7 +891,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# This is batch mode where ensemble already contains the tree that we are
# building. This tree should never be dropped.
num_trees = 10
- with self.test_session():
+ with self.cached_session():
# Empty tree ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Add 10 trees with some weights.
@@ -877,7 +959,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
dropout_info_center[0][num_dropped_center - 1])
def testDropoutSeed(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Add 10 trees with some weights.
for i in range(0, 999):
@@ -917,7 +999,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Different seed.
_, dropout_info_3 = prediction_ops.gradient_trees_prediction(
tree_ensemble_handle,
- 112314, [self._dense_float_tensor],
+ 112314, [self._dense_float_tensor1],
[self._sparse_float_indices1, self._sparse_float_indices2],
[self._sparse_float_values1, self._sparse_float_values2],
[self._sparse_float_shape1, self._sparse_float_shape2],
@@ -950,7 +1032,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
len(dropout_info_4.eval()[0]) + 1, len(dropout_info_1.eval()[0]))
def testDropOutZeroProb(self):
- with self.test_session():
+ with self.cached_session():
# Empty tree ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Add 1000 trees with some weights.
@@ -993,7 +1075,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(result.eval(), result_no_dropout.eval())
def testAveragingAllTrees(self):
- with self.test_session():
+ with self.cached_session():
# Empty tree ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
adjusted_tree_ensemble_config = (
@@ -1057,7 +1139,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval())
def testAveragingSomeTrees(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
adjusted_tree_ensemble_config = (
tree_config_pb2.DecisionTreeEnsembleConfig())
@@ -1138,7 +1220,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(dropout_info_2.eval(), pattern_dropout_info.eval())
def testAverageMoreThanNumTreesExist(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
adjusted_tree_ensemble_config = (
tree_config_pb2.DecisionTreeEnsembleConfig())
@@ -1204,15 +1286,18 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
"""Sets up the prediction tests.
- Create a batch of two examples having one dense float, two sparse float and
- one sparse int features.
+ Create a batch of two examples having three dense float, two sparse float
+ and one sparse int features.
The data looks like the following:
- | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 |
- | 0 | 7 | -3 | | 9,1 |
- | 1 | -2 | | 4 | |
+ |Instance |Dense0 |Dense1 |Dense2 |SparseF0 |SparseF1 |SparseI0 |
+ | 0 | 7 | 1 | 2 | -3 | | 9,1 |
+ | 1 | -2 | 2 | 0.5 | | 4 | |
+
"""
super(PartitionExamplesOpsTest, self).setUp()
- self._dense_float_tensor = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor1 = np.array([[7.0], [-2.0]])
+ self._dense_float_tensor2 = np.array([[1.0], [2.0]])
+ self._dense_float_tensor3 = np.array([[2.0], [0.5]])
self._sparse_float_indices1 = np.array([[0, 0]])
self._sparse_float_values1 = np.array([-3.0])
self._sparse_float_shape1 = np.array([2, 1])
@@ -1224,7 +1309,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
self._sparse_int_shape1 = np.array([2, 2])
def testEnsembleEmpty(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -1234,17 +1319,17 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
def testTreeNonFinalized(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Depth 3 tree.
tree1 = tree_ensemble_config.trees.add()
@@ -1269,17 +1354,17 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([5, 3], result.eval())
def testTreeFinalized(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Depth 3 tree.
tree1 = tree_ensemble_config.trees.add()
@@ -1304,15 +1389,51 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
resources.initialize_resources(resources.shared_resources()).run()
result = prediction_ops.gradient_trees_partition_examples(
- tree_ensemble_handle, [self._dense_float_tensor], [
- self._sparse_float_indices1, self._sparse_float_indices2
- ], [self._sparse_float_values1, self._sparse_float_values2],
- [self._sparse_float_shape1,
- self._sparse_float_shape2], [self._sparse_int_indices1],
- [self._sparse_int_values1], [self._sparse_int_shape1])
+ tree_ensemble_handle, [self._dense_float_tensor1],
+ [self._sparse_float_indices1, self._sparse_float_indices2],
+ [self._sparse_float_values1, self._sparse_float_values2],
+ [self._sparse_float_shape1, self._sparse_float_shape2],
+ [self._sparse_int_indices1], [self._sparse_int_values1],
+ [self._sparse_int_shape1])
self.assertAllEqual([0, 0], result.eval())
+ def testObliviousTreeNonFinalized(self):
+ with self.cached_session():
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ # Depth 3 tree.
+ tree1 = tree_ensemble_config.trees.add()
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 0, 5.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 1, 3.0)
+ _set_float_oblivious_split(
+ tree1.nodes.add().oblivious_dense_float_binary_split, 2, 1.0)
+ for i in range(1, 9):
+ _append_to_leaf(tree1.nodes.add().leaf, 0, i / 10.0)
+ tree_ensemble_config.tree_weights.append(1.0)
+ tree_ensemble_config.tree_metadata.add().is_finalized = False
+
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="full_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ result = prediction_ops.gradient_trees_partition_examples(
+ tree_ensemble_handle, [
+ self._dense_float_tensor1,
+ self._dense_float_tensor2,
+ self._dense_float_tensor3
+ ], [], [], [], [], [], [])
+
+ # The first example goes right, left, right in the tree and the second
+ # example goes lef, left, left. Since the depth of the tree is 3, the
+ # partition id's are as follows:
+ # First example: 3 + 5 = 8
+ # Second exampel: 3 + 0 = 3
+ self.assertAllEqual([8, 3], result.eval())
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
index 074623699d..848c42b686 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
@@ -77,7 +77,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
example_weights = constant_op.constant(
[10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
- with self.test_session():
+ with self.cached_session():
config = self._gen_config(0.33, 3)
dense_buckets, sparse_buckets = quantile_ops.quantile_buckets(
[dense_float_tensor_0], [sparse_indices_0, sparse_indices_m],
@@ -107,7 +107,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
"""
num_quantiles = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=num_quantiles,
epsilon=0.001, name="q1")
@@ -119,7 +119,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
column=input_column,
example_weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(1, 23):
# start = 1, 2, 4, 7, 11, 16 ... (see comment above)
start = int((i * (i-1) / 2) + 1)
@@ -127,7 +127,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
{input_column: range(start, start+i),
weights: [1] * i})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_flush = (sess.run(
@@ -142,7 +142,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
num_quantiles = 3
# set generate_quantiles to True since the test will generate fewer
# boundaries otherwise.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=num_quantiles,
epsilon=0.001, name="q1", generate_quantiles=True)
@@ -154,7 +154,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
column=input_column,
example_weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This input is generated by integer in the range [2030, 2060]
# but represented by with float16 precision. Integers <= 2048 are
# exactly represented, whereas numbers > 2048 are rounded; and hence
@@ -174,7 +174,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
{input_column: inputs,
weights: [1] * len(inputs)})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_flush = (sess.run(
@@ -189,7 +189,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
# set generate_quantiles to True since the test will generate fewer
# boundaries otherwise.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=num_quantiles,
epsilon=0.001, name="q1", generate_quantiles=True)
@@ -201,12 +201,12 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
column=input_column,
example_weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(update,
{input_column: inputs,
weights: [1] * len(inputs)})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_flush = (sess.run(
@@ -265,7 +265,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
[9900 9901 .. 9999]
All the batches have 1 for all the example weights.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.01, name="q1")
resources.initialize_resources(resources.shared_resources()).run()
@@ -275,7 +275,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
stamp_token=0,
column=dense_placeholder,
example_weights=weight_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(100):
dense_float = np.linspace(
i * 100, (i + 1) * 100 - 1, num=100).reshape(-1, 1)
@@ -284,7 +284,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
weight_placeholder: np.ones(shape=(100, 1), dtype=np.float32)
})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
@@ -301,7 +301,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
[9900 9901 .. 9999]
All the batches have 1 for all the example weights.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.01, name="q1")
accumulator_2 = quantile_ops.QuantileAccumulator(
@@ -313,7 +313,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
stamp_token=0,
column=dense_placeholder,
example_weights=weight_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(100):
dense_float = np.linspace(
i * 100, (i + 1) * 100 - 1, num=100).reshape(-1, 1)
@@ -322,7 +322,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
weight_placeholder: np.ones(shape=(100, 1), dtype=np.float32)
})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
summary = sess.run(
accumulator.flush_summary(stamp_token=0, next_stamp_token=1))
sess.run(
@@ -338,7 +338,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
@@ -366,7 +366,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
self.assertEqual(True, are_ready_flush)
self.assertAllEqual([2, 4, 6.], buckets)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
save = saver.Saver()
@@ -389,7 +389,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
@@ -413,7 +413,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual([1, 3, 5], buckets)
save.save(sess, save_path)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
accumulator = quantile_ops.QuantileAccumulator(
init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
save = saver.Saver()
@@ -438,7 +438,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
[1] * (int(math.pow(2, 16)) + 1), dtype=dtypes.float32)
config = self._gen_config(0.1, 10)
- with self.test_session():
+ with self.cached_session():
dense_buckets, _ = quantile_ops.quantile_buckets(
[dense_float_tensor_0], [], [], [],
example_weights=example_weights,
@@ -464,7 +464,7 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
config = self._gen_config(0.1, 10)
- with self.test_session():
+ with self.cached_session():
dense_buckets, _ = quantile_ops.quantile_buckets(
[dense_float_tensor_0], [], [], [],
example_weights=example_weights,
@@ -533,7 +533,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
self._sparse_thresholds_m = [1, 2, 1000]
def testDenseFeaturesOnly(self):
- with self.test_session():
+ with self.cached_session():
dense_quantiles, _ = quantile_ops.quantiles(
[self._dense_float_tensor_0, self._dense_float_tensor_1], [],
[self._dense_thresholds_0, self._dense_thresholds_1], [], [])
@@ -546,7 +546,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
dense_quantiles[1].eval())
def testSparseFeaturesOnly(self):
- with self.test_session():
+ with self.cached_session():
_, sparse_quantiles = quantile_ops.quantiles([], [
self._sparse_values_0, self._sparse_values_1, self._sparse_values_2,
self._sparse_values_m
@@ -571,7 +571,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
sparse_quantiles[3].eval())
def testDenseAndSparseFeatures(self):
- with self.test_session():
+ with self.cached_session():
dense_quantiles, sparse_quantiles = quantile_ops.quantiles(
[self._dense_float_tensor_0, self._dense_float_tensor_1], [
self._sparse_values_0, self._sparse_values_1,
@@ -602,14 +602,14 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
sparse_quantiles[3].eval())
def testBucketizeWithInputBoundaries(self):
- with self.test_session():
+ with self.cached_session():
buckets = quantile_ops.bucketize_with_input_boundaries(
input=[1, 2, 3, 4, 5],
boundaries=[3])
self.assertAllEqual([0, 0, 1, 1, 1], buckets.eval())
def testBucketizeWithInputBoundaries2(self):
- with self.test_session():
+ with self.cached_session():
boundaries = constant_op.constant([3], dtype=dtypes.float32)
buckets = quantile_ops.bucketize_with_input_boundaries(
input=[1, 2, 3, 4, 5],
@@ -617,7 +617,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual([0, 0, 1, 1, 1], buckets.eval())
def testBucketizeWithInputBoundaries3(self):
- with self.test_session():
+ with self.cached_session():
b = array_ops.placeholder(dtypes.float32)
buckets = quantile_ops.bucketize_with_input_boundaries(
input=[1, 2, 3, 4, 5],
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 2589504762..74917f7cde 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -33,7 +33,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeDenseSplit(self):
"""Tests split handler op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following after dividing by number of steps (2).
# Gradients | Partition | Dense Quantile |
# (1.2, 0.2) | 0 | 0 |
@@ -111,7 +111,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeMulticlassDenseSplit(self):
"""Tests split handler op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32)
bucket_ids = array_ops.constant(
[[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64)
@@ -153,7 +153,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeDenseSplitEmptyInputs(self):
"""Tests empty inputs op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
partition_ids = array_ops.constant([], dtype=dtypes.int32)
bucket_ids = array_ops.constant([[]], dtype=dtypes.int64)
gradients = array_ops.constant([])
@@ -183,7 +183,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeSparseSplit(self):
"""Tests split handler op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following after dividing by number of steps (2).
# Gradients | Partition | bucket ID |
# (0.9, 0.39) | 0 | -1 |
@@ -274,7 +274,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeSparseSplitAllEmptyDimensions(self):
"""Tests split handler op when all dimensions have only bias bucket id."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following after dividing by number of steps (2).
# Gradients | Partition | Dimension | bucket ID |
# (0.9, 0.39) | 0 | 0 | -1 |
@@ -307,7 +307,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeSparseMultidimensionalSplit(self):
"""Tests split handler op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Num of steps is 2.
# The feature column is three dimensional.
# First dimension has bias bucket only, the second has bias bucket and
@@ -408,7 +408,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
"""Tests default direction is stable when no sparsity."""
random.seed(1123)
for _ in range(50):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
grad = random.random()
hessian = random.random()
# The data looks like the following (divide by the num of steps 2).
@@ -465,7 +465,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeMulticlassSparseSplit(self):
"""Tests split handler op."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32)
bucket_ids = array_ops.constant(
[[-1, 0], [0, 0], [1, 0], [-1, 0], [1, 0]], dtype=dtypes.int64)
@@ -514,7 +514,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeCategoricalEqualitySplit(self):
"""Tests split handler op for categorical equality split."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following after dividing by number of steps (2).
# Gradients | Partition | Feature ID |
# (0.9, 0.39) | 0 | -1 |
@@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -608,7 +609,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
def testMakeMulticlassCategoricalEqualitySplit(self):
"""Tests split handler op for categorical equality split in multiclass."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([[1.8, 3.5], [2.4, 1.0], [0.4, 4.0],
[9.0, 3.1], [3.0, 0.8]])
@@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -655,7 +657,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testMakeCategoricalEqualitySplitEmptyInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = []
hessians = []
partition_ids = []
@@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertEqual(0, len(partitions))
self.assertEqual(0, len(gains))
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py
index 978bf530cd..05ce0884cc 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py
@@ -29,7 +29,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
"""Tests for scalar gradients and hessians accumulator."""
def testSimpleAcculumator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -57,7 +57,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
def testMultidimensionalAcculumator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -86,7 +86,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2])
def testDropStaleUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -118,7 +118,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
def testSerialize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -159,7 +159,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
self.assertEqual(0, stamp_token)
def testDeserialize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -196,7 +196,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(4, 6, 2)], [0.5, 0.7])
def testMakeSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
@@ -218,7 +218,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
"""Tests for tensor gradients and hessians accumulator."""
def testSimpleAcculumator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
@@ -256,7 +256,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]])
def testMultidimensionalAcculumator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
@@ -294,7 +294,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 1)][1], [[0.05, 0.06], [0.07, 0.08]])
def testDropStaleUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
@@ -331,7 +331,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]])
def testSerialize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
@@ -381,7 +381,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(result_1[2, 3, 0][1], result_2[2, 3, 0][1])
def testDeserialize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
@@ -425,7 +425,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(result[(4, 5, 0)][1], [[0.07, 0.08], [0.09, 0.10]])
def testMakeSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.TensorShape([2]),
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index e39e1de8d1..86fd5770a0 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -91,6 +91,31 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
return split.SerializeToString()
+def _gen_dense_oblivious_split_info(fc, threshold, leave_weights,
+ children_parent_id):
+ split_str = """
+ split_node {
+ oblivious_dense_float_binary_split {
+ feature_column: %d
+ threshold: %f
+ }
+ }""" % (fc, threshold)
+ for weight in leave_weights:
+ split_str += """
+ children {
+ vector {
+ value: %f
+ }
+ }""" % (
+ weight)
+ for x in children_parent_id:
+ split_str += """
+ children_parent_id: %d""" % (x)
+ split = split_info_pb2.ObliviousSplitInfo()
+ text_format.Merge(split_str, split)
+ return split.SerializeToString()
+
+
def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight):
split_str = """
split_node {
@@ -125,7 +150,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase):
def testCenterBias(self):
"""Tests bias centering for multiple iterations."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -276,7 +301,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEmptyEnsemble(self):
"""Test growing an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -324,7 +349,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -383,9 +409,122 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 1)
self.assertProtoEquals(expected_result, tree_ensemble_config)
+ def testGrowEmptyEnsembleObliviousCase(self):
+ """Test growing an empty ensemble in the oblivious case."""
+ with self.cached_session() as session:
+ # Create empty ensemble.
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=1,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+
+ # Prepare handler inputs.
+ # Note that handlers 1 & 3 have the same gain but different splits.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([7.62], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143], [0])
+ ]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([0.63], dtype=np.float32)
+ handler2_split = [
+ _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24], [0])
+ ]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([7.62], dtype=np.float32)
+ handler3_split = [
+ _gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143], [0])
+ ]
+
+ # Grow tree ensemble.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ # Expect the split with bigger handler_id, i.e. handler 3 to be chosen.
+ # The grown tree should be finalized as max tree depth is 1.
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.143
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 1)
+ self.assertEqual(stats.num_layers, 1)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 1)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge("""
@@ -476,7 +615,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -575,7 +715,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeFinalized(self):
"""Test growing an existing ensemble with the last tree finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge("""
@@ -661,7 +801,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -757,7 +898,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsemblePrePrune(self):
"""Test growing an ensemble with pre-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -798,7 +939,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the ensemble to be empty.
@@ -823,7 +965,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsemblePostPruneNone(self):
"""Test growing an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -869,7 +1011,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -930,7 +1073,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsemblePostPruneAll(self):
"""Test growing an ensemble with post-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -971,7 +1114,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1053,7 +1197,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the ensemble to be empty as post-pruning will prune
@@ -1079,7 +1224,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsemblePostPrunePartial(self):
"""Test growing an ensemble with post-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -1120,7 +1265,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1200,7 +1346,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the negative gain split of partition 1 to be pruned and the
@@ -1280,7 +1427,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsembleTreeLayerByLayer(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge("""
@@ -1371,7 +1518,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -1470,9 +1618,721 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 2)
self.assertProtoEquals(expected_result, tree_ensemble_config)
+ def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
+ """Test growing an existing ensemble with the last tree not finalized."""
+ with self.cached_session() as session:
+ # Create existing ensemble with one root split
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=3,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.4], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5],
+ [1, 2])
+ ]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([2.7], dtype=np.float32)
+ handler2_split = [
+ _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4],
+ [1, 2])
+ ]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([1.7], dtype=np.float32)
+ handler3_split = [
+ _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1],
+ [1, 2])
+ ]
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ # Expect the split for partition 1 to be chosen from handler 1 and
+ # the split for partition 2 to be chosen from handler 2.
+ # The grown tree should not be finalized as max tree depth is 3 and
+ # it's only grown 2 layers.
+ # The partition 1 split weights get added to original leaf weight 7.143.
+ # The partition 2 split weights get added to original leaf weight -4.375.
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.383
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 2)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 2)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 2)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
+ def testGrowEnsembleWithEmptyNodesMiddleCase(self):
+ """Test case: The middle existing leaves don't have examples."""
+ with self.cached_session() as session:
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=6,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.8], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [2, 5])
+ ]
+ # The tree currently has depth 2, so the ids for the four leaves are in
+ # the range [2, 6). In this test case we are assuming that our examples
+ # only fall in leaves 2 and 5.
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[handler1_partitions],
+ gains=[handler1_gains],
+ splits=[handler1_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.9
+ }
+ node_metadata {
+ gain: 1.8
+ original_oblivious_leaves {
+ vector {
+ value: 6.543
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 7.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.075
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 8.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -0.975
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.025
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 3
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 3)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 3)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 3)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
+ def testGrowEnsembleWithEmptyNodesBorderCase(self):
+ """Test case: The first and last existing leaves don't have examples."""
+ with self.cached_session() as session:
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=6,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.8], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [3, 4])
+ ]
+ # The tree currently has depth 2, so the ids for the four leaves are in
+ # the range [2, 6). In this test case we are assuming that our examples
+ # only fall in leaves 3 and 4.
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[handler1_partitions],
+ gains=[handler1_gains],
+ splits=[handler1_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.9
+ }
+ node_metadata {
+ gain: 1.8
+ original_oblivious_leaves {
+ vector {
+ value: 6.543
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 7.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.075
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 8.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 9.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -0.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 3
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 3)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 3)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 3)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
"""Test growing an existing ensemble with the last tree finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split and one bias tree.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge("""
@@ -1575,7 +2435,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -1596,7 +2457,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeWithFeatureSelectionUsedHandlers(self):
"""Test growing a tree with feature selection."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split and one bias tree.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge("""
@@ -1700,7 +2561,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
_, serialized = session.run(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 20ff48c360..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -218,6 +218,21 @@ def extract_features(features, feature_columns, use_core_columns):
sparse_int_shapes = []
for key in sorted(features.keys()):
tensor = features[key]
+ # TODO(nponomareva): consider iterating over feature columns instead.
+ if isinstance(tensor, tuple):
+ # Weighted categorical feature.
+ categorical_tensor = tensor[0]
+ weight_tensor = tensor[1]
+
+ shape = categorical_tensor.dense_shape
+ indices = array_ops.concat([
+ array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]),
+ array_ops.expand_dims(
+ math_ops.to_int64(categorical_tensor.values), -1)
+ ], 1)
+ tensor = sparse_tensor.SparseTensor(
+ indices=indices, values=weight_tensor.values, dense_shape=shape)
+
if isinstance(tensor, sparse_tensor.SparseTensor):
if tensor.values.dtype == dtypes.float32:
sparse_float_names.append(key)
@@ -289,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -312,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -384,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -674,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
@@ -747,7 +765,8 @@ class GradientBoostedDecisionTreeModel(object):
hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
- loss_uses_sum_reduction=loss_uses_sum_reduction))
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type))
fc_name_idx += 1
# Create ensemble stats variables.
@@ -1048,6 +1067,12 @@ class GradientBoostedDecisionTreeModel(object):
# Grow the ensemble given the current candidates.
sizes = array_ops.unstack(split_sizes)
partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
+ # When using the oblivious decision tree as weak learner, it produces
+ # one gain and one split per handler and not number of partitions.
+ if self._learner_config.weak_learner_type == (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE):
+ sizes = len(training_state.handlers)
+
gains_list = list(array_ops.split(gains, sizes, axis=0))
split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
return training_ops.grow_tree_ensemble(
@@ -1061,7 +1086,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
center_bias=self._center_bias,
- max_tree_depth=self._max_tree_depth)
+ max_tree_depth=self._max_tree_depth,
+ weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_not_ready_fn():
# Don't grow the ensemble, just update the stamp.
@@ -1076,7 +1102,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
center_bias=self._center_bias,
- max_tree_depth=self._max_tree_depth)
+ max_tree_depth=self._max_tree_depth,
+ weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_fn():
# Conditionally grow an ensemble depending on whether the splits
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index f7867d882d..73e41bc457 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from google.protobuf import text_format
from tensorflow.contrib import layers
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, output.trees[0])
+ def testObliviousDecisionTreeAsWeakLearner(self):
+ with self.test_session():
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.learning_rate_tuner.fixed.learning_rate = 1
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 2
+ learner_config.constraints.min_node_weight = 0
+ learner_config.weak_learner_type = (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE
+ learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
+ features = {}
+ features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]],
+ dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN)
+ predictions = predictions_dict["predictions"]
+ labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+
+ train_op = gbdt_model.train(
+ loss=math_ops.reduce_mean(
+ _squared_loss(labels, weights, predictions)),
+ predictions_dict=predictions_dict,
+ labels=labels)
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect no splits to be chosen because the quantile
+ # buckets will not be ready.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 1)
+
+ # Second run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 2)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+ # Third run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 3)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -2.0
+ }
+ node_metadata {
+ gain: 0.25
+ original_oblivious_leaves {
+ vector {
+ value: -1.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -2.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index e92f0bb841..150d734db6 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -34,6 +34,9 @@ Checkpointable data structures:
Checkpoint management:
@@CheckpointManager
+
+Saving and restoring Python state:
+@@NumpyState
"""
from __future__ import absolute_import
@@ -41,6 +44,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
+from tensorflow.contrib.checkpoint.python.python_state import NumpyState
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index 7b200a29bf..ada4168726 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -9,6 +9,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":containers",
+ ":python_state",
":split_dependency",
":visualize",
"//tensorflow/python/training/checkpointable:data_structures",
@@ -41,6 +42,33 @@ py_test(
)
py_library(
+ name = "python_state",
+ srcs = ["python_state.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python/training/checkpointable:base",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "python_state_test",
+ srcs = ["python_state_test.py"],
+ deps = [
+ ":python_state",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:util",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
new file mode 100644
index 0000000000..9b11035b6d
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -0,0 +1,166 @@
+"""Utilities for including Python state in TensorFlow checkpoints."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy
+
+from tensorflow.python.training.checkpointable import base
+
+# pylint: disable=g-import-not-at-top
+try:
+ # In Python 2.x, use the faster string buffering option.
+ from cStringIO import StringIO as BytesIO
+except ImportError:
+ from io import BytesIO
+# pylint: enable=g-import-not-at-top
+
+
+class NumpyState(base.CheckpointableBase):
+ """A checkpointable object whose NumPy array attributes are saved/restored.
+
+ Example usage:
+
+ ```python
+ arrays = tf.contrib.checkpoint.NumpyState()
+ checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save("/tmp/ckpt")
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ assert (arrays.x == numpy.zeros([3, 4])).all()
+
+ second_checkpoint = tf.train.Checkpoint(
+ numpy_arrays=tf.contrib.checkpoint.NumpyState())
+ # Attributes of NumpyState objects are created automatically by restore()
+ second_checkpoint.restore(save_path)
+ assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
+ ```
+
+ Note that `NumpyState` objects re-create the attributes of the previously
+ saved object on `restore()`. This is in contrast to TensorFlow variables, for
+ which a `Variable` object must be created and assigned to an attribute.
+
+ This snippet works both when graph building and when executing eagerly. On
+ save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
+ a placeholder when graph building, or as a string constant when executing
+ eagerly). When restoring they skip the TensorFlow graph entirely, and so no
+ restore ops need be run. This means that restoration always happens eagerly,
+ rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
+ TensorFlow variables when graph building.
+ """
+
+ def _lookup_dependency(self, name):
+ """Create placeholder NumPy arrays for to-be-restored attributes.
+
+ Typically `_lookup_dependency` is used to check by name whether a dependency
+ exists. We cheat slightly by creating a checkpointable object for `name` if
+ we don't already have one, giving us attribute re-creation behavior when
+ loading a checkpoint.
+
+ Args:
+ name: The name of the dependency being checked.
+ Returns:
+ An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
+ dependency (which will generally be restored immediately).
+ """
+ value = super(NumpyState, self)._lookup_dependency(name)
+ if value is None:
+ value = _NumpyWrapper(numpy.array([]))
+ new_reference = base.CheckpointableReference(name=name, ref=value)
+ self._unconditional_checkpoint_dependencies.append(new_reference)
+ self._unconditional_dependency_names[name] = value
+ super(NumpyState, self).__setattr__(name, value)
+ return value
+
+ def __getattribute__(self, name):
+ """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
+ value = super(NumpyState, self).__getattribute__(name)
+ if isinstance(value, _NumpyWrapper):
+ return value.array
+ return value
+
+ def __setattr__(self, name, value):
+ """Automatically wrap NumPy arrays assigned to attributes."""
+ # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
+ # ndarrays checkpointable natively and using standard checkpointable list
+ # tracking.
+ if isinstance(value, numpy.ndarray):
+ try:
+ existing = super(NumpyState, self).__getattribute__(name)
+ existing.array = value
+ return
+ except AttributeError:
+ value = _NumpyWrapper(value)
+ self._track_checkpointable(value, name=name, overwrite=True)
+ elif (name not in ("_setattr_tracking", "_update_uid")
+ and getattr(self, "_setattr_tracking", True)):
+ # Mixing restore()-created attributes with user-added checkpointable
+ # objects is tricky, since we can't use the `_lookup_dependency` trick to
+ # re-create attributes (we might accidentally steal the restoration for
+ # another checkpointable object). For now `NumpyState` objects must be
+ # leaf nodes. Theoretically we could add some extra arguments to
+ # `_lookup_dependency` to figure out whether we should create a NumPy
+ # array for the attribute or not.
+ raise NotImplementedError(
+ ("Assigned %s to the %s property of %s, which is not a NumPy array. "
+ "Currently mixing NumPy arrays and other checkpointable objects is "
+ "not supported. File a feature request if this limitation bothers "
+ "you.")
+ % (value, name, self))
+ super(NumpyState, self).__setattr__(name, value)
+
+
+class _NumpyWrapper(base.CheckpointableBase):
+ """Wraps a NumPy array for storage in an object-based checkpoint."""
+
+ def __init__(self, array):
+ """Specify a NumPy array to wrap.
+
+ Args:
+ array: The NumPy array to save and restore (may be overwritten).
+ """
+ self.array = array
+
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the array."""
+ string_file = BytesIO()
+ try:
+ numpy.save(string_file, self.array, allow_pickle=False)
+ serialized = string_file.getvalue()
+ finally:
+ string_file.close()
+ return serialized
+
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ string_file = BytesIO(string_value)
+ try:
+ self.array = numpy.load(string_file, allow_pickle=False)
+ finally:
+ string_file.close()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "array": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
new file mode 100644
index 0000000000..0439a4755e
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy
+
+from tensorflow.contrib.checkpoint.python import python_state
+from tensorflow.python.client import session
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.training.checkpointable import util
+
+
+class NumpyStateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreNumpyState(self):
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_state.b = numpy.ones([2, 2])
+ save_state.b = numpy.zeros([2, 2])
+ self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ first_save_path = saver.save(prefix)
+ save_state.a[1, 1] = 2.
+ second_save_path = saver.save(prefix)
+
+ load_state = python_state.NumpyState()
+ loader = util.Checkpoint(numpy=load_state)
+ loader.restore(first_save_path).initialize_or_restore()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ load_state.a[0, 0] = 42.
+ self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
+ loader.restore(first_save_path).run_restore_ops()
+ self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
+ loader.restore(second_save_path).run_restore_ops()
+ self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
+ self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+
+ def testNoGraphPollution(self):
+ graph = ops.Graph()
+ with graph.as_default(), session.Session():
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ save_state = python_state.NumpyState()
+ saver = util.Checkpoint(numpy=save_state)
+ save_state.a = numpy.ones([2, 2])
+ save_path = saver.save(prefix)
+ saver.restore(save_path)
+ graph.finalize()
+ saver.save(prefix)
+ save_state.a = numpy.zeros([2, 2])
+ saver.save(prefix)
+ saver.restore(save_path)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoMixedNumpyStateTF(self):
+ save_state = python_state.NumpyState()
+ save_state.a = numpy.ones([2, 2])
+ with self.assertRaises(NotImplementedError):
+ save_state.v = variables.Variable(1.)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDocstringExample(self):
+ arrays = python_state.NumpyState()
+ checkpoint = util.Checkpoint(numpy_arrays=arrays)
+ arrays.x = numpy.zeros([3, 4])
+ save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ arrays.x[1, 1] = 4.
+ checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
+
+ second_checkpoint = util.Checkpoint(numpy_arrays=python_state.NumpyState())
+ second_checkpoint.restore(save_path)
+ self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
index 58fadffce3..e57a66b99f 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
@@ -33,7 +33,7 @@ bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
Status ParseJson(StringPiece json, Json::Value* result) {
Json::Reader reader;
- if (!reader.parse(json.ToString(), *result)) {
+ if (!reader.parse(string(json), *result)) {
return errors::Internal("Couldn't parse JSON response from BigQuery.");
}
return Status::OK();
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
index 1af43a3e10..f1fcaff73b 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
-#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#ifndef TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_H_
+#define TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_H_
#include <map>
#include <memory>
@@ -198,4 +198,4 @@ class BigQueryTableAccessor {
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#endif // TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_H_
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
index fea6b15640..6f4d54ae4a 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
-#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#ifndef TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#define TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
#include <string>
@@ -401,4 +401,4 @@ const string kTestEmptyRow = R"({
} // namespace
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#endif // TENSORFLOW_CONTRIB_CLOUD_KERNELS_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
index 493b3c6f1b..11e177cd0c 100644
--- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
@@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase):
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
feature_configs = {
"int64_col":
parsing_ops.FixedLenFeature(
@@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase):
num_rows = 10
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
index 9b6c056d6c..4f2ecbcb17 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase):
def testSetBlockCache(self):
cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gcs_config_ops.configure_gcs(sess, block_cache=cfg)
def testConfigureGcsHook(self):
@@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase):
'type': 'authorized_user'}
hook = gcs_config_ops.ConfigureGcsHook(credentials=creds)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None
hook.after_create_session(sess, None)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d74a..1056894f18 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index f6c928e2be..ebcabb4223 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -364,7 +364,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML)
+ add_definitions(-DINTEL_MKL_ML_ONLY)
endif()
endif (tensorflow_ENABLE_MKL_SUPPORT)
diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake
index 1d638e6402..479609458c 100644
--- a/tensorflow/contrib/cmake/external/nsync.cmake
+++ b/tensorflow/contrib/cmake/external/nsync.cmake
@@ -16,16 +16,16 @@ include (ExternalProject)
set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public)
set(nsync_URL https://github.com/google/nsync)
-set(nsync_TAG 1.20.0)
+set(nsync_TAG 1.20.1)
set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync)
set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install)
if(WIN32)
set(nsync_HEADERS "${nsync_BUILD}/public/*.h")
- set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib)
+ set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync_cpp.lib)
else()
set(nsync_HEADERS "${nsync_BUILD}/public/*.h")
- set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync.a)
+ set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/libnsync_cpp.a)
endif()
ExternalProject_Add(nsync
@@ -35,12 +35,12 @@ ExternalProject_Add(nsync
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
BUILD_BYPRODUCTS ${nsync_STATIC_LIBRARIES}
- PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/nsync/CMakeLists.txt ${nsync_BUILD}
INSTALL_DIR ${nsync_INSTALL}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL}
+ -DCMAKE_INSTALL_LIBDIR:STRING=lib
-DNSYNC_LANGUAGE:STRING=c++11)
set(nsync_HEADERS
diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
deleted file mode 100644
index 6f059c7225..0000000000
--- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
+++ /dev/null
@@ -1,325 +0,0 @@
-cmake_minimum_required (VERSION 2.8.12)
-
-# nsync provides portable synchronization primitives, such as mutexes and
-# condition variables.
-project (nsync)
-
-# Set variable NSYNC_LANGUAGE to "c++11" to build with C++11
-# rather than C.
-
-# Some builds need position-independent code.
-set (CMAKE_POSITION_INDEPENDENT_CODE ON)
-
-# -----------------------------------------------------------------
-# Platform dependencies
-
-# Many platforms use these posix related sources; even Win32.
-set (NSYNC_POSIX_SRC
- "platform/posix/src/nsync_panic.c"
- "platform/posix/src/per_thread_waiter.c"
- "platform/posix/src/time_rep.c"
- "platform/posix/src/yield.c"
-)
-
-if (WIN32)
- # Suppress warnings to reduce build log size.
- add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
- add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
- add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
- add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
- add_definitions(/wd8029)
-endif()
-
-# Many of the string matches below use a literal "X" suffix on both sides.
-# This is because some versions of cmake treat (for example) "MSVC" (in quotes)
-# as a reference to the variable MSVC, thus the expression
-# "${CMAKE_C_COMPILER_ID}" STREQUAL "MSVC"
-# is false when ${CMAKE_C_COMPILER_ID} has the value "MSVC"! See
-# https://cmake.org/cmake/help/v3.1/policy/CMP0054.html
-
-# Pick the include directory for the operating system.
-if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11")
- add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11")
- set (NSYNC_OS_CPP_SRC
- "platform/c++11/src/per_thread_waiter.cc"
- "platform/c++11/src/yield.cc"
- "platform/c++11/src/time_rep_timespec.cc"
- "platform/c++11/src/nsync_panic.cc"
- )
- if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/win32")
- add_compile_options ("/TP")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- "platform/win32/src/clock_gettime.c"
- "platform/win32/src/pthread_key_win32.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/win32/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/macos")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE
- # when including certin C++ standard header files, such as <mutex>.
- add_definitions ("-D_DARWIN_C_SOURCE")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- ${NSYNC_OS_CPP_SRC}
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- "platform/posix/src/clock_gettime.c"
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX")
- include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/linux/src/nsync_semaphore_futex.c"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- add_compile_options ("-std=c++11")
- set (NSYNC_OS_SRC
- "platform/c++11/src/nsync_semaphore_mutex.cc"
- ${NSYNC_OS_CPP_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
- endif ()
-endif ()
-
-# Pick the include directory for the compiler.
-if ("${CMAKE_C_COMPILER_ID}X" STREQUAL "GNUX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/gcc")
- set (THREADS_HAVE_PTHREAD_ARG ON)
-elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "ClangX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/clang")
- set (THREADS_HAVE_PTHREAD_ARG ON)
-elseif ("${CMAKE_C_COMPILER_ID}X" STREQUAL "MSVCX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/msvc")
-else ()
- message (WARNING "CMAKE_C_COMPILER_ID (${CMAKE_C_COMPILER_ID}) matched NOTHING")
-endif ()
-
-if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/win32")
- set (NSYNC_OS_SRC
- ${NSYNC_POSIX_SRC}
- "platform/win32/src/clock_gettime.c"
- "platform/win32/src/init_callback_win32.c"
- "platform/win32/src/nanosleep.c"
- "platform/win32/src/nsync_semaphore_win32.c"
- "platform/win32/src/pthread_cond_timedwait_win32.c"
- "platform/win32/src/pthread_key_win32.cc"
- )
- set (NSYNC_TEST_OS_SRC
- "platform/win32/src/start_thread.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/macos")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/clock_gettime.c"
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/linux")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/linux/src/nsync_semaphore_futex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd")
- set (NSYNC_POSIX ON)
- set (NSYNC_OS_EXTRA_SRC
- "platform/posix/src/nsync_semaphore_mutex.c"
- )
- endif ()
-endif ()
-
-if (NSYNC_POSIX)
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
- set (NSYNC_OS_SRC
- ${NSYNC_POSIX_SRC}
- ${NSYNC_OS_EXTRA_SRC}
- )
- set (NSYNC_TEST_OS_SRC
- "platform/posix/src/start_thread.c"
- )
-endif ()
-
-# Pick the include directory for the architecture.
-if (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "amd64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "AMD64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_64")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "x86_32X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i386X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "i686X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/x86_32")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv6lX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armv7lX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "armX"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/arm")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "aarch64X") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "arm64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/aarch64")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppcX") OR
- ("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc32X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc32")
-elseif (("${CMAKE_SYSTEM_PROCESSOR}X" STREQUAL "ppc64X"))
- include_directories ("${PROJECT_SOURCE_DIR}/platform/ppc64")
-endif ()
-
-# Windows uses some include files from the posix directory also.
-if ("${CMAKE_SYSTEM_NAME}X" STREQUAL "WindowsX")
- include_directories ("${PROJECT_SOURCE_DIR}/platform/posix")
-endif ()
-
-# -----------------------------------------------------------------
-
-include_directories ("${PROJECT_SOURCE_DIR}/public")
-include_directories ("${PROJECT_SOURCE_DIR}/internal")
-
-set (NSYNC_SRC
- "internal/common.c"
- "internal/counter.c"
- "internal/cv.c"
- "internal/debug.c"
- "internal/dll.c"
- "internal/mu.c"
- "internal/mu_wait.c"
- "internal/note.c"
- "internal/once.c"
- "internal/sem_wait.c"
- "internal/time_internal.c"
- "internal/wait.c"
- ${NSYNC_OS_SRC}
-)
-add_library (nsync ${NSYNC_SRC})
-
-set (NSYNC_TEST_SRC
- "testing/array.c"
- "testing/atm_log.c"
- "testing/closure.c"
- "testing/smprintf.c"
- "testing/testing.c"
- "testing/time_extra.c"
- ${NSYNC_TEST_OS_SRC}
-)
-add_library (nsync_test ${NSYNC_TEST_SRC})
-
-set (NSYNC_TESTS
- "counter_test"
- "cv_mu_timeout_stress_test"
- "cv_test"
- "cv_wait_example_test"
- "dll_test"
- "mu_starvation_test"
- "mu_test"
- "mu_wait_example_test"
- "mu_wait_test"
- "note_test"
- "once_test"
- "pingpong_test"
- "wait_test"
-)
-
-if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
- foreach (s IN ITEMS ${NSYNC_SRC} ${NSYNC_TEST_SRC})
- SET_SOURCE_FILES_PROPERTIES ("${s}" PROPERTIES LANGUAGE CXX)
- endforeach (s)
- foreach (t IN ITEMS ${NSYNC_TESTS})
- SET_SOURCE_FILES_PROPERTIES ("testing/${t}.c" PROPERTIES LANGUAGE CXX)
- endforeach (t)
-endif ()
-
-enable_testing ()
-foreach (t IN ITEMS ${NSYNC_TESTS})
- add_executable (${t} "testing/${t}.c")
-endforeach (t)
-
-find_package (Threads REQUIRED)
-set (THREADS_PREFER_PTHREAD_FLAG ON)
-foreach (t IN ITEMS "nsync" "nsync_test" ${NSYNC_TESTS})
- if (THREADS_HAVE_PTHREAD_ARG)
- target_compile_options (${t} PUBLIC "-pthread")
- endif ()
- if (CMAKE_THREAD_LIBS_INIT)
- target_link_libraries (${t} "${CMAKE_THREAD_LIBS_INIT}")
- endif ()
-endforeach (t)
-
-foreach (t IN ITEMS ${NSYNC_TESTS})
- target_link_libraries (${t} nsync_test nsync)
- add_test (NAME ${t} COMMAND ${t})
-endforeach (t)
-
-install (TARGETS nsync
- LIBRARY DESTINATION lib COMPONENT RuntimeLibraries
- ARCHIVE DESTINATION lib COMPONENT Development)
-
-set (NSYNC_INCLUDES
- "public/nsync.h"
- "public/nsync_atomic.h"
- "public/nsync_counter.h"
- "public/nsync_cpp.h"
- "public/nsync_cv.h"
- "public/nsync_debug.h"
- "public/nsync_mu.h"
- "public/nsync_mu_wait.h"
- "public/nsync_note.h"
- "public/nsync_once.h"
- "public/nsync_time.h"
- "public/nsync_time_internal.h"
- "public/nsync_waiter.h"
-)
-
-foreach (NSYNC_INCLUDE ${NSYNC_INCLUDES})
- install (FILES ${NSYNC_INCLUDE} DESTINATION include COMPONENT Development)
-endforeach ()
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index a5a947f726..fb871acae9 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -4,6 +4,8 @@ tensorflow
tensorflow/core
tensorflow/core/example
tensorflow/core/framework
+tensorflow/core/kernels
+tensorflow/core/kernels/boosted_trees
tensorflow/core/lib
tensorflow/core/lib/core
tensorflow/core/profiler
@@ -245,10 +247,6 @@ tensorflow/contrib/kernel_methods/python
tensorflow/contrib/kernel_methods/python/mappers
tensorflow/contrib/kinesis/python
tensorflow/contrib/kinesis/python/ops
-tensorflow/contrib/kfac
-tensorflow/contrib/kfac/examples
-tensorflow/contrib/kfac/python
-tensorflow/contrib/kfac/python/ops
tensorflow/contrib/labeled_tensor
tensorflow/contrib/labeled_tensor/python
tensorflow/contrib/labeled_tensor/python/ops
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index 855c824ead..4bfd753bb1 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -3,6 +3,7 @@
package(default_visibility = [
"//learning/brain:__subpackages__",
+ "//research/vision/piedpiper:__subpackages__",
"//tensorflow:__subpackages__",
])
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index bcee0b04c8..d7583be6d8 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -8,6 +8,7 @@ package_group(
packages = ["//tensorflow/..."],
)
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
@@ -46,3 +47,36 @@ cuda_py_test(
],
xla_enabled = True,
)
+
+py_library(
+ name = "xla",
+ srcs = ["xla.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+tf_py_test(
+ name = "xla_test",
+ srcs = ["xla_test.py"],
+ additional_deps = [
+ ":xla",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ ],
+ tags = ["no_pip"],
+)
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index a56a01b163..42b3b9f026 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -48,7 +48,7 @@ class JITTest(test.TestCase):
def compute(self, use_jit, compute_fn):
random_seed.set_random_seed(1234)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(use_jit):
r = compute_fn()
sess.run(variables.global_variables_initializer())
@@ -88,7 +88,7 @@ class JITTest(test.TestCase):
self.assertAllClose(v_false_1, v_true_1)
def testJITXlaScope(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True):
# XlaScope 0
a1 = constant_op.constant(1)
@@ -138,7 +138,8 @@ class JITTest(test.TestCase):
self.assertAllClose(v_false_1, v_true_1)
def testDefunNoJitScope(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
+
@function.Defun(compiled=True, noinline=True)
def mulop(x1, x2):
return x1 * x2
@@ -153,7 +154,7 @@ class JITTest(test.TestCase):
self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
def testDefunInheritsJitScope(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True):
@function.Defun(compiled=True, noinline=True)
def mulop(x1, x2):
@@ -195,7 +196,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertAllClose([[108]], x_grads.eval())
def testCompilationGradientScopeNames(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope():
# XlaScope 0
a1 = constant_op.constant([[1.]])
@@ -217,7 +218,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))
def testCompilationSeparateGradientScopeNames(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
# XlaScope 0
a1 = constant_op.constant([[1.]])
@@ -241,7 +242,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
grad_a2.op.get_attr("_XlaScope"))
def testPlaysNicelyWithDefun(self):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(True):
@function.Defun(compiled=True, noinline=True)
def mulop(x1, x2):
@@ -266,7 +267,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
def testPlaysNicelyWithDefunSeparateGradientScope(self):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(True):
@function.Defun(
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
new file mode 100644
index 0000000000..60f5af1662
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla.py
@@ -0,0 +1,208 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""xla provides experimental xla support API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+_XLA_COMPILE_ATTR = '_xla_compile_id'
+_MAX_WARNING_LINES = 5
+
+# Operations that indicate some error in the users graph. For example, XLA
+# computation should not have any Placeholder op.
+_BLACKLISTED_OPS = set([
+ 'Placeholder',
+])
+
+# XLA doesn't currently support reading of intermediate tensors, thus some ops
+# are not supported.
+_UNSUPPORTED_OPS = set([
+ 'AudioSummary',
+ 'AudioSummaryV2',
+ 'HistogramSummary',
+ 'ImageSummary',
+ 'MergeSummary',
+ 'Print',
+ 'ScalarSummary',
+ 'TensorSummary',
+ 'TensorSummaryV2',
+])
+
+
+class XLACompileContext(control_flow_ops.XLAControlFlowContext):
+ """A `ControlFlowContext` for nodes inside an XLA computation cluster.
+
+ THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
+
+ The primary role of `XLACompileContext` is to mark operators inside a
+ xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
+ a unique name.
+
+ `ControlFlowContext` is used to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a xla.compile() block, the
+ `ResourceVariable` implementation can use
+ `with ops.control_dependencies(None)` to build the variable's definition
+ outside the compiled computation.
+ """
+
+ def __init__(self, name, pivot):
+ """Builds a new XLACompileContext.
+
+ Args:
+ name: a unique name for the context, used to populate the
+ `_xla_compile_id` attribute.
+ pivot: a pivot node. Nodes in the XLACompileContext that do not have any
+ inputs will have a control dependency on the pivot node. This ensures
+ that nodes are correctly included in any enclosing control flow
+ contexts.
+ """
+ super(XLACompileContext, self).__init__()
+ self._name = name
+ self._name_as_bytes = compat.as_bytes(name)
+ self._unsupported_ops = []
+ self._pivot = pivot
+
+ def report_unsupported_operations(self):
+ if self._unsupported_ops:
+ op_str = '\n'.join([
+ ' %s (%s)' % (op.type, op.name)
+ for op in self._unsupported_ops[:_MAX_WARNING_LINES]
+ ])
+ logging.warning('%d unsupported operations found: \n%s',
+ len(self._unsupported_ops), op_str)
+ if len(self._unsupported_ops) > _MAX_WARNING_LINES:
+ logging.warning('... and %d more',
+ len(self._unsupported_ops) - _MAX_WARNING_LINES)
+
+ def AddOp(self, op):
+ """Create op in XLACompileContext and notifies outer context recursively."""
+ # pylint: disable=protected-access
+ if op.type in _BLACKLISTED_OPS:
+ logging.error(
+ 'Operation of type %s (%s) is not supported in XLA. Execution will '
+ 'fail if this op is used in the graph. ', op.type, op.name)
+
+ # TODO(ycao): Automatically disable summaries instead of reporting them.
+ if op.type in _UNSUPPORTED_OPS:
+ self._unsupported_ops.append(op)
+
+ if any(x.dtype._is_ref_dtype for x in op.inputs):
+ raise NotImplementedError(
+ 'Non-resource Variables are not supported inside XLA computations '
+ '(operator name: %s)' % op.name)
+
+ if _XLA_COMPILE_ATTR in op.node_def.attr:
+ raise ValueError('XLA compiled computations cannot be nested, (operator '
+ 'name: %s)' % op.name)
+
+ op._set_attr(
+ _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
+
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
+
+ # Remove any control edges from outer control flow contexts. These may cause
+ # mismatched frame errors. An example is when one of op's inputs is
+ # generated in a different While control flow context.
+ (internal_control_inputs,
+ external_control_inputs) = self._RemoveExternalControlEdges(op)
+
+ if not op.inputs:
+ # Add a control edge from the control pivot to this op.
+ if not internal_control_inputs:
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot)
+ # pylint: enable=protected-access
+ else:
+ for index in xrange(len(op.inputs)):
+ x = op.inputs[index]
+ real_x = self.AddValue(x)
+ if real_x != x:
+ op._update_input(index, real_x) # pylint: disable=protected-access
+
+ if external_control_inputs:
+ # Use an identity to pull control inputs as data inputs. Note that we
+ # ignore ops which don't have outputs. TODO(phawkins): fix that.
+ with ops.control_dependencies(None):
+ self.Enter()
+ external_control_inputs = [
+ array_ops.identity(x.outputs[0]).op
+ for x in external_control_inputs
+ if x.outputs
+ ]
+ self.Exit()
+ # pylint: disable=protected-access
+ op._add_control_inputs(external_control_inputs)
+ # pylint: enable=protected-access
+
+ # Mark op's outputs as seen by this context and any outer contexts.
+ output_names = [x.name for x in op.outputs]
+ context = self
+ while context is not None:
+ # pylint: disable=protected-access
+ context._values.update(output_names)
+ context = context._outer_context
+ # pylint: enable=protected-access
+
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ """Add `val` to the current context and its outer context recursively."""
+ if val.name in self._values:
+ # Use the real value if it comes from outer context.
+ result = self._external_values.get(val.name)
+ return val if result is None else result
+
+ result = val
+ self._values.add(val.name)
+ if self._outer_context:
+ result = self._outer_context.AddValue(val)
+ self._values.add(result.name)
+
+ self._external_values[val.name] = result
+
+ return result
+
+ def AddInnerOp(self, op):
+ self.AddOp(op)
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ @property
+ def grad_state(self):
+ # Define the gradient loop state associated with the XLACompileContext to
+ # be None as the XLACompileContext does not get nested nor does the
+ # grad_state outside the XLACompileContext affect the graph inside so the
+ # grad_state should be as if this is the top-level gradient state.
+ return None
+
+ @property
+ def back_prop(self):
+ """Forwards to the enclosing while context, if any."""
+ if self.GetWhileContext():
+ return self.GetWhileContext().back_prop
+ return False
diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py
new file mode 100644
index 0000000000..a306b56f63
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla_test.py
@@ -0,0 +1,180 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for contrib.compiler.xla."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.compiler import xla
+from tensorflow.python import summary
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class XLACompileContextTest(test.TestCase):
+
+ def create_test_xla_compile_context(self):
+ computation_name = ops.get_default_graph().unique_name('computation')
+ pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
+ return xla.XLACompileContext(name=computation_name, pivot=pivot)
+
+ def test_report_unsupported_operations(self):
+ """Tests that unsupported operations are detected."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ dummy_tensor = constant_op.constant(1.1)
+ audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
+ histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
+ image_summary = summary.image('image_summary', dummy_tensor)
+ scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
+ tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor)
+ summary.merge(
+ [
+ audio_summary, histogram_summary, image_summary, scalar_summary,
+ tensor_summary
+ ],
+ name='merge_summary')
+ logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
+ context.Exit()
+
+ unsupported_ops_names = [op.name for op in context._unsupported_ops]
+ self.assertEqual(unsupported_ops_names, [
+ u'audio_summary', u'histogram_summary', u'image_summary',
+ u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
+ u'print_op'
+ ])
+
+ def test_resource_variable(self):
+ """Tests that resource variable usage is allowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=True)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_non_resource_variable_error(self):
+ """Tests that non-resource variable usage is disallowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=False)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with self.assertRaisesRegexp(
+ NotImplementedError, 'Non-resource Variables are not supported inside '
+ r'XLA computations \(operator name: Assign\)'):
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_nested_xla_compile_error(self):
+ """Tests that nested XLA computation leads to fatal error."""
+ context1 = self.create_test_xla_compile_context()
+ context1.Enter()
+
+ context2 = self.create_test_xla_compile_context()
+ context2.Enter()
+ with self.assertRaisesRegexp(ValueError,
+ 'XLA compiled computations cannot be nested'):
+ constant_op.constant(1)
+ context2.Exit()
+ context1.Exit()
+
+ def test_xla_compile_attr(self):
+ """Tests that ops are tagged with XLA compile ID attribute."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertIn('_xla_compile_id', op.op.node_def.attr)
+
+ def test_op_without_input(self):
+ """Tests that ops without inputs depend on pivot correctly."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(context._pivot, op.op.control_inputs)
+
+ def test_external_control_edges(self):
+ """Tests that external control edges are handled correctly."""
+ i = constant_op.constant(1)
+ op1 = constant_op.constant(1)
+
+ with ops.control_dependencies([op1]):
+ op2 = constant_op.constant(1)
+ self.assertIn(op1.op, op2.op.control_inputs)
+
+ def while_body(i):
+ del i # unused
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with ops.control_dependencies([op1]):
+ op3 = constant_op.constant(1)
+ context.Exit()
+ self.assertNotIn(op1.op, op3.op.control_inputs)
+ return op3
+
+ control_flow_ops.while_loop(
+ cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
+
+ def test_op_output_marked_as_seen(self):
+ """Tests that any op output is marked as seen in context."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(op.name, context._values)
+
+ def testOpIsInContext(self):
+ """Tests that XLACompileContext is recognized as an XLA context."""
+ op1 = constant_op.constant(1)
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op2 = constant_op.constant(2)
+ context.Exit()
+ self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
+ self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
+
+ def testOpPreventFeeding(self):
+ """Tests that ops created inside XLACompileContext can not be fed."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_feedable(op.op))
+
+ def testOpPreventFetching(self):
+ """Tests that ops created inside XLACompileContext can not be fetched."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_fetchable(op.op))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
index 9b4bf62710..3e25079e02 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_multipliers1 = session.run(
external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
multipliers1, 1.0))
@@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
]
multipliers = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(multipliers) < len(expected_multipliers):
multipliers.append(session.run(optimizer.lagrange_multipliers))
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
index 34c4543dca..df0eced631 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase):
matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
eigenvector1 = session.run(
swap_regret_optimizer._maximal_eigenvector_power_method(
standard_ops.constant(matrix1)))
@@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
[0.4, 0.3, 0.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
matrix))
@@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
[0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
standard_ops.exp(
swap_regret_optimizer.
@@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
@@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index 615e62b16f..fe5e34d258 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Linear-chain CRF layer.
-See the @{$python/contrib.crf} guide.
+See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide.
@@crf_binary_score
@@crf_decode
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 8cfe142059..556d731840 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -61,7 +61,7 @@ class CrfTest(test.TestCase):
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
@@ -96,7 +96,7 @@ class CrfTest(test.TestCase):
]
for sequence_lengths, inputs, tag_bitmap in zip(
sequence_lengths_list, inputs_list, tag_bitmap_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_multitag_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
@@ -124,7 +124,7 @@ class CrfTest(test.TestCase):
for dtype in (np.int32, np.int64):
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -140,7 +140,7 @@ class CrfTest(test.TestCase):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
binary_score = crf.crf_binary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -176,7 +176,7 @@ class CrfTest(test.TestCase):
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
# Compare the dynamic program with brute force computation.
@@ -206,7 +206,7 @@ class CrfTest(test.TestCase):
"""
Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
@@ -226,7 +226,7 @@ class CrfTest(test.TestCase):
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_log_likelihoods = []
# Make sure all probabilities sum to 1.
@@ -254,7 +254,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -310,7 +310,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -351,7 +351,7 @@ class CrfTest(test.TestCase):
"""
Test that crf_decode works when sequence_length contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 252ea1560d..fda1b9f1b3 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -802,7 +802,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
[single_cell_fn() for _ in range(num_layers)])
input_size = 3
save_graph = ops.Graph()
- with save_graph.as_default(), self.test_session(graph=save_graph):
+ with save_graph.as_default(), self.session(graph=save_graph):
save_layer = _MultiCellFn()
save_layer(inputs=array_ops.ones([1, input_size]),
state=save_layer.zero_state(1, dtypes.float32))
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 8bdbba83ef..9f710613dd 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -33,14 +33,22 @@ cc_library(
tf_custom_op_library(
name = "_dataset_ops.so",
- srcs = ["ops/dataset_ops.cc"],
- deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] +
- if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
+ srcs = [
+ "ops/dataset_ops.cc",
+ "ops/indexed_dataset_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/kernels:dataset_kernels",
+ "//tensorflow/contrib/data/kernels:indexed_dataset",
+ ] + if_static(
+ extra_deps = [":lib_proto_parsing_for_dataset_ops"],
+ otherwise = [],
+ ),
)
tf_gen_op_libs(
- op_lib_names = ["dataset_ops"],
+ op_lib_names = [
+ "dataset_ops",
+ "indexed_dataset_ops",
+ ],
)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index dbfff9b4f8..baec238c62 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -20,11 +20,13 @@ be used in conjunction with the `tf.data.Dataset` API. Note that the
guarantees as `tf.data`, but we will provide deprecation advice in advance of
removing existing functionality.
-See @{$guide/datasets$Importing Data} for an overview.
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@LMDBDataset
+@@Optional
@@RandomDataset
@@Reducer
@@SqlDataset
@@ -37,7 +39,7 @@ See @{$guide/datasets$Importing Data} for an overview.
@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
-
+@@get_next_as_optional
@@get_single_element
@@group_by_reducer
@@group_by_window
@@ -45,10 +47,10 @@ See @{$guide/datasets$Importing Data} for an overview.
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
-
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
+@@parse_example_dataset
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
@@ -89,10 +91,12 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
+from tensorflow.contrib.data.python.ops.readers import LMDBDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
@@ -103,6 +107,8 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index 2e249f5c14..ec6cb37193 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -7,6 +7,31 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
cc_library(
+ name = "indexed_dataset_headers",
+ hdrs = ["indexed_dataset.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
+ name = "indexed_dataset",
+ srcs = [
+ "identity_indexed_dataset.cc",
+ "indexed_dataset.cc",
+ ],
+ deps = [
+ ":indexed_dataset_headers",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
@@ -52,6 +77,17 @@ cc_library(
)
cc_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
@@ -91,6 +127,8 @@ cc_library(
":csv_dataset_op",
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
+ ":indexed_dataset",
+ ":lmdb_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
index e36c9c0634..c19a609780 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
AssertNextDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index d242cfdf49..74107d5242 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
+namespace data {
namespace {
class CSVDatasetOp : public DatasetOpKernel {
@@ -713,7 +714,7 @@ class CSVDatasetOp : public DatasetOpKernel {
component.scalar<string>()() =
dataset()->record_defaults_[output_idx].flat<string>()(0);
} else {
- component.scalar<string>()() = field.ToString();
+ component.scalar<string>()() = string(field);
}
break;
}
@@ -851,4 +852,5 @@ class CSVDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
index ccf7ec1f84..a5321620bf 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -276,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
new file mode 100644
index 0000000000..c3cb45dbf7
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
+ public:
+ using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
+
+ void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) override {
+ uint64 size = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
+ OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
+ *output = new Dataset(ctx, size);
+ }
+
+ class Dataset : public IndexedDataset {
+ public:
+ Dataset(OpKernelContext* ctx, uint64 size)
+ : IndexedDataset(DatasetContext(ctx)), size_(size) {}
+
+ Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
+ materialized->reset(new Materialized(this));
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
+ }
+
+ string DebugString() const override {
+ return "IdentityIndexedDataset::Dataset";
+ }
+
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** node) const override {
+ return errors::Unimplemented(
+ "identity_indexed_dataset.AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (cur_ < dataset()->size_) {
+ Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = cur_++;
+ out_tensors->emplace_back(std::move(result_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ uint64 cur_ GUARDED_BY(mu_);
+ };
+
+ class Materialized : public MaterializedIndexedDataset {
+ public:
+ explicit Materialized(Dataset* dataset) : dataset_(dataset) {
+ dataset->Ref();
+ }
+
+ ~Materialized() override {
+ // TODO(saeta): Pull this into MaterializedIndexedDataset
+ dataset_->Unref();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dataset_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return dataset_->output_shapes();
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const override {
+ LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
+ << ")";
+ if (index >= dataset_->size_) {
+ // Note: use InvalidArgument instead of OutOfRange error because many
+ // things consider OutOfRange to be a "clean termination" error.
+ return errors::InvalidArgument(
+ "Index ", index,
+ " is out of range for this dataset. (Size is: ", dataset_->size_,
+ ".)");
+ }
+ Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = index;
+ out_tensors->emplace_back(std::move(result_tensor));
+ return Status::OK();
+ }
+
+ Status Size(uint64* size) const override {
+ *size = dataset_->size_;
+ return Status::OK();
+ }
+
+ private:
+ const Dataset* const dataset_; // Not owned.
+ };
+
+ const uint64 size_;
+ std::shared_ptr<Materialized> materialized_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
index db24e60846..beec344534 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -137,5 +137,5 @@ REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
IgnoreErrorsDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc
new file mode 100644
index 0000000000..ced8ab0d60
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc
@@ -0,0 +1,373 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
+class MaterializedDatasetResource : public ResourceBase {
+ public:
+ MaterializedDatasetResource(
+ const DataTypeVector& output_dtypes,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
+
+ string DebugString() override {
+ return "Materialized IndexedDataset resource";
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) {
+ std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
+ if (captured) {
+ return captured->Get(std::move(ctx), index, out_tensors);
+ } else {
+ return errors::FailedPrecondition(
+ "Get() failed because the MaterializedIndexedDataset has not been "
+ "initialized. Ensure that you have run the materialization operation "
+ "for this MaterializedIndexedDataset before retrieving elements.");
+ }
+ }
+
+ // TODO(saeta): Implement Save and Restore
+
+ const DataTypeVector& output_dtypes() const { return output_dtypes_; }
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ Status set_materialized_dataset(
+ const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
+ if (dataset) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
+ }
+ materialized_ = dataset;
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<MaterializedIndexedDataset> materialized_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
+// tensor. Objects of the wrapper class own a reference on an instance of an
+// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
+// of managing the reference count.
+//
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
+// implemented.
+//
+// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
+// use `tensorflow::DatasetVariantWrapper`.
+class IndexedDatasetVariantWrapper {
+ public:
+ IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
+
+ // Transfers ownership of `dataset` to `*this`.
+ explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
+ : dataset_(dataset) {}
+
+ IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
+ : dataset_(other.dataset_) {
+ if (dataset_) dataset_->Ref();
+ }
+
+ ~IndexedDatasetVariantWrapper() {
+ if (dataset_) dataset_->Unref();
+ }
+
+ IndexedDataset* get() const { return dataset_; }
+
+ string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
+ string DebugString() const {
+ if (dataset_) {
+ return dataset_->DebugString();
+ } else {
+ return "<Uninitialized IndexedDatasetVariantWrapper>";
+ }
+ }
+
+ void Encode(VariantTensorData* data) const {
+ LOG(ERROR) << "The Encode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ LOG(ERROR) << "The Decode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ return false;
+ }
+
+ private:
+ IndexedDataset* const dataset_; // Owns one reference.
+};
+
+} // namespace
+
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset) {
+ if (!(tensor.dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor.shape()))) {
+ return errors::InvalidArgument(
+ "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ const Variant& variant = tensor.scalar<Variant>()();
+ const IndexedDatasetVariantWrapper* wrapper =
+ variant.get<IndexedDatasetVariantWrapper>();
+ if (wrapper == nullptr) {
+ return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
+ }
+ *out_dataset = wrapper->get();
+ if (*out_dataset == nullptr) {
+ return errors::Internal("Read uninitialized IndexedDataset variant.");
+ }
+ return Status::OK();
+}
+
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor) {
+ if (!(tensor->dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor->shape()))) {
+ return errors::InvalidArgument(
+ "Dataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
+ return Status::OK();
+}
+
+void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
+ IndexedDataset* dataset = nullptr;
+ MakeIndexedDataset(ctx, &dataset);
+
+ if (ctx->status().ok()) {
+ OP_REQUIRES(ctx, dataset != nullptr,
+ errors::Internal("MakeIndexedDataset did not correctly "
+ "construct the IndexedDataset"));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
+ }
+}
+
+namespace {
+
+class MaterializedHandleOp : public OpKernel {
+ public:
+ explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ ~MaterializedHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MaterializedDatasetResource* resource;
+ OP_REQUIRES_OK(context,
+ mgr->LookupOrCreate<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this](MaterializedDatasetResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MaterializedDatasetResource(
+ output_dtypes_, output_shapes_);
+ return Status::OK();
+ }));
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MaterializedDatasetResource>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MaterializedDatasetResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+// TODO(saeta): Make async.
+class MaterializeDatasetOp : public OpKernel {
+ public:
+ explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IndexedDataset* dataset;
+ OP_REQUIRES_OK(ctx,
+ GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
+
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &materialized_resource));
+ core::ScopedUnref unref(materialized_resource);
+ std::shared_ptr<MaterializedIndexedDataset> materialized;
+ OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
+ OP_REQUIRES_OK(
+ ctx, materialized_resource->set_materialized_dataset(materialized));
+ }
+};
+
+// TODO(saeta): Make async
+class IndexedDatasetGet : public OpKernel {
+ public:
+ explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
+ &materialized_resource));
+ auto cleanup = gtl::MakeCleanup([materialized_resource] {
+ materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
+ });
+
+ const Tensor* index_t;
+ OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
+ // TODO(saeta): Support batch reads (indexes should be non-scalar!)
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
+ errors::InvalidArgument("index must be a scalar"));
+ const uint64 index = index_t->scalar<uint64>()();
+
+ std::vector<Tensor> out_tensors;
+ Status s =
+ materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
+
+ // Note: Unref materialized_resource to avoid destruction races. (Important
+ // in a [future] async op implementation.)
+ cleanup.release()();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else {
+ auto expected_shapes = materialized_resource->output_shapes();
+ auto expected_types = materialized_resource->output_dtypes();
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ OP_REQUIRES(
+ ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
+ errors::Internal(
+ "Materialized dataset output at index ", i,
+ " is incompatible with the expected shape. (Expected: ",
+ expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
+ OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
+ errors::Internal("Materialized dataset output at index ", i,
+ " was not the expected dtype. (Expected: ",
+ expected_types[i],
+ ", got: ", out_tensors[i].dtype(), ")"));
+ ctx->set_output(i, out_tensors[i]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ MaterializedHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h
new file mode 100644
index 0000000000..7aa2d3fdbc
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.h
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace data {
+
+// TODO(saeta): Urgh, this is ugly.
+class MaterializedIndexedDataset {
+ public:
+ virtual ~MaterializedIndexedDataset() = default;
+
+ // Retrieve the element at a given index. The output tensors are stored in
+ // out_tensors.
+ //
+ // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
+ // returned.
+ //
+ // Get is thread-safe.
+ virtual Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const = 0;
+
+ // Size determines the number of elements in this IndexedDataset.
+ //
+ // Size is thread-safe.
+ virtual Status Size(uint64* size) const = 0;
+
+ // Returns a vector of DataType values, representing the respective
+ // element types of each tuple component in the outputs of this dataset.
+ virtual const DataTypeVector& output_dtypes() const = 0;
+
+ // Returns a vector of tensor shapes, representing the respective
+ // (and possibly partially defined) shapes of each tuple component
+ // in the outputs of this dataset.
+ virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+};
+
+// IndexedDataset represents a dataset that supports random access in addition
+// to iterator-based sequential access.
+//
+// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
+// significant (backwards incompatible) changes!
+class IndexedDataset : public DatasetBase {
+ public:
+ IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
+
+ // Materialize (if necessary) the dataset, and return a pointer.
+ // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
+ virtual Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
+};
+
+// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
+// rest of the TensorFlow runtime.
+//
+// Most IndexedDataset's will be private members of classes inheriting from this
+// class.
+class IndexedDatasetOpKernel : public OpKernel {
+ public:
+ IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Validates and extracts an `IndexedDataset` object from `tensor`.
+//
+// `tensor` must have been written by a call to
+// `StoreIndexedDatasetInVariantTensor`
+//
+// The retrieved pointer isa borrowed reference to the dataset, which is owned
+// by the tensor. The consumer must either acquire its own reference to the
+// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
+// destroyed or mutated while the retrieved pointer is in use.
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset);
+
+// Stores an `IndexedDataset` object in `tensor.`
+//
+// The ownership of `dataset` is transferred to `tensor`.
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..d233c1f8ec
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 74df1e42a8..078de717e0 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
+namespace data {
namespace {
struct BufferElement {
@@ -548,7 +549,9 @@ class MultiDeviceIterator : public ResourceBase {
devices_(devices),
flib_def_(std::move(flib_def)),
pflr_(std::move(pflr)),
- lib_(lib) {}
+ lib_(lib) {
+ CHECK_NOTNULL(lib_);
+ }
string DebugString() override {
return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
@@ -600,6 +603,11 @@ class MultiDeviceIterator : public ResourceBase {
return lib_def_;
}
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
private:
// A private class that uses a background thread to keep a per device buffer
// full.
@@ -930,8 +938,10 @@ class MultiDeviceIteratorInitOp : public OpKernel {
core::ScopedUnref unref(resource);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator",
- &iterator));
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));
@@ -1105,5 +1115,6 @@ REGISTER_KERNEL_BUILDER(
Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
MultiDeviceIteratorFromStringHandleOp);
-} // anonymous namespace
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index ab584504a0..30fa97a636 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
+namespace data {
namespace {
class ThreadPoolResource : public ResourceBase {
@@ -214,4 +215,5 @@ REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
ThreadPoolDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
index 6fbf5d2ebb..57fc5697a4 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -219,5 +219,5 @@ REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index cc5e250ea1..ae104d55bd 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -266,4 +266,13 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("LMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
new file mode 100644
index 0000000000..cd9b7c68a0
--- /dev/null
+++ b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("MaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("IndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("IndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn)
+ .Doc(R"doc(
+Gets the element at `index` from `materialized` IndexedDataset.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 2b75aa2ca5..6f0111a2bd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -4,7 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "batch_dataset_op_test",
@@ -134,12 +135,26 @@ py_test(
)
py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:contrib_op_loader",
+ "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
+ "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "interleave_dataset_op_test",
size = "medium",
srcs = ["interleave_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
- "manual",
"no_oss",
"no_pip",
"notap",
@@ -180,6 +195,31 @@ py_test(
)
py_test(
+ name = "lmdb_dataset_op_test",
+ size = "medium",
+ srcs = ["lmdb_dataset_op_test.py"],
+ data = ["//tensorflow/core:lmdb_testdata"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
@@ -206,6 +246,25 @@ py_test(
)
py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.py"],
@@ -220,26 +279,30 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
],
)
py_test(
- name = "optimize_dataset_op_test",
+ name = "parsing_ops_test",
size = "small",
- srcs = ["optimize_dataset_op_test.py"],
+ srcs = ["parsing_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/contrib/data/python/ops:parsing_ops",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
)
@@ -329,6 +392,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
@@ -549,3 +613,13 @@ py_test(
"//tensorflow/python/data/ops:readers",
],
)
+
+py_library(
+ name = "test_utils",
+ srcs = ["test_utils.py"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 42adfd17f0..8e368bf2bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
@@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
@@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
@@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
@@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
@@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
@parameterized.named_parameters(
- ("default", None, None),
- ("sequential_calls", 1, None),
- ("parallel_calls", 2, None),
- ("parallel_batches", None, 10),
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
@@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("even", False),
- ("uneven", True),
+ ("Even", False),
+ ("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
@@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
@@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
@@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
@@ -659,11 +659,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
- @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
@@ -679,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
@@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (False, dtypes.bool),
- (-42, dtypes.int8),
- (-42, dtypes.int16),
- (-42, dtypes.int32),
- (-42, dtypes.int64),
- (42, dtypes.uint8),
- (42, dtypes.uint16),
- (42.0, dtypes.float16),
- (42.0, dtypes.float32),
- (42.0, dtypes.float64),
- (b"hello", dtypes.string),
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
@@ -711,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@@ -720,6 +727,42 @@ class RestructuredDatasetTest(test.TestCase):
def test_assert_element_shape(self):
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
@@ -741,6 +784,59 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(5):
@@ -748,7 +844,7 @@ class RestructuredDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def test_assert_wrong_element_shape(self):
+ def test_assert_wrong_partial_element_shape(self):
def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
@@ -756,11 +852,41 @@ class RestructuredDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.range(3).map(create_dataset)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
with self.assertRaises(ValueError):
dataset.apply(batching.assert_element_shape(wrong_shapes))
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
@@ -776,13 +902,13 @@ class RestructuredDatasetTest(test.TestCase):
self.assertEqual(unknown_shapes, dataset.output_shapes)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
iterator = (
dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f2bd..293be2bd06 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
@@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
@@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
@@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase):
element_len, boundaries, batch_sizes))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
@@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
@@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 2a0e64caeb..63bffd023f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -51,7 +51,7 @@ class CsvDatasetOpTest(test.TestCase):
assert ds1.output_classes == ds2.output_classes
next1 = ds1.make_one_shot_iterator().get_next()
next2 = ds2.make_one_shot_iterator().get_next()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Run through datasets and check that outputs match, or errors match.
while True:
try:
@@ -138,7 +138,7 @@ class CsvDatasetOpTest(test.TestCase):
filenames = self._setup_files(inputs, linebreak, compression_type)
kwargs['compression_type'] = compression_type
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, **kwargs)
self._verify_output_or_err(sess, dataset, expected_output,
expected_err_re)
@@ -192,7 +192,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
filenames = self._setup_files(inputs)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
@@ -202,7 +202,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
filenames = self._setup_files(inputs)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
dataset = dataset.apply(error_ops.ignore_errors())
self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
@@ -378,7 +378,7 @@ class CsvDatasetOpTest(test.TestCase):
file_path, batch_size=1, shuffle=False, num_epochs=1)
next_batch = ds.make_one_shot_iterator().get_next()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
result = list(sess.run(next_batch).values())
self.assertEqual(result, sorted(result))
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9b1857de1a..eb110324d1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(100):
for i in range(10):
@@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
freqs[sess.run(next_element)] += 1
@@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
# "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs]:
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
@@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
new file mode 100644
index 0000000000..6d01bf585c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks FilterDataset input pipeline op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # filter with and without filter fusion.
+ def benchmarkFilters(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkFilters(chain_length, False)
+ self._benchmarkFilters(chain_length, True)
+
+ def _benchmarkFilters(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Filter dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index e6883d53e0..f3968cdc15 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if error is None:
dense_val, sparse_val = sess.run(
element, feed_dict={
@@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(stop_t)
element = get_single_element.reduce_dataset(dataset, sum_reducer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sess.run(element, feed_dict={stop_t: stop})
self.assertEqual(stop * (stop - 1) / 2, value)
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
new file mode 100644
index 0000000000..9c508d686d
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental indexed dataset ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class IndexedDatasetOpsTest(test.TestCase):
+
+ def testLowLevelIndexedDatasetOps(self):
+ identity = gen_dataset_ops.identity_indexed_dataset(
+ ops.convert_to_tensor(16, dtype=dtypes.uint64))
+ handle = gen_dataset_ops.materialized_index_dataset_handle(
+ container="",
+ shared_name="",
+ output_types=[dtypes.uint64],
+ output_shapes=[[]])
+ materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ index = array_ops.placeholder(dtypes.uint64)
+ get_op = gen_dataset_ops.indexed_dataset_get(
+ handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
+
+ with self.cached_session() as sess:
+ sess.run(materialize)
+ self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
+
+ def testIdentityIndexedDataset(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ materialized = ds.materialize()
+ with self.cached_session() as sess:
+ sess.run(materialized.initializer)
+ placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
+ for i in range(16):
+ output = sess.run(
+ materialized.get(placeholder), feed_dict={placeholder: i})
+ self.assertEqual([i], output)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
+
+ @unittest.skip("Requisite functionality currently unimplemented.")
+ def testIdentityIndexedDatasetIterator(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.cached_session() as sess:
+ sess.run(itr.initializer)
+ for i in range(16):
+ output = sess.run(n)
+ self.assertEqual(i, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 44c3325a3d..b9e74dfddb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Empty input.
self._clear_coordination_events()
sess.run(
@@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
def testBlockLengthWithContentionSloppy(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
@@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
@@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(get_next)
def testErrorsInOutputFn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -777,6 +777,34 @@ class ParallelInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
+ def testShutdownRace(self):
+ dataset = dataset_ops.Dataset.range(20)
+ map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ map_fn,
+ cycle_length=3,
+ sloppy=False,
+ buffer_output_elements=1,
+ prefetch_input_elements=0))
+ dataset = dataset.batch(32)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ results = []
+ with self.cached_session() as sess:
+ for _ in range(2):
+ elements = []
+ sess.run(iterator.initializer)
+ try:
+ while True:
+ elements.extend(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ results.append(elements)
+
+ self.assertAllEqual(results[0], results[1])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 77148aceec..704c0d1eb2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -60,7 +60,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
saver.restore(sess, ckpt_path)
return sess.run(ops.get_collection('my_vars'))
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
new file mode 100644
index 0000000000..1cc5ddc9a2
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LMDBDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+prefix_path = "tensorflow/core/lib"
+
+
+class LMDBDatasetTest(test.TestCase):
+
+ def setUp(self):
+ super(LMDBDatasetTest, self).setUp()
+ # Copy database out because we need the path to be writable to use locks.
+ path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb")
+ self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
+ shutil.copy(path, self.db_path)
+
+ def testReadFromFile(self):
+ filename = self.db_path
+
+ filenames = constant_op.constant([filename], dtypes.string)
+ num_repeats = 2
+
+ dataset = readers.LMDBDataset(filenames).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(10): # 10 records.
+ k = compat.as_bytes(str(i))
+ v = compat.as_bytes(str(chr(ord("a") + i)))
+ self.assertEqual((k, v), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 009e21a34c..e8519381d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
@@ -139,7 +139,7 @@ class MapDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
captured_init_op, init_op, get_next = _build_graph()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(captured_init_op)
sess.run(init_op)
for i in range(10):
@@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark):
end = time.time()
chained_deltas.append(end - start)
- fused_dataset = dataset = dataset.apply(
+ fused_dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul,
num_parallel_calls=num_calls,
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index a711325dae..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,53 +28,63 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
- def testMapDefun_Simple(self):
+ def testMapDefunSimple(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2 + 3
- with self.test_session():
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_MismatchedTypes(self):
+ def testMapDefunMismatchedTypes(self):
@function.Defun(dtypes.int32)
def fn(x):
return math_ops.cast(x, dtypes.float64)
- with self.test_session():
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunReduceDim(self):
+ # Tests where the output has a different rank from the input
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return array_ops.gather(x, 0)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ expected = constant_op.constant([1, 3, 5])
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_MultipleOutputs(self):
+ def testMapDefunMultipleOutputs(self):
@function.Defun(dtypes.int32)
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
- with self.test_session():
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
- [(2,), (2,)])
- expected = [elems, elems * 2 + 3]
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
+ (2,)])
+ expected = [elems, elems * 2 + 3]
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
- def testMapDefun_ShapeInference(self):
+ def testMapDefunShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -82,7 +95,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
- def testMapDefun_PartialShapeInference(self):
+ def testMapDefunPartialShapeInference(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -92,7 +105,7 @@ class MapDefunTest(test.TestCase):
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
- def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):
+ def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.Defun(dtypes.int32, dtypes.int32)
def fn(x, y):
@@ -108,7 +121,7 @@ class MapDefunTest(test.TestCase):
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
- def testMapDefun_RaisesDefunError(self):
+ def testMapDefunRaisesDefunError(self):
@function.Defun(dtypes.int32)
def fn(x):
@@ -117,10 +130,124 @@ class MapDefunTest(test.TestCase):
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
- with self.test_session():
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(result)
+
+ def testMapDefunCancelledCorrectly(self):
+
+ @function.Defun(dtypes.int64)
+ def defun(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ c = array_ops.tile(
+ array_ops.expand_dims(
+ constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+ [100, 1])
+ map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(map_defun_op)
+
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(result)
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
new file mode 100644
index 0000000000..459bdf66f3
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -0,0 +1,88 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_vectorization_test",
+ size = "small",
+ srcs = ["map_vectorization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:test_utils",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_filter_fusion_test",
+ size = "medium",
+ srcs = ["map_and_filter_fusion_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_op_test",
+ size = "small",
+ srcs = ["optimize_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..bd7b50b902
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test.TestCase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
new file mode 100644
index 0000000000..db380c02a9
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the LatencyAllEdges optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ optimization.optimize(["latency_all_edges"])).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index ae147b4fa7..dde115925e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the MapAndFilterFusion optimization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,91 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
-
- def testAssertSuffix(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertSuffixInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertSuffixShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
-
- def testDefaultOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize())
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testEmptyOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize([]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimization(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFunctionLibraryDefinitionModification(self):
- dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply(
- optimization.optimize(["_test_only_function_rename"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(errors.NotFoundError,
- "Function .* is not defined."):
- sess.run(get_next)
+class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
@staticmethod
def map_functions():
@@ -130,22 +44,22 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
- "test_{}_{}".format(i, j),
+ "Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
- "test_{}_{}_{}".format(i, j, k),
+ "Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
- "swap1",
+ "Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
- "swap2",
+ "Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@@ -160,7 +74,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(5):
result = sess.run(get_next)
r = x
@@ -195,13 +109,13 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
- ("multiTwo", lambda x: (x, 2),
+ ("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@@ -217,7 +131,7 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
def _testMapAndFilter(self, dataset, function, predicate):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(10):
r = function(x)
if isinstance(r, tuple):
@@ -247,34 +161,63 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
self._testMapAndFilter(dataset, function, predicate)
+ @staticmethod
+ def filter_functions():
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
-class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+ tests = []
+ filters = [take_all, is_zero, greater]
+ identity = lambda x: x
+ for x, predicate_1 in enumerate(filters):
+ for y, predicate_2 in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), identity,
+ [predicate_1, predicate_2]))
+ for z, predicate_3 in enumerate(filters):
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
+ [predicate_1, predicate_2, predicate_3]))
+
+ take_all_multiple = lambda x, y: constant_op.constant(True)
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ [take_all_multiple, take_all_multiple]))
+ tests.append(("Multi2", lambda x: (x, 2), [
+ take_all_multiple,
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ ]))
+ return tuple(tests)
- def testLatencyStatsOptimization(self):
+ @parameterized.named_parameters(*filter_functions.__func__())
+ def testFilterFusion(self, map_function, predicates):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Filter",
+ "Prefetch"])).map(map_function)
+ for predicate in predicates:
+ dataset = dataset.filter(predicate)
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.from_tensors(1).apply(
- optimization.assert_next(
- ["LatencyStats", "Map", "LatencyStats", "Prefetch",
- "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- optimization.optimize(["latency_all_edges"])).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
+ dataset = dataset.prefetch(0).apply(
+ optimization.optimize(["filter_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
with self.test_session() as sess:
- sess.run(iterator.initializer)
- self.assertEqual(1 * 1, sess.run(get_next))
+ for x in range(5):
+ r = map_function(x)
+ filtered = False
+ for predicate in predicates:
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if not sess.run(b):
+ filtered = True
+ break
+
+ if not filtered:
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str,
- "record_latency_TensorDataset/_1", 1)
- self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
- 1)
- self._assertSummaryHasCount(summary_str,
- "record_latency_PrefetchDataset/_6", 1)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
new file mode 100644
index 0000000000..e2c9bc82df
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -0,0 +1,219 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapVectorization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import test_utils
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+
+ def _get_test_datasets(self,
+ base_dataset,
+ map_fn,
+ num_parallel_calls=None,
+ expect_optimized=True):
+ """Given base dataset and map fn, creates test datasets.
+
+ Returns a tuple of (unoptimized, dataset, optimized dataset). The
+ unoptimized dataset has the assertion that Batch follows Map. The optimized
+ dataset has the assertion that Map follows Batch, and has the
+ "map_vectorization" optimization applied.
+
+ Args:
+ base_dataset: Input dataset to map->batch
+ map_fn: Map function to use
+ num_parallel_calls: (Optional.) num_parallel_calls argument for map
+ expect_optimized: (Optional.) Whether we expect the optimization to take
+ place, in which case we will assert that Batch is followed by Map,
+ otherwise Map followed by Batch. Defaults to True.
+
+ Returns:
+ Tuple of (unoptimized dataset, optimized dataset).
+ """
+ map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
+ batch_size = 100
+
+ def _make_dataset(node_names):
+ return base_dataset.apply(optimization.assert_next(node_names)).map(
+ map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
+
+ unoptimized = _make_dataset([map_node_name, "Batch"])
+ optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
+ [map_node_name, "Batch"]).apply(
+ optimization.optimize(["map_vectorization"]))
+
+ return unoptimized, optimized
+
+ @parameterized.named_parameters(
+ ("Basic", lambda x: (x, x + 1), None),
+ ("Parallel", lambda x: (x, x + 1), 12),
+ ("Gather", lambda x: array_ops.gather(x, 0), 12),
+ )
+ def testOptimization(self, map_fn, num_parallel_calls):
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+ num_parallel_calls)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationBadMapFn(self):
+ # Test map functions that give an error
+ def map_fn(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
+ 5, drop_remainder=True)
+ _, optimized = self._get_test_datasets(base_dataset, map_fn)
+ nxt = optimized.make_one_shot_iterator().get_next()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(nxt)
+
+ def testOptimizationWithCapturedInputs(self):
+ # Tests that vectorization works with captured inputs
+ def map_fn(x):
+ return x + y
+
+ y = constant_op.constant(1, shape=(2,))
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ # TODO(rachelim): when this optimization works, turn on expect_optimized
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(optimized, unoptimized)
+
+ def testOptimizationIgnoreStateful(self):
+
+ def map_fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+ def testOptimizationIgnoreRagged(self):
+ # Make sure we ignore inputs that might not be uniformly sized
+ def map_fn(x):
+ return array_ops.gather(x, 0)
+
+ # output_shape = (?,)
+ base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationIgnoreRaggedMap(self):
+ # Don't optimize when the output of the map fn shapes are unknown.
+ def map_fn(x):
+ return array_ops.tile(x, x)
+
+ base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+
+class MapVectorizationBenchmark(test.Benchmark):
+ # TODO(rachelim): Add a benchmark for more expensive transformations, such as
+ # vgg_preprocessing.
+
+ def _run(self, x, num_iters=100, name=None):
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ # Warm up session...
+ sess.run(x)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(x)
+ end = time.time()
+ deltas.append(end - start)
+ median_time = np.median(deltas)
+ self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
+ return median_time
+
+ def benchmark_CheapFns(self):
+
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ for map_fn, str_id in self._get_known_cheap_fns():
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
+
+ def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
+ num_elems = np.prod(input_size)
+ name_template = "{}__batch_size_{}_input_size_{}_{}"
+ unoptimized = input_dataset.map(map_fn).batch(batch_size)
+ unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
+
+ optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
+ optimized_op = optimized.make_one_shot_iterator().get_next()
+
+ unoptimized_time = self._run(
+ unoptimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
+ optimized_time = self._run(
+ optimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "optimized"))
+
+ print("Batch size: {}\n"
+ "Input size: {}\n"
+ "Transformation: {}\n"
+ "Speedup: {}\n".format(batch_size, input_size, str_id,
+ (unoptimized_time / optimized_time)))
+
+ def _get_known_cheap_fns(self):
+ return [
+ (lambda *args: [array_ops.identity(x) for x in args], "identity"),
+ (lambda *args: [x + 1 for x in args], "add_const"),
+ (lambda *args: args[0], "select"),
+ (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
+ "cast"),
+ ]
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..909da5aee0
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test.TestCase):
+
+ def testOptimizationDefault(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize())
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationEmpty(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize([]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationFusion(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
new file mode 100644
index 0000000000..c4623bca73
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -0,0 +1,850 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.parsing_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# Helpers for creating Example objects
+example = example_pb2.Example
+feature = feature_pb2.Feature
+features = lambda d: feature_pb2.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
+float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
+# Helpers for creating SequenceExample objects
+feature_list = lambda l: feature_pb2.FeatureList(feature=l)
+feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
+sequence_example = example_pb2.SequenceExample
+
+
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in sorted(dict_tensors.items()):
+ # TODO(shivaniagrawal): flat_output is same as v.
+ expected_v = expected_tensors[k]
+ tf_logging.info("Comparing key: %s", k)
+ print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
+ if sparse_tensor.is_sparse(v):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, len(expected_v)], [k, 3])
+ print("i", i, "flat_output", flat_output[i].indices, "expected_v",
+ expected_v[0])
+ tester.assertAllEqual(expected_v[0], flat_output[i].indices)
+ tester.assertAllEqual(expected_v[1], flat_output[i].values)
+ tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(expected_v, flat_output[i])
+ i += 1
+
+
+class ParseExampleTest(test.TestCase):
+
+ def _test(self,
+ input_tensor,
+ feature_val,
+ expected_values=None,
+ expected_err=None):
+
+ with self.cached_session() as sess:
+ if expected_err:
+ with self.assertRaisesWithPredicateMatch(expected_err[0],
+ expected_err[1]):
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ sess.run(get_next)
+ return
+ else:
+ # Returns dict w/ Tensors and SparseTensors.
+ # Check values.
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ result = sess.run(get_next)
+ flattened = nest.flatten(result)
+ print("result", result, "expected_values", expected_values)
+ _compare_output_to_expected(self, result, expected_values, flattened)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
+ np.asarray(input_tensor).size)
+ for k, f in feature_val.items():
+ print("output_shapes as list ",
+ tuple(dataset.output_shapes[k].as_list()))
+ if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
+ elif isinstance(f, parsing_ops.VarLenFeature):
+ self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
+
+ def testEmptySerializedWithAllDefaults(self):
+ sparse_name = "st_a"
+ a_name = "a"
+ b_name = "b"
+ c_name = "c:has_a_tricky_name"
+ a_default = [0, 42, 0]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ c_default = np.random.rand(2).astype(np.float32)
+
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ sparse_name: expected_st_a,
+ a_name: np.array(2 * [[a_default]]),
+ b_name: np.array(2 * [b_default]),
+ c_name: np.array(2 * [c_default]),
+ }
+
+ self._test(
+ ops.convert_to_tensor(["", ""]), {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ },
+ expected_values=expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ input_features = {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=[0, 42, 0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3),
+ dtypes.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c":
+ parsing_ops.FixedLenFeature(
+ (2,), dtype=dtypes.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({"c": feature()}))
+ self._test(
+ [original.SerializeToString()],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ ["", ""],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })), example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Key: a, Index: 1. Number of float values"))
+
+ def testDenseDefaultNoShapeShouldFail(self):
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
+ expected_err=(ValueError, "Missing shape for feature a"))
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "st_c": parsing_ops.VarLenFeature(dtypes.float32),
+ "st_d": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array(
+ [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
+ [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp1":
+ parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
+ "sp2":
+ parsing_ops.SparseFeature(
+ "idx", "val2", dtypes.float32, size=7, already_sorted=True)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContaining3DSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = (
+ # indices
+ np.array(
+ [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
+ dtype=np.int64),
+ # values
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ # shape batch == 4, max_elems = 13
+ np.array([4, 13, 3], dtype=np.int64))
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDense(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ original = [
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })), example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ # This test is identical as the previous one except
+ # for the creation of 'serialized'.
+ def testSerializedContainingDenseWithConcat(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ # TODO(lew): Feature appearing twice should be an error in future.
+ original = [
+ (example(features=features({
+ aname: float_feature([10, 10]),
+ })), example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
+ (
+ example(features=features({
+ bname: bytes_feature([b"b100"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),),
+ ]
+
+ serialized = [
+ m.SerializeToString() + n.SerializeToString() for (m, n) in original
+ ]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })), example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ "b": feature()
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
+ 1),
+ "b":
+ np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
+ 1),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ a_default = [1, 2, 3]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ expected_output = {
+ "st_a": expected_st_a,
+ "sp": expected_sp,
+ "a": np.array(2 * [[a_default]]),
+ "b": np.array(2 * [b_default]),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c":
+ parsing_ops.FixedLenFeature((2,), dtypes.float32),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]), np.array(
+ [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "d", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
+ },
+ expected_values=expected_output)
+
+ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
+ # During parsing, data read from the serialized proto is stored in buffers.
+ # For small batch sizes, a buffer will contain one minibatch entry.
+ # For larger batch sizes, a buffer may contain several minibatch
+ # entries. This test identified a bug where the code that copied
+ # data out of the buffers and into the output tensors assumed each
+ # buffer only contained one minibatch entry. The bug has since been fixed.
+ truth_int = [i for i in range(batch_size)]
+ truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
+ for i in range(batch_size)]
+
+ expected_str = copy.deepcopy(truth_str)
+
+ # Delete some intermediate entries
+ for i in range(batch_size):
+ col = 1
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry
+ expected_str[i][col] = b"default"
+ col -= 1
+ truth_str[i].pop()
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry (possibly again)
+ expected_str[i][col] = b"default"
+ truth_str[i].pop()
+
+ expected_output = {
+ # Batch size batch_size, 1 time step.
+ "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
+ # Batch size batch_size, 2 time steps.
+ "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
+ }
+
+ original = [
+ example(features=features(
+ {"a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])}))
+ for i in range(batch_size)
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized, dtype=dtypes.string), {
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingVarLenDenseLargerBatch(self):
+ np.random.seed(3456)
+ for batch_size in (1, 10, 20, 100, 256):
+ self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
+
+ def testSerializedContainingVarLenDense(self):
+ aname = "a"
+ bname = "b"
+ cname = "c"
+ dname = "d"
+ original = [
+ example(features=features({
+ cname: int64_feature([2]),
+ })),
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [
+ [0, 0, 0, 0],
+ [1, 1, 0, 0],
+ [-1, -1, 2, 2],
+ [0, 0, 0, 0],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1),
+ bname:
+ np.array(
+ [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
+ dtype=bytes).reshape(4, 2, 1, 1, 1),
+ cname:
+ np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
+ dname:
+ np.empty(shape=(4, 0), dtype=bytes),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_values=expected_output)
+
+ # Test with padding values.
+ expected_output_custom_padding = dict(expected_output)
+ expected_output_custom_padding[aname] = np.array(
+ [
+ [-2, -2, -2, -2],
+ [1, 1, -2, -2],
+ [-1, -1, 2, 2],
+ [-2, -2, -2, -2],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1)
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=-2.0),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }, expected_output_custom_padding)
+
+ # Change number of required values so the inputs are not a
+ # multiple of this size.
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(
+ errors_impl.OpError, "Key: b, Index: 2. "
+ "Number of bytes values is not a multiple of stride length."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=[]),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Cannot reshape a tensor with 0 elements to shape"))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "First dimension of shape for feature a unknown. "
+ "Consider using FixedLenSequenceFeature."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ cname:
+ parsing_ops.FixedLenFeature(
+ (1, None), dtype=dtypes.int64, default_value=[[1]]),
+ },
+ expected_err=(ValueError,
+ "All dimensions of shape for feature c need to be known "
+ r"but received \(1, None\)."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
+
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 361fe0dd39..0166ba0d44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
@@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = back_to_cpu_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase):
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 592642da0c..db8fe6aa1b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase):
self.assertEqual([tensor_shape.TensorShape([])] * 3,
[t.shape for t in get_next[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
@@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase):
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 15b342d30f..ed75b27a44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -43,29 +43,57 @@ class ReadBatchFeaturesTest(
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
@@ -88,9 +116,9 @@ class ReadBatchFeaturesTest(
init_op = iterator.initializer
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
- for file_batch, _, _, _, record_batch in self._next_expected_batch(
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
@@ -104,7 +132,7 @@ class ReadBatchFeaturesTest(
for batch_size in [1, 2]:
# Test that shuffling with same seed produces the same result.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
@@ -125,7 +153,7 @@ class ReadBatchFeaturesTest(
# Test that shuffling with different seeds produces a different order.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
@@ -152,7 +180,26 @@ class ReadBatchFeaturesTest(
for reader_num_threads in [2, 4]:
for parser_num_threads in [2, 4]:
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
@@ -175,16 +222,20 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
- for _, tensor in outputs.items():
+ for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
- filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
@@ -275,7 +326,7 @@ class MakeCsvDatasetTest(test.TestCase):
filenames = self._setup_files(
inputs, compression_type=kwargs.get("compression_type", None))
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
dataset = self._make_csv_dataset(
filenames,
batch_size=batch_size,
@@ -740,7 +791,7 @@ class MakeCsvDatasetTest(test.TestCase):
total_records = 20
for batch_size in [1, 2]:
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Test that shuffling with the same seed produces the same result
dataset1 = self._make_csv_dataset(
filenames,
@@ -771,7 +822,7 @@ class MakeCsvDatasetTest(test.TestCase):
self.assertAllEqual(batch1[i], batch2[i])
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Test that shuffling with a different seed produces different results
dataset1 = self._make_csv_dataset(
filenames,
@@ -909,7 +960,7 @@ class MakeTFRecordDatasetTest(
fn = None
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
outputs = readers.make_tf_record_dataset(
file_pattern=file_pattern,
num_epochs=num_epochs,
@@ -965,7 +1016,7 @@ class MakeTFRecordDatasetTest(
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
seed=None):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames,
num_epochs=num_epochs,
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index e63bc4c720..08b9f03816 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -76,6 +76,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames,
num_epochs,
batch_size,
+ label_key=None,
reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
@@ -91,8 +92,10 @@ class ReadBatchFeaturesTestBase(test.TestCase):
features={
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
},
+ label_key=label_key,
reader=core_readers.TFRecordDataset,
num_epochs=self.num_epochs,
shuffle=shuffle,
@@ -101,7 +104,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
parser_num_threads=parser_num_threads,
drop_final_batch=drop_final_batch)
- def _record(self, f, r):
+ def _record(self, f, r, l):
example = example_pb2.Example(
features=feature_pb2.Features(
feature={
@@ -114,7 +117,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
"keywords":
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
}))
return example.SerializeToString()
@@ -139,23 +146,30 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
- writer.write(self._record(i, j))
+ writer.write(self._record(i, j, "fake-label"))
writer.close()
return filenames
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
+ keywords_dense_shape_op, record_op, label_op
])
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
@@ -188,7 +202,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
- yield j, i
+ yield j, i, compat.as_bytes("fake-label")
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
@@ -200,6 +214,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
@@ -208,6 +223,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
for record in next_records:
f = record[0]
r = record[1]
+ label_batch.append(record[2])
file_batch.append(f)
record_batch.append(r)
keywords = self._get_keywords(f, r)
@@ -219,7 +235,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
if len(file_batch) == batch_size:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
]
file_batch = []
keywords_batch_indices = []
@@ -227,10 +243,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
if file_batch:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
]
def verify_records(self,
@@ -238,6 +255,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
batch_size,
file_index=None,
num_epochs=1,
+ label_key_provided=False,
interleave_cycle_length=1):
if file_index is not None:
file_indices = [file_index]
@@ -245,8 +263,12 @@ class ReadBatchFeaturesTestBase(test.TestCase):
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c5cfddb72b..16b1441baa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
class_func=lambda c, _: c,
seed=27)).make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
while len(returned) < 4000:
returned.append(sess.run(get_next))
@@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 42cada0b97..dde678bd54 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase):
start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase):
make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 7b9ea191a4..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -318,6 +319,19 @@ py_test(
)
py_test(
+ name = "parse_example_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parse_example_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "prefetch_dataset_serialization_test",
size = "small",
srcs = ["prefetch_dataset_serialization_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 3ed4dfb729..595cecef4d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -252,7 +252,7 @@ class DatasetSerializationTestBase(test.TestCase):
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self._restore(saver, sess)
self._initialize(init_op, sess)
for _ in range(num_outputs):
@@ -315,7 +315,7 @@ class DatasetSerializationTestBase(test.TestCase):
_, get_next_op, saver = self._build_graph(
ds_fn2, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
@@ -376,7 +376,7 @@ class DatasetSerializationTestBase(test.TestCase):
get_next_op, saver = self._build_empty_graph(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
@@ -410,7 +410,7 @@ class DatasetSerializationTestBase(test.TestCase):
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self._initialize(init_op, sess)
for _ in range(break_point):
sess.run(get_next_op)
@@ -510,14 +510,13 @@ class DatasetSerializationTestBase(test.TestCase):
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = get_ops()
get_next_op = remove_variants(get_next_op)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
if ckpt_saved:
if init_before_restore:
self._initialize(init_op, sess)
@@ -616,29 +615,40 @@ class DatasetSerializationTestBase(test.TestCase):
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
-
- # TODO(shivaniagrwal): `output_classes` is a nested structure of classes,
- # this base class is specific to current test cases. Update when tests are
- # added with `output_classes` as a nested structure with at least one of the
- # component being `tf.SparseTensor`.
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
- else:
- for el in nest.flatten(get_next):
- ops.add_to_collection("iterator_ops", el)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- else:
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), all_ops[1:])
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
def _get_output_types(self, ds_fn):
with ops.Graph().as_default():
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
new file mode 100644
index 0000000000..d3fa84e74c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParseExampleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.platform import test
+
+
+class ParseExampleDatasetSerializationTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def ParseExampleDataset(self, num_repeat, batch_size):
+ return self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_repeat,
+ batch_size=batch_size,
+ reader_num_threads=5,
+ parser_num_threads=10)
+
+ def testSerializationCore(self):
+ num_repeat = 5
+ batch_size = 2
+ num_outputs = self._num_records * self._num_files * num_repeat // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self.ParseExampleDataset(
+ num_repeat=num_repeat, batch_size=batch_size),
+ lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
index e4f5b6cf5d..6341190847 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -70,7 +70,7 @@ class RangeDatasetSerializationTest(
break_point = 5
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
@@ -79,7 +79,7 @@ class RangeDatasetSerializationTest(
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
@@ -90,7 +90,7 @@ class RangeDatasetSerializationTest(
# Saving and restoring in same session.
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
index 992d996a48..6aac50ecd9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
@@ -59,7 +59,7 @@ class SerializationIntegrationTest(test.TestCase):
with ops.Graph().as_default() as g:
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_ops)
for _ in range(break_point):
output = sess.run(get_next_ops)
@@ -70,7 +70,7 @@ class SerializationIntegrationTest(test.TestCase):
with ops.Graph().as_default() as g:
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
saver.restore(sess, self._ckpt_path())
for _ in range(num_outputs - break_point):
output = sess.run(get_next_ops)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
index d46c762aaa..a59fa94d66 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -136,7 +136,7 @@ class ShuffleDatasetSerializationTest(
for saveable in saveables:
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = saver_lib.Saver(allow_empty=True)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self._save(sess, saver)
expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
self._restore(saver, sess)
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 3c11d7a97f..440e48db30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase):
def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
get_next = ds_fn().make_one_shot_iterator().get_next()
outputs = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_outputs):
outputs.append(sess.run(get_next))
if verify_exhausted:
@@ -106,7 +106,7 @@ class ShuffleAndRepeatTest(test.TestCase):
ds = dataset_ops.Dataset.range(20).apply(
shuffle_ops.shuffle_and_repeat(buffer_size=21))
get_next_op = ds.make_one_shot_iterator().get_next()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(get_next_op)
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846494..90d18dca2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@ from tensorflow.python.platform import test
class SlideDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDataset(self, count, window_size, window_shift, window_stride):
"""Tests a dataset that slides a window its input elements."""
@@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDatasetDeprecated(self, count, window_size, stride,
window_stride):
@@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (14, 0, 3, 1),
- (14, 3, 0, 1),
- (14, 3, 3, 0),
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
)
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
@@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
window_stride=window_stride_t)).make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Slide: 1st batch.
actual = sess.run(get_next)
@@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 2c2cfbebff..52823d3fca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to verify statelessness of db operations.
sess.run(
init_op,
@@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetJoinQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetNullTerminator(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetReuseSqlDataset(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadEmptyResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidDriverName(self):
init_op = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidColumnName(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfQueryWithSyntaxError(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfInsertQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# table and place it in an `int32` tensor.
def testReadResultSetInt32VarCharColumnAsInt(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index a41d21f8c1..e25570c5ad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -76,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@@ -175,44 +199,5 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-class FeatureStatsDatasetTest(
- stats_dataset_test_base.StatsDatasetTestBase,
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testFeaturesStats(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- batch_size = 2
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5,
- drop_final_batch=True).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(total_records // batch_size):
- sess.run(next_element)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:features", total_records)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:feature-values", total_records)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:features", total_records * 3)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:feature-values",
- self._sum_keywords(1) * num_epochs + 2 * total_records)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 9a13acf8f0..2f5a44408f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
new file mode 100644
index 0000000000..4c3353fe40
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def _assert_datasets_equal(self, dataset1, dataset2):
+ # TODO(rachelim): support sparse tensor outputs
+ next1 = dataset1.make_one_shot_iterator().get_next()
+ next2 = dataset2.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ while True:
+ try:
+ op1 = sess.run(next1)
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next2)
+ break
+ op2 = sess.run(next2)
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ self.assertAllEqual(op1[i], op2[i])
+
+ def _assert_datasets_raise_same_error(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ # We are defining next1 and next2 in the same line so that we get identical
+ # file:line_number in the error messages
+ # pylint: disable=line-too-long
+ next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
+ # pylint: enable=line-too-long
+ with self.cached_session() as sess:
+ try:
+ sess.run(next1)
+ raise ValueError(
+ "Expected dataset to raise an error of type %s, but it did not." %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2bce2..8d335e87d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@ from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
- (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
@@ -60,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
thread_ids = []
try:
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index d79a842e7a..f994c8563f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6754..6eaa0b1959 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(xs, ys)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetFlatMap(self, structure, shape, dtype):
"""Tests windowing by chaining it with flat map.
@@ -92,20 +92,20 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchDense(self, structure, shape, dtype):
"""Tests batching of dense tensor windows.
@@ -128,17 +128,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredElement(structure, np.concatenate(
([5], shape), axis=0), dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchDenseDynamicShape(self, shape):
"""Tests batching of dynamically shaped dense tensor windows.
@@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredElement(None, np.concatenate(([5], shape), axis=0),
@@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchSparse(self, structure, shape, dtype):
"""Tests batching of sparse tensor windows.
@@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredSparseElement(structure,
np.concatenate(([5], shape), axis=0),
@@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchSparseDynamicShape(self, shape):
"""Tests batching of dynamically shaped sparse tensor windows.
@@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredSparseElement(None,
@@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
]))
- @parameterized.parameters(
- (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
)
def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
padded_shape):
@@ -320,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping.window_dataset(len(shapes))).apply(
grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
self._structuredElement(
@@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1], [2], [3]]), [-1]),
- (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1], [2], [3]]), [-1]),
+ ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
"""Tests padded batching of dynamically shaped dense tensor windows.
@@ -351,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
@@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1]]), np.int32([0])),
- (np.int32([[10], [20]]), np.int32([15])),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1]]), np.int32([0])),
+ ("2", np.int32([[10], [20]]), np.int32([15])),
)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of dense tensor windows.
@@ -379,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
padded_shape):
@@ -456,17 +458,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shapes, dtype).apply(grouping.window_dataset(
len(shapes))).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredRaggedSparseElement(structure, shapes, dtype,
padded_shape))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1], [2], [3]]), [-1]),
- (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1], [2], [3]]), [-1]),
+ ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
padded_shape):
@@ -487,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected = sess.run(
self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
@@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1]]), [0]),
- (np.int64([[10], [20]]), [15]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1]]), [0]),
+ ("2", np.int64([[10], [20]]), [15]),
)
def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of sparse tensor windows.
@@ -514,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index c603ecc5ab..867ee2ba37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
def testWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer, feed_dict={
self.filename: self._createFile(),
@@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
@@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index ad9378dfb9..4b45cc7e36 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,17 +80,14 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":parsing_ops",
":shuffle_ops",
- ":stats_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
- "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
@@ -211,6 +208,22 @@ py_library(
)
py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
name = "map_defun",
srcs = ["map_defun.py"],
srcs_version = "PY2AND3",
@@ -331,7 +344,10 @@ py_library(
tf_gen_op_wrapper_py(
name = "gen_dataset_ops",
out = "gen_dataset_ops.py",
- deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
+ deps = [
+ "//tensorflow/contrib/data:dataset_ops_op_lib",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
+ ],
)
tf_kernel_library(
@@ -349,6 +365,7 @@ tf_custom_op_py_library(
dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
kernels = [
":dataset_ops_kernels",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/data:dataset_ops_op_lib",
],
srcs_version = "PY2AND3",
@@ -360,6 +377,19 @@ tf_custom_op_py_library(
)
py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ ":contrib_op_loader",
+ ":gen_dataset_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
@@ -380,6 +410,7 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
+ ":indexed_dataset_ops",
":interleave_ops",
":map_defun",
":optimization",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 9f059942a6..367c159dc5 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
padding_value = 0
def batch_init_fn(_):
- return array_ops.fill(
- array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
- constant_op.constant(padding_value, dtype=dataset.output_types))
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
def batch_reduce_fn(state, value):
return array_ops.concat([state, [value]], 0)
@@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
```python
- shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+ shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])]
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
- print(result.output_shapes) # ==> "((16, 256), <unknown>)"
+ print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))"
```
If dataset shapes and expected_shape, are fully defined, assert they match.
Otherwise, add assert op that will validate the shapes when tensors are
evaluated, and set shapes on tensors, respectively.
+ Note that unknown dimension in `expected_shapes` will be ignored.
+
Args:
expected_shapes: A nested structure of `tf.TensorShape` objects.
@@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes):
`tf.data.Dataset.apply`
"""
+ def _merge_output_shapes(original_shapes, expected_shapes):
+ flat_original_shapes = nest.flatten(original_shapes)
+ flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
+ flat_merged_output_shapes = [
+ original_shape.merge_with(new_shape)
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes)]
+ return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
+
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [
- with_shape(shape, tensor)
+ with_shape(shape, tensor) if shape else tensor # Ignore unknown shape
for shape, tensor in zip(flatten_shapes, flatten_tensors)
]
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
+ output_shapes = _merge_output_shapes(dataset.output_shapes,
+ expected_shapes)
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
- output_shapes=expected_shapes,
+ output_shapes=output_shapes,
output_classes=dataset.output_classes)
return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
new file mode 100644
index 0000000000..a0932b4081
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -0,0 +1,173 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for indexed datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class MaterializedIndexedDataset(object):
+ """MaterializedIndexedDataset is highly experimental!
+ """
+
+ def __init__(self, materialized_resource, materializer, output_classes,
+ output_types, output_shapes):
+ self._materialized_resource = materialized_resource
+ self._materializer = materializer
+ self._output_classes = output_classes
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+
+ @property
+ def initializer(self):
+ if self._materializer is not None:
+ return self._materializer
+ raise ValueError("MaterializedDataset does not have a materializer")
+
+ def get(self, index):
+ """Get retrieves a value (or set of values) from the IndexedDataset.
+
+ Args:
+ index: A uint64 scalar or vector tensor with the indices to retrieve.
+
+ Returns:
+ A tensor containing the values corresponding to `index`.
+ """
+ # TODO(saeta): nest.pack_sequence_as(...)
+ return gen_dataset_ops.indexed_dataset_get(
+ self._materialized_resource,
+ index,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self._output_shapes, self._output_classes)))
+
+
+class IndexedDataset(dataset_ops.Dataset):
+ """IndexedDataset is highly experimental!
+ """
+
+ def __init__(self):
+ pass
+
+ def materialize(self, shared_name=None, container=None):
+ """Materialize creates a MaterializedIndexedDataset.
+
+ IndexedDatasets can be combined through operations such as TBD. Therefore,
+ they are only materialized when absolutely required.
+
+ Args:
+ shared_name: a string for the shared name to use for the resource.
+ container: a string for the container to store the resource.
+
+ Returns:
+ A MaterializedIndexedDataset.
+ """
+ if container is None:
+ container = ""
+ if shared_name is None:
+ shared_name = ""
+ materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes, self.output_classes)))
+
+ with ops.colocate_with(materialized_resource):
+ materializer = gen_dataset_ops.indexed_dataset_materialize(
+ self._as_variant_tensor(), materialized_resource)
+ return MaterializedIndexedDataset(materialized_resource, materializer,
+ self.output_classes, self.output_types,
+ self.output_shapes)
+
+ @abc.abstractproperty
+ def output_types(self):
+ """Returns the type of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_types")
+
+ @abc.abstractproperty
+ def output_classes(self):
+ """Returns the class of each component of an element of this IndexedDataset.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_classes")
+
+ @abc.abstractproperty
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_shapes")
+
+ @abc.abstractmethod
+ def _as_variant_tensor(self):
+ """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.variant` type, which represents this
+ IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset._as_variant_tensor")
+
+
+class IdentityIndexedDataset(IndexedDataset):
+ """IdentityIndexedDataset is a trivial indexed dataset used for testing.
+ """
+
+ def __init__(self, size):
+ super(IdentityIndexedDataset, self).__init__()
+ # TODO(saeta): Verify _size is a scalar!
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
+
+ @property
+ def output_types(self):
+ return dtypes.uint64
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.identity_indexed_dataset(self._size)
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 5a1a35199a..92d4251a86 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -163,7 +163,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
for data_input in data_inputs[1:]:
if (data_input.output_types != data_inputs[0].output_types or
data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type.")
+ raise TypeError("All datasets must have the same type and class.")
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -216,25 +216,59 @@ def sample_from_datasets(datasets, weights=None, seed=None):
length of the `datasets` element.
"""
num_datasets = len(datasets)
- if weights is None:
- weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat()
- elif not isinstance(weights, dataset_ops.Dataset):
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError("`weights` must be a vector of length `len(datasets)`.")
- weights = dataset_ops.Dataset.from_tensors(weights).repeat()
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
- def select_dataset(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
- selector_input = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
return _DirectedInterleaveDataset(selector_input, datasets)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6da0..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- if not all(s.is_fully_defined() for s in output_shapes):
- raise ValueError("All fn output shapes must be fully defined.")
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
new file mode 100644
index 0000000000..2701605e64
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -0,0 +1,150 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+
+
+class _ParseExampleDataset(dataset_ops.Dataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 3882d4bfdb..4c466781f7 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,8 +25,8 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import convert
@@ -37,7 +37,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -326,7 +325,6 @@ def make_csv_dataset(
shuffle_seed=None,
prefetch_buffer_size=1,
num_parallel_reads=1,
- num_parallel_parser_calls=2,
sloppy=False,
num_rows_for_inference=100,
compression_type=None,
@@ -393,8 +391,6 @@ def make_csv_dataset(
batches consumed per training step.
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
- num_parallel_parser_calls: Number of parallel invocations of the CSV parsing
- function on CSV records.
sloppy: If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
@@ -503,7 +499,8 @@ def make_csv_dataset(
# indefinitely, and all batches will be full-sized.
dataset = dataset.batch(batch_size=batch_size,
drop_remainder=num_epochs is None)
- dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -663,6 +660,7 @@ def make_batched_features_dataset(file_pattern,
batch_size,
features,
reader=core_readers.TFRecordDataset,
+ label_key=None,
reader_args=None,
num_epochs=None,
shuffle=True,
@@ -675,6 +673,9 @@ def make_batched_features_dataset(file_pattern,
drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
Example:
```
@@ -725,6 +726,9 @@ def make_batched_features_dataset(file_pattern,
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
@@ -750,8 +754,11 @@ def make_batched_features_dataset(file_pattern,
`False`.
Returns:
- A dataset of `dict` elements. Each `dict` maps feature keys to
- `Tensor` or `SparseTensor` objects.
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
@@ -772,14 +779,13 @@ def make_batched_features_dataset(file_pattern,
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset.map(lambda _, v: v)
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
# Apply dataset repeat and shuffle transformations.
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
- dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
-
# NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
# improve the shape inference, because it makes the batch dimension static.
# It is safe to do this because in that case we are repeating the input
@@ -788,13 +794,17 @@ def make_batched_features_dataset(file_pattern,
batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.map(
- lambda x: parsing_ops.parse_example(x, features),
- num_parallel_calls=parser_num_threads)
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
- # TODO(rachelim): Add an optional label_name argument for extracting the label
- # from the features dictionary, to comply with the type expected by the
- # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function.
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -974,3 +984,49 @@ class SqlDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
+
+
+class LMDBDataset(dataset_ops.Dataset):
+ """A LMDB Dataset that reads the lmdb file."""
+
+ def __init__(self, filenames):
+ """Create a `LMDBDataset`.
+
+ `LMDBDataset` allows a user to read data from a mdb file as
+ (key value) pairs sequentially.
+ For example:
+ ```python
+ dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb")
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the (key, value) pairs inside a lmdb file.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ """
+ super(LMDBDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+
+ def _as_variant_tensor(self):
+ return contrib_gen_dataset_ops.lmdb_dataset(
+ self._filenames,
+ output_types=nest.flatten(self.output_types),
+ output_shapes=nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor, ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return dtypes.string, dtypes.string
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3b4e981402..8426228992 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -178,29 +178,6 @@ def latency_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def feature_stats(tag):
- """Records the features stats from `Example` records of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will be
- associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag)
-
- return _apply_fn
-
-
class _StatsDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index d3628d480d..a87a5624c8 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -29,12 +29,13 @@ py_library(
"//tensorflow/contrib/distribute/python:cross_tower_ops",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
"//tensorflow/contrib/distribute/python:monitor",
- "//tensorflow/contrib/distribute/python:multi_worker_strategy",
"//tensorflow/contrib/distribute/python:one_device_strategy",
"//tensorflow/contrib/distribute/python:parameter_server_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:distribute_config",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 2f5dd10550..30e1992c01 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -1,6 +1,6 @@
# Distribution Strategy
-> *NOTE*: This is a experimental feature. The API and performance
+> *NOTE*: This is an experimental feature. The API and performance
> characteristics are subject to change.
## Overview
@@ -9,29 +9,111 @@
API is an easy way to distribute your training
across multiple devices/machines. Our goal is to allow users to use existing
models and training code with minimal changes to enable distributed training.
-Moreover, we've design the API in such a way that it works with both eager and
+Moreover, we've designed the API in such a way that it works with both eager and
graph execution.
-Currently we support one type of strategy, called
-[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy).
-It does in-graph replication with synchronous training
+Currently we support several types of strategies:
+
+* [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy):
+This does in-graph replication with synchronous training
on many GPUs on one machine. Essentially, we create copies of all variables in
the model's layers on each device. We then use all-reduce to combine gradients
across the devices before applying them to the variables to keep them in sync.
-In the future, we intend to support other kinds of training configurations such
-as multi-node, synchronous,
-[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program),
-parameter servers and model parallelism.
+* [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy):
+This is a version of `MirroredStrategy` for multi-working training. It uses
+a collective op to do all-reduce. This supports between-graph communication and
+synchronization, and delegates the specifics of the all-reduce implementation to
+the runtime (as opposed to encoding it in the graph). This allows it to perform
+optimizations like batching and switch between plugins that support different
+hardware or algorithms. In the future, this strategy will implement
+fault-tolerance to allow training to continue when there is worker failure.
+
+* [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy):
+This strategy supports using parameter servers either for multi-GPU local
+training or asynchronous multi-machine training. When used to train locally,
+variables are not mirrored, instead they placed on the CPU and operations are
+replicated across all local GPUs. In a multi-machine setting, some are
+designated as workers and some as parameter servers. Each variable is placed on
+one parameter server. Computation operations are replicated across all GPUs of
+the workers.
+
+## Multi-GPU Training
+
+## Example with Keras API
+
+Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras).
+
+Take a very simple model consisting of a single layer:
+
+```python
+inputs = tf.keras.layers.Input(shape=(1,))
+predictions = tf.keras.layers.Dense(1)(inputs)
+model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
+```
-## Example
+Let's also define a simple input dataset for training this model. Note that currently we require using
+[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
+with `DistributionStrategy`.
+
+```python
+features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+train_dataset = tf.data.Dataset.zip((features, labels))
+```
-Let's demonstrate how to use this API with a simple example. We will use the
-[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
-approach, and show you how to scale your model to run on multiple GPUs on one
-machine using `MirroredStrategy`.
-Let's consider a very simple model function which tries to learn a simple
-function.
+To distribute this Keras model on multiple GPUs using `MirroredStrategy` we
+first instantiate a `MirroredStrategy` object.
+
+```python
+distribution = tf.contrib.distribute.MirroredStrategy()
+```
+
+We then compile the Keras model and pass the `MirroredStrategy` object in the
+`distribute` argument (apart from other usual arguments like `loss` and
+`optimizer`).
+
+```python
+model.compile(loss='mean_squared_error',
+ optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
+ distribute=strategy)
+```
+
+To train the model we call Keras `fit` API using the input dataset that we
+created earlier, same as how we would in a non-distributed case.
+
+```python
+model.fit(train_dataset, epochs=5, steps_per_epoch=10)
+```
+
+Similarly, we can also call `evaluate` and `predict` as before using appropriate
+datasets.
+
+```python
+model.evaluate(eval_dataset)
+model.predict(predict_dataset)
+```
+
+That's all you need to train your model with Keras on multiple GPUs with
+`MirroredStrategy`. It will take care of splitting up
+the input dataset, replicating layers and variables on each device, and
+combining and applying gradients.
+
+The model and input code does not have to change because we have changed the
+underlying components of TensorFlow (such as
+optimizer, batch norm and summaries) to become distribution-aware.
+That means those components know how to
+combine their state across devices. Further, saving and checkpointing works
+seamlessly, so you can save with one or no distribution strategy and resume with
+another.
+
+
+## Example with Estimator API
+
+You can also use Distribution Strategy API with [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Let's see a simple example of it's usage with `MirroredStrategy`.
+
+
+Consider a very simple model function which tries to learn a simple function.
```python
def model_fn(features, labels, mode):
@@ -53,17 +135,14 @@ def model_fn(features, labels, mode):
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
```
-Let's also define a simple input function to feed data for training this model.
-Note that we require using
-[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
-with `DistributionStrategy`.
+Again, let's define a simple input function to feed data for training this model.
```python
def input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
labels = tf.data.Dataset.from_tensors(1.).repeat(100)
- return dataset_ops.Dataset.zip((features, labels))
+ return tf.data.Dataset.zip((features, labels))
```
Now that we have a model function and input function defined, we can define the
@@ -80,20 +159,14 @@ distribution = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)
+classifier.evaluate(input_fn=input_fn)
```
That's it! This change will now configure estimator to run on all GPUs on your
-machine, with the `MirroredStrategy` approach. It will take care of distributing
-the input dataset, replicating layers and variables on each device, and
-combining and applying gradients.
+machine.
-The model and input functions do not have to change because we have changed the
-underlying components of TensorFlow (such as
-optimizer, batch norm and summaries) to become distribution-aware.
-That means those components know how to
-combine their state across devices. Further, saving and checkpointing works
-seamlessly, so you can save with one or no distribution strategy and resume with
-another.
+
+## Customization and Performance Tips
Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__).
There are few things you can customize in practice:
@@ -103,8 +176,6 @@ of GPUs (using param `num_gpus`), in case you don't want auto detection.
* You can specify various parameters for all reduce with the `cross_tower_ops`
param, such as the all reduce algorithm to use, and gradient repacking.
-## Performance Tips
-
We've tried to make it such that you get the best performance for your existing
model. We also recommend you follow the tips from
[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance).
@@ -113,15 +184,177 @@ and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_perform
in the input function gives a solid boost in performance. When using
`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size.
+## Multi-worker Training
+### Overview
+
+For multi-worker training, no code change is required to the `Estimator` code.
+You can run the same model code for all tasks in your cluster including
+parameter servers and the evaluator. But you need to use
+`tf.estimator.train_and_evaluator`, explicitly specify `num_gpus_per_workers`
+for your strategy object, and set "TF\_CONFIG" environment variables for each
+binary running in your cluster. We'll provide a Kubernetes template in the
+[tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets
+"TF\_CONFIG" for your training tasks.
+
+### TF\_CONFIG environment variable
+
+The "TF\_CONFIG" environment variables is a JSON string which specifies what
+tasks constitute a cluster, their addresses and each task's role in the cluster.
+One example of "TF\_CONFIG" is:
+
+```python
+TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"],
+ "ps": ["host4:port", "host5:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+}'
+```
+
+This "TF\_CONFIG" specifies that there are three workers and two ps tasks in the
+cluster along with their hosts and ports. The "task" part specifies that the
+role of the current task in the cluster, worker 1. Valid roles in a cluster is
+"chief", "worker", "ps" and "evaluator". There should be no "ps" job for
+`CollectiveAllReduceStrategy` and `MirroredStrategy`. The "evaluator" job is
+optional and can have at most one task. It does single machine evaluation and if
+you don't want to do evaluation, you can pass in a dummy `input_fn` to the
+`tf.estimator.EvalSpec` of `tf.estimator.train_and_evaluate`.
+
+### Dataset
+
+The `input_fn` you provide to estimator code is for one worker. So remember to
+scale up your batch if you have multiple GPUs on each worker.
+
+The same `input_fn` will be used for all workers if you use
+`CollectiveAllReduceStrategy` and `ParameterServerStrategy`. Therefore it is
+important to shuffle your dataset in your `input_fn`.
+
+`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
+`input_fn`. As a result, each worker gets a fraction of your input data.
+
+### Performance Tips
+
+We have been actively working on multi-worker performance. Currently, prefer
+`CollectiveAllReduceStrategy` for synchronous multi-worker training.
+
+### Example
+
+Let's use the same example for multi-worker. We'll start a cluster with 3
+workers doing synchronous all-reduce training. In the following code snippet, we
+start multi-worker training using `tf.estimator.train_and_evaluate`:
+
+
+```python
+def model_main():
+ estimator = ...
+ distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+ config = tf.estimator.RunConfig(train_distribute=distribution)
+ train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+ eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+
+**Note**: You don't have to set "TF\_CONFIG" manually if you use our provided
+Kubernetes template.
+
+You'll then need 3 machines, find out their host addresses and one available
+port on each machine. Then set "TF\_CONFIG" in each binary and run the above
+model code.
+
+In your worker 0, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 0}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 1, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 2, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 2}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+Then you'll find your cluster has started training! You can inspect the logs of
+workers or start a tensorboard.
+
+### Standalone client mode
+
+We have a new way to run distributed training. You can bring up standard
+tensorflow servers in your cluster and run your model code anywhere such as on
+your laptop.
+
+In the above example, instead of calling `model_main`, you can call
+`tf.contrib.distribute.run_standard_tensorflow_server().join()`. This will bring
+up a cluster running standard tensorflow servers which wait for your request to
+start training.
+
+On your laptop, you can run
+
+```python
+estimator = ...
+distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+config = tf.estimator.RunConfig(
+ experimental_distribute=tf.contrib.distribute.DistributeConfig(
+ train_distribute=distribution,
+ remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]}))
+train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+Then you will see the training logs on your laptop. You can terminate the
+training by terminating your process on your laptop. You can also modify your
+code and run a new model against the same cluster.
+
+We've been optimizing the performance of standalone client mode. If you notice
+high latency between your laptop and your cluster, you can reduce that latency
+by running your model binary in the cluster.
+
## Caveats
+
This feature is in early stages and there are a lot of improvements forthcoming:
* Summaries are only computed in the first tower in `MirroredStrategy`.
-* Evaluation is not yet distributed.
* Eager support is in the works; performance can be more challenging with eager
execution.
-* As mentioned earlier, multi-node and other distributed strategies will be
-introduced in the future.
+* We currently support the following predefined Keras callbacks:
+`ModelCheckpointCallback`, `TensorBoardCallback`. We will soon be adding support for
+some of the other callbacks such as `EarlyStopping`, `ReduceLROnPlateau`, etc. If you
+create your own callback, you will not have access to all model properties and
+validation data.
* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch)
your input data, we will place one batch on each GPU in each step. So your
effective batch size will be `num_gpus * batch_size`. Therefore, consider
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 2c93ce92ce..350f81f60f 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -23,11 +23,12 @@ from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
-from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
+from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server
from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import *
@@ -38,9 +39,9 @@ _allowed_symbols = [
'AllReduceCrossTowerOps',
'CollectiveAllReduceStrategy',
'CrossTowerOps',
+ 'DistributeConfig',
'DistributionStrategy',
'MirroredStrategy',
- 'MultiWorkerMirroredStrategy',
'Monitor',
'OneDeviceStrategy',
'ParameterServerStrategy',
@@ -56,6 +57,7 @@ _allowed_symbols = [
'get_tower_context',
'has_distribution_strategy',
'require_tower_context',
+ 'run_standard_tensorflow_server',
'UpdateContext',
]
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 40a1c1707c..87f76eaa94 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -23,8 +23,6 @@ py_library(
deps = [
":input_ops",
":prefetching_ops_v2",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/eager/python:datasets",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -72,48 +70,72 @@ py_library(
":cross_tower_ops",
":shared_variable_creator",
":values",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:device",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
- "@six_archive//:six",
],
)
py_library(
- name = "multi_worker_strategy",
- srcs = ["multi_worker_strategy.py"],
+ name = "parameter_server_strategy",
+ srcs = ["parameter_server_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
+ ":cross_tower_ops",
":mirrored_strategy",
":values",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
],
)
-py_library(
- name = "parameter_server_strategy",
- srcs = ["parameter_server_strategy.py"],
- visibility = ["//tensorflow:internal"],
- deps = [
- ":cross_tower_ops",
- ":mirrored_strategy",
+cuda_py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
":values",
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
- "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -147,6 +169,7 @@ py_library(
"//tensorflow/python:collective_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
],
)
@@ -184,7 +207,6 @@ py_library(
],
deps = [
":mirrored_strategy",
- ":multi_worker_strategy",
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
@@ -219,9 +241,13 @@ py_test(
],
deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":strategy_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -243,48 +269,21 @@ py_test(
],
)
-py_test(
- name = "parameter_server_strategy_test",
- srcs = ["parameter_server_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- ":combinations",
- ":multi_worker_test_base",
- ":parameter_server_strategy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:layers",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
additional_deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":values",
":strategy_test_lib",
"//tensorflow/python:distribute",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
- "//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -345,19 +344,17 @@ py_library(
],
)
-py_test(
+cuda_py_test(
name = "collective_all_reduce_strategy_test",
srcs = ["collective_all_reduce_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
+ additional_deps = [
":collective_all_reduce_strategy",
":combinations",
":cross_tower_utils",
":multi_worker_test_base",
":strategy_test_lib",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -371,8 +368,10 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:estimator_py",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -452,6 +451,35 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "estimator_training_test",
+ size = "large",
+ srcs = ["estimator_training_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":mirrored_strategy",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ ],
+ tags = [
+ "manual",
+ "multi_and_single_gpu",
+ "no_pip",
+ "nogpu",
+ "notap",
+ ],
+)
+
py_library(
name = "single_loss_example",
srcs = ["single_loss_example.py"],
@@ -607,6 +635,7 @@ cuda_py_test(
":combinations",
":cross_tower_ops",
":multi_worker_test_base",
+ ":mirrored_strategy",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
@@ -679,19 +708,32 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "keras_test",
+py_library(
+ name = "keras_test_lib",
+ testonly = 1,
srcs = ["keras_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
+ deps = [
+ ":combinations",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "keras_test",
+ srcs = ["keras_test.py"],
+ additional_deps = [
+ ":keras_test_lib",
],
tags = [
"multi_and_single_gpu",
+ "no_pip",
"no_windows_gpu",
"notsan",
],
diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
index bcb977f640..865dba803f 100644
--- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
+++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
@@ -48,7 +48,7 @@ class CheckpointUtilsWithDistributionStrategyTest(
mode=["graph"]))
def testInitFromCheckpoint(self, distribution, in_tower_mode):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints(
session, checkpoint_dir)
@@ -62,7 +62,7 @@ class CheckpointUtilsWithDistributionStrategyTest(
"var1": "new_var1",
"var2": "new_var2"
})
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
self.assertAllEqual(v2_value, self.evaluate(v2))
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 9afcaecf78..77079d0df9 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -18,96 +18,96 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import os
-
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
-from tensorflow.python.training import server_lib
-
+from tensorflow.python.platform import tf_logging as logging
-# TODO(yuefengz): move this function to a common util file.
-def _normalize_cluster_spec(cluster_spec):
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- return server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- return cluster_spec
-
-# TODO(yuefengz): shard the dataset.
# TODO(yuefengz): support in-graph replication.
-# TODO(yuefengz): it only works with a cluster without a chief node, maybe
-# support chief node?
class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Distribution strategy that uses collective ops for all-reduce.
It is similar to the MirroredStrategy but it uses collective ops for
- reduction. It currently only works for between-graph replication and its
- reduction will reduce across all workers.
+ reduction.
+
+ When `cluster_spec` is given by the `configure` method, it turns into the
+ mulit-worker version that works on multiple workers with between-graph
+ replication.
+
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type="worker",
- task_id=0):
+ def __init__(self, num_gpus_per_worker=0):
"""Initializes the object.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type, such as "worker".
- task_id: the current task id.
-
- Raises:
- ValueError: if `task_type` is not in the `cluster_spec`.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
"""
self._num_gpus_per_worker = num_gpus_per_worker
- self._initialize(cluster_spec, task_type, task_id)
+ self._initialize_local_worker(num_gpus_per_worker)
+
+ def _initialize_local_worker(self, num_gpus_per_worker):
+ """Initializes the object for local training."""
+ self._is_chief = True
+ self._num_workers = 1
- def _initialize(self, cluster_spec, task_type, task_id):
+ if num_gpus_per_worker:
+ local_devices = [
+ "/device:GPU:%d" % i for i in range(num_gpus_per_worker)
+ ]
+ else:
+ local_devices = ["/device:CPU:0"]
+
+ self._collective_keys = cross_tower_utils.CollectiveKeys()
+ super(CollectiveAllReduceStrategy, self).__init__(
+ devices=local_devices,
+ cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
+ num_workers=1,
+ num_gpus_per_worker=num_gpus_per_worker,
+ collective_keys=self._collective_keys))
+
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info("CollectiveAllReduceStrategy with local_devices = %r",
+ local_devices)
+
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initializes the object for multi-worker training."""
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
if task_type not in ["chief", "worker"]:
raise ValueError(
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\"." % task_type)
- if cluster_spec:
- self._cluster_spec = _normalize_cluster_spec(cluster_spec)
- worker_device = "/job:%s/task:%d" % (task_type, task_id)
- num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
- if "chief" in self._cluster_spec.as_dict():
- num_workers += 1
- if not num_workers:
- raise ValueError("`task_type` shoud be in `cluster_spec`.")
-
- # TODO(yuefengz): create a utility to infer chief.
- if "chief" in self._cluster_spec.as_dict() and task_type == "chief":
- assert task_id == 0
- self._is_chief = True
- else:
- assert task_type == "worker"
- self._is_chief = task_id == 0
- else:
- self._cluster_spec = None
- self._is_chief = True
- worker_device = ""
- num_workers = 1
- self._num_workers = num_workers
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len(
+ cluster_spec.as_dict().get("chief", []))
+ if not self._num_workers:
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
- if self._num_gpus_per_worker:
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ if num_gpus_per_worker:
local_devices = [
"%s/device:GPU:%d" % (worker_device, i)
- for i in range(self._num_gpus_per_worker)
+ for i in range(num_gpus_per_worker)
]
else:
local_devices = [worker_device]
@@ -116,14 +116,23 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
super(CollectiveAllReduceStrategy, self).__init__(
devices=local_devices,
cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
- num_workers=num_workers,
- num_gpus_per_worker=self._num_gpus_per_worker,
+ num_workers=self._num_workers,
+ num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys))
# Add a default device so that ops without specified devices will not end up
# on other workers.
- if cluster_spec:
- self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id)
+ self._default_device = "/job:%s/task:%d" % (task_type, task_id)
+
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker CollectiveAllReduceStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_workers = %r, local_devices = %r", cluster_spec.as_dict(),
+ task_type, task_id, self._num_workers, local_devices)
def _create_variable(self, next_creator, *args, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
@@ -187,19 +196,81 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
return mirrored_strategy._create_mirrored_variable(
devices, _real_mirrored_creator, *args, **kwargs)
- def configure(self, session_config=None):
- # Use TF_CONFIG to get the cluster spec and the current job.
- if not self._cluster_spec:
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+ def distribute_dataset(self, dataset_fn):
+ """Distributes the dataset to each local GPU."""
+ # TODO(yuefengz): shard the dataset.
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices, True)
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", "worker")
- task_id = int(task_env.get("index", "0"))
- else:
- task_type = "worker"
- task_id = 0
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the object.
+
+ Args:
+ session_config: a @{tf.ConfigProto}
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ if not self._cluster_spec and cluster_spec:
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
+ self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
+ task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ session_config.isolate_session_state = True
+
+ assert self._task_type
+ assert self._task_id is not None
+
+ # Collective group leader is needed for collective ops to coordinate
+ # workers.
+ if "chief" in self._cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in self._cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ # The device filters prevent communication between workers.
+ del session_config.device_filters[:]
+ session_config.device_filters.append(
+ "/job:%s/task:%d" % (self._task_type, self._task_id))
+
+ # The scoped_allocator_optimization is to optimize graphs for collective
+ # ops.
+ rewrite_options = session_config.graph_options.rewrite_options
+ rewrite_options.scoped_allocator_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ del rewrite_options.scoped_allocator_opts.enable_op[:]
+ rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
+
+ @property
+ def between_graph(self):
+ return True
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
- if cluster_spec:
- self._initialize(cluster_spec, task_type, task_id)
+ @property
+ def should_save_summary(self):
+ return self._is_chief
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index b5e54e3b7d..36e9761073 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -41,53 +39,46 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class DistributedCollectiveAllReduceStrategyTest(
- multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+class CollectiveAllReduceStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
- @classmethod
- def setUpClass(cls):
- """Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ]
- }
-
def setUp(self):
self._run_options = config_pb2.RunOptions()
self._run_options.experimental.collective_graph_key = 6
self._sess_config = config_pb2.ConfigProto()
- self._sess_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
# tests.
- DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
- super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
+ super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus,
- cluster_spec=self._cluster_spec,
- task_type=task_type,
- task_id=task_id)
+ num_gpus_per_worker=num_gpus)
+ if task_type and task_id is not None:
+ distribution.configure(
+ session_config=self._sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution._collective_keys = collective_keys
distribution._cross_tower_ops._collective_keys = collective_keys
- return distribution, self._workers[task_id].target
+ if task_type and task_id is not None:
+ return distribution, 'grpc://' + self._cluster_spec[task_type][task_id]
+ else:
+ return distribution, ''
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_object(task_type, task_id, num_gpus)
@@ -155,12 +146,6 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertLess(error_after, error_before)
return error_after < error_before
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testMinimizeLossGraph(self, num_gpus):
- self._run_between_graph_clients(self._test_minimize_loss_graph,
- self._cluster_spec, num_gpus)
-
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -182,16 +167,37 @@ class DistributedCollectiveAllReduceStrategyTest(
distribution.reduce(
variable_scope.VariableAggregation.MEAN, x,
destinations='/cpu:0'))[0]
+ x = distribution.unwrap(x)[0]
sess.run(
variables.global_variables_initializer(), options=self._run_options)
+
x_value, reduced_x_value = sess.run(
[x, reduced_x], options=self._run_options)
- self.assertTrue(np.array_equal(x_value, reduced_x_value))
- return np.array_equal(x_value, reduced_x_value)
+ self.assertTrue(
+ np.allclose(x_value, reduced_x_value, atol=1e-5),
+ msg=('x_value = %r, reduced_x_value = %r' % (x_value,
+ reduced_x_value)))
+ return np.allclose(x_value, reduced_x_value, atol=1e-5)
+
+
+class DistributedCollectiveAllReduceStrategyTest(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testVariableInitialization(self, num_gpus):
if context.num_gpus() < num_gpus:
return
@@ -201,16 +207,44 @@ class DistributedCollectiveAllReduceStrategyTest(
num_gpus=num_gpus)
-class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
- parameterized.TestCase):
+class DistributedCollectiveAllReduceStrategyTestWithChief(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0, has_chief=True)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
+ self._run_options.experimental.collective_graph_key = 7
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class LocalCollectiveAllReduceStrategy(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
def testMinimizeLossGraph(self, num_gpus=2):
# Collective ops doesn't support strategy with one device.
if context.num_gpus() < num_gpus:
return
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus)
- self._test_minimize_loss_graph(distribution)
+ self._test_minimize_loss_graph(None, None, num_gpus)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index aeec9c44d7..1133be6d0b 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -48,7 +48,6 @@ import six
from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
-from tensorflow.contrib.distribute.python import multi_worker_strategy
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
@@ -329,6 +328,10 @@ tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=5),
required_tpu=True)
+tpu_strategy_one_step = NamedDistribution(
+ "TPU", lambda: tpu_lib.TPUStrategy(
+ TPUClusterResolver(""), steps_per_run=1),
+ required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
@@ -342,33 +345,6 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
-multi_worker_strategy_with_cpu = NamedDistribution(
- "MultiWorkerCPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=0), 0)
-multi_worker_strategy_with_one_gpu = NamedDistribution(
- "MultiWorker1GPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=1), 1)
-multi_worker_strategy_with_two_gpus = NamedDistribution(
- "MultiWorker2GPUs",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=2), 2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 3a7addf221..e08ba9c2a6 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -35,13 +35,13 @@ from tensorflow.python.training import device_util
def check_destinations(destinations):
- """Checks whether `destinations` is not None and not empty.
+ """Checks whether `destinations` is not empty.
Args:
destinations: a DistributedValues, Variable, string or a list of strings.
Returns:
- Boolean indicating whether `destinations` is not None and not empty.
+ Boolean which is True if `destinations` is not empty.
"""
# Calling bool() on a ResourceVariable is not allowed.
if isinstance(destinations, resource_variable_ops.ResourceVariable):
@@ -53,16 +53,53 @@ def validate_destinations(destinations):
if not isinstance(
destinations,
(value_lib.DistributedValues, resource_variable_ops.ResourceVariable,
- six.string_types, list)):
+ value_lib.AggregatingVariable, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, a device string, a list of device "
- "strings or None")
+ "strings")
if not check_destinations(destinations):
raise ValueError("destinations can not be empty")
+def _make_tensor_into_per_device(input_tensor):
+ """Converts a single tensor into a PerDevice object."""
+ if isinstance(input_tensor, (tuple, list)):
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, "
+ "got %r but expected a object that is not a tuple or list."
+ % (input_tensor,))
+ if isinstance(input_tensor, value_lib.PerDevice):
+ return input_tensor
+
+ try:
+ device = input_tensor.device
+ except AttributeError:
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object "
+ "because it doesn't have device set.")
+
+ return value_lib.PerDevice({device: input_tensor})
+
+
+def _normalize_value_destination_pairs(value_destination_pairs):
+ """Converts each tensor into a PerDevice object in the input list."""
+ result = []
+ if not isinstance(value_destination_pairs, (list, tuple)):
+ raise ValueError("`value_destination_pairs` should be a list or tuple")
+ for pair in value_destination_pairs:
+ if not isinstance(pair, tuple):
+ raise ValueError(
+ "Each element of `value_destination_pairs` should be a tuple.")
+ if len(pair) != 2:
+ raise ValueError("Each element of `value_destination_pairs` should be a "
+ "tuple of size 2.")
+
+ per_device = _make_tensor_into_per_device(pair[0])
+ result.append((per_device, pair[1]))
+ return result
+
+
def _validate_value_destination_pairs(value_destination_pairs):
+ # TODO(yuefengz): raise exceptions instead of returning False.
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
@@ -78,12 +115,15 @@ def _validate_value_destination_pairs(value_destination_pairs):
def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
- elif isinstance(destinations, resource_variable_ops.ResourceVariable):
+ elif isinstance(destinations, (resource_variable_ops.ResourceVariable,
+ value_lib.AggregatingVariable)):
return [destinations.device]
elif isinstance(destinations, six.string_types):
return [device_util.resolve(destinations)]
- else:
+ elif isinstance(destinations, (list, tuple)):
return [device_util.resolve(destination) for destination in destinations]
+ else:
+ return [destinations.device]
def _devices_match(left, right):
@@ -91,8 +131,7 @@ def _devices_match(left, right):
def _all_devices_match(value_destination_pairs):
- if not all([d is None or _devices_match(v, d)
- for v, d in value_destination_pairs]):
+ if not all([_devices_match(v, d) for v, d in value_destination_pairs]):
return False
if not all([_devices_match(v, value_destination_pairs[0][0])
for v, _ in value_destination_pairs[1:]]):
@@ -149,7 +188,7 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, aggregation, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations):
"""Reduce `per_device_value` to `destinations`.
It runs the reduction operation defined by `aggregation` and put the
@@ -158,7 +197,7 @@ class CrossTowerOps(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
- per_device_value: a PerDevice object.
+ per_device_value: a PerDevice object or a tensor with device set.
destinations: the reduction destinations.
Returns:
@@ -168,9 +207,9 @@ class CrossTowerOps(object):
ValueError: if per_device_value is not a PerDevice object.
"""
if not isinstance(per_device_value, value_lib.PerDevice):
- raise ValueError("`per_device_value` must be a `PerDevice` object.")
- if destinations is not None:
- validate_destinations(destinations)
+ per_device_value = _make_tensor_into_per_device(per_device_value)
+
+ validate_destinations(destinations)
return self._reduce(aggregation, per_device_value, destinations)
def batch_reduce(self, aggregation, value_destination_pairs):
@@ -183,8 +222,7 @@ class CrossTowerOps(object):
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
- and destinations. If a destination is None, then the destinations
- are set to match the devices of the input PerDevice object.
+ (or tensors with device set if there is one tower) and destinations.
Returns:
a list of Mirrored objects.
@@ -194,11 +232,13 @@ class CrossTowerOps(object):
tuples of PerDevice objects and destinations
"""
if not _validate_value_destination_pairs(value_destination_pairs):
- raise ValueError("`value_destination_pairs` must be a list or a tuple of "
- "tuples of PerDevice objects and destinations")
+ # If the first element of each pair is a tensor, we try to turn it into a
+ # PerDevice object.
+ value_destination_pairs = _normalize_value_destination_pairs(
+ value_destination_pairs)
+
for _, d in value_destination_pairs:
- if d is not None:
- validate_destinations(d)
+ validate_destinations(d)
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -528,7 +568,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
- if ((destinations is None or _devices_match(per_device_value, destinations))
+ if (_devices_match(per_device_value, destinations)
and not context.executing_eagerly()
and not contains_indexed_slices):
return self._batch_all_reduce(aggregation, [per_device_value])[0]
@@ -557,8 +597,10 @@ class AllReduceCrossTowerOps(CrossTowerOps):
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
- logging.warning("Efficient batch_reduce is not supported if "
- "destinations are different.")
+ logging.log_first_n(logging.WARN,
+ "Efficient batch_reduce is not supported if "
+ "destinations are different.",
+ 10)
return [
self._reduce(aggregation, t, destinations=v)
@@ -737,7 +779,7 @@ class CollectiveAllReduce(CrossTowerOps):
def __init__(self,
num_workers=1,
num_gpus_per_worker=0,
- all_reduce_merge_scope=1,
+ all_reduce_merge_scope=32,
collective_keys=None):
"""Initializes the object.
@@ -756,10 +798,17 @@ class CollectiveAllReduce(CrossTowerOps):
)
super(CollectiveAllReduce, self).__init__()
- # TODO(yuefengz, tucker): is index slices supported by collective ops?
+ # TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, aggregation, per_device_value, destinations):
+ if cross_tower_utils.contains_indexed_slices(per_device_value):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
- if destinations is None or _devices_match(per_device_value, destinations):
+ if _devices_match(per_device_value, destinations):
return all_reduced
else:
index = {}
@@ -768,20 +817,40 @@ class CollectiveAllReduce(CrossTowerOps):
if d in all_reduced._index:
index[d] = all_reduced._index[d]
else:
- with ops.device(d):
+ with ops.control_dependencies(list(
+ all_reduced._index.values())), ops.device(d):
index[d] = array_ops.identity(list(all_reduced._index.values())[0])
+
return value_lib.Mirrored(index)
def _batch_reduce(self, aggregation, value_destination_pairs):
- return [
- self._reduce(aggregation, t, destinations=v)
- for t, v in value_destination_pairs
- ]
+ if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
+ all_devices_match = _all_devices_match(value_destination_pairs)
+ if all_devices_match:
+ return self._batch_all_reduce(aggregation,
+ [v[0] for v in value_destination_pairs])
+ else:
+ if not all_devices_match:
+ logging.log_first_n(
+ logging.WARN, "Efficient batch_reduce is not supported if "
+ "destinations are different.", 10)
+
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _batch_all_reduce(self, aggregation, per_device_values):
"""All-reduce across all workers in a batch."""
if context.executing_eagerly():
- raise ValueError("Eager mode with collective ops is not supported yet.")
+ raise ValueError(
+ "Eager execution with collective ops is not supported yet.")
logging.log_first_n(
logging.INFO, "Collective All-reduce invoked with batches size = %d, "
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index aec53b01d7..490371477a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -26,12 +26,12 @@ import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -40,9 +40,17 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
-def _make_per_device(values, devices):
+def _make_per_device(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
+
+ # We simulate the result of regroup called on PerDevice which strips the
+ # PerDevice wrapper if it has only one value.
+ if len(values) == 1 and regroup:
+ with ops.device(devices[0]):
+ placed_v = array_ops.identity(values[0])
+ return placed_v
+
index = {}
for d, v in zip(devices, values):
with ops.device(d):
@@ -127,7 +135,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_mirrored, destination_different, destination_str,
destination_list
]
@@ -138,24 +146,24 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device))
+ _fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device))
+ _fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM, per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices), destinations or per_device))
+ _fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices), destinations or per_device))
+ _fake_mirrored(mean_2 * len(devices), destinations))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -163,25 +171,22 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ _fake_mirrored(mean * len(devices), d1),
+ _fake_mirrored(mean_2 * len(devices), d2)
])
# test broadcast()
for destinations in all_destinations:
- if destinations is None:
- continue
- else:
- self._assert_values_equal(
- cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
- _fake_mirrored(1., destinations))
+ self._assert_values_equal(
+ cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
+ _fake_mirrored(1., destinations))
class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
@@ -368,14 +373,27 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
("xring", 2, -1)], 0, 0, 0)),
],
distribution=[
- combinations.multi_worker_strategy_with_cpu,
- combinations.multi_worker_strategy_with_one_gpu,
- combinations.multi_worker_strategy_with_two_gpus
+ combinations.NamedDistribution(
+ "MirroredCPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ required_gpus=0),
+ combinations.NamedDistribution(
+ "Mirrored1GPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "Mirrored2GPUs",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ required_gpus=2),
],
mode=["graph"])
@combinations.generate(multi_worker_allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ distribution.configure(cluster_spec={
+ "worker":
+ ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
+ })
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
@@ -388,13 +406,8 @@ class MultiWorkerCollectiveAllReduceTest(
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- "fake_worker_0", "fake_worker_1", "fake_worker_2"
- ]
- }
def setUp(self):
super(MultiWorkerCollectiveAllReduceTest, self).setUp()
@@ -417,7 +430,7 @@ class MultiWorkerCollectiveAllReduceTest(
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
else:
devices = ["/device:CPU:0"]
- return collective_all_reduce_ops, devices, "local"
+ return collective_all_reduce_ops, devices, ""
else:
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
3, num_gpus, collective_keys=collective_keys)
@@ -428,7 +441,8 @@ class MultiWorkerCollectiveAllReduceTest(
]
else:
devices = ["/job:%s/task:%d" % (task_type, task_id)]
- return collective_all_reduce_ops, devices, self._workers[task_id].target
+ return (collective_all_reduce_ops, devices,
+ "grpc://" + self._cluster_spec[task_type][task_id])
def _assert_values_equal(self, left, right, sess):
if isinstance(left, list):
@@ -455,7 +469,8 @@ class MultiWorkerCollectiveAllReduceTest(
num_workers = 1
worker_device = None
else:
- num_workers = len(self._workers)
+ num_workers = len(self._cluster_spec.get("chief", [])) + len(
+ self._cluster_spec.get("worker", []))
worker_device = "/job:%s/task:%d" % (task_type, task_id)
with ops.Graph().as_default(), \
ops.device(worker_device), \
@@ -463,7 +478,7 @@ class MultiWorkerCollectiveAllReduceTest(
# Collective ops doesn't support scalar tensors, so we have to construct
# 1-d tensors.
values = [constant_op.constant([float(d)]) for d in range(len(devices))]
- per_device = _make_per_device(values, devices)
+ per_device = _make_per_device(values, devices, regroup=True)
mean = np.array([(len(devices) - 1.) / 2.])
values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
@@ -476,7 +491,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_different, destination_mirrored, destination_str,
destination_list
]
@@ -487,27 +502,27 @@ class MultiWorkerCollectiveAllReduceTest(
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device), sess)
+ _fake_mirrored(mean, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device), sess)
+ _fake_mirrored(mean_2, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean * len(devices) * num_workers, destinations),
+ sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
+ sess)
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -516,30 +531,34 @@ class MultiWorkerCollectiveAllReduceTest(
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
], sess)
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices) * num_workers, d1 or
- per_device),
- _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
- per_device_2)
+ _fake_mirrored(mean * len(devices) * num_workers, d1),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
], sess)
return True
@combinations.generate(
- combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1))
def testReductionDistributed(self, num_gpus):
if context.num_gpus() < num_gpus:
return
self._run_between_graph_clients(self._test_reduction, self._cluster_spec,
num_gpus)
+ # Collective ops doesn't support strategy with one device.
+ def testReductionLocal(self, num_gpus=2):
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_reduction(None, None, num_gpus, local_mode=True)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
new file mode 100644
index 0000000000..5348512016
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -0,0 +1,659 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that show Distribute Coordinator works with Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import json
+import os
+import sys
+import tempfile
+import threading
+from absl.testing import parameterized
+import numpy as np
+import six
+
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.optimizer_v2 import adagrad
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import estimator_training as dc_training
+from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.estimator import training as estimator_training
+from tensorflow.python.estimator.canned import dnn_linear_combined
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export as export_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import server_lib
+
+BATCH_SIZE = 10
+LABEL_DIMENSION = 2
+DATA = np.linspace(
+ 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
+ BATCH_SIZE, LABEL_DIMENSION)
+EVAL_NAME = "foo"
+EXPORTER_NAME = "saved_model_exporter"
+MAX_STEPS = 10
+
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+WORKER = dc._TaskType.WORKER
+PS = dc._TaskType.PS
+
+original_run_distribute_coordinator = dc.run_distribute_coordinator
+
+
+# TODO(yuefengz): merge this method back to test_util.
+def _create_local_cluster(num_workers,
+ num_ps,
+ has_eval=False,
+ protocol="grpc",
+ worker_config=None,
+ ps_config=None):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ if has_eval:
+ cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs,
+ job_name="worker",
+ protocol=protocol,
+ task_index=ix,
+ config=worker_config,
+ start=True) for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs,
+ job_name="ps",
+ protocol=protocol,
+ task_index=ix,
+ config=ps_config,
+ start=True) for ix in range(num_ps)
+ ]
+ if has_eval:
+ evals = [
+ server_lib.Server(
+ cs,
+ job_name="evaluator",
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+ ]
+ else:
+ evals = []
+
+ return workers, ps_servers, evals
+
+
+def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
+ """Create an in-process cluster that consists of only standard server."""
+ # Leave some memory for cuda runtime.
+ if has_eval:
+ gpu_mem_frac = 0.7 / (num_workers + 1)
+ else:
+ gpu_mem_frac = 0.7 / num_workers
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ # Enable collective ops which has no impact on non-collective ops.
+ # TODO(yuefengz, tucker): removing this after we move the initialization of
+ # collective mgr to the session level.
+ worker_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ ps_config = config_pb2.ConfigProto()
+ ps_config.device_count["GPU"] = 0
+
+ return _create_local_cluster(
+ num_workers,
+ num_ps=num_ps,
+ has_eval=has_eval,
+ worker_config=worker_config,
+ ps_config=ps_config,
+ protocol="grpc")
+
+
+def _create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
+ if num_workers:
+ cluster_spec[WORKER] = [
+ "localhost:%s" % portpicker.pick_unused_port()
+ for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec[PS] = [
+ "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
+ return cluster_spec
+
+
+def _bytes_to_str(maybe_bytes):
+ if isinstance(maybe_bytes, six.string_types):
+ return maybe_bytes
+ else:
+ return str(maybe_bytes, "utf-8")
+
+
+def _strip_protocol(target):
+ # cluster_spec expects "host:port" strings.
+ if "//" in target:
+ return target.split("//")[1]
+ else:
+ return target
+
+
+class DistributeCoordinatorIntegrationTest(test.TestCase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ num_workers=3, num_ps=2, has_eval=True)
+ cls._cluster_spec = {
+ "worker": [
+ _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
+ ],
+ "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
+ "evaluator": [
+ _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
+ ]
+ }
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+ self._event = threading.Event()
+ super(DistributeCoordinatorIntegrationTest, self).setUp()
+
+ def dataset_input_fn(self, x, y, batch_size, shuffle):
+
+ def input_fn():
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ if shuffle:
+ dataset = dataset.shuffle(batch_size)
+ dataset = dataset.repeat(100).batch(batch_size)
+ return dataset
+
+ return input_fn
+
+ def _get_exporter(self, name, fc):
+ feature_spec = feature_column.make_parse_example_spec(fc)
+ serving_input_receiver_fn = (
+ export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
+ return exporter_lib.LatestExporter(
+ name, serving_input_receiver_fn=serving_input_receiver_fn)
+
+ def _extract_loss_and_global_step(self, event_folder):
+ """Returns the loss and global step in last event."""
+ event_paths = glob.glob(os.path.join(event_folder, "events*"))
+
+ loss = None
+ global_step_count = None
+
+ for e in summary_iterator.summary_iterator(event_paths[-1]):
+ current_loss = None
+ for v in e.summary.value:
+ if v.tag == "loss":
+ current_loss = v.simple_value
+
+ # If loss is not found, global step is meaningless.
+ if current_loss is None:
+ continue
+
+ current_global_step = e.step
+ if global_step_count is None or current_global_step > global_step_count:
+ global_step_count = current_global_step
+ loss = current_loss
+
+ return (loss, global_step_count)
+
+ def _get_estimator(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ input_dimension = LABEL_DIMENSION
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+
+ return dnn_linear_combined.DNNLinearCombinedRegressor(
+ linear_feature_columns=linear_feature_columns,
+ dnn_hidden_units=(2, 2),
+ dnn_feature_columns=dnn_feature_columns,
+ label_dimension=LABEL_DIMENSION,
+ model_dir=self._model_dir,
+ dnn_optimizer=adagrad.AdagradOptimizer(0.001),
+ linear_optimizer=adagrad.AdagradOptimizer(0.001),
+ config=run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=train_distribute,
+ eval_distribute=eval_distribute,
+ remote_cluster=remote_cluster)))
+
+ def _complete_flow(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ estimator = self._get_estimator(train_distribute, eval_distribute,
+ remote_cluster)
+
+ input_dimension = LABEL_DIMENSION
+ train_input_fn = self.dataset_input_fn(
+ x={"x": DATA},
+ y=DATA,
+ batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
+ shuffle=True)
+ if eval_distribute:
+ eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
+ else:
+ eval_batch_size = BATCH_SIZE
+ eval_input_fn = self.dataset_input_fn(
+ x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
+
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+
+ estimator_training.train_and_evaluate(
+ estimator,
+ estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
+ estimator_training.EvalSpec(
+ name=EVAL_NAME,
+ input_fn=eval_input_fn,
+ steps=None,
+ exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
+ start_delay_secs=0,
+ throttle_secs=1))
+ return estimator
+
+ def _inspect_train_and_eval_events(self, estimator):
+ # Make sure nothing is stuck in limbo.
+ writer_cache.FileWriterCache.clear()
+
+ # Examine the training events. Use a range to check global step to avoid
+ # flakyness due to global step race condition.
+ training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
+ self.assertIsNotNone(training_loss)
+
+ # Examine the eval events. The global step should be accurate.
+ eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
+ eval_loss, eval_global_step = self._extract_loss_and_global_step(
+ event_folder=eval_dir)
+ self.assertIsNotNone(eval_loss)
+ self.assertGreaterEqual(eval_global_step, MAX_STEPS)
+
+ # Examine the export folder.
+ export_dir = os.path.join(
+ os.path.join(self._model_dir, "export"), EXPORTER_NAME)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ # Examine the ckpt for predict.
+ def predict_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ "x": DATA
+ }).batch(BATCH_SIZE)
+
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_standalone_client(self, train_distribute_cls,
+ eval_distribute_cls):
+ try:
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+ except TypeError:
+ train_distribute = train_distribute_cls(num_gpus_per_worker=2)
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ estimator = self._complete_flow(
+ train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
+ self._inspect_train_and_eval_events(estimator)
+
+ def _mock_run_distribute_coordinator(
+ self,
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=dc.CoordinatorMode.STANDALONE_CLIENT,
+ cluster_spec=None,
+ session_config=None):
+ # Calls the origial `run_distribute_coordinator` method but gets task config
+ # from environment variables and then signals the caller.
+ task_type = None
+ task_id = None
+ if not cluster_spec:
+ cluster_spec = None
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if not cluster_spec:
+ cluster_spec = tf_config.get("cluster", {})
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", task_type)
+ task_id = int(task_env.get("index", task_id))
+ self._event.set()
+ original_run_distribute_coordinator(
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=mode,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config)
+
+ def _task_thread(self, train_distribute, eval_distribute):
+ with test.mock.patch.object(dc, "run_distribute_coordinator",
+ self._mock_run_distribute_coordinator):
+ self._complete_flow(train_distribute, eval_distribute)
+
+ def _run_task_in_thread(self, cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute):
+ if task_type:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ else:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ self._event.clear()
+ t = threading.Thread(
+ target=self._task_thread, args=(train_distribute, eval_distribute))
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ t.start()
+ self._event.wait()
+ return t
+
+ def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
+ eval_distribute):
+ threads = {}
+ for task_type in cluster_spec.keys():
+ threads[task_type] = []
+ for task_id in range(len(cluster_spec[task_type])):
+ t = self._run_task_in_thread(cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute)
+ threads[task_type].append(t)
+ return threads
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ parameter_server_strategy.ParameterServerStrategy,
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_between_graph(
+ self, train_distribute_cls, eval_distribute_cls):
+ train_distribute = train_distribute_cls(
+ num_gpus_per_worker=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ for task_type, ts in threads.items():
+ if task_type == PS:
+ continue
+ for t in ts:
+ t.join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[mirrored_strategy.MirroredStrategy],
+ eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
+ eval_distribute_cls):
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ threads[WORKER][0].join()
+ threads[EVALUATOR][0].join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+
+TF_CONFIG_WITH_CHIEF = {
+ "cluster": {
+ "chief": ["fake_chief"],
+ },
+ "task": {
+ "type": "chief",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITH_MASTER = {
+ "cluster": {
+ "master": ["fake_master"],
+ },
+ "task": {
+ "type": "master",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
+
+
+class RunConfigTest(test.TestCase):
+
+ def test_previously_unexpected_cluster_spec(self):
+ with test.mock.patch.dict(
+ "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
+ run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+
+ def test_should_run_distribute_coordinator(self):
+ """Tests that should_run_distribute_coordinator return a correct value."""
+ # We don't use distribute coordinator for local training.
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config_with_train_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ config_with_eval_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertTrue(
+ dc_training.should_run_distribute_coordinator(
+ config_with_train_distribute))
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ config_with_eval_distribute))
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertFalse(dc_training.should_run_distribute_coordinator(config))
+
+ def test_init_run_config_duplicate_distribute(self):
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy()))
+
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy()))
+
+ def test_init_run_config_none_distribute_coordinator_mode(self):
+ # We don't use distribute coordinator for local training.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ dc_training.init_run_config(config, {})
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig()
+ self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
+
+ def test_init_run_config_independent_worker(self):
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator with INDEPENDENT_WORKER mode.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.INDEPENDENT_WORKER)
+
+ def test_init_run_config_standalone_client(self):
+ # When `train_distribute` is specified, TF_CONFIG is detected and
+ # `experimental.remote_cluster` is set use distribute coordinator with
+ # STANDALONE_CLIENT mode.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ remote_cluster={"chief": ["fake_worker"]}))
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.STANDALONE_CLIENT)
+
+
+if __name__ == "__main__":
+ with test.mock.patch.object(sys, "exit", os._exit):
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD
index cbfd178502..84b106545e 100644
--- a/tensorflow/contrib/distribute/python/examples/BUILD
+++ b/tensorflow/contrib/distribute/python/examples/BUILD
@@ -19,9 +19,20 @@ py_binary(
)
py_binary(
- name = "simple_tfkeras_example",
+ name = "keras_model_with_estimator",
srcs = [
- "simple_tfkeras_example.py",
+ "keras_model_with_estimator.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_binary(
+ name = "keras_mnist",
+ srcs = [
+ "keras_mnist.py",
],
deps = [
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
new file mode 100644
index 0000000000..a84ef04196
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An example training a Keras Model using MirroredStrategy and native APIs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+NUM_CLASSES = 10
+
+
+def get_input_datasets():
+ """Downloads the MNIST dataset and creates train and eval dataset objects.
+
+ Returns:
+ Train dataset, eval dataset and input shape.
+
+ """
+ # input image dimensions
+ img_rows, img_cols = 28, 28
+
+ # the data, split between train and test sets
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+
+ if tf.keras.backend.image_data_format() == 'channels_first':
+ x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
+ x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
+ input_shape = (1, img_rows, img_cols)
+ else:
+ x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
+ x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
+ input_shape = (img_rows, img_cols, 1)
+
+ x_train = x_train.astype('float32')
+ x_test = x_test.astype('float32')
+ x_train /= 255
+ x_test /= 255
+
+ # convert class vectors to binary class matrices
+ y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
+ y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
+
+ # train dataset
+ train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ train_ds = train_ds.repeat()
+ train_ds = train_ds.shuffle(100)
+ train_ds = train_ds.batch(64, drop_remainder=True)
+
+ # eval dataset
+ eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+ eval_ds = eval_ds.repeat()
+ eval_ds = eval_ds.batch(64, drop_remainder=True)
+
+ return train_ds, eval_ds, input_shape
+
+
+def get_model(input_shape):
+ """Builds a Sequential CNN model to recognize MNIST digits.
+
+ Args:
+ input_shape: Shape of the input depending on the `image_data_format`.
+
+ Returns:
+ a Keras model
+
+ """
+ # Define a CNN model to recognize MNIST digits.
+ model = tf.keras.models.Sequential()
+ model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
+ activation='relu',
+ input_shape=input_shape))
+ model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
+ model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
+ model.add(tf.keras.layers.Dropout(0.25))
+ model.add(tf.keras.layers.Flatten())
+ model.add(tf.keras.layers.Dense(128, activation='relu'))
+ model.add(tf.keras.layers.Dropout(0.5))
+ model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ return model
+
+
+def main(_):
+ # Build the train and eval datasets from the MNIST data. Also return the
+ # input shape which is constructed based on the `image_data_format`
+ # i.e channels_first or channels_last.
+ train_ds, eval_ds, input_shape = get_input_datasets()
+ model = get_model(input_shape)
+
+ # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or
+ # the `devices` argument then all the GPUs available on the machine are used.
+ strategy = tf.contrib.distribute.MirroredStrategy()
+
+ # Compile the model by passing the distribution strategy object to the
+ # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed
+ # based on the strategy instantiated.
+ model.compile(loss=tf.keras.losses.categorical_crossentropy,
+ optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001),
+ metrics=['accuracy'],
+ distribute=strategy)
+
+ # Train the model with the train dataset.
+ model.fit(x=train_ds, epochs=20, steps_per_epoch=310)
+
+ # Evaluate the model with the eval dataset.
+ score = model.evaluate(eval_ds, steps=10, verbose=0)
+ print('Test loss:', score[0])
+ print('Test accuracy:', score[1])
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
index 518ec9c423..8d117eb7e8 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
@@ -42,19 +42,19 @@ def main(args):
model_dir = args[1]
print('Using %s to store checkpoints.' % model_dir)
- # Define tf.keras Model.
+ # Define a Keras Model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
- # Compile tf.keras Model.
+ # Compile the model.
optimizer = tf.train.GradientDescentOptimizer(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
tf.keras.backend.set_learning_phase(True)
- # Define a DistributionStrategy and convert the tf.keras Model to a
- # tf.Estimator that utilizes the DistributionStrategy.
+ # Define a DistributionStrategy and convert the Keras Model to an
+ # Estimator that utilizes the DistributionStrategy.
strategy = tf.contrib.distribute.MirroredStrategy(
['/device:GPU:0', '/device:GPU:1'])
config = tf.estimator.RunConfig(
@@ -62,7 +62,7 @@ def main(args):
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, config=config, model_dir=model_dir)
- # Train and evaluate the tf.Estimator.
+ # Train and evaluate the model.
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn)
print('Eval result: {}'.format(eval_result))
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py
index 1f24f62947..f07ec8234d 100644
--- a/tensorflow/contrib/distribute/python/input_ops.py
+++ b/tensorflow/contrib/distribute/python/input_ops.py
@@ -47,11 +47,8 @@ def auto_shard_dataset(dataset, num_shards, index):
Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the
- files.
-
- Raises:
- NotImplementedError: If we cannot automatically determine a good way to
- shard the input dataset.
+ files. The input dataset will be returned if we cannot automatically
+ determine a good way to shard the input dataset.
"""
# TODO(priyag): Clone datasets instead of updating in place, similar to the
@@ -127,8 +124,10 @@ def auto_shard_dataset(dataset, num_shards, index):
tf_logging.warn(
"Could not find a standard reader in the input pipeline"
"(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
- "Falling back to sharding the dataset anyway. Please verify"
- "correctness of auto-sharding for your input.")
+ "So auto-sharding is not done. Please verify correctness of "
+ "auto-sharding for your input.")
+ # TODO(yuefengz): maybe still shard it?
+ return dataset
# TODO(priyag): What do we want to do if the number of filenames is
# uneven in the number of shards? By default, this will just return as
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index 16179c3a49..c5acb7ced4 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -91,7 +91,7 @@ class AutoShardDatasetTest(test.TestCase):
def _verifySimpleShardingOutput(self, dataset, record_fn):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
@@ -150,7 +150,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
@@ -182,7 +182,7 @@ class AutoShardDatasetTest(test.TestCase):
# Verify output.
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual = []
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
@@ -218,7 +218,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._record(r, f), sess.run(next_element))
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 4facd72d12..9e1762d92c 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -18,9 +18,12 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
+from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
@@ -91,6 +94,25 @@ def get_ds_test_input_fn():
return dataset
+def batch_wrapper(dataset, batch_size, distribution):
+ # TPUs currently require fully defined input shapes, drop_remainder ensures
+ # the input will have fully defined shapes.
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ return dataset.batch(batch_size, drop_remainder=True)
+ else:
+ return dataset.batch(batch_size)
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step],
+ mode=['graph'])
+
+
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
def setUp(self):
@@ -116,7 +138,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
model_dir=self._base_dir,
train_distribute=dist,
eval_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -139,7 +161,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -163,7 +185,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
config=config)
with self.assertRaisesRegexp(ValueError,
@@ -175,10 +197,10 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._config.model_dir)
-class TestWithDistributionStrategy(test.TestCase):
+class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_validating_dataset_input_tensors_with_shape_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2))
@@ -197,7 +219,7 @@ class TestWithDistributionStrategy(test.TestCase):
strategy, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
@@ -215,8 +237,8 @@ class TestWithDistributionStrategy(test.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_on_same_dataset(self):
- with self.test_session():
+ def test_calling_model_with_numpy_arrays(self):
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -228,11 +250,44 @@ class TestWithDistributionStrategy(test.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ inputs = np.zeros((64, 3), dtype=np.float32)
+ targets = np.zeros((64, 4), dtype=np.float32)
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
+ validation_data=(inputs, targets))
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+ @combinations.generate(all_combinations())
+ def test_calling_model_on_same_dataset(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
+
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -241,8 +296,11 @@ class TestWithDistributionStrategy(test.TestCase):
validation_data=dataset, validation_steps=2)
model.predict(dataset, steps=2)
+ # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
+ # as clone_model's input_tensors argument only seems to accept list and not
+ # tuples or dict.
def test_fit_with_tuple_and_dict_dataset_inputs(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -282,8 +340,9 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- def test_fit_eval_and_predict_methods_on_dataset(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -291,16 +350,13 @@ class TestWithDistributionStrategy(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
-
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -320,7 +376,7 @@ class TestWithDistributionStrategy(test.TestCase):
def __call__(self, y_true, y_pred):
return y_pred - y_true
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -336,7 +392,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
def test_unsupported_features(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -367,8 +423,8 @@ class TestWithDistributionStrategy(test.TestCase):
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
- NotImplementedError, 'sample_weight is currently not supported when '
- 'using DistributionStrategy.'):
+ NotImplementedError, '`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.'):
model.fit(
dataset,
epochs=1,
@@ -389,7 +445,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.predict(dataset, verbose=0)
def test_calling_with_unsupported_predefined_callbacks(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -428,7 +484,7 @@ class TestWithDistributionStrategy(test.TestCase):
callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
def test_dataset_input_shape_validation(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -446,8 +502,7 @@ class TestWithDistributionStrategy(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
# Wrong input shape
@@ -465,7 +520,7 @@ class TestWithDistributionStrategy(test.TestCase):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
# Lambda layer uses the learning phase.
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(16,), name='input')
y = keras.layers.Dense(16)(x)
z = keras.layers.Dropout(0.9999)(y)
@@ -497,8 +552,10 @@ class TestWithDistributionStrategy(test.TestCase):
class LossMaskingWithDistributionStrategyTest(test.TestCase):
+ # TODO(priyag): Enable all strategies for this test. Currently it does not
+ # work for TPU due to some invalid datatype.
def test_masking(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
@@ -520,24 +577,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
self.assertEqual(hist.history['loss'][0], 0)
-class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+class NormalizationLayerWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
- def test_batchnorm_correctness(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_batchnorm_correctness(self, distribution):
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
- strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
- '/device:GPU:0'])
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
- distribute=strategy)
+ distribute=distribution)
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ x = x.astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
dataset = dataset.repeat(100)
- dataset = dataset.batch(32)
+ dataset = batch_wrapper(dataset, 32, distribution)
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
out = model.predict(dataset, steps=2)
@@ -547,10 +605,12 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
-class CorrectnessWithDistributionStrategyTest(test.TestCase):
+class CorrectnessWithDistributionStrategyTest(test.TestCase,
+ parameterized.TestCase):
- def test_correctness(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_correctness(self, distribution):
+ with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
x_train = np.random.rand(num_samples, 1)
@@ -558,44 +618,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
-
- # With DistributionStrategy
- dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
- dataset_with = dataset_with.batch(32)
- strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'],
- prefetch_on_device=False)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5),
- distribute=strategy)
- model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
- wts_with_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
- predict_dataset_with = predict_dataset_with.batch(2)
- predict_with_ds = model.predict(predict_dataset_with, steps=1)
- predict_with_ds = np.reshape(predict_with_ds, (4, 1))
-
- # Without DistributionStrategy
- dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+ def fit_and_predict(with_distribution=None):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ distribute=with_distribution)
+
+ batch_size = 64
+ if with_distribution:
+ batch_size //= with_distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
- dataset_without = dataset_without.batch(64)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5))
- model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
- wts_without_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
- x_predict, x_predict))
- predict_dataset_without = predict_dataset_without.batch(4)
- predict_without_ds = model.predict(predict_dataset_without, steps=1)
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+ # Running only 100 steps instead of the full dataset to keep test
+ # duration small.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+
+ weights = model.get_weights()
+
+ x_predict = [[1.], [2.], [3.], [4.]]
+ predict_batch_size = 4
+ if with_distribution:
+ predict_batch_size //= with_distribution.num_towers
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
+ x_predict))
+ predict_dataset = batch_wrapper(predict_dataset,
+ predict_batch_size, distribution)
+ predict_result = model.predict(predict_dataset, steps=1)
+ predict_result = np.reshape(predict_result, (4, 1))
+
+ return weights, predict_result
+
+ wts_with_ds, predict_with_ds = fit_and_predict(
+ with_distribution=distribution)
+ wts_without_ds, predict_without_ds = fit_and_predict(
+ with_distribution=None)
# Verify that the weights are the same within some limits of tolerance.
np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
@@ -604,5 +663,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
+# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1.
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 2f3d6bdd3f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -68,6 +68,8 @@ def _regression_dataset_fn():
"predictions": [1., .75, .25, 0.]}).repeat()
+# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using
+# TowerLocalVariables on TPUs. Submit http://cl/208914352.
def all_combinations():
return combinations.combine(
distribution=[combinations.default_strategy,
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 516ede7ade..bdac4fb58c 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -71,7 +71,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -108,7 +108,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -168,7 +168,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -249,7 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -343,7 +343,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -466,7 +466,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 72d1c6b7dd..0c6805d682 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import contextlib
+from functools import partial
import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@@ -63,7 +65,7 @@ class _RequestedStop(Exception):
pass
-# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# _call_for_each_tower and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
@@ -195,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and equal to 0.
if value == 0:
return 0
- # If the aggregation type is MEAN, then this essentially means that the same
- # value should be on all destinations.
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return distribution.broadcast(value, destinations)
+ # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
+ # essentially means that the same value should be on all destinations.
+ if aggregation in (
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
+ return value
cross_tower_ops_lib.validate_destinations(destinations)
# We do not support an aggregation type of SUM if the value is the same across
@@ -206,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value cannot be reduced with the "
- "given aggregation.")
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given aggregation %s." % (value, aggregation))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
@@ -252,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
- if aggregation not in [
+ if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ ):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -274,6 +279,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -287,28 +295,140 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result
class MirroredStrategy(distribute_lib.DistributionStrategy):
- """Mirrors vars to distribute across multiple devices on a single machine.
+ """Mirrors vars to distribute across multiple devices and machines.
+
+ This strategy uses one tower per device and sync replication for its multi-GPU
+ version.
+
+ When `cluster_spec` is given by the `configure` method., it turns into the
+ mulit-worker version that works on multiple workers with in-graph replication.
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
+
+ There are several important concepts for distributed TensorFlow, e.g.
+ `client`, `job`, 'task', `cluster`, `in-graph replication` and
+ 'synchronous training' and they have already been defined in the
+ [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+ The distribution strategy inherits these concepts as well and in addition to
+ that we also clarify several more concepts:
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
+ specifies tasks for devices on all workers. The `client` then creates a
+ client session which will talk to the `master` service of a `worker`. Then
+ the `master` will partition the graph and distribute the work to all
+ participating workers.
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ physical machine. We will have multiple `worker`s with different `task`
+ index. They all do similar things except for one worker checkpointing model
+ variables, writing summaries, etc. in addition to its ordinary work.
+
+ The multi-worker version of this class maps one tower to one device on a
+ worker. It mirrors all model variables on all towers. For example, if you have
+ two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the
+ model variables on these 8 GPUs. Then like in MirroredStrategy, each tower
+ performs their computation with their own copy of variables unless in
+ cross-tower model where variable or tensor reduction happens.
- This strategy uses one tower per device and sync replication.
+ Args:
+ devices: a list of device strings.
+ num_gpus: number of GPUs. For local training, either specify `devices` or
+ `num_gpus`. In distributed training, this must be specified as number of
+ GPUs on each worker.
+ num_gpus_per_worker: number of GPUs per worker. This is the same as
+ `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
+ specified.
+ cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
+ set, the `configure` method will try to find the best one.
+ prefetch_on_device: optional boolean to specify whether to prefetch input
+ data to devices.
"""
def __init__(self,
devices=None,
num_gpus=None,
+ num_gpus_per_worker=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
+
+ self._cross_tower_ops = cross_tower_ops
+ self._prefetch_on_device = prefetch_on_device
+ # Rememeber num GPUs which might be needed by `configure` method.
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is not None:
+ self._num_gpus = num_gpus
+ else:
+ self._num_gpus = num_gpus_per_worker
+
+ self._initialize_local(self._num_gpus, devices)
+
+ def _initialize_local(self, num_gpus, devices):
+ """Initializes the object for local training."""
+ self._cluster_spec = None
# Convert `num_gpus` into `devices`, shouldn't specify both.
if devices is None:
if num_gpus is None:
num_gpus = context.num_gpus()
- devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ if num_gpus == 0:
+ devices = ["/device:CPU:0"]
+ else:
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
elif num_gpus is not None:
raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ self._num_gpus = num_gpus
+ # TODO(yuefengz): consider setting the default device.
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
+ self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)})
+
+ def _initialize_multi_worker(self, num_gpus, cluster_spec):
+ """Initializes the object for multi-worker training."""
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in ["chief", "worker"]:
+ for task in range(len(cluster_spec.as_dict().get(job, []))):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
+ if num_gpus is None:
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ if num_gpus > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, "/device:CPU:0")]
+ for worker in self._workers
+ }
+
+ devices = nest.flatten(self._worker_device_map)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
assert devices, "Must specify at least one device."
assert len(set(devices)) == len(devices), (
@@ -318,9 +438,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
{d: i for i, d in enumerate(devices)})
- self._cross_tower_ops = cross_tower_ops
- self._prefetch_on_device = prefetch_on_device
- # TODO(yuefengz): consider setting the default device.
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
@@ -357,9 +474,14 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**kwargs)
def distribute_dataset(self, dataset_fn):
- return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+ self._prefetch_on_device)
+ else:
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
@@ -383,12 +505,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
@@ -435,10 +566,33 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
- def configure(self, session_config=None):
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del task_type, task_id
+
+ if session_config:
+ session_config.isolate_session_state = True
+
+ if cluster_spec:
+ self._initialize_multi_worker(self._num_gpus, cluster_spec)
+
if self._cross_tower_ops is None:
- self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
- self._devices, session_config=session_config)
+ if self._cluster_spec:
+ # It currently cannot detect the toplogy of remote workers. So we
+ # hard-code the multi-worker all-reduce algorithm for now.
+ if len(self._workers) == 1:
+ # The default is "nccl".
+ self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps()
+ else:
+ # The default is hierarchical reduce and broadcast.
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
+ else:
+ self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None:
@@ -454,10 +608,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ value = value.get(self._devices[0])
+ if isinstance(value, (int, float)):
+ return value
+ return self.broadcast(value, destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._devices[0]), d)
+ for v, d in value_destination_pairs]
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
@@ -523,6 +685,22 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def parameter_devices(self):
return list(self._devices)
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def non_slot_devices(self, var_list):
del var_list
return list(self._devices)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9a4cc0a897..c6894e9013 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import sys
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
@@ -37,10 +38,12 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -126,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes
+ def testReduceOnlyFirstTowerUpdates(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ def run_fn(device_id):
+ return constant_op.constant(3 + 5 * device_id)
+
+ dist = self._get_distribution_strategy()
+ with dist.scope():
+ result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER,
+ result,
+ destinations="/device:CPU:0")
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(1, len(unwrapped))
+ self.assertEqual(3, self.evaluate(unwrapped[0]))
+
@test_util.run_in_graph_and_eager_modes()
def testReduceToMultipleDestinations(self):
if not GPU_TEST:
@@ -382,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
v3.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testOnlyFirstTowerUpdatesVariables(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def create_fn():
+ aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ v0 = variable_scope.variable(
+ 2.0,
+ name="on_read",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ v1 = variable_scope.variable(
+ 3.0,
+ name="on_write",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=aggregation)
+ return v0, v1
+
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False)
+ self.evaluate(v0.initializer)
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0, self.evaluate(dist.read_var(v0)))
+ self.evaluate(v1.initializer)
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using the assign_add member function.
+ def update_member_fn(device_id):
+ update0 = v0.assign_add(5.0 * (device_id + 1))
+ update1 = v1.assign_add(7.0 * (device_id + 1))
+ return update0, update1
+
+ update0a, update1a = dist.call_for_each_tower(
+ update_member_fn, dist.worker_device_index, run_concurrently=False)
+
+ # Update "sync on read" variable.
+ self.evaluate(dist.group(update0a))
+ self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
+ # Writes are not synchronized for "sync on read" variables,
+ # so device[1] can end up with a different value.
+ self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
+ # Always reads from device 0.
+ self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1a))
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
+ # Writes are synchronized for v1, only the argument to assign_add on
+ # device[0] is used.
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using state_ops.assign_add global function.
+ def update_state_ops_fn(device_id):
+ update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1))
+ update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1))
+ return update0, update1
+
+ update0b, update1b = dist.call_for_each_tower(
+ update_state_ops_fn, dist.worker_device_index, run_concurrently=False)
+ self.evaluate(dist.group(update0b))
+
+ # Update "sync on read" variable.
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1b))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testNoneSynchronizationWithGetVariable(self):
self._skip_eager_if_gpus_less_than(1)
devices = ["/device:CPU:0", "/device:GPU:0"]
@@ -802,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase):
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
- ValueError, "A non-DistributedValues value cannot be reduced with "
- "the given aggregation."):
+ ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
+ "with the given aggregation VariableAggregation.SUM."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -886,8 +986,18 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
- mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+
+ # read_value == True
+ mirrored_var_result = self.evaluate(
+ mirrored_var.assign_add(6.0, read_value=True))
self.assertEquals(7.0, mirrored_var_result)
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+
+ # read_value == False
+ self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarTowerContext(self):
@@ -954,6 +1064,8 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEquals(3.0, mirrored_var_result)
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarTowerContext(self):
@@ -1244,5 +1356,40 @@ class MirroredStrategyDefunTest(test.TestCase):
self._call_and_check(fn1, [factors], expected_result, [fn1])
+class MultiWorkerMirroredStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "worker": ["/job:worker/task:0", "/job:worker/task:1"]
+ })
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=cluster_spec)
+ return strategy
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy(),
+ learning_rate=0.05)
+
+
+class MultiWorkerMirroredStrategyTestWithChief(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=2, num_ps=0, has_chief=True)
+ cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
+
+ def testMinimizeLossGraph(self):
+ strategy = mirrored_strategy.MirroredStrategy(
+ num_gpus_per_worker=context.num_gpus())
+ strategy.configure(cluster_spec=self._cluster_spec)
+ self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 5db2fff239..969e126956 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -22,6 +22,8 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context
@@ -60,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase):
def model_fn(device_id):
assert isinstance(device_id, int)
+
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
@@ -86,5 +89,21 @@ class VariableCreatorStackTest(test.TestCase):
self.assertEquals(expected, result)
+class MultiWorkerMirroredStrategyTest(test.TestCase):
+
+ def testDeviceScope(self):
+ """Test the device scope of multi-worker MirroredStrategy."""
+ with context.graph_mode():
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(
+ cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]})
+ with strategy.scope():
+ a = constant_op.constant(1.)
+ with ops.device("/cpu:0"):
+ b = constant_op.constant(1.)
+ self.assertEqual(a.device, "/job:worker/task:0")
+ self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 2892ce4394..16be839e1d 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -45,7 +45,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
monitor = monitor_lib.Monitor(single_loss_step, None)
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
monitor = monitor_lib.Monitor(single_loss_step, sess)
monitor.run_steps(1)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
deleted file mode 100644
index cbfe5df61d..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Classes implementing a mirrored DistributionStrategy for multiple workers."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from functools import partial
-
-from tensorflow.contrib.distribute.python import values
-from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
-from tensorflow.core.protobuf import cluster_pb2
-from tensorflow.python.training import device_util
-from tensorflow.python.training import server_lib
-from tensorflow.python.util import nest
-
-
-# TODO(yuefengz): support between-graph replication.
-# TODO(yuefengz): merge this class into its base class.
-# TODO(yuefengz): in some cases, we probably want to use configure method to
-# configure this class.
-# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the
-# class is introduced.
-class MultiWorkerMirroredStrategy(MirroredStrategy):
- """Mirrored strategy that works on multiple workers with in-graph replication.
-
- There are several important concepts for distributed TensorFlow, e.g.
- `client`, `job`, 'task', `cluster`, `in-graph replication` and
- 'synchronous training' and they have already been defined in the
- [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
- The distribution strategy inherits these concepts as well and in addition to
- that we also clarify several more concepts:
- * **In-graph replication**: the `client` creates a single `tf.Graph` that
- specifies tasks for devices on all workers. The `client` then creates a
- client session which will talk to the `master` service of a `worker`. Then
- the `master` will partition the graph and distribute the work to all
- participating workers.
- * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
- physical machine. We will have multiple `worker`s with different `task`
- index. They all do similar things except for one worker checkpointing model
- variables, writing summaries, etc. in addition to its ordinary work.
-
- This class maps one tower to one device on a worker. It mirrors all model
- variables on all towers. For example, if you have two `worker`s and each
- `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8
- GPUs. Then like in MirroredStrategy, each tower performs their computation
- with their own copy of variables unless in cross-tower model where variable or
- tensor reduction happens.
- """
-
- def __init__(self,
- num_gpus_per_worker=1,
- worker_job_name=None,
- num_workers=None,
- cluster=None,
- cross_tower_ops=None,
- prefetch_on_device=None):
- """Initialize the strategy object.
-
- Args:
- num_gpus_per_worker: number of GPUs per work. If it is zero, the local
- CPU will be used.
- worker_job_name: the job name for `worker`, typically just 'worker'.
- num_workers: the number of workers. If it is 0, it regenerates to
- single-worker MirroredStrategy.
- cluster: a `tf.train.ClusterSpec` object or a dict that can be used to
- construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef`
- proto buffer. It is an alternative way to initialize this object.
- cross_tower_ops: the cross tower ops to use. If None, a default one will
- be used. If configure method is called, a best one for the configuration
- will be chosen.
- prefetch_on_device: a boolean to specify whether to prefetech input to
- each worker's devices.
-
- Raises:
- ValueError: if got an unexpected `cluster`.
- """
- if cluster is None:
- self._workers = [
- '/job:%s/task:%d' % (worker_job_name, task_index)
- for task_index in range(num_workers)
- ]
- else:
- if isinstance(cluster, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster)
- elif isinstance(cluster, server_lib.ClusterSpec):
- cluster_spec = cluster
- else:
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- '`tf.train.ClusterDef` object')
-
- self._workers = []
- for job in sorted(cluster_spec.jobs):
- for task in range(cluster_spec.num_tasks(job)):
- self._workers.append('/job:%s/task:%d' % (job, task))
-
- self._num_gpus_per_worker = num_gpus_per_worker
- if num_gpus_per_worker > 0:
- self._worker_device_map = {
- worker: [
- device_util.canonicalize(worker + '/device:GPU:%d' % gpu)
- for gpu in range(num_gpus_per_worker)
- ] for worker in self._workers
- }
- else:
- self._worker_device_map = {
- worker: [device_util.canonicalize(worker, '/device:CPU:0')]
- for worker in self._workers
- }
- self._devices = nest.flatten(self._worker_device_map)
-
- super(MultiWorkerMirroredStrategy, self).__init__(
- devices=self._devices, prefetch_on_device=prefetch_on_device)
-
- # Setting `_default_device` will add a device scope in the
- # distribution.scope. We set the default device to the first worker. When
- # users specify device under distribution.scope by
- # with tf.device("/cpu:0"):
- # ...
- # their ops will end up on the cpu device of its first worker, e.g.
- # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
- self._default_device = self._workers[0]
-
- def distribute_dataset(self, dataset_fn):
- return values.MultiWorkerDataset(
- partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
deleted file mode 100644
index 09c859b32a..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MultiWorkerMirroredStrategy."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import multi_worker_strategy
-from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
-from tensorflow.python.eager import context
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.training import server_lib
-
-
-class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- strategy_test_lib.DistributionTestBase):
-
- def _get_distribution_strategy(self):
- return multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster=server_lib.ClusterSpec({
- 'worker': ['/job:worker/task:0', '/job:worker/task:1']
- }),
- num_gpus_per_worker=context.num_gpus())
-
- def testMinimizeLossGraph(self):
- self._test_minimize_loss_graph(self._get_distribution_strategy())
-
-
-class DeviceScopeTest(test.TestCase):
- """Test the device scope of MultiWorkerMirroredStrategy."""
-
- def testDeviceScope(self):
- with context.graph_mode():
- strategy = multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
- num_gpus_per_worker=context.num_gpus())
- with strategy.scope():
- a = constant_op.constant(1.)
- with ops.device('/cpu:0'):
- b = constant_op.constant(1.)
- self.assertEqual(a.device, '/job:worker/task:0')
- self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 249de01f08..18b4503eff 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -23,26 +23,105 @@ import copy
import threading
import numpy as np
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
-
-
-def create_in_process_cluster(num_workers, num_ps):
+from tensorflow.python.training import server_lib
+
+
+def _create_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ protocol='grpc',
+ worker_config=None,
+ ps_config=None):
+ """Creates and starts local servers and returns the cluster_spec dict."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ if num_workers > 0:
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
+ if has_eval:
+ cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ if has_chief:
+ cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ for i in range(num_workers):
+ server_lib.Server(
+ cs,
+ job_name='worker',
+ protocol=protocol,
+ task_index=i,
+ config=worker_config,
+ start=True)
+
+ for i in range(num_ps):
+ server_lib.Server(
+ cs,
+ job_name='ps',
+ protocol=protocol,
+ task_index=i,
+ config=ps_config,
+ start=True)
+
+ if has_chief:
+ server_lib.Server(
+ cs,
+ job_name='chief',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ if has_eval:
+ server_lib.Server(
+ cs,
+ job_name='evaluator',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
- gpu_mem_frac = 0.7 / num_workers
+ gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
+ if has_chief:
+ worker_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
+ else:
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps):
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
- return test_util.create_local_cluster(
+ return _create_cluster(
num_workers,
num_ps=num_ps,
+ has_chief=has_chief,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
@@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
@@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase):
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ if target is None:
+ target = self._default_target
if graph is None:
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
- graph=None, config=config, target=target or self._workers[0].target)
+ graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
- with session.Session(
- graph=graph, config=config, target=target or
- self._workers[0].target) as sess:
+ with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 86833ad851..23b220f64b 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -67,6 +67,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device)
def _broadcast(self, tensor, destinations):
+ del destinations
return tensor
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
@@ -88,13 +89,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
+ # TODO(priyag): Use max_iterations instead of an explicit counter.
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
- # TODO(priyag): Use max_iterations instead of an explicit counter.
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
@@ -118,6 +128,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
def _reduce(self, aggregation, value, destinations):
+ del destinations
if not isinstance(value, values.MapOutput):
return value
l = value.get()
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index a2d736e422..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built)))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 407c78df95..1125d027f6 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -18,38 +18,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import os
-
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
_LOCAL_CPU = "/device:CPU:0"
_LOCAL_GPU_0 = "/device:GPU:0"
-def _normalize_cluster_spec(cluster_spec):
- """Makes `cluster_spec` into a `ClusterSpec` object."""
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- return server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- return cluster_spec
-
-
# TODO(yuefengz): maybe cache variables on local CPU.
# TODO(yuefengz): we may want to set session options to disallow communication
# between workers.
@@ -70,7 +58,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
assigned to.
This class assumes between-graph replication will be used and works on a graph
- for a particular worker.
+ for a particular worker. Note that each graph and worker is independent.
+ This means that while each worker will synchronously compute a single gradient
+ update across all GPUs, updates between workers proceed asynchronously.
+ Operations that occur only on the first tower (such as incrementing the global
+ step), will occur on the first tower *of every worker*.
It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any
operations which potentially can be replicated across towers (i.e. multiple
@@ -88,40 +80,32 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
3) It is also not recommended to open a colocation scope (i.e. calling
`tf.colocate_with`) under the strategy's scope. For colocating variables,
use `distribution.colocate_vars_with` instead. Colocation of ops will possibly
- create conflicts of device assignement.
+ create conflicts of device assignment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type=None,
- task_id=None):
- """Initiailizes this strategy.
+ def __init__(self, num_gpus_per_worker=0):
+ """Initializes this strategy.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type.
- task_id: the current task id.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
- if cluster_spec:
- cluster_spec = _normalize_cluster_spec(cluster_spec)
- self._cluster_spec = cluster_spec
+ self._initialize_local(num_gpus_per_worker)
# We typically don't need to do all-reduce in this strategy.
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
reduce_to_device=_LOCAL_CPU))
- self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type,
- task_id)
-
- def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type,
- task_id):
- """Initialize internal devices.
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initialize devices for multiple workers.
It creates variable devices and compute devices. Variables and operations
will be assigned to them respectively. We have one compute device per tower.
@@ -139,82 +123,103 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
Raises:
ValueError: if the cluster_spec doesn't have ps jobs.
"""
- self._task_type = task_type or "worker"
- self._task_id = task_id or 0
- self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
+ assert cluster_spec
+ if not task_type or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- # TODO(yuefengz): maybe clearer to split it into two classes, one for
- # the distribuetd case and one for the local case, once we have the factory
- # class/method.
+ self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
# Define compute devices which is a list of device strings and one for each
# tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
# place operations on CPU.
- if cluster_spec is None:
- # Local mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = list(
- map("/device:GPU:{}".format, range(num_gpus_per_worker)))
- else:
- self._compute_devices = [_LOCAL_CPU]
+ if num_gpus_per_worker > 0:
+ self._compute_devices = [
+ "%s/device:GPU:%d" % (self._worker_device, i)
+ for i in range(num_gpus_per_worker)
+ ]
else:
- # Distributed mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = [
- "%s/device:GPU:%d" % (self._worker_device, i)
- for i in range(num_gpus_per_worker)
- ]
- else:
- self._compute_devices = [self._worker_device]
+ self._compute_devices = [self._worker_device]
self._compute_devices = list(
map(device_util.resolve, self._compute_devices))
self._canonical_compute_device_set = set(self._compute_devices)
- # Define variable device which is a device string in the local case and a
- # device function in the distributed case. It is used to open a device scope
- # where varibles are defined.
+ # In distributed mode, place variables on ps jobs in a round-robin fashion.
+ # Note that devices returned from `replica_device_setter` are not
+ # canonical and therefore we don't canonicalize all variable devices to
+ # make them consistent.
+ # TODO(yuefengz): support passing a strategy object to control variable
+ # assignment.
+ # TODO(yuefengz): merge the logic of replica_device_setter into this
+ # class.
+ num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
+ if num_ps_replicas == 0:
+ raise ValueError("The cluster spec needs to have `ps` jobs.")
+ self._variable_device = device_setter.replica_device_setter(
+ ps_tasks=num_ps_replicas,
+ worker_device=self._worker_device,
+ merge_devices=True,
+ cluster=cluster_spec)
+
# The `_parameter_devices` is needed for the `parameter_devices` property
- # and is a list of all variable devices.
- if cluster_spec is None:
- # Local mode. If there is only one GPU, put everything on that GPU.
- # Otherwise, place variables on CPU.
- if num_gpus_per_worker == 1:
- assert len(list(self._compute_devices)) == 1
- self._variable_device = _LOCAL_GPU_0
- self._parameter_devices = [_LOCAL_GPU_0]
- else:
- self._variable_device = _LOCAL_CPU
- self._parameter_devices = [_LOCAL_CPU]
+ # and is a list of all variable devices. Here parameter devices are all
+ # tasks of the "ps" job.
+ self._parameter_devices = map("/job:ps/task:{}".format,
+ range(num_ps_replicas))
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ self._default_device = self._worker_device
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker ParameterServerStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
+ "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
+ num_ps_replicas, self._is_chief, self._compute_devices,
+ self._variable_device)
+
+ def _initialize_local(self, num_gpus_per_worker):
+ """Initialize internal devices for local training."""
+ # Define compute devices which is a list of device strings and one for each
+ # tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
+ # place operations on CPU.
+ if num_gpus_per_worker > 0:
+ self._compute_devices = list(
+ map("/device:GPU:{}".format, range(num_gpus_per_worker)))
else:
- # Distributed mode. Place variables on ps jobs in a round-robin fashion.
- # Note that devices returned from `replica_device_setter` are not
- # canonical and therefore we don't canonicalize all variable devices to
- # make them consistent.
- # TODO(yuefengz): support passing a strategy object to control variable
- # assignment.
- # TODO(yuefengz): merge the logic of replica_device_setter into this
- # class.
- num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
- if num_ps_replicas == 0:
- raise ValueError("The cluster spec needs to have `ps` jobs.")
- self._variable_device = device_setter.replica_device_setter(
- ps_tasks=num_ps_replicas,
- worker_device=self._worker_device,
- merge_devices=True,
- cluster=cluster_spec)
-
- # Parameter devices are all tasks of the "ps" job.
- self._parameter_devices = map("/job:ps/task:{}".format,
- range(num_ps_replicas))
-
- # Define the default device in cross-tower mode. In the distributed case, we
- # set the default device to the corresponding worker to prevent these ops
- # from being placed on other workers.
- if cluster_spec is None:
- self._default_device = None
+ self._compute_devices = [_LOCAL_CPU]
+
+ self._compute_devices = list(
+ map(device_util.resolve, self._compute_devices))
+ self._canonical_compute_device_set = set(self._compute_devices)
+
+ # If there is only one GPU, put everything on that GPU. Otherwise, place
+ # variables on CPU.
+ if num_gpus_per_worker == 1:
+ assert len(list(self._compute_devices)) == 1
+ self._variable_device = _LOCAL_GPU_0
+ self._parameter_devices = [_LOCAL_GPU_0]
else:
- self._default_device = self._worker_device
+ self._variable_device = _LOCAL_CPU
+ self._parameter_devices = [_LOCAL_CPU]
+
+ self._is_chief = True
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info(
+ "ParameterServerStrategy with compute_devices = %r, "
+ "variable_device = %r", self._compute_devices, self._variable_device)
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
@@ -229,14 +234,58 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
# this creator, such as "MutableHashTable".
def _create_variable(self, next_creator, *args, **kwargs):
+ if self.num_towers > 1:
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in (
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER
+ ):
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ def var_creator(*args, **kwargs):
+ # Record what collections this variable should be added to.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Create and wrap the variable.
+ v = next_creator(*args, **kwargs)
+ wrapped = values.AggregatingVariable(v, aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the contained
+ # variable to the TRAINABLE_VARIABLES collection, so we manually
+ # remove it and replace with the wrapper. We can't set "trainable"
+ # to False for next_creator() since that causes functions like
+ # implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l.remove(v)
+ g.add_to_collections(collections, wrapped)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
+
+ return wrapped
+ else:
+ var_creator = next_creator
+
if "colocate_with" in kwargs:
with ops.device(None):
with ops.colocate_with(kwargs["colocate_with"]):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._variable_device):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
def _call_for_each_tower(self, fn, *args, **kwargs):
# pylint: disable=protected-access
@@ -258,11 +307,15 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
-
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self.broadcast(value.get(self._compute_devices[0]), destinations)
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._compute_devices[0]), d)
+ for v, d in value_destination_pairs]
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_tower_ops.batch_reduce(aggregation,
@@ -291,6 +344,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
def _update(self, var, fn, *args, **kwargs):
+ if isinstance(var, values.AggregatingVariable):
+ var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
@@ -319,27 +374,56 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# No need to distinguish between normal variables and tower-local variables.
return array_ops.identity(var)
- def configure(self, session_config=None):
- del session_config
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the strategy class.
- # Use TF_CONFIG to get the cluster spec and the current job.
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+ The strategy object will be re-initialized if `cluster_spec` is given but
+ was not passed in the constructor.
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", "worker")
- task_id = int(task_env.get("index", "0"))
- else:
- task_type = "worker"
- task_id = None
+ Args:
+ session_config: not used currently.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type.
+ task_id: the current task id.
- # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in
- # the constructor.
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
+ """
if not self._cluster_spec and cluster_spec:
- self._cluster_spec = cluster_spec
- self._initialize_devices(self._num_gpus_per_worker, cluster_spec,
- task_type, task_id)
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec)
+ self._task_type = task_type
+ self._task_id = task_id
+ self._initialize_multi_worker(self._num_gpus_per_worker,
+ self._cluster_spec, task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ session_config.isolate_session_state = False
+
+ assert self._cluster_spec
+ assert self._task_type
+ assert self._task_id is not None
+
+ # The device filters prevent communication between workers.
+ if self._task_type not in ["chief", "worker"]:
+ return
+ del session_config.device_filters[:]
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
@property
def num_towers(self):
@@ -356,3 +440,19 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
+
+ @property
+ def between_graph(self):
+ return True
+
+ @property
+ def should_init(self):
+ return self._is_chief
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
+
+ @property
+ def should_save_summary(self):
+ return self._is_chief
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 02eb68227d..12789e0bc9 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -18,13 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
+import copy
import threading
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
@@ -38,21 +41,15 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import training_util
+CHIEF = run_config.TaskType.CHIEF
+WORKER = run_config.TaskType.WORKER
+PS = run_config.TaskType.PS
-class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- parameterized.TestCase):
- @classmethod
- def setUpClass(cls):
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=2)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ],
- run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
- }
+class ParameterServerStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
def setUp(self):
self._result = 0
@@ -61,34 +58,30 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
- super(ParameterServerStrategyTest, self).setUp()
+ self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ super(ParameterServerStrategyTestBase, self).setUp()
def _get_test_objects(self, task_type, task_id, num_gpus):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=num_gpus)
if not task_type:
- return distribution, ''
-
- tf_config = {
- 'cluster': self._cluster_spec,
- 'task': {
- 'type': task_type,
- 'index': task_id
- }
- }
- with self._lock:
- # Accessing environment variables should be protected by locks because
- # environment variables are shared by all threads.
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
- distribution.configure()
- return distribution, self._workers[task_id].target
+ return distribution, '', self._sess_config
+
+ sess_config = copy.deepcopy(self._sess_config)
+ distribution.configure(
+ session_config=sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id],
+ sess_config)
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
- d, _ = self._get_test_objects(task_type, task_id, num_gpus)
+ d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target,
+ config=sess_config) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -113,7 +106,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/job:worker/task:0'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
# The variable x is on the task 1 since the device_function has been
@@ -125,18 +120,26 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(y.device, '/job:ps/task:1')
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(z.device, '/job:ps/task:0')
self.assertNotEqual(z.device, x.device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, worker_device + '/' + last_part_device)
@@ -174,18 +177,14 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testDeviceAssignmentDistributed(self, num_gpus):
- self._test_device_assignment_distributed('worker', 1, num_gpus)
-
def _test_device_assignment_local(self,
d,
compute_device='CPU',
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target,
+ config=self._sess_config) as sess, \
d.scope():
def model_fn():
@@ -214,7 +213,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/device:GPU:2'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
self.assertEqual(
@@ -224,19 +225,27 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(
device_util.canonicalize(y.device), tower_variable_device)
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(
device_util.canonicalize(z.device), tower_variable_device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, tower_compute_device)
@@ -268,47 +277,43 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- def testDeviceAssignmentLocalCPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=0)
- self._test_device_assignment_local(
- distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
-
- def testDeviceAssignmentLocalOneGPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=1)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
-
- def testDeviceAssignmentLocalTwoGPUs(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=2)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
-
def _test_simple_increment(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
- num_workers = len(d._cluster_spec.as_dict().get('worker',
- ['dummy_worker']))
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if 'chief' in d._cluster_spec.as_dict():
+ num_workers += 1
else:
num_workers = 1
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
def model_fn():
- x = variable_scope.get_variable('x', initializer=10.0)
- y = variable_scope.get_variable('y', initializer=20.0)
-
- x_add = x.assign_add(1.0, use_locking=True)
- y_add = y.assign_add(1.0, use_locking=True)
-
- train_op = control_flow_ops.group([x_add, y_add])
- return x, y, train_op
-
- x, y, train_op = d.call_for_each_tower(model_fn)
- train_op = d.group(d.unwrap(train_op))
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ z = variable_scope.get_variable(
+ 'z', initializer=30.0,
+ aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER)
+
+ # We explicitly make a constant tensor here to avoid complaints about
+ # summing non-distributed values.
+ one = constant_op.constant(1.0)
+ x_add = x.assign_add(one, use_locking=True)
+ y_add = y.assign_add(one, use_locking=True)
+ z_add = z.assign_add(one, use_locking=True)
+
+ train_op = control_flow_ops.group(x_add, y_add, z_add)
+ return x, y, z, train_op
+
+ x, y, z, train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(train_op)
if context.num_gpus() < d._num_gpus_per_worker:
return True
@@ -334,16 +339,25 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._finish_condition.notify_all()
self._finish_condition.release()
- x_val, y_val = sess.run([x, y])
+ x_val, y_val, z_val = sess.run([x, y, z])
self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers)
self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers)
+ self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
- y_val == 20.0 + 1.0 * num_workers * d.num_towers)
+ y_val == 20.0 + 1.0 * num_workers * d.num_towers and
+ z_val == 30.0 + 1.0 * num_workers)
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
+ assert hasattr(d, '_cluster_spec') and d._cluster_spec
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if CHIEF in d._cluster_spec.as_dict():
+ num_workers += 1
+
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
l = core.Dense(1, use_bias=False)
@@ -390,13 +404,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_id == 0:
+ if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id):
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
self._init_condition.acquire()
self._init_reached += 1
- while self._init_reached != 3:
+ while self._init_reached != num_workers:
self._init_condition.wait()
self._init_condition.notify_all()
self._init_condition.release()
@@ -413,9 +427,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertLess(error_after, error_before)
return error_after < error_before
+
+class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2)
+ cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
+
+ def testDeviceAssignmentLocalCPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=0)
+ self._test_device_assignment_local(
+ distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+
+ def testDeviceAssignmentLocalOneGPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=1)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+
+ def testDeviceAssignmentLocalTwoGPUs(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testDeviceAssignmentDistributed(self, num_gpus):
+ self._test_device_assignment_distributed('worker', 1, num_gpus)
+
def testSimpleBetweenGraph(self):
self._run_between_graph_clients(self._test_simple_increment,
- self._cluster_spec, 0)
+ self._cluster_spec, context.num_gpus())
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
@@ -429,5 +476,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._cluster_spec, num_gpus)
+class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2, has_chief=True)
+ cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
+
+ def testSimpleBetweenGraph(self):
+ self._run_between_graph_clients(self._test_simple_increment,
+ self._cluster_spec, context.num_gpus())
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def testGlobalStepIsWrapped(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ with ops.Graph().as_default(), distribution.scope():
+ created_step = training_util.create_global_step()
+ get_step = training_util.get_global_step()
+ self.assertEqual(created_step, get_step,
+ msg=('created_step %s type %s vs. get_step %s type %s' %
+ (id(created_step), created_step.__class__.__name__,
+ id(get_step), get_step.__class__.__name__)))
+ self.assertIs(values.AggregatingVariable, type(created_step))
+ self.assertIs(values.AggregatingVariable, type(get_step))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
index a68dbce6c7..bb10b546a1 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -37,7 +37,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@ class PrefetchingOpsV2Test(test.TestCase):
next_element = iterator.get_next()
output = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
result = sess.run(next_element)
self.assertEqual(2, len(result))
@@ -75,7 +75,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(5):
sess.run(next_element)
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 8605ab1f7d..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -49,7 +49,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
run_step = single_loss_step
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 371b97ba96..5d498fb629 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -130,7 +130,8 @@ class DistributionTestBase(test.TestCase):
# Error should go down
self.assertLess(error_after, error_before)
- def _test_minimize_loss_graph(self, d, soft_placement=False):
+ def _test_minimize_loss_graph(self, d, soft_placement=False,
+ learning_rate=0.2):
config = config_pb2.ConfigProto()
config.allow_soft_placement = soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
@@ -150,7 +151,7 @@ class DistributionTestBase(test.TestCase):
grad_fn = backprop.implicit_grad(loss)
def update(v, g):
- return v.assign_sub(0.2 * g)
+ return v.assign_sub(learning_rate * g)
one = d.broadcast(constant_op.constant([[1.]]))
@@ -189,7 +190,8 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out,
+ "/device:CPU:0")
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 3f8a0922de..6ba83976fc 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -37,7 +37,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
-from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@@ -46,13 +45,13 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
master = tpu_cluster_resolver.master()
# pylint: disable=protected-access
- cluster_def = (tpu_cluster_resolver.cluster_spec()
- or server_lib.ClusterSpec({})).as_cluster_def()
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
- query_topology=True))
+ query_topology=False))
return tpu_system_metadata
@@ -60,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self, tpu_cluster_resolver, steps_per_run):
+ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None):
"""Initializes the TPUStrategy object.
Args:
@@ -71,68 +70,101 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
+ num_cores: Number of cores to use on the TPU. If None specified, then
+ auto-detect the cores and topology of the TPU system.
"""
- # TODO(isaprykin): Generalize the defaults. They are currently tailored for
- # the unit test.
+ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the
+ # master node fetched from the cluster resolver.
super(TPUStrategy, self).__init__('/device:CPU:0')
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ # TODO(sourabhbajaj): Change this from num_cores to metadata_override
+ self._num_cores_override = num_cores
- # TODO(priyag): This should not be hardcoded here.
- self._host = '/device:CPU:0'
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
- def distribute_dataset(self, dataset_fn):
- # TODO(priyag): Perhaps distribute across cores here.
- return self._call_dataset_fn(dataset_fn)
+ def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes,
+ iterations):
+ """Create an enqueue op for a single host identified using host_id.
- # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
- # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
- # a mechanism to infer the outputs of `fn`. Pending b/110550782.
- def _run_steps_on_dataset(self, fn, iterator, iterations,
- initial_loop_values=None):
+ The while_loop op returned will run `iterations` times and in each run
+ enqueue batches for each shard.
- shapes = nest.flatten(iterator.output_shapes)
- if any([not s.is_fully_defined() for s in shapes]):
- raise ValueError(
- 'TPU currently requires fully defined shapes. Either use '
- 'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
- types = nest.flatten(iterator.output_types)
+ Args:
+ host_id: integer, id of the host to run the enqueue ops on.
+ iterator: `tf.data` iterator to read the input data.
+ input_shapes: shape of inputs to be enqueue on the queue. This is same as
+ the value of `nest.flatten(iterator.output_shapes)`.
+ iterations: integer, number of iterations to be run; determines the
+ number of batches to be enqueued.
+
+ Returns:
+ while_loop_op running `iterations` times; in each run we enqueue a batch
+ on the infeed queue from the host with id `host_id` for each device shard.
+ """
+ host = self.get_host_cpu_device(host_id)
- def enqueue_ops_fn():
+ def _infeed_enqueue_ops_fn():
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
- with ops.device(self._host):
- for _ in range(self.num_towers):
+ enqueue_ops = []
+
+ with ops.device(host):
+ for _ in range(self.num_towers_per_host):
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
inputs = nest.flatten(iterator.get_next())
control_deps.extend(inputs)
sharded_inputs.append(inputs)
- enqueue_ops = []
for core_id, shard_input in enumerate(sharded_inputs):
enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
- inputs=shard_input, shapes=shapes, device_ordinal=core_id))
+ inputs=shard_input,
+ shapes=input_shapes,
+ device_ordinal=core_id))
return enqueue_ops
def enqueue_ops_loop_body(i):
- with ops.control_dependencies(enqueue_ops_fn()):
+ """Callable for the loop body of the while_loop instantiated below."""
+ with ops.control_dependencies(_infeed_enqueue_ops_fn()):
return i + 1
- with ops.device(self._host):
- enqueue_ops = control_flow_ops.while_loop(
+ with ops.device(host):
+ enqueue_op_per_host = control_flow_ops.while_loop(
lambda i: i < iterations,
enqueue_ops_loop_body,
[constant_op.constant(0)],
parallel_iterations=1)
+ return enqueue_op_per_host
+
+ def distribute_dataset(self, dataset_fn):
+ # TODO(priyag): Perhaps distribute across cores here.
+ return self._call_dataset_fn(dataset_fn)
+
+ # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
+ # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
+ # a mechanism to infer the outputs of `fn`. Pending b/110550782.
+ def _run_steps_on_dataset(self, fn, iterator, iterations,
+ initial_loop_values=None):
+
+ shapes = nest.flatten(iterator.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'TPU currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ types = nest.flatten(iterator.output_types)
+
+ enqueue_ops = [
+ self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations)
+ for host_id in range(self.num_hosts)]
+
def dequeue_fn():
dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
return nest.pack_sequence_as(iterator.output_shapes, dequeued)
@@ -143,6 +175,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
def run_fn(*args, **kwargs):
+ """Single step on the TPU device."""
del args, kwargs
fn_inputs = dequeue_fn()
if not isinstance(fn_inputs, tuple):
@@ -160,8 +193,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def iterate_on_tpu():
return training_loop.repeat(iterations, run_fn, initial_loop_values)
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop and TPU replicate context. This is useful in cases
+ # where we might need to exit these contexts and get back to the outer
+ # context to do some things, for e.g. create an op which should be
+ # evaluated only once at the end of the loop on the host. One such usage
+ # is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
replicate_inputs = [[]] * self.num_towers
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
# Filter out any ops from the outputs, typically this would be the case
@@ -224,6 +267,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ 'Currently only support sum & mean in TPUStrategy.')
return tpu_ops.cross_replica_sum(value)
cf_context = cf_context.outer_context
@@ -233,10 +279,12 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
- self._host)
+ self.get_host_cpu_device(0))
else:
raise ValueError('Multiple devices are not supported for TPUStrategy')
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return value[0]
output = math_ops.add_n(value)
if aggregation == vs.VariableAggregation.MEAN:
return output * (1. / len(value))
@@ -249,4 +297,31 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
@property
def num_towers(self):
+ return self._num_cores_override or self._tpu_metadata.num_cores
+
+ @property
+ def num_hosts(self):
+ return self._tpu_metadata.num_hosts
+
+ @property
+ def num_towers_per_host(self):
return self._tpu_metadata.num_of_cores_per_host
+
+ def get_host_cpu_device(self, host_id):
+ if self._tpu_cluster_resolver.get_master() in ('', 'local'):
+ return '/replica:0/task:0/device:CPU:0'
+ job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+ return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del cluster_spec, task_type, task_id
+ if session_config:
+ session_config.isolate_session_state = True
+ cluster_spec = self._tpu_cluster_resolver.cluster_spec()
+ if cluster_spec:
+ session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8548a86421..fafa6384a1 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate):
return self._index[device]
return list(self._index.values())[0]
+ def _as_graph_element(self):
+ obj = self.get()
+ # pylint: disable=protected-access
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ return conv_fn()
+ return obj
+
def _assign_on_device(device, variable, tensor):
with ops.device(device):
@@ -296,6 +304,10 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
def read_value(self):
return distribution_strategy_context.get_distribution_strategy().read_var(
self)
@@ -308,26 +320,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
-def _get_update_device():
- """Validate we are in update/update_non_slot() and return current device.
-
- This is used in MirroredVariable.assign* members, to make sure they
- are only called via an update method, to make sure all components of the
- variable are being updated in a consistent way.
-
- Returns:
- A string device.
-
- Raises:
- RuntimeError: If not in distribution.update()/.update_non_slot().
- """
- device = distribute_lib.get_update_device()
- if device is None:
- raise RuntimeError(
- "Use DistributionStrategy.update() to modify a MirroredVariable.")
- return device
-
-
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
@@ -348,10 +340,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var, aggregation):
- # Use a weakref to make it easy to map from the contained values
- # to the container without introducing a reference cycle.
- for v in six.itervalues(index):
- v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -366,15 +354,27 @@ class MirroredVariable(DistributedVariable, Mirrored,
f = kwargs.pop("f")
if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
- # We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
- # We are calling an assign function on the mirrored variable in cross
- # tower context.
+ # We are calling an assign function on the mirrored variable in an
+ # update context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribution_strategy_context.get_distribution_strategy().update(
- self, f, *args, **kwargs)
+ # We are calling assign on the mirrored variable in cross tower context,
+ # use update to update the variable.
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ updates = strategy.update(self, f, *args, **kwargs)
+ grouped = strategy.group(updates)
+ if isinstance(updates, DistributedValues) and updates.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(mirrored_var.assign*(...)) may only update one tower.
+ index = {}
+ for d in updates.devices:
+ with ops.device(d), ops.control_dependencies([grouped]):
+ index[d] = array_ops.identity(updates.get(d))
+ return Mirrored(index)
+ else:
+ return grouped
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -519,6 +519,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self._aggregation
def _get_cross_tower(self):
+ if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self._primary_var
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
@@ -1057,3 +1059,160 @@ def value_container(val):
if container is not None:
return container
return val
+
+
+# TODO(josh11b): Descend from Variable.
+class AggregatingVariable(checkpointable.CheckpointableBase):
+ """A wrapper around a variable that aggregates updates across towers."""
+
+ def __init__(self, v, aggregation):
+ self._v = v
+ # TODO(josh11b): Set v._distributed_container?
+ # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ self._aggregation = aggregation
+
+ def get(self):
+ return self._v
+
+ def __getattr__(self, name):
+ return getattr(self._v, name)
+
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ if update_device is not None:
+ # We are calling an assign function in an update context.
+ return f(self._v, *args, **kwargs)
+
+ # We are calling an assign function in cross tower context, wrap it in an
+ # update call.
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ assert distribution_strategy_context.get_tower_context()
+ # We are calling an assign function in tower context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function with the reduced value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "a variable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ def assign_sub(self, *args, **kwargs):
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def name(self):
+ return self._v.name
+
+ @property
+ def dtype(self):
+ return self._v.dtype
+
+ # TODO(josh11b): Test saving & restoring.
+ def _gather_saveables_for_checkpoint(self):
+ return {checkpointable.VARIABLE_VALUE_KEY: self._v}
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self._v + o
+ def __radd__(self, o): return o + self._v
+ def __sub__(self, o): return self._v - o
+ def __rsub__(self, o): return o - self._v
+ def __mul__(self, o): return self._v * o
+ def __rmul__(self, o): return o * self._v
+ def __truediv__(self, o): return self._v / o
+ def __rtruediv__(self, o): return o / self._v
+ def __floordiv__(self, o): return self._v // o
+ def __rfloordiv__(self, o): return o // self._v
+ def __mod__(self, o): return self._v % o
+ def __rmod__(self, o): return o % self._v
+ def __lt__(self, o): return self._v < o
+ def __le__(self, o): return self._v <= o
+ def __gt__(self, o): return self._v > o
+ def __ge__(self, o): return self._v >= o
+ def __and__(self, o): return self._v & o
+ def __rand__(self, o): return o & self._v
+ def __or__(self, o): return self._v | o
+ def __ror__(self, o): return o | self._v
+ def __xor__(self, o): return self._v ^ o
+ def __rxor__(self, o): return o ^ self._v
+ def __getitem__(self, o): return self._v[o]
+ def __pow__(self, o, modulo=None): return pow(self._v, o, modulo)
+ def __rpow__(self, o): return pow(o, self._v)
+ def __invert__(self): return ~self._v
+ def __neg__(self): return -self._v
+ def __abs__(self): return abs(self._v)
+
+ def __div__(self, o):
+ try:
+ return self._v.__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self._v.__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self._v.__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self._v.__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __str__(self):
+ return str(self._v)
+
+ def __repr__(self):
+ return repr(self._v)
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(
+ AggregatingVariable, _tensor_conversion_aggregate)
+ops.register_dense_tensor_like_type(AggregatingVariable)
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 91a43d4999..15a85a28f5 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -521,6 +521,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
return worker_device_map, devices
def testDataDistributionOneDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -528,6 +529,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testDataDistributionTwoDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
worker_device_map, devices = self._cpu_and_one_gpu_devices()
@@ -537,6 +539,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 2, 1, 3], [4, 6, 5, 7]])
def testTupleDataset(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
@@ -553,6 +556,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
expected_values)
def testInitializableIterator(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -570,6 +574,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testValueErrorForIterator(self):
+ self.skipTest("Temporarily disabled.")
# Incompatiable arguments.
with self.assertRaises(ValueError):
values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})
@@ -653,7 +658,7 @@ class MirroredVariableTest(test.TestCase):
def _save_mirrored(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -668,7 +673,7 @@ class MirroredVariableTest(test.TestCase):
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
@@ -684,7 +689,7 @@ class MirroredVariableTest(test.TestCase):
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
@@ -698,7 +703,7 @@ class MirroredVariableTest(test.TestCase):
def _restore_mirrored(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -864,7 +869,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
@@ -881,7 +886,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_sum(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local("sum")
# Overwrite the initial values.
@@ -897,7 +902,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
@@ -913,7 +918,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
@@ -927,7 +932,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
@@ -942,7 +947,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
index d8bacdb338..5d57d144c1 100644
--- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py
+++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
@@ -56,7 +56,7 @@ class WarmStartingUtilWithDistributionStrategyTest(
# Create variable and save checkpoint from which to warm-start.
def create_var(g):
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
var = variable_scope.get_variable(var_name, initializer=original_value)
sess.run(variables.global_variables_initializer())
saver = saver_lib.Saver()
@@ -75,7 +75,7 @@ class WarmStartingUtilWithDistributionStrategyTest(
self.assertAllEqual(original_value, prev_init_val)
def warm_start(g):
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
var_name, initializer=[[0., 0.], [0., 0.]])
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index a8d0d493ab..97c53ae2b9 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -445,7 +445,7 @@ cuda_py_test(
cuda_py_test(
name = "sinh_arcsinh_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/sinh_arcsinh_test.py"],
additional_deps = [
":distributions_py",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
index 0928dc3f35..a22d4d825b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
@@ -53,7 +53,7 @@ class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testSampleAndLogProbConsistency(self):
batch_shape = []
event_size = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
sample0 = array_ops.zeros(batch_event_shape)
affine = Affine(scale_tril=self._random_scale_tril(event_size))
@@ -67,7 +67,7 @@ class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase):
sample_shape = np.int32([4, 5])
batch_shape = np.int32([])
event_size = np.int32(2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
sample0 = array_ops.zeros(batch_event_shape)
affine = Affine(scale_tril=self._random_scale_tril(event_size))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
index f2bb2d3325..62623deccd 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -76,7 +76,7 @@ class _BatchReshapeTest(object):
wishart.log_prob(x), expected_log_prob_shape)
actual_log_prob = reshape_wishart.log_prob(expected_sample)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
batch_shape_,
event_shape_,
@@ -132,7 +132,7 @@ class _BatchReshapeTest(object):
wishart.variance(), expected_matrix_stat_shape)
actual_variance = reshape_wishart.variance()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
expected_entropy_, actual_entropy_,
expected_mean_, actual_mean_,
@@ -202,7 +202,7 @@ class _BatchReshapeTest(object):
normal.log_prob(x), expected_log_prob_shape)
actual_log_prob = reshape_normal.log_prob(expected_sample)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
batch_shape_,
event_shape_,
@@ -255,7 +255,7 @@ class _BatchReshapeTest(object):
normal.variance(), expected_scalar_stat_shape)
actual_variance = reshape_normal.variance()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
expected_entropy_, actual_entropy_,
expected_mean_, actual_mean_,
@@ -323,7 +323,7 @@ class _BatchReshapeTest(object):
mvn.log_prob(x), expected_log_prob_shape)
actual_log_prob = reshape_mvn.log_prob(expected_sample)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
batch_shape_,
event_shape_,
@@ -385,7 +385,7 @@ class _BatchReshapeTest(object):
mvn.covariance(), expected_matrix_stat_shape)
actual_covariance = reshape_mvn.covariance()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[
expected_entropy_, actual_entropy_,
expected_mean_, actual_mean_,
@@ -447,7 +447,7 @@ class _BatchReshapeTest(object):
validate_args=True)
else:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(r"Shape sizes do not match."):
batch_reshape_lib.BatchReshape(
distribution=mvn,
@@ -482,7 +482,7 @@ class _BatchReshapeTest(object):
validate_args=True)
else:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(r".*must be >=-1.*"):
batch_reshape_lib.BatchReshape(
distribution=mvn,
@@ -512,7 +512,7 @@ class _BatchReshapeTest(object):
validate_args=True)
else:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(r".*must be a vector.*"):
batch_reshape_lib.BatchReshape(
distribution=mvn,
@@ -548,11 +548,11 @@ class _BatchReshapeTest(object):
return
with self.assertRaisesOpError("too few batch and event dims"):
- with self.test_session():
+ with self.cached_session():
poisson_141_reshaped.log_prob(x_4).eval()
with self.assertRaisesOpError("unexpected batch and event shape"):
- with self.test_session():
+ with self.cached_session():
poisson_141_reshaped.log_prob(x_114).eval()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
index 042c8ebd51..372b7e37b7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
@@ -31,7 +31,7 @@ class AbsoluteValueTest(test.TestCase):
"""Tests correctness of the absolute value bijector."""
def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
self.assertEqual("absolute_value", bijector.name)
x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3]
@@ -54,13 +54,13 @@ class AbsoluteValueTest(test.TestCase):
y, event_ndims=0)))
def testNegativeYRaisesForInverseIfValidateArgs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
with self.assertRaisesOpError("y was negative"):
sess.run(bijector.inverse(-1.))
def testNegativeYRaisesForILDJIfValidateArgs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
with self.assertRaisesOpError("y was negative"):
sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
index 1e4ad724d0..a7bd51430e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class AffineLinearOperatorTest(test.TestCase):
def testIdentity(self):
- with self.test_session():
+ with self.cached_session():
affine = AffineLinearOperator(
validate_args=True)
x = np.array([[1, 0, -1], [2, 3, 4]], dtype=np.float32)
@@ -45,7 +45,7 @@ class AffineLinearOperatorTest(test.TestCase):
affine.forward_log_det_jacobian(x, event_ndims=2).eval())
def testDiag(self):
- with self.test_session():
+ with self.cached_session():
shift = np.array([-1, 0, 1], dtype=np.float32)
diag = np.array([[1, 2, 3],
[2, 5, 6]], dtype=np.float32)
@@ -67,7 +67,7 @@ class AffineLinearOperatorTest(test.TestCase):
affine.forward_log_det_jacobian(x, event_ndims=1).eval())
def testTriL(self):
- with self.test_session():
+ with self.cached_session():
shift = np.array([-1, 0, 1], dtype=np.float32)
tril = np.array([[[3, 0, 0],
[2, -1, 0],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
index d2533620be..bc6752a69d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
@@ -31,14 +31,14 @@ class AffineScalarBijectorTest(test.TestCase):
"""Tests correctness of the Y = scale @ x + shift transformation."""
def testProperties(self):
- with self.test_session():
+ with self.cached_session():
mu = -1.
# scale corresponds to 1.
bijector = AffineScalar(shift=mu)
self.assertEqual("affine_scalar", bijector.name)
def testNoBatchScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -60,7 +60,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -83,7 +83,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -106,7 +106,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testTwoBatchScalarIdentityViaIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -129,7 +129,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testTwoBatchScalarIdentityViaScale(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -152,7 +152,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = AffineScalar(shift=3.6, scale=0.42)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
index 9e14b9a53e..dc18eb3df6 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
@@ -32,14 +32,14 @@ class AffineBijectorTest(test.TestCase):
"""Tests correctness of the Y = scale @ x + shift transformation."""
def testProperties(self):
- with self.test_session():
+ with self.cached_session():
mu = -1.
# scale corresponds to 1.
bijector = Affine(shift=mu)
self.assertEqual("affine", bijector.name)
def testNoBatchMultivariateIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -71,7 +71,7 @@ class AffineBijectorTest(test.TestCase):
0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateDiag(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -114,7 +114,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateFullDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -137,7 +137,7 @@ class AffineBijectorTest(test.TestCase):
feed_dict))
def testBatchMultivariateIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -161,7 +161,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testBatchMultivariateDiag(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -185,7 +185,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testBatchMultivariateFullDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -209,7 +209,7 @@ class AffineBijectorTest(test.TestCase):
x, event_ndims=1), feed_dict))
def testIdentityWithDiagUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -235,7 +235,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -261,7 +261,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testDiagWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -285,7 +285,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityAndDiagWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -312,7 +312,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -349,7 +349,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testDiagWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -385,7 +385,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testTriLWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -422,7 +422,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testTriLWithVDVTUpdateNoDiagonal(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -459,7 +459,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateRaisesWhenSingular(self):
- with self.test_session():
+ with self.cached_session():
mu = [1., -1]
bijector = Affine(
shift=mu,
@@ -531,7 +531,7 @@ class AffineBijectorTest(test.TestCase):
itertools.combinations(s, r) for r in range(len(s) + 1))
for args in _powerset(scale_params.items()):
- with self.test_session():
+ with self.cached_session():
args = dict(args)
scale_args = dict({"x": x}, **args)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
index c832fcaa68..bf61e9f2fe 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
@@ -69,7 +69,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
]
for input_shape, event_dims, training in params:
x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(x_)
# When training, memorize the exact mean of the last
# minibatch that it normalized (instead of moving average assignment).
@@ -145,7 +145,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testMaximumLikelihoodTraining(self):
# Test Maximum Likelihood training with default bijector.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
batch_norm = BatchNormalization(training=True)
dist = transformed_distribution_lib.TransformedDistribution(
@@ -176,7 +176,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
def testLogProb(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
@@ -196,7 +196,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testMutuallyConsistent(self):
# BatchNorm bijector is only mutually consistent when training=False.
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
dist = transformed_distribution_lib.TransformedDistribution(
@@ -215,7 +215,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testInvertMutuallyConsistent(self):
# BatchNorm bijector is only mutually consistent when training=False.
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = Invert(
BatchNormalization(batchnorm_layer=layer, training=False))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index dc45114b1c..ada99ec9c6 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -46,7 +46,7 @@ class ChainBijectorTest(test.TestCase):
"""Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain((Exp(), Softplus()))
self.assertEqual("chain_of_exp_of_softplus", chain.name)
x = np.asarray([[[1., 2.],
@@ -61,7 +61,7 @@ class ChainBijectorTest(test.TestCase):
chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testBijectorIdentity(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain()
self.assertEqual("identity", chain.name)
x = np.asarray([[[1., 2.],
@@ -74,13 +74,13 @@ class ChainBijectorTest(test.TestCase):
0., chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain((Exp(), Softplus()))
assert_scalar_congruency(
chain, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain([
SoftmaxCentered(validate_args=True),
SoftmaxCentered(validate_args=True),
@@ -195,7 +195,7 @@ class ChainBijectorTest(test.TestCase):
dtype=np.float32, shape=[None, 10], name="samples")
ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
self.assertTrue(ildj is not None)
- with self.test_session():
+ with self.cached_session():
ildj.eval({samples: np.zeros([2, 10], np.float32)})
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index d1ce273499..9681b64ced 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -30,7 +30,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
"""Tests the correctness of the Y = X @ X.T transformation."""
def testBijectorMatrix(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.CholeskyOuterProduct(validate_args=True)
self.assertEqual("cholesky_outer_product", bijector.name)
x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]]
@@ -75,7 +75,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
bijector = bijectors.CholeskyOuterProduct()
x_pl = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2)
# The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
@@ -86,7 +86,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
def testNoBatchStatic(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
@@ -98,7 +98,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
def testNoBatchDeferred(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
@@ -119,7 +119,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
[2, 5]],
[[9., 3],
[3, 5]]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
@@ -137,7 +137,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
[2, 5]],
[[9., 3],
[3, 5]]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
index 7be939cd27..d2c00865e7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
@@ -30,7 +30,7 @@ class ExpBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = exp(X) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
self.assertEqual("exp", bijector.name)
x = [[[1.], [2.]]]
@@ -48,13 +48,13 @@ class ExpBijectorTest(test.TestCase):
x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=1.5, rtol=0.05)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
x = np.linspace(-10, 10, num=10).astype(np.float32)
y = np.logspace(-10, 10, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
index 54e54c3296..b9cdbfb823 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
@@ -31,7 +31,7 @@ class GumbelBijectorTest(test.TestCase):
"""Tests correctness of the Gumbel bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
loc = 0.3
scale = 5.
bijector = Gumbel(loc=loc, scale=scale, validate_args=True)
@@ -52,12 +52,12 @@ class GumbelBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Gumbel(loc=0., scale=3.0, validate_args=True)
x = np.linspace(-10., 10., num=10).astype(np.float32)
y = np.linspace(0.01, 0.99, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
index 7d3bd758cd..c9bccb36fc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
@@ -32,7 +32,7 @@ class InlineBijectorTest(test.TestCase):
"""Tests correctness of the inline constructed bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
exp = Exp()
inline = Inline(
forward_fn=math_ops.exp,
@@ -55,7 +55,7 @@ class InlineBijectorTest(test.TestCase):
inline.forward_log_det_jacobian(x, event_ndims=1).eval())
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
bijector = Inline(
forward_event_shape_tensor_fn=lambda x: array_ops.concat((x, [1]), 0),
forward_event_shape_fn=lambda x: x.as_list() + [1],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
index 8b14c8327f..7e3340aeb0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
@@ -31,7 +31,7 @@ class InvertBijectorTest(test.TestCase):
"""Tests the correctness of the Y = Invert(bij) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
for fwd in [
bijectors.Identity(),
bijectors.Exp(),
@@ -53,13 +53,13 @@ class InvertBijectorTest(test.TestCase):
rev.forward_log_det_jacobian(x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Invert(bijectors.Exp())
assert_scalar_congruency(
bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True))
x = tensor_shape.TensorShape([2])
y = tensor_shape.TensorShape([1])
@@ -73,7 +73,7 @@ class InvertBijectorTest(test.TestCase):
bijector.inverse_event_shape_tensor(y.as_list()).eval())
def testDocstringExample(self):
- with self.test_session():
+ with self.cached_session():
exp_gamma_distribution = (
transformed_distribution_lib.TransformedDistribution(
distribution=gamma_lib.Gamma(concentration=1., rate=2.),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
index a8089881f6..b3fb50005e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
@@ -30,7 +30,7 @@ class KumaraswamyBijectorTest(test.TestCase):
"""Tests correctness of the Kumaraswamy bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
a = 2.
b = 0.3
bijector = Kumaraswamy(
@@ -54,13 +54,13 @@ class KumaraswamyBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Kumaraswamy(concentration1=0.5, concentration0=1.1),
lower_x=0., upper_x=1., n=int(10e3), rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
concentration1 = 1.2
concentration0 = 2.
bijector = Kumaraswamy(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
index 5ba5a2083b..ad4329d425 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
@@ -71,7 +71,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testBijector(self):
x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs)
@@ -102,7 +102,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs)
@@ -121,7 +121,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testInvertMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = Invert(MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
index 49a9afe3f6..31ee36f024 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class MatrixInverseTriLBijectorTest(test.TestCase):
"""Tests the correctness of the Y = inv(tril) transformation."""
@@ -40,7 +41,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0
return y
- @test_util.run_in_graph_and_eager_modes
def testComputesCorrectValues(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
self.assertEqual("matrix_inverse_tril", inv.name)
@@ -62,7 +62,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testOneByOneMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[5.]], dtype=np.float32)
@@ -81,7 +80,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testZeroByZeroMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.eye(0, dtype=np.float32)
@@ -100,7 +98,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testBatch(self):
# Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
# (2, 1).
@@ -125,20 +122,18 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testErrorOnInputRankTooLow(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([0.1], dtype=np.float32)
rank_error_msg = "must have rank at least 2"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
# TODO(b/80481923): Figure out why these assertions fail, and fix them.
## def testErrorOnInputNonSquare(self):
@@ -146,55 +141,50 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
## x_ = np.array([[1., 2., 3.],
## [4., 5., 6.]], dtype=np.float32)
## square_error_msg = "must be a square matrix"
- ## with self.test_session():
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputNotLowerTriangular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 2.],
[3., 4.]], dtype=np.float32)
triangular_error_msg = "must be lower triangular"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputSingular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 0.],
[0., 0.]], dtype=np.float32)
nonsingular_error_msg = "must have all diagonal entries nonzero"
- with self.test_session():
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
index cb42331a21..9a88f8f1bc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
@@ -38,26 +38,25 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorVector(self):
- with self.test_session():
- ordered = Ordered()
- self.assertEqual("ordered", ordered.name)
- x = np.asarray([[2., 3, 4], [4., 8, 13]])
- y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
- self.assertAllClose(y, self.evaluate(ordered.forward(x)))
- self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
- self.assertAllClose(
- np.sum(np.asarray(y)[..., 1:], axis=-1),
- self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
- self.assertAllClose(
- self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
+ ordered = Ordered()
+ self.assertEqual("ordered", ordered.name)
+ x = np.asarray([[2., 3, 4], [4., 8, 13]])
+ y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
+ self.assertAllClose(y, self.evaluate(ordered.forward(x)))
+ self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
+ self.assertAllClose(
+ np.sum(np.asarray(y)[..., 1:], axis=-1),
+ self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
+ self.assertAllClose(
+ self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
def testBijectorUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
ordered = Ordered()
self.assertEqual("ordered", ordered.name)
x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
@@ -84,21 +83,20 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testShapeGetters(self):
- with self.test_session():
- x = tensor_shape.TensorShape([4])
- y = tensor_shape.TensorShape([4])
- bijector = Ordered(validate_args=True)
- self.assertAllEqual(y, bijector.forward_event_shape(x))
- self.assertAllEqual(y.as_list(),
- self.evaluate(bijector.forward_event_shape_tensor(
- x.as_list())))
- self.assertAllEqual(x, bijector.inverse_event_shape(y))
- self.assertAllEqual(x.as_list(),
- self.evaluate(bijector.inverse_event_shape_tensor(
- y.as_list())))
+ x = tensor_shape.TensorShape([4])
+ y = tensor_shape.TensorShape([4])
+ bijector = Ordered(validate_args=True)
+ self.assertAllEqual(y, bijector.forward_event_shape(x))
+ self.assertAllEqual(y.as_list(),
+ self.evaluate(bijector.forward_event_shape_tensor(
+ x.as_list())))
+ self.assertAllEqual(x, bijector.inverse_event_shape(y))
+ self.assertAllEqual(x.as_list(),
+ self.evaluate(bijector.inverse_event_shape_tensor(
+ y.as_list())))
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
ordered = Ordered()
x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32)
y = (self._rng.randn(3, 10)).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
index 7eef4ab599..e2062ed55d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
@@ -38,7 +38,7 @@ class PermuteBijectorTest(test.TestCase):
expected_x = np.random.randn(4, 2, 3)
expected_y = expected_x[..., expected_permutation]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
permutation_ph = array_ops.placeholder(dtype=dtypes.int32)
bijector = Permute(
permutation=permutation_ph,
@@ -64,7 +64,7 @@ class PermuteBijectorTest(test.TestCase):
self.assertAllClose(0., ildj, rtol=1e-6, atol=0)
def testRaisesOpError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError("Permutation over `d` must contain"):
permutation_ph = array_ops.placeholder(dtype=dtypes.int32)
bijector = Permute(
@@ -77,7 +77,7 @@ class PermuteBijectorTest(test.TestCase):
permutation = np.int32([2, 0, 1])
x = np.random.randn(4, 2, 3)
y = x[..., permutation]
- with self.test_session():
+ with self.cached_session():
bijector = Permute(permutation=permutation, validate_args=True)
assert_bijective_and_finite(
bijector, x, y, event_ndims=1, rtol=1e-6, atol=0)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
index 85d2283013..ef303ab664 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
@@ -30,7 +30,7 @@ class PowerTransformBijectorTest(test.TestCase):
"""Tests correctness of the power transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
c = 0.2
bijector = PowerTransform(power=c, validate_args=True)
self.assertEqual("power_transform", bijector.name)
@@ -48,13 +48,13 @@ class PowerTransformBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = PowerTransform(power=0.2, validate_args=True)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=1.5, rtol=0.05)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = PowerTransform(power=0.2, validate_args=True)
x = np.linspace(-4.999, 10, num=10).astype(np.float32)
y = np.logspace(0.001, 10, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
index 2d52895fbe..b3b7b8535e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
@@ -43,7 +43,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testBijector(self):
x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = RealNVP(
num_masked=4,
validate_args=True,
@@ -78,7 +78,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = RealNVP(
num_masked=3,
validate_args=True,
@@ -98,7 +98,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testInvertMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = Invert(RealNVP(
num_masked=3,
validate_args=True,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
index d44e49b487..79eadf524b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
@@ -50,7 +50,7 @@ class _ReshapeBijectorTest(object):
expected_x = np.random.randn(4, 3, 2)
expected_y = np.reshape(expected_x, [4, 6])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -84,7 +84,7 @@ class _ReshapeBijectorTest(object):
# using the _tensor methods, we should always get a fully-specified
# result since these are evaluated at graph runtime.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
(shape_out_,
shape_in_) = sess.run((
bijector.forward_event_shape_tensor(shape_in),
@@ -103,7 +103,7 @@ class _ReshapeBijectorTest(object):
expected_y_scalar = expected_x_scalar[0]
shape_in, shape_out, feed_dict = self.build_shapes([], [1,])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = Reshape(
event_shape_out=shape_in,
event_shape_in=shape_out, validate_args=True)
@@ -124,7 +124,7 @@ class _ReshapeBijectorTest(object):
def testMultipleUnspecifiedDimensionsOpError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -139,7 +139,7 @@ class _ReshapeBijectorTest(object):
# pylint: disable=invalid-name
def _testInvalidDimensionsOpError(self, expected_error_message):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
bijector = Reshape(
@@ -155,7 +155,7 @@ class _ReshapeBijectorTest(object):
def testValidButNonMatchingInputOpError(self):
x = np.random.randn(4, 3, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -173,7 +173,7 @@ class _ReshapeBijectorTest(object):
def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
x = np.random.randn(4, 3, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -190,7 +190,7 @@ class _ReshapeBijectorTest(object):
x1 = np.random.randn(4, 2, 3)
x2 = np.random.randn(4, 1, 1, 5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
[1, 1, 5])
bijector = Reshape(
@@ -208,7 +208,7 @@ class _ReshapeBijectorTest(object):
expected_x = np.random.randn(4, 6)
expected_y = np.reshape(expected_x, [4, 2, 3])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# one of input/output shapes is partially specified
shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3])
bijector = Reshape(
@@ -227,7 +227,7 @@ class _ReshapeBijectorTest(object):
def testBothShapesPartiallySpecified(self):
expected_x = np.random.randn(4, 2, 3)
expected_y = np.reshape(expected_x, [4, 3, 2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2])
bijector = Reshape(
event_shape_out=shape_out,
@@ -245,7 +245,7 @@ class _ReshapeBijectorTest(object):
def testDefaultVectorShape(self):
expected_x = np.random.randn(4, 4)
expected_y = np.reshape(expected_x, [4, 2, 2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2])
bijector = Reshape(shape_out,
validate_args=True)
@@ -292,7 +292,7 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
def testBijectiveAndFinite(self):
x = np.random.randn(4, 2, 3)
y = np.reshape(x, [4, 1, 2, 3])
- with self.test_session():
+ with self.cached_session():
bijector = Reshape(
event_shape_in=[2, 3],
event_shape_out=[1, 2, 3],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
index cea4a62c22..a6d432753d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
@@ -31,7 +31,7 @@ class SigmoidBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
self.assertEqual("sigmoid", Sigmoid().name)
x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32)
y = special.expit(x)
@@ -45,11 +45,11 @@ class SigmoidBijectorTest(test.TestCase):
x, event_ndims=0).eval(), atol=0., rtol=1e-4)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(Sigmoid(), lower_x=-7., upper_x=7.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
x = np.linspace(-7., 7., 100).astype(np.float32)
eps = 1e-3
y = np.linspace(eps, 1. - eps, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
index 795f1993ba..282619a73b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
@@ -33,7 +33,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
"""Tests correctness of the power transformation."""
def testBijectorVersusNumpyRewriteOfBasicFunctions(self):
- with self.test_session():
+ with self.cached_session():
skewness = 0.2
tailweight = 2.0
bijector = SinhArcsinh(
@@ -58,7 +58,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
atol=0.)
def testLargerTailWeightPutsMoreWeightInTails(self):
- with self.test_session():
+ with self.cached_session():
# Will broadcast together to shape [3, 2].
x = [-1., 1.]
tailweight = [[0.5], [1.0], [2.0]]
@@ -75,7 +75,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertLess(forward_1[1], forward_1[2])
def testSkew(self):
- with self.test_session():
+ with self.cached_session():
# Will broadcast together to shape [3, 2].
x = [-1., 1.]
skewness = [[-1.], [0.], [1.]]
@@ -92,24 +92,24 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertLess(np.abs(y[2, 0]), np.abs(y[2, 1]))
def testScalarCongruencySkewness1Tailweight0p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=1.0, tailweight=0.5, validate_args=True)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05)
def testScalarCongruencySkewnessNeg1Tailweight1p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=-1.0, tailweight=1.5, validate_args=True)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05)
def testBijectiveAndFiniteSkewnessNeg1Tailweight0p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True)
x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace(
-2, 10, 1000))).astype(np.float32)
assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3)
def testBijectiveAndFiniteSkewness1Tailweight3(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True)
x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace(
-2, 5, 1000))).astype(np.float32)
@@ -117,7 +117,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
bijector, x, x, event_ndims=0, rtol=1e-3)
def testBijectorEndpoints(self):
- with self.test_session():
+ with self.cached_session():
for dtype in (np.float32, np.float64):
bijector = SinhArcsinh(
skewness=dtype(0.), tailweight=dtype(1.), validate_args=True)
@@ -129,7 +129,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
bijector, bounds, bounds, event_ndims=0, atol=2e-6)
def testBijectorOverRange(self):
- with self.test_session():
+ with self.cached_session():
for dtype in (np.float32, np.float64):
skewness = np.array([1.2, 5.], dtype=dtype)
tailweight = np.array([2., 10.], dtype=dtype)
@@ -176,12 +176,12 @@ class SinhArcsinhBijectorTest(test.TestCase):
atol=0.)
def testZeroTailweightRaises(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not positive"):
SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval()
def testDefaultDtypeIsFloat32(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh()
self.assertEqual(bijector.tailweight.dtype, np.float32)
self.assertEqual(bijector.skewness.dtype, np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
index 0f0a2fa531..8d18400487 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
@@ -35,7 +35,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation."""
def testBijectorVector(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
self.assertEqual("softmax_centered", softmax.name)
x = np.log([[2., 3, 4], [4., 8, 12]])
@@ -54,7 +54,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
rtol=1e-7)
def testBijectorUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
self.assertEqual("softmax_centered", softmax.name)
x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
@@ -80,7 +80,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
rtol=1e-7)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
x = tensor_shape.TensorShape([4])
y = tensor_shape.TensorShape([5])
bijector = SoftmaxCentered(validate_args=True)
@@ -94,7 +94,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
y.as_list()).eval())
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32)
# Make y values on the simplex with a wide range.
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
index 3d8a0a32bb..e805619041 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
@@ -42,13 +42,13 @@ class SoftplusBijectorTest(test.TestCase):
return -np.log(1 - np.exp(-y))
def testHingeSoftnessZeroRaises(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=0., validate_args=True)
with self.assertRaisesOpError("must be non-zero"):
bijector.forward([1., 1.]).eval()
def testBijectorForwardInverseEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
self.assertEqual("softplus", bijector.name)
x = 2 * rng.randn(2, 10)
@@ -58,7 +58,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.5)
x = 2 * rng.randn(2, 10)
y = 1.5 * self._softplus(x / 1.5)
@@ -67,7 +67,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorLogDetJacobianEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
y = 2 * rng.rand(2, 10)
# No reduction needed if event_dims = 0.
@@ -77,7 +77,7 @@ class SoftplusBijectorTest(test.TestCase):
y, event_ndims=0).eval())
def testBijectorForwardInverseEventDimsOne(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
self.assertEqual("softplus", bijector.name)
x = 2 * rng.randn(2, 10)
@@ -87,7 +87,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorLogDetJacobianEventDimsOne(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
y = 2 * rng.rand(2, 10)
ildj_before = self._softplus_ildj_before_reduction(y)
@@ -97,25 +97,25 @@ class SoftplusBijectorTest(test.TestCase):
y, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testScalarCongruencyWithPositiveHingeSoftness(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.3)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testScalarCongruencyWithNegativeHingeSoftness(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=-1.3)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testBijectiveAndFinite32bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.logspace(-10, 10, 100).astype(np.float32)
@@ -123,7 +123,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.23)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.logspace(-10, 10, 100).astype(np.float32)
@@ -131,7 +131,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=-0.7)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = -np.logspace(-10, 10, 100).astype(np.float32)
@@ -139,7 +139,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFinite16bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
# softplus(-20) is zero, so we can't use such a large range as in 32bit.
x = np.linspace(-10., 20., 100).astype(np.float16)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index d0098c3c10..8dad80aa64 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -43,16 +43,15 @@ class SoftsignBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorBounds(self):
bijector = Softsign(validate_args=True)
- with self.test_session():
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse(-3.).eval()
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval()
-
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse(3.).eval()
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse_log_det_jacobian(3., event_ndims=0).eval()
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse(-3.))
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(-3., event_ndims=0))
+
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse(3.))
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(3., event_ndims=0))
@test_util.run_in_graph_and_eager_modes
def testBijectorForwardInverse(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
index 30c7a738c3..e5550cc830 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
@@ -29,7 +29,7 @@ class SquareBijectorTest(test.TestCase):
"""Tests the correctness of the Y = X ** 2 transformation."""
def testBijectorScalar(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Square(validate_args=True)
self.assertEqual("square", bijector.name)
x = [[[1., 5],
@@ -50,7 +50,7 @@ class SquareBijectorTest(test.TestCase):
rtol=1e-7)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Square(validate_args=True)
assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
index f57adcda89..424eb58fa0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
@@ -31,7 +31,7 @@ class WeibullBijectorTest(test.TestCase):
"""Tests correctness of the weibull bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
scale = 5.
concentration = 0.3
bijector = Weibull(
@@ -54,13 +54,13 @@ class WeibullBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Weibull(scale=20., concentration=0.3),
lower_x=1., upper_x=100., rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Weibull(
scale=20., concentration=2., validate_args=True)
x = np.linspace(1., 8., num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
index d30f6e418d..c317393fbc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class BinomialTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
p = np.float32(np.random.beta(1, 1))
binom = binomial.Binomial(total_count=1., probs=p)
self.assertAllEqual([], binom.event_shape_tensor().eval())
@@ -37,7 +37,7 @@ class BinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), binom.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
p = np.random.beta(1, 1, size=(3, 2)).astype(np.float32)
n = [[3., 2], [4, 5], [6, 7]]
binom = binomial.Binomial(total_count=n, probs=p)
@@ -50,14 +50,14 @@ class BinomialTest(test.TestCase):
def testNProperty(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
- with self.test_session():
+ with self.cached_session():
binom = binomial.Binomial(total_count=n, probs=p)
self.assertEqual((2, 1), binom.total_count.get_shape())
self.assertAllClose(n, binom.total_count.eval())
def testPProperty(self):
p = [[0.1, 0.2, 0.7]]
- with self.test_session():
+ with self.cached_session():
binom = binomial.Binomial(total_count=3., probs=p)
self.assertEqual((1, 3), binom.probs.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
@@ -65,7 +65,7 @@ class BinomialTest(test.TestCase):
def testLogitsProperty(self):
logits = [[0., 9., -0.5]]
- with self.test_session():
+ with self.cached_session():
binom = binomial.Binomial(total_count=3., logits=logits)
self.assertEqual((1, 3), binom.probs.get_shape())
self.assertEqual((1, 3), binom.logits.get_shape())
@@ -74,7 +74,7 @@ class BinomialTest(test.TestCase):
def testPmfAndCdfNandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
binom.prob([2., 3, 2]).eval()
binom.prob([3., 1, 2]).eval()
@@ -92,7 +92,7 @@ class BinomialTest(test.TestCase):
def testPmfAndCdfNonIntegerCounts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
# No errors with integer n.
binom = binomial.Binomial(total_count=n, probs=p, validate_args=True)
binom.prob([2., 3, 2]).eval()
@@ -116,7 +116,7 @@ class BinomialTest(test.TestCase):
binom.cdf([1.0, 2.5, 1.5]).eval()
def testPmfAndCdfBothZeroBatches(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = 0.5
counts = 1.
@@ -129,7 +129,7 @@ class BinomialTest(test.TestCase):
self.assertEqual((), cdf.get_shape())
def testPmfAndCdfBothZeroBatchesNontrivialN(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = 0.1
counts = 3.
@@ -142,7 +142,7 @@ class BinomialTest(test.TestCase):
self.assertEqual((), cdf.get_shape())
def testPmfAndCdfPStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9]]
counts = [[1., 2.]]
binom = binomial.Binomial(total_count=3., probs=p)
@@ -154,7 +154,7 @@ class BinomialTest(test.TestCase):
self.assertEqual((1, 2), cdf.get_shape())
def testPmfAndCdfPStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
p = [0.1, 0.4]
counts = [[1.], [0.]]
binom = binomial.Binomial(total_count=1., probs=p)
@@ -166,7 +166,7 @@ class BinomialTest(test.TestCase):
self.assertEqual((2, 2), cdf.get_shape())
def testBinomialMean(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = binomial.Binomial(total_count=n, probs=p)
@@ -175,7 +175,7 @@ class BinomialTest(test.TestCase):
self.assertAllClose(expected_means, binom.mean().eval())
def testBinomialVariance(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = binomial.Binomial(total_count=n, probs=p)
@@ -184,7 +184,7 @@ class BinomialTest(test.TestCase):
self.assertAllClose(expected_variances, binom.variance().eval())
def testBinomialMode(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
binom = binomial.Binomial(total_count=n, probs=p)
@@ -193,7 +193,7 @@ class BinomialTest(test.TestCase):
self.assertAllClose(expected_modes, binom.mode().eval())
def testBinomialMultipleMode(self):
- with self.test_session():
+ with self.cached_session():
n = 9.
p = [0.1, 0.2, 0.7]
binom = binomial.Binomial(total_count=n, probs=p)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py
index 73747db31c..4411d6f461 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py
@@ -56,7 +56,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(all_true, is_finite)
def _testParamShapes(self, sample_shape, expected):
- with self.test_session():
+ with self.cached_session():
param_shapes = cauchy_lib.Cauchy.param_shapes(sample_shape)
loc_shape, scale_shape = param_shapes["loc"], param_shapes["scale"]
self.assertAllEqual(expected, loc_shape.eval())
@@ -85,7 +85,7 @@ class CauchyTest(test.TestCase):
tensor_shape.TensorShape(sample_shape), sample_shape)
def testCauchyLogPDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
loc = constant_op.constant([3.0] * batch_size)
scale = constant_op.constant([np.sqrt(10.0)] * batch_size)
@@ -112,7 +112,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testCauchyLogPDFMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
loc = constant_op.constant([[3.0, -3.0]] * batch_size)
scale = constant_op.constant(
@@ -144,7 +144,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testCauchyCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
loc = self._rng.randn(batch_size)
scale = self._rng.rand(batch_size) + 1.0
@@ -162,7 +162,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
def testCauchySurvivalFunction(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
loc = self._rng.randn(batch_size)
scale = self._rng.rand(batch_size) + 1.0
@@ -181,7 +181,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(expected_sf, sf.eval(), atol=0)
def testCauchyLogCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
loc = self._rng.randn(batch_size)
scale = self._rng.rand(batch_size) + 1.0
@@ -214,14 +214,14 @@ class CauchyTest(test.TestCase):
]:
value = func(x)
grads = gradients_impl.gradients(value, [loc, scale])
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(value)
self.assertAllFinite(grads[0])
self.assertAllFinite(grads[1])
def testCauchyLogSurvivalFunction(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
loc = self._rng.randn(batch_size)
scale = self._rng.rand(batch_size) + 1.0
@@ -241,7 +241,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
def testCauchyEntropy(self):
- with self.test_session():
+ with self.cached_session():
loc = np.array([1.0, 1.0, 1.0])
scale = np.array([[1.0, 2.0, 3.0]])
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
@@ -259,7 +259,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(expected_entropy, entropy.eval())
def testCauchyMode(self):
- with self.test_session():
+ with self.cached_session():
# Mu will be broadcast to [7, 7, 7].
loc = [7.]
scale = [11., 12., 13.]
@@ -270,7 +270,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual([7., 7, 7], cauchy.mode().eval())
def testCauchyMean(self):
- with self.test_session():
+ with self.cached_session():
loc = [1., 2., 3.]
scale = [7.]
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
@@ -279,7 +279,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual([np.nan] * 3, cauchy.mean().eval())
def testCauchyNanMean(self):
- with self.test_session():
+ with self.cached_session():
loc = [1., 2., 3.]
scale = [7.]
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False)
@@ -288,7 +288,7 @@ class CauchyTest(test.TestCase):
cauchy.mean().eval()
def testCauchyQuantile(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
loc = self._rng.randn(batch_size)
scale = self._rng.rand(batch_size) + 1.0
@@ -308,7 +308,7 @@ class CauchyTest(test.TestCase):
self.assertAllClose(expected_x, x.eval(), atol=0.)
def testCauchyVariance(self):
- with self.test_session():
+ with self.cached_session():
# scale will be broadcast to [7, 7, 7]
loc = [1., 2., 3.]
scale = [7.]
@@ -318,7 +318,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual([np.nan] * 3, cauchy.variance().eval())
def testCauchyNanVariance(self):
- with self.test_session():
+ with self.cached_session():
# scale will be broadcast to [7, 7, 7]
loc = [1., 2., 3.]
scale = [7.]
@@ -328,7 +328,7 @@ class CauchyTest(test.TestCase):
cauchy.variance().eval()
def testCauchyStandardDeviation(self):
- with self.test_session():
+ with self.cached_session():
# scale will be broadcast to [7, 7, 7]
loc = [1., 2., 3.]
scale = [7.]
@@ -338,7 +338,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual([np.nan] * 3, cauchy.stddev().eval())
def testCauchyNanStandardDeviation(self):
- with self.test_session():
+ with self.cached_session():
# scale will be broadcast to [7, 7, 7]
loc = [1., 2., 3.]
scale = [7.]
@@ -348,7 +348,7 @@ class CauchyTest(test.TestCase):
cauchy.stddev().eval()
def testCauchySample(self):
- with self.test_session():
+ with self.cached_session():
loc = constant_op.constant(3.0)
scale = constant_op.constant(1.0)
loc_v = 3.0
@@ -373,7 +373,7 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected_shape, sample_values.shape)
def testCauchySampleMultiDimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 2
loc = constant_op.constant([[3.0, -3.0]] * batch_size)
scale = constant_op.constant([[0.5, 1.0]] * batch_size)
@@ -399,13 +399,13 @@ class CauchyTest(test.TestCase):
self.assertAllEqual(expected_shape, sample_values.shape)
def testCauchyNegativeLocFails(self):
- with self.test_session():
+ with self.cached_session():
cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True)
with self.assertRaisesOpError("Condition x > 0 did not hold"):
cauchy.mode().eval()
def testCauchyShape(self):
- with self.test_session():
+ with self.cached_session():
loc = constant_op.constant([-3.0] * 5)
scale = constant_op.constant(11.0)
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
@@ -420,7 +420,7 @@ class CauchyTest(test.TestCase):
scale = array_ops.placeholder(dtype=dtypes.float32)
cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(cauchy.batch_shape, tensor_shape.TensorShape(None))
self.assertEqual(cauchy.event_shape, ())
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
index 75d48791ec..3b5a6aa90c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class Chi2Test(test.TestCase):
def testChi2LogPDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
df = constant_op.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
@@ -46,7 +46,7 @@ class Chi2Test(test.TestCase):
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testChi2CDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
df = constant_op.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
@@ -60,7 +60,7 @@ class Chi2Test(test.TestCase):
self.assertAllClose(cdf.eval(), expected_cdf)
def testChi2Mean(self):
- with self.test_session():
+ with self.cached_session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_mean = stats.chi2.mean(df_v)
chi2 = chi2_lib.Chi2(df=df_v)
@@ -68,7 +68,7 @@ class Chi2Test(test.TestCase):
self.assertAllClose(chi2.mean().eval(), expected_mean)
def testChi2Variance(self):
- with self.test_session():
+ with self.cached_session():
df_v = np.array([1., 3, 5], np.float64)
expected_variances = stats.chi2.var(df_v)
chi2 = chi2_lib.Chi2(df=df_v)
@@ -76,7 +76,7 @@ class Chi2Test(test.TestCase):
self.assertAllClose(chi2.variance().eval(), expected_variances)
def testChi2Entropy(self):
- with self.test_session():
+ with self.cached_session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_entropy = stats.chi2.entropy(df_v)
chi2 = chi2_lib.Chi2(df=df_v)
@@ -84,7 +84,7 @@ class Chi2Test(test.TestCase):
self.assertAllClose(chi2.entropy().eval(), expected_entropy)
def testChi2WithAbsDf(self):
- with self.test_session():
+ with self.cached_session():
df_v = np.array([-1.3, -3.2, 5], dtype=np.float64)
chi2 = chi2_lib.Chi2WithAbsDf(df=df_v)
self.assertAllClose(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
index 4e8989b6c2..7e63b5ca5f 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
@@ -69,7 +69,7 @@ class ConditionalTransformedDistributionTest(
return ds.ConditionalTransformedDistribution
def testConditioning(self):
- with self.test_session():
+ with self.cached_session():
conditional_normal = ds.ConditionalTransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
bijector=_ChooseLocation(loc=[-100., 100.]))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
index 200310bc41..36fc7a70c8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py
@@ -29,7 +29,7 @@ rng = np.random.RandomState(0)
class DeterministicTest(test.TestCase):
def testShape(self):
- with self.test_session():
+ with self.cached_session():
loc = rng.rand(2, 3, 4)
deterministic = deterministic_lib.Deterministic(loc)
@@ -42,20 +42,20 @@ class DeterministicTest(test.TestCase):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.Deterministic(
loc, atol=-1, validate_args=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(0.).eval()
def testProbWithNoBatchDimsIntegerType(self):
deterministic = deterministic_lib.Deterministic(0)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1, deterministic.prob(0).eval())
self.assertAllClose(0, deterministic.prob(2).eval())
self.assertAllClose([1, 0], deterministic.prob([0, 2]).eval())
def testProbWithNoBatchDims(self):
deterministic = deterministic_lib.Deterministic(0.)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1., deterministic.prob(0.).eval())
self.assertAllClose(0., deterministic.prob(2.).eval())
self.assertAllClose([1., 0.], deterministic.prob([0., 2.]).eval())
@@ -65,7 +65,7 @@ class DeterministicTest(test.TestCase):
x = [[0., 1.1], [1.99, 3.]]
deterministic = deterministic_lib.Deterministic(loc)
expected_prob = [[1., 0.], [0., 1.]]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((2, 2), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -75,7 +75,7 @@ class DeterministicTest(test.TestCase):
x = [[0., 1.1], [1.99, 3.]]
deterministic = deterministic_lib.Deterministic(loc, atol=0.05)
expected_prob = [[1., 0.], [1., 1.]]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((2, 2), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -85,7 +85,7 @@ class DeterministicTest(test.TestCase):
x = [[0, 2], [4, 2]]
deterministic = deterministic_lib.Deterministic(loc, atol=1)
expected_prob = [[1, 1], [0, 1]]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((2, 2), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -95,7 +95,7 @@ class DeterministicTest(test.TestCase):
x = [[0., 1.1], [100.1, 103.]]
deterministic = deterministic_lib.Deterministic(loc, rtol=0.01)
expected_prob = [[1., 0.], [1., 0.]]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((2, 2), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -107,7 +107,7 @@ class DeterministicTest(test.TestCase):
# Batch 1 will have rtol = 1 (100% slack allowed)
deterministic = deterministic_lib.Deterministic(loc, rtol=[[0], [1]])
expected_prob = [[1, 0, 0], [1, 1, 0]]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((2, 3), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -117,7 +117,7 @@ class DeterministicTest(test.TestCase):
x = [[-1., -0.1], [-0.01, 1.000001]]
deterministic = deterministic_lib.Deterministic(loc)
expected_cdf = [[0., 0.], [0., 1.]]
- with self.test_session():
+ with self.cached_session():
cdf = deterministic.cdf(x)
self.assertAllEqual((2, 2), cdf.get_shape())
self.assertAllEqual(expected_cdf, cdf.eval())
@@ -127,7 +127,7 @@ class DeterministicTest(test.TestCase):
x = [[-1., -0.1], [-0.01, 1.000001]]
deterministic = deterministic_lib.Deterministic(loc, atol=0.05)
expected_cdf = [[0., 0.], [1., 1.]]
- with self.test_session():
+ with self.cached_session():
cdf = deterministic.cdf(x)
self.assertAllEqual((2, 2), cdf.get_shape())
self.assertAllEqual(expected_cdf, cdf.eval())
@@ -137,7 +137,7 @@ class DeterministicTest(test.TestCase):
x = [[0.9, 1.], [99.9, 97]]
deterministic = deterministic_lib.Deterministic(loc, rtol=0.01)
expected_cdf = [[0., 1.], [1., 0.]]
- with self.test_session():
+ with self.cached_session():
cdf = deterministic.cdf(x)
self.assertAllEqual((2, 2), cdf.get_shape())
self.assertAllEqual(expected_cdf, cdf.eval())
@@ -145,7 +145,7 @@ class DeterministicTest(test.TestCase):
def testSampleNoBatchDims(self):
deterministic = deterministic_lib.Deterministic(0.)
for sample_shape in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample = deterministic.sample(sample_shape)
self.assertAllEqual(sample_shape, sample.get_shape())
self.assertAllClose(
@@ -154,7 +154,7 @@ class DeterministicTest(test.TestCase):
def testSampleWithBatchDims(self):
deterministic = deterministic_lib.Deterministic([0., 0.])
for sample_shape in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample = deterministic.sample(sample_shape)
self.assertAllEqual(sample_shape + (2,), sample.get_shape())
self.assertAllClose(
@@ -166,7 +166,7 @@ class DeterministicTest(test.TestCase):
deterministic = deterministic_lib.Deterministic(loc)
for sample_shape_ in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample_ = deterministic.sample(sample_shape).eval(
feed_dict={loc: [0., 0.],
sample_shape: sample_shape_})
@@ -176,7 +176,7 @@ class DeterministicTest(test.TestCase):
def testEntropy(self):
loc = np.array([-0.1, -3.2, 7.])
deterministic = deterministic_lib.Deterministic(loc=loc)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
entropy_ = sess.run(deterministic.entropy())
self.assertAllEqual(np.zeros(3), entropy_)
@@ -184,7 +184,7 @@ class DeterministicTest(test.TestCase):
class VectorDeterministicTest(test.TestCase):
def testShape(self):
- with self.test_session():
+ with self.cached_session():
loc = rng.rand(2, 3, 4)
deterministic = deterministic_lib.VectorDeterministic(loc)
@@ -197,7 +197,7 @@ class VectorDeterministicTest(test.TestCase):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
deterministic.prob(loc).eval()
@@ -205,14 +205,14 @@ class VectorDeterministicTest(test.TestCase):
loc = rng.rand(2, 3, 4).astype(np.float32)
deterministic = deterministic_lib.VectorDeterministic(
loc, atol=-1, validate_args=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "must have rank at least 1"):
deterministic.prob(0.).eval()
def testProbVectorDeterministicWithNoBatchDims(self):
# 0 batch of deterministics on R^1.
deterministic = deterministic_lib.VectorDeterministic([0.])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1., deterministic.prob([0.]).eval())
self.assertAllClose(0., deterministic.prob([2.]).eval())
self.assertAllClose([1., 0.], deterministic.prob([[0.], [2.]]).eval())
@@ -223,7 +223,7 @@ class VectorDeterministicTest(test.TestCase):
x = [[0., 1.], [1.9, 3.], [3.99, 5.]]
deterministic = deterministic_lib.VectorDeterministic(loc)
expected_prob = [1., 0., 0.]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((3,), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -234,7 +234,7 @@ class VectorDeterministicTest(test.TestCase):
x = [[0., 1.], [1.9, 3.], [3.99, 5.]]
deterministic = deterministic_lib.VectorDeterministic(loc, atol=0.05)
expected_prob = [1., 0., 1.]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((3,), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -245,7 +245,7 @@ class VectorDeterministicTest(test.TestCase):
x = [[0., 1.], [0.9, 1.], [99.9, 100.1]]
deterministic = deterministic_lib.VectorDeterministic(loc, rtol=0.01)
expected_prob = [1., 0., 1.]
- with self.test_session():
+ with self.cached_session():
prob = deterministic.prob(x)
self.assertAllEqual((3,), prob.get_shape())
self.assertAllEqual(expected_prob, prob.eval())
@@ -254,7 +254,7 @@ class VectorDeterministicTest(test.TestCase):
# 0 batch of deterministics on R^0.
deterministic = deterministic_lib.VectorDeterministic(
[], validate_args=True)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1., deterministic.prob([]).eval())
def testProbVectorDeterministicWithNoBatchDimsOnRZeroRaisesIfXNotInSameRk(
@@ -262,14 +262,14 @@ class VectorDeterministicTest(test.TestCase):
# 0 batch of deterministics on R^0.
deterministic = deterministic_lib.VectorDeterministic(
[], validate_args=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not defined in the same space"):
deterministic.prob([1.]).eval()
def testSampleNoBatchDims(self):
deterministic = deterministic_lib.VectorDeterministic([0.])
for sample_shape in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample = deterministic.sample(sample_shape)
self.assertAllEqual(sample_shape + (1,), sample.get_shape())
self.assertAllClose(
@@ -278,7 +278,7 @@ class VectorDeterministicTest(test.TestCase):
def testSampleWithBatchDims(self):
deterministic = deterministic_lib.VectorDeterministic([[0.], [0.]])
for sample_shape in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample = deterministic.sample(sample_shape)
self.assertAllEqual(sample_shape + (2, 1), sample.get_shape())
self.assertAllClose(
@@ -290,7 +290,7 @@ class VectorDeterministicTest(test.TestCase):
deterministic = deterministic_lib.VectorDeterministic(loc)
for sample_shape_ in [(), (4,)]:
- with self.test_session():
+ with self.cached_session():
sample_ = deterministic.sample(sample_shape).eval(
feed_dict={loc: [[0.], [0.]],
sample_shape: sample_shape_})
@@ -300,7 +300,7 @@ class VectorDeterministicTest(test.TestCase):
def testEntropy(self):
loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]])
deterministic = deterministic_lib.VectorDeterministic(loc=loc)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
entropy_ = sess.run(deterministic.entropy())
self.assertAllEqual(np.zeros(2), entropy_)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index f42feae25d..f073f51a69 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -47,7 +47,7 @@ class DistributionTest(test.TestCase):
]
sample_shapes = [(), (10,), (10, 20, 30)]
- with self.test_session():
+ with self.cached_session():
for cls in classes:
for sample_shape in sample_shapes:
param_shapes = cls.param_shapes(sample_shape)
@@ -62,7 +62,7 @@ class DistributionTest(test.TestCase):
self.assertEqual(dist.parameters, dist_copy.parameters)
def testCopyExtraArgs(self):
- with self.test_session():
+ with self.cached_session():
# Note: we cannot easily test all distributions since each requires
# different initialization arguments. We therefore spot test a few.
normal = tfd.Normal(loc=1., scale=2., validate_args=True)
@@ -72,7 +72,7 @@ class DistributionTest(test.TestCase):
self.assertEqual(wishart.parameters, wishart.copy().parameters)
def testCopyOverride(self):
- with self.test_session():
+ with self.cached_session():
normal = tfd.Normal(loc=1., scale=2., validate_args=True)
unused_normal_copy = normal.copy(validate_args=False)
base_params = normal.parameters.copy()
@@ -82,7 +82,7 @@ class DistributionTest(test.TestCase):
self.assertEqual(base_params, copy_params)
def testIsScalar(self):
- with self.test_session():
+ with self.cached_session():
mu = 1.
sigma = 2.
@@ -152,7 +152,7 @@ class DistributionTest(test.TestCase):
def testSampleShapeHints(self):
fake_distribution = self._GetFakeDistribution()
- with self.test_session():
+ with self.cached_session():
# Make a new session since we're playing with static shapes. [And below.]
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[2, 3], event_shape=[5])
@@ -162,28 +162,28 @@ class DistributionTest(test.TestCase):
# unknown values, ie, Dimension(None).
self.assertAllEqual([6, 7, 2, 3, 5], y.get_shape().as_list())
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=[5])
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertAllEqual([6, 7, None, 3, 5], y.get_shape().as_list())
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=[None])
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertAllEqual([6, 7, None, 3, None], y.get_shape().as_list())
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=None, event_shape=None)
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
y = dist._set_sample_static_shape(x, sample_shape)
self.assertTrue(y.get_shape().ndims is None)
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
dist = fake_distribution(batch_shape=[None, 3], event_shape=None)
sample_shape = ops.convert_to_tensor([6, 7], dtype=dtypes.int32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index 181c46d2e5..05f5d30666 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -100,7 +100,7 @@ class MakeTrilScaleTest(test.TestCase):
def _testLegalInputs(
self, loc=None, shape_hint=None, scale_params=None):
for args in _powerset(scale_params.items()):
- with self.test_session():
+ with self.cached_session():
args = dict(args)
scale_args = dict({
@@ -143,19 +143,19 @@ class MakeTrilScaleTest(test.TestCase):
})
def testZeroTriU(self):
- with self.test_session():
+ with self.cached_session():
scale = distribution_util.make_tril_scale(scale_tril=[[1., 1], [1., 1.]])
self.assertAllClose([[1., 0], [1., 1.]], scale.to_dense().eval())
def testValidateArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("diagonal part must be non-zero"):
scale = distribution_util.make_tril_scale(
scale_tril=[[0., 1], [1., 1.]], validate_args=True)
scale.to_dense().eval()
def testAssertPositive(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("diagonal part must be positive"):
scale = distribution_util.make_tril_scale(
scale_tril=[[-1., 1], [1., 1.]],
@@ -169,7 +169,7 @@ class MakeDiagScaleTest(test.TestCase):
def _testLegalInputs(
self, loc=None, shape_hint=None, scale_params=None):
for args in _powerset(scale_params.items()):
- with self.test_session():
+ with self.cached_session():
args = dict(args)
scale_args = dict({
@@ -204,14 +204,14 @@ class MakeDiagScaleTest(test.TestCase):
})
def testValidateArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("diagonal part must be non-zero"):
scale = distribution_util.make_diag_scale(
scale_diag=[[0., 1], [1., 1.]], validate_args=True)
scale.to_dense().eval()
def testAssertPositive(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("diagonal part must be positive"):
scale = distribution_util.make_diag_scale(
scale_diag=[[-1., 1], [1., 1.]],
@@ -241,7 +241,7 @@ class ShapesFromLocAndScaleTest(test.TestCase):
loc = constant_op.constant(np.zeros((2, 3)))
diag = array_ops.placeholder(dtypes.float64)
scale = linear_operator_diag.LinearOperatorDiag(diag)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_shape, event_shape = sess.run(
distribution_util.shapes_from_loc_and_scale(loc, scale),
feed_dict={diag: np.ones((5, 1, 3))})
@@ -252,7 +252,7 @@ class ShapesFromLocAndScaleTest(test.TestCase):
loc = array_ops.placeholder(dtypes.float64)
diag = constant_op.constant(np.ones((5, 2, 3)))
scale = linear_operator_diag.LinearOperatorDiag(diag)
- with self.test_session():
+ with self.cached_session():
batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
loc, scale)
# batch_shape depends on both args, and so is dynamic. Since loc did not
@@ -266,7 +266,7 @@ class ShapesFromLocAndScaleTest(test.TestCase):
loc = array_ops.placeholder(dtypes.float64)
diag = array_ops.placeholder(dtypes.float64)
scale = linear_operator_diag.LinearOperatorDiag(diag)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_shape, event_shape = sess.run(
distribution_util.shapes_from_loc_and_scale(loc, scale),
feed_dict={diag: np.ones((5, 2, 3)), loc: np.zeros((2, 3))})
@@ -286,7 +286,7 @@ class ShapesFromLocAndScaleTest(test.TestCase):
loc = None
diag = array_ops.placeholder(dtypes.float64)
scale = linear_operator_diag.LinearOperatorDiag(diag)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_shape, event_shape = sess.run(
distribution_util.shapes_from_loc_and_scale(loc, scale),
feed_dict={diag: np.ones((5, 1, 3))})
@@ -307,7 +307,7 @@ class GetBroadcastShapeTest(test.TestCase):
x = array_ops.ones((2, 1, 3))
y = array_ops.placeholder(x.dtype)
z = array_ops.ones(())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bcast_shape = sess.run(
distribution_util.get_broadcast_shape(x, y, z),
feed_dict={y: np.ones((1, 5, 3)).astype(np.float32)})
@@ -317,7 +317,7 @@ class GetBroadcastShapeTest(test.TestCase):
class TridiagTest(test.TestCase):
def testWorksCorrectlyNoBatches(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
[[4., 8., 0., 0.],
[1., 5., 9., 0.],
@@ -329,7 +329,7 @@ class TridiagTest(test.TestCase):
[8., 9., 10.]).eval())
def testWorksCorrectlyBatches(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
[[[4., 8., 0., 0.],
[1., 5., 9., 0.],
@@ -349,7 +349,7 @@ class TridiagTest(test.TestCase):
rtol=1e-5, atol=0.)
def testHandlesNone(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
[[[4., 0., 0., 0.],
[0., 5., 0., 0.],
@@ -396,7 +396,7 @@ class MixtureStddevTest(test.TestCase):
means_tf,
sigmas_tf)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_devs = sess.run(mix_dev)
self.assertAllClose(actual_devs, expected_devs)
@@ -405,7 +405,7 @@ class MixtureStddevTest(test.TestCase):
class PadMixtureDimensionsTest(test.TestCase):
def test_pad_mixture_dimensions_mixture(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture.Mixture(
cat=categorical.Categorical(probs=[[0.3, 0.7]]),
components=[
@@ -422,7 +422,7 @@ class PadMixtureDimensionsTest(test.TestCase):
self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def test_pad_mixture_dimensions_mixture_same_family(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family.MixtureSameFamily(
mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag.MultivariateNormalDiag(
@@ -444,7 +444,7 @@ class _PadTest(object):
[4, 5, 6]])
value_ = np.float32(0.25)
count_ = np.int32(2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder_with_default(
x_, shape=x_.shape if self.is_static_shape else None)
value = (constant_op.constant(value_) if self.is_static_shape
@@ -491,7 +491,7 @@ class _PadTest(object):
[4, 5, 6]])
value_ = np.float32(0.25)
count_ = np.int32(2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder_with_default(
x_, shape=x_.shape if self.is_static_shape else None)
value = (constant_op.constant(value_) if self.is_static_shape
@@ -542,9 +542,9 @@ class PadDynamicTest(_PadTest, test.TestCase):
return False
+@test_util.run_all_in_graph_and_eager_modes
class TestMoveDimension(test.TestCase):
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_static_shape(self):
x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
@@ -561,7 +561,6 @@ class TestMoveDimension(test.TestCase):
x_perm = distribution_util.move_dimension(x, 4, 2)
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_dynamic_shape(self):
x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
index 87cdd0485a..a627d85229 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class GeometricTest(test.TestCase):
def testGeometricShape(self):
- with self.test_session():
+ with self.cached_session():
probs = constant_op.constant([.1] * 5)
geom = geometric.Geometric(probs=probs)
@@ -45,19 +45,19 @@ class GeometricTest(test.TestCase):
def testInvalidP(self):
invalid_ps = [-.01, -0.01, -2.]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
geom = geometric.Geometric(probs=invalid_ps, validate_args=True)
geom.probs.eval()
invalid_ps = [1.1, 3., 5.]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x <= y"):
geom = geometric.Geometric(probs=invalid_ps, validate_args=True)
geom.probs.eval()
def testGeomLogPmf(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = constant_op.constant([.2] * batch_size)
probs_v = .2
@@ -73,7 +73,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_prob), pmf.eval())
def testGeometricLogPmf_validate_args(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = constant_op.constant([.9] * batch_size)
x = array_ops.placeholder(dtypes.float32, shape=[6])
@@ -95,7 +95,7 @@ class GeometricTest(test.TestCase):
self.assertEqual([6,], pmf.get_shape())
def testGeometricLogPmfMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = constant_op.constant([[.2, .3, .5]] * batch_size)
probs_v = np.array([.2, .3, .5])
@@ -113,7 +113,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_prob), pmf_values)
def testGeometricCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = constant_op.constant([[.2, .4, .5]] * batch_size)
probs_v = np.array([.2, .4, .5])
@@ -127,7 +127,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(expected_cdf, cdf.eval())
def testGeometricEntropy(self):
- with self.test_session():
+ with self.cached_session():
probs_v = np.array([.1, .3, .25], dtype=np.float32)
geom = geometric.Geometric(probs=probs_v)
expected_entropy = stats.geom.entropy(probs_v, loc=-1)
@@ -135,7 +135,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(expected_entropy, geom.entropy().eval())
def testGeometricMean(self):
- with self.test_session():
+ with self.cached_session():
probs_v = np.array([.1, .3, .25])
geom = geometric.Geometric(probs=probs_v)
expected_means = stats.geom.mean(probs_v, loc=-1)
@@ -143,7 +143,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(expected_means, geom.mean().eval())
def testGeometricVariance(self):
- with self.test_session():
+ with self.cached_session():
probs_v = np.array([.1, .3, .25])
geom = geometric.Geometric(probs=probs_v)
expected_vars = stats.geom.var(probs_v, loc=-1)
@@ -151,7 +151,7 @@ class GeometricTest(test.TestCase):
self.assertAllClose(expected_vars, geom.variance().eval())
def testGeometricStddev(self):
- with self.test_session():
+ with self.cached_session():
probs_v = np.array([.1, .3, .25])
geom = geometric.Geometric(probs=probs_v)
expected_stddevs = stats.geom.std(probs_v, loc=-1)
@@ -159,14 +159,14 @@ class GeometricTest(test.TestCase):
self.assertAllClose(geom.stddev().eval(), expected_stddevs)
def testGeometricMode(self):
- with self.test_session():
+ with self.cached_session():
probs_v = np.array([.1, .3, .25])
geom = geometric.Geometric(probs=probs_v)
self.assertEqual([3,], geom.mode().get_shape())
self.assertAllClose([0.] * 3, geom.mode().eval())
def testGeometricSample(self):
- with self.test_session():
+ with self.cached_session():
probs_v = [.3, .9]
probs = constant_op.constant(probs_v)
n = constant_op.constant(100000)
@@ -186,7 +186,7 @@ class GeometricTest(test.TestCase):
rtol=.02)
def testGeometricSampleMultiDimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 2
probs_v = [.3, .9]
probs = constant_op.constant([probs_v] * batch_size)
@@ -215,7 +215,7 @@ class GeometricTest(test.TestCase):
rtol=.02)
def testGeometricAtBoundary(self):
- with self.test_session():
+ with self.cached_session():
geom = geometric.Geometric(probs=1., validate_args=True)
x = np.array([0., 2., 3., 4., 5., 6., 7.], dtype=np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py
index a4e7566008..686de9d246 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py
@@ -55,7 +55,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual(all_true, is_finite)
def _testParamShapes(self, sample_shape, expected):
- with self.test_session():
+ with self.cached_session():
param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape)
scale_shape = param_shapes["scale"]
self.assertAllEqual(expected, scale_shape.eval())
@@ -87,7 +87,7 @@ class HalfNormalTest(test.TestCase):
tensor_shape.TensorShape(sample_shape), sample_shape)
def testHalfNormalLogPDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
scale = constant_op.constant([3.0] * batch_size)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
@@ -106,7 +106,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testHalfNormalLogPDFMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
scale = constant_op.constant([[3.0, 1.0]] * batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
@@ -125,7 +125,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testHalfNormalCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
scale = self._rng.rand(batch_size) + 1.0
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
@@ -144,7 +144,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0)
def testHalfNormalSurvivalFunction(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
scale = self._rng.rand(batch_size) + 1.0
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
@@ -163,7 +163,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0)
def testHalfNormalQuantile(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 50
scale = self._rng.rand(batch_size) + 1.0
p = np.linspace(0., 1.0, batch_size).astype(np.float64)
@@ -191,13 +191,13 @@ class HalfNormalTest(test.TestCase):
print(func.__name__)
value = func(x)
grads = gradients_impl.gradients(value, [scale])
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(value)
self.assertAllFinite(grads[0])
def testHalfNormalEntropy(self):
- with self.test_session():
+ with self.cached_session():
scale = np.array([[1.0, 2.0, 3.0]])
halfnorm = hn_lib.HalfNormal(scale=scale)
@@ -210,7 +210,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllClose(expected_entropy, entropy.eval())
def testHalfNormalMeanAndMode(self):
- with self.test_session():
+ with self.cached_session():
scale = np.array([11., 12., 13.])
halfnorm = hn_lib.HalfNormal(scale=scale)
@@ -223,7 +223,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval())
def testHalfNormalVariance(self):
- with self.test_session():
+ with self.cached_session():
scale = np.array([7., 7., 7.])
halfnorm = hn_lib.HalfNormal(scale=scale)
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
@@ -232,7 +232,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual(expected_variance, halfnorm.variance().eval())
def testHalfNormalStandardDeviation(self):
- with self.test_session():
+ with self.cached_session():
scale = np.array([7., 7., 7.])
halfnorm = hn_lib.HalfNormal(scale=scale)
expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi)
@@ -241,7 +241,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval())
def testHalfNormalSample(self):
- with self.test_session():
+ with self.cached_session():
scale = constant_op.constant(3.0)
n = constant_op.constant(100000)
halfnorm = hn_lib.HalfNormal(scale=scale)
@@ -263,7 +263,7 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual(expected_shape_static, sample.eval().shape)
def testHalfNormalSampleMultiDimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 2
scale = constant_op.constant([[2.0, 3.0]] * batch_size)
n = constant_op.constant(100000)
@@ -287,13 +287,13 @@ class HalfNormalTest(test.TestCase):
self.assertAllEqual(expected_shape_static, sample.eval().shape)
def testNegativeSigmaFails(self):
- with self.test_session():
+ with self.cached_session():
halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G")
with self.assertRaisesOpError("Condition x > 0 did not hold"):
halfnorm.mean().eval()
def testHalfNormalShape(self):
- with self.test_session():
+ with self.cached_session():
scale = constant_op.constant([6.0] * 5)
halfnorm = hn_lib.HalfNormal(scale=scale)
@@ -306,7 +306,7 @@ class HalfNormalTest(test.TestCase):
scale = array_ops.placeholder(dtype=dtypes.float32)
halfnorm = hn_lib.HalfNormal(scale=scale)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None))
self.assertEqual(halfnorm.event_shape, ())
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py
index 6a69f9e60b..ecf27289d7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py
@@ -52,7 +52,7 @@ class ProductDistributionTest(test.TestCase):
def testSampleAndLogProbUnivariate(self):
loc = np.float32([-1., 1])
scale = np.float32([0.1, 0.5])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ind = independent_lib.Independent(
distribution=normal_lib.Normal(loc=loc, scale=scale),
reinterpreted_batch_ndims=1)
@@ -73,7 +73,7 @@ class ProductDistributionTest(test.TestCase):
def testSampleAndLogProbMultivariate(self):
loc = np.float32([[-1., 1], [1, -1]])
scale = np.float32([1., 0.5])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ind = independent_lib.Independent(
distribution=mvn_diag_lib.MultivariateNormalDiag(
loc=loc,
@@ -98,7 +98,7 @@ class ProductDistributionTest(test.TestCase):
loc = np.float32([[-1., 1], [1, -1]])
scale = np.float32([1., 0.5])
n_samp = 1e4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ind = independent_lib.Independent(
distribution=mvn_diag_lib.MultivariateNormalDiag(
loc=loc,
@@ -231,7 +231,7 @@ class ProductDistributionTest(test.TestCase):
def expected_log_prob(x, logits):
return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits_ph = array_ops.placeholder(
dtypes.float32, shape=logits.shape if static_shape else None)
ind = independent_lib.Independent(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
index 6eb96ea9ff..70551d89d9 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class InverseGammaTest(test.TestCase):
def testInverseGammaShape(self):
- with self.test_session():
+ with self.cached_session():
alpha = constant_op.constant([3.0] * 5)
beta = constant_op.constant(11.0)
inv_gamma = inverse_gamma.InverseGamma(concentration=alpha, rate=beta)
@@ -43,7 +43,7 @@ class InverseGammaTest(test.TestCase):
[]))
def testInverseGammaLogPDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
alpha = constant_op.constant([2.0] * batch_size)
beta = constant_op.constant([3.0] * batch_size)
@@ -61,7 +61,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
def testInverseGammaLogPDFMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
beta = constant_op.constant([[3.0, 4.0]] * batch_size)
@@ -81,7 +81,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testInverseGammaLogPDFMultidimensionalBroadcasting(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
beta = constant_op.constant(3.0)
@@ -101,7 +101,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testInverseGammaCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
alpha_v = 2.0
beta_v = 3.0
@@ -117,7 +117,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(cdf.eval(), expected_cdf)
def testInverseGammaMode(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
@@ -126,7 +126,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(inv_gamma.mode().eval(), expected_modes)
def testInverseGammaMeanAllDefined(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
@@ -135,7 +135,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(inv_gamma.mean().eval(), expected_means)
def testInverseGammaMeanAllowNanStats(self):
- with self.test_session():
+ with self.cached_session():
# Mean will not be defined for the first entry.
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
@@ -145,7 +145,7 @@ class InverseGammaTest(test.TestCase):
inv_gamma.mean().eval()
def testInverseGammaMeanNanStats(self):
- with self.test_session():
+ with self.cached_session():
# Mode will not be defined for the first two entries.
alpha_v = np.array([0.5, 1.0, 3.0, 2.5])
beta_v = np.array([1.0, 2.0, 4.0, 5.0])
@@ -158,7 +158,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(inv_gamma.mean().eval(), expected_means)
def testInverseGammaVarianceAllDefined(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([7.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(concentration=alpha_v, rate=beta_v)
@@ -167,7 +167,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(inv_gamma.variance().eval(), expected_variances)
def testInverseGammaVarianceAllowNanStats(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([1.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
@@ -176,7 +176,7 @@ class InverseGammaTest(test.TestCase):
inv_gamma.variance().eval()
def testInverseGammaVarianceNanStats(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([1.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
inv_gamma = inverse_gamma.InverseGamma(
@@ -187,7 +187,7 @@ class InverseGammaTest(test.TestCase):
self.assertAllClose(inv_gamma.variance().eval(), expected_variances)
def testInverseGammaEntropy(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.invgamma.entropy(alpha_v, scale=beta_v)
@@ -292,7 +292,7 @@ class InverseGammaTest(test.TestCase):
self.assertNear(1., total, err=err)
def testInverseGammaNonPositiveInitializationParamsRaises(self):
- with self.test_session():
+ with self.cached_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
inv_gamma = inverse_gamma.InverseGamma(
@@ -307,7 +307,7 @@ class InverseGammaTest(test.TestCase):
inv_gamma.mean().eval()
def testInverseGammaWithSoftplusConcentrationRate(self):
- with self.test_session():
+ with self.cached_session():
alpha = constant_op.constant([-0.1, -2.9], name="alpha")
beta = constant_op.constant([1.0, -4.8], name="beta")
inv_gamma = inverse_gamma.InverseGammaWithSoftplusConcentrationRate(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
index 2980e2bfe9..e39db51728 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py
@@ -77,7 +77,7 @@ def _kumaraswamy_pdf(a, b, x):
class KumaraswamyTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
a = np.random.rand(3)
b = np.random.rand(3)
dist = kumaraswamy_lib.Kumaraswamy(a, b)
@@ -87,7 +87,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
a = np.random.rand(3, 2, 2)
b = np.random.rand(3, 2, 2)
dist = kumaraswamy_lib.Kumaraswamy(a, b)
@@ -97,7 +97,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testComplexShapesBroadcast(self):
- with self.test_session():
+ with self.cached_session():
a = np.random.rand(3, 2, 2)
b = np.random.rand(2, 2)
dist = kumaraswamy_lib.Kumaraswamy(a, b)
@@ -109,7 +109,7 @@ class KumaraswamyTest(test.TestCase):
def testAProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
+ with self.cached_session():
dist = kumaraswamy_lib.Kumaraswamy(a, b)
self.assertEqual([1, 3], dist.concentration1.get_shape())
self.assertAllClose(a, dist.concentration1.eval())
@@ -117,7 +117,7 @@ class KumaraswamyTest(test.TestCase):
def testBProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
+ with self.cached_session():
dist = kumaraswamy_lib.Kumaraswamy(a, b)
self.assertEqual([1, 3], dist.concentration0.get_shape())
self.assertAllClose(b, dist.concentration0.eval())
@@ -125,7 +125,7 @@ class KumaraswamyTest(test.TestCase):
def testPdfXProper(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
+ with self.cached_session():
dist = kumaraswamy_lib.Kumaraswamy(a, b, validate_args=True)
dist.prob([.1, .3, .6]).eval()
dist.prob([.2, .3, .5]).eval()
@@ -136,7 +136,7 @@ class KumaraswamyTest(test.TestCase):
dist.prob([.1, .2, 1.2]).eval()
def testPdfTwoBatches(self):
- with self.test_session():
+ with self.cached_session():
a = [1., 2]
b = [1., 2]
x = [.5, .5]
@@ -147,7 +147,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((2,), pdf.get_shape())
def testPdfTwoBatchesNontrivialX(self):
- with self.test_session():
+ with self.cached_session():
a = [1., 2]
b = [1., 2]
x = [.3, .7]
@@ -158,7 +158,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((2,), pdf.get_shape())
def testPdfUniformZeroBatch(self):
- with self.test_session():
+ with self.cached_session():
# This is equivalent to a uniform distribution
a = 1.
b = 1.
@@ -170,7 +170,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((5,), pdf.get_shape())
def testPdfAStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
a = [[1., 2]]
b = [[1., 2]]
x = [[.5, .5], [.3, .7]]
@@ -181,7 +181,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape())
def testPdfAStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
a = [1., 2]
b = [1., 2]
x = [[.5, .5], [.2, .8]]
@@ -191,7 +191,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
a = [[1., 2], [2., 3]]
b = [[1., 2], [2., 3]]
x = [[.5, .5]]
@@ -201,7 +201,7 @@ class KumaraswamyTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
a = [[1., 2], [2., 3]]
b = [[1., 2], [2., 3]]
x = [.5, .5]
@@ -289,7 +289,7 @@ class KumaraswamyTest(test.TestCase):
self.assertAllClose(expected_entropy, dist.entropy().eval())
def testKumaraswamySample(self):
- with self.test_session():
+ with self.cached_session():
a = 1.
b = 2.
kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b)
@@ -316,7 +316,7 @@ class KumaraswamyTest(test.TestCase):
# Test that sampling with the same seed twice gives the same results.
def testKumaraswamySampleMultipleTimes(self):
- with self.test_session():
+ with self.cached_session():
a_val = 1.
b_val = 2.
n_val = 100
@@ -334,7 +334,7 @@ class KumaraswamyTest(test.TestCase):
self.assertAllClose(samples1, samples2)
def testKumaraswamySampleMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
a = np.random.rand(3, 2, 2).astype(np.float32)
b = np.random.rand(3, 2, 2).astype(np.float32)
kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b)
@@ -351,7 +351,7 @@ class KumaraswamyTest(test.TestCase):
atol=1e-1)
def testKumaraswamyCdf(self):
- with self.test_session():
+ with self.cached_session():
shape = (30, 40, 50)
for dt in (np.float32, np.float64):
a = 10. * np.random.random(shape).astype(dt)
@@ -366,7 +366,7 @@ class KumaraswamyTest(test.TestCase):
_kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0)
def testKumaraswamyLogCdf(self):
- with self.test_session():
+ with self.cached_session():
shape = (30, 40, 50)
for dt in (np.float32, np.float64):
a = 10. * np.random.random(shape).astype(dt)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py
index 251be9ed4f..12a2d4f8ec 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py
@@ -39,7 +39,7 @@ class LogisticTest(test.TestCase):
dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED)
def testLogisticLogProb(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -57,7 +57,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(prob.eval(), np.exp(expected_log_prob))
def testLogisticCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -72,7 +72,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(cdf.eval(), expected_cdf)
def testLogisticLogCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -87,7 +87,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(logcdf.eval(), expected_logcdf)
def testLogisticSurvivalFunction(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -102,7 +102,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(survival_function.eval(), expected_survival_function)
def testLogisticLogSurvivalFunction(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -118,7 +118,7 @@ class LogisticTest(test.TestCase):
expected_logsurvival_function)
def testLogisticMean(self):
- with self.test_session():
+ with self.cached_session():
loc = [2.0, 1.5, 1.0]
scale = 1.5
expected_mean = stats.logistic.mean(loc, scale)
@@ -126,7 +126,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(dist.mean().eval(), expected_mean)
def testLogisticVariance(self):
- with self.test_session():
+ with self.cached_session():
loc = [2.0, 1.5, 1.0]
scale = 1.5
expected_variance = stats.logistic.var(loc, scale)
@@ -134,7 +134,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(dist.variance().eval(), expected_variance)
def testLogisticEntropy(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 3
np_loc = np.array([2.0] * batch_size, dtype=np.float32)
loc = constant_op.constant(np_loc)
@@ -144,7 +144,7 @@ class LogisticTest(test.TestCase):
self.assertAllClose(dist.entropy().eval(), expected_entropy)
def testLogisticSample(self):
- with self.test_session():
+ with self.cached_session():
loc = [3.0, 4.0, 2.0]
scale = 1.0
dist = logistic.Logistic(loc, scale)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
index ff6092fc26..faff42d243 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py
@@ -35,7 +35,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
test.TestCase):
def testSampleAndLogProbUnivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=normal_lib.Normal(
@@ -46,7 +46,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5], log_prob_x.shape)
def testSampleAndLogProbBatch(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]),
components_distribution=normal_lib.Normal(
@@ -59,7 +59,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
def testSampleAndLogProbShapesBroadcastMix(self):
mix_probs = np.float32([.3, .7])
bern_probs = np.float32([[.4, .6], [.25, .75]])
- with self.test_session():
+ with self.cached_session():
bm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=mix_probs),
components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs))
@@ -72,7 +72,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.))
def testSampleAndLogProbMultivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -83,7 +83,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5], log_prob_x.shape)
def testSampleAndLogProbBatchMultivariateShapes(self):
- with self.test_session():
+ with self.cached_session():
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -98,7 +98,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.assertEqual([4, 5, 2], log_prob_x.shape)
def testSampleConsistentLogProb(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -111,7 +111,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
def testLogCdf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=normal_lib.Normal(
@@ -128,7 +128,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
rtol=1e-6, atol=0.0)
def testSampleConsistentMeanCovariance(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
@@ -136,7 +136,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
self.run_test_sample_consistent_mean_covariance(sess.run, gm)
def testVarianceConsistentCovariance(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gm = mixture_same_family_lib.MixtureSameFamily(
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
index 0206489175..f8dbd34d02 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py
@@ -152,7 +152,7 @@ class MixtureTest(test.TestCase):
use_static_graph = False
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_univariate_mixture(batch_shape, num_components=10,
use_static_graph=self.use_static_graph)
@@ -200,7 +200,7 @@ class MixtureTest(test.TestCase):
use_static_graph=self.use_static_graph)
def testBrokenShapesDynamic(self):
- with self.test_session():
+ with self.cached_session():
d0_param = array_ops.placeholder(dtype=dtypes.float32)
d1_param = array_ops.placeholder(dtype=dtypes.float32)
d = ds.Mixture(
@@ -246,7 +246,7 @@ class MixtureTest(test.TestCase):
# mixture are checked for equivalence.
def testMeanUnivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_univariate_mixture(
batch_shape=batch_shape, num_components=2,
@@ -268,7 +268,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(true_mean, mean_value)
def testMeanMultivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
batch_shape=batch_shape, num_components=2, event_shape=(4,),
@@ -296,7 +296,7 @@ class MixtureTest(test.TestCase):
def testStddevShapeUnivariate(self):
num_components = 2
# This is the same shape test which is done in 'testMeanUnivariate'.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_univariate_mixture(
batch_shape=batch_shape, num_components=num_components,
@@ -337,7 +337,7 @@ class MixtureTest(test.TestCase):
num_components = 2
# This is the same shape test which is done in 'testMeanMultivariate'.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
batch_shape=batch_shape,
@@ -392,12 +392,12 @@ class MixtureTest(test.TestCase):
],
use_static_graph=self.use_static_graph)
mix_dev = mixture_dist.stddev()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_stddev = sess.run(mix_dev)
self.assertAllClose(actual_stddev, ground_truth_stddev)
def testProbScalarUnivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = make_univariate_mixture(batch_shape=[], num_components=2,
use_static_graph=self.use_static_graph)
for x in [
@@ -423,7 +423,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(total_prob, p_x_value)
def testProbScalarMultivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = make_multivariate_mixture(
batch_shape=[], num_components=2, event_shape=[3],
use_static_graph=self.use_static_graph)
@@ -452,7 +452,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(total_prob, p_x_value)
def testProbBatchUnivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2,
use_static_graph=self.use_static_graph)
@@ -479,7 +479,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(total_prob, p_x_value)
def testProbBatchMultivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = make_multivariate_mixture(
batch_shape=[2, 3], num_components=2, event_shape=[4],
use_static_graph=self.use_static_graph)
@@ -506,7 +506,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(total_prob, p_x_value)
def testSampleScalarBatchUnivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_components = 3
batch_shape = []
dist = make_univariate_mixture(
@@ -539,7 +539,7 @@ class MixtureTest(test.TestCase):
mus = [-5.0, 0.0, 5.0, 4.0, 20.0]
sigmas = [0.1, 5.0, 3.0, 0.2, 4.0]
- with self.test_session():
+ with self.cached_session():
n = 100
random_seed.set_random_seed(654321)
@@ -567,7 +567,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(samples1, samples2)
def testSampleScalarBatchMultivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_components = 3
dist = make_multivariate_mixture(
batch_shape=[], num_components=num_components, event_shape=[2],
@@ -592,7 +592,7 @@ class MixtureTest(test.TestCase):
self.assertAllClose(which_dist_samples, sample_values[which_c, :])
def testSampleBatchUnivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_components = 3
dist = make_univariate_mixture(
batch_shape=[2, 3], num_components=num_components,
@@ -620,7 +620,7 @@ class MixtureTest(test.TestCase):
sample_values[which_c_s, which_c_b0, which_c_b1])
def _testSampleBatchMultivariate(self, fully_known_batch_shape):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_components = 3
if fully_known_batch_shape:
batch_shape = [2, 3]
@@ -672,7 +672,7 @@ class MixtureTest(test.TestCase):
self._testSampleBatchMultivariate(fully_known_batch_shape=False)
def testEntropyLowerBoundMultivariate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for batch_shape in ((), (2,), (2, 3)):
dist = make_multivariate_mixture(
batch_shape=batch_shape, num_components=2, event_shape=(4,),
@@ -732,7 +732,7 @@ class MixtureTest(test.TestCase):
x_cdf_tf = mixture_tf.cdf(x_tensor)
x_log_cdf_tf = mixture_tf.log_cdf(x_tensor)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x_feed in xs_to_check:
x_cdf_tf_result, x_log_cdf_tf_result = sess.run(
[x_cdf_tf, x_log_cdf_tf], feed_dict={x_tensor: x_feed})
@@ -778,7 +778,7 @@ class MixtureTest(test.TestCase):
x_cdf_tf = mixture_tf.cdf(x_tensor)
x_log_cdf_tf = mixture_tf.log_cdf(x_tensor)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x_feed in xs_to_check:
x_cdf_tf_result, x_log_cdf_tf_result = sess.run(
[x_cdf_tf, x_log_cdf_tf],
@@ -802,7 +802,7 @@ class MixtureTest(test.TestCase):
Mixture's use of dynamic partition requires `random_gamma` correctly returns
an empty `Tensor`.
"""
- with self.test_session():
+ with self.cached_session():
gm = ds.Mixture(
cat=ds.Categorical(probs=[.3, .7]),
components=[ds.Gamma(1., 2.),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
index 509fc66c05..3c988dad8a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
@@ -36,7 +36,7 @@ class MovingReduceMeanVarianceTest(test.TestCase):
shape = [1, 2]
true_mean = np.array([[0., 3.]])
true_stddev = np.array([[1.1, 0.5]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Start "x" out with this mean.
mean_var = variables.Variable(array_ops.zeros_like(true_mean))
variance_var = variables.Variable(array_ops.ones_like(true_stddev))
@@ -84,7 +84,7 @@ class MovingReduceMeanVarianceTest(test.TestCase):
shape = [1, 2]
true_mean = np.array([[0., 3.]])
true_stddev = np.array([[1.1, 0.5]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Start "x" out with this mean.
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
@@ -111,7 +111,7 @@ class MovingLogExponentialMovingMeanExpTest(test.TestCase):
true_mean = np.array([[0., 3.]])
true_stddev = np.array([[1.1, 0.5]])
decay = 0.99
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Start "x" out with this mean.
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py
index a924d2e383..88d0d346a4 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_plus_low_rank_test.py
@@ -39,7 +39,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
diag = np.array([[1., 2], [3, 4], [5, 6]])
# batch_shape: [1], event_shape: []
identity_multiplier = np.array([5.])
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagPlusLowRank(
scale_diag=diag,
scale_identity_multiplier=identity_multiplier,
@@ -61,7 +61,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
diag = np.array([[1., 2], [3, 4], [5, 6]])
# batch_shape: [3, 1], event_shape: []
identity_multiplier = np.array([[5.], [4], [3]])
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagPlusLowRank(
scale_diag=diag,
scale_identity_multiplier=identity_multiplier,
@@ -75,7 +75,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
diag = np.array([[1., 2], [3, 4], [5, 6]])
# batch_shape: [3], event_shape: []
identity_multiplier = np.array([5., 4, 3])
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagPlusLowRank(
scale_diag=diag,
scale_identity_multiplier=identity_multiplier,
@@ -94,7 +94,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
loc = np.array([1., 0, -1])
# batch_shape: [3], event_shape: []
identity_multiplier = np.array([5., 4, 3])
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagPlusLowRank(
loc=loc,
scale_identity_multiplier=identity_multiplier,
@@ -116,7 +116,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
diag_large = [1.0, 5.0]
v = [[2.0], [3.0]]
diag_small = [3.0]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagPlusLowRank(
loc=mu,
scale_diag=diag_large,
@@ -146,7 +146,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
true_variance = np.diag(true_covariance)
true_stddev = np.sqrt(true_variance)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.MultivariateNormalDiagPlusLowRank(
loc=mu,
scale_diag=diag_large,
@@ -380,7 +380,7 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
cov = np.stack([np.matmul(scale[0], scale[0].T),
np.matmul(scale[1], scale[1].T)])
logging.vlog(2, "expected_cov:\n{}".format(cov))
- with self.test_session():
+ with self.cached_session():
mvn = ds.MultivariateNormalDiagPlusLowRank(
loc=mu,
scale_perturb_factor=u,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
index 9635134b08..6a3d171f6c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
@@ -45,14 +45,14 @@ class MultivariateNormalDiagTest(test.TestCase):
def testScalarParams(self):
mu = -1.
diag = -5.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "at least 1 dimension"):
ds.MultivariateNormalDiag(mu, diag)
def testVectorParams(self):
mu = [-1.]
diag = [-5.]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
self.assertAllEqual([3, 1], dist.sample(3).get_shape())
@@ -63,7 +63,7 @@ class MultivariateNormalDiagTest(test.TestCase):
# Batch shape = [1], event shape = [3]
mu = array_ops.zeros((1, 3))
diag = array_ops.ones((1, 3))
- with self.test_session():
+ with self.cached_session():
base_dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
dist = ds.TransformedDistribution(
base_dist,
@@ -75,14 +75,14 @@ class MultivariateNormalDiagTest(test.TestCase):
def testMean(self):
mu = [-1., 1]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
self.assertAllEqual(mu, dist.mean().eval())
def testMeanWithBroadcastLoc(self):
mu = [-1.]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
self.assertAllEqual([-1., -1.], dist.mean().eval())
@@ -91,14 +91,14 @@ class MultivariateNormalDiagTest(test.TestCase):
diag = [-1., 5]
diag_mat = np.diag(diag)
scipy_mvn = stats.multivariate_normal(mean=mu, cov=diag_mat**2)
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
self.assertAllClose(scipy_mvn.entropy(), dist.entropy().eval(), atol=1e-4)
def testSample(self):
mu = [-1., 1]
diag = [1., -2]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
samps = dist.sample(int(1e3), seed=0).eval()
cov_mat = array_ops.matrix_diag(diag).eval()**2
@@ -111,7 +111,7 @@ class MultivariateNormalDiagTest(test.TestCase):
def testSingularScaleRaises(self):
mu = [-1., 1]
diag = [1., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
with self.assertRaisesOpError("Singular"):
dist.sample().eval()
@@ -123,7 +123,7 @@ class MultivariateNormalDiagTest(test.TestCase):
# diag corresponds to no batches of 3-variate normals
diag = np.ones([3])
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
mean = dist.mean()
@@ -142,7 +142,7 @@ class MultivariateNormalDiagTest(test.TestCase):
atol=0.10, rtol=0.05)
def testCovariance(self):
- with self.test_session():
+ with self.cached_session():
mvn = ds.MultivariateNormalDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -178,7 +178,7 @@ class MultivariateNormalDiagTest(test.TestCase):
mvn.covariance().eval())
def testVariance(self):
- with self.test_session():
+ with self.cached_session():
mvn = ds.MultivariateNormalDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -203,7 +203,7 @@ class MultivariateNormalDiagTest(test.TestCase):
mvn.variance().eval())
def testStddev(self):
- with self.test_session():
+ with self.cached_session():
mvn = ds.MultivariateNormalDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -229,7 +229,7 @@ class MultivariateNormalDiagTest(test.TestCase):
def testMultivariateNormalDiagWithSoftplusScale(self):
mu = [-1.0, 1.0]
diag = [-1.0, -2.0]
- with self.test_session():
+ with self.cached_session():
dist = ds.MultivariateNormalDiagWithSoftplusScale(
mu, diag, validate_args=True)
samps = dist.sample(1000, seed=0).eval()
@@ -241,7 +241,7 @@ class MultivariateNormalDiagTest(test.TestCase):
def testMultivariateNormalDiagNegLogLikelihood(self):
num_draws = 50
dims = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(dtype=dtypes.float32,
shape=[None, dims],
name="x")
@@ -291,7 +291,7 @@ class MultivariateNormalDiagTest(test.TestCase):
def testKLDivIdenticalGradientDefined(self):
dims = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loc = array_ops.zeros([dims], dtype=dtypes.float32)
mvn = ds.MultivariateNormalDiag(
loc=loc,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
index b003526392..bbf803f045 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py
@@ -40,7 +40,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
return math_ops.matmul(chol, chol, adjoint_b=True).eval()
def testRaisesIfInitializedWithNonSymmetricMatrix(self):
- with self.test_session():
+ with self.cached_session():
mu = [1., 2.]
sigma = [[1., 0.], [1., 1.]] # Nonsingular, but not symmetric
mvn = ds.MultivariateNormalFullCovariance(mu, sigma, validate_args=True)
@@ -48,14 +48,14 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
mvn.covariance().eval()
def testNamePropertyIsSetByInitArg(self):
- with self.test_session():
+ with self.cached_session():
mu = [1., 2.]
sigma = [[1., 0.], [0., 1.]]
mvn = ds.MultivariateNormalFullCovariance(mu, sigma, name="Billy")
self.assertEqual(mvn.name, "Billy/")
def testDoesNotRaiseIfInitializedWithSymmetricMatrix(self):
- with self.test_session():
+ with self.cached_session():
mu = rng.rand(10)
sigma = self._random_pd_matrix(10, 10)
mvn = ds.MultivariateNormalFullCovariance(mu, sigma, validate_args=True)
@@ -63,7 +63,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
mvn.covariance().eval()
def testLogPDFScalarBatch(self):
- with self.test_session():
+ with self.cached_session():
mu = rng.rand(2)
sigma = self._random_pd_matrix(2, 2)
mvn = ds.MultivariateNormalFullCovariance(mu, sigma, validate_args=True)
@@ -82,7 +82,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFScalarBatchCovarianceNotProvided(self):
- with self.test_session():
+ with self.cached_session():
mu = rng.rand(2)
mvn = ds.MultivariateNormalFullCovariance(
mu, covariance_matrix=None, validate_args=True)
@@ -102,7 +102,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
self.assertAllClose(expected_pdf, pdf.eval())
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
mu = rng.rand(3, 5, 2)
covariance = self._random_pd_matrix(3, 5, 2, 2)
@@ -133,7 +133,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
def testKLBatch(self):
batch_shape = [2]
event_shape = [3]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalFullCovariance(
@@ -159,7 +159,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase):
def testKLBatchBroadcast(self):
batch_shape = [2]
event_shape = [3]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
# No batch shape.
mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
index b556d06123..776fc2ca9d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
@@ -45,7 +45,7 @@ class MultivariateNormalTriLTest(test.TestCase):
return chol.eval(), sigma.eval()
def testLogPDFScalarBatch(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
chol[1, 1] = -chol[1, 1]
@@ -65,7 +65,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_pdf, pdf.eval())
def testLogPDFXIsHigherRank(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
chol[0, 0] = -chol[0, 0]
@@ -85,7 +85,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03)
def testLogPDFXLowerDimension(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(3, 2)
chol, sigma = self._random_chol(3, 2, 2)
chol[0, 0, 0] = -chol[0, 0, 0]
@@ -108,7 +108,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_pdf, pdf.eval()[1])
def testEntropy(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
chol[0, 0] = -chol[0, 0]
@@ -121,7 +121,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_entropy, entropy.eval())
def testEntropyMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
@@ -136,7 +136,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
def testSample(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(2)
chol, sigma = self._random_chol(2, 2)
chol[0, 0] = -chol[0, 0]
@@ -152,7 +152,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
def testSingularScaleRaises(self):
- with self.test_session():
+ with self.cached_session():
mu = None
chol = [[1., 0.], [0., 0.]]
mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
@@ -160,7 +160,7 @@ class MultivariateNormalTriLTest(test.TestCase):
mvn.sample().eval()
def testSampleWithSampleShape(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
@@ -185,7 +185,7 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(expected_log_pdf, x_log_pdf)
def testSampleMultiDimensional(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(3, 5, 2)
chol, sigma = self._random_chol(3, 5, 2, 2)
chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
@@ -205,7 +205,7 @@ class MultivariateNormalTriLTest(test.TestCase):
atol=1e-1)
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
mu = self._rng.rand(3, 5, 2)
chol, _ = self._random_chol(3, 5, 2, 2)
chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
@@ -237,7 +237,7 @@ class MultivariateNormalTriLTest(test.TestCase):
def testKLNonBatch(self):
batch_shape = []
event_shape = [2]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalTriL(
@@ -259,7 +259,7 @@ class MultivariateNormalTriLTest(test.TestCase):
def testKLBatch(self):
batch_shape = [2]
event_shape = [3]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalTriL(
@@ -285,7 +285,7 @@ class MultivariateNormalTriLTest(test.TestCase):
def testKLBatchBroadcast(self):
batch_shape = [2]
event_shape = [3]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
# No batch shape.
mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
@@ -312,7 +312,7 @@ class MultivariateNormalTriLTest(test.TestCase):
def testKLTwoIdenticalDistributionsIsZero(self):
batch_shape = [2]
event_shape = [3]
- with self.test_session():
+ with self.cached_session():
mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
mvn_a = ds.MultivariateNormalTriL(
loc=mu_a,
@@ -336,7 +336,7 @@ class MultivariateNormalTriLTest(test.TestCase):
true_variance = np.diag(true_covariance)
true_stddev = np.sqrt(true_variance)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.MultivariateNormalTriL(
loc=mu,
scale_tril=scale_tril,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py
index 37edaa42cd..a46b81af35 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class NegativeBinomialTest(test.TestCase):
def testNegativeBinomialShape(self):
- with self.test_session():
+ with self.cached_session():
probs = [.1] * 5
total_count = [2.0] * 5
negbinom = negative_binomial.NegativeBinomial(
@@ -46,7 +46,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), negbinom.event_shape)
def testNegativeBinomialShapeBroadcast(self):
- with self.test_session():
+ with self.cached_session():
probs = [[.1, .2, .3]] * 5
total_count = [[2.]] * 5
negbinom = negative_binomial.NegativeBinomial(
@@ -60,7 +60,7 @@ class NegativeBinomialTest(test.TestCase):
def testLogits(self):
logits = [[0., 9., -0.5]]
- with self.test_session():
+ with self.cached_session():
negbinom = negative_binomial.NegativeBinomial(
total_count=3., logits=logits)
self.assertEqual([1, 3], negbinom.probs.get_shape())
@@ -69,14 +69,14 @@ class NegativeBinomialTest(test.TestCase):
def testInvalidP(self):
invalid_ps = [-.01, 0., -2.,]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
negbinom = negative_binomial.NegativeBinomial(
5., probs=invalid_ps, validate_args=True)
negbinom.probs.eval()
invalid_ps = [1.01, 2., 1.001,]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("probs has components greater than 1."):
negbinom = negative_binomial.NegativeBinomial(
5., probs=invalid_ps, validate_args=True)
@@ -84,14 +84,14 @@ class NegativeBinomialTest(test.TestCase):
def testInvalidNegativeCount(self):
invalid_rs = [-.01, 0., -2.,]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x > 0"):
negbinom = negative_binomial.NegativeBinomial(
total_count=invalid_rs, probs=0.1, validate_args=True)
negbinom.total_count.eval()
def testNegativeBinomialLogCdf(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = [.2] * batch_size
probs_v = .2
@@ -109,7 +109,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_cdf), cdf.eval())
def testNegativeBinomialLogCdfValidateArgs(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = [.9] * batch_size
total_count = 5.
@@ -119,7 +119,7 @@ class NegativeBinomialTest(test.TestCase):
negbinom.log_cdf(-1.).eval()
def testNegativeBinomialLogPmf(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = [.2] * batch_size
probs_v = .2
@@ -137,7 +137,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pmf), pmf.eval())
def testNegativeBinomialLogPmfValidateArgs(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = [.9] * batch_size
total_count = 5.
@@ -162,7 +162,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertEqual([6], pmf.get_shape())
def testNegativeBinomialLogPmfMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
probs = constant_op.constant([[.2, .3, .5]] * batch_size)
probs_v = np.array([.2, .3, .5])
@@ -183,7 +183,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pmf), pmf_values)
def testNegativeBinomialMean(self):
- with self.test_session():
+ with self.cached_session():
total_count = 5.
probs = np.array([.1, .3, .25], dtype=np.float32)
negbinom = negative_binomial.NegativeBinomial(
@@ -193,7 +193,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(expected_means, negbinom.mean().eval())
def testNegativeBinomialVariance(self):
- with self.test_session():
+ with self.cached_session():
total_count = 5.
probs = np.array([.1, .3, .25], dtype=np.float32)
negbinom = negative_binomial.NegativeBinomial(
@@ -203,7 +203,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(expected_vars, negbinom.variance().eval())
def testNegativeBinomialStddev(self):
- with self.test_session():
+ with self.cached_session():
total_count = 5.
probs = np.array([.1, .3, .25], dtype=np.float32)
negbinom = negative_binomial.NegativeBinomial(
@@ -213,7 +213,7 @@ class NegativeBinomialTest(test.TestCase):
self.assertAllClose(expected_stds, negbinom.stddev().eval())
def testNegativeBinomialSample(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
probs = [.3, .9]
total_count = [4., 11.]
n = int(100e3)
@@ -242,7 +242,7 @@ class NegativeBinomialTest(test.TestCase):
rtol=.02)
def testLogProbOverflow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = np.float32([20., 30., 40.])
total_count = np.float32(1.)
x = np.float32(0.)
@@ -253,7 +253,7 @@ class NegativeBinomialTest(test.TestCase):
np.isfinite(log_prob_))
def testLogProbUnderflow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = np.float32([-90, -100, -110])
total_count = np.float32(1.)
x = np.float32(0.)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py
index 111f88eeb5..84ee19123c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/onehot_categorical_test.py
@@ -44,7 +44,7 @@ class OneHotCategoricalTest(test.TestCase):
def testP(self):
p = [0.2, 0.8]
dist = onehot_categorical.OneHotCategorical(probs=p)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(p, dist.probs.eval())
self.assertAllEqual([2], dist.logits.get_shape())
@@ -52,14 +52,14 @@ class OneHotCategoricalTest(test.TestCase):
p = np.array([0.2, 0.8], dtype=np.float32)
logits = np.log(p) - 50.
dist = onehot_categorical.OneHotCategorical(logits=logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([2], dist.probs.get_shape())
self.assertAllEqual([2], dist.logits.get_shape())
self.assertAllClose(dist.probs.eval(), p)
self.assertAllClose(dist.logits.eval(), logits)
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_onehot_categorical(batch_shape, 10)
self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
@@ -97,7 +97,7 @@ class OneHotCategoricalTest(test.TestCase):
np.array([1]+[0]*4, dtype=np.int64)).dtype)
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(dtype=dtypes.float32)
dist = onehot_categorical.OneHotCategorical(logits)
sample = dist.sample()
@@ -112,7 +112,7 @@ class OneHotCategoricalTest(test.TestCase):
def testEntropyNoBatch(self):
logits = np.log([0.2, 0.8]) - 50.
dist = onehot_categorical.OneHotCategorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
dist.entropy().eval(),
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
@@ -120,7 +120,7 @@ class OneHotCategoricalTest(test.TestCase):
def testEntropyWithBatch(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = onehot_categorical.OneHotCategorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.entropy().eval(), [
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
-(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
@@ -128,7 +128,7 @@ class OneHotCategoricalTest(test.TestCase):
def testPmf(self):
# check that probability of samples correspond to their class probabilities
- with self.test_session():
+ with self.cached_session():
logits = self._rng.random_sample(size=(8, 2, 10))
prob = np.exp(logits)/np.sum(np.exp(logits), axis=-1, keepdims=True)
dist = onehot_categorical.OneHotCategorical(logits=logits)
@@ -138,7 +138,7 @@ class OneHotCategoricalTest(test.TestCase):
self.assertAllClose(expected_prob, np_prob.flatten())
def testSample(self):
- with self.test_session():
+ with self.cached_session():
probs = [[[0.2, 0.8], [0.4, 0.6]]]
dist = onehot_categorical.OneHotCategorical(math_ops.log(probs) - 50.)
n = 100
@@ -150,7 +150,7 @@ class OneHotCategoricalTest(test.TestCase):
self.assertFalse(np.any(sample_values > 1))
def testSampleWithSampleShape(self):
- with self.test_session():
+ with self.cached_session():
probs = [[[0.2, 0.8], [0.4, 0.6]]]
dist = onehot_categorical.OneHotCategorical(math_ops.log(probs) - 50.)
samples = dist.sample((100, 100), seed=123)
@@ -166,7 +166,7 @@ class OneHotCategoricalTest(test.TestCase):
exp_logits = np.exp(logits)
return exp_logits / exp_logits.sum(axis=-1, keepdims=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for categories in [2, 10]:
for batch_size in [1, 2]:
p_logits = self._rng.random_sample((batch_size, categories))
@@ -193,7 +193,7 @@ class OneHotCategoricalTest(test.TestCase):
self.assertAllClose(kl_sample_, kl_expected, atol=1e-2, rtol=0.)
def testSampleUnbiasedNonScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = self._rng.rand(4, 3, 2).astype(np.float32)
dist = onehot_categorical.OneHotCategorical(logits=logits)
n = int(3e3)
@@ -221,7 +221,7 @@ class OneHotCategoricalTest(test.TestCase):
actual_covariance_, sample_covariance_, atol=0., rtol=0.10)
def testSampleUnbiasedScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = self._rng.rand(3).astype(np.float32)
dist = onehot_categorical.OneHotCategorical(logits=logits)
n = int(1e4)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
index 1035cb00f7..e2d04c9c27 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py
@@ -29,7 +29,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
"""Tests the PoissonLogNormalQuadratureCompoundTest distribution."""
def testSampleProbConsistent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
-2.,
@@ -43,7 +43,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
sess.run, pln, batch_size=1, rtol=0.1)
def testMeanVariance(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
0.,
@@ -57,7 +57,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
sess.run, pln, rtol=0.02)
def testSampleProbConsistentBroadcastScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
[0., -0.5],
@@ -71,7 +71,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
sess.run, pln, batch_size=2, rtol=0.1, atol=0.01)
def testMeanVarianceBroadcastScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
[0., -0.5],
@@ -85,7 +85,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
sess.run, pln, rtol=0.1, atol=0.01)
def testSampleProbConsistentBroadcastBoth(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
[[0.], [-0.5]],
@@ -99,7 +99,7 @@ class _PoissonLogNormalQuadratureCompoundTest(
sess.run, pln, batch_size=4, rtol=0.1, atol=0.08)
def testMeanVarianceBroadcastBoth(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=array_ops.placeholder_with_default(
[[0.], [-0.5]],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
index 19a7472d91..29eba5afca 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
@@ -35,7 +35,7 @@ class PoissonTest(test.TestCase):
return poisson_lib.Poisson(rate=rate, validate_args=validate_args)
def testPoissonShape(self):
- with self.test_session():
+ with self.cached_session():
lam = constant_op.constant([3.0] * 5)
poisson = self._make_poisson(rate=lam)
@@ -47,13 +47,13 @@ class PoissonTest(test.TestCase):
def testInvalidLam(self):
invalid_lams = [-.01, 0., -2.]
for lam in invalid_lams:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x > 0"):
poisson = self._make_poisson(rate=lam, validate_args=True)
poisson.rate.eval()
def testPoissonLogPmf(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([3.0] * batch_size)
lam_v = 3.0
@@ -68,7 +68,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(pmf.eval(), stats.poisson.pmf(x, lam_v))
def testPoissonLogPmfValidateArgs(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([3.0] * batch_size)
x = array_ops.placeholder(dtypes.float32, shape=[6])
@@ -91,7 +91,7 @@ class PoissonTest(test.TestCase):
self.assertEqual(pmf.get_shape(), (6,))
def testPoissonLogPmfMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size)
lam_v = [2.0, 4.0, 5.0]
@@ -107,7 +107,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(pmf.eval(), stats.poisson.pmf(x, lam_v))
def testPoissonCDF(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([3.0] * batch_size)
lam_v = 3.0
@@ -123,7 +123,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v))
def testPoissonCDFNonIntegerValues(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([3.0] * batch_size)
lam_v = 3.0
@@ -142,7 +142,7 @@ class PoissonTest(test.TestCase):
poisson_validate.cdf(x).eval()
def testPoissonCdfMultidimensional(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 6
lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size)
lam_v = [2.0, 4.0, 5.0]
@@ -158,7 +158,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v))
def testPoissonMean(self):
- with self.test_session():
+ with self.cached_session():
lam_v = [1.0, 3.0, 2.5]
poisson = self._make_poisson(rate=lam_v)
self.assertEqual(poisson.mean().get_shape(), (3,))
@@ -166,7 +166,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(poisson.mean().eval(), lam_v)
def testPoissonVariance(self):
- with self.test_session():
+ with self.cached_session():
lam_v = [1.0, 3.0, 2.5]
poisson = self._make_poisson(rate=lam_v)
self.assertEqual(poisson.variance().get_shape(), (3,))
@@ -174,7 +174,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(poisson.variance().eval(), lam_v)
def testPoissonStd(self):
- with self.test_session():
+ with self.cached_session():
lam_v = [1.0, 3.0, 2.5]
poisson = self._make_poisson(rate=lam_v)
self.assertEqual(poisson.stddev().get_shape(), (3,))
@@ -182,14 +182,14 @@ class PoissonTest(test.TestCase):
self.assertAllClose(poisson.stddev().eval(), np.sqrt(lam_v))
def testPoissonMode(self):
- with self.test_session():
+ with self.cached_session():
lam_v = [1.0, 3.0, 2.5, 3.2, 1.1, 0.05]
poisson = self._make_poisson(rate=lam_v)
self.assertEqual(poisson.mode().get_shape(), (6,))
self.assertAllClose(poisson.mode().eval(), np.floor(lam_v))
def testPoissonMultipleMode(self):
- with self.test_session():
+ with self.cached_session():
lam_v = [1.0, 3.0, 2.0, 4.0, 5.0, 10.0]
poisson = self._make_poisson(rate=lam_v)
# For the case where lam is an integer, the modes are: lam and lam - 1.
@@ -198,7 +198,7 @@ class PoissonTest(test.TestCase):
self.assertAllClose(lam_v, poisson.mode().eval())
def testPoissonSample(self):
- with self.test_session():
+ with self.cached_session():
lam_v = 4.0
lam = constant_op.constant(lam_v)
# Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
@@ -215,7 +215,7 @@ class PoissonTest(test.TestCase):
sample_values.var(), stats.poisson.var(lam_v), rtol=.01)
def testPoissonSampleMultidimensionalMean(self):
- with self.test_session():
+ with self.cached_session():
lam_v = np.array([np.arange(1, 51, dtype=np.float32)]) # 1 x 50
poisson = self._make_poisson(rate=lam_v)
# Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be
@@ -232,7 +232,7 @@ class PoissonTest(test.TestCase):
atol=0)
def testPoissonSampleMultidimensionalVariance(self):
- with self.test_session():
+ with self.cached_session():
lam_v = np.array([np.arange(5, 15, dtype=np.float32)]) # 1 x 10
poisson = self._make_poisson(rate=lam_v)
# Choosing `n >= 2 * lam * (k/rtol)**2, roughly ensures our sample
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
index 6a7ee3a8bf..07528cafaf 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py
@@ -38,7 +38,7 @@ class QuantizedDistributionTest(test.TestCase):
self.assertTrue(np.isfinite(array).all())
def testQuantizationOfUniformWithCutoffsHavingNoEffect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The Quantized uniform with cutoffs == None divides the real line into:
# R = ...(-1, 0](0, 1](1, 2](2, 3](3, 4]...
# j = ... 0 1 2 3 4 ...
@@ -93,7 +93,7 @@ class QuantizedDistributionTest(test.TestCase):
self.assertAllClose(3 / 3, cdf_5)
def testQuantizationOfUniformWithCutoffsInTheMiddle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The uniform is supported on [-3, 3]
# Consider partitions the real line in intervals
# ...(-3, -2](-2, -1](-1, 0](0, 1](1, 2](2, 3] ...
@@ -131,7 +131,7 @@ class QuantizedDistributionTest(test.TestCase):
def testQuantizationOfBatchOfUniforms(self):
batch_shape = (5, 5)
- with self.test_session():
+ with self.cached_session():
# The uniforms are supported on [0, 10]. The qdist considers the
# intervals
# ... (0, 1](1, 2]...(9, 10]...
@@ -165,7 +165,7 @@ class QuantizedDistributionTest(test.TestCase):
def testSamplingFromBatchOfNormals(self):
batch_shape = (2,)
- with self.test_session():
+ with self.cached_session():
normal = distributions.Normal(
loc=array_ops.zeros(
batch_shape, dtype=dtypes.float32),
@@ -199,7 +199,7 @@ class QuantizedDistributionTest(test.TestCase):
# pretend that the cdf F is a bijection, and hence F(X) is uniform.
# Note that F cannot be bijection since it is constant between the
# integers. Hence, F(X) (see below) will not be uniform exactly.
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Exponential(rate=0.01))
# X ~ QuantizedExponential
@@ -222,7 +222,7 @@ class QuantizedDistributionTest(test.TestCase):
# it makes sure the bin edges are consistent.
# Make an exponential with mean 5.
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Exponential(rate=0.2))
# Standard error should be less than 1 / (2 * sqrt(n_samples))
@@ -243,7 +243,7 @@ class QuantizedDistributionTest(test.TestCase):
batch_shape = (3, 3)
mu = rng.randn(*batch_shape)
sigma = rng.rand(*batch_shape) + 1.0
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
loc=mu, scale=sigma))
@@ -260,7 +260,7 @@ class QuantizedDistributionTest(test.TestCase):
batch_shape = (3, 3)
mu = rng.randn(*batch_shape)
sigma = rng.rand(*batch_shape) + 1.0
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
loc=mu, scale=sigma))
@@ -275,7 +275,7 @@ class QuantizedDistributionTest(test.TestCase):
def testNormalProbWithCutoffs(self):
# At integer values, the result should be the same as the standard normal.
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=-2.,
@@ -297,7 +297,7 @@ class QuantizedDistributionTest(test.TestCase):
def testNormalLogProbWithCutoffs(self):
# At integer values, the result should be the same as the standard normal.
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=-2.,
@@ -335,14 +335,14 @@ class QuantizedDistributionTest(test.TestCase):
x = np.arange(-100, 100, 2).astype(dtype)
proba = qdist.log_prob(x)
grads = gradients_impl.gradients(proba, [mu, sigma])
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self._assert_all_finite(proba.eval())
self._assert_all_finite(grads[0].eval())
self._assert_all_finite(grads[1].eval())
def testProbAndGradGivesFiniteResultsForCommonEvents(self):
- with self.test_session():
+ with self.cached_session():
mu = variables.Variable(0.0, name="mu")
sigma = variables.Variable(1.0, name="sigma")
qdist = distributions.QuantizedDistribution(
@@ -360,7 +360,7 @@ class QuantizedDistributionTest(test.TestCase):
self._assert_all_finite(grads[1].eval())
def testLowerCutoffMustBeBelowUpperCutoffOrWeRaise(self):
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
low=1., # not strictly less than high.
@@ -372,7 +372,7 @@ class QuantizedDistributionTest(test.TestCase):
qdist.sample().eval()
def testCutoffsMustBeIntegerValuedIfValidateArgsTrue(self):
- with self.test_session():
+ with self.cached_session():
low = array_ops.placeholder(dtypes.float32)
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(loc=0., scale=1.),
@@ -385,7 +385,7 @@ class QuantizedDistributionTest(test.TestCase):
qdist.sample().eval(feed_dict={low: 1.5})
def testCutoffsCanBeFloatValuedIfValidateArgsFalse(self):
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
loc=0., scale=1., validate_args=False),
@@ -399,7 +399,7 @@ class QuantizedDistributionTest(test.TestCase):
def testDtypeAndShapeInheritedFromBaseDist(self):
batch_shape = (2, 3)
- with self.test_session():
+ with self.cached_session():
qdist = distributions.QuantizedDistribution(
distribution=distributions.Normal(
loc=array_ops.zeros(batch_shape),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py
index 2cf12bbe50..fec2374928 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py
@@ -34,29 +34,29 @@ class RelaxedBernoulliTest(test.TestCase):
temperature = 1.0
p = [0.1, 0.4]
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(p, dist.probs.eval())
def testLogits(self):
temperature = 2.0
logits = [-42., 42.]
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, logits=logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(logits, dist.logits.eval())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(scipy.special.expit(logits), dist.probs.eval())
p = [0.01, 0.99, 0.42]
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(scipy.special.logit(p), dist.logits.eval())
def testInvalidP(self):
temperature = 1.0
invalid_ps = [1.01, 2.]
for p in invalid_ps:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("probs has components greater than 1"):
dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
probs=p,
@@ -65,7 +65,7 @@ class RelaxedBernoulliTest(test.TestCase):
invalid_ps = [-0.01, -3.]
for p in invalid_ps:
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Condition x >= 0"):
dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
probs=p,
@@ -74,13 +74,13 @@ class RelaxedBernoulliTest(test.TestCase):
valid_ps = [0.0, 0.5, 1.0]
for p in valid_ps:
- with self.test_session():
+ with self.cached_session():
dist = relaxed_bernoulli.RelaxedBernoulli(temperature,
probs=p)
self.assertEqual(p, dist.probs.eval())
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
temperature = 1.0
p = np.random.random(batch_shape).astype(np.float32)
@@ -96,7 +96,7 @@ class RelaxedBernoulliTest(test.TestCase):
p = constant_op.constant([0.1, 0.4])
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p,
validate_args=True)
- with self.test_session():
+ with self.cached_session():
sample = dist.sample()
with self.assertRaises(errors_impl.InvalidArgumentError):
sample.eval()
@@ -117,7 +117,7 @@ class RelaxedBernoulliTest(test.TestCase):
self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
def testLogProb(self):
- with self.test_session():
+ with self.cached_session():
t = np.array(1.0, dtype=np.float64)
p = np.array(0.1, dtype=np.float64) # P(x=1)
dist = relaxed_bernoulli.RelaxedBernoulli(t, probs=p)
@@ -131,7 +131,7 @@ class RelaxedBernoulliTest(test.TestCase):
self.assertAllClose(expected_log_pdf, log_pdf)
def testBoundaryConditions(self):
- with self.test_session():
+ with self.cached_session():
temperature = 1e-2
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=1.0)
self.assertAllClose(np.nan, dist.log_prob(0.0).eval())
@@ -139,7 +139,7 @@ class RelaxedBernoulliTest(test.TestCase):
def testSampleN(self):
"""mean of quantized samples still approximates the Bernoulli mean."""
- with self.test_session():
+ with self.cached_session():
temperature = 1e-2
p = [0.2, 0.6, 0.5]
dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py
index faae9da6ad..ff13c2decc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py
@@ -46,7 +46,7 @@ class ExpRelaxedOneHotCategoricalTest(test.TestCase):
dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(temperature,
logits)
expected_p = np.exp(logits)/np.sum(np.exp(logits))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_p, dist.probs.eval())
self.assertAllEqual([3], dist.probs.get_shape())
@@ -57,7 +57,7 @@ class ExpRelaxedOneHotCategoricalTest(test.TestCase):
p = np.exp(logits)/np.sum(np.exp(logits))
dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(temperature,
logits)
- with self.test_session():
+ with self.cached_session():
x = dist.sample().eval()
# analytical ExpConcrete density presented in Maddison et al. 2016
prod_term = p*np.exp(-temperature * x)
@@ -74,14 +74,14 @@ class RelaxedOneHotCategoricalTest(test.TestCase):
logits = [2.0, 3.0, -4.0]
dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(temperature,
logits)
- with self.test_session():
+ with self.cached_session():
# check p for ExpRelaxed base distribution
self.assertAllClose(logits, dist._distribution.logits.eval())
self.assertAllEqual([3], dist._distribution.logits.get_shape())
def testSample(self):
temperature = 1.4
- with self.test_session():
+ with self.cached_session():
# single logit
logits = [.3, .1, .4]
dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(temperature,
@@ -115,7 +115,7 @@ class RelaxedOneHotCategoricalTest(test.TestCase):
expected_pdf = term1*np.power(term2, -k)*term3
return expected_pdf
- with self.test_session():
+ with self.cached_session():
temperature = .4
logits = np.array([[.3, .1, .4]]).astype(np.float32)
dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(temperature,
@@ -136,7 +136,7 @@ class RelaxedOneHotCategoricalTest(test.TestCase):
self.assertAllClose(expected_pdf.flatten(), pdf, rtol=1e-4)
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_relaxed_categorical(batch_shape, 10)
self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
@@ -153,12 +153,12 @@ class RelaxedOneHotCategoricalTest(test.TestCase):
self.assertAllEqual([10], dist.event_shape_tensor().eval())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
logits_pl = array_ops.placeholder(dtypes.float32)
temperature = 1.0
dist = relaxed_onehot_categorical.ExpRelaxedOneHotCategorical(temperature,
logits_pl)
- with self.test_session():
+ with self.cached_session():
feed_dict = {logits_pl: [.3, .1, .4]}
self.assertAllEqual([3], dist.sample().eval(feed_dict=feed_dict).shape)
self.assertAllEqual([5, 3],
@@ -166,7 +166,7 @@ class RelaxedOneHotCategoricalTest(test.TestCase):
def testDTypes(self):
# check that sampling and log_prob work for a range of dtypes
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
logits = random_ops.random_uniform(shape=[3, 3], dtype=dtype)
dist = relaxed_onehot_categorical.RelaxedOneHotCategorical(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py
index ea04e8c29a..d6020e7866 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py
@@ -47,7 +47,7 @@ class _AutoCorrelationTest(object):
input=x_,
shape=x_.shape if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Setting normalize = True means we divide by zero.
auto_corr = sample_stats.auto_correlation(
x_ph, axis=1, center=False, normalize=False)
@@ -65,7 +65,7 @@ class _AutoCorrelationTest(object):
input=x_,
shape=x_.shape if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Setting normalize = True means we divide by zero.
auto_corr = sample_stats.auto_correlation(
x_ph, axis=1, normalize=False, center=True)
@@ -100,7 +100,7 @@ class _AutoCorrelationTest(object):
x_ph = array_ops.placeholder_with_default(
x, shape=x.shape if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session():
+ with self.cached_session():
auto_corr = sample_stats.auto_correlation(
x_ph, axis=axis, max_lags=max_lags, center=center,
normalize=normalize)
@@ -167,7 +167,7 @@ class _AutoCorrelationTest(object):
x_ph = array_ops.placeholder_with_default(
x, shape=(l,) if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session():
+ with self.cached_session():
rxx = sample_stats.auto_correlation(
x_ph, max_lags=l // 2, center=True, normalize=False)
if self.use_static_shape:
@@ -188,7 +188,7 @@ class _AutoCorrelationTest(object):
x_ph = array_ops.placeholder_with_default(
x, shape=(1000 * 10,) if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session():
+ with self.cached_session():
rxx = sample_stats.auto_correlation(
x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False)
if self.use_static_shape:
@@ -209,7 +209,7 @@ class _AutoCorrelationTest(object):
x_ph = array_ops.placeholder_with_default(
x, shape=(l,) if self.use_static_shape else None)
with spectral_ops_test_util.fft_kernel_label_map():
- with self.test_session():
+ with self.cached_session():
rxx = sample_stats.auto_correlation(
x_ph, max_lags=l // 2, center=True, normalize=True)
if self.use_static_shape:
@@ -271,7 +271,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation, axis=0)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x, q=q, interpolation=self._interpolation, axis=[0])
self.assertAllEqual((), pct.get_shape())
@@ -282,7 +282,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(x, q=q, interpolation=self._interpolation)
self.assertAllEqual((), pct.get_shape())
self.assertAllClose(expected_percentile, pct.eval())
@@ -292,7 +292,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation, axis=0)
- with self.test_session():
+ with self.cached_session():
# Get dim 1 with negative and positive indices.
pct_neg_index = sample_stats.percentile(
x, q=q, interpolation=self._interpolation, axis=[0])
@@ -308,7 +308,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation, axis=0)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x, q=q, interpolation=self._interpolation, axis=[0])
self.assertAllEqual((2,), pct.get_shape())
@@ -319,7 +319,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation, keepdims=True, axis=0)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x,
q=q,
@@ -334,7 +334,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for axis in [None, 0, 1, -2, (0,), (-1,), (-1, 1), (3, 1), (-3, 0)]:
expected_percentile = np.percentile(
x, q=0.77, interpolation=self._interpolation, axis=axis)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x,
q=0.77,
@@ -352,7 +352,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
interpolation=self._interpolation,
axis=axis,
keepdims=True)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x,
q=0.77,
@@ -368,7 +368,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for axis in [None, 0, 1, -2, (0,), (-1,), (-1, 1), (3, 1), (-3, 0)]:
expected_percentile = np.percentile(
x, q=0.77, interpolation=self._interpolation, axis=axis)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x_ph,
q=0.77,
@@ -386,7 +386,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
interpolation=self._interpolation,
axis=axis,
keepdims=True)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(
x_ph,
q=0.77,
@@ -400,7 +400,7 @@ class PercentileTestWithLowerInterpolation(test.TestCase):
for q in [0, 10, 25, 49.9, 50, 50.01, 90, 95, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(x, q=q, interpolation=self._interpolation)
self.assertEqual(dtypes.int32, pct.dtype)
self.assertAllEqual((), pct.get_shape())
@@ -423,7 +423,7 @@ class PercentileTestWithNearestInterpolation(test.TestCase):
for q in [0, 10.1, 25.1, 49.9, 50.1, 50.01, 89, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(x, q=q, interpolation=self._interpolation)
self.assertAllEqual((), pct.get_shape())
self.assertAllClose(expected_percentile, pct.eval())
@@ -433,7 +433,7 @@ class PercentileTestWithNearestInterpolation(test.TestCase):
for q in [0, 10.1, 25.1, 49.9, 50.1, 50.01, 89, 100]:
expected_percentile = np.percentile(
x, q=q, interpolation=self._interpolation)
- with self.test_session():
+ with self.cached_session():
pct = sample_stats.percentile(x, q=q, interpolation=self._interpolation)
self.assertAllEqual((), pct.get_shape())
self.assertAllClose(expected_percentile, pct.eval())
@@ -452,7 +452,7 @@ class PercentileTestWithNearestInterpolation(test.TestCase):
x = [1., 5., 3., 2., 4.]
q_ph = array_ops.placeholder(dtypes.float32)
pct = sample_stats.percentile(x, q=q_ph, validate_args=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("rank"):
pct.eval(feed_dict={q_ph: [0.5]})
@@ -462,7 +462,7 @@ class PercentileTestWithNearestInterpolation(test.TestCase):
# If float is used, it fails with InvalidArgumentError about an index out of
# bounds.
x = math_ops.linspace(0., 3e7, num=int(3e7))
- with self.test_session():
+ with self.cached_session():
minval = sample_stats.percentile(x, q=0, validate_args=True)
self.assertAllEqual(0, minval.eval())
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
index 243b5a0348..a4d2aa381c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
@@ -73,7 +73,7 @@ class MakeBatchReadyTest(test.TestCase):
return y, sample_shape, should_be_x_value
def _test_dynamic(self, x, batch_ndims, event_ndims, expand_batch_dim=True):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(x.dtype)
batch_ndims_pl = array_ops.placeholder(dtypes.int32)
event_ndims_pl = array_ops.placeholder(dtypes.int32)
@@ -91,7 +91,7 @@ class MakeBatchReadyTest(test.TestCase):
self.assertAllEqual(x, should_be_x_value_)
def _test_static(self, x, batch_ndims, event_ndims, expand_batch_dim):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[y_, sample_shape_, should_be_x_value_] = sess.run(
self._build_graph(x, batch_ndims, event_ndims, expand_batch_dim))
expected_y, expected_sample_shape = self._get_expected(
@@ -544,7 +544,7 @@ class DistributionShapeTest(test.TestCase):
self.assertAllEqual(expected_item, next(actual_item))
def testDistributionShapeGetNdimsStatic(self):
- with self.test_session():
+ with self.cached_session():
shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
x = 1
self.assertEqual(0, shaper.get_sample_ndims(x).eval())
@@ -572,7 +572,7 @@ class DistributionShapeTest(test.TestCase):
self.assertEqual(1, shaper.event_ndims.eval())
def testDistributionShapeGetNdimsDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_ndims = array_ops.placeholder(dtypes.int32)
event_ndims = array_ops.placeholder(dtypes.int32)
shaper = _DistributionShape(
@@ -583,7 +583,7 @@ class DistributionShapeTest(test.TestCase):
self.assertEqual(2, sess.run(shaper.get_ndims(y), feed_dict=feed_dict))
def testDistributionShapeGetDimsStatic(self):
- with self.test_session():
+ with self.cached_session():
shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
x = 1
self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape),
@@ -597,7 +597,7 @@ class DistributionShapeTest(test.TestCase):
_constant(shaper.get_dims(x)))
def testDistributionShapeGetDimsDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Works for static {batch,event}_ndims despite unfed input.
shaper = _DistributionShape(batch_ndims=1, event_ndims=2)
y = array_ops.placeholder(dtypes.float32, shape=(10, None, 5, 5))
@@ -615,7 +615,7 @@ class DistributionShapeTest(test.TestCase):
([0], [1], [2, 3]), sess.run(shaper.get_dims(y), feed_dict=feed_dict))
def testDistributionShapeGetShapeStatic(self):
- with self.test_session():
+ with self.cached_session():
shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape),
_constant(shaper.get_shape(1.)))
@@ -657,7 +657,7 @@ class DistributionShapeTest(test.TestCase):
_constant(shaper.get_shape(np.ones((3, 2, 1)))))
def testDistributionShapeGetShapeDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Works for static ndims despite unknown static shape.
shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
y = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py
index 88b48736dd..1811d85b7e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py
@@ -34,7 +34,7 @@ class SinhArcsinhTest(test.TestCase):
b = 10
scale = rng.rand(b) + 0.5
loc = rng.randn(b)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.Normal(
loc=loc,
scale=scale,
@@ -58,7 +58,7 @@ class SinhArcsinhTest(test.TestCase):
norm_samps.std(axis=0), sasnorm_samps.std(axis=0), atol=0.1)
def test_broadcast_params_dynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loc = array_ops.placeholder(dtypes.float64)
scale = array_ops.placeholder(dtypes.float64)
skewness = array_ops.placeholder(dtypes.float64)
@@ -78,7 +78,7 @@ class SinhArcsinhTest(test.TestCase):
b = 10
scale = rng.rand(b) + 0.5
loc = rng.randn(b)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lap = ds.Laplace(
loc=loc,
scale=scale,
@@ -106,7 +106,7 @@ class SinhArcsinhTest(test.TestCase):
batch_size = 10
scale = rng.rand(batch_size) + 0.5
loc = 0.1 * rng.randn(batch_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.Normal(
loc=loc,
scale=scale,
@@ -148,7 +148,7 @@ class SinhArcsinhTest(test.TestCase):
batch_size = 10
scale = rng.rand(batch_size) + 0.5
loc = np.float64(0.)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.Normal(
loc=loc,
scale=scale,
@@ -190,7 +190,7 @@ class SinhArcsinhTest(test.TestCase):
batch_size = 10
scale = rng.rand(batch_size) + 0.5
loc = rng.randn(batch_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sasnorm = ds.SinhArcsinh(
loc=loc,
scale=scale,
@@ -201,7 +201,7 @@ class SinhArcsinhTest(test.TestCase):
np.testing.assert_array_less(loc, sasnorm_samps.mean(axis=0))
def test_pdf_reflected_for_negative_skewness(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sas_pos_skew = ds.SinhArcsinh(
loc=0.,
scale=1.,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 5fe1331d2c..196cc41335 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -91,7 +91,7 @@ class TransformedDistributionTest(test.TestCase):
# sample
sample = log_normal.sample(100000, seed=235)
self.assertAllEqual([], log_normal.event_shape)
- with self.test_session(graph=g):
+ with self.session(graph=g):
self.assertAllEqual([], log_normal.event_shape_tensor().eval())
self.assertAllClose(
sp_dist.mean(), np.mean(sample.eval()), atol=0.0, rtol=0.05)
@@ -107,7 +107,7 @@ class TransformedDistributionTest(test.TestCase):
[log_normal.log_survival_function, sp_dist.logsf]]:
actual = func[0](test_vals)
expected = func[1](test_vals)
- with self.test_session(graph=g):
+ with self.session(graph=g):
self.assertAllClose(expected, actual.eval(), atol=0, rtol=0.01)
def testNonInjectiveTransformedDistribution(self):
@@ -123,7 +123,7 @@ class TransformedDistributionTest(test.TestCase):
# sample
sample = abs_normal.sample(100000, seed=235)
self.assertAllEqual([], abs_normal.event_shape)
- with self.test_session(graph=g):
+ with self.session(graph=g):
sample_ = sample.eval()
self.assertAllEqual([], abs_normal.event_shape_tensor().eval())
@@ -147,7 +147,7 @@ class TransformedDistributionTest(test.TestCase):
abs_normal.log_prob(2.13).eval())
def testQuantile(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logit_normal = self._cls()(
distribution=ds.Normal(loc=0., scale=1.),
bijector=bs.Sigmoid(),
@@ -169,7 +169,7 @@ class TransformedDistributionTest(test.TestCase):
exp_forward_only._inverse_log_det_jacobian = self._make_unimplemented(
"inverse_log_det_jacobian ")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mu = 3.0
sigma = 0.02
log_normal = self._cls()(
@@ -195,7 +195,7 @@ class TransformedDistributionTest(test.TestCase):
log_forward_only = bs.Invert(exp_inverse_only)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The log bijector isn't defined over the whole real line, so we make
# sigma sufficiently small so that the draws are positive.
mu = 2.
@@ -211,7 +211,7 @@ class TransformedDistributionTest(test.TestCase):
self.assertAllClose(expected_log_pdf, log_pdf_val, atol=0.)
def testShapeChangingBijector(self):
- with self.test_session():
+ with self.cached_session():
softmax = bs.SoftmaxCentered()
standard_normal = ds.Normal(loc=0., scale=1.)
multi_logit_normal = self._cls()(
@@ -235,7 +235,7 @@ class TransformedDistributionTest(test.TestCase):
def testCastLogDetJacobian(self):
"""Test log_prob when Jacobian and log_prob dtypes do not match."""
- with self.test_session():
+ with self.cached_session():
# Create an identity bijector whose jacobians have dtype int32
int_identity = bs.Inline(
forward_fn=array_ops.identity,
@@ -257,7 +257,7 @@ class TransformedDistributionTest(test.TestCase):
normal.entropy().eval()
def testEntropy(self):
- with self.test_session():
+ with self.cached_session():
shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
actual_mvn_entropy = np.concatenate([
@@ -277,7 +277,7 @@ class TransformedDistributionTest(test.TestCase):
fake_mvn.entropy().eval())
def testScalarBatchScalarEventIdentityScale(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
exp2 = self._cls()(
ds.Exponential(rate=0.25),
bijector=ds.bijectors.AffineScalar(scale=2.)
@@ -310,7 +310,7 @@ class ScalarToMultiTest(test.TestCase):
batch_shape=(),
event_shape=(),
not_implemented_message=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Overriding shapes must be compatible w/bijector; most bijectors are
# batch_shape agnostic and only care about event_ndims.
# In the case of `Affine`, if we got it wrong then it would fire an
@@ -428,7 +428,7 @@ class ScalarToMultiTest(test.TestCase):
batch_shape=[2],
not_implemented_message="not implemented")
- with self.test_session():
+ with self.cached_session():
# Can't override event_shape for scalar batch, non-scalar event.
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
self._cls()(
@@ -445,7 +445,7 @@ class ScalarToMultiTest(test.TestCase):
event_shape=[3],
not_implemented_message="not implemented when overriding event_shape")
- with self.test_session():
+ with self.cached_session():
# Can't override batch_shape for non-scalar batch, scalar event.
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
self._cls()(
@@ -456,7 +456,7 @@ class ScalarToMultiTest(test.TestCase):
validate_args=True)
def testNonScalarBatchNonScalarEvent(self):
- with self.test_session():
+ with self.cached_session():
# Can't override event_shape and/or batch_shape for non_scalar batch,
# non-scalar event.
with self.assertRaisesRegexp(ValueError, "base distribution not scalar"):
@@ -469,7 +469,7 @@ class ScalarToMultiTest(test.TestCase):
validate_args=True)
def testMatrixEvent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_shape = [2]
event_shape = [2, 3, 3]
batch_shape_pl = array_ops.placeholder(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
index 04f047aa0c..856579da32 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
@@ -35,7 +35,7 @@ class VectorDiffeomixtureTest(
"""Tests the VectorDiffeomixture distribution."""
def testSampleProbConsistentBroadcastMixNoBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
@@ -64,7 +64,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, radius=4., center=2., rtol=0.015)
def testSampleProbConsistentBroadcastMixNonStandardBase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
@@ -93,7 +93,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, radius=4., center=3., rtol=0.01)
def testSampleProbConsistentBroadcastMixBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
@@ -128,7 +128,7 @@ class VectorDiffeomixtureTest(
dims = 4
loc_1 = rng.randn(2, 3, dims).astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32),
temperature=[1.],
@@ -152,7 +152,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, radius=3., center=loc_1, rtol=0.02)
def testMeanCovarianceNoBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
@@ -179,7 +179,7 @@ class VectorDiffeomixtureTest(
def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self):
# As temperature decreases, this should approach a mixture of normals, with
# components at -2, 2.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
@@ -216,7 +216,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 1
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[-1.], [0.], [1.]],
@@ -259,7 +259,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
@@ -284,7 +284,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
def testMeanCovarianceBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
@@ -312,7 +312,7 @@ class VectorDiffeomixtureTest(
sess.run, vdm, rtol=0.02, cov_rtol=0.07)
def testSampleProbConsistentQuadrature(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_exponential_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_exponential_diag_test.py
index fd05bd207f..db8186b79a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_exponential_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_exponential_diag_test.py
@@ -37,42 +37,42 @@ class VectorExponentialDiagTest(test.TestCase):
def testScalarParams(self):
mu = -1.
diag = -5.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "at least 1 dimension"):
ds.VectorExponentialDiag(mu, diag)
def testVectorParams(self):
mu = [-1.]
diag = [-5.]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
self.assertAllEqual([3, 1], dist.sample(3).get_shape())
def testMean(self):
mu = [-1., 1]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
self.assertAllEqual([-1. + 1., 1. - 5.], dist.mean().eval())
def testMode(self):
mu = [-1.]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
self.assertAllEqual([-1., -1.], dist.mode().eval())
def testMeanWithBroadcastLoc(self):
mu = [-1.]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
self.assertAllEqual([-1. + 1, -1. - 5], dist.mean().eval())
def testSample(self):
mu = [-2., 1]
diag = [1., -2]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
samps = dist.sample(int(1e4), seed=0).eval()
cov_mat = array_ops.matrix_diag(diag).eval()**2
@@ -85,7 +85,7 @@ class VectorExponentialDiagTest(test.TestCase):
def testSingularScaleRaises(self):
mu = [-1., 1]
diag = [1., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
with self.assertRaisesOpError("Singular"):
dist.sample().eval()
@@ -97,7 +97,7 @@ class VectorExponentialDiagTest(test.TestCase):
# diag corresponds to no batches of 3-variate normals
diag = np.ones([3])
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorExponentialDiag(mu, diag, validate_args=True)
mean = dist.mean()
@@ -117,7 +117,7 @@ class VectorExponentialDiagTest(test.TestCase):
atol=0.10, rtol=0.05)
def testCovariance(self):
- with self.test_session():
+ with self.cached_session():
vex = ds.VectorExponentialDiag(
loc=array_ops.ones([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -153,7 +153,7 @@ class VectorExponentialDiagTest(test.TestCase):
vex.covariance().eval())
def testVariance(self):
- with self.test_session():
+ with self.cached_session():
vex = ds.VectorExponentialDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -178,7 +178,7 @@ class VectorExponentialDiagTest(test.TestCase):
vex.variance().eval())
def testStddev(self):
- with self.test_session():
+ with self.cached_session():
vex = ds.VectorExponentialDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
index 1226c66113..9ee19b7e93 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
@@ -38,14 +38,14 @@ class VectorLaplaceDiagTest(test.TestCase):
def testScalarParams(self):
mu = -1.
diag = -5.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "at least 1 dimension"):
ds.VectorLaplaceDiag(mu, diag)
def testVectorParams(self):
mu = [-1.]
diag = [-5.]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
self.assertAllEqual([3, 1], dist.sample(3).get_shape())
@@ -56,7 +56,7 @@ class VectorLaplaceDiagTest(test.TestCase):
# Batch shape = [1], event shape = [3]
mu = array_ops.zeros((1, 3))
diag = array_ops.ones((1, 3))
- with self.test_session():
+ with self.cached_session():
base_dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
dist = ds.TransformedDistribution(
base_dist,
@@ -68,21 +68,21 @@ class VectorLaplaceDiagTest(test.TestCase):
def testMean(self):
mu = [-1., 1]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
self.assertAllEqual(mu, dist.mean().eval())
def testMeanWithBroadcastLoc(self):
mu = [-1.]
diag = [1., -5]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
self.assertAllEqual([-1., -1.], dist.mean().eval())
def testSample(self):
mu = [-1., 1]
diag = [1., -2]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
samps = dist.sample(int(1e4), seed=0).eval()
cov_mat = 2. * array_ops.matrix_diag(diag).eval()**2
@@ -95,7 +95,7 @@ class VectorLaplaceDiagTest(test.TestCase):
def testSingularScaleRaises(self):
mu = [-1., 1]
diag = [1., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
with self.assertRaisesOpError("Singular"):
dist.sample().eval()
@@ -107,7 +107,7 @@ class VectorLaplaceDiagTest(test.TestCase):
# diag corresponds to no batches of 3-variate normals
diag = np.ones([3])
- with self.test_session():
+ with self.cached_session():
dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
mean = dist.mean()
@@ -126,7 +126,7 @@ class VectorLaplaceDiagTest(test.TestCase):
atol=0.10, rtol=0.05)
def testCovariance(self):
- with self.test_session():
+ with self.cached_session():
vla = ds.VectorLaplaceDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -162,7 +162,7 @@ class VectorLaplaceDiagTest(test.TestCase):
vla.covariance().eval())
def testVariance(self):
- with self.test_session():
+ with self.cached_session():
vla = ds.VectorLaplaceDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
@@ -187,7 +187,7 @@ class VectorLaplaceDiagTest(test.TestCase):
vla.variance().eval())
def testStddev(self):
- with self.test_session():
+ with self.cached_session():
vla = ds.VectorLaplaceDiag(
loc=array_ops.zeros([2, 3], dtype=dtypes.float32))
self.assertAllClose(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py
index 2bc6a926dd..0dd7d23eb0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py
@@ -35,7 +35,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(1.0)
loc = rng.randn(d)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.MultivariateNormalDiag(
loc=loc,
scale_diag=scale_diag,
@@ -65,7 +65,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(1.2)
loc = rng.randn(d)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vlap = ds.VectorLaplaceDiag(
loc=loc,
scale_diag=scale_diag,
@@ -96,7 +96,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(0.9)
loc = rng.randn(d)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.MultivariateNormalDiag(
loc=loc,
scale_diag=scale_diag,
@@ -141,7 +141,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(1.0)
loc = rng.randn(d)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
norm = ds.MultivariateNormalDiag(
loc=loc,
scale_diag=scale_diag,
@@ -186,7 +186,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(1.0)
loc = rng.randn(d)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sasnorm = ds.VectorSinhArcsinhDiag(
loc=loc,
scale_diag=scale_diag,
@@ -201,7 +201,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
b, d = 5, 2
scale_diag = rng.rand(b, d)
scale_identity_multiplier = np.float64(1.1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sasnorm = ds.VectorSinhArcsinhDiag(
scale_diag=scale_diag,
scale_identity_multiplier=scale_identity_multiplier,
@@ -228,7 +228,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
d = 3
scale_diag = rng.rand(d)
scale_identity_multiplier = np.float64(1.1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sasnorm = ds.VectorSinhArcsinhDiag(
scale_diag=scale_diag,
scale_identity_multiplier=scale_identity_multiplier,
@@ -252,7 +252,7 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
rtol=0.1)
def test_pdf_reflected_for_negative_skewness(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sas_pos_skew = ds.VectorSinhArcsinhDiag(
loc=[0.],
scale_identity_multiplier=1.,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
index b8a3a262ce..aaec1f09d9 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py
@@ -75,7 +75,7 @@ class VectorStudentTTest(test.TestCase):
self._rng = np.random.RandomState(42)
def testProbStaticScalar(self):
- with self.test_session():
+ with self.cached_session():
# Scalar batch_shape.
df = np.asarray(3., dtype=np.float32)
# Scalar batch_shape.
@@ -116,7 +116,7 @@ class VectorStudentTTest(test.TestCase):
expected_mst = _FakeVectorStudentT(
df=df, loc=loc, scale_tril=scale_tril)
- with self.test_session():
+ with self.cached_session():
actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
@@ -145,7 +145,7 @@ class VectorStudentTTest(test.TestCase):
expected_mst = _FakeVectorStudentT(
df=df, loc=loc, scale_tril=scale_tril)
- with self.test_session():
+ with self.cached_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -180,7 +180,7 @@ class VectorStudentTTest(test.TestCase):
loc=loc,
scale_tril=scale_tril)
- with self.test_session():
+ with self.cached_session():
actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
@@ -211,7 +211,7 @@ class VectorStudentTTest(test.TestCase):
loc=loc,
scale_tril=scale_tril)
- with self.test_session():
+ with self.cached_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -240,7 +240,7 @@ class VectorStudentTTest(test.TestCase):
scale_tril=np.tile(scale_tril[array_ops.newaxis, :, :],
reps=[len(df), 1, 1]))
- with self.test_session():
+ with self.cached_session():
actual_mst = _VectorStudentT(df=df, loc=loc, scale_diag=scale_diag,
validate_args=True)
self.assertAllClose(expected_mst.log_prob(x),
@@ -266,7 +266,7 @@ class VectorStudentTTest(test.TestCase):
scale_tril=np.tile(scale_tril[array_ops.newaxis, :, :],
reps=[len(df), 1, 1]))
- with self.test_session():
+ with self.cached_session():
df_pl = array_ops.placeholder(dtypes.float32, name="df")
loc_pl = array_ops.placeholder(dtypes.float32, name="loc")
scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag")
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
index dcecce981f..a60056c444 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
@@ -52,7 +52,7 @@ def wishart_var(df, x):
class WishartCholeskyTest(test.TestCase):
def testEntropy(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
w = distributions.WishartCholesky(df, chol(scale))
@@ -64,7 +64,7 @@ class WishartCholeskyTest(test.TestCase):
self.assertAllClose(0.78375711047393404, w.entropy().eval())
def testMeanLogDetAndLogNormalizingConstant(self):
- with self.test_session():
+ with self.cached_session():
def entropy_alt(w):
return (
@@ -80,35 +80,35 @@ class WishartCholeskyTest(test.TestCase):
self.assertAllClose(w.entropy().eval(), entropy_alt(w))
def testMean(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
w = distributions.WishartCholesky(df, chol(scale))
self.assertAllEqual(df * scale, w.mean().eval())
def testMode(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
w = distributions.WishartCholesky(df, chol(scale))
self.assertAllEqual((df - 2. - 1.) * scale, w.mode().eval())
def testStd(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
w = distributions.WishartCholesky(df, chol(scale))
self.assertAllEqual(chol(wishart_var(df, scale)), w.stddev().eval())
def testVariance(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
w = distributions.WishartCholesky(df, chol(scale))
self.assertAllEqual(wishart_var(df, scale), w.variance().eval())
def testSample(self):
- with self.test_session():
+ with self.cached_session():
scale = make_pd(1., 2)
df = 4
@@ -161,7 +161,7 @@ class WishartCholeskyTest(test.TestCase):
# Test that sampling with the same seed twice gives the same results.
def testSampleMultipleTimes(self):
- with self.test_session():
+ with self.cached_session():
df = 4.
n_val = 100
@@ -184,7 +184,7 @@ class WishartCholeskyTest(test.TestCase):
self.assertAllClose(samples1, samples2)
def testProb(self):
- with self.test_session():
+ with self.cached_session():
# Generate some positive definite (pd) matrices and their Cholesky
# factorizations.
x = np.array(
@@ -271,7 +271,7 @@ class WishartCholeskyTest(test.TestCase):
w.log_prob(np.reshape(x, (2, 2, 2, 2))).get_shape())
def testBatchShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
scale = make_pd(1., 2)
chol_scale = chol(scale)
@@ -295,7 +295,7 @@ class WishartCholeskyTest(test.TestCase):
feed_dict={scale_deferred: [chol_scale, chol_scale]}))
def testEventShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
scale = make_pd(1., 2)
chol_scale = chol(scale)
@@ -320,7 +320,7 @@ class WishartCholeskyTest(test.TestCase):
feed_dict={scale_deferred: [chol_scale, chol_scale]}))
def testValidateArgs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
df_deferred = array_ops.placeholder(dtypes.float32)
chol_scale_deferred = array_ops.placeholder(dtypes.float32)
x = make_pd(1., 3)
@@ -374,7 +374,7 @@ class WishartCholeskyTest(test.TestCase):
chol_scale_deferred: np.ones((3, 3))})
def testStaticAsserts(self):
- with self.test_session():
+ with self.cached_session():
x = make_pd(1., 3)
chol_scale = chol(x)
@@ -404,7 +404,7 @@ class WishartCholeskyTest(test.TestCase):
batch_shape + [dims, dims])
wishart = distributions.WishartFull(df=5, scale=scale)
x = wishart.sample(sample_shape, seed=42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_ = sess.run(x)
expected_shape = sample_shape + batch_shape + [dims, dims]
self.assertAllEqual(expected_shape, x.shape)
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index f7933639a0..84517b57c7 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":remote",
":saver",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -104,7 +105,6 @@ cuda_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python/eager:graph_callable",
"//tensorflow/python/eager:test",
"//tensorflow/python:variables",
],
@@ -224,11 +224,24 @@ py_test(
],
)
+py_library(
+ name = "remote",
+ srcs = ["remote.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
py_test(
name = "remote_test",
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py
index 7d2274db9b..48d093e075 100644
--- a/tensorflow/contrib/eager/python/evaluator_test.py
+++ b/tensorflow/contrib/eager/python/evaluator_test.py
@@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"].numpy())
def testDatasetGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
@@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"])
def testWriteSummariesGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
index 0736ed02b7..e5058bfd94 100644
--- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
@@ -218,7 +218,7 @@ class DensenetBenchmark(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks,
@@ -228,7 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark):
weight_decay=1e-4, dropout_rate=0,
pool_initial=True, include_top=True)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -264,8 +264,7 @@ class DensenetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -279,8 +278,8 @@ class DensenetBenchmark(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 1a5a186e7a..529c99b37c 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -66,7 +66,7 @@
"\n",
"[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
"\n",
- "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
+ "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
"\n",
"![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
"\n",
@@ -128,7 +128,7 @@
"source": [
"## Download and prepare the MS-COCO dataset\n",
"\n",
- "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n",
+ "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n",
"\n",
"**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
]
@@ -1056,7 +1056,7 @@
"\n",
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
"\n",
- " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
" result.append(index_word[predicted_id])\n",
"\n",
" if index_word[predicted_id] == '<end>':\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 027097908f..40bc098724 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -610,7 +610,7 @@
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
- " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index 08d8364978..f1e1f99c57 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -466,10 +466,10 @@
" # passing the concatenated vector to the GRU\n",
" output, state = self.gru(x)\n",
" \n",
- " # output shape == (batch_size * max_length, hidden_size)\n",
+ " # output shape == (batch_size * 1, hidden_size)\n",
" output = tf.reshape(output, (-1, output.shape[2]))\n",
" \n",
- " # output shape == (batch_size * max_length, vocab)\n",
+ " # output shape == (batch_size * 1, vocab)\n",
" x = self.fc(output)\n",
" \n",
" return x, state, attention_weights\n",
@@ -677,7 +677,7 @@
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
- " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md
index 0d5ed84894..2778b228e9 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/README.md
+++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md
@@ -1,11 +1,3 @@
-## Research and experimentation
-
-Eager execution provides an imperative, define-by-run interface for advanced
-operations. Write custom layers, forward passes, and training loops with auto
-differentiation. Start with these notebooks, then read the
-[eager execution guide](https://www.tensorflow.org/guide/eager).
-
-1. [Eager execution basics](./eager_basics.ipynb)
-2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb)
-3. [Custom training: basics](./custom_training.ipynb)
-4. [Custom layers](./custom_layers.ipynb)
+The notebooks have been moved to the
+[tensorflow/docs](https://github.com/tensorflow/docs/tree/master/site/en/tutorials/eager)
+repository.
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
index 51b7ffc4de..8fae622e12 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -15,12 +15,7 @@
"execution_count": 0,
"metadata": {
"cellView": "form",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "GCCk8_dHpuNf"
},
@@ -53,308 +48,35 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "idv0bPeCp325"
- },
- "source": [
- "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n",
- " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
- "\u003c/td\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "vDJ4XzMqodTy"
- },
- "source": [
- "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "GQJysDM__Qb0"
- },
- "source": [
- "## Setup\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "OiMPZStlibBv"
- },
- "outputs": [],
- "source": [
- "import tensorflow as tf\n",
- "tf.enable_eager_execution()\n",
- "\n",
- "tfe = tf.contrib.eager # Shorthand for some symbols"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "1CLWJl0QliB0"
- },
- "source": [
- "## Derivatives of a function\n",
- "\n",
- "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "9FViq92UX7P8"
- },
- "outputs": [],
- "source": [
- "from math import pi\n",
- "\n",
- "def f(x):\n",
- " return tf.square(tf.sin(x))\n",
- "\n",
- "assert f(pi/2).numpy() == 1.0\n",
- "\n",
- "\n",
- "# grad_f will return a list of derivatives of f\n",
- "# with respect to its arguments. Since f() has a single argument,\n",
- "# grad_f will return a list with a single element.\n",
- "grad_f = tfe.gradients_function(f)\n",
- "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "v9fPs8RyopCf"
- },
- "source": [
- "### Higher-order gradients\n",
- "\n",
- "The same API can be used to differentiate as many times as you like:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "3D0ZvnGYo0rW"
- },
- "outputs": [],
- "source": [
- "def f(x):\n",
- " return tf.square(tf.sin(x))\n",
- "\n",
- "def grad(f):\n",
- " return lambda x: tfe.gradients_function(f)(x)[0]\n",
- "\n",
- "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "plt.plot(x, f(x), label=\"f\")\n",
- "plt.plot(x, grad(f)(x), label=\"first derivative\")\n",
- "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n",
- "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n",
- "plt.legend()\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "-39gouo7mtgu"
- },
- "source": [
- "## Gradient tapes\n",
- "\n",
- "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n",
- "\n",
- "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "MH0UfjympWf7"
- },
- "outputs": [],
- "source": [
- "def f(x, y):\n",
- " output = 1\n",
- " # Must use range(int(y)) instead of range(y) in Python 3 when\n",
- " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n",
- " for i in range(int(y)):\n",
- " output = tf.multiply(output, x)\n",
- " return output\n",
- "\n",
- "def g(x, y):\n",
- " # Return the gradient of `f` with respect to it's first parameter\n",
- " return tfe.gradients_function(f)(x, y)[0]\n",
- "\n",
- "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n",
- "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n",
- "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n",
- "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "aNmR5-jhpX2t"
- },
- "source": [
- "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n",
- "\n",
- "For example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "bAFeIE8EuVIq"
+ "id": "clNGnJ3u8Rl6"
},
- "outputs": [],
"source": [
- "x = tf.ones((2, 2))\n",
- " \n",
- "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n",
- "# a single t.gradient() call when the bug is resolved.\n",
- "with tf.GradientTape(persistent=True) as t:\n",
- " # TODO(ashankar): Explain with \"watch\" argument better?\n",
- " t.watch(x)\n",
- " y = tf.reduce_sum(x)\n",
- " z = tf.multiply(y, y)\n",
- "\n",
- "# Use the same tape to compute the derivative of z with respect to the\n",
- "# intermediate value y.\n",
- "dz_dy = t.gradient(z, y)\n",
- "assert dz_dy.numpy() == 8.0\n",
- "\n",
- "# Derivative of z with respect to the original input tensor x\n",
- "dz_dx = t.gradient(z, x)\n",
- "for i in [0, 1]:\n",
- " for j in [0, 1]:\n",
- " assert dz_dx[i][j].numpy() == 8.0"
+ "This file has moved."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "DK05KXrAAld3"
- },
- "source": [
- "### Higher-order gradients\n",
- "\n",
- "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "cPQgthZ7ugRJ"
- },
- "outputs": [],
- "source": [
- "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n",
- "\n",
- "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n",
- "\n",
- "with tf.GradientTape() as t:\n",
- " with tf.GradientTape() as t2:\n",
- " t2.watch(x)\n",
- " y = x * x * x\n",
- " # Compute the gradient inside the 't' context manager\n",
- " # which means the gradient computation is differentiable as well.\n",
- " dy_dx = t2.gradient(y, x)\n",
- "d2y_dx2 = t.gradient(dy_dx, x)\n",
- "\n",
- "assert dy_dx.numpy() == 3.0\n",
- "assert d2y_dx2.numpy() == 6.0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "4U1KKzUpNl58"
+ "id": "idv0bPeCp325"
},
"source": [
- "## Next Steps\n",
- "\n",
- "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
- "default_view": {},
"name": "automatic_differentiation.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true,
- "version": "0.3.2",
- "views": {}
+ "version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
index a0bbbb6123..d89774c45e 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "custom_layers.ipynb",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "tDnwEv8FtJm7",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "tDnwEv8FtJm7"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "JlknJBWQtKkI",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "JlknJBWQtKkI"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,347 +32,57 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "60RdWsg1tETW",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Custom layers"
- ]
- },
- {
- "metadata": {
- "id": "BcJg7Enms86w",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "UEu3q4jmpKVT",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n"
]
},
{
- "metadata": {
- "id": "pwX7Fii1rwsJ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "tfe = tf.contrib.eager\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "zSFfVVjkrrsI",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Layers: common sets of useful operations\n",
- "\n",
- "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n",
- "\n",
- "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n",
- "\n",
- "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n"
- ]
- },
- {
"metadata": {
- "id": "8PyXlPl-4TzQ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "60RdWsg1tETW"
},
- "cell_type": "code",
- "source": [
- "# In the tf.keras.layers package, layers are objects. To construct a layer,\n",
- "# simply construct the object. Most layers take as a first argument the number\n",
- "# of output dimensions / channels.\n",
- "layer = tf.keras.layers.Dense(100)\n",
- "# The number of input dimensions is often unnecessary, as it can be inferred\n",
- "# the first time the layer is used, but it can be provided if you want to \n",
- "# specify it manually, which is useful in some complex models.\n",
- "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "Fn69xxPO5Psr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
"source": [
- "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n",
- "Conv2D, LSTM, BatchNormalization, Dropout, and many others."
+ "# Custom layers"
]
},
{
- "metadata": {
- "id": "E3XKNknP5Mhb",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# To use a layer, simply call it.\n",
- "layer(tf.zeros([10, 5]))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "Wt_Nsv-L5t2s",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# Layers have many useful methods. For example, you can inspect all variables\n",
- "# in a layer by calling layer.variables. In this case a fully-connected layer\n",
- "# will have variables for weights and biases.\n",
- "layer.variables"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "6ilvKjz8_4MQ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# The variables are also accessible through nice accessors\n",
- "layer.kernel, layer.bias"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "O0kDbE54-5VS",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Implementing custom layers\n",
- "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n",
- " * `__init__` , where you can do all input-independent initialization\n",
- " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n",
- " * `call`, where you do the forward computation\n",
- "\n",
- "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified."
- ]
- },
- {
- "metadata": {
- "id": "5Byl3n1k5kIy",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class MyDenseLayer(tf.keras.layers.Layer):\n",
- " def __init__(self, num_outputs):\n",
- " super(MyDenseLayer, self).__init__()\n",
- " self.num_outputs = num_outputs\n",
- " \n",
- " def build(self, input_shape):\n",
- " self.kernel = self.add_variable(\"kernel\", \n",
- " shape=[input_shape[-1].value, \n",
- " self.num_outputs])\n",
- " \n",
- " def call(self, input):\n",
- " return tf.matmul(input, self.kernel)\n",
- " \n",
- "layer = MyDenseLayer(10)\n",
- "print(layer(tf.zeros([10, 5])))\n",
- "print(layer.variables)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "tk8E2vY0-z4Z",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "9sFn_RV_8zM-"
},
- "cell_type": "markdown",
"source": [
- "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n",
- "\n",
- "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!"
+ "This file has moved."
]
},
{
- "metadata": {
- "id": "Qhg4KlbKrs3G",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Models: composing layers\n",
- "\n",
- "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n",
- "\n",
- "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model."
- ]
- },
- {
- "metadata": {
- "id": "N30DTXiRASlb",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class ResnetIdentityBlock(tf.keras.Model):\n",
- " def __init__(self, kernel_size, filters):\n",
- " super(ResnetIdentityBlock, self).__init__(name='')\n",
- " filters1, filters2, filters3 = filters\n",
- "\n",
- " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n",
- " self.bn2a = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n",
- " self.bn2b = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n",
- " self.bn2c = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " def call(self, input_tensor, training=False):\n",
- " x = self.conv2a(input_tensor)\n",
- " x = self.bn2a(x, training=training)\n",
- " x = tf.nn.relu(x)\n",
- "\n",
- " x = self.conv2b(x)\n",
- " x = self.bn2b(x, training=training)\n",
- " x = tf.nn.relu(x)\n",
- "\n",
- " x = self.conv2c(x)\n",
- " x = self.bn2c(x, training=training)\n",
- "\n",
- " x += input_tensor\n",
- " return tf.nn.relu(x)\n",
- "\n",
- " \n",
- "block = ResnetIdentityBlock(1, [1, 2, 3])\n",
- "print(block(tf.zeros([1, 2, 3, 3])))\n",
- "print([x.name for x in block.variables])"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "wYfucVw65PMj",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "BcJg7Enms86w"
},
- "cell_type": "markdown",
"source": [
- "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential"
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "custom_layers.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "L9frk7Ur4uvJ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n",
- " tf.keras.layers.BatchNormalization(),\n",
- " tf.keras.layers.Conv2D(2, 1, \n",
- " padding='same'),\n",
- " tf.keras.layers.BatchNormalization(),\n",
- " tf.keras.layers.Conv2D(3, (1, 1)),\n",
- " tf.keras.layers.BatchNormalization()])\n",
- "my_seq(tf.zeros([1, 2, 3, 3]))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "c5YwYcnuK-wc",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Next steps\n",
- "\n",
- "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured."
- ]
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
index 5f1b48fa0d..86dca0b423 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "Custom training: basics",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "5rmpybwysXGV",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "5rmpybwysXGV"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "m8y3rGtQsYP2",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "m8y3rGtQsYP2"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,425 +32,57 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "hrXv0rU9sIma",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Custom training: basics"
- ]
- },
- {
- "metadata": {
- "id": "7S0BwJ_8sLu7",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "k2o3TTG4TFpt",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n",
- "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n",
- "\n",
- "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation."
- ]
- },
- {
- "metadata": {
- "id": "3LXMVuV0VhDr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Setup"
- ]
- },
- {
- "metadata": {
- "id": "PJ64L90aVir3",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "eMAWbDJFVmMk",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Variables\n",
- "\n",
- "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n"
- ]
- },
- {
- "metadata": {
- "id": "VkJwtLS_Jbn8",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# Using python state\n",
- "x = tf.zeros([10, 10])\n",
- "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n",
- " # value of x\n",
- "print(x)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "wfneTXy7JcUz",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n",
- "\n",
- "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable."
]
},
{
- "metadata": {
- "id": "itxmrMil6DQi",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "v = tf.Variable(1.0)\n",
- "assert v.numpy() == 1.0\n",
- "\n",
- "# Re-assign the value\n",
- "v.assign(3.0)\n",
- "assert v.numpy() == 3.0\n",
- "\n",
- "# Use `v` in a TensorFlow operation like tf.square() and reassign\n",
- "v.assign(tf.square(v))\n",
- "assert v.numpy() == 9.0"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "-paSaeq1JzwC",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n",
- "\n",
- "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable."
- ]
- },
- {
"metadata": {
- "id": "BMiFcDzE7Qu3",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "hrXv0rU9sIma"
},
- "cell_type": "markdown",
"source": [
- "## Example: Fitting a linear model\n",
- "\n",
- "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n",
- "\n",
- "1. Define the model.\n",
- "2. Define a loss function.\n",
- "3. Obtain training data.\n",
- "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n",
- "\n",
- "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`."
- ]
- },
- {
- "metadata": {
- "id": "gFzH64Jn9PIm",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Define the model\n",
- "\n",
- "Let's define a simple class to encapsulate the variables and the computation."
+ "# Custom training: basics"
]
},
{
- "metadata": {
- "id": "_WRu7Pze7wk8",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class Model(object):\n",
- " def __init__(self):\n",
- " # Initialize variable to (5.0, 0.0)\n",
- " # In practice, these should be initialized to random values.\n",
- " self.W = tf.Variable(5.0)\n",
- " self.b = tf.Variable(0.0)\n",
- " \n",
- " def __call__(self, x):\n",
- " return self.W * x + self.b\n",
- " \n",
- "model = Model()\n",
- "\n",
- "assert model(3.0).numpy() == 15.0"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "xa6j_yXa-j79",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "### Define a loss function\n",
- "\n",
- "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss."
- ]
- },
- {
- "metadata": {
- "id": "Y0ysUFGY924U",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def loss(predicted_y, desired_y):\n",
- " return tf.reduce_mean(tf.square(predicted_y - desired_y))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "qutT_fkl_CBc",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "IGPZTmwn9IT4"
},
- "cell_type": "markdown",
"source": [
- "### Obtain training data\n",
- "\n",
- "Let's synthesize the training data with some noise."
+ "This file has moved."
]
},
{
- "metadata": {
- "id": "gxPTb-kt_N5m",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "TRUE_W = 3.0\n",
- "TRUE_b = 2.0\n",
- "NUM_EXAMPLES = 1000\n",
- "\n",
- "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n",
- "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n",
- "outputs = inputs * TRUE_W + TRUE_b + noise"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "-50nq-wPBsAW",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue."
- ]
- },
- {
"metadata": {
- "id": "_eb83LtrB4nt",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "7S0BwJ_8sLu7"
},
- "cell_type": "code",
"source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "plt.scatter(inputs, outputs, c='b')\n",
- "plt.scatter(inputs, model(inputs), c='r')\n",
- "plt.show()\n",
- "\n",
- "print('Current loss: '),\n",
- "print(loss(model(inputs), outputs).numpy())"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "sSDP-yeq_4jE",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Define a training loop\n",
- "\n",
- "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Custom training: basics",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "MBIACgdnA55X",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def train(model, inputs, outputs, learning_rate):\n",
- " with tf.GradientTape() as t:\n",
- " current_loss = loss(model(inputs), outputs)\n",
- " dW, db = t.gradient(current_loss, [model.W, model.b])\n",
- " model.W.assign_sub(learning_rate * dW)\n",
- " model.b.assign_sub(learning_rate * db)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "RwWPaJryD2aN",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve."
- ]
- },
- {
- "metadata": {
- "id": "XdfkR223D9dW",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "model = Model()\n",
- "\n",
- "# Collect the history of W-values and b-values to plot later\n",
- "Ws, bs = [], []\n",
- "epochs = range(10)\n",
- "for epoch in epochs:\n",
- " Ws.append(model.W.numpy())\n",
- " bs.append(model.b.numpy())\n",
- " current_loss = loss(model(inputs), outputs)\n",
- "\n",
- " train(model, inputs, outputs, learning_rate=0.1)\n",
- " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n",
- " (epoch, Ws[-1], bs[-1], current_loss))\n",
- "\n",
- "# Let's plot it all\n",
- "plt.plot(epochs, Ws, 'r',\n",
- " epochs, bs, 'b')\n",
- "plt.plot([TRUE_W] * len(epochs), 'r--',\n",
- " [TRUE_b] * len(epochs), 'b--')\n",
- "plt.legend(['W', 'b', 'true W', 'true_b'])\n",
- "plt.show()\n",
- " "
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vPnIVuaSJwWz",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Next Steps\n",
- "\n",
- "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n",
- "\n",
- "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n",
- "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n",
- "\n",
- "The [next tutorial](TODO) will cover these higher level APIs."
- ]
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
index f1e13de5de..c6d1a56604 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "eager_basics.ipynb",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "iPpI7RaYoZuE",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "iPpI7RaYoZuE"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "hro2InpHobKk",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "hro2InpHobKk"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,439 +32,47 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "U9i2Dsh-ziXr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Eager execution basics"
- ]
- },
- {
- "metadata": {
- "id": "Hndw-YcxoOJK",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "6sILUVbHoSgH",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "This is an introductory tutorial for using TensorFlow. It will cover:\n",
- "\n",
- "* Importing required packages\n",
- "* Creating and using Tensors\n",
- "* Using GPU acceleration\n",
- "* Datasets"
- ]
- },
- {
- "metadata": {
- "id": "z1JcS5iBXMRO",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Import TensorFlow\n",
- "\n",
- "To get started, import the `tensorflow` module and enable eager execution.\n",
- "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later."
- ]
- },
- {
- "metadata": {
- "id": "RlIWhyeLoYnG",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "H9UySOPLXdaw",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Tensors\n",
- "\n",
- "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n"
- ]
- },
- {
- "metadata": {
- "id": "ngUe237Wt48W",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "print(tf.add(1, 2))\n",
- "print(tf.add([1, 2], [3, 4]))\n",
- "print(tf.square(5))\n",
- "print(tf.reduce_sum([1, 2, 3]))\n",
- "print(tf.encode_base64(\"hello world\"))\n",
- "\n",
- "# Operator overloading is also supported\n",
- "print(tf.square(2) + tf.square(3))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "IDY4WsYRhP81",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "Each Tensor has a shape and a datatype"
- ]
- },
- {
- "metadata": {
- "id": "srYWH1MdJNG7",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "x = tf.matmul([[1]], [[2, 3]])\n",
- "print(x.shape)\n",
- "print(x.dtype)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "eBPw8e8vrsom",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n",
- "\n",
- "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n",
- "2. Tensors are immutable."
- ]
- },
- {
- "metadata": {
- "id": "Dwi1tdW3JBw6",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### NumPy Compatibility\n",
- "\n",
- "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n",
- "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n",
- "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n",
- "\n",
- "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n",
- "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory."
- ]
- },
- {
- "metadata": {
- "id": "lCUWzso6mbqR",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import numpy as np\n",
- "\n",
- "ndarray = np.ones([3, 3])\n",
- "\n",
- "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n",
- "tensor = tf.multiply(ndarray, 42)\n",
- "print(tensor)\n",
- "\n",
- "\n",
- "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n",
- "print(np.add(tensor, 1))\n",
- "\n",
- "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n",
- "print(tensor.numpy())"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "PBNP8yTRfu_X",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## GPU acceleration\n",
- "\n",
- "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:"
- ]
- },
- {
- "metadata": {
- "id": "3Twf_Rw-gQFM",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "x = tf.random_uniform([3, 3])\n",
- "\n",
- "print(\"Is there a GPU available: \"),\n",
- "print(tf.test.is_gpu_available())\n",
- "\n",
- "print(\"Is the Tensor on GPU #0: \"),\n",
- "print(x.device.endswith('GPU:0'))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vpgYzgVXW2Ud",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Device Names\n",
- "\n",
- "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:<N>` if the tensor is placed on the `N`-th tensor on the host."
]
},
{
- "metadata": {
- "id": "ZWZQCimzuqyP",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "\n",
- "\n",
- "### Explicit Device Placement\n",
- "\n",
- "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:"
- ]
- },
- {
- "metadata": {
- "id": "RjkNZTuauy-Q",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def time_matmul(x):\n",
- " %timeit tf.matmul(x, x)\n",
- "\n",
- "# Force execution on CPU\n",
- "print(\"On CPU:\")\n",
- "with tf.device(\"CPU:0\"):\n",
- " x = tf.random_uniform([1000, 1000])\n",
- " assert x.device.endswith(\"CPU:0\")\n",
- " time_matmul(x)\n",
- "\n",
- "# Force execution on GPU #0 if available\n",
- "if tf.test.is_gpu_available():\n",
- " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n",
- " x = tf.random_uniform([1000, 1000])\n",
- " assert x.device.endswith(\"GPU:0\")\n",
- " time_matmul(x)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "o1K4dlhhHtQj",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "U9i2Dsh-ziXr"
},
- "cell_type": "markdown",
"source": [
- "## Datasets\n",
- "\n",
- "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n",
- "\n",
- "* Creating a `Dataset`.\n",
- "* Iteration over a `Dataset` with eager execution enabled.\n",
- "\n",
- "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
- "\n",
- "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n",
- "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n",
- "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled."
- ]
- },
- {
- "metadata": {
- "id": "zI0fmOynH-Ne",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Create a source `Dataset`\n",
- "\n",
- "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information."
+ "# Eager execution basics"
]
},
{
- "metadata": {
- "id": "F04fVOHQIBiG",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n",
- "\n",
- "# Create a CSV file\n",
- "import tempfile\n",
- "_, filename = tempfile.mkstemp()\n",
- "\n",
- "with open(filename, 'w') as f:\n",
- " f.write(\"\"\"Line 1\n",
- "Line 2\n",
- "Line 3\n",
- " \"\"\")\n",
- "\n",
- "ds_file = tf.data.TextLineDataset(filename)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vbxIhC-5IPdf",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "### Apply transformations\n",
- "\n",
- "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details."
- ]
- },
- {
"metadata": {
- "id": "uXSDZWE-ISsd",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "Hndw-YcxoOJK"
},
- "cell_type": "code",
"source": [
- "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n",
- "\n",
- "ds_file = ds_file.batch(2)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "A8X1GNfoIZKJ",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Iterate\n",
- "\n",
- "When eager execution is enabled `Dataset` objects support iteration.\n",
- "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "eager_basics.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "ws-WKRk5Ic6-",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "print('Elements of ds_tensors:')\n",
- "for x in ds_tensors:\n",
- " print(x)\n",
- "\n",
- "print('\\nElements in ds_file:')\n",
- "for x in ds_file:\n",
- " print(x)"
- ],
- "execution_count": 0,
- "outputs": []
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index ee25d25b52..d60ee18586 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -147,11 +147,12 @@
" # random jittering\n",
" \n",
" # resizing to 286 x 286 x 3\n",
- " # method = 2 indicates using \"ResizeMethod.NEAREST_NEIGHBOR\"\n",
" input_image = tf.image.resize_images(input_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" real_image = tf.image.resize_images(real_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" \n",
" # randomly cropping to 256 x 256 x 3\n",
" stacked_image = tf.stack([input_image, real_image], axis=0)\n",
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index a28bc8a43d..9d090e8429 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -195,12 +195,12 @@ class ResNet50(tf.keras.Model):
def __init__(self,
data_format,
- name=None,
+ name='',
trainable=True,
include_top=True,
pooling=None,
classes=1000):
- super(ResNet50, self).__init__(name='')
+ super(ResNet50, self).__init__(name=name)
valid_channel_values = ('channels_first', 'channels_last')
if data_format not in valid_channel_values:
@@ -272,8 +272,8 @@ class ResNet50(tf.keras.Model):
else:
self.global_pooling = None
- def call(self, input_tensor, training):
- x = self.conv1(input_tensor)
+ def call(self, inputs, training=True):
+ x = self.conv1(inputs)
x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x)
x = self.max_pool(x)
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 07d8788882..d265169b5e 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -216,12 +216,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
tf.constant(1.).cpu()
def _benchmark_eager_apply(self, label, device_and_format, defun=False,
- execution_mode=None, compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = resnet50.ResNet50(data_format)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -257,8 +257,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
@@ -267,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
optimizer = tf.train.GradientDescentOptimizer(0.1)
apply_grads = apply_gradients
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
- apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ model.call = tfe.defun(model.call)
+ apply_grads = tfe.defun(apply_gradients)
num_burn = 3
num_iters = 10
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 84b2ddf0de..6a921e1997 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -226,14 +226,13 @@ class RevNetBenchmark(tf.test.Benchmark):
label,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
model = revnet.RevNet(config=config)
if defun:
- model.call = tfe.defun(model.call, compiled=compiled)
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 10
@@ -271,8 +270,7 @@ class RevNetBenchmark(tf.test.Benchmark):
make_iterator,
device_and_format,
defun=False,
- execution_mode=None,
- compiled=False):
+ execution_mode=None):
config = config_.get_hparams_imagenet_56()
with tfe.execution_mode(execution_mode):
device, data_format = device_and_format
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
index 5ee2176154..74ebb1ec77 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
@@ -243,8 +243,8 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10):
print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss()))
-SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
-SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
+SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv"
+SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv"
def main(_):
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
deleted file mode 100644
index 638c57d1c9..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-cuda_py_test(
- name = "scan_test",
- size = "small",
- srcs = ["scan_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "scan_graph_test",
- size = "small",
- srcs = ["scan_graph_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
deleted file mode 100644
index d4b8c8941e..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under graph mode execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- with tf.Session() as sess:
- sess.run(sum_op)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 6efafccd6b..930e62b680 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -336,9 +336,27 @@ class Mean(Metric):
return values
return values, weights
- def result(self):
+ def result(self, write_summary=True):
+ """Returns the result of the Metric.
+
+ Args:
+ write_summary: bool indicating whether to feed the result to the summary
+ before returning.
+ Returns:
+ aggregated metric as float.
+ Raises:
+ ValueError: if the optional argument is not bool
+ """
+ # Convert the boolean to tensor for tf.cond, if it is not.
+ if not isinstance(write_summary, ops.Tensor):
+ write_summary = ops.convert_to_tensor(write_summary)
t = self.numer / self.denom
- summary_ops.scalar(name=self.name, tensor=t)
+ def write_summary_f():
+ summary_ops.scalar(name=self.name, tensor=t)
+ return t
+ control_flow_ops.cond(write_summary,
+ write_summary_f,
+ lambda: t)
return t
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 20d938d492..9d2d172752 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -46,6 +49,18 @@ class MetricsTest(test.TestCase):
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
+ def testSummaryArg(self):
+ m = metrics.Mean()
+ m([1, 10, 100])
+ m(1000)
+ m([10000.0, 100000.0])
+ self.assertEqual(111111.0/6, m.result(write_summary=True).numpy())
+ self.assertEqual(111111.0/6, m.result(write_summary=False).numpy())
+ with self.assertRaises(ValueError):
+ m.result(write_summary=5)
+ with self.assertRaises(ValueError):
+ m.result(write_summary=[True])
+
def testVariableCollections(self):
with context.graph_mode(), ops.Graph().as_default():
m = metrics.Mean()
@@ -93,6 +108,16 @@ class MetricsTest(test.TestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
+ # Get result without saving the summary.
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_file_writer(
+ logdir, max_queue=0,
+ name="t0").as_default(), summary_ops.always_record_summaries():
+ m.result(write_summary=False) # As a side-effect will write summaries.
+ # events_from_logdir(_) asserts the directory exists.
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(len(events), 1)
+
def testWeightedMean(self):
m = metrics.Mean()
m([1, 100, 100000], weights=[1, 0.2, 0.3])
@@ -191,7 +216,7 @@ class MetricsTest(test.TestCase):
self.assertEqual(m1.numer.name, "has_space/numer:0")
def testGraphWithPlaceholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
@@ -222,6 +247,48 @@ class MetricsTest(test.TestCase):
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorGlobalVariables(self):
+ m = metrics.Mean(use_global_variables=True)
+ inputs = ops.convert_to_tensor([1.0, 2.0])
+ accumulate = m(inputs)
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.5)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorWhileLoopDoubleCall(self):
+ m = metrics.Mean()
+ init_value = constant_op.constant(1)
+ cond = lambda i: math_ops.less(i, 3)
+ def body(x):
+ with ops.control_dependencies([m(x)]):
+ return math_ops.add(x, 1)
+ accumulate = control_flow_ops.while_loop(cond, body, [init_value])
+
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ if ops.context.executing_eagerly():
+ self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
+ else:
+ # Reuse the loop operators in graph mode
+ self.evaluate(accumulate)
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testTwoMeansGraph(self):
# Verify two metrics with the same name in the same graph raises a
# ValueError.
@@ -242,7 +309,7 @@ class MetricsTest(test.TestCase):
self.assertTrue(old_numer is m.numer)
def testMetricsChain(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
m1 = metrics.Mean()
m2 = metrics.Mean(name="m2")
update_m2 = m2(3.0)
diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py
new file mode 100644
index 0000000000..b74cf394f6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/remote.py
@@ -0,0 +1,73 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers to connect to remote servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
+from tensorflow.python.eager import context
+
+
+def connect_to_remote_host(remote_host=None, job_name="worker"):
+ """Connects to a single machine to enable remote execution on it.
+
+ Will make devices on the remote host available to use. Note that calling this
+ more than once will work, but will invalidate any tensor handles on the old
+ remote devices.
+
+ Using the default job_name of worker, you can schedule ops to run remotely as
+ follows:
+ ```python
+ # Enable eager execution, and connect to the remote host.
+ tf.enable_eager_execution()
+ tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ # The following tensors should be resident on the remote device, and the op
+ # will also execute remotely.
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ ```
+
+ Args:
+ remote_host: The addr of the remote server in host-port format.
+ job_name: The job name under which the new server will be accessible.
+
+ Raises:
+ ValueError: if remote_host is None.
+ """
+ if remote_host is None:
+ raise ValueError("Must provide an remote_host")
+ cluster_def = ClusterDef()
+ job_def = cluster_def.job.add()
+ job_def.name = job_name
+ job_def.tasks[0] = "127.0.0.1:0"
+ job_def.tasks[1] = remote_host
+
+ server_def = ServerDef(
+ cluster=cluster_def,
+ job_name=job_name,
+ task_index=0,
+ protocol="grpc")
+
+ # TODO(nareshmodi): Make this default since it works in more situations.
+ os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
+ context.set_server_def(server_def)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 76f48eeb1c..13029db975 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.eager import backprop
@@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase):
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
+ def setUp(self):
# Start the local server.
context.set_server_def(
server_def=get_server_def(
@@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ @run_sync_and_async
+ def testConnectToRemoteServer(self):
+ """Basic server connection."""
+ remote.connect_to_remote_host(self._cached_server1_target)
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
if __name__ == "__main__":
ops.enable_eager_execution()
diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py
index 90a3711475..91bc75213c 100644
--- a/tensorflow/contrib/eager/python/saver_test.py
+++ b/tensorflow/contrib/eager/python/saver_test.py
@@ -21,15 +21,11 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
-from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
@@ -142,53 +138,6 @@ class SaverTest(test.TestCase):
with _saver.restore_variables_on_create(ckpt_prefix):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
- def testSaveRestoreGraphCallable(self):
- with ops.device(self._dev()):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- # Default 2 + 0 = 2
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # Save the variable value 0.
- ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
- _saver.Saver(model.variables).save(ckpt_prefix)
-
- # update variable to 1, so that 2 + 1 = 3
- model.variables[0].assign(1.)
- self.assertEqual(
- 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # load the variable value 0, so that 2 + 0 = 2
- _saver.Saver(model.variables).restore(ckpt_prefix)
- self.assertEqual(
- 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # update checkpoint variable to 1 and memory value to 2.
- model.variables[0].assign(1.)
- _saver.Saver(model.variables).save(ckpt_prefix)
- model.variables[0].assign(2.)
- self.assertEqual(
- 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
- # reset the graph and reload on create, so that 1 + 2 = 3
- ops.reset_default_graph()
- with _saver.restore_variables_on_create(ckpt_prefix):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model2(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
-
class GetOptimizerTests(test.TestCase):
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index de11d00a1a..f5b8d95e4f 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -16,7 +16,7 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
-To use, at program startup, call `tfe.enable_eager_execution()`.
+To use, at program startup, call `tf.enable_eager_execution()`.
@@metrics
@@ -67,12 +67,15 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@execution_mode
@@async_wait
@@async_clear_error
+@@set_server_def
@@run_test_in_graph_and_eager_modes
@@run_all_tests_in_graph_and_eager_modes
@@TensorSpec
+@@connect_to_remote_host
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -93,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.network import Sequential
from tensorflow.contrib.eager.python.network import save_network_checkpoint
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
+from tensorflow.contrib.eager.python.remote import connect_to_remote_host
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver
@@ -110,6 +114,7 @@ from tensorflow.python.eager.context import async_clear_error
from tensorflow.python.eager.context import SYNC
from tensorflow.python.eager.context import ASYNC
from tensorflow.python.eager.context import num_gpus
+from tensorflow.python.eager.context import set_server_def
from tensorflow.python.eager.execution_callbacks import add_execution_callback
from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
from tensorflow.python.eager.execution_callbacks import inf_callback
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..437b3d965d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -446,6 +446,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 6ad3a4a604..258860f263 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,7 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
- 'StopAtCheckpointStepHook',
+ 'make_stop_at_checkpoint_step_hook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index 505c94e971..513feb03b6 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -37,13 +37,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
@@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py
index 03cf6f107c..b0deb9b494 100644
--- a/tensorflow/contrib/estimator/python/estimator/export.py
+++ b/tensorflow/contrib/estimator/python/estimator/export.py
@@ -31,8 +31,8 @@ def export_saved_model_for_mode(
# pylint: disable=line-too-long
"""Exports a single train/eval/predict graph as a SavedModel.
- For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ For a detailed guide, see [Using SavedModel with Estimators](
+ https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
Sample usage:
```python
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 26449b4651..e3c44bea66 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import function_utils
@@ -140,7 +141,7 @@ def clip_gradients_by_norm(optimizer, clip_norm):
name='ClipByNorm' + optimizer.get_name())
-def forward_features(estimator, keys=None):
+def forward_features(estimator, keys=None, sparse_default_values=None):
"""Forward features to predictions dictionary.
In some cases, user wants to see some of the features in estimators prediction
@@ -148,39 +149,36 @@ def forward_features(estimator, keys=None):
runs inference on the users graph and returns the results. Keys are essential
because there is no order guarantee on the outputs so they need to be rejoined
to the inputs via keys or transclusion of the inputs in the outputs.
-
Example:
-
```python
def input_fn():
features, labels = ...
features['unique_example_id'] = ...
features, labels
-
estimator = tf.estimator.LinearClassifier(...)
estimator = tf.contrib.estimator.forward_features(
estimator, 'unique_example_id')
estimator.train(...)
assert 'unique_example_id' in estimator.predict(...)
```
-
Args:
estimator: A `tf.estimator.Estimator` object.
- keys: a `string` or a `list` of `string`. If it is `None`, all of the
+ keys: A `string` or a `list` of `string`. If it is `None`, all of the
`features` in `dict` is forwarded to the `predictions`. If it is a
`string`, only given key is forwarded. If it is a `list` of strings, all
the given `keys` are forwarded.
+ sparse_default_values: A dict of `str` keys mapping the name of the sparse
+ features to be converted to dense, to the default value to use. Only
+ sparse features indicated in the dictionary are converted to dense and the
+ provided default value is used.
Returns:
A new `tf.estimator.Estimator` which forwards features to predictions.
-
Raises:
ValueError:
* if `keys` is already part of `predictions`. We don't allow
override.
* if 'keys' does not exist in `features`.
- * if feature key refers to a `SparseTensor`, since we don't support
- `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
"""
@@ -231,11 +229,18 @@ def forward_features(estimator, keys=None):
for key in get_keys(features):
feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
features[key])
+ if sparse_default_values and (key in sparse_default_values):
+ if not isinstance(feature, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'Feature ({}) is expected to be a `SparseTensor`.'.format(key))
+ feature = sparse_ops.sparse_tensor_to_dense(
+ feature, default_value=sparse_default_values[key])
if not isinstance(feature, ops.Tensor):
raise ValueError(
- 'Forwarded feature ({}) should be a Tensor. Please use keys '
- 'argument of forward_features to filter unwanted features. Type of '
- 'features[{}] is {}.'.format(key, key, type(feature)))
+ 'Feature ({}) should be a Tensor. Please use `keys` '
+ 'argument of forward_features to filter unwanted features, or'
+ 'add key to argument `sparse_default_values`.'
+ 'Type of features[{}] is {}.'.format(key, key, type(feature)))
predictions[key] = feature
spec = spec._replace(predictions=predictions)
if spec.export_outputs:
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
index 407af2deaf..c8fdaa8791 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""extenders tests."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
@@ -170,19 +172,53 @@ class ClipGradientsByNormTest(test.TestCase):
class ForwardFeaturesTest(test.TestCase):
"""Tests forward_features."""
- def test_forward_single_key(self):
-
- def input_fn():
- return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
+ def _export_estimator(self, estimator, serving_input_fn):
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+ return export_dir, tmpdir
+ def make_dummy_input_fn(self):
+ def _input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': [[3.], [5.]],
+ 'id': [[101], [102]],
+ 'sparse_id': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[1.], [2.]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+ return _input_fn
+
+ def test_forward_keys(self):
+
+ input_fn = self.make_dummy_input_fn()
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
- self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
- estimator = extenders.forward_features(estimator, 'id')
- predictions = next(estimator.predict(input_fn=input_fn))
- self.assertIn('id', predictions)
- self.assertEqual(101, predictions['id'])
+ forwarded_keys = ['id', 'sparse_id']
+
+ for key in forwarded_keys:
+ self.assertNotIn(key, next(estimator.predict(input_fn=input_fn)))
+
+ estimator = extenders.forward_features(
+ estimator, forwarded_keys, sparse_default_values={'sparse_id': 1})
+
+ expected_results = [101, 2, 102, 5]
+ predictions = estimator.predict(input_fn=input_fn)
+ for _ in range(2):
+ prediction = next(predictions)
+ for key in forwarded_keys:
+ self.assertIn(key, prediction)
+ self.assertEqual(expected_results.pop(0), sum(prediction[key]))
def test_forward_in_exported(self):
@@ -205,11 +241,7 @@ class ForwardFeaturesTest(test.TestCase):
estimator = extenders.forward_features(estimator, 'id')
# export saved model
- tmpdir = tempfile.mkdtemp()
- export_dir_base = os.path.join(
- compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
- self.assertTrue(gfile.Exists(export_dir))
+ export_dir, tmpdir = self._export_estimator(estimator, serving_input_fn)
# restore model
predict_fn = from_saved_model(export_dir, signature_def_key='predict')
@@ -222,6 +254,47 @@ class ForwardFeaturesTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_forward_in_exported_sparse(self):
+ features_columns = [fc.indicator_column(
+ fc.categorical_column_with_vocabulary_list('x', range(10)))]
+
+ classifier = linear.LinearClassifier(feature_columns=features_columns)
+
+ def train_input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[0], [1]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+
+ classifier.train(train_input_fn, max_steps=1)
+
+ classifier = extenders.forward_features(
+ classifier, keys=['x'], sparse_default_values={'x': 0})
+
+ def serving_input_fn():
+ features_ph = array_ops.placeholder(dtype=dtypes.int32, name='x',
+ shape=[None])
+ features = {'x': layers.dense_to_sparse(features_ph)}
+ return estimator_lib.export.ServingInputReceiver(features,
+ {'x': features_ph})
+ export_dir, tmpdir = self._export_estimator(classifier, serving_input_fn)
+ prediction_fn = from_saved_model(export_dir, signature_def_key='predict')
+
+ features = (0, 2)
+ prediction = prediction_fn({'x': features})
+
+ self.assertIn('x', prediction)
+ self.assertEqual(features, tuple(prediction['x']))
+ gfile.DeleteRecursively(tmpdir)
+
def test_forward_list(self):
def input_fn():
@@ -266,7 +339,6 @@ class ForwardFeaturesTest(test.TestCase):
extenders.forward_features(estimator, ['x', estimator])
def test_key_should_be_in_features(self):
-
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
@@ -279,27 +351,36 @@ class ForwardFeaturesTest(test.TestCase):
next(estimator.predict(input_fn=input_fn))
def test_forwarded_feature_should_not_be_a_sparse_tensor(self):
-
def input_fn():
return {
'x': [[3.], [5.]],
- 'id':
- sparse_tensor.SparseTensor(
- values=['1', '2'],
- indices=[[0, 0], [1, 0]],
- dense_shape=[2, 1])
- }, [[1.], [2.]]
+ 'id': sparse_tensor.SparseTensor(
+ values=['1', '2'],
+ indices=[[0, 0], [1, 0]],
+ dense_shape=[2, 1])
+ }, [[1.], [2.]]
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
estimator = extenders.forward_features(estimator)
with self.assertRaisesRegexp(ValueError,
- 'Forwarded feature.* should be a Tensor.'):
+ 'Feature .* should be a Tensor.*'):
next(estimator.predict(input_fn=input_fn))
- def test_predictions_should_be_dict(self):
+ def test_forwarded_feature_should_be_a_sparse_tensor(self):
+ input_fn = self.make_dummy_input_fn()
+
+ estimator = linear.LinearRegressor([fc.numeric_column('x')])
+ estimator.train(input_fn=input_fn, steps=1)
+ estimator = extenders.forward_features(
+ estimator, sparse_default_values={'id': 0, 'sparse_id': 0})
+ with self.assertRaisesRegexp(
+ ValueError, 'Feature .* is expected to be a `SparseTensor`.'):
+ next(estimator.predict(input_fn=input_fn))
+
+ def test_predictions_should_be_dict(self):
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 2d367adb47..c6e75f8d46 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -215,7 +215,7 @@ class MultiLabelHead(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -246,7 +246,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_export_classes,
@@ -271,7 +271,7 @@ class MultiLabelHead(test.TestCase):
logits=logits)
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -297,7 +297,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss,
actual_training_loss.eval())
@@ -321,7 +321,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, actual_training_loss.eval(), atol=1e-4)
@@ -338,7 +338,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -375,7 +375,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval())
@@ -394,7 +394,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -433,7 +433,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -753,7 +753,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -791,7 +791,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=1e-4)
@@ -825,7 +825,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=1e-4)
@@ -864,7 +864,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -890,7 +890,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -919,7 +919,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1011,7 +1011,7 @@ class MultiLabelHead(test.TestCase):
optimizer=_Optimizer())
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1040,7 +1040,7 @@ class MultiLabelHead(test.TestCase):
labels=np.array([[1, 0], [1, 1]], dtype=np.int64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -1079,7 +1079,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1127,7 +1127,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1162,7 +1162,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels)
atol = 1.e-3
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=atol)
@@ -1197,7 +1197,7 @@ class MultiLabelHead(test.TestCase):
train_op_fn=_train_op_fn)
atol = 1.e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, monitored_session.Scaffold())
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, atol=atol)
@@ -1224,7 +1224,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1252,7 +1252,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1327,7 +1327,7 @@ class PoissonRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run([spec.loss, spec.train_op])
self.assertAlmostEqual(expected_loss, loss, delta=atol)
@@ -1352,7 +1352,7 @@ class PoissonRegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
@@ -1395,7 +1395,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run([spec.loss, spec.train_op])
self.assertAlmostEqual(expected_loss, loss, delta=atol)
@@ -1419,7 +1419,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1444,7 +1444,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1471,7 +1471,7 @@ class LogisticRegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index faefda7c48..66c46e66b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -74,8 +74,9 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
estimator: A `tf.estimator.Estimator` instance to call evaluate.
input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A
function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Createing input functions](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
@@ -212,8 +213,12 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
-class StopAtCheckpointStepHook(training.SessionRunHook):
- """Hook that requests stop at a specified step based on checkpoint."""
+class _StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint.
+
+ Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper
+ hook.
+ """
def __init__(self, model_dir, last_step,
wait_after_file_check_secs=30):
@@ -263,4 +268,17 @@ class StopAtCheckpointStepHook(training.SessionRunHook):
else:
time.sleep(self._wait_after_file_check_secs)
+
+def make_stop_at_checkpoint_step_hook(estimator,
+ last_step,
+ wait_after_file_check_secs=30):
+ """Creates a proper StopAtCheckpointStepHook based on chief status."""
+
+ if estimator.config.is_chief:
+ return training.StopAtStepHook(last_step=last_step)
+ return _StopAtCheckpointStepHook(
+ model_dir=estimator.model_dir,
+ last_step=last_step,
+ wait_after_file_check_secs=wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index 42352aa3ff..c6c6cad95a 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -326,7 +326,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=tempfile.mkdtemp(), last_step=10)
with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
mon_sess.raw_session().run(assign_ten)
@@ -342,7 +342,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
assign_nine = step.assign(9)
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_nine)
@@ -360,7 +360,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_ten)
@@ -372,6 +372,32 @@ class StopAtCheckpointStepHookTest(test.TestCase):
self.assertFalse(mock_sleep.called)
self.assertTrue(mon_sess.should_stop())
+ def test_creates_regular_stop_at_step_hook_for_chief(self):
+ # by default an estimator is in chief mode
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1])
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, training.StopAtStepHook)
+ self.assertEqual(300, hook._last_step)
+
+ def test_creates_checkpoint_hook_for_workers(self):
+
+ class FakeWorkerConfig(estimator_lib.RunConfig):
+
+ @property
+ def is_chief(self):
+ return False
+
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1],
+ config=FakeWorkerConfig())
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
+ self.assertEqual(300, hook._last_step)
+ self.assertEqual(dnn.model_dir, hook._model_dir)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 3d6fccb118..2b4d5f5261 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -132,7 +132,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -202,7 +202,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -259,7 +259,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -336,7 +336,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -362,7 +362,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)[0]
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
# (averaged over classes, averaged over examples).
self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol)
@@ -397,7 +397,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
# training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
@@ -445,7 +445,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
# training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
@@ -498,7 +498,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)[0]
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -535,7 +535,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -579,7 +579,7 @@ class MultiHeadTest(test.TestCase):
optimizer=_Optimizer())
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -634,7 +634,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index dd8a3a95f1..65229d67bb 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -209,7 +209,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -233,7 +233,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
# Add another trainable variable that doesn't produce a gradient to
# verify that None gradients are supported.
_ = variable_scope.get_variable(
@@ -275,7 +275,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
# for the second.
expected_c = 10.0 - 3.0, 7.0 - 4.0
- with self.test_session() as session, variable_scope.variable_scope(
+ with self.cached_session() as session, variable_scope.variable_scope(
'', reuse=variable_scope.AUTO_REUSE):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
@@ -299,7 +299,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -330,7 +330,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
@@ -359,7 +359,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
@@ -374,7 +374,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -396,7 +396,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -424,7 +424,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -456,7 +456,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session():
+ with self.cached_session():
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/GPU:0'])
_ = replicated_model_fn(
@@ -470,7 +470,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session():
+ with self.cached_session():
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
_ = replicated_model_fn(
@@ -521,7 +521,7 @@ class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -649,7 +649,7 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -746,7 +746,7 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -777,7 +777,7 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session(), ops_lib.Graph().as_default():
+ with self.cached_session(), ops_lib.Graph().as_default():
with self.assertRaisesRegexp(
ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
@@ -819,7 +819,7 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'Please.+wrap.+with.+TowerOptimizer'):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
@@ -845,7 +845,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
def test_gradients_are_computed(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=None,
@@ -879,7 +879,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
self.assertEqual(0.25, session.run(c))
def test_gradients_are_computed_with_mean_reduction(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=model_fn_lib.ModeKeys.EVAL,
@@ -932,7 +932,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
return model_fn_lib.EstimatorSpec(
mode=mode, loss=math_ops.reduce_sum(loss))
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
model_fn,
mode=None,
@@ -975,7 +975,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual(a.dense_shape, b.dense_shape)
def test_simple_half_split(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -988,7 +988,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
def test_to_each_their_own(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1001,7 +1001,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
def test_one_batch(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1014,7 +1014,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
def test_half_split_in_dictionary(self):
- with self.test_session():
+ with self.cached_session():
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1029,7 +1029,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
def test_sparse_tensor_can_be_split_unevenly(self):
- with self.test_session():
+ with self.cached_session():
features = {
'x':
sparse_tensor.SparseTensor(
@@ -1054,7 +1054,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[2.0]], label_shards[1].eval())
def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
- with self.test_session():
+ with self.cached_session():
features = {
'x':
sparse_tensor.SparseTensor(
@@ -1081,7 +1081,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[2.0]], label_shards[1].eval())
def test_one_batch_in_dictionary(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.cached_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1095,7 +1095,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
def test_feature_and_label_dictionaries(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.cached_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
@@ -1127,7 +1127,7 @@ class TrainSpecTest(test_util.TensorFlowTestCase):
return constant_op.constant(loss_value, dtype=dtypes.float64)
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
tower_specs = list(map(self.create_estimator_spec, tower_losses))
@@ -1161,7 +1161,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
return metrics
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = map(self.create_constant_loss, [2, 4, 6])
tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
tower_specs = [
@@ -1187,7 +1187,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
def test_handles_single_tower(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = map(self.create_constant_loss, [5])
tower_metrics = map(self.create_eval_metrics, [0.2])
tower_specs = [
@@ -1231,7 +1231,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase):
})
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=None,
@@ -1273,7 +1273,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
@@ -1303,7 +1303,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
def test_reduce_is_idempotent(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
@@ -1329,7 +1329,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
def test_handles_single_tower(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
self.create_tower_metrics(0)
session.run(
variables.variables_initializer(
@@ -1346,7 +1346,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
def test_doesnt_accept_uneven_number_of_variables(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
self.create_metric_variable(-1.0, 'oddball')
@@ -1418,7 +1418,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
return estimator_spec
def test_merge_predict_output(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
{
@@ -1428,7 +1428,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
def test_merge_classification_output_scores_classes(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1440,7 +1440,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
estimator_spec.export_outputs['classification_output'].classes))
def test_merge_classification_output_scores(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1450,7 +1450,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
None, estimator_spec.export_outputs['classification_scores'].classes)
def test_merge_classification_output_classes(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllEqual(
[b'split_inputs/split:0', b'split_inputs/split:1'],
@@ -1460,7 +1460,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
None, estimator_spec.export_outputs['classification_classes'].scores)
def test_merge_regression_output(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1548,7 +1548,7 @@ class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
def test_vectors(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
total = replicate_model_fn._compute_sum_on_device(
[1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
@@ -1557,7 +1557,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, session.run(total))
def test_tensors(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
total = replicate_model_fn._compute_sum_on_device(
[[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
@@ -1566,7 +1566,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
self.assertAllEqual([4.0, 6.0], session.run(total))
def test_indexedslices(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([1.0, 2.0]), [0, 1],
dense_shape=constant_op.constant([2]))
@@ -1580,7 +1580,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
session.run(ops_lib.convert_to_tensor(total)))
def test_indexedslices_higher_dimensions(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
dense_shape=constant_op.constant([2, 4]))
@@ -1595,7 +1595,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
session.run(ops_lib.convert_to_tensor(total)))
def test_indexedslices_some_dont_overlap(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([1.0, 2.0]), [0, 3],
dense_shape=constant_op.constant([4]))
@@ -1637,7 +1637,7 @@ class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
},
]
- with self.test_session() as session:
+ with self.cached_session() as session:
self.assertAllClose({
'a': np.array([1.0, 2.0, 3.0]),
'b': np.array([11.0, 12.0, 13.0, 14.0]),
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
index 1322f7ce5f..db47073fcc 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
@@ -41,7 +41,7 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
[-1., -1.]]).astype(np.float32)
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
sampled_points = clustering_ops.kmeans_plus_plus_initialization(
self._points, 3, seed, (seed % 5) - 1)
self.assertAllClose(
@@ -58,7 +58,7 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
class KMC2InitializationTest(test.TestCase):
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
distances = np.zeros(1000).astype(np.float32)
distances[6] = 10e7
distances[4] = 10e3
@@ -82,7 +82,7 @@ class KMC2InitializationLargeTest(test.TestCase):
self._distances[1000] = 50.0
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
counts = {}
seed = 0
for i in range(50):
@@ -102,7 +102,7 @@ class KMC2InitializationCornercaseTest(test.TestCase):
self._distances = np.zeros(10)
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
sampled_point = clustering_ops.kmc2_chain_initialization(
self._distances, seed)
self.assertEquals(sampled_point.eval(), 0)
@@ -128,14 +128,14 @@ class NearestCentersTest(test.TestCase):
[1., 1.]]).astype(np.float32)
def testNearest1(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 1)
self.assertAllClose(indices.eval(), [[0], [0], [1], [4]])
self.assertAllClose(distances.eval(), [[0.], [5.], [1.], [0.]])
def testNearest2(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 2)
self.assertAllClose(indices.eval(), [[0, 1], [0, 1], [1, 0], [4, 3]])
@@ -180,7 +180,7 @@ class NearestCentersLargeTest(test.TestCase):
expected_nearest_neighbor_squared_distances))
def testNearest1(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 1)
self.assertAllClose(indices.eval(),
@@ -190,7 +190,7 @@ class NearestCentersLargeTest(test.TestCase):
self._expected_nearest_neighbor_squared_distances[:, [0]])
def testNearest5(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 5)
self.assertAllClose(indices.eval(),
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
index 3a909e2373..dd115735d0 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
@@ -58,7 +58,7 @@ class MaskedProductOpsTest(test.TestCase):
self._mask_ind, self._mask_shape = MakeMask()
def _runTestMaskedProduct(self, transpose_a, transpose_b):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
a = self._a if not transpose_a else array_ops.transpose(self._a)
b = self._b if not transpose_b else array_ops.transpose(self._b)
@@ -78,7 +78,7 @@ class MaskedProductOpsTest(test.TestCase):
AssertClose(result, true_result)
def _runTestEmptyMaskedProduct(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
empty_mask = constant_op.constant(0, shape=[0, 2], dtype=dtypes.int64)
values = gen_factorization_ops.masked_matmul(
self._a, self._b, empty_mask, False, False)
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
index 6c2f1d4608..8a16e22663 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -50,7 +50,7 @@ class WalsSolverOpsTest(test.TestCase):
def testWalsSolverLhs(self):
sparse_block = SparseBlock3x3()
- with self.test_session():
+ with self.cached_session():
[lhs_tensor,
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, self._column_weights, self._unobserved_weights,
@@ -82,7 +82,7 @@ class WalsSolverOpsTest(test.TestCase):
def testWalsSolverLhsEntryWeights(self):
sparse_block = SparseBlock3x3()
- with self.test_session():
+ with self.cached_session():
[lhs_tensor,
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, [], self._unobserved_weights,
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index bb5140aeb3..6aa62fb82e 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase):
observed *= num_rows / 3. if test_rows else num_cols / 2.
want_weight_sum = unobserved + observed
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
wals_model = factorization_ops.WALSModel(
input_rows=num_rows,
input_cols=num_cols,
@@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input_transposed(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase):
# trigger the more efficient ALS updates.
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase):
atol=1e-2)
def _run_test_als_transposed(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase):
rows = 15
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase):
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase):
def keep_index(x):
return not (x[0] + x[1]) % 4
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
row_wts = 0.1 + np.random.rand(rows)
col_wts = 0.1 + np.random.rand(cols)
data = np.dot(np.random.rand(rows, 3), np.random.rand(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
index 888c3c238c..112e4d289b 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
@@ -99,7 +99,7 @@ class GmmOpsTest(test.TestCase):
logging.info('Numpy took %f', time.time() - start_time)
start_time = time.time()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
op = gmm_ops._covariance(
constant_op.constant(
data.T, dtype=dtypes.float32), False)
@@ -120,7 +120,7 @@ class GmmOpsTest(test.TestCase):
graph = ops.Graph()
with graph.as_default() as g:
g.seed = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant(self.data, dtype=dtypes.float32)
loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
data, 'random', num_classes, random_seed=self.seed)
@@ -144,7 +144,7 @@ class GmmOpsTest(test.TestCase):
def testParams(self):
"""Tests that the params work as intended."""
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Experiment 1. Update weights only.
data = constant_op.constant(self.data, dtype=dtypes.float32)
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index 88eb9cf692..1ab5418fe4 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -232,7 +232,7 @@ class KMeansTest(KMeansTestBase):
self.assertEqual(features.shape, parsed_feature_dict.shape)
self.assertEqual(features.dtype, parsed_feature_dict.dtype)
# Then check that running the tensor yields the original list of points.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
parsed_points = sess.run(parsed_feature_dict)
self.assertAllEqual(self.points, parsed_points)
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index ca46c39baa..b82bf1188f 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator):
WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
factorization. It computes a low-rank approximation of a given sparse (n x m)
- matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix
- and V is a (m x k) matrix. Here k is the rank of the approximation, also
- called the embedding dimension. We refer to U as the row factors, and V as the
- column factors.
+ matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k)
+ matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation,
+ also called the embedding dimension. We refer to `U` as the row factors, and
+ `V` as the column factors.
See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
formulation.
- The training proceeds in sweeps: during a row_sweep, we fix V and solve for U.
- During a column sweep, we fix U and solve for V. Each one of these problems is
- an unconstrained quadratic minimization problem and can be solved exactly (it
- can also be solved in mini-batches, since the solution decouples nicely).
+ The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for
+ `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these
+ problems is an unconstrained quadratic minimization problem and can be solved
+ exactly (it can also be solved in mini-batches, since the solution decouples
+ across rows of each matrix).
The alternating between sweeps is achieved by using a hook during training,
which is responsible for keeping track of the sweeps and running preparation
ops at the beginning of each sweep. It also updates the global_step variable,
which keeps track of the number of batches processed since the beginning of
training.
The current implementation assumes that the training is run on a single
- machine, and will fail if config.num_worker_replicas is not equal to one.
- Training is done by calling self.fit(input_fn=input_fn), where input_fn
+ machine, and will fail if `config.num_worker_replicas` is not equal to one.
+ Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn`
provides two tensors: one for rows of the input matrix, and one for rows of
the transposed input matrix (i.e. columns of the original matrix). Note that
during a row sweep, only row batches are processed (ignoring column batches)
and vice-versa.
Also note that every row (respectively every column) of the input matrix
must be processed at least once for the sweep to be considered complete. In
- particular, training will not make progress if input_fn does not generate some
- rows.
-
- For prediction, given a new set of input rows A' (e.g. new rows of the A
- matrix), we compute a corresponding set of row factors U', such that U' * V^T
- is a good approximation of A'. We call this operation a row projection. A
- similar operation is defined for columns.
- Projection is done by calling self.get_projections(input_fn=input_fn), where
- input_fn satisfies the constraints given below.
-
- The input functions must satisfy the following constraints: Calling input_fn
- must return a tuple (features, labels) where labels is None, and features is
- a dict containing the following keys:
+ particular, training will not make progress if some rows are not generated by
+ the `input_fn`.
+
+ For prediction, given a new set of input rows `A'`, we compute a corresponding
+ set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`.
+ We call this operation a row projection. A similar operation is defined for
+ columns. Projection is done by calling
+ `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the
+ constraints given below.
+
+ The input functions must satisfy the following constraints: Calling `input_fn`
+ must return a tuple `(features, labels)` where `labels` is None, and
+ `features` is a dict containing the following keys:
+
TRAIN:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows of the input matrix to process (or to project).
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns of the input matrix to process (or to project), transposed.
+
INFER:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
- - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor
+ * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor
(vector). The weights to use in the projection.
+
EVAL:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
"""
# Keys to be used in model_fn
@@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator):
max_sweeps=None,
model_dir=None,
config=None):
- """Creates a model for matrix factorization using the WALS method.
+ r"""Creates a model for matrix factorization using the WALS method.
Args:
num_rows: Total number of rows for input matrix.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 36b483c6d7..9bdbd05015 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase):
nz_row_ids = np.arange(np.shape(np_matrix)[0])
nz_col_ids = np.arange(np.shape(np_matrix)[1])
- def extract_features(row_batch, col_batch, shape):
+ def extract_features(row_batch, col_batch, num_rows, num_cols):
row_ids = row_batch[0]
col_ids = col_batch[0]
- rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
- cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
+ rows = self.remap_sparse_tensor_rows(
+ row_batch[1], row_ids, shape=[num_rows, num_cols])
+ cols = self.remap_sparse_tensor_rows(
+ col_batch[1], col_ids, shape=[num_cols, num_rows])
features = {
wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
@@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
capacity=10,
enqueue_many=True)
- features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
+ features = extract_features(row_batch, col_batch, num_rows, num_cols)
if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
self.assertTrue(
@@ -334,7 +336,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -352,7 +354,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -438,7 +440,7 @@ class SweepHookTest(test.TestCase):
math_ops.logical_not(is_row_sweep_var)))
mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
is_sweep_done_var,
@@ -489,7 +491,7 @@ class StopAtSweepHookTest(test.TestCase):
train_op = state_ops.assign_add(completed_sweeps, 1)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([variables.global_variables_initializer()])
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py
index 484ffee3e7..3a756da932 100644
--- a/tensorflow/contrib/ffmpeg/__init__.py
+++ b/tensorflow/contrib/ffmpeg/__init__.py
@@ -15,7 +15,7 @@
# pylint: disable=g-short-docstring-punctuation
"""Working with audio using FFmpeg.
-See the @{$python/contrib.ffmpeg} guide.
+See the [FFMPEG](https://tensorflow.org/api_guides/python/contrib.ffmpeg) guide.
@@decode_audio
@@encode_audio
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
index 3dc663bb6f..784da1c432 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
@@ -56,7 +56,7 @@ class DecodeAudioOpTest(test.TestCase):
"""
if samples_per_second_tensor is None:
samples_per_second_tensor = samples_per_second
- with self.test_session():
+ with self.cached_session():
path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
filename)
with open(path, 'rb') as f:
@@ -123,7 +123,7 @@ class DecodeAudioOpTest(test.TestCase):
self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1)
def testInvalidFile(self):
- with self.test_session():
+ with self.cached_session():
contents = 'invalid file'
audio_op = ffmpeg.decode_audio(
contents,
@@ -168,7 +168,7 @@ class DecodeAudioOpTest(test.TestCase):
self._loadFileAndTest('mono_16khz.mp3', 'docx', 0.57, 20000, 1)
def testStaticShapeInference_ConstantChannelCount(self):
- with self.test_session():
+ with self.cached_session():
audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
file_format='wav',
samples_per_second=44100,
@@ -176,7 +176,7 @@ class DecodeAudioOpTest(test.TestCase):
self.assertEqual([None, 2], audio_op.shape.as_list())
def testStaticShapeInference_NonConstantChannelCount(self):
- with self.test_session():
+ with self.cached_session():
channel_count = array_ops.placeholder(dtypes.int32)
audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~',
file_format='wav',
@@ -185,7 +185,7 @@ class DecodeAudioOpTest(test.TestCase):
self.assertEqual([None, None], audio_op.shape.as_list())
def testStaticShapeInference_ZeroChannelCountInvalid(self):
- with self.test_session():
+ with self.cached_session():
with six.assertRaisesRegex(self, Exception,
r'channel_count must be positive'):
ffmpeg.decode_audio(b'~~~ wave ~~~',
@@ -194,7 +194,7 @@ class DecodeAudioOpTest(test.TestCase):
channel_count=0)
def testStaticShapeInference_NegativeChannelCountInvalid(self):
- with self.test_session():
+ with self.cached_session():
with six.assertRaisesRegex(self, Exception,
r'channel_count must be positive'):
ffmpeg.decode_audio(b'~~~ wave ~~~',
diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py
index b43b6b8919..b734690756 100644
--- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py
+++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py
@@ -42,7 +42,7 @@ class DecodeVideoOpTest(test.TestCase):
bmp_filename: The filename for the bmp file.
index: Index location inside the video.
"""
- with self.test_session():
+ with self.cached_session():
path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
filename)
with open(path, 'rb') as f:
diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
index 870290dc10..eb4325da82 100644
--- a/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
+++ b/tensorflow/contrib/ffmpeg/encode_audio_op_test.py
@@ -61,7 +61,7 @@ class EncodeAudioOpTest(test.TestCase):
def testRoundTrip(self):
"""Reads a wav file, writes it, and compares them."""
- with self.test_session():
+ with self.cached_session():
audio_op = ffmpeg.decode_audio(
self._contents,
file_format='wav',
@@ -73,7 +73,7 @@ class EncodeAudioOpTest(test.TestCase):
self._compareWavFiles(self._contents, encoded_contents)
def testRoundTripWithPlaceholderSampleRate(self):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.int32)
audio_op = ffmpeg.decode_audio(
self._contents,
@@ -86,7 +86,7 @@ class EncodeAudioOpTest(test.TestCase):
self._compareWavFiles(self._contents, encoded_contents)
def testFloatingPointSampleRateInvalid(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
ffmpeg.encode_audio(
[[0.0], [1.0]],
@@ -94,7 +94,7 @@ class EncodeAudioOpTest(test.TestCase):
samples_per_second=12345.678)
def testZeroSampleRateInvalid(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
encode_op = ffmpeg.encode_audio(
[[0.0], [1.0]],
file_format='wav',
@@ -103,7 +103,7 @@ class EncodeAudioOpTest(test.TestCase):
sess.run(encode_op)
def testNegativeSampleRateInvalid(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
encode_op = ffmpeg.encode_audio(
[[0.0], [1.0]],
file_format='wav',
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index b1b5126d9e..45a67acb5b 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -24,11 +24,13 @@ from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
+from tensorflow.python.util.deprecation import deprecated
_ffmpeg_so = loader.load_op_library(
resource_loader.get_path_to_datafile('ffmpeg.so'))
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_audio(contents, file_format=None, samples_per_second=None,
channel_count=None, stream=None):
"""Create an op that decodes the contents of an audio file.
@@ -69,6 +71,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
ops.NotDifferentiable('DecodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def encode_audio(audio, file_format=None, samples_per_second=None):
"""Creates an op that encodes an audio file using sampled audio from a tensor.
@@ -95,6 +98,7 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
ops.NotDifferentiable('EncodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_video(contents):
"""Create an op that decodes the contents of a video file.
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 20d099fe5d..95f5ba90ab 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -15,7 +15,9 @@
"""Framework utilities.
-See the @{$python/contrib.framework} guide.
+See the
+[Contrib Framework](https://tensorflow.org/api_guides/python/contrib.framework)
+guide.
@@assert_same_float_dtype
@@assert_scalar
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
index 9396f027d3..77a424145a 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
@@ -82,7 +82,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -90,7 +90,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -103,7 +103,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -112,12 +112,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("my1", [1, 10])
with variable_scope.variable_scope("some_other_scope"):
@@ -146,7 +146,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -158,19 +158,19 @@ class CheckpointsTest(test.TestCase):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(my4.eval(session), v4)
self.assertAllEqual(my5.eval(session), my5_init)
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
@@ -189,12 +189,12 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
my3 = variable_scope.get_variable("var3", [100, 100])
@@ -212,12 +212,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
@@ -247,7 +247,7 @@ class CheckpointsTest(test.TestCase):
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
@@ -266,12 +266,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
_ = variable_scope.get_variable("my1", [10, 10])
_ = variable_scope.get_variable(
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 4e6eea8884..bdf8aeb2b8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None):
return predictions, labels
-def _all_equal(tensor0, tensor1):
- with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
+def _shape_tensor_compatible(expected_shape, actual_shape):
+ """Returns whether actual_shape is compatible with expected_shape.
+
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
+ Args:
+ expected_shape: Integer list defining the expected shape, or tensor of same.
+ actual_shape: Shape of the tensor to test.
+ Returns:
+ New tensor.
+ """
+ with ops.name_scope('shape_tensor_equal',
+ values=[expected_shape, actual_shape]) as scope:
return math_ops.reduce_all(
- math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
+ math_ops.logical_or(
+ math_ops.equal(expected_shape, -1),
+ math_ops.equal(expected_shape, actual_shape, 'equal'),
+ name='exclude_partial_shape'),
+ name=scope)
def _is_rank(expected_rank, actual_tensor):
@@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor):
def _is_shape(expected_shape, actual_tensor, actual_shape=None):
"""Returns whether actual_tensor's shape is expected_shape.
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
Args:
expected_shape: Integer list defining the expected shape, or tensor of same.
actual_tensor: Tensor to test.
@@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None):
is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
if actual_shape is None:
actual_shape = array_ops.shape(actual_tensor, name='actual')
- shape_equal = _all_equal(
- ops.convert_to_tensor(expected_shape, name='expected'),
- actual_shape)
+ shape_equal = _shape_tensor_compatible(expected_shape, actual_shape)
return math_ops.logical_and(is_rank, shape_equal, name=scope)
def _assert_shape_op(expected_shape, actual_tensor):
"""Asserts actual_tensor's shape is expected_shape.
+ Note that unknown dimension in `expected_shape` will be ignored.
+
Args:
expected_shape: List of integers defining the expected shape, or tensor of
same.
@@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor):
"""
with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
actual_shape = array_ops.shape(actual_tensor, name='actual')
+ if (isinstance(expected_shape, tensor_shape.TensorShape)
+ and not expected_shape.is_fully_defined()):
+ expected_shape = [d if d else -1 for d in expected_shape.as_list()]
is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
return control_flow_ops.Assert(
is_shape, [
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index af1b404cb5..b1820c10c8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
@@ -39,7 +39,7 @@ from tensorflow.python.platform import test
class LocalVariabletest(test.TestCase):
def test_local_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
@@ -55,7 +55,7 @@ class LocalVariabletest(test.TestCase):
class ReduceSumNTest(test.TestCase):
def test_reduce_sum_n(self):
- with self.test_session():
+ with self.cached_session():
a = constant_op.constant(1)
b = constant_op.constant([2])
c = constant_op.constant([[3, 4], [5, 6]])
@@ -119,13 +119,13 @@ class WithShapeTest(test.TestCase):
}))
def test_with_shape_invalid_expected_shape(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid rank",
tensor_util.with_shape, [[1], [2]],
constant_op.constant(1.0))
def test_with_shape_invalid_type(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid dtype",
tensor_util.with_shape, [1.1],
constant_op.constant([1.0]))
@@ -138,7 +138,7 @@ class WithShapeTest(test.TestCase):
constant_op.constant(1.0))
def test_with_shape_0(self):
- with self.test_session():
+ with self.cached_session():
value = 42
shape = [0]
unexpected_shapes = [[1], [2], [1, 1]]
@@ -150,7 +150,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_1(self):
- with self.test_session():
+ with self.cached_session():
value = [42]
shape = [1]
unexpected_shapes = [[0], [2], [1, 1]]
@@ -162,7 +162,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
shape = [2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -174,7 +174,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2x2(self):
- with self.test_session():
+ with self.cached_session():
value = [[42, 43], [44, 45]]
shape = [2, 2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -185,8 +185,18 @@ class WithShapeTest(test.TestCase):
shape,
unexpected_shapes)
- def test_with_shape_none(self):
+ def test_with_shape_2x2_with_partial_expected_shape(self):
with self.test_session():
+ value = [[42, 43], [44, 45]]
+ actual_shape = [2, 2]
+ tensor = constant_op.constant(value, shape=actual_shape)
+ partial_expected_shape = tensor_shape.TensorShape([None, 2])
+ # Won't raise any exception here:
+ tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor)
+ np.testing.assert_array_equal(value, tensor_with_shape.eval())
+
+ def test_with_shape_none(self):
+ with self.cached_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)
compatible_shape = [2, 2]
@@ -210,7 +220,7 @@ class WithShapeTest(test.TestCase):
@test_util.enable_c_shapes
def test_with_shape_partial(self):
- with self.test_session():
+ with self.cached_session():
tensor_partial_shape = array_ops.placeholder(dtypes.float32)
tensor_partial_shape.set_shape([None, 2])
@@ -366,7 +376,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
squeezed_predictions, squeezed_labels = (
tensor_util.remove_squeezable_dimensions(predictions, labels))
- with self.test_session(g):
+ with self.session(g):
variables_lib.local_variables_initializer().run()
self.assertAllClose(
predictions_value, squeezed_predictions.eval(feed_dict=feed_dict))
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
index bcafc1a328..0e6c6f0e2f 100644
--- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py
+++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py
@@ -52,7 +52,7 @@ def _key_op(op):
class ArgScopeTest(test.TestCase):
def testEmptyArgScope(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([]) as sc:
self.assertEqual(sc, {})
@@ -60,7 +60,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
func1_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as sc1:
self.assertEqual(sc1, func1_scope)
with arg_scope({}) as sc2:
@@ -86,7 +86,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
current_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope:
self.assertDictEqual(scope, current_scope)
@@ -102,7 +102,7 @@ class ArgScopeTest(test.TestCase):
key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()
}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]):
with arg_scope([func2], b=2, d=[2]) as scope:
self.assertDictEqual(scope, current_scope)
@@ -111,7 +111,7 @@ class ArgScopeTest(test.TestCase):
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
key_op = _key_op(func1)
current_scope = {key_op: func1_kwargs.copy()}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
pass
with arg_scope(scope1) as scope:
@@ -126,7 +126,7 @@ class ArgScopeTest(test.TestCase):
key(func1): func1_kwargs.copy(),
key(func2): func2_kwargs.copy()
}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
with arg_scope([func2], b=2, d=[2]) as scope2:
pass
@@ -140,7 +140,7 @@ class ArgScopeTest(test.TestCase):
def testSimpleArgScope(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
- with self.test_session():
+ with self.cached_session():
with arg_scope([func1], a=1, b=None, c=[1]):
args, kwargs = func1(0)
self.assertTupleEqual(args, func1_args)
@@ -149,7 +149,7 @@ class ArgScopeTest(test.TestCase):
def testSimpleArgScopeWithTuple(self):
func1_args = (0,)
func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
- with self.test_session():
+ with self.cached_session():
with arg_scope((func1,), a=1, b=None, c=[1]):
args, kwargs = func1(0)
self.assertTupleEqual(args, func1_args)
@@ -240,7 +240,7 @@ class ArgScopeTest(test.TestCase):
def testAddArgScopeRaceCondition(self):
func4_kwargs = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h')
for i in range(4):
- # redefine the function with different args
+ # redefine the function with different args
@add_arg_scope
def func4(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8):
pass
diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
index b7b9f5c59e..4036c87b6d 100644
--- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py
@@ -50,7 +50,7 @@ class LoadMulticlassBiasTest(test.TestCase):
bias = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='bias')
save = saver.Saver([bias])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bias_checkpoint')
save.save(sess, self.bundle_file)
@@ -90,7 +90,7 @@ class LoadMulticlassBiasTest(test.TestCase):
initializer=bias_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(3))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_bias_vector,
remapped_bias_vector.as_tensor().eval())
@@ -109,7 +109,7 @@ class LoadVariableSlotTest(test.TestCase):
accum = variables.Variable(
array_ops.reshape(flat_data, (num, dim)), name='accum')
save = saver.Saver([accum])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'accum_checkpoint')
save.save(sess, self.bundle_file)
@@ -179,7 +179,7 @@ class LoadVariableSlotTest(test.TestCase):
shape=[2, 1],
initializer=variable_slot_initializer_part_1)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_accum_vector_part_0,
remapped_accum_vector_part_0.eval())
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
index 72835c3ad8..71ab755aa2 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
@@ -325,6 +325,8 @@ class CriticalSection(object):
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
+ if isinstance(x, ops.EagerTensor):
+ return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
@@ -365,8 +367,7 @@ class CriticalSection(object):
"(CriticalSection: %s) requested exclusive resource access "
"of this resource. Did you mean to call execute with keyword "
"argument exclusive_resource_access=False?" %
- (list(resource_intersection), self._handle.name,
- sg.op.name, sg.handle.name))
+ (list(resource_intersection), self._handle, sg, sg.handle))
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
diff --git a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
index 50bcbe625d..c104c51fef 100644
--- a/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/prettyprint_ops_test.py
@@ -34,7 +34,7 @@ class PrettyPrintOpsTest(test.TestCase):
def testPrintTensorPassthrough(self):
a = constant_op.constant([1])
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(a.eval(), constant_op.constant([1]).eval())
def testPrintSparseTensorPassthrough(self):
@@ -43,7 +43,7 @@ class PrettyPrintOpsTest(test.TestCase):
b = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(a).eval(),
sparse_ops.sparse_tensor_to_dense(b).eval())
@@ -54,13 +54,13 @@ class PrettyPrintOpsTest(test.TestCase):
a = a.write(1, 1)
a = a.write(0, 0)
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(a.stack().eval(), constant_op.constant([0, 1]).eval())
def testPrintVariable(self):
a = variables.Variable(1.0)
a = prettyprint_ops.print_op(a)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
a.eval()
diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py
index 5d269fefdc..d5cb679e2c 100644
--- a/tensorflow/contrib/framework/python/ops/script_ops.py
+++ b/tensorflow/contrib/framework/python/ops/script_ops.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
-"""Script Language Operators. See the @{$python/script_ops} guide.
+"""Script Language Operators.
@@py_func
"""
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
index a8fb94b245..791b32cd1e 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py
@@ -48,7 +48,7 @@ class SortTest(test.TestCase):
sort_axis = np.random.choice(rank)
if negative_axis:
sort_axis = -1 - sort_axis
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -60,7 +60,7 @@ class SortTest(test.TestCase):
shape = [np.random.randint(1, 4) for _ in range(rank)]
arr = np.random.random(shape)
sort_axis = np.random.choice(rank)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=sort_axis),
sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval())
@@ -73,7 +73,7 @@ class SortTest(test.TestCase):
scalar = array_ops.zeros(zeros_length_1)
sort = sort_ops.sort(scalar)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors.InvalidArgumentError):
sort.eval()
@@ -84,7 +84,7 @@ class SortTest(test.TestCase):
def testDescending(self):
arr = np.random.random((10, 5, 5))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr, axis=0)[::-1],
sort_ops.sort(
@@ -111,7 +111,7 @@ class SortTest(test.TestCase):
def testArgsort_1d(self):
arr = np.random.random(42)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.sort(arr),
array_ops.gather(arr, sort_ops.argsort(arr)).eval())
@@ -119,7 +119,7 @@ class SortTest(test.TestCase):
def testArgsort(self):
arr = np.random.random((5, 6, 7, 8))
for axis in range(4):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
np.argsort(arr, axis=axis),
sort_ops.argsort(arr, axis=axis).eval())
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 3c44630a51..f9b0efd1da 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -45,7 +45,7 @@ from tensorflow.python.training import saver as saver_lib
class LocalVariableTest(test.TestCase):
def test_local_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
@@ -58,7 +58,7 @@ class LocalVariableTest(test.TestCase):
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testLocalVariableNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable([1, 1, 1, 1, 1], name='a')
self.assertEquals(a.op.name, 'A/a')
@@ -66,21 +66,21 @@ class LocalVariableTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_local_variables())
def testLocalVariableNotInAllVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertFalse(a in variables_lib.global_variables())
self.assertTrue(a in variables_lib.local_variables())
def testLocalVariableNotInVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
self.assertFalse(a in variables_lib2.get_variables_to_restore())
self.assertTrue(a in variables_lib.local_variables())
def testGetVariablesDontReturnsTransients(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable(0)
with variable_scope.variable_scope('B'):
@@ -89,7 +89,7 @@ class LocalVariableTest(test.TestCase):
self.assertEquals([], variables_lib2.get_variables('B'))
def testGetLocalVariablesReturnsTransients(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.local_variable(0)
with variable_scope.variable_scope('B'):
@@ -98,7 +98,7 @@ class LocalVariableTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.local_variable([0, 0, 0, 0, 0], name='a')
sess.run(variables_lib.local_variables_initializer())
self.assertAllEqual(a.eval(), [0] * 5)
@@ -114,7 +114,7 @@ class LocalVariableTest(test.TestCase):
class GlobalVariableTest(test.TestCase):
def test_global_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.global_variables())
value0 = 42
variables_lib2.global_variable(value0)
@@ -129,7 +129,7 @@ class GlobalVariableTest(test.TestCase):
self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
def testVariableNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a')
self.assertEquals(a.op.name, 'A/a')
@@ -137,21 +137,21 @@ class GlobalVariableTest(test.TestCase):
self.assertListEqual([a], variables_lib.global_variables())
def testGlobalVariableNotInLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
self.assertFalse(a in variables_lib.local_variables())
self.assertTrue(a in variables_lib.global_variables())
def testGlobalVariableInVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
self.assertFalse(a in variables_lib.local_variables())
self.assertTrue(a in variables_lib2.get_variables_to_restore())
def testGetVariablesReturnsThem(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.global_variable(0)
with variable_scope.variable_scope('B'):
@@ -160,7 +160,7 @@ class GlobalVariableTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetLocalVariablesDontReturnsThem(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.global_variable(0)
with variable_scope.variable_scope('B'):
@@ -169,7 +169,7 @@ class GlobalVariableTest(test.TestCase):
self.assertEquals([], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a')
sess.run(variables_lib.global_variables_initializer())
self.assertAllEqual(a.eval(), [0] * 5)
@@ -249,7 +249,7 @@ class GlobalStepTest(test.TestCase):
class VariablesTest(test.TestCase):
def testCreateVariable(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
@@ -259,7 +259,7 @@ class VariablesTest(test.TestCase):
self.assertFalse(a in variables_lib.local_variables())
def testGetVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -269,7 +269,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetVariablesWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A') as var_scope:
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -277,7 +277,7 @@ class VariablesTest(test.TestCase):
set([a, b]), set(variables_lib2.get_variables(var_scope)))
def testGetVariablesSuffix(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('A'):
@@ -286,13 +286,13 @@ class VariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables(suffix='b'))
def testGetVariableWithSingleVar(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('parent'):
a = variables_lib2.variable('child', [5])
self.assertEquals(a, variables_lib2.get_unique_variable('parent/child'))
def testGetVariableWithDistractors(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('parent'):
a = variables_lib2.variable('child', [5])
with variable_scope.variable_scope('child'):
@@ -302,13 +302,13 @@ class VariablesTest(test.TestCase):
def testGetVariableThrowsExceptionWithNoMatch(self):
var_name = 'cant_find_me'
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
variables_lib2.get_unique_variable(var_name)
def testGetThrowsExceptionWithChildrenButNoMatch(self):
var_name = 'parent/child'
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(var_name):
variables_lib2.variable('grandchild1', [7])
variables_lib2.variable('grandchild2', [9])
@@ -316,7 +316,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_unique_variable(var_name)
def testGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -324,7 +324,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([a, b], variables_lib2.get_variables_to_restore())
def testIncludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -333,7 +333,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([a], variables_lib2.get_variables_to_restore(['A']))
def testExcludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -343,7 +343,7 @@ class VariablesTest(test.TestCase):
[a], variables_lib2.get_variables_to_restore(exclude=['B']))
def testWrongIncludeGetVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -352,7 +352,7 @@ class VariablesTest(test.TestCase):
self.assertEquals([], variables_lib2.get_variables_to_restore(['a']))
def testGetMixedVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -365,7 +365,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_variables_to_restore(include=['A/a', 'B/c']))
def testExcludeGetMixedVariablesToRestore(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -378,7 +378,7 @@ class VariablesTest(test.TestCase):
variables_lib2.get_variables_to_restore(exclude=['A/a', 'B/c']))
def testReuseVariable(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [])
with variable_scope.variable_scope('A', reuse=True):
@@ -387,14 +387,14 @@ class VariablesTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_variables())
def testVariableWithRegularizer(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [], regularizer=nn_ops.l2_loss)
loss = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithRegularizerColocate(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable(
'a', [], device='gpu:0', regularizer=nn_ops.l2_loss)
@@ -402,7 +402,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(loss.device, a.device)
def testVariableWithDevice(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [], device='cpu:0')
b = variables_lib2.variable('b', [], device='cpu:1')
@@ -410,7 +410,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(b.device, 'cpu:1')
def testVariableWithDeviceFromScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.device('/cpu:0'):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [], device='cpu:1')
@@ -428,7 +428,7 @@ class VariablesTest(test.TestCase):
self.counter += 1
return 'cpu:%d' % self.counter
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], device=DevFn()):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
@@ -453,7 +453,7 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
def testVariableWithReplicaDeviceSetter(self):
- with self.test_session():
+ with self.cached_session():
with ops.device(device_setter.replica_device_setter(ps_tasks=2)):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
@@ -570,7 +570,7 @@ class VariablesTest(test.TestCase):
class ModelVariablesTest(test.TestCase):
def testNameAndShape(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
self.assertEquals(a.op.name, 'A/a')
@@ -578,7 +578,7 @@ class ModelVariablesTest(test.TestCase):
self.assertListEqual([a], variables_lib2.get_model_variables('A'))
def testNotInLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
self.assertTrue(a in variables_lib.global_variables())
@@ -586,7 +586,7 @@ class ModelVariablesTest(test.TestCase):
self.assertFalse(a in variables_lib.local_variables())
def testGetVariablesReturns(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -595,7 +595,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetModelVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -604,7 +604,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_model_variables('B'))
def testGetTrainableVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable([5])
a = variables_lib.Variable([5])
@@ -615,7 +615,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_trainable_variables('B'))
def testGetLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
_ = variables_lib2.model_variable('a', [5])
with variable_scope.variable_scope('B'):
@@ -624,7 +624,7 @@ class ModelVariablesTest(test.TestCase):
self.assertEquals([], variables_lib2.get_local_variables('B'))
def testInitializedVariableValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables_lib2.model_variable(
'a', [5], initializer=init_ops.ones_initializer())
sess.run(variables_lib.global_variables_initializer())
@@ -670,14 +670,14 @@ class ModelVariablesTest(test.TestCase):
class GetVariablesCollections(test.TestCase):
def testVariableCollection(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [], collections='A')
b = variables_lib2.variable('b', [], collections='B')
self.assertEquals(a, ops.get_collection('A')[0])
self.assertEquals(b, ops.get_collection('B')[0])
def testVariableCollections(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [], collections=['A', 'C'])
b = variables_lib2.variable('b', [], collections=['B', 'C'])
self.assertEquals(a, ops.get_collection('A')[0])
@@ -685,14 +685,14 @@ class GetVariablesCollections(test.TestCase):
self.assertListEqual([a, b], ops.get_collection('C'))
def testVariableCollectionsWithArgScope(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
b = variables_lib2.variable('b', [])
self.assertListEqual([a, b], ops.get_collection('A'))
def testVariableCollectionsWithArgScopeNested(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
with arg_scope([variables_lib2.variable], collections='B'):
@@ -701,7 +701,7 @@ class GetVariablesCollections(test.TestCase):
self.assertEquals(b, ops.get_collection('B')[0])
def testVariableCollectionsWithArgScopeNonNested(self):
- with self.test_session():
+ with self.cached_session():
with arg_scope([variables_lib2.variable], collections='A'):
a = variables_lib2.variable('a', [])
with arg_scope([variables_lib2.variable], collections='B'):
@@ -711,7 +711,7 @@ class GetVariablesCollections(test.TestCase):
self.assertListEqual([b], ops.get_collection('B'))
def testVariableRestoreWithArgScopeNested(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [])
with arg_scope(
[variables_lib2.variable], trainable=False, collections=['A', 'B']):
@@ -726,7 +726,7 @@ class GetVariablesCollections(test.TestCase):
class GetVariablesBySuffixTest(test.TestCase):
def testGetVariableGivenNameScoped(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -734,7 +734,7 @@ class GetVariablesBySuffixTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables_by_suffix('b'))
def testGetVariableWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
@@ -748,7 +748,7 @@ class GetVariablesBySuffixTest(test.TestCase):
self.assertEquals([a, fooa], matched_variables)
def testGetVariableWithoutScope(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
b_a = variables_lib2.variable('B/a', [5])
@@ -761,7 +761,7 @@ class GetVariablesBySuffixTest(test.TestCase):
class GetVariablesByNameTest(test.TestCase):
def testGetVariableGivenNameScoped(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
b = variables_lib2.variable('b', [5])
@@ -769,7 +769,7 @@ class GetVariablesByNameTest(test.TestCase):
self.assertEquals([b], variables_lib2.get_variables_by_name('b'))
def testGetVariableWithScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('A'):
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
@@ -785,7 +785,7 @@ class GetVariablesByNameTest(test.TestCase):
self.assertEquals([a], matched_variables)
def testGetVariableWithoutScope(self):
- with self.test_session():
+ with self.cached_session():
a = variables_lib2.variable('a', [5])
fooa = variables_lib2.variable('fooa', [5])
b_a = variables_lib2.variable('B/a', [5])
@@ -818,7 +818,7 @@ class AssignFromValuesTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
var0 = variables_lib2.variable(
'my_var0', shape=[1, 3, 1], initializer=initializer)
@@ -844,7 +844,7 @@ class AssignFromValuesTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -879,7 +879,7 @@ class AssignFromValuesFnTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
var0 = variables_lib2.variable(
'my_var0', shape=[1, 3, 1], initializer=initializer)
@@ -904,7 +904,7 @@ class AssignFromValuesFnTest(test.TestCase):
init_value0 = np.asarray([1.0, 3.0, 9.0]).reshape((1, 3, 1))
init_value1 = np.asarray([2.0, 4.0, 6.0, 8.0]).reshape((2, 1, 2))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.truncated_normal_initializer(stddev=.1)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -968,7 +968,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -998,7 +998,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case.
var_names_to_values = {'var0': init_value0, 'var1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
# var0 and var1 are PartitionedVariables.
@@ -1039,7 +1039,7 @@ class AssignFromCheckpointTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session():
+ with self.cached_session():
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1062,7 +1062,7 @@ class AssignFromCheckpointTest(test.TestCase):
var_names_to_values = {'layer0/v0': init_value0, 'layer1/v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
with variable_scope.variable_scope('my_model/my_layer0'):
@@ -1123,7 +1123,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1154,7 +1154,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[2, 1])
@@ -1183,7 +1183,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[2, 1])
@@ -1213,7 +1213,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1241,7 +1241,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('v0', shape=[])
@@ -1272,7 +1272,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_value1 = 20.0
var_names_to_values = {'v0': init_value0, 'v1': init_value1}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_path = self.create_checkpoint_from_values(var_names_to_values,
model_dir)
var0 = variables_lib2.variable('my_var0', shape=[])
@@ -1299,7 +1299,7 @@ class ZeroInitializerOpTest(test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
var = variables_lib.Variable(initializer)
var_zero = variables_lib2.zero_initializer(var)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError('Attempting to use uninitialized value'):
var.eval()
if use_init:
@@ -1324,7 +1324,7 @@ class ZeroVarInitializerOpTest(test.TestCase):
var = resource_variable_ops.ResourceVariable(initializer)
var_zero = variables_lib2.zero_initializer(var)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError('Error while reading resource variable'):
var.eval()
if use_init:
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 0ccb4583ab..716bb87e38 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel {
// Input bias is a 1-D tensor, with size matching output depth.
const Tensor& bias = context->input(kBias);
- OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
+ OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
index 7534f5797c..869e899ac8 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
-#define THIRDPARTY_TENSORFLOW_CONTRIB_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
+#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -62,4 +62,4 @@ class LaunchFusedConv2DBiasActivationOp<Eigen::GpuDevice, T, BiasType,
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV2D_BIAS_ACTIVATION_OP_H_
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 9866fccfba..9d0e6e1335 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -105,6 +105,7 @@ py_library(
deps = [
":gan_estimator",
":head",
+ ":stargan_estimator",
"//tensorflow/python:util",
],
)
@@ -534,6 +535,57 @@ py_test(
)
py_library(
+ name = "stargan_estimator",
+ srcs = [
+ "python/estimator/python/stargan_estimator.py",
+ "python/estimator/python/stargan_estimator_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":summaries",
+ ":train",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "stargan_estimator_test",
+ srcs = ["python/estimator/python/stargan_estimator_test.py"],
+ shard_count = 1,
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":namedtuples",
+ ":stargan_estimator",
+ ":tuple_losses",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "sliced_wasserstein",
srcs = [
"python/eval/python/sliced_wasserstein.py",
diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py
index c9f7bc61b2..99d38011ba 100644
--- a/tensorflow/contrib/gan/python/estimator/__init__.py
+++ b/tensorflow/contrib/gan/python/estimator/__init__.py
@@ -26,15 +26,18 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
from tensorflow.contrib.gan.python.estimator.python import head
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
from tensorflow.contrib.gan.python.estimator.python.head import *
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
+ 'stargan_estimator',
'head',
-] + gan_estimator.__all__ + head.__all__
+] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index ab9886580d..7243f150ce 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -184,7 +184,7 @@ class GANEstimator(estimator.Estimator):
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn)
+ get_hooks_fn, use_loss_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -211,15 +211,17 @@ def _get_gan_model(
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn=None):
+ get_hooks_fn=None, use_loss_summaries=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
gan_loss = tfgan_tuples.GANLoss(
- generator_loss=generator_loss_fn(gan_model),
- discriminator_loss=discriminator_loss_fn(gan_model))
+ generator_loss=generator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries),
+ discriminator_loss=discriminator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries))
if mode == model_fn_lib.ModeKeys.EVAL:
estimator_spec = _get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 9ac9c6ca9c..83f8dd641f 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -116,7 +116,7 @@ def get_dummy_gan_model():
discriminator_fn=None)
-def dummy_loss_fn(gan_model):
+def dummy_loss_fn(gan_model, add_summaries=True):
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
index 87d1866e06..341bdf9fbb 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
@@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""The KFAC optimizer."""
+"""`tf.Learn` components for `GANEstimator`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.optimizer import *
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import *
+# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-_allowed_symbols = [
- "KfacOptimizer",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
+__all__ = stargan_estimator_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
new file mode 100644
index 0000000000..f60e16bc04
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
@@ -0,0 +1,363 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A TFGAN-backed StarGAN Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import enum
+
+from tensorflow.contrib.framework.python.ops import variables as variable_lib
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python import train as tfgan_train
+from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import tf_inspect as inspect
+
+__all__ = ['StarGANEstimator', 'SummaryType']
+
+
+class SummaryType(enum.IntEnum):
+ NONE = 0
+ VARIABLES = 1
+ IMAGES = 2
+ IMAGE_COMPARISON = 3
+
+
+_summary_type_map = {
+ SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
+ SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries,
+}
+
+
+class StarGANEstimator(estimator.Estimator):
+ """An estimator for Generative Adversarial Networks (GANs).
+
+ This Estimator is backed by TFGAN. The network functions follow the TFGAN API
+ except for one exception: if either `generator_fn` or `discriminator_fn` have
+ an argument called `mode`, then the tf.Estimator mode is passed in for that
+ argument. This helps with operations like batch normalization, which have
+ different train and evaluation behavior.
+
+ Example:
+
+ ```python
+ import tensorflow as tf
+ tfgan = tf.contrib.gan
+
+ # See TFGAN's `train.py` for a description of the generator and
+ # discriminator API.
+ def generator_fn(generator_inputs):
+ ...
+ return generated_data
+
+ def discriminator_fn(data, conditioning):
+ ...
+ return logits
+
+ # Create GAN estimator.
+ stargan_estimator = tfgan.estimator.StarGANEstimator(
+ model_dir,
+ generator_fn=generator_fn,
+ discriminator_fn=discriminator_fn,
+ loss_fn=loss_fn,
+ generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
+ discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))
+
+ # Train estimator.
+ stargan_estimator.train(train_input_fn, steps)
+
+ # Evaluate resulting estimator.
+ stargan_estimator.evaluate(eval_input_fn)
+
+ # Generate samples from generator.
+ stargan_estimator = np.array([
+ x for x in stargan_estimator.predict(predict_input_fn)])
+ ```
+ """
+
+ def __init__(self,
+ model_dir=None,
+ generator_fn=None,
+ discriminator_fn=None,
+ loss_fn=None,
+ generator_optimizer=None,
+ discriminator_optimizer=None,
+ get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
+ add_summaries=None,
+ use_loss_summaries=True,
+ config=None):
+ """Initializes a StarGANEstimator instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ generator_fn: A python function that takes a Tensor, Tensor list, or
+ Tensor dictionary as inputs and returns the outputs of the GAN
+ generator. See `TFGAN` for more details and examples. Additionally, if
+ it has an argument called `mode`, the Estimator's `mode` will be passed
+ in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
+ normalization.
+ discriminator_fn: A python function that takes the output of
+ `generator_fn` or real data in the GAN setup, and `input_data`. Outputs
+ a Tensor in the range [-inf, inf]. See `TFGAN` for more details and
+ examples.
+ loss_fn: The loss function on the generator. Takes a `StarGANModel`
+ namedtuple and return a `GANLoss` namedtuple.
+ generator_optimizer: The optimizer for generator updates, or a function
+ that takes no arguments and returns an optimizer. This function will be
+ called when the default graph is the `StarGANEstimator`'s graph, so
+ utilities like `tf.contrib.framework.get_or_create_global_step` will
+ work.
+ discriminator_optimizer: Same as `generator_optimizer`, but for the
+ discriminator updates.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. These hooks are run on the generator and discriminator
+ train ops, and can be used to implement the GAN training scheme.
+ Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
+ add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
+ use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
+ If `None`, uses defaults.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If loss functions aren't callable.
+ ValueError: If `use_loss_summaries` isn't boolean or `None`.
+ ValueError: If `get_hooks_fn` isn't callable or `None`.
+ """
+ if not callable(loss_fn):
+ raise ValueError('loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
+
+ def _model_fn(features, labels, mode):
+ """StarGANEstimator model function."""
+ if mode not in [
+ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT
+ ]:
+ raise ValueError('Mode not recognized: %s' % mode)
+
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ input_data = features[0]
+ input_data_domain_label = features[1]
+ else:
+ input_data = features # rename inputs for clarity
+ input_data_domain_label = labels # rename inputs for clarity
+
+ # Make StarGANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(mode, generator_fn, discriminator_fn,
+ input_data, input_data_domain_label,
+ add_summaries)
+
+ # Make the EstimatorSpec, which incorporates the StarGANModel, losses,
+ # eval, metrics, and optimizers (if required).
+ return _get_estimator_spec(mode, gan_model, loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer,
+ discriminator_optimizer, get_hooks_fn)
+
+ super(StarGANEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+def _get_gan_model(mode,
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries,
+ generator_scope='Generator'):
+ """Makes the StarGANModel tuple."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ gan_model = _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope,
+ add_summaries, mode)
+
+ return gan_model
+
+
+def _get_estimator_spec(mode,
+ gan_model,
+ loss_fn,
+ get_eval_metric_ops_fn,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = loss_fn(gan_model)
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
+ get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (
+ generator_optimizer()
+ if callable(generator_optimizer) else generator_optimizer)
+ dopt = (
+ discriminator_optimizer()
+ if callable(discriminator_optimizer) else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
+ dopt, get_hooks_fn)
+
+ return estimator_spec
+
+
+def _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope, add_summaries,
+ mode):
+ """Construct a `StarGANModel`, and optionally pass in `mode`."""
+ # If network functions have an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(generator_fn, mode=mode)
+ if 'mode' in inspect.getargspec(discriminator_fn).args:
+ discriminator_fn = functools.partial(discriminator_fn, mode=mode)
+ gan_model = tfgan_train.stargan_model(
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ generator_scope=generator_scope)
+ if add_summaries:
+ if not isinstance(add_summaries, (tuple, list)):
+ add_summaries = [add_summaries]
+ with ops.name_scope(None):
+ for summary_type in add_summaries:
+ _summary_type_map[summary_type](gan_model)
+
+ return gan_model
+
+
+def _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope):
+ """Make a `StarGANModel` from just the generator."""
+ # If `generator_fn` has an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(
+ generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ # pylint:disable=protected-access
+ input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
+ input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
+ input_data_domain_label)
+ # pylint:enable=protected-access
+ generated_data = generator_fn(input_data, input_data_domain_label)
+ generator_variables = variable_lib.get_trainable_variables(gen_scope)
+
+ return tfgan_tuples.StarGANModel(
+ input_data=input_data,
+ input_data_domain_label=None,
+ generated_data=generated_data,
+ generated_data_domain_target=input_data_domain_label,
+ reconstructed_data=None,
+ discriminator_input_data_source_predication=None,
+ discriminator_generated_data_source_predication=None,
+ discriminator_input_data_domain_predication=None,
+ discriminator_generated_data_domain_predication=None,
+ generator_variables=generator_variables,
+ generator_scope=generator_scope,
+ generator_fn=generator_fn,
+ discriminator_variables=None,
+ discriminator_scope=None,
+ discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model,
+ gan_loss,
+ get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss, gan_loss.discriminator_loss]):
+
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(gan_model,
+ gan_loss,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn,
+ train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
+
+
+def stargan_prediction_input_fn_wrapper(fn):
+ """StarGAN Estimator prediction input_fn wrapper.
+
+ Since estimator will disregard the "label" variable pass to the model, we will
+ use a wrapper to pack the (feature, label) tuple as feature passed to the
+ model.
+
+ Args:
+ fn: input_fn for the prediction.
+
+ Returns:
+ A tuple ((feature, label), None) where the second element is the dummy label
+ to be disregarded and the first element is the true input to the estimator.
+ """
+
+ def new_fn():
+ return fn(), None
+
+ return new_fn
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
new file mode 100644
index 0000000000..2ec7938c7c
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
@@ -0,0 +1,306 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFGAN's stargan_estimator.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+import six
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+def dummy_generator_fn(input_data, input_data_domain_label, mode):
+ del input_data_domain_label, mode
+
+ return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data
+
+
+def dummy_discriminator_fn(input_data, num_domains, mode):
+ del mode
+
+ hidden = layers.flatten(input_data)
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+ output_cls = layers.fully_connected(
+ inputs=hidden, num_outputs=num_domains, scope='debug')
+
+ return output_src, output_cls
+
+
+class StarGetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `StarGetGANModel` produces the correct model."""
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
+ with ops.Graph().as_default():
+ input_data = array_ops.ones([6, 4, 4, 3])
+ input_data_domain_label = array_ops.one_hot([0] * 6, 5)
+ gan_model = estimator._get_gan_model(
+ mode,
+ dummy_generator_fn,
+ dummy_discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries=False)
+
+ self.assertEqual(input_data, gan_model.input_data)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertIsNotNone(gan_model.generated_data_domain_target)
+ self.assertEqual(1, len(gan_model.generator_variables))
+ self.assertIsNotNone(gan_model.generator_scope)
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.input_data_domain_label)
+ self.assertEqual(input_data_domain_label,
+ gan_model.generated_data_domain_target)
+ self.assertIsNone(gan_model.reconstructed_data)
+ self.assertIsNone(gan_model.discriminator_input_data_source_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNone(gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertEqual(input_data_domain_label,
+ gan_model.input_data_domain_label)
+ self.assertIsNotNone(gan_model.reconstructed_data.shape)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=array_ops.ones([1, 2, 2, 3]),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var,
+ discriminator_generated_data_source_predication=array_ops.ones(
+ [1]) * gen_var * dis_var,
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2
+ ]) * dis_var,
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) *
+ gen_var * dis_var,
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ loss = math_ops.reduce_sum(
+ gan_model.discriminator_input_data_domain_predication -
+ gan_model.discriminator_generated_data_domain_predication)
+ loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data)
+ return tfgan_tuples.GANLoss(loss, loss)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric':
+ metrics_lib.mean_squared_error(gan_model.input_data,
+ gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
+
+
+# TODO(joelshor): Add pandas test.
+class StarGANEstimatorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self,
+ train_input_fn,
+ eval_input_fn,
+ predict_input_fn,
+ prediction_size,
+ lr_decay=False):
+
+ def make_opt():
+ gstep = training_util.get_or_create_global_step()
+ lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
+ return training.GradientDescentOptimizer(lr)
+
+ gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ est = estimator.StarGANEstimator(
+ generator_fn=dummy_generator_fn,
+ discriminator_fn=dummy_discriminator_fn,
+ loss_fn=dummy_loss_fn,
+ generator_optimizer=gopt,
+ discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+ self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
+ scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([x for x in est.predict(predict_input_fn)])
+
+ self.assertAllEqual(prediction_size, predictions.shape)
+
+ @staticmethod
+ def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size):
+ """Wrapper to remove the dictionary in numpy_input_fn.
+
+ NOTE:
+ We create the domain_label here because the model expect a fully define
+ batch_size from the input.
+
+ Args:
+ numpy_input_fn: input_fn created from numpy_io
+ batch_size: (int) number of items for each batch
+ label_size: (int) number of domains
+
+ Returns:
+ a new input_fn
+ """
+
+ def new_input_fn():
+ features = numpy_input_fn()
+ return features['x'], array_ops.one_hot([0] * batch_size, label_size)
+
+ return new_input_fn
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ batch_size = 5
+ img_size = 8
+ channel_size = 3
+ label_size = 3
+ image_data = np.zeros(
+ [batch_size, img_size, img_size, channel_size], dtype=np.float32)
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data},
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, shuffle=False)
+
+ train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size,
+ label_size)
+ eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size,
+ label_size)
+ predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn,
+ batch_size, label_size)
+
+ predict_input_fn = estimator.stargan_prediction_input_fn_wrapper(
+ predict_input_fn)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ prediction_size=[batch_size, img_size, img_size, channel_size])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
index 4fb8d58bc9..d64dfd1576 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -335,7 +335,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
tf_pool_real_a, tf_pool_gen_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_mofid = sess.run(mofid_op)
expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a)
@@ -355,7 +355,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
tf_pool_real_a, tf_pool_gen_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_dofid = sess.run(dofid_op)
expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a)
@@ -377,7 +377,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
test_pool_gen_a,
classifier_fn=lambda x: x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_fid = sess.run(fid_op)
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
@@ -404,7 +404,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
classifier_fn=lambda x: x))
fids = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for fid_op in fid_ops:
fids.append(sess.run(fid_op))
@@ -426,7 +426,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase):
trace_sqrt_prod_op = _run_with_mock(classifier_metrics.trace_sqrt_product,
cov_real, cov_gen)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# trace_sqrt_product: tsp
actual_tsp = sess.run(trace_sqrt_prod_op)
diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
index 871f1ad54e..ab909feae3 100644
--- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py
@@ -65,7 +65,7 @@ class ClassifierMetricsTest(test.TestCase):
pyramid = np_laplacian_pyramid(data, 3)
data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3])
pyramid_tf = swd._laplacian_pyramid(data_tf, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pyramid_tf = sess.run(
pyramid_tf, feed_dict={
data_tf: data.transpose(0, 2, 3, 1)
@@ -79,7 +79,7 @@ class ClassifierMetricsTest(test.TestCase):
d1 = random_ops.random_uniform([256, 32, 32, 3])
d2 = random_ops.random_normal([256, 32, 32, 3])
wfunc = swd.sliced_wasserstein_distance(d1, d2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wscores = [sess.run(x) for x in wfunc]
self.assertAllClose(
np.array([0.014, 0.014], 'f'),
@@ -95,7 +95,7 @@ class ClassifierMetricsTest(test.TestCase):
d1 = random_ops.random_uniform([256, 32, 32, 3])
d2 = random_ops.random_normal([256, 32, 32, 3])
wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wscores = [sess.run(x) for x in wfunc]
self.assertAllClose(
np.array([0.013, 0.013], 'f'),
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 9f5fee4542..e3c780ac1a 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -51,7 +51,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -59,7 +59,7 @@ class _LossesTest(object):
self._discriminator_real_outputs, self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -90,7 +90,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
@@ -98,7 +98,7 @@ class _LossesTest(object):
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -108,7 +108,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(logits, weights=weights)
self.assertEqual(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [[10.0, 4.4, -5.5, 3.6]],
@@ -125,7 +125,7 @@ class _LossesTest(object):
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
@@ -136,7 +136,7 @@ class _LossesTest(object):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -144,14 +144,14 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=self._weights, generated_weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs,
weights=constant_op.constant(self._weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -160,7 +160,7 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=weights, generated_weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._generator_kwargs.items()}
loss = self._g_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._discriminator_kwargs.items()}
loss = self._d_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase):
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._g_loss_fn(gen_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(
real_weights=self._weights, generated_weights=self._weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(
weights=constant_op.constant(self._weights), **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -404,7 +404,7 @@ class _PenaltyTest(object):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
self.assertEqual(self._expected_op_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
@@ -419,13 +419,13 @@ class _PenaltyTest(object):
def test_python_scalar_weight(self):
loss = self._penalty_fn(weights=2.3, **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
def test_scalar_tensor_weight(self):
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
@@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'])
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
one_sided=True)
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'],
target=2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(
loss,
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbfa11..25d74a8c23 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args):
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
@@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase):
discriminator_scope=self.discriminator_scope)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 7e6a0f14f6..726f74c7b7 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -186,22 +186,22 @@ class GdrMemoryManager : public RemoteMemoryManager {
// TODO(byronyi): remove this class and its registration when the default
// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
// longer in use.
-class BFCRdmaAllocator : public BFCAllocator {
+class BFCGdrAllocator : public BFCAllocator {
public:
- BFCRdmaAllocator()
+ BFCGdrAllocator()
: BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_rdma_bfc") {}
+ true, "cpu_gdr_bfc") {}
};
-class BFCRdmaAllocatorFactory : public AllocatorFactory {
+class BFCGdrAllocatorFactory : public AllocatorFactory {
public:
- Allocator* CreateAllocator() override { return new BFCRdmaAllocator; }
+ Allocator* CreateAllocator() override { return new BFCGdrAllocator; }
virtual SubAllocator* CreateSubAllocator(int numa_node) {
return new BasicCPUAllocator(numa_node);
}
};
-REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory);
+REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory);
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
: host_(host),
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.h b/tensorflow/contrib/gdr/gdr_memory_manager.h
index 9ac1aa96c4..c85886863e 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.h
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef GDR_MEMORY_MANAGER_H_
-#define GDR_MEMORY_MANAGER_H_
+#ifndef TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_
+#define TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_
#include "google/protobuf/any.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -57,4 +57,4 @@ RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
} // namespace tensorflow
-#endif // GDR_MEMORY_MANAGER_H_
+#endif // TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_
diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.h b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.h
index 7fedd04f54..47a36efdb7 100644
--- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.h
+++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef GDR_RENDEZVOUS_MGR_H_
-#define GDR_RENDEZVOUS_MGR_H_
+#ifndef TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
@@ -39,4 +39,4 @@ class GdrRendezvousMgr : public BaseRendezvousMgr {
} // end namespace tensorflow
-#endif // GDR_RENDEZVOUS_MGR_H_
+#endif // TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/contrib/gdr/gdr_server_lib.h b/tensorflow/contrib/gdr/gdr_server_lib.h
index d6c40d429e..efa2390d33 100644
--- a/tensorflow/contrib/gdr/gdr_server_lib.h
+++ b/tensorflow/contrib/gdr/gdr_server_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef GDR_SERVER_LIB_H_
-#define GDR_SERVER_LIB_H_
+#ifndef TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_
+#define TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
@@ -49,4 +49,4 @@ class GdrServer : public GrpcServer {
} // namespace tensorflow
-#endif // GDR_SERVER_LIB_H_
+#endif // TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_
diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h
index 54081f655e..65105ed997 100644
--- a/tensorflow/contrib/gdr/gdr_worker.h
+++ b/tensorflow/contrib/gdr/gdr_worker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef GDR_WORKER_H_
-#define GDR_WORKER_H_
+#ifndef TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_
+#define TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
@@ -44,4 +44,4 @@ class GdrWorker : public GrpcWorker {
} // namespace tensorflow
-#endif // GDR_WORKER_H_
+#endif // TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_
diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py
index 51b7f45274..b2de2b9a69 100644
--- a/tensorflow/contrib/graph_editor/__init__.py
+++ b/tensorflow/contrib/graph_editor/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""TensorFlow Graph Editor.
-See the @{$python/contrib.graph_editor} guide.
+See the
+[Graph Editor](https://tensorflow.org/api_guides/python/contrib.graph_editor)
+guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index 80b2d3e08b..2bf6097d01 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
+namespace data {
namespace {
static const size_t kSyncMarkerSize = 16;
@@ -332,9 +333,10 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
};
DataTypeVector output_types_;
};
-} // namespace
REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU),
SequenceFileDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 693724b457..370a8caf6a 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -71,7 +71,6 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
- const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -82,17 +81,28 @@ class ImageProjectiveTransform : public OpKernel {
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
- OP_REQUIRES(ctx, shape_t.dims() == 1,
- errors::InvalidArgument("output shape must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(ctx, shape_t.NumElements() == 2,
- errors::InvalidArgument("output shape must have two elements",
- shape_t.shape().DebugString()));
- auto shape_vec = shape_t.vec<int32>();
- int32 out_height = shape_vec(0);
- int32 out_width = shape_vec(1);
- OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
- errors::InvalidArgument("output dimensions must be positive"));
+
+ int32 out_height, out_width;
+ // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
+ if (ctx->num_inputs() >= 3) {
+ const Tensor& shape_t = ctx->input(2);
+ OP_REQUIRES(ctx, shape_t.dims() == 1,
+ errors::InvalidArgument("output shape must be 1-dimensional",
+ shape_t.shape().DebugString()));
+ OP_REQUIRES(ctx, shape_t.NumElements() == 2,
+ errors::InvalidArgument("output shape must have two elements",
+ shape_t.shape().DebugString()));
+ auto shape_vec = shape_t.vec<int32>();
+ out_height = shape_vec(0);
+ out_width = shape_vec(1);
+ OP_REQUIRES(
+ ctx, out_height > 0 && out_width > 0,
+ errors::InvalidArgument("output dimensions must be positive"));
+ } else {
+ // Shape is N (batch size), H (height), W (width), C (channels).
+ out_height = images_t.shape().dim_size(1);
+ out_width = images_t.shape().dim_size(2);
+ }
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
@@ -109,10 +119,14 @@ class ImageProjectiveTransform : public OpKernel {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("dtype"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<CPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
ImageProjectiveTransform<CPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
@@ -147,11 +161,15 @@ TF_CALL_double(DECLARE_FUNCTOR);
} // end namespace functor
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<TYPE>("dtype") \
- .HostMemory("output_shape"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<GPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype") \
+ .HostMemory("output_shape"), \
ImageProjectiveTransform<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4969ac58f9..6f7c9bb520 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -67,19 +67,7 @@ Status ResizeShapeFn(InferenceContext* c) {
c->Dim(input, 3));
}
-} // namespace
-
-// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
-// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-REGISTER_OP("ImageProjectiveTransform")
- .Input("images: dtype")
- .Input("transforms: float32")
- .Input("output_shape: int32")
- .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
- .Attr("interpolation: string")
- .Output("transformed_images: dtype")
- .SetShapeFn(ResizeShapeFn)
- .Doc(R"doc(
+static const char kImageProjectiveTransformDoc[] = R"doc(
Applies the given transform to each of the images.
Input `image` is a `Tensor` in NHWC format (where the axes are image in batch,
@@ -99,7 +87,35 @@ transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).
transformed_images: 4D `Tensor`, image(s) in NHWC format, generated by applying
the `transforms` to the `images`. Satisfies the description above.
-)doc");
+)doc";
+
+} // namespace
+
+// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
+// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
+REGISTER_OP("ImageProjectiveTransform")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ // Output shape is identical to input images.
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(kImageProjectiveTransformDoc);
+
+// V2 op supports output_shape.
+REGISTER_OP("ImageProjectiveTransformV2")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Input("output_shape: int32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ .SetShapeFn(ResizeShapeFn)
+ .Doc(kImageProjectiveTransformDoc);
REGISTER_OP("BipartiteMatch")
.Input("distance_mat: float")
diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
index a58b6a247e..24b790977d 100644
--- a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py
@@ -50,7 +50,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
interp = dense_image_warp._interpolate_bilinear(grid, query_points)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predicted = sess.run(interp)
self.assertAllClose(expected_results, predicted)
@@ -64,7 +64,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
interp = dense_image_warp._interpolate_bilinear(
grid, query_points, indexing='xy')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predicted = sess.run(interp)
self.assertAllClose(expected_results, predicted)
@@ -78,7 +78,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
interp = dense_image_warp._interpolate_bilinear(grid, query_points)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predicted = sess.run(interp)
self.assertAllClose(expected_results, predicted)
@@ -160,7 +160,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
flow_type)
interp = dense_image_warp.dense_image_warp(image, flows)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
rand_image, rand_flows = self.get_random_image_and_flows(
shape, image_type, flow_type)
rand_flows *= 0
@@ -191,7 +191,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
flow_type)
interp = dense_image_warp.dense_image_warp(image, flows)
low_precision = image_type == 'float16' or flow_type == 'float16'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
rand_image, rand_flows = self.get_random_image_and_flows(
shape, image_type, flow_type)
@@ -249,7 +249,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
opt_func = optimizer.apply_gradients(zip(grad, [flows]))
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(10):
sess.run(opt_func)
diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
index a495b58b7f..ac8573445c 100644
--- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
@@ -217,7 +217,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
'gb_same',
'rgb_same',
]
- with self.test_session():
+ with self.cached_session():
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index f588eae923..376c0751ee 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.image.ops import gen_image_ops
from tensorflow.contrib.image.python.ops import image_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -39,7 +40,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_zeros(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
for shape in [(5, 5), (24, 24), (2, 24, 24, 3)]:
for angle in [0, 1, np.pi / 2.0]:
image = array_ops.zeros(shape, dtype)
@@ -49,7 +50,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_rotate_even(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
image = array_ops.reshape(
math_ops.cast(math_ops.range(36), dtype), (6, 6))
image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1])
@@ -71,7 +72,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_rotate_odd(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
image = array_ops.reshape(
math_ops.cast(math_ops.range(25), dtype), (5, 5))
image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1])
@@ -91,7 +92,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_translate(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
image = constant_op.constant(
[[1, 0, 1, 0],
[0, 1, 0, 1],
@@ -107,7 +108,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_compose(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
image = constant_op.constant(
[[1, 1, 1, 0],
[1, 0, 0, 0],
@@ -131,7 +132,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
def test_extreme_projective_transform(self):
for dtype in _DTYPES:
- with self.test_session():
+ with self.cached_session():
image = constant_op.constant(
[[1, 0, 1, 0],
[0, 1, 0, 1],
@@ -147,7 +148,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 0, 0, 0]])
def test_bilinear(self):
- with self.test_session():
+ with self.cached_session():
image = constant_op.constant(
[[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
@@ -176,7 +177,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 0, 1, 0, 0]])
def test_bilinear_uint8(self):
- with self.test_session():
+ with self.cached_session():
image = constant_op.constant(
np.asarray(
[[0.0, 0.0, 0.0, 0.0, 0.0],
@@ -209,7 +210,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([3, 5], result.get_shape())
def _test_grad(self, shape_to_test):
- with self.test_session():
+ with self.cached_session():
test_image_shape = shape_to_test
test_image = np.random.randn(*test_image_shape)
test_image_tensor = constant_op.constant(
@@ -228,7 +229,7 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
self.assertLess(left_err, 1e-10)
def _test_grad_different_shape(self, input_shape, output_shape):
- with self.test_session():
+ with self.cached_session():
test_image_shape = input_shape
test_image = np.random.randn(*test_image_shape)
test_image_tensor = constant_op.constant(
@@ -262,6 +263,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
+ def test_projective_transform_v1(self):
+ """The original ImageProjectiveTransform op should take 2 arguments."""
+ image = constant_op.constant([[[[1], [0]], [[0], [1]]]])
+ transform = constant_op.constant([[1., 0., 0., 0., 1., 0., 0., 0.]])
+ result = gen_image_ops.image_projective_transform(
+ image, transform, interpolation="NEAREST")
+ with self.cached_session():
+ self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):
@@ -276,7 +286,7 @@ class BipartiteMatchTest(test_util.TensorFlowTestCase):
expected_col_to_row_match_np = np.array(expected_col_to_row_match,
dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
distance_mat_tf = constant_op.constant(distance_mat_np,
shape=distance_mat_shape)
location_to_prior, prior_to_location = image_ops.bipartite_match(
diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
index 1939caaa2d..d58a654292 100644
--- a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
@@ -164,7 +165,7 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
with ops.name_scope('interpolator'):
interpolator = interpolate_spline.interpolate_spline(
train_points, train_values, query_points, interpolation_order)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fetches = [query_points, train_points, train_values, interpolator]
query_points_, train_points_, train_values_, interp_ = sess.run(fetches)
@@ -204,7 +205,7 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
target_interpolation = np.array(target_interpolation)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
interp_val = sess.run(interpolator)
self.assertAllClose(interp_val[0, :, 0], target_interpolation)
@@ -222,10 +223,85 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
target_interpolation = np.array(target_interpolation)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
interp_val = sess.run(interpolator)
self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+ def test_nd_linear_interpolation_unspecified_shape(self):
+ """Ensure that interpolation supports dynamic batch_size and num_points."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ interpolator = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph, query_points_ph, order, reg_weight)
+
+ target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
+ target_interpolation = np.array(target_interpolation)
+ with self.cached_session() as sess:
+
+ (train_points_value, train_values_value, query_points_value) = sess.run(
+ [train_points, train_values, query_points])
+
+ interp_val = sess.run(
+ interpolator,
+ feed_dict={
+ train_points_ph: train_points_value,
+ train_values_ph: train_values_value,
+ query_points_ph: query_points_value
+ })
+ self.assertAllClose(interp_val[0, :, 0], target_interpolation)
+
+ def test_fully_unspecified_shape(self):
+ """Ensure that erreor is thrown when input/output dim unspecified."""
+
+ tp = _QuadraticPlusSinProblemND()
+ (query_points, _, train_points,
+ train_values) = tp.get_problem(dtype='float64')
+
+ # Construct placeholders such that the batch size, number of train points,
+ # and number of query points are not known at graph construction time.
+ feature_dim = query_points.shape[-1]
+ value_dim = train_values.shape[-1]
+ train_points_ph = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, feature_dim])
+ train_points_ph_invalid = array_ops.placeholder(
+ dtype=train_points.dtype, shape=[None, None, None])
+ train_values_ph = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, value_dim])
+ train_values_ph_invalid = array_ops.placeholder(
+ dtype=train_values.dtype, shape=[None, None, None])
+ query_points_ph = array_ops.placeholder(
+ dtype=query_points.dtype, shape=[None, None, feature_dim])
+
+ order = 1
+ reg_weight = 0.01
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph_invalid, train_values_ph, query_points_ph, order,
+ reg_weight)
+
+ with self.assertRaises(ValueError):
+ _ = interpolate_spline.interpolate_spline(
+ train_points_ph, train_values_ph_invalid, query_points_ph, order,
+ reg_weight)
+
def test_interpolation_gradient(self):
"""Make sure that backprop can run. Correctness of gradients is assumed.
@@ -254,7 +330,7 @@ class InterpolateSplineTest(test_util.TensorFlowTestCase):
opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(100):
sess.run([loss, opt_func])
diff --git a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py
index 48066cbace..3d39165ede 100644
--- a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py
@@ -59,19 +59,19 @@ class SegmentationTest(test_util.TensorFlowTestCase):
[7, 0, 8, 0, 0, 0, 9, 0, 0],
[0, 0, 0, 0, 10, 0, 0, 0, 0],
[0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(image_ops.connected_components(arr).eval(), expected)
def testSimple(self):
arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
- with self.test_session():
+ with self.cached_session():
# Single component with id 1.
self.assertAllEqual(
image_ops.connected_components(math_ops.cast(
arr, dtypes.bool)).eval(), arr)
def testSnake(self):
- with self.test_session():
+ with self.cached_session():
# Single component with id 1.
self.assertAllEqual(
image_ops.connected_components(math_ops.cast(
@@ -80,7 +80,7 @@ class SegmentationTest(test_util.TensorFlowTestCase):
def testSnake_disconnected(self):
for i in range(SNAKE.shape[0]):
for j in range(SNAKE.shape[1]):
- with self.test_session():
+ with self.cached_session():
# If we disconnect any part of the snake except for the endpoints,
# there will be 2 components.
if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]:
@@ -121,27 +121,27 @@ class SegmentationTest(test_util.TensorFlowTestCase):
[0, 6, 6, 0],
[8, 0, 6, 0],
[0, 0, 6, 6]]] # pyformat: disable
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
image_ops.connected_components(math_ops.cast(
images, dtypes.bool)).eval(), expected)
def testZeros(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
image_ops.connected_components(
array_ops.zeros((100, 20, 50), dtypes.bool)).eval(),
np.zeros((100, 20, 50)))
def testOnes(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
image_ops.connected_components(
array_ops.ones((100, 20, 50), dtypes.bool)).eval(),
np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50]))
def testOnes_small(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
image_ops.connected_components(array_ops.ones((3, 5),
dtypes.bool)).eval(),
@@ -153,7 +153,7 @@ class SegmentationTest(test_util.TensorFlowTestCase):
expected = connected_components_reference_implementation(images)
if expected is None:
return
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
image_ops.connected_components(images).eval(), expected)
diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py
index 3f4029e558..e5980c53b2 100644
--- a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py
@@ -47,7 +47,7 @@ class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase):
normalize=True)
shape_1 = sirds_1.get_shape().as_list()
self.assertEqual(shape_1, [768, 1024, 1])
- with self.test_session():
+ with self.cached_session():
r_tf_1 = sirds_1.eval()
self.assertAllEqual(shape_1, r_tf_1.shape)
@@ -59,7 +59,7 @@ class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase):
normalize=True)
shape_2 = sirds_2.get_shape().as_list()
self.assertEqual(shape_2, [768, 1024, 3])
- with self.test_session():
+ with self.cached_session():
r_tf_2 = sirds_2.eval()
self.assertAllEqual(shape_2, r_tf_2.shape)
@@ -73,7 +73,7 @@ class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase):
output_image_shape=[1200, 800, 1])
shape_3 = sirds_3.get_shape().as_list()
self.assertEqual(shape_3, [800, 1200, 1])
- with self.test_session():
+ with self.cached_session():
r_tf_3 = sirds_3.eval()
self.assertAllEqual(shape_3, r_tf_3.shape)
diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
index 0135c66e29..ce9e34df73 100644
--- a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py
@@ -107,7 +107,7 @@ class SparseImageWarpTest(test_util.TensorFlowTestCase):
regularization_weight=regularization,
num_boundary_points=num_boundary_points)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
warped_image, input_image, _ = sess.run(
[warped_image_op, input_image_op, flow_field])
@@ -149,7 +149,7 @@ class SparseImageWarpTest(test_util.TensorFlowTestCase):
interpolation_order=order,
num_boundary_points=num_boundary_points)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
warped_image, input_image, flow = sess.run(
[warped_image_op, input_image_op, flow_field])
# Check that it moved the pixel correctly.
@@ -176,7 +176,7 @@ class SparseImageWarpTest(test_util.TensorFlowTestCase):
test_data_dir = test.test_src_dir_path('contrib/image/python/'
'kernel_tests/test_data/')
input_file = test_data_dir + 'Yellow_Smiley_Face.png'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_image = self.load_image(input_file, sess)
control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111],
[180 - 39, 111], [90, 143], [58, 134],
@@ -199,7 +199,7 @@ class SparseImageWarpTest(test_util.TensorFlowTestCase):
control_points_op + control_point_displacements_op,
interpolation_order=interpolation_order,
num_boundary_points=num_boundary_points)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
warped_image = sess.run(warp_op)
out_image = np.uint8(warped_image[0, :, :, :] * 255)
target_file = (
@@ -244,7 +244,7 @@ class SparseImageWarpTest(test_util.TensorFlowTestCase):
opt_func = optimizer.apply_gradients(zip(grad, [image]))
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run([loss, opt_func])
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index e7a09041ad..d4fb99a017 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -39,6 +39,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn)
# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
@@ -290,7 +291,7 @@ def transform(images,
else:
raise TypeError("Transforms should have rank 1 or 2.")
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
@@ -391,7 +392,7 @@ def matrices_to_flat_transforms(transform_matrices):
return transforms[:, :8]
-@ops.RegisterGradient("ImageProjectiveTransform")
+@ops.RegisterGradient("ImageProjectiveTransformV2")
def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
@@ -415,7 +416,7 @@ def _image_projective_transform_grad(op, grad):
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images=grad,
transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3],
diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py
index daf8c56456..f0b408faa3 100644
--- a/tensorflow/contrib/image/python/ops/interpolate_spline.py
+++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py
@@ -17,9 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@@ -95,10 +92,22 @@ def _solve_interpolation(train_points, train_values, order,
Returns:
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
+ Raises:
+ ValueError: if d or k is not fully specified.
"""
- b, n, d = train_points.get_shape().as_list()
- _, _, k = train_values.get_shape().as_list()
+ # These dimensions are set dynamically at runtime.
+ b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3)
+
+ d = train_points.shape[-1]
+ if d.value is None:
+ raise ValueError('The dimensionality of the input points (d) must be '
+ 'statically-inferrable.')
+
+ k = train_values.shape[-1]
+ if k.value is None:
+ raise ValueError('The dimensionality of the output values (k) must be '
+ 'statically-inferrable.')
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
@@ -113,14 +122,12 @@ def _solve_interpolation(train_points, train_values, order,
matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
if regularization_weight > 0:
- batch_identity_matrix = np.expand_dims(np.eye(n), 0)
- batch_identity_matrix = constant_op.constant(
- batch_identity_matrix, dtype=train_points.dtype)
-
+ batch_identity_matrix = array_ops.expand_dims(
+ linalg_ops.eye(n, dtype=c.dtype), 0)
matrix_a += regularization_weight * batch_identity_matrix
# Append ones to the feature values for the bias term in the linear model.
- ones = array_ops.ones([b, n, 1], train_points.dtype)
+ ones = array_ops.ones_like(c[..., :1], dtype=c.dtype)
matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1]
# [b, n + d + 1, n]
@@ -164,9 +171,6 @@ def _apply_interpolation(query_points, train_points, w, v, order):
Polyharmonic interpolation evaluated at points defined in query_points.
"""
- batch_size = train_points.get_shape()[0].value
- num_query_points = query_points.get_shape()[1].value
-
# First, compute the contribution from the rbf term.
pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
phi_pairwise_dists = _phi(pairwise_dists, order)
@@ -177,7 +181,7 @@ def _apply_interpolation(query_points, train_points, w, v, order):
# Pad query_points with ones, for the bias term in the linear model.
query_points_pad = array_ops.concat([
query_points,
- array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
+ array_ops.ones_like(query_points[..., :1], train_points.dtype)
], 2)
linear_term = math_ops.matmul(query_points_pad, v)
@@ -251,6 +255,9 @@ def interpolate_spline(train_points,
Note the interpolation procedure is differentiable with respect to all inputs
besides the order parameter.
+ We support dynamically-shaped inputs, where batch_size, n, and m are None
+ at graph construction time. However, d and k must be known.
+
Args:
train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
locations. These do not need to be regularly-spaced.
diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py
index 694f0c14bd..3c37f152e5 100644
--- a/tensorflow/contrib/integrate/__init__.py
+++ b/tensorflow/contrib/integrate/__init__.py
@@ -15,7 +15,9 @@
"""Integration and ODE solvers.
-See the @{$python/contrib.integrate} guide.
+See the
+[Contrib Integrate](https://tensorflow.org/api_guides/python/contrib.integrate)
+guide.
@@odeint
@@odeint_fixed
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index c7b4e2faa8..be915ef96f 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -49,7 +49,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, y0, t)
self.assertIn('odeint', y_solved.name)
self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(t)
self.assertAllClose(y_true, y_solved)
@@ -62,7 +62,7 @@ class OdeIntTest(test.TestCase):
func = lambda y, t: k * y
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, 1.0 + 0.0j, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(k * t)
self.assertAllClose(y_true, y_solved)
@@ -74,7 +74,7 @@ class OdeIntTest(test.TestCase):
func = lambda t, y: (y - t)**2 + 1.0
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, np.float64(0.5), t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = 1.0 / (2.0 - t) + t
self.assertAllClose(y_true, y_solved)
@@ -96,7 +96,7 @@ class OdeIntTest(test.TestCase):
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, y0, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.zeros((len(t), 2, 1))
@@ -113,7 +113,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
self.assertEqual(y_solved.get_shape(),
tensor_shape.TensorShape(expected_shape))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
self.assertEquals(y_solved.shape, expected_shape)
@@ -126,7 +126,7 @@ class OdeIntTest(test.TestCase):
for t_dtype in [dtypes.float32, dtypes.float64]:
y0 = math_ops.cast(1.0, y0_dtype)
y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
expected = np.asarray(np.exp(t))
self.assertAllClose(y_solved, expected, rtol=1e-5)
@@ -148,13 +148,13 @@ class OdeIntTest(test.TestCase):
self.y0, [0, 1],
method='dopri5',
options={'max_num_steps': 0})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'max_num_steps'):
sess.run(y)
y = odes.odeint(self.func, self.y0, [1, 0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'monotonic increasing'):
sess.run(y)
@@ -164,7 +164,7 @@ class OdeIntTest(test.TestCase):
times0 = np.linspace(0, 10, num=11, dtype=float)
times1 = np.linspace(0, 10, num=101, dtype=float)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved_0, info_0 = sess.run(
odes.odeint(self.func, self.y0, times0, full_output=True))
y_solved_1, info_1 = sess.run(
@@ -179,7 +179,7 @@ class OdeIntTest(test.TestCase):
t = [0, 20]
kwargs = dict(
full_output=True, method='dopri5', options=dict(max_num_steps=2000))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_, info_0 = sess.run(
odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
_, info_1 = sess.run(
@@ -196,7 +196,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.9)
@@ -204,7 +204,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(0.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 10.0)
@@ -212,7 +212,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1e6))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.2)
@@ -229,13 +229,13 @@ class InterpolationTest(test.TestCase):
y_fit = array_ops.stack(
[odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
y_expected = f(times)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = sess.run(y_fit)
self.assertAllClose(y_expected, y_actual)
# attempt interpolation outside bounds
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(y_invalid)
@@ -251,7 +251,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [0., 1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
@@ -265,7 +265,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD
deleted file mode 100644
index b719046b37..0000000000
--- a/tensorflow/contrib/kfac/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-# Description:
-# Contains KfacOptimizer, an implementation of the K-FAC optimization
-# algorithm in TensorFlow.
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "kfac",
- srcs = ["__init__.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
- "//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
- "//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
- "//tensorflow/contrib/kfac/python/ops:op_queue_lib",
- "//tensorflow/contrib/kfac/python/ops:utils_lib",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md
index 102626925d..42b91d0313 100644
--- a/tensorflow/contrib/kfac/README.md
+++ b/tensorflow/contrib/kfac/README.md
@@ -1,94 +1,3 @@
# K-FAC: Kronecker-Factored Approximate Curvature
-# <font color="red", size=10><u>WARNING: </u></font>
-# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
-# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
-# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
-
-**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
-approximate second-order optimization method, in TensorFlow. When applied to
-feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
-faster in `>14x` fewer iterations than SGD with Momentum.
-
-[kfac-paper]: https://arxiv.org/abs/1503.05671
-
-## What is K-FAC?
-
-K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
-to the [Natural Gradient][natural_gradient] algorithm designed specifically for
-neural networks. It maintains a block-diagonal approximation to the [Fisher
-Information matrix][fisher_information], whose inverse preconditions the
-gradient.
-
-K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
-Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
-
-Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
-are the weights for layer i?"). As such, you must add some additional code while
-constructing your model to use K-FAC.
-
-[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
-[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
-
-## Why should I use K-FAC?
-
-K-FAC can take advantage of the curvature of the optimization problem, resulting
-in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
-loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
-training loss changes as a function of number of epochs, steps, and seconds:
-
-![autoencoder](g3doc/autoencoder.png)
-
-## Is K-FAC for me?
-
-If you have a feedforward or convolutional model for classification that is
-converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
-
-* Your model defines a posterior distribution.
-* Your model uses only fully-connected or convolutional layers (residual
- connections OK).
-* You are training on CPU or GPU.
-* You can modify model code to register layers with K-FAC.
-
-## How do I use K-FAC?
-
-Using K-FAC requires three steps:
-
-1. Registering layer inputs, weights, and pre-activations with a
- `LayerCollection`.
-1. Minimizing the loss with a `KfacOptimizer`.
-1. Keeping K-FAC's preconditioner updated.
-
-```python
-# Build model.
-w = tf.get_variable("w", ...)
-b = tf.get_variable("b", ...)
-logits = tf.matmul(x, w) + b
-loss = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
-
-# Register layers.
-layer_collection = LayerCollection()
-layer_collection.register_fully_connected((w, b), x, logits)
-layer_collection.register_categorical_predictive_distribution(logits)
-
-# Construct training ops.
-optimizer = KfacOptimizer(..., layer_collection=layer_collection)
-train_op = optimizer.minimize(loss)
-
-# Minimize loss.
-with tf.Session() as sess:
- ...
- sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
-```
-
-See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
-
-## Authors
-
-- Alok Aggarwal
-- Daniel Duckworth
-- James Martens
-- Matthew Johnson
-- Olga Wichrowska
-- Roger Grosse
+## KFAC moved to third_party/tensorflow_kfac.
diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py
deleted file mode 100644
index 1ea354e6cd..0000000000
--- a/tensorflow/contrib/kfac/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Kronecker-factored Approximate Curvature Optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
-from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
-from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
-from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
-from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
-from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
-from tensorflow.contrib.kfac.python.ops import utils_lib as utils
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long
-
-_allowed_symbols = [
- "curvature_matrix_vector_products",
- "estimator",
- "fisher_blocks",
- "fisher_factors",
- "layer_collection",
- "loss_functions",
- "op_queue",
- "optimizer",
- "utils",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
deleted file mode 100644
index 8186fa1c62..0000000000
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ /dev/null
@@ -1,80 +0,0 @@
-package(default_visibility = [
- "//learning/brain/contrib/kfac/examples:__subpackages__",
- "//tensorflow/contrib/kfac/examples:__subpackages__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_binary(
- name = "mlp_mnist_main",
- srcs = ["mlp_mnist_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "mlp",
- srcs = ["mlp.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mnist",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_single_main",
- srcs = ["convnet_mnist_single_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_multi_tower_main",
- srcs = ["convnet_mnist_multi_tower_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_distributed_main",
- srcs = ["convnet_mnist_distributed_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "convnet",
- srcs = ["convnet.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- ":mnist",
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "mnist",
- srcs = ["mnist.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
deleted file mode 100644
index 44e01e1aeb..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ /dev/null
@@ -1,667 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
-following structure,
-
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Linear: 10 output dims.
-
-After 3k~6k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-from tensorflow.contrib.kfac.examples import mnist
-from tensorflow.contrib.kfac.python.ops import optimizer as opt
-
-
-lc = tf.contrib.kfac.layer_collection
-oq = tf.contrib.kfac.op_queue
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "conv_layer",
- "max_pool_layer",
- "linear_layer",
- "build_model",
- "minimize_loss_single_machine",
- "distributed_grads_only_and_ops_chief_worker",
- "distributed_grads_and_ops_dedicated_workers",
- "train_mnist_single_machine",
- "train_mnist_distributed_sync_replicas",
- "train_mnist_multitower"
-]
-
-
-# Inverse update ops will be run every _INVERT_EVRY iterations.
-_INVERT_EVERY = 10
-
-
-def conv_layer(layer_id, inputs, kernel_size, out_channels):
- """Builds a convolutional layer with ReLU non-linearity.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height of the convolution kernel. The kernel is
- assumed to be square.
- out_channels: int. Number of output features per pixel.
-
- Returns:
- preactivations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately before the activation function.
- activations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately after the activation function.
- params: Tuple of (kernel, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Conv2D(
- out_channels,
- kernel_size=[kernel_size, kernel_size],
- kernel_initializer=tf.random_normal_initializer(stddev=0.01),
- padding="SAME",
- name="conv_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.relu(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def max_pool_layer(layer_id, inputs, kernel_size, stride):
- """Build a max-pooling layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height to pool over per input channel. The
- kernel is assumed to be square.
- stride: int. Step size between pooling operations.
-
- Returns:
- Tensor of shape [num_examples, width/stride, height/stride, out_channels].
- Result of applying max pooling to 'inputs'.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- with tf.variable_scope("pool_%d" % layer_id):
- return tf.nn.max_pool(
- inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
- padding="SAME",
- name="pool")
-
-
-def linear_layer(layer_id, inputs, output_size):
- """Builds the final linear layer for an MNIST classification problem.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- output_size: int. Number of output dims per example.
-
- Returns:
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
- return pre, params
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds a ConvNet classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance. Layers will be registered here.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build a ConvNet. For each layer with parameters, we'll keep track of the
- # preactivations, activations, weights, and bias.
- tf.logging.info("Building model.")
- pre0, act0, params0 = conv_layer(
- layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
- act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
- pre2, act2, params2 = conv_layer(
- layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
- act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
- flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
- logits, params4 = linear_layer(
- layer_id=4, inputs=flat_act3, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- with tf.device("/cpu:0"):
- tf.summary.scalar("loss", loss)
- tf.summary.scalar("accuracy", accuracy)
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each conv/fully connected layer and the logits powering the
- # posterior probability over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
- pre0)
- layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
- layer_collection.register_fully_connected(params4, flat_act3, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize_loss_single_machine(loss,
- accuracy,
- layer_collection,
- device="/gpu:0",
- session_config=None):
- """Minimize loss with K-FAC on a single machine.
-
- A single Session is responsible for running all of K-FAC's ops. The covariance
- and inverse update ops are placed on `device`. All model variables are on CPU.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
- session_config: None or tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- final value for 'accuracy'.
- """
- # Train with K-FAC.
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=[device],
- inv_devices=[device],
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- with tf.device(device):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def _is_gradient_task(task_id, num_tasks):
- """Returns True if this task should update the weights."""
- if num_tasks < 3:
- return True
- return 0 <= task_id < 0.6 * num_tasks
-
-
-def _is_cov_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's covariance matrices."""
- if num_tasks < 3:
- return False
- return 0.6 * num_tasks <= task_id < num_tasks - 1
-
-
-def _is_inv_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's preconditioner."""
- if num_tasks < 3:
- return False
- return task_id == num_tasks - 1
-
-
-def _num_gradient_tasks(num_tasks):
- """Number of tasks that will update weights."""
- if num_tasks < 3:
- return num_tasks
- return int(np.ceil(0.6 * num_tasks))
-
-
-def _make_distributed_train_op(
- task_id,
- num_worker_tasks,
- num_ps_tasks,
- layer_collection
-):
- """Creates optimizer and distributed training op.
-
- Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
- the train op.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
- optimizer.
- optimizer: Instance of `opt.KfacOptimizer`.
- global_step: `tensor`, Global step.
- """
- tf.logging.info("Task id : %d", task_id)
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- momentum=0.9)
- sync_optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
- total_num_replicas=num_worker_tasks)
- return sync_optimizer, optimizer, global_step
-
-
-def distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection, invert_every=10):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- All workers perform gradient computation. Chief worker applies gradient after
- averaging the gradients obtained from all the workers. All workers block
- execution until the update is applied. Chief worker runs covariance and
- inverse update ops. Covariance and inverse matrices are placed on parameter
- servers in a round robin manner. For further details on synchronous
- distributed optimization check `tf.train.SyncReplicasOptimizer`.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- invert_every: `int`, Number of steps between update the inverse.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
-
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- tf.logging.info("Starting training.")
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- if is_chief:
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, invert_every), 0),
- lambda: make_update_op(inv_update_thunks),
- tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- else:
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
-
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
- return accuracy_
-
-
-def distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- Different workers are responsible for different parts of K-FAC's Ops. The
- first 60% of tasks compute gradients; the next 20% accumulate covariance
- statistics; the last 20% invert the matrices used to precondition gradients.
- The chief worker applies the gradient .
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- inv_update_queue = oq.OpQueue(inv_update_ops)
-
- tf.logging.info("Starting training.")
- is_chief = (task_id == 0)
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- # Choose which op this task is responsible for running.
- if _is_gradient_task(task_id, num_worker_tasks):
- learning_op = train_op
- elif _is_cov_update_task(task_id, num_worker_tasks):
- learning_op = cov_update_op
- elif _is_inv_update_task(task_id, num_worker_tasks):
- # TODO(duckworthd): Running this op before cov_update_op has been run a
- # few times can result in "InvalidArgumentError: Cholesky decomposition
- # was not successful." Delay running this op until cov_update_op has
- # been run a few times.
- learning_op = inv_update_queue.next_op(sess)
- else:
- raise ValueError("Which op should task %d do?" % task_id)
-
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, learning_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist_single_machine(data_dir,
- num_epochs,
- use_fake_data=False,
- device="/gpu:0"):
- """Train a ConvNet on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- return minimize_loss_single_machine(
- loss, accuracy, layer_collection, device=device)
-
-
-def train_mnist_multitower(data_dir, num_epochs, num_towers,
- use_fake_data=True, devices=None):
- """Train a ConvNet on MNIST.
-
- Training data is split equally among the towers. Each tower computes loss on
- its own batch of data and the loss is aggregated on the CPU. The model
- variables are placed on first tower. The covariance and inverse update ops
- and variables are placed on GPUs in a round robin manner.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split inference across.
- use_fake_data: bool. If True, generate a synthetic dataset.
- devices: string, Either list of CPU or GPU. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- if devices:
- device_count = {"GPU": num_towers}
- else:
- device_count = {"CPU": num_towers}
-
- devices = devices or [
- "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
- ]
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- tower_batch_size = 128
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device(devices[tower_id]):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
-
- session_config = tf.ConfigProto(
- allow_soft_placement=False,
- device_count=device_count,
- )
-
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices,
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
-
-def train_mnist_distributed_sync_replicas(task_id,
- is_chief,
- num_worker_tasks,
- num_ps_tasks,
- master,
- data_dir,
- num_epochs,
- op_strategy,
- use_fake_data=False):
- """Train a ConvNet on MNIST using Sync replicas optimizer.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables.
- master: string. IP and port of TensorFlow runtime process.
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- op_strategy: `string`, Strategy to run the covariance and inverse
- ops. If op_strategy == `chief_worker` then covariance and inverse
- update ops are run on chief worker otherwise they are run on dedicated
- workers.
-
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
-
- Raises:
- ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
- if op_strategy == "chief_worker":
- return distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- elif op_strategy == "dedicated_workers":
- return distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- else:
- raise ValueError("Only supported op strategies are : {}, {}".format(
- "chief_worker", "dedicated_workers"))
-
-
-if __name__ == "__main__":
- tf.app.run()
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
deleted file mode 100644
index b4c2d4a9e9..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Distributed training with sync replicas optimizer. See
-`convnet.train_mnist_distributed_sync_replicas` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer("task", -1, "Task identifier")
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-flags.DEFINE_string(
- "cov_inv_op_strategy", "chief_worker",
- "In dist training mode run the cov, inv ops on chief or dedicated workers."
-)
-flags.DEFINE_string("master", "local", "Session master.")
-flags.DEFINE_integer("ps_tasks", 2,
- "Number of tasks in the parameter server job.")
-flags.DEFINE_integer("replicas_to_aggregate", 5,
- "Number of replicas to aggregate.")
-flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
-flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
-
-
-def _is_chief():
- """Determines whether a job is the chief worker."""
- if "chief_worker" in FLAGS.brain_jobs:
- return FLAGS.brain_job_name == "chief_worker"
- else:
- return FLAGS.task == 0
-
-
-def main(unused_argv):
- _ = unused_argv
- convnet.train_mnist_distributed_sync_replicas(
- FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
- FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
deleted file mode 100644
index 4249bf8a8d..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Multi tower training mode. See `convnet.train_mnist_multitower` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
-flags.DEFINE_integer("num_towers", 2,
- "Number of towers for multi tower training.")
-
-
-def main(unused_argv):
- _ = unused_argv
- assert FLAGS.num_towers > 1
- devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
- convnet.train_mnist_multitower(
- FLAGS.data_dir,
- num_epochs=200,
- num_towers=FLAGS.num_towers,
- devices=devices)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
deleted file mode 100644
index 2c1f099360..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Train on single machine. See `convnet.train_mnist_single_machine` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-
-
-def main(unused_argv):
- convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
deleted file mode 100644
index ea2b252a05..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
-~25k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-lc = tf.contrib.kfac.layer_collection
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "fc_layer",
- "train_mnist",
- "train_mnist_multitower",
-]
-
-
-def fc_layer(layer_id, inputs, output_size):
- """Builds a fully connected layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
- to a single example.
- output_size: int. Number of output dimensions after fully connected layer.
-
- Returns:
- preactivations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately before the activation function.
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Dense(
- output_size,
- kernel_initializer=tf.random_normal_initializer(),
- name="fc_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.tanh(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds an MLP classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance describing model architecture.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build an MLP. For each layer, we'll keep track of the preactivations,
- # activations, weights, and bias.
- pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
- pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
- pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
- logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each layer and the logits powering the posterior probability
- # over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_fully_connected(params0, examples, pre0)
- layer_collection.register_fully_connected(params1, act0, pre1)
- layer_collection.register_fully_connected(params2, act1, pre2)
- layer_collection.register_fully_connected(params3, act2, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
- """Minimize 'loss' with KfacOptimizer.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance. Describes layers in model.
- num_towers: int. Number of CPUs to split minibatch across.
- session_config: tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- accuracy of classifier on final minibatch.
- """
- devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
-
- # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
- # every 10k iterations.
- tf.logging.info("Building KFAC Optimizer.")
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0005,
- layer_collection=layer_collection,
- momentum=0.99,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
- # once that gets moved over? Could still leave more advanced examples as they
- # are (e.g. train_mnist_estimator in this file)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # We update the inverses only every 20 iterations.
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, 100), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
-
- if global_step_ % 100 == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Build an MLP. The model's layers will be added to the LayerCollection.
- tf.logging.info("Building model.")
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(examples, labels, 10, layer_collection)
-
- # Fit model.
- minimize(loss, accuracy, layer_collection, 1)
-
-
-def train_mnist_multitower(data_dir,
- num_epochs,
- num_towers,
- use_fake_data=False):
- """Train an MLP on MNIST, splitting the minibatch across multiple towers.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split minibatch across.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tower_batch_size = 64
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device("/cpu:%d" % tower_id):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
- session_config = tf.ConfigProto(
- allow_soft_placement=False, device_count={
- "CPU": num_towers
- })
- return minimize(
- loss, accuracy, layer_collection, num_towers,
- session_config=session_config)
-
-
-def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST using tf.estimator.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
-
- # Load a dataset.
- def input_fn():
- tf.logging.info("Loading MNIST into memory.")
- return mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- def model_fn(features, labels, mode, params):
- """Model function for MLP trained with K-FAC.
-
- Args:
- features: Tensor of shape [batch_size, input_size]. Input features.
- labels: Tensor of shape [batch_size]. Target labels for training.
- mode: tf.estimator.ModeKey. Must be TRAIN.
- params: ignored.
-
- Returns:
- EstimatorSpec for training.
-
- Raises:
- ValueError: If 'mode' is anything other than TRAIN.
- """
- del params
-
- if mode != tf.estimator.ModeKeys.TRAIN:
- raise ValueError("Only training is supposed with this API.")
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- features, labels, num_labels=10, layer_collection=layer_collection)
-
- # Train with K-FAC.
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0001,
- layer_collection=layer_collection,
- momentum=0.99)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- def make_batch_executed_op(update_thunks, batch_size=1):
- return tf.group(*tf.contrib.kfac.utils.batch_execute(
- global_step, update_thunks, batch_size=batch_size))
-
- # Run cov_update_op every step. Run 1 inv_update_ops per step.
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # But make sure to execute all the inverse ops on the first step
- inverse_op = tf.cond(tf.equal(global_step, 0),
- lambda: make_update_op(inv_update_thunks),
- lambda: make_batch_executed_op(inv_update_thunks))
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- # Print metrics every 5 sec.
- hooks = [
- tf.train.LoggingTensorHook(
- {
- "loss": loss,
- "accuracy": accuracy
- }, every_n_secs=5),
- ]
- return tf.estimator.EstimatorSpec(
- mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
-
- run_config = tf.estimator.RunConfig(
- model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
-
- # Train until input_fn() is empty with Estimator. This is a prerequisite for
- # TPU compatibility.
- estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
- estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
deleted file mode 100644
index 9c34ade1d2..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-See mlp.py for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import sys
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-FLAGS = None
-
-
-def main(argv):
- _ = argv
- if FLAGS.use_estimator:
- if FLAGS.num_towers != 1:
- raise ValueError("Only 1 device supported in tf.estimator example.")
- mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
- elif FLAGS.num_towers > 1:
- mlp.train_mnist_multitower(
- FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
- else:
- mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir",
- type=str,
- default="/tmp/mnist",
- help="Directory to store dataset in.")
- parser.add_argument(
- "--num_towers",
- type=int,
- default=1,
- help="Number of CPUs to split minibatch across.")
- parser.add_argument(
- "--use_estimator",
- action="store_true",
- help="Use tf.estimator API to train.")
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py
deleted file mode 100644
index 547c4ab25d..0000000000
--- a/tensorflow/contrib/kfac/examples/mnist.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities for loading MNIST into TensorFlow."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-__all__ = [
- 'load_mnist',
-]
-
-
-def load_mnist(data_dir,
- num_epochs,
- batch_size,
- flatten_images=True,
- use_fake_data=False):
- """Loads MNIST dataset into memory.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the dataset.
- batch_size: int. Number of examples per minibatch.
- flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
- [784]-shaped vectors.
- use_fake_data: bool. If True, generate a synthetic dataset rather than
- reading MNIST in.
-
- Returns:
- examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
- True, else [batch_size, 28, 28, 1]. Each row is one example.
- Values in [0, 1].
- labels: Tensor of shape [batch_size]. Indices of integer corresponding to
- each example. Values in {0...9}.
- """
- if use_fake_data:
- rng = np.random.RandomState(42)
- num_examples = batch_size * 4
- images = rng.rand(num_examples, 28 * 28)
- if not flatten_images:
- images = np.reshape(images, [num_examples, 28, 28, 1])
- labels = rng.randint(10, size=num_examples)
- else:
- mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
- data_dir, reshape=flatten_images)
- num_examples = len(mnist_data.train.labels)
- images = mnist_data.train.images
- labels = mnist_data.train.labels
-
- dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
- images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
- return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
- .make_one_shot_iterator().get_next())
diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD
deleted file mode 100644
index ede7f183fe..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/BUILD
+++ /dev/null
@@ -1,52 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "mlp_test",
- size = "large",
- srcs = ["mlp_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mlp",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "convnet_test",
- size = "large",
- srcs = ["convnet_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac",
- "//tensorflow/contrib/kfac/examples:convnet",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "mnist_test",
- srcs = ["mnist_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mnist",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
deleted file mode 100644
index adecda7166..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for convnet.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac import layer_collection as lc
-from tensorflow.contrib.kfac.examples import convnet
-
-
-class ConvNetTest(tf.test.TestCase):
-
- def testConvLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = convnet.conv_layer(
- layer_id=1,
- inputs=tf.zeros([5, 3, 3, 2]),
- kernel_size=3,
- out_channels=5)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
- self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("conv_1", w.op.name)
- self.assertIn("conv_1", b.op.name)
-
- def testMaxPoolLayer(self):
- with tf.Graph().as_default():
- act = convnet.max_pool_layer(
- layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
- self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
- self.assertEqual(act.op.name, "pool_1/pool")
-
- def testLinearLayer(self):
- with tf.Graph().as_default():
- act, (w, b) = convnet.linear_layer(
- layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
- self.assertShapeEqual(np.zeros([5, 5]), act)
- self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1", w.op.name)
- self.assertIn("fc_1", b.op.name)
-
- def testBuildModel(self):
- with tf.Graph().as_default():
- x = tf.placeholder(tf.float32, [None, 6, 6, 3])
- y = tf.placeholder(tf.int64, [None])
- layer_collection = lc.LayerCollection()
- loss, accuracy = convnet.build_model(
- x, y, num_labels=5, layer_collection=layer_collection)
-
- # Ensure layers and logits were registered.
- self.assertEqual(len(layer_collection.fisher_blocks), 3)
- self.assertEqual(len(layer_collection.losses), 1)
-
- # Ensure inference doesn't crash.
- with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
- feed_dict = {
- x: np.random.randn(10, 6, 6, 3).astype(np.float32),
- y: np.random.randint(5, size=10).astype(np.int64),
- }
- sess.run([loss, accuracy], feed_dict=feed_dict)
-
- def _build_toy_problem(self):
- """Construct a toy linear regression problem.
-
- Initial loss should be,
- 2.5 = 0.5 * (1^2 + 2^2)
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensors representing model accuracy.
- layer_collection: LayerCollection instance describing model architecture.
- """
- x = np.asarray([[1.], [2.]]).astype(np.float32)
- y = np.asarray([1., 2.]).astype(np.float32)
- x, y = (tf.data.Dataset.from_tensor_slices((x, y))
- .repeat(100).batch(2).make_one_shot_iterator().get_next())
- w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
- y_hat = tf.matmul(x, w)
- loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
- accuracy = loss
-
- layer_collection = lc.LayerCollection()
- layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
- layer_collection.register_normal_predictive_distribution(y_hat)
-
- return loss, accuracy, layer_collection
-
- def testMinimizeLossSingleMachine(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_single_machine(
- loss, accuracy, layer_collection, device="/cpu:0")
- self.assertLess(accuracy_, 2.0)
-
- def testMinimizeLossDistributed(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- checkpoint_dir=None,
- loss=loss,
- accuracy=accuracy,
- layer_collection=layer_collection)
- self.assertLess(accuracy_, 2.0)
-
- def testTrainMnistSingleMachine(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but there are too few parameters for the model to effectively memorize
- # the training set the way an MLP can.
- convnet.train_mnist_single_machine(
- data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistDistributed(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_distributed_sync_replicas(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- data_dir=None,
- num_epochs=2,
- op_strategy="chief_worker",
- use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
deleted file mode 100644
index 22da6c29f1..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mlp.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-
-class MlpTest(tf.test.TestCase):
-
- def testFcLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = mlp.fc_layer(
- layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
- self.assertShapeEqual(np.zeros([5, 10]), pre)
- self.assertShapeEqual(np.zeros([5, 10]), act)
- self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1/", w.op.name)
- self.assertIn("fc_1/", b.op.name)
-
- def testTrainMnist(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but that takes a non-trivial amount of compute.
- mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistEstimator(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py
deleted file mode 100644
index 92f8462357..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mnist_test.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mnist.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-
-class MnistTest(tf.test.TestCase):
-
- def testValues(self):
- """Ensure values are in their expected range."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
- self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
-
- def testFlattenedShapes(self):
- """Ensure images are flattened into their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=True,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 784))
- self.assertEqual(labels_.shape, (64,))
-
- def testNotFlattenedShapes(self):
- """Ensure non-flattened images are their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=False,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 28, 28, 1))
- self.assertEqual(labels_.shape, (64,))
-
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png
deleted file mode 100644
index 20f93c7703..0000000000
--- a/tensorflow/contrib/kfac/g3doc/autoencoder.png
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
deleted file mode 100644
index 6e4a8d71ba..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ /dev/null
@@ -1,160 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "estimator_test",
- srcs = ["estimator_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_factors_test",
- srcs = ["fisher_factors_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_blocks_test",
- srcs = ["fisher_blocks_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:linear_operator",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "layer_collection_test",
- srcs = ["layer_collection_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- ],
-)
-
-py_test(
- name = "optimizer_test",
- srcs = ["optimizer_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "utils_test",
- srcs = ["utils_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
- deps = [
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "op_queue_test",
- srcs = ["op_queue_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:op_queue",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- ],
-)
-
-py_test(
- name = "loss_functions_test",
- srcs = ["loss_functions_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:loss_functions",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
deleted file mode 100644
index 0e65d419a3..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ /dev/null
@@ -1,310 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.estimator."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import estimator
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import training_util
-
-_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
-
-
-class EstimatorTest(test.TestCase):
-
- def setUp(self):
- self._graph = ops.Graph()
- with self._graph.as_default():
- self.layer_collection = lc.LayerCollection()
-
- self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
- self.weights = variable_scope.get_variable(
- "w", shape=(2, 2), dtype=dtypes.float32)
- self.bias = variable_scope.get_variable(
- "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
- self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
-
- # Only register the weights.
- self.layer_collection.register_fully_connected(
- params=(self.weights,), inputs=self.inputs, outputs=self.output)
-
- self.outputs = math_ops.tanh(self.output)
- self.targets = array_ops.zeros_like(self.outputs)
- self.layer_collection.register_categorical_predictive_distribution(
- logits=self.outputs, targets=self.targets)
-
- def testEstimatorInitManualRegistration(self):
- with self._graph.as_default():
- # We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
-
- # Check that we throw an error if we try to build an estimator for vars
- # that were not manually registered.
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights, self.bias],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
- est.make_vars_and_create_op_thunks()
-
- # Check that we throw an error if we don't include registered variables,
- # i.e. self.weights
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
- def testVariableWrongNumberOfUses(self, mock_uses):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- def testInvalidEstimationMode(self):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="not_a_real_mode")
- est.make_vars_and_create_op_thunks()
-
- def testGradientsModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="gradients")
- est.make_vars_and_create_op_thunks()
-
- def testEmpiricalModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="empirical")
- est.make_vars_and_create_op_thunks()
-
- def testCurvaturePropModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="curvature_prop")
- est.make_vars_and_create_op_thunks()
-
- def testExactModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="exact")
- est.make_vars_and_create_op_thunks()
-
- def test_cov_update_thunks(self):
- """Ensures covariance update ops run once per global_step."""
- with self._graph.as_default(), self.test_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct an op that executes one covariance update per step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, cov_update_op_thunks, _,
- _) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- cov_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(cov_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_cov_values = sess.run(cov_matrices)
-
- # Ensure there's one update per covariance matrix.
- self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
-
- # Test is no-op if only 1 covariance matrix.
- assert len(cov_matrices) > 1
-
- for i in range(len(cov_matrices)):
- # Compare new and old covariance values
- new_cov_values = sess.run(cov_matrices)
- is_cov_equal = [
- np.allclose(initial_cov_value, new_cov_value)
- for (initial_cov_value,
- new_cov_value) in zip(initial_cov_values, new_cov_values)
- ]
- num_cov_equal = sum(is_cov_equal)
-
- # Ensure exactly one covariance matrix changes per step.
- self.assertEqual(num_cov_equal, len(cov_matrices) - i)
-
- # Run all covariance update ops.
- sess.run(cov_update_op)
- sess.run(increment_global_step)
-
- def test_round_robin_placement(self):
- """Check if the ops and variables are placed on devices correctly."""
- with self._graph.as_default():
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0,
- cov_devices=["/cpu:{}".format(i) for i in range(2)],
- inv_devices=["/cpu:{}".format(i) for i in range(2)])
-
- # Construct an op that executes one covariance update per step.
- (cov_update_thunks,
- inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
- scope="test")
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
- self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
- self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
- self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
- # Inverse matrices need to be explicitly placed.
- self.assertEqual(inv_matrices[0].device, "")
- self.assertEqual(inv_matrices[1].device, "")
-
- def test_inv_update_thunks(self):
- """Ensures inverse update ops run once per global_step."""
- with self._graph.as_default(), self.test_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct op that updates one inverse per global step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, _, inv_variable_thunks,
- inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- for thunk in inv_variable_thunks:
- thunk()
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- inv_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(inv_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_inv_values = sess.run(inv_matrices)
-
- # Ensure there's one update per inverse matrix. This is true as long as
- # there's no fan-in/fan-out or parameter re-use.
- self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
-
- # Test is no-op if only 1 invariance matrix.
- assert len(inv_matrices) > 1
-
- # Assign each covariance matrix a value other than the identity. This
- # ensures that the inverse matrices are updated to something different as
- # well.
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- sess.run([
- cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
- for cov_matrix in cov_matrices
- ])
-
- for i in range(len(inv_matrices)):
- # Compare new and old inverse values
- new_inv_values = sess.run(inv_matrices)
- is_inv_equal = [
- np.allclose(initial_inv_value, new_inv_value)
- for (initial_inv_value,
- new_inv_value) in zip(initial_inv_values, new_inv_values)
- ]
- num_inv_equal = sum(is_inv_equal)
-
- # Ensure exactly one inverse matrix changes per step.
- self.assertEqual(num_inv_equal, len(inv_matrices) - i)
-
- # Run all inverse update ops.
- sess.run(inv_update_op)
- sess.run(increment_global_step)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
deleted file mode 100644
index 86ec7a095a..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ /dev/null
@@ -1,1018 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_blocks."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
-# inverse is something other than the identity" are actually broken. They never
-# run the covariance update ops and so the inverse actually is the identity
-# (possible plus the damping term, which would still make it a multiple of the
-# identity).
-
-
-def _make_psd(dim):
- """Constructs a PSD matrix of the given dimension."""
- mat = np.ones((dim, dim), dtype=np.float32)
- mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
- return array_ops.constant(mat)
-
-
-class UtilsTest(test.TestCase):
-
- def testComputePiTracenorm(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- diag = ops.convert_to_tensor([1., 2., 0., 1.])
- left_factor = lo.LinearOperatorDiag(diag)
- right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
-
- # pi is the sqrt of the left trace norm divided by the right trace norm
- pi = fb.compute_pi_tracenorm(left_factor, right_factor)
-
- pi_val = sess.run(pi)
- self.assertEqual(1., pi_val)
-
-
-class FullFBTest(test.TestCase):
-
- def testFullFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testFullFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class NaiveDiagonalFBTest(test.TestCase):
-
- def testNaiveDiagonalFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testNaiveDiagonalFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
-
- cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
- sess.run(state_ops.assign(block._factor._cov, cov))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(FullyConnectedDiagonalFBTest, self).setUp()
-
- self.batch_size = 4
- self.input_size = 6
- self.output_size = 3
-
- self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
- np.float32)
- self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
- np.float32)
- self.output_grads = np.random.randn(self.batch_size,
- self.output_size).astype(np.float32)
- self.w = np.random.randn(self.input_size, self.output_size).astype(
- np.float32)
- self.b = np.random.randn(self.output_size).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs, np.ones([self.batch_size, 1])], axis=1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads):
- """Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[outer(input, output_grad)]
-
- where the expectation is taken over examples.
-
- Args:
- inputs: np.array of shape [batch_size, input_size].
- output_grads: np.array of shape [batch_size, output_size].
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- input_size * output_size.
- """
- batch_size = inputs.shape[0]
- assert output_grads.shape[0] == batch_size
- input_size = inputs.shape[1]
- output_size = output_grads.shape[1]
- fisher_diag = np.zeros((input_size, output_size))
- for i in range(batch_size):
- fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- expected_result = self.fisherApprox(True).dot(
- np.concatenate([self.w.flatten(), self.b.flatten()]))
- expected_result = expected_result.reshape(
- [self.input_size + 1, self.output_size])
- expected_result = (expected_result[:-1], expected_result[-1])
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.test_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.FullyConnectedDiagonalFB(
- lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class EmbeddingKFACFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Create a sparse update.
- indices = array_ops.constant([1, 3, 4])
- values = array_ops.constant([[1.], [1.], [1.]])
- sparse_vector = ops.IndexedSlices(
- values, indices, dense_shape=[vocab_size, 1])
- dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
-
- # Compare Fisher-vector product against explicit result.
- result = block.multiply_inverse(sparse_vector)
- expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
- dense_vector)
-
- sess.run(tf_variables.global_variables_initializer())
- self.assertAlmostEqual(
- sess.run(expected_result[1]), sess.run(result.values[0]))
- self.assertAlmostEqual(
- sess.run(expected_result[3]), sess.run(result.values[1]))
- self.assertAlmostEqual(
- sess.run(expected_result[4]), sess.run(result.values[2]))
-
-
-class FullyConnectedKFACBasicFBTest(test.TestCase):
-
- def testFullyConnectedKFACBasicFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (
- np.arange(2, 6).reshape(2, 2).astype(np.float32), #
- np.arange(1, 3).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- output[0])
- self.assertAllClose([0.343146, 0.686291], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- input_dim, output_dim = 3, 2
- inputs = array_ops.zeros([32, input_dim])
- outputs = array_ops.zeros([32, output_dim])
- params = array_ops.zeros([input_dim, output_dim])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
-
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(6, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class ConvDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(ConvDiagonalFBTest, self).setUp()
-
- self.batch_size = 2
- self.height = 8
- self.width = 4
- self.input_channels = 6
- self.output_channels = 3
- self.kernel_size = 1
-
- self.inputs = np.random.randn(self.batch_size, self.height, self.width,
- self.input_channels).astype(np.float32)
- self.outputs = np.zeros(
- [self.batch_size, self.height, self.width,
- self.output_channels]).astype(np.float32)
- self.output_grads = np.random.randn(
- self.batch_size, self.height, self.width, self.output_channels).astype(
- np.float32)
- self.w = np.random.randn(self.kernel_size, self.kernel_size,
- self.input_channels, self.output_channels).astype(
- np.float32)
- self.b = np.random.randn(self.output_channels).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs,
- np.ones([self.batch_size, self.height, self.width, 1])],
- axis=-1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
- self.kernel_size)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
- r"""Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
-
- where the expectation is taken over examples and the sum over (x, y)
- locations upon which the convolution is applied.
-
- Args:
- inputs: np.array of shape [batch_size, height, width, input_channels].
- output_grads: np.array of shape [batch_size, height, width,
- output_channels].
- kernel_size: int. height and width of kernel.
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- kernel_size^2 * input_channels * output_channels.
- """
- batch_size, height, width, input_channels = inputs.shape
- assert output_grads.shape[0] == batch_size
- assert output_grads.shape[1] == height
- assert output_grads.shape[2] == width
- output_channels = output_grads.shape[3]
-
- # If kernel_size == 1, then we don't need to worry about capturing context
- # around the pixel upon which a convolution is applied. This makes testing
- # easier.
- assert kernel_size == 1, "kernel_size != 1 isn't supported."
- num_locations = height * width
- inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
- output_grads = np.reshape(output_grads,
- [batch_size, num_locations, output_channels])
-
- fisher_diag = np.zeros((input_channels, output_channels))
- for i in range(batch_size):
- # Each example's approximation is a square(sum-of-outer-products).
- example_fisher_diag = np.zeros((input_channels, output_channels))
- for j in range(num_locations):
- example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
- fisher_diag += np.square(example_fisher_diag)
-
- # Normalize by batch_size (not num_locations).
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result, atol=1e-3)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- # Clone 'b' along 'input_channels' dimension.
- b_filter = np.tile(
- np.reshape(self.b, [1, 1, 1, self.output_channels]),
- [self.kernel_size, self.kernel_size, 1, 1])
- params = np.concatenate([self.w, b_filter], axis=2)
- expected_result = self.fisherApprox(True).dot(params.flatten())
-
- # Extract 'b' from concatenated parameters.
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels + 1,
- self.output_channels
- ])
- expected_result = (expected_result[:, :, 0:-1, :],
- np.reshape(expected_result[:, :, -1, :],
- [self.output_channels]))
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.test_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.ConvDiagonalFB(
- lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class DepthwiseConvKFCBasicFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Ensure inverse update op doesn't crash.
- sess.run(tf_variables.global_variables_initializer())
- sess.run([
- factor.make_inverse_update_ops()
- for factor in layer_collection.get_factors()
- ])
-
- # Ensure inverse-vector multiply doesn't crash.
- output = block.multiply_inverse(params)
- sess.run(output)
-
- # Ensure same shape.
- self.assertAllEqual(output.shape, params.shape)
-
-
-class ConvKFCBasicFBTest(test.TestCase):
-
- def _testConvKFCBasicFBInitParams(self, params):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- if isinstance(params, (list, tuple)):
- params = [array_ops.constant(param) for param in params]
- else:
- params = array_ops.constant(params)
- inputs = random_ops.random_normal((2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testConvKFCBasicFBInitParamsParamsTuple(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
-
- def testConvKFCBasicFBInitParamsParamsSingle(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
- np.arange(2, 4).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([0.136455, 0.27291], output[0][0])
- self.assertAllClose([0.27291, 0.409365], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertFalse(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseNotTupleWithBias(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = [random_ops.random_normal((2, 2, 2, 2))]
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertTrue(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.zeros((2, 2, 2, 2))
- inputs = array_ops.zeros((2, 2, 2, 2))
- outputs = array_ops.zeros((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(16, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedSeriesFBTest(test.TestCase):
-
- def testFullyConnectedSeriesFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
- block.register_additional_tower([inputs], [outputs])
- self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=True)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=False)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
-
-def as_tensors(tensor_or_tuple):
- """Converts a potentially nested tuple of np.array to Tensors."""
- if isinstance(tensor_or_tuple, (tuple, list)):
- return tuple(as_tensors(t) for t in tensor_or_tuple)
- return ops.convert_to_tensor(tensor_or_tuple)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
deleted file mode 100644
index fad47cd02f..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ /dev/null
@@ -1,955 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_factors."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def make_damping_func(damping):
- return fb._package_func(lambda: damping, damping)
-
-
-class FisherFactorTestingDummy(ff.FisherFactor):
- """Dummy class to test the non-abstract methods on ff.FisherFactor."""
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- raise NotImplementedError
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def make_inverse_update_ops(self):
- return []
-
- def get_cov(self):
- return NotImplementedError
-
- def instantiate_inv_variables(self):
- return NotImplementedError
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
- def register_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def register_cholesky(self, damping_func):
- raise NotImplementedError
-
- def register_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def get_cholesky(self, damping_func):
- raise NotImplementedError
-
- def get_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_cov_as_linear_operator(self):
- raise NotImplementedError
-
-
-class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
- """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
- """
-
- def __init__(self, shape):
- self._shape = shape
- super(DenseSquareMatrixFactorTestingDummy, self).__init__()
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- return self._shape
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
-
-class NumericalUtilsTest(test.TestCase):
-
- def testComputeCovAgainstNumpy(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x))
- np_cov = np.dot(x.T, x) / x.shape[0]
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- normalizer = 10.
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
- np_cov = np.dot(x.T, x) / normalizer
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testAppendHomog(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- npr.seed(0)
-
- m, n = 3, 4
- a = npr.randn(m, n)
- a_homog = ff.append_homog(array_ops.constant(a))
- np_result = np.hstack([a, np.ones((m, 1))])
-
- self.assertAllClose(sess.run(a_homog), np_result)
-
-
-class NameStringUtilFunctionTest(test.TestCase):
-
- def _make_tensor(self):
- x = array_ops.placeholder(dtypes.float64, (3, 1))
- w = array_ops.constant(npr.RandomState(0).randn(3, 3))
- y = math_ops.matmul(w, x)
- g = gradients_impl.gradients(y, x)[0]
- return g
-
- def testScopeStringFromParamsSingleTensor(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_params(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScopeStringFromParamsMultipleTensors(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params((x, y))
- self.assertEqual('Const_Const_1', scope_string)
-
- def testScopeStringFromParamsMultipleTypes(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
- (x, y)])
- self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
-
- def testScopeStringFromParamsUnsupportedType(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- unsupported = 1.2 # Floats are not supported.
- with self.assertRaises(ValueError):
- ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
- unsupported])
-
- def testScopeStringFromName(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScalarOrTensorToString(self):
- with tf_ops.Graph().as_default():
- self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
-
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
-
-
-class FisherFactorTest(test.TestCase):
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
-
-class DenseSquareMatrixFactorTest(test.TestCase):
-
- def testRegisterDampedInverse(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [2, 2]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- damping_funcs = [make_damping_func(0.1),
- make_damping_func(0.1),
- make_damping_func(1e-5),
- make_damping_func(1e-5)]
- for damping_func in damping_funcs:
- factor.register_inverse(damping_func)
-
- factor.instantiate_inv_variables()
-
- inv = factor.get_inverse(damping_funcs[0]).to_dense()
- self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
- self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
- self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
- factor.get_inverse(damping_funcs[3]).to_dense())
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- self.assertEqual(set([inv,
- factor.get_inverse(damping_funcs[2]).to_dense()]),
- set(factor_tensors))
- self.assertEqual(shape, inv.get_shape())
-
- def testRegisterMatpower(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [3, 3]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- # TODO(b/74201126): Change to using the same func for both once
- # Topohash is in place.
- damping_func_1 = make_damping_func(0.5)
- damping_func_2 = make_damping_func(0.5)
-
- factor.register_matpower(-0.5, damping_func_1)
- factor.register_matpower(2, damping_func_2)
-
- factor.instantiate_inv_variables()
-
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
-
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
- matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
-
- self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
-
- self.assertEqual(shape, matpower1.get_shape())
- self.assertEqual(shape, matpower2.get_shape())
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
- def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[1., 2.], [3., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_funcs = []
- for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
- damping_funcs.append(make_damping_func(1./i))
-
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- factor.register_inverse(damping_funcs[i])
-
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- new_invs = []
- sess.run(ops)
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- # The inverse op will assign the damped inverse of cov to the inv var.
- new_invs.append(
- sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
-
- # We want to see that the new invs are all different from each other.
- for i in range(len(new_invs)):
- for j in range(i + 1, len(new_invs)):
- # Just check the first element.
- self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
-
- def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[6., 2.], [2., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
- exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
- damping = 0.5
- damping_func = make_damping_func(damping)
-
- factor.register_matpower(exp, damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(ops[0])
- matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
- matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
- self.assertAllClose(matpower, matpower_np)
-
- def testMakeInverseUpdateOpsNoEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_func = make_damping_func(0)
-
- factor.register_inverse(damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- # The inverse op will assign the damped inverse of cov to the inv var.
- old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(
- sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
-
- sess.run(ops)
- new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(new_inv, np.linalg.inv(cov))
-
-
-class FullFactorTest(test.TestCase):
-
- def testFullFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
-
- def testFullFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 6], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.FullFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
-
-
-class NaiveDiagonalFactorTest(test.TestCase):
-
- def testNaiveDiagonalFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
-
- def testNaiveDiagonalFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 1], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75], [1.5]], new_cov)
-
-
-class EmbeddingInputKroneckerFactorTest(test.TestCase):
-
- def testInitialization(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.shape.as_list(), [vocab_size])
-
- def testCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(cov_update_op)
- self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
-
-
-class ConvDiagonalFactorTest(test.TestCase):
-
- def setUp(self):
- self.batch_size = 10
- self.height = self.width = 32
- self.in_channels = 3
- self.out_channels = 1
- self.kernel_height = self.kernel_width = 3
- self.strides = [1, 2, 2, 1]
- self.data_format = 'NHWC'
- self.padding = 'SAME'
- self.kernel_shape = [
- self.kernel_height, self.kernel_width, self.in_channels,
- self.out_channels
- ]
-
- def testInit(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format)
- factor.instantiate_cov_variables()
-
- # Ensure covariance matrix's shape makes sense.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- # Construct all arguments such that convolution kernel is applied in
- # exactly one spatial location.
- inputs = np.random.randn(
- 1, # batch_size
- self.kernel_height,
- self.kernel_width,
- self.in_channels) # in_channels
- outputs_grad = np.random.randn(
- 1, # batch_size
- 1, # output_height
- 1, # output_width
- self.out_channels)
-
- factor = ff.ConvDiagonalFactor(
- (constant_op.constant(inputs),),
- ((constant_op.constant(outputs_grad),),),
- self.kernel_shape,
- strides=[1, 1, 1, 1],
- padding='VALID')
- factor.instantiate_cov_variables()
-
- # Completely forget initial value on first update.
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- # Ensure new covariance value is same as outer-product of inputs/outputs
- # vectorized, squared.
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- cov = sess.run(cov_update_op)
- expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
- self.assertAllClose(expected_cov, cov)
-
- def testHasBias(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format,
- has_bias=True)
- factor.instantiate_cov_variables()
-
- # Ensure shape accounts for bias.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels + 1,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- # Ensure update op doesn't crash.
- cov_update_op = factor.make_covariance_update_op(0.0)
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(cov_update_op)
-
-
-class FullyConnectedKroneckerFactorTest(test.TestCase):
-
- def _testFullyConnectedKroneckerFactorInit(self,
- has_bias,
- final_shape,
- dtype=dtypes.float32_ref):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual(final_shape, cov.get_shape().as_list())
-
- def testFullyConnectedKroneckerFactorInitNoBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
-
- def testFullyConnectedKroneckerFactorInitWithBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-class ConvFactorTestCase(test.TestCase):
-
- def assertMatrixRank(self, rank, matrix, atol=1e-5):
- assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
- eigvals = np.linalg.eigvals(matrix)
- nnz_eigvals = np.sum(eigvals > atol)
- self.assertEqual(
- rank,
- nnz_eigvals,
- msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
- (nnz_eigvals, rank, eigvals)))
-
-
-class ConvInputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**3
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, width, in_channels), seed=0),),
- filter_shape=(width, width, width, in_channels, out_channels),
- padding='SAME',
- strides=(2, 2, 2),
- extract_patches_fn='extract_convolution_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- input_size = in_channels * (width**3)
- self.assertEqual([input_size, input_size],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-8, as the filter will be applied at each corner of
- # the 4-D cube.
- self.assertMatrixRank(8, cov)
-
- def testPointwiseConv2d(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 1, 1, 1),
- extract_patches_fn='extract_pointwise_conv2d_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- self.assertEqual([in_channels, in_channels],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-9, as the filter will be applied at each location.
- self.assertMatrixRank(9, cov)
-
- def testStrides(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 2, 1, 1),
- extract_patches_fn='extract_image_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be the sum of 3 * 2 = 6 outer products.
- self.assertMatrixRank(6, cov)
-
- def testDilationRate(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(3, 3, in_channels, out_channels),
- padding='SAME',
- extract_patches_fn='extract_image_patches',
- strides=(1, 1, 1, 1),
- dilation_rate=(1, width, width, 1),
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank = in_channels, as only the center of the filter
- # receives non-zero input for each input channel.
- self.assertMatrixRank(in_channels, cov)
-
- def testConvInputKroneckerFactorInitNoBias(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- inputs=(tensor,),
- filter_shape=(1, 2, 3, 4),
- padding='SAME',
- has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose(
- [
- [(1. + 4.) / 2., (1. + 2.) / 2.], #
- [(1. + 2.) / 2., (1. + 1.) / 2.]
- ], #
- new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
-
- def testSubSample(self):
- with tf_ops.Graph().as_default():
- patches_1 = array_ops.constant(1, shape=(10, 2))
- patches_2 = array_ops.constant(1, shape=(10, 8))
- patches_3 = array_ops.constant(1, shape=(3, 3))
- patches_1_sub = ff._subsample_for_cov_computation(patches_1)
- patches_2_sub = ff._subsample_for_cov_computation(patches_2)
- patches_3_sub = ff._subsample_for_cov_computation(patches_3)
- patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
- patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
- patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
- self.assertEqual(2, patches_1_sub_batch_size)
- self.assertEqual(8, patches_2_sub_batch_size)
- self.assertEqual(3, patches_3_sub_batch_size)
-
-
-class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- out_channels = width**3
-
- factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
- random_ops.random_uniform(
- (batch_size, width, width, width, out_channels), seed=0)
- ],))
- factor.instantiate_cov_variables()
-
- with self.test_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank 3^3, as each spatial position donates a rank-1
- # update.
- self.assertMatrixRank(width**3, cov)
-
- def testConvOutputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
-
- def testConvOutputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([5, 5], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
- factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
-
-
-class FullyConnectedMultiKFTest(test.TestCase):
-
- def testFullyConnectedMultiKFInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
-
- def testFullyConnectedMultiKFInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([3, 3], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
deleted file mode 100644
index cb80fca370..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ /dev/null
@@ -1,597 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.layer_collection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class MockFisherBlock(object):
- """A fake FisherBlock."""
-
- num_registered_towers = 2
-
- def __init__(self, name='MockFisherBlock'):
- self.name = name
-
- def __eq__(self, other):
- return isinstance(other, MockFisherBlock) and other.name == self.name
-
- def __hash__(self):
- return hash(self.name)
-
-
-class LayerParametersDictTest(test.TestCase):
-
- def testSetItem(self):
- """Ensure insertion, contains, retrieval works for supported key types."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y0 = array_ops.constant(0)
- y1 = array_ops.constant(0)
- z0 = array_ops.constant(0)
- z1 = array_ops.constant(0)
- keys = [x, (y0, y1), [z0, z1]]
- for key in keys:
- lp_dict[key] = key
-
- for key in keys:
- self.assertTrue(key in lp_dict)
- self.assertEqual(lp_dict[key], key)
-
- def testSetItemOverlap(self):
- """Ensure insertion fails if key overlaps with existing key."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y = array_ops.constant(0)
- lp_dict[x] = 'value'
-
- with self.assertRaises(ValueError):
- lp_dict[(x, y)] = 'value'
-
- # Ensure 'y' wasn't inserted.
- self.assertTrue(x in lp_dict)
- self.assertFalse(y in lp_dict)
-
-
-class LayerCollectionTest(test.TestCase):
-
- def testLayerCollectionInit(self):
- lc = layer_collection.LayerCollection()
- self.assertEqual(0, len(lc.get_blocks()))
- self.assertEqual(0, len(lc.get_factors()))
- self.assertFalse(lc.losses)
-
- def testRegisterBlocks(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(
- array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
- lc.register_fully_connected(
- array_ops.constant(1),
- array_ops.constant(2),
- array_ops.constant(3),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)))
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_separable_conv2d(
- depthwise_params=array_ops.ones((3, 3, 1, 2)),
- pointwise_params=array_ops.ones((1, 1, 2, 4)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
- pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
- strides=[1, 1, 1, 1],
- padding='SAME')
- lc.register_convolution(
- params=array_ops.ones((3, 3, 1, 8)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- outputs=array_ops.ones((32, 5, 5, 8)),
- padding='SAME')
- lc.register_generic(
- array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
- lc.register_generic(
- array_ops.constant(6),
- 16,
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_fully_connected_multi(
- array_ops.constant(1),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
- lc.register_conv2d_multi(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
- outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
- lc.register_embedding_multi(
- array_ops.constant((1,)),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
-
- self.assertEqual(12, len(lc.get_blocks()))
-
- def testRegisterBlocksMultipleRegistrations(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.register_fully_connected(key, array_ops.constant(2),
- array_ops.constant(3))
- with self.assertRaises(ValueError) as cm:
- lc.register_generic(key, 16)
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
- '1'
- }
- lc.register_block(x, 'foo')
-
- def testShouldRegisterSingleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamRegisteredInTuple(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
- '1'
- }
-
- lc.register_block((x, y), 'foo')
- self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
-
- def testRegisterTupleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterTupleParamRegisteredInSuperset(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y, z): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamSomeRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), MockFisherBlock('foo'))
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterCategoricalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- logits = linalg_ops.eye(2)
-
- lc = layer_collection.LayerCollection()
- lc.register_categorical_predictive_distribution(logits, seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testLossFunctionByName(self):
- """Ensure loss functions can be identified by name."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function by name.
- lc.register_categorical_predictive_distribution(logits, name='loss1')
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add logits to same loss function.
- lc.register_categorical_predictive_distribution(
- logits, name='loss1', reuse=True)
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add another new loss function.
- lc.register_categorical_predictive_distribution(logits, name='loss2')
- self.assertEqual(2, len(lc.towers_by_loss))
-
- def testLossFunctionWithoutName(self):
- """Ensure loss functions get unique names if 'name' not specified."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function with default names.
- lc.register_categorical_predictive_distribution(logits)
- lc.register_categorical_predictive_distribution(logits)
- self.assertEqual(2, len(lc.losses))
-
- def testCategoricalPredictiveDistributionMultipleMinibatches(self):
- """Ensure multiple minibatches are registered."""
- with ops.Graph().as_default():
- batch_size = 3
- output_size = 2
- logits = array_ops.zeros([batch_size, output_size])
- targets = array_ops.ones([batch_size], dtype=dtypes.int32)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function.
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1')
-
- # Can add when reuse=True
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=True)
-
- # Can add when reuse=VARIABLE_SCOPE and reuse=True there.
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- # Can't add when reuse=False
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=False)
-
- # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- self.assertEqual(len(lc.towers_by_loss), 1)
- # Three successful registrations.
- self.assertEqual(len(lc.towers_by_loss[0]), 3)
-
- def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- logits = random_ops.random_normal((1, 2))
- lc = layer_collection.LayerCollection()
-
- lc.register_categorical_predictive_distribution(logits, seed=200)
-
- def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([0, 1], dtype=dtypes.int32)
-
- lc.register_categorical_predictive_distribution(logits, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(1.6265233, single_loss)
-
- def testRegisterNormalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4]], dtype=dtypes.float32)
-
- lc = layer_collection.LayerCollection()
- lc.register_normal_predictive_distribution(predictions, 1., seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
-
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
-
- lc.register_normal_predictive_distribution(
- predictions, 2.**2, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(7.6983433, single_loss)
-
- def ensureLayerReuseWorks(self, register_fn):
- """Ensure the 'reuse' keyword argument function as intended.
-
- Args:
- register_fn: function for registering a layer. Arguments are
- layer_collection, reuse, and approx.
- """
- # Fails on second if reuse=False.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=False)
-
- # Succeeds on second if reuse=True.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- register_fn(lc, reuse=True)
-
- # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Fails if block type changes.
- lc = layer_collection.LayerCollection()
- register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
- with self.assertRaises(ValueError):
- register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
-
- # Fails if reuse requested but no FisherBlock exists.
- lc = layer_collection.LayerCollection()
- with self.assertRaises(KeyError):
- register_fn(lc, reuse=True)
-
- def testRegisterFullyConnectedReuse(self):
- """Ensure the 'reuse' works with register_fully_connected."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 10])
- outputs = array_ops.zeros([2, 5])
- params = (
- variable_scope.get_variable('w', [10, 5]), #
- variable_scope.get_variable('b', [5]))
-
- def register_fn(lc, **kwargs):
- lc.register_fully_connected(
- params=params, inputs=inputs, outputs=outputs, **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testRegisterConv2dReuse(self):
- """Ensure the 'reuse' works with register_conv2d."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- params = (
- variable_scope.get_variable('w', [1, 1, 10, 3]), #
- variable_scope.get_variable('b', [3]))
-
- def register_fn(lc, **kwargs):
- lc.register_conv2d(
- params=params,
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=inputs,
- outputs=outputs,
- **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testReuseWithInvalidRegistration(self):
- """Invalid registrations shouldn't overwrite existing blocks."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- w = variable_scope.get_variable('w', [1, 1, 10, 3])
- b = variable_scope.get_variable('b', [3])
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(w, inputs, outputs)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- with self.assertRaises(KeyError):
- lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
- self.assertNotIn((w, b), lc.fisher_blocks)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- lc.register_fully_connected(w, inputs, outputs, reuse=True)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
-
- def testMakeOrGetFactor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(
- all([var.name.startswith('LayerCollection') for var in variables]))
-
- def testMakeOrGetFactorCustomScope(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- scope = 'Foo'
- lc = layer_collection.LayerCollection(name=scope)
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(all([var.name.startswith(scope) for var in variables]))
-
- def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- z = variable_scope.get_variable('z', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters((x, z))
-
- def testIdentifySubsetPreviouslyRegisteredTensor(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters(x)
-
- def testSpecifyApproximation(self):
- w_0 = variable_scope.get_variable('w_0', [10, 10])
- w_1 = variable_scope.get_variable('w_1', [10, 10])
-
- b_0 = variable_scope.get_variable('b_0', [10])
- b_1 = variable_scope.get_variable('b_1', [10])
-
- x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
- x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
-
- pre_bias_0 = math_ops.matmul(x_0, w_0)
- pre_bias_1 = math_ops.matmul(x_1, w_1)
-
- # Build the fully connected layers in the graph.
- pre_bias_0 + b_0 # pylint: disable=pointless-statement
- pre_bias_1 + b_1 # pylint: disable=pointless-statement
-
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters(
- w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- b_0, approximation=layer_collection.APPROX_FULL_NAME)
- lc.define_linked_parameters(
- b_1, approximation=layer_collection.APPROX_FULL_NAME)
-
- lc.register_fully_connected(w_0, x_0, pre_bias_0)
- lc.register_fully_connected(
- w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
- self.assertIsInstance(lc.fisher_blocks[w_0],
- fisher_blocks.FullyConnectedDiagonalFB)
- self.assertIsInstance(lc.fisher_blocks[w_1],
- fisher_blocks.FullyConnectedKFACBasicFB)
-
- lc.register_generic(b_0, batch_size=1)
- lc.register_generic(
- b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
- self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
- self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
-
- def testDefaultLayerCollection(self):
- with ops.Graph().as_default():
- # Can't get default if there isn't one set.
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # Can't set default twice.
- lc = layer_collection.LayerCollection()
- layer_collection.set_default_layer_collection(lc)
- with self.assertRaises(ValueError):
- layer_collection.set_default_layer_collection(lc)
-
- # Same as one set.
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
-
- # Can set to None.
- layer_collection.set_default_layer_collection(None)
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # as_default() is the same as setting/clearing.
- with lc.as_default():
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
deleted file mode 100644
index c00af5593f..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.loss_functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import loss_functions
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class InsertSliceInZerosTest(test.TestCase):
-
- def testBadShape(self):
- bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
- with self.assertRaises(ValueError):
- loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
-
- def test3d(self):
- input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
- expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
- op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
- with self.test_session() as sess:
- actual_output_array = sess.run(op)
- self.assertAllEqual(expected_output_array, actual_output_array)
-
-
-class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2,))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.constant(targets))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
- def testMultiplyFisherSingleVector(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.array([1., 2., 3.])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- # the LossFunction.multiply_fisher docstring only says it supports the
- # case where the vector is the same shape as the input natural parameters
- # (i.e. the logits here), but here we also test leading dimensions
- vector = np.array([1., 2., 3.])
- vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
-
- probs = np.exp(logits - np.logaddexp.reduce(logits))
- fisher = np.diag(probs) - np.outer(probs, probs)
-
- for vector in vectors:
- result = loss.multiply_fisher(vector)
- expected_result = np.dot(vector, fisher)
- self.assertAllClose(expected_result, sess.run(result))
-
- def testMultiplyFisherBatch(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.array([[1., 2., 3.], [4., 6., 8.]])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- vector = np.array([[1., 2., 3.], [5., 3., 1.]])
-
- na = np.newaxis
- probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
- keepdims=True))
- fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
-
- result = loss.multiply_fisher(vector)
- expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
- self.assertEqual(sess.run(result).shape, logits.shape)
- self.assertAllClose(expected_result, sess.run(result))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2, 3))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.test_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
deleted file mode 100644
index b20a70e4ca..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.op_queue."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import op_queue
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class OpQueueTest(test.TestCase):
-
- def testNextOp(self):
- """Ensures all ops get selected eventually."""
- with tf_ops.Graph().as_default():
- ops = [
- math_ops.add(1, 2),
- math_ops.subtract(1, 2),
- math_ops.reduce_mean([1, 2]),
- ]
- queue = op_queue.OpQueue(ops, seed=0)
-
- with self.test_session() as sess:
- # Ensure every inv update op gets selected.
- selected_ops = set([queue.next_op(sess) for _ in ops])
- self.assertEqual(set(ops), set(selected_ops))
-
- # Ensure additional calls don't create any new ops.
- selected_ops.add(queue.next_op(sess))
- self.assertEqual(set(ops), set(selected_ops))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
deleted file mode 100644
index 560a9b0b42..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import optimizer
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def dummy_layer_collection():
- lcoll = lc.LayerCollection()
- dummy = array_ops.constant([1., 2.])
- lcoll.register_categorical_predictive_distribution(logits=dummy)
- return lcoll
-
-
-class OptimizerTest(test.TestCase):
-
- def testOptimizerInitInvalidMomentumRegistration(self):
- with self.assertRaises(ValueError):
- optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
-
- def testOptimizerInit(self):
- with ops.Graph().as_default():
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
-
- def testSquaredFisherNorm(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
- sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
-
- def testUpdateClipCoeff(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- lrate = 0.1
-
- # Note: without rescaling, the squared Fisher norm of the update
- # is 1.74
-
- # If the update already satisfies the norm constraint, there should
- # be no rescaling.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(1., sess.run(coeff), places=5)
-
- # If the update violates the constraint, it should be rescaled to
- # be on the constraint boundary.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
- self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
-
- def testComputeUpdateStepsRegular(self):
- # TODO(olganw): implement this.
- pass
-
- def testComputeUpdateStepsAdam(self):
- # TODO(olganw): implement this.
- pass
-
- def testUpdateVelocities(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- layers = lc.LayerCollection()
- layers.register_categorical_predictive_distribution(
- array_ops.constant([1.0]))
- opt = optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
- x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
- y = variable_scope.get_variable(
- 'y', initializer=array_ops.ones((2, 2)) * 2)
- vec1 = array_ops.ones((2, 2)) * 3
- vec2 = array_ops.ones((2, 2)) * 4
-
- model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
- opt_vars = [
- v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- if v not in model_vars
- ]
-
- sess.run(tf_variables.global_variables_initializer())
- old_opt_vars = sess.run(opt_vars)
-
- # Optimizer vars start out at 0.
- for opt_var in old_opt_vars:
- self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
-
- sess.run(update_op)
- new_opt_vars = sess.run(opt_vars)
- # After one update, the velocities are equal to the vectors.
- for vec, opt_var in zip([vec1, vec2], new_opt_vars):
- self.assertAllEqual(sess.run(vec), opt_var)
-
- sess.run(update_op)
- final_opt_vars = sess.run(opt_vars)
- for first, second in zip(new_opt_vars, final_opt_vars):
- self.assertFalse(np.equal(first, second).all())
-
- def testApplyGradients(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- opt = optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
- (cov_update_thunks,
- inv_update_thunks) = opt.make_vars_and_create_op_thunks()
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
-
- grads_and_vars = opt.compute_gradients(output, [weights, bias])
- all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
-
- op = opt.apply_gradients(grads_and_vars)
-
- sess.run(tf_variables.global_variables_initializer())
- old_vars = sess.run(all_vars)
- sess.run(cov_update_ops)
- sess.run(inv_update_ops)
- sess.run(op)
- new_vars = sess.run(all_vars)
-
- for old_var, new_var in zip(old_vars, new_vars):
- self.assertNotEqual(old_var, new_var)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
deleted file mode 100644
index 2cee01212a..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class SequenceDictTest(test.TestCase):
-
- def testSequenceDictInit(self):
- seq_dict = utils.SequenceDict()
- self.assertFalse(seq_dict._dict)
-
- def testSequenceDictInitWithIterable(self):
- reg_dict = {'a': 'foo', 'b': 'bar'}
- itr = zip(reg_dict.keys(), reg_dict.values())
- seq_dict = utils.SequenceDict(itr)
- self.assertEqual(reg_dict, seq_dict._dict)
-
- def testGetItemSingleKey(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual('foo', seq_dict['a'])
-
- def testGetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
-
- def testSetItemSingleKey(self):
- seq_dict = utils.SequenceDict()
- seq_dict['a'] = 'foo'
- self.assertEqual([('a', 'foo')], seq_dict.items())
-
- def testSetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict()
- keys = ('a', 'b', 'c')
- values = ('foo', 'bar', 'baz')
- seq_dict[keys] = values
- self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
-
-
-class SubGraphTest(test.TestCase):
-
- def testBasicGraph(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
- self.assertFalse(sub_graph.is_member(d))
-
- def testRepeatedAdds(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b + a # note that a appears twice in this graph
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
-
- def testFilterList(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- input_list = [b, d]
- filtered_list = sub_graph.filter_list(input_list)
- self.assertEqual(filtered_list, [b])
-
- def testVariableUses(self):
- with ops.Graph().as_default():
- var = variable_scope.get_variable('var', shape=[10, 10])
- resource_var = variable_scope.get_variable(
- 'resource_var', shape=[10, 10], use_resource=True)
- x = array_ops.zeros([3, 10])
- z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
- z1 = math_ops.matmul(x, resource_var)
- sub_graph = utils.SubGraph((z0, z1))
- self.assertEqual(2, sub_graph.variable_uses(var))
- self.assertEqual(1, sub_graph.variable_uses(resource_var))
-
-
-class UtilsTest(test.TestCase):
-
- def _fully_connected_layer_params(self):
- weights_part = array_ops.constant([[1., 2.], [4., 3.]])
- bias_part = array_ops.constant([1., 2.])
- return (weights_part, bias_part)
-
- def _conv_layer_params(self):
- weights_shape = 2, 2, 3, 4
- biases_shape = weights_shape[-1:]
- weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
- biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
- return (weights, biases)
-
- def testFullyConnectedLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([3, 2], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
-
- def testFullyConnectedLayerParamsTensorToMat2d(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params[0])
- self.assertListEqual([2, 2], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
-
- def testConvLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- layer_params = self._conv_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
-
- def testKron(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- mat1 = np.array([[1., 2.], [3., 4.]])
- mat2 = np.array([[5., 6.], [7., 8.]])
- mat1_tf = array_ops.constant(mat1)
- mat2_tf = array_ops.constant(mat2)
- ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
- ans_np = np.kron(mat1, mat2)
- self.assertAllClose(ans_tf, ans_np)
-
- def testMat2dToFullyConnectedLayerParamsTuple(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()
- mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
- self.assertAllClose(b, np.array([1., 0.]))
-
- def testMat2dToFullyConnectedLayerParamsTensor(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()[0]
- mat2d = array_ops.constant([[5., 4.], [3., 2.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
-
- def testTensorsToColumn(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
-
- vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- output = utils.tensors_to_column(vector)
- self.assertListEqual([4, 1], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
-
- vector = self._fully_connected_layer_params()
- output = utils.tensors_to_column(vector)
- self.assertListEqual([6, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
-
- vector = list(vector)
- vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
-
- output = utils.tensors_to_column(vector)
- self.assertListEqual([10, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output),
- np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
-
- def testColumnToTensors(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
-
- vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- colvec = array_ops.constant(np.arange(4.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
-
- vector_template = self._fully_connected_layer_params()
- colvec = array_ops.constant(np.arange(6.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
-
- vector_template = list(vector_template)
- vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
- colvec = array_ops.constant(np.arange(10.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 3)
- a, b, c = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
- self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
-
- def testPosDefInvCholesky(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testPosDefInvMatrixInverse(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_matrix_inverse(
- array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testCrossReplicaMean(self):
- """Ensures that cross_replica_mean() executes only when num_shards > 1."""
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(4):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertNotEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(1):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with self.assertRaises(ValueError): # Outside of TPU context.
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
-
- def testBatchExecute(self):
- """Ensure batch_execute runs in a round-robin fashion."""
-
- def increment_var(var):
- return lambda: var.assign_add(1)
-
- with ops.Graph().as_default(), self.test_session() as sess:
- i = variable_scope.get_variable('i', initializer=0)
- accumulators = [
- variable_scope.get_variable('var%d' % j, initializer=0)
- for j in range(3)
- ]
- thunks = [increment_var(var) for var in accumulators]
- increment_accumulators = utils.batch_execute(i, thunks, 2)
- increment_i = i.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
-
- # Ensure one op per thunk.
- self.assertEqual(3, len(increment_accumulators))
-
- # Ensure round-robin execution.
- values = []
- for _ in range(5):
- sess.run(increment_accumulators)
- sess.run(increment_i)
- values.append(sess.run(accumulators))
- self.assertAllClose(
- [
- [1, 1, 0], #
- [2, 1, 1], #
- [2, 2, 2], #
- [3, 3, 2], #
- [4, 3, 3]
- ],
- values)
-
- def testExtractConvolutionPatches(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- batch_size = 10
- image_spatial_shape = [9, 10, 11]
- in_channels = out_channels = 32
- kernel_spatial_shape = [5, 3, 3]
- spatial_strides = [1, 2, 1]
- spatial_dilation = [1, 1, 1]
- padding = 'SAME'
-
- images = random_ops.random_uniform(
- [batch_size] + image_spatial_shape + [in_channels], seed=0)
- kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_convolution_patches(
- images,
- kernel_shape,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
- result_spatial_shape = (
- patches.shape.as_list()[1:1 + len(image_spatial_shape)])
- self.assertEqual(patches.shape.as_list(),
- [batch_size] + result_spatial_shape +
- kernel_spatial_shape + [in_channels])
-
- # Ensure extract...patches() + matmul() and convolution() implementation
- # give the same answer.
- outputs = nn_ops.convolution(
- images,
- kernel,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
-
- patches_flat = array_ops.reshape(
- patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
- def testExtractPointwiseConv2dPatches(self):
- with ops.Graph().as_default(), self.test_session() as sess:
- batch_size = 10
- image_height = image_width = 8
- in_channels = out_channels = 3
- kernel_height = kernel_width = 1
- strides = [1, 1, 1, 1]
- padding = 'VALID'
-
- images = random_ops.random_uniform(
- [batch_size, image_height, image_width, in_channels], seed=0)
- kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
- self.assertEqual(patches.shape.as_list(), [
- batch_size, image_height, image_width, kernel_height, kernel_width,
- in_channels
- ])
-
- # Ensure extract...patches() + matmul() and conv2d() implementation
- # give the same answer.
- outputs = nn_ops.conv2d(images, kernel, strides, padding)
-
- patches_flat = array_ops.reshape(
- patches, [-1, kernel_height * kernel_width * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
deleted file mode 100644
index 3c01eb65e7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ /dev/null
@@ -1,263 +0,0 @@
-package(default_visibility = [
- "//tensorflow/contrib/kfac:__pkg__",
- "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "fisher_blocks",
- srcs = ["fisher_blocks.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_blocks_lib",
- srcs = ["fisher_blocks_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_factors",
- srcs = ["fisher_factors.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":linear_operator",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:special_math_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_factors_lib",
- srcs = ["fisher_factors_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "linear_operator",
- srcs = ["linear_operator.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions",
- srcs = ["loss_functions.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/ops/distributions",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions_lib",
- srcs = ["loss_functions_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":loss_functions",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products",
- srcs = ["curvature_matrix_vector_products.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products_lib",
- srcs = ["curvature_matrix_vector_products_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "layer_collection",
- srcs = ["layer_collection.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- ":loss_functions",
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "layer_collection_lib",
- srcs = ["layer_collection_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":layer_collection",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "kfac_optimizer",
- srcs = [
- "optimizer.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- ":fisher_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ],
-)
-
-py_library(
- name = "kfac_optimizer_lib",
- srcs = [
- "optimizer_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":kfac_optimizer",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_estimator",
- srcs = [
- "estimator.py",
- "placement.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:util",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_estimator_lib",
- srcs = [
- "estimator_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_estimator",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "utils",
- srcs = ["utils.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "utils_lib",
- srcs = ["utils_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "op_queue",
- srcs = ["op_queue.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:dataset_ops",
- "//tensorflow/python:framework_ops",
- ],
-)
-
-py_library(
- name = "op_queue_lib",
- srcs = ["op_queue_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":op_queue",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
deleted file mode 100644
index 21b5cde9b9..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-
-class CurvatureMatrixVectorProductComputer(object):
- """Class for computing matrix-vector products for Fishers, GGNs and Hessians.
-
- In other words we compute M*v where M is the matrix, v is the vector, and
- * refers to standard matrix/vector multiplication (not element-wise
- multiplication).
-
- The matrices are defined in terms of some differential quantity of the total
- loss function with respect to a provided list of tensors ("wrt_tensors").
- For example, the Fisher associated with a log-prob loss w.r.t. the
- parameters.
-
- The 'vecs' argument to each method are lists of tensors that must be the
- size as the corresponding ones from "wrt_tensors". They represent
- the vector being multiplied.
-
- "factors" of the matrix M are defined as matrices B such that B*B^T = M.
- Methods that multiply by the factor B take a 'loss_inner_vecs' argument
- instead of 'vecs', which must be a list of tensors with shapes given by the
- corresponding XXX_inner_shapes property.
-
- Note that matrix-vector products are not normalized by the batch size, nor
- are any damping terms added to the results. These things can be easily
- applied externally, if desired.
-
- See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
- and https://arxiv.org/abs/1412.1193 for more information about the
- generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
- products.
- """
-
- def __init__(self, losses, wrt_tensors):
- """Create a CurvatureMatrixVectorProductComputer object.
-
- Args:
- losses: A list of LossFunction instances whose sum defines the total loss.
- wrt_tensors: A list of Tensors to compute the differential quantities
- (defining the matrices) with respect to. See class description for more
- info.
- """
- self._losses = losses
- self._inputs_to_losses = list(loss.inputs for loss in losses)
- self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
- self._wrt_tensors = wrt_tensors
-
- @property
- def _total_loss(self):
- return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
-
- # Jacobian multiplication functions:
- def _multiply_jacobian(self, vecs):
- """Multiply vecs by the Jacobian of losses."""
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- jacobian_vecs_flat = utils.fwd_gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
- stop_gradients=self._wrt_tensors)
- return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
-
- def _multiply_jacobian_transpose(self, loss_vecs):
- """Multiply vecs by the transpose Jacobian of losses."""
- loss_vecs_flat = nest.flatten(loss_vecs)
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- return gradients_impl.gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
- stop_gradients=self._wrt_tensors)
-
- # Losses Fisher/Hessian multiplication functions:
- def _multiply_loss_fisher(self, loss_vecs):
- """Multiply loss_vecs by Fisher of total loss."""
- return tuple(
- loss.multiply_fisher(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian(self, loss_vecs):
- """Multiply loss_vecs by Hessian of total loss."""
- return tuple(
- loss.multiply_hessian(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- # Matrix-vector product functions:
- def multiply_fisher(self, vecs):
- """Multiply vecs by Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
-
- def multiply_fisher_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
-
- def multiply_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
- loss_inner_vecs)
- return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
-
- def multiply_hessian(self, vecs):
- """Multiply vecs by Hessian of total loss."""
- return gradients_impl.gradients(
- gradients_impl.gradients(self._total_loss, self._wrt_tensors),
- self._wrt_tensors,
- grad_ys=vecs)
-
- def multiply_generalized_gauss_newton(self, vecs):
- """Multiply vecs by generalized Gauss-Newton of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of GGN of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of GGN of total loss."""
- hessian_factor_transpose_vecs = (
- self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
- return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
-
- # Shape properties for multiply_XXX_factor methods:
- @property
- def fisher_factor_inner_shapes(self):
- """Shapes required by multiply_fisher_factor."""
- return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
-
- @property
- def generalized_gauss_newton_factor_inner_shapes(self):
- """Shapes required by multiply_generalized_gauss_newton_factor."""
- return tuple(loss.hessian_factor_inner_shape for loss in self._losses)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
deleted file mode 100644
index 6e8c6404dc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'CurvatureMatrixVectorProductComputer',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
deleted file mode 100644
index 323234c403..0000000000
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ /dev/null
@@ -1,516 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import placement
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-
-# The linter is confused.
-# pylint: disable=abstract-class-instantiated
-def make_fisher_estimator(placement_strategy=None, **kwargs):
- """Creates Fisher estimator instances based on the placement strategy.
-
- For example if the `placement_strategy` is 'round_robin' then
- `FisherEstimatorRoundRobin` instance is returned.
-
- Args:
- placement_strategy: `string`, Strategy to be used for placing covariance
- variables, covariance ops and inverse ops. Check
- `placement.FisherEstimatorRoundRobin` for a concrete example.
- **kwargs: Arguments to be passed into `FisherEstimator` class initializer.
-
- Returns:
- An instance of class which inherits from `FisherEstimator` and the mixin
- which implements specific placement strategy. See,
- `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
- `RoundRobinPlacementMixin`.
-
- Raises:
- ValueError: If the `placement_strategy` is not equal to 'round_robin'.
- """
- if placement_strategy in [None, "round_robin"]:
- return FisherEstimatorRoundRobin(**kwargs)
- else:
- raise ValueError("Unimplemented vars and ops "
- "placement strategy : {}".format(placement_strategy))
-# pylint: enable=abstract-class-instantiated
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherEstimator(object):
- """Fisher estimator class supporting various approximations of the Fisher.
-
- This is an abstract base class which does not implement a strategy for
- placing covariance variables, covariance update ops and inverse update ops.
- The placement strategies are implemented in `placement.py`. See
- `FisherEstimatorRoundRobin` for example of a concrete subclass with
- a round-robin placement strategy.
- """
-
- def __init__(self,
- variables,
- cov_ema_decay,
- damping,
- layer_collection,
- exps=(-1,),
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- name="FisherEstimator",
- compute_cholesky=False,
- compute_cholesky_inverse=False):
- """Create a FisherEstimator object.
-
- Args:
- variables: A `list` of variables or `callable` which returns the variables
- for which to estimate the Fisher. This must match the variables
- registered in layer_collection (if it is not None).
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: float. The damping factor used to stabilize training due to
- errors in the local approximation with the Fisher information matrix,
- and to regularize the update direction by making it closer to the
- gradient. (Higher damping means the update looks more like a standard
- gradient update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the Fisher
- blocks, Kronecker factors, and losses associated with the
- graph.
- exps: List of floats or ints. These represent the different matrix
- powers of the approximate Fisher that the FisherEstimator will be able
- to multiply vectors by. If the user asks for a matrix power other
- one of these (or 1, which is always supported), there will be a
- failure. (Default: (-1,))
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_prop', or 'exact'.
- (Default: 'gradients'). 'gradients' is the basic estimation approach
- from the original K-FAC paper. 'empirical' computes the 'empirical'
- Fisher information matrix (which uses the data's distribution for the
- targets, as opposed to the true Fisher which uses the model's
- distribution) and requires that each registered loss have specified
- targets. 'curvature_propagation' is a method which estimates the
- Fisher using self-products of random 1/-1 vectors times "half-factors"
- of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
- Finally, 'exact' is the obvious generalization of Curvature
- Propagation to compute the exact Fisher (modulo any additional
- diagonal or Kronecker approximations) by looping over one-hot vectors
- for each coordinate of the output instead of using 1/-1 vectors. It
- is more expensive to compute than the other three options by a factor
- equal to the output dimension, roughly speaking.
- colocate_gradients_with_ops: Whether we should request gradients be
- colocated with their respective ops. (Default: True)
- name: A string. A name given to this estimator, which is added to the
- variable scope when constructing variables and ops.
- (Default: "FisherEstimator")
- compute_cholesky: Bool. Whether or not the FisherEstimator will be
- able to multiply vectors by the Cholesky factor.
- (Default: False)
- compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
- will be able to multiply vectors by the Cholesky factor inverse.
- (Default: False)
- Raises:
- ValueError: If no losses have been registered with layer_collection.
- """
- self._variables = variables
- self._cov_ema_decay = cov_ema_decay
- self._damping = damping
- self._estimation_mode = estimation_mode
- self._layers = layer_collection
- self._gradient_fns = {
- "gradients": self._get_grads_lists_gradients,
- "empirical": self._get_grads_lists_empirical,
- "curvature_prop": self._get_grads_lists_curvature_prop,
- "exact": self._get_grads_lists_exact
- }
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- self._made_vars = False
- self._exps = exps
- self._compute_cholesky = compute_cholesky
- self._compute_cholesky_inverse = compute_cholesky_inverse
-
- self._name = name
-
- @property
- def variables(self):
- if callable(self._variables):
- return self._variables()
- else:
- return self._variables
-
- @property
- def damping(self):
- return self._damping
-
- @property
- def blocks(self):
- """All registered FisherBlocks."""
- return self._layers.get_blocks()
-
- @property
- def factors(self):
- """All registered FisherFactors."""
- return self._layers.get_factors()
-
- @property
- def name(self):
- return self._name
-
- @abc.abstractmethod
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks with a specific placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- pass
-
- def _apply_transformation(self, vecs_and_vars, transform):
- """Applies an block-wise transformation to the corresponding vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transform: A function of the form f(fb, vec), where vec is the vector
- to transform and fb is its corresponding block in the matrix, that
- returns the transformed vector.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
-
- vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
-
- trans_vecs = utils.SequenceDict()
-
- for params, fb in self._layers.fisher_blocks.items():
- trans_vecs[params] = transform(fb, vecs[params])
-
- return [(trans_vecs[var], var) for _, var in vecs_and_vars]
-
- def multiply_inverse(self, vecs_and_vars):
- """Multiplies the vecs by the corresponding (damped) inverses of the blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(-1, vecs_and_vars)
-
- def multiply(self, vecs_and_vars):
- """Multiplies the vectors by the corresponding (damped) blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(1, vecs_and_vars)
-
- def multiply_matpower(self, exp, vecs_and_vars):
- """Multiplies the vecs by the corresponding matrix powers of the blocks.
-
- Args:
- exp: A float representing the power to raise the blocks by before
- multiplying it by the vector.
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert exp in self._exps
-
- fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky(self, vecs_and_vars, transpose=False):
- """Multiplies the vecs by the corresponding Cholesky factors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factors are transposed before
- multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky
-
- fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
- """Mults the vecs by the inverses of the corresponding Cholesky factors.
-
- Note: if you are using Cholesky inverse multiplication to sample from
- a matrix-variate Gaussian you will want to multiply by the transpose.
- Let L be the Cholesky factor of F and observe that
-
- L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
-
- Thus we want to multiply by L^-T in order to sample from Gaussian with
- covariance F^-1.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factor inverses are transposed
- before multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky_inverse
-
- fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def _instantiate_factors(self):
- """Instantiates FisherFactors' variables.
-
- Raises:
- ValueError: If estimation_mode was improperly specified at construction.
- """
- blocks = self.blocks
- tensors_to_compute_grads = [
- block.tensors_to_compute_grads() for block in blocks
- ]
-
- try:
- grads_lists = self._gradient_fns[self._estimation_mode](
- tensors_to_compute_grads)
- except KeyError:
- raise ValueError("Unrecognized value {} for estimation_mode.".format(
- self._estimation_mode))
-
- for grads_list, block in zip(grads_lists, blocks):
- block.instantiate_factors(grads_list, self.damping)
-
- def _check_vars_unmade_and_set_made_flag(self):
- if self._made_vars:
- raise Exception("Already made variables.")
- self._made_vars = True
-
- def made_vars(self):
- return self._made_vars
-
- def _register_matrix_functions(self):
- for block in self.blocks:
- for exp in self._exps:
- block.register_matpower(exp)
- if self._compute_cholesky:
- block.register_cholesky()
- if self._compute_cholesky_inverse:
- block.register_cholesky_inverse()
-
- def _finalize_layer_collection(self):
- self._layers.create_subgraph()
- self._layers.check_registration(self.variables)
- self._instantiate_factors()
- self._register_matrix_functions()
-
- def create_ops_and_vars_thunks(self, scope=None):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All thunks will execute inside
- of a variable scope of the given name. (Default: None)
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- self._check_vars_unmade_and_set_made_flag()
-
- self._finalize_layer_collection()
-
- scope = self.name if scope is None else scope
-
- cov_variable_thunks = [
- self._create_cov_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- cov_update_thunks = [
- self._create_cov_update_thunk(factor, scope) for factor in self.factors
- ]
- inv_variable_thunks = [
- self._create_inv_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- inv_update_thunks = [
- self._create_inv_update_thunk(factor, scope) for factor in self.factors
- ]
-
- return (cov_variable_thunks, cov_update_thunks,
- inv_variable_thunks, inv_update_thunks)
-
- def _create_cov_variable_thunk(self, factor, scope):
- """Constructs a covariance variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_cov_variables()
-
- return thunk
-
- def _create_cov_update_thunk(self, factor, scope):
- """Constructs a covariance update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.make_covariance_update_op(self._cov_ema_decay)
-
- return thunk
-
- def _create_inv_variable_thunk(self, factor, scope):
- """Constructs a inverse variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_inv_variables()
-
- return thunk
-
- def _create_inv_update_thunk(self, factor, scope):
- """Constructs an inverse update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return control_flow_ops.group(factor.make_inverse_update_ops())
-
- return thunk
-
- def _get_grads_lists_gradients(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnessesary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses_on_samples(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_empirical(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnecessary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_transformed_random_signs(self):
- transformed_random_signs = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- transformed_random_signs.append(
- loss.multiply_fisher_factor(
- utils.generate_random_signs(loss.fisher_factor_inner_shape)))
- return transformed_random_signs
-
- def _get_grads_lists_curvature_prop(self, tensors):
- loss_inputs = list(loss.inputs for loss in self._layers.losses)
- transformed_random_signs = self._get_transformed_random_signs()
- grads_flat = gradients_impl.gradients(
- nest.flatten(loss_inputs),
- nest.flatten(tensors),
- grad_ys=nest.flatten(transformed_random_signs),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_exact(self, tensors):
- """No docstring required."""
- # Loop over all coordinates of all losses.
- grads_all = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
- transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
- index)
- grads_flat = gradients_impl.gradients(
- loss.inputs,
- nest.flatten(tensors),
- grad_ys=transformed_one_hot,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
- return zip(*grads_all)
-
-
-class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
- FisherEstimator):
- """Fisher estimator which provides round robin device placement strategy."""
- pass
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
deleted file mode 100644
index 9fa6eb7dcd..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ /dev/null
@@ -1,1752 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions.
-
-This library contains classes for estimating blocks in a model's Fisher
-Information matrix. Suppose one has a model that parameterizes a posterior
-distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
-Fisher Information matrix is given by,
-
- $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
-
-where,
-
- $$v(x, y, params) = (d / d params) log p(y | x, params)$$
-
-and the expectation is taken with respect to the data's distribution for 'x' and
-the model's posterior distribution for 'y',
-
- x ~ p(x)
- y ~ p(y | x, params)
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import enum # pylint: disable=g-bad-import-order
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-# For blocks corresponding to convolutional layers, or any type of block where
-# the parameters can be thought of as being replicated in time or space,
-# we want to adjust the scale of the damping by
-# damping /= num_replications ** NORMALIZE_DAMPING_POWER
-NORMALIZE_DAMPING_POWER = 1.0
-
-# Methods for adjusting damping for FisherBlocks. See
-# compute_pi_adjusted_damping() for details.
-PI_OFF_NAME = "off"
-PI_TRACENORM_NAME = "tracenorm"
-PI_TYPE = PI_TRACENORM_NAME
-
-
-def set_global_constants(normalize_damping_power=None, pi_type=None):
- """Sets various global constants used by the classes in this module."""
- global NORMALIZE_DAMPING_POWER
- global PI_TYPE
-
- if normalize_damping_power is not None:
- NORMALIZE_DAMPING_POWER = normalize_damping_power
-
- if pi_type is not None:
- PI_TYPE = pi_type
-
-
-def normalize_damping(damping, num_replications):
- """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
- if NORMALIZE_DAMPING_POWER:
- return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
- return damping
-
-
-def compute_pi_tracenorm(left_cov, right_cov):
- r"""Computes the scalar constant pi for Tikhonov regularization/damping.
-
- $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
- See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
-
- Args:
- left_cov: A LinearOperator object. The left Kronecker factor "covariance".
- right_cov: A LinearOperator object. The right Kronecker factor "covariance".
-
- Returns:
- The computed scalar constant pi for these Kronecker Factors (as a Tensor).
- """
- # Instead of dividing by the dim of the norm, we multiply by the dim of the
- # other norm. This works out the same in the ratio.
- left_norm = left_cov.trace() * int(right_cov.domain_dimension)
- right_norm = right_cov.trace() * int(left_cov.domain_dimension)
- return math_ops.sqrt(left_norm / right_norm)
-
-
-def compute_pi_adjusted_damping(left_cov, right_cov, damping):
-
- if PI_TYPE == PI_TRACENORM_NAME:
- pi = compute_pi_tracenorm(left_cov, right_cov)
- return (damping * pi, damping / pi)
-
- elif PI_TYPE == PI_OFF_NAME:
- return (damping, damping)
-
-
-class PackagedFunc(object):
- """A Python thunk with a stable ID.
-
- Enables stable names for lambdas.
- """
-
- def __init__(self, func, func_id):
- """Initializes PackagedFunc.
-
- Args:
- func: a zero-arg Python function.
- func_id: a hashable, function that produces a hashable, or a list/tuple
- thereof.
- """
- self._func = func
- func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
- self._func_id = func_id
-
- def __call__(self):
- return self._func()
-
- @property
- def func_id(self):
- """A hashable identifier for this function."""
- return tuple(elt() if callable(elt) else elt for elt in self._func_id)
-
-
-def _package_func(func, func_id):
- return PackagedFunc(func, func_id)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherBlock(object):
- """Abstract base class for objects modeling approximate Fisher matrix blocks.
-
- Subclasses must implement register_matpower, multiply_matpower,
- instantiate_factors, tensors_to_compute_grads, and num_registered_towers
- methods.
- """
-
- def __init__(self, layer_collection):
- self._layer_collection = layer_collection
-
- @abc.abstractmethod
- def instantiate_factors(self, grads_list, damping):
- """Creates and registers the component factors of this Fisher block.
-
- Args:
- grads_list: A list gradients (each a Tensor or tuple of Tensors) with
- respect to the tensors returned by tensors_to_compute_grads() that
- are to be used to estimate the block.
- damping: The damping factor (float or Tensor).
- """
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp):
- """Registers a matrix power to be computed by the block.
-
- Args:
- exp: A float representing the power to raise the block by.
- """
- pass
-
- @abc.abstractmethod
- def register_cholesky(self):
- """Registers a Cholesky factor to be computed by the block."""
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self):
- """Registers an inverse Cholesky factor to be computed by the block."""
- pass
-
- def register_inverse(self):
- """Registers a matrix inverse to be computed by the block."""
- self.register_matpower(-1)
-
- @abc.abstractmethod
- def multiply_matpower(self, vector, exp):
- """Multiplies the vector by the (damped) matrix-power of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- exp: A float representing the power to raise the block by before
- multiplying it by the vector.
-
- Returns:
- The vector left-multiplied by the (damped) matrix-power of the block.
- """
- pass
-
- def multiply_inverse(self, vector):
- """Multiplies the vector by the (damped) inverse of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) inverse of the block.
- """
- return self.multiply_matpower(vector, -1)
-
- def multiply(self, vector):
- """Multiplies the vector by the (damped) block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) block.
- """
- return self.multiply_matpower(vector, 1)
-
- @abc.abstractmethod
- def multiply_cholesky(self, vector, transpose=False):
- """Multiplies the vector by the (damped) Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor is transposed before
- multiplying the vector. (Default: False)
-
- Returns:
- The vector left-multiplied by the (damped) Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def multiply_cholesky_inverse(self, vector, transpose=False):
- """Multiplies vector by the (damped) inverse Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor inverse is transposed
- before multiplying the vector. (Default: False)
- Returns:
- Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def tensors_to_compute_grads(self):
- """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
- """
- pass
-
- @abc.abstractproperty
- def num_registered_towers(self):
- """Number of towers registered for this FisherBlock.
-
- Typically equal to the number of towers in a multi-tower setup.
- """
- pass
-
-
-class FullFB(FisherBlock):
- """FisherBlock using a full matrix estimate (no approximations).
-
- FullFB uses a full matrix estimate (no approximations), and should only ever
- be used for very low dimensional parameters.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a FullFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._batch_sizes = []
- self._params = params
-
- super(FullFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullFactor, (grads_list, self._batch_size))
-
- def register_matpower(self, exp):
- self._factor.register_matpower(exp, self._damping_func)
-
- def register_cholesky(self):
- self._factor.register_cholesky(self._damping_func)
-
- def register_cholesky_inverse(self):
- self._factor.register_cholesky_inverse(self._damping_func)
-
- def _multiply_matrix(self, matrix, vector, transpose=False):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat, adjoint=transpose)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block."""
- return self._factor.get_cov_as_linear_operator().to_dense()
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class DiagonalFB(FisherBlock):
- """A base class for FisherBlocks that use diagonal approximations."""
-
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky(self):
- # Not needed for this. Cholesky's are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky_inverse(self):
- # Not needed for this. Cholesky inverses's are computed on demand in the
- # diagonal case
- pass
-
- def _multiply_matrix(self, matrix, vector):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def full_fisher_block(self):
- return self._factor.get_cov_as_linear_operator().to_dense()
-
-
-class NaiveDiagonalFB(DiagonalFB):
- """FisherBlock using a diagonal matrix approximation.
-
- This type of approximation is generically applicable but quite primitive.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a NaiveDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._params = params
- self._batch_sizes = []
-
- super(NaiveDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-class InputOutputMultiTower(object):
- """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
-
- def __init__(self, *args, **kwargs):
- self.__inputs = []
- self.__outputs = []
- super(InputOutputMultiTower, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- The initial format of self._inputs is expected to be a list of Tensors
- over towers. Similarly grads_lists is expected to be a list over sources
- of such lists.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
- tensor (represented as a PartitionedTensor object) equal to the
- concatenation (across towers) of all of the elements of self._inputs. And
- similarly grads_list is formatted into a tuple (over sources) of such
- tensors (also represented as PartitionedTensors).
-
- If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
- remains unchanged from the initial format (although possibly converting
- from lists into tuples).
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
- """
- inputs = self._inputs
- # inputs is a list over towers of Tensors
- # grads_list is a list of list with the first index being sources and the
- # second being towers.
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Merge towers together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
- # Do the same for grads_list but preserve leading sources dimension
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
- elif fisher_factors.TOWER_STRATEGY == "separate":
- inputs = tuple(inputs)
- grads_list = tuple(grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
-
- return inputs, grads_list
-
- def tensors_to_compute_grads(self):
- """Tensors to compute derivative of loss with respect to."""
- return tuple(self._outputs)
-
- def register_additional_tower(self, inputs, outputs):
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_towers(self):
- result = len(self._inputs)
- assert result == len(self._outputs)
- return result
-
- @property
- def _inputs(self):
- return self.__inputs
-
- @property
- def _outputs(self):
- return self.__outputs
-
-
-class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for fully-connected (dense) layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a fully
- connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
- squares" estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider fully connected layer in this model with (unshared) weight matrix
- 'w'. For an example 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( a (d loss / d s)^T )$$
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedDiagonalFactor,
- (inputs, grads_list, self._has_bias))
-
- self._damping_func = _package_func(lambda: damping, (damping,))
-
-
-class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for 2-D convolutional layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a convolutional
- layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
- estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider a convoluational layer in this model with (unshared) filter matrix
- 'w'. For an example image 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
-
- where 'loc' is a single (x, y) location in an image.
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- data_format=None,
- dilations=None):
- """Creates a ConvDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [kernel_height, kernel_width,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (e.g. "SAME").
- data_format: str or None. Format of input data.
- dilations: List of 4 ints or None. Rate for dilation along all dimensions.
-
- Raises:
- ValueError: if strides is not length-4.
- ValueError: if dilations is not length-4.
- ValueError: if channel is not last dimension.
- """
- if len(strides) != 4:
- raise ValueError("strides must contain 4 numbers.")
-
- if dilations is None:
- dilations = [1, 1, 1, 1]
-
- if len(dilations) != 4:
- raise ValueError("dilations must contain 4 numbers.")
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- self._strides = maybe_tuple(strides)
- self._padding = padding
- self._data_format = data_format
- self._dilations = maybe_tuple(dilations)
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- if len(self._filter_shape) != 4:
- raise ValueError(
- "Convolution filter must be of shape"
- " [filter_height, filter_width, in_channels, out_channels].")
-
- super(ConvDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides, self._padding,
- self._data_format, self._dilations, self._has_bias))
-
- def damping_func():
- return self._num_locations * normalize_damping(damping,
- self._num_locations)
-
- damping_id = (self._num_locations, "mult", "normalize_damping", damping,
- self._num_locations)
- self._damping_func = _package_func(damping_func, damping_id)
-
-
-class KroneckerProductFB(FisherBlock):
- """A base class for blocks with separate input and output Kronecker factors.
-
- The Fisher block is approximated as a Kronecker product of the input and
- output factors.
- """
-
- def _setup_damping(self, damping, normalization=None):
- """Makes functions that compute the damping values for both factors."""
- def compute_damping():
- if normalization is not None:
- maybe_normalized_damping = normalize_damping(damping, normalization)
- else:
- maybe_normalized_damping = damping
-
- return compute_pi_adjusted_damping(
- self._input_factor.get_cov_as_linear_operator(),
- self._output_factor.get_cov_as_linear_operator(),
- maybe_normalized_damping**0.5)
-
- if normalization is not None:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- "normalize_damping", damping, normalization, "power", 0.5)
- else:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- damping, "power", 0.5)
-
- self._input_damping_func = _package_func(lambda: compute_damping()[0],
- damping_id + ("ref", 0))
- self._output_damping_func = _package_func(lambda: compute_damping()[1],
- damping_id + ("ref", 1))
-
- def register_matpower(self, exp):
- self._input_factor.register_matpower(exp, self._input_damping_func)
- self._output_factor.register_matpower(exp, self._output_damping_func)
-
- def register_cholesky(self):
- self._input_factor.register_cholesky(self._input_damping_func)
- self._output_factor.register_cholesky(self._output_damping_func)
-
- def register_cholesky_inverse(self):
- self._input_factor.register_cholesky_inverse(self._input_damping_func)
- self._output_factor.register_cholesky_inverse(self._output_damping_func)
-
- @property
- def _renorm_coeff(self):
- """Kronecker factor multiplier coefficient.
-
- If this FisherBlock is represented as 'FB = c * kron(left, right)', then
- this is 'c'.
-
- Returns:
- 0-D Tensor.
- """
- return 1.0
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = right_factor.matmul_right(reshaped_vector,
- adjoint=transpose_right)
- reshaped_out = left_factor.matmul(reshaped_out,
- adjoint=transpose_left)
- if extra_scale != 1.0:
- reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
- def multiply_matpower(self, vector, exp):
- left_factor = self._input_factor.get_matpower(
- exp, self._input_damping_func)
- right_factor = self._output_factor.get_matpower(
- exp, self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**exp
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale)
-
- def multiply_cholesky(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky(self._input_damping_func)
- right_factor = self._output_factor.get_cholesky(self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky_inverse(
- self._input_damping_func)
- right_factor = self._output_factor.get_cholesky_inverse(
- self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**-0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block.
-
- Used for testing purposes. (In general, the result may be very large.)
-
- Returns:
- The full Fisher block.
- """
- left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
- right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
- return self._renorm_coeff * utils.kronecker_product(left_factor,
- right_factor)
-
-
-class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers.
-
- This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
- input factor is approximated by a diagonal matrix. In the case that each
- example references exactly one embedding, this approximation is exact.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size):
- """Creates a EmbeddingKFACFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
- self._setup_damping(damping)
-
-
-class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for fully-connected (dense) layers.
-
- This uses the Kronecker-factorized approximation from the original
- K-FAC paper (https://arxiv.org/abs/1503.05671)
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedKFACBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- ((inputs,), self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- (grads_list,))
- self._setup_damping(damping)
-
-
-class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
- r"""FisherBlock for convolutional layers using the basic KFC approx.
-
- Estimates the Fisher Information matrix's blog for a convolutional
- layer.
-
- Consider a convolutional layer in this model with (unshared) filter matrix
- 'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
- this FisherBlock estimates,
-
- $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
- E[flat(ds) flat(ds)^T])$$
-
- where
-
- $$ds = (d / ds) log p(y | x, w)$$
- #locations = number of (x, y) locations where 'w' is applied.
-
- where the expectation is taken over all examples and locations and flat()
- concatenates an array's leading dimensions.
-
- See equation 23 in https://arxiv.org/abs/1602.01407 for details.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None):
- """Creates a ConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=self._num_locations)
-
- @property
- def _renorm_coeff(self):
- return self._num_locations
-
-
-class DepthwiseConvDiagonalFB(ConvDiagonalFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvDiagonalFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvDiagonalFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- strides=strides,
- padding=padding,
- dilations=rate,
- data_format=data_format)
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_matrix(self, matrix, vector):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvKFCBasicFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=rate,
- data_format=data_format,
- extract_patches_fn="extract_image_patches")
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
- left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
- transpose_left=transpose_left, transpose_right=transpose_right)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with conv2d.
-
- Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
- compatible with tf.nn.conv2d().
-
- Args:
- filter: Tensor of shape [height, width, in_channels, channel_multiplier].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape [height, width, in_channels, out_channels].
-
- """
- with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, channel_multiplier = (
- filter.shape.as_list())
-
- results = []
- for i in range(in_channels):
- # Slice out one in_channel's filter. Insert zeros around it to force it
- # to affect that channel and that channel alone.
- elements = []
- if i > 0:
- elements.append(
- array_ops.zeros(
- [filter_height, filter_width, i, channel_multiplier]))
- elements.append(filter[:, :, i:(i + 1), :])
- if i + 1 < in_channels:
- elements.append(
- array_ops.zeros([
- filter_height, filter_width, in_channels - (i + 1),
- channel_multiplier
- ]))
-
- # Concat along in_channel.
- results.append(
- array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-1, name="out_channel")
-
-
-def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with depthwise_conv2d.
-
- Transforms a filter for use with tf.nn.conv2d() to one that's
- compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
- the diagonal.
-
- Args:
- filter: Tensor of shape [height, width, in_channels, out_channels].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape,
- [height, width, in_channels, channel_multiplier]
-
- Raises:
- ValueError: if out_channels is not evenly divisible by in_channels.
- """
- with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, out_channels = (
- filter.shape.as_list())
-
- if out_channels % in_channels != 0:
- raise ValueError("out_channels must be evenly divisible by in_channels.")
- channel_multiplier = out_channels // in_channels
-
- results = []
- filter = array_ops.reshape(filter, [
- filter_height, filter_width, in_channels, in_channels,
- channel_multiplier
- ])
- for i in range(in_channels):
- # Slice out output corresponding to the correct filter.
- filter_slice = array_ops.reshape(
- filter[:, :, i, i, :],
- [filter_height, filter_width, 1, channel_multiplier])
- results.append(filter_slice)
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-2, name="in_channels")
-
-
-def maybe_tuple(obj):
- if not isinstance(obj, list):
- return obj
- return tuple(obj)
-
-
-def num_conv_locations(input_shape, strides):
- """Returns the number of spatial locations a 2D Conv kernel is applied to.
-
- Args:
- input_shape: List of ints representing shape of inputs to
- tf.nn.convolution().
- strides: List of ints representing strides along spatial dimensions as
- passed in to tf.nn.convolution().
-
- Returns:
- A scalar |T| denoting the number of spatial locations for the Conv layer.
- """
- spatial_input_locations = np.prod(input_shape[1:-1])
-
- if strides is None:
- spatial_strides_divisor = 1
- else:
- spatial_strides_divisor = np.prod(strides)
-
- return spatial_input_locations // spatial_strides_divisor
-
-
-class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
- """Adds methods for multi-use/time-step case to InputOutputMultiTower."""
-
- def __init__(self, num_uses=None, *args, **kwargs):
- self._num_uses = num_uses
- super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process temporal/multi-use data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- It accepts the data in one of two initial formats. The first possible
- format is where self._inputs is a list of list of Tensors. The first index
- is tower, the second is use/time-step. grads_list, meanwhile, is a list
- over sources of such lists of lists.
-
- The second possible data format is where self._inputs is a Tensor with
- uses/times-steps folded into the batch dimension. i.e. it is a Tensor
- of shape [num_uses * size_batch, ...] which represents a reshape of a
- Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is
- a list over sources of such Tensors.
-
- There are two possible formats which inputs and grads_list are transformed
- into.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
- a single tensor (represented as a PartitionedTensor object) with all of
- the data from the towers, as well as the uses/time-steps, concatenated
- together. In this tensor the leading dimension is the batch and
- use/time-step dimensions folded together (with 'use' being the major of
- these two, so that the tensors can be thought of as reshapes of ones of
- shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
- tuple over sources of such tensors.
-
- If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
- tensors over towers. Each of these tensors has a similar format to
- the tensor produced by the "concat" option, except that each contains
- only the data from a single tower. grads_list is similarly formatted
- into a tuple over sources of such tuples.
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
- ValueError: If the given/initial format of self._inputs and grads_list
- isn't recognized, or doesn't agree with self._num_uses.
- """
-
- inputs = self._inputs
-
- if isinstance(inputs[0], (list, tuple)):
- num_uses = len(inputs[0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of inputs.")
- else:
- self._num_uses = num_uses
-
- # Check that all mini-batches/towers have the same number of uses
- if not all(len(input_) == num_uses for input_ in inputs):
- raise ValueError("Length of inputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- inputs = tuple(zip(*inputs))
-
- # Flatten the two dimensions
- inputs = nest.flatten(inputs)
-
- # Merge everything together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- inputs = tuple(inputs)
-
- # Now we perform the analogous processing for grads_list
- if isinstance(grads_list[0][0], (list, tuple)):
- num_uses = len(grads_list[0][0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of outputs, "
- "or length of outputs is inconsistent with length of "
- "inputs.")
- else:
- self._num_uses = num_uses
-
- if not all(len(grad) == num_uses for grads in grads_list
- for grad in grads):
- raise ValueError("Length of outputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
-
- # Flatten the two dimensions, leaving the leading dimension (source)
- # intact
- grads_list = tuple(nest.flatten(grads) for grads in grads_list)
-
- # Merge inner dimensions together into PartitionedTensors. We package
- # them in a singleton tuple since the factors will expect a list over
- # towers
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- grads_list = tuple(tuple(utils.PartitionedTensor(grad)
- for grad in grads)
- for grads in grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- grads_list = tuple(tuple(grads) for grads in grads_list)
-
- if self._num_uses is None:
- raise ValueError("You must supply a value for the num_uses argument if "
- "the number of uses cannot be inferred from inputs or "
- "outputs arguments (e.g. if they are both given in the "
- "single Tensor format, instead of as lists of Tensors.")
-
- return inputs, grads_list
-
-
-class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters.
-
- This class implements the "independence across time" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
- """
-
- def __init__(self, layer_collection, has_bias=False, num_uses=None):
- """Creates a FullyConnectedMultiIndepFB block.
-
- Args:
- layer_collection: LayerCollection instance.
- has_bias: bool. If True, estimates Fisher with respect to a bias
- parameter as well as the layer's parameters.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for 2D convolutional layers using the basic KFC approx.
-
- Similar to ConvKFCBasicFB except that this version supports multiple
- uses/time-steps via a standard independence approximation. Similar to the
- "independence across time" used in FullyConnectedMultiIndepFB but generalized
- in the obvious way to conv layers.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- num_uses=None):
- """Creates a ConvKFCBasicMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=
- (self._num_locations * self._num_uses))
-
- @property
- def _renorm_coeff(self):
- return self._num_locations * self._num_uses
-
-
-class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers used multiple times in the graph.
-
- Similar to EmbeddingKFACFB except that this version supports multiple uses
- of the parameter within a single model. These uses could correspond to time
- steps in an RNN architecture, but they don't have to.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size, num_uses=None):
- """Creates a EmbeddingKFACMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with time folded into the batch
- dimension (instead of time being a list dimension). (Default: None)
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
- gradient of the loss with respect to 'outputs' from source 'i',
- tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
- [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class SeriesFBApproximation(enum.IntEnum):
- """See FullyConnectedSeriesFB.__init__ for description and usage."""
- option1 = 1
- option2 = 2
-
-
-class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters across time.
-
- This class implements the "Option 1" and "Option 2" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
-
- See the end of the appendix of the paper for a pseudo-code of the
- algorithm being implemented by multiply_matpower here. Note that we are
- using pre-computed versions of certain matrix-matrix products to speed
- things up. This is explicitly explained wherever it is done.
- """
-
- def __init__(self,
- layer_collection,
- has_bias=False,
- num_uses=None,
- option=SeriesFBApproximation.option2):
- """Constructs a new `FullyConnectedSeriesFB`.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the layer includes a bias parameter.
- num_uses: int or None. Number of time-steps over which the layer
- is used. Only required if the data is formatted with time folded into
- the batch dimension (instead of time being a list dimension).
- (Default: None)
- option: A `SeriesFBApproximation` specifying the simplifying assumption
- to be used in this block. `option1` approximates the cross-covariance
- over time as a symmetric matrix, while `option2` makes
- the assumption that training sequences are infinitely long. See section
- 3.5 of the paper for more details.
- """
-
- self._has_bias = has_bias
- self._option = option
-
- super(FullyConnectedSeriesFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _renorm_coeff(self):
- # This should no longer be used since the multiply_X functions from the base
- # class have been overridden
- assert False
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
- self._input_factor.register_cov_dt1()
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._output_factor.register_cov_dt1()
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- def register_matpower(self, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- if self._option == SeriesFBApproximation.option1:
- self._input_factor.register_option1quants(self._input_damping_func)
- self._output_factor.register_option1quants(self._output_damping_func)
- elif self._option == SeriesFBApproximation.option2:
- self._input_factor.register_option2quants(self._input_damping_func)
- self._output_factor.register_option2quants(self._output_damping_func)
- else:
- raise ValueError(
- "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
- self._option))
-
- def multiply_matpower(self, vector, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- # pylint: disable=invalid-name
-
- Z = utils.layer_params_to_mat2d(vector)
-
- # Derivations were done for "batch_dim==1" case so we need to convert to
- # that orientation:
- Z = array_ops.transpose(Z)
-
- if self._option == SeriesFBApproximation.option1:
-
- # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
- L_A, psi_A = self._input_factor.get_option1quants(
- self._input_damping_func)
- L_G, psi_G = self._output_factor.get_option1quants(
- self._output_damping_func)
-
- def gamma(x):
- # We are assuming that each case has the same number of time-steps.
- # If this stops being the case one shouldn't simply replace this T
- # with its average value. Instead, one needs to go back to the
- # definition of the gamma function from the paper.
- T = self._num_timesteps
- return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
-
- # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
- # Even though Y is Z-independent we are recomputing it from the psi's
- # each since Y depends on both A and G quantities, and it is relatively
- # cheap to compute.
- Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
-
- # \\(Z = L_G^T * Z * L_A\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = U_G^T * Z * U_A\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
-
- # \\(Z = Z .* Y\\)
- Z *= Y
-
- # \\(Z = L_G * Z * L_A^T\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = U_G * Z * U_A^T\\)
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
-
- elif self._option == SeriesFBApproximation.option2:
-
- # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
- # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
- P_A, K_A, mu_A = self._input_factor.get_option2quants(
- self._input_damping_func)
- P_G, K_G, mu_G = self._output_factor.get_option2quants(
- self._output_damping_func)
-
- # Our approach differs superficially from the pseudo-code in the paper
- # in order to reduce the total number of matrix-matrix multiplies.
- # In particular, the first three computations in the pseudo code are
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
- # \\(Z = E_G^T * Z * E_A\\)
- # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
- # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
- # the entire computation can be written as
- # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
- # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
- # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
- # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
- # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
- # This final expression is computed by the following two lines:
- # \\(Z = Z - P_G * Z * P_A^T\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
- # \\(Z = K_G^T * Z * K_A\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
-
- # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
- # Be careful with the outer product. We don't want to accidentally
- # make it an inner-product instead.
- tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
- # Prevent some numerical issues by setting any 0.0 eigs to 1.0
- tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype)
- Z /= tmp
-
- # We now perform the transpose/reverse version of the operations
- # derived above, whose derivation from the original pseudo-code is
- # analgous.
- # \\(Z = K_G * Z * K_A^T\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
-
- # \\(Z = Z - P_G^T * Z * P_A\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
-
- # \\(Z = normalize (1/E[T]) * Z\\)
- # Note that this normalization is done because we compute the statistics
- # by averaging, not summing, over time. (And the gradient is presumably
- # summed over time, not averaged, and thus their scales are different.)
- Z /= math_ops.cast(self._num_timesteps, Z.dtype)
-
- # Convert back to the "batch_dim==0" orientation.
- Z = array_ops.transpose(Z)
-
- return utils.mat2d_to_layer_params(vector, Z)
-
- # pylint: enable=invalid-name
-
- def multiply_cholesky(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
- def multiply_cholesky_inverse(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
deleted file mode 100644
index c04cf727fa..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'FisherBlock',
- 'FullFB',
- 'NaiveDiagonalFB',
- 'FullyConnectedDiagonalFB',
- 'KroneckerProductFB',
- 'EmbeddingKFACFB',
- 'FullyConnectedKFACBasicFB',
- 'ConvKFCBasicFB',
- 'ConvDiagonalFB',
- 'set_global_constants',
- 'compute_pi_tracenorm',
- 'compute_pi_adjusted_damping',
- 'num_conv_locations',
- 'normalize_damping',
- 'LEFT_MULTIPLY',
- 'RIGHT_MULTIPLY',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
deleted file mode 100644
index afa2fd1ca7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ /dev/null
@@ -1,1830 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import contextlib
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import special_math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import moving_averages
-from tensorflow.python.util import nest
-
-
-# Whether to initialize covariance estimators at a zero matrix (or the identity
-# matrix).
-INIT_COVARIANCES_AT_ZERO = True
-
-# Whether to zero-debias the moving averages.
-ZERO_DEBIAS = True
-
-# Whether to initialize inverse (and other such matrices computed from the cov
-# matrices) to the zero matrix (or the identity matrix).
-INIT_INVERSES_AT_ZERO = True
-
-# When the number of inverses requested from a FisherFactor exceeds this value,
-# the inverses are computed using an eigenvalue decomposition.
-EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
-
-# Numerical eigenvalues computed from covariance matrix estimates are clipped to
-# be at least as large as this value before they are used to compute inverses or
-# matrix powers. Must be nonnegative.
-EIGENVALUE_CLIPPING_THRESHOLD = 0.0
-
-# Used to subsample the flattened extracted image patches. The number of
-# outer products per row of the covariance matrix should not exceed this
-# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
-_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
-
-# Used to subsample the inputs passed to the extract image patches. The batch
-# size of number of inputs to extract image patches is multiplied by this
-# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
-_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_OUTER_PRODUCTS = False
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_INPUTS = False
-
-# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
-# passed to the factors from the blocks will be concatenated across towers
-# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over
-# towers will be passed in, and the factors will iterate over this and do the
-# cov computations separately for each one, averaging the results together.
-TOWER_STRATEGY = "concat"
-
-
-def set_global_constants(init_covariances_at_zero=None,
- zero_debias=None,
- init_inverses_at_zero=None,
- eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None,
- max_num_outer_products_per_cov_row=None,
- sub_sample_outer_products=None,
- inputs_to_extract_patches_factor=None,
- sub_sample_inputs=None,
- tower_strategy=None):
- """Sets various global constants used by the classes in this module."""
- global INIT_COVARIANCES_AT_ZERO
- global ZERO_DEBIAS
- global INIT_INVERSES_AT_ZERO
- global EIGENVALUE_DECOMPOSITION_THRESHOLD
- global EIGENVALUE_CLIPPING_THRESHOLD
- global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
- global _SUB_SAMPLE_OUTER_PRODUCTS
- global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
- global _SUB_SAMPLE_INPUTS
- global TOWER_STRATEGY
-
- if init_covariances_at_zero is not None:
- INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
- if zero_debias is not None:
- ZERO_DEBIAS = zero_debias
- if init_inverses_at_zero is not None:
- INIT_INVERSES_AT_ZERO = init_inverses_at_zero
- if eigenvalue_decomposition_threshold is not None:
- EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
- if eigenvalue_clipping_threshold is not None:
- EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
- if max_num_outer_products_per_cov_row is not None:
- _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
- if sub_sample_outer_products is not None:
- _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
- if inputs_to_extract_patches_factor is not None:
- _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor
- if sub_sample_inputs is not None:
- _SUB_SAMPLE_INPUTS = sub_sample_inputs
- if tower_strategy is not None:
- TOWER_STRATEGY = tower_strategy
-
-
-def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_INVERSES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return array_ops.ones(shape, dtype=dtype)
-
-
-@contextlib.contextmanager
-def place_on_device(device):
- if device is not None and len(device):
- with tf_ops.device(device):
- yield
- else:
- yield
-
-
-def compute_cov(tensor, tensor_right=None, normalizer=None):
- """Compute the empirical second moment of the rows of a 2D Tensor.
-
- This function is meant to be applied to random matrices for which the true row
- mean is zero, so that the true second moment equals the true covariance.
-
- Args:
- tensor: A 2D Tensor.
- tensor_right: An optional 2D Tensor. If provided, this function computes
- the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
- normalizer: optional scalar for the estimator (by default, the normalizer is
- the number of rows of tensor).
-
- Returns:
- A square 2D Tensor with as many rows/cols as the number of input columns.
- """
- if normalizer is None:
- normalizer = array_ops.shape(tensor)[0]
- if tensor_right is None:
- cov = (
- math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast(
- normalizer, tensor.dtype))
- return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype)
- else:
- return (math_ops.matmul(tensor, tensor_right, transpose_a=True) /
- math_ops.cast(normalizer, tensor.dtype))
-
-
-def append_homog(tensor):
- """Appends a homogeneous coordinate to the last dimension of a Tensor.
-
- Args:
- tensor: A Tensor.
-
- Returns:
- A Tensor identical to the input but one larger in the last dimension. The
- new entries are filled with ones.
- """
- rank = len(tensor.shape.as_list())
- shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
- ones = array_ops.ones(shape, dtype=tensor.dtype)
- return array_ops.concat([tensor, ones], axis=rank - 1)
-
-
-def scope_string_from_params(params):
- """Builds a variable scope string name from the given parameters.
-
- Supported parameters are:
- * tensors
- * booleans
- * ints
- * strings
- * depth-1 tuples/lists of ints
- * any depth tuples/lists of tensors
- Other parameter types will throw an error.
-
- Args:
- params: A parameter or list of parameters.
-
- Returns:
- A string to use for the variable scope.
-
- Raises:
- ValueError: if params includes an unsupported type.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
-
- name_parts = []
- for param in params:
- if param is None:
- name_parts.append("None")
- elif isinstance(param, (tuple, list)):
- if all([isinstance(p, int) for p in param]):
- name_parts.append("-".join([str(p) for p in param]))
- else:
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, (str, int, bool)):
- name_parts.append(str(param))
- elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, utils.PartitionedTensor):
- name_parts.append(scope_string_from_name(param.tensors))
- else:
- raise ValueError("Encountered an unsupported param type {}".format(
- type(param)))
- return "_".join(name_parts)
-
-
-def scope_string_from_name(tensor):
- if isinstance(tensor, (tuple, list)):
- return "__".join([scope_string_from_name(t) for t in tensor])
- # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape"
- return tensor.name.split(":")[0].replace("/", "_")
-
-
-def scalar_or_tensor_to_string(val):
- return repr(val) if np.isscalar(val) else scope_string_from_name(val)
-
-
-def list_to_string(lst):
- return "_".join(val if isinstance(val, six.string_types)
- else scalar_or_tensor_to_string(val) for val in lst)
-
-
-def graph_func_to_id(func):
- """Returns a hashable object that represents func's computation."""
- # TODO(b/74201126): replace with Topohash of func's output
- return func.func_id
-
-
-def graph_func_to_string(func):
- # TODO(b/74201126): replace with Topohash of func's output
- return list_to_string(func.func_id)
-
-
-def _subsample_for_cov_computation(array, name=None):
- """Subsamples the first dimension of the array.
-
- `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
- matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
- products per row of the covariance matrix is greater than
- `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- name: `string`, Default(None)
-
- Returns:
- A tensor of shape `[max_samples, dim_2]`.
-
- Raises:
- ValueError: If array's is not matrix-shaped.
- ValueError: If array's batch_size cannot be inferred.
-
- """
- with tf_ops.name_scope(name, "subsample", [array]):
- array = tf_ops.convert_to_tensor(array)
- if len(array.shape) != 2:
- raise ValueError("Input param array must be a matrix.")
-
- batch_size = array.shape.as_list()[0]
- if batch_size is None:
- raise ValueError("Unable to get batch_size from input param array.")
-
- num_cov_rows = array.shape.as_list()[-1]
- max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
- if batch_size <= max_batch_size:
- return array
-
- return _random_tensor_gather(array, max_batch_size)
-
-
-def _random_tensor_gather(array, max_size):
- """Generates a random set of indices and gathers the value at the indices.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- max_size: int, Number of indices to sample.
-
- Returns:
- A tensor of shape `[max_size, ...]`.
- """
- batch_size = array.shape.as_list()[0]
- indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
- return array_ops.gather(array, indices)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherFactor(object):
- """Base class for objects modeling factors of approximate Fisher blocks.
-
- A FisherFactor represents part of an approximate Fisher Information matrix.
- For example, one approximation to the Fisher uses the Kronecker product of two
- FisherFactors A and B, F = kron(A, B). FisherFactors are composed with
- FisherBlocks to construct a block-diagonal approximation to the full Fisher.
-
- FisherFactors are backed by a single, non-trainable variable that is updated
- by running FisherFactor.make_covariance_update_op(). The shape and type of
- this variable is implementation specific.
-
- Note that for blocks that aren't based on approximations, a 'factor' can
- be the entire block itself, as is the case for the diagonal and full
- representations.
- """
-
- def __init__(self):
- self._cov = None
-
- @abc.abstractproperty
- def _var_scope(self):
- """Variable scope for this FisherFactor instance.
-
- Returns:
- string that unique identifies this FisherFactor instance.
- """
- pass
-
- @property
- def name(self):
- return self._var_scope
-
- @abc.abstractproperty
- def _cov_shape(self):
- """The shape of the variable backing this FisherFactor."""
- pass
-
- @abc.abstractproperty
- def _num_sources(self):
- """The number of things to sum over when updating covariance variable.
-
- The default make_covariance_update_op function will call _compute_new_cov
- with indices ranging from 0 to _num_sources-1. The typical situation is
- where the factor wants to sum the statistics it computes over multiple
- backpropped "gradients" (typically passed in via "tensors" or
- "outputs_grads" arguments).
- """
- pass
-
- @abc.abstractproperty
- def _num_towers(self):
- pass
-
- @abc.abstractproperty
- def _dtype(self):
- """dtype for variable backing this factor."""
- pass
-
- @property
- def _cov_initializer(self):
- """Function for initializing covariance variable."""
- return covariance_initializer
-
- def instantiate_cov_variables(self):
- """Makes the internal cov variable(s)."""
- assert self._cov is None
- with variable_scope.variable_scope(self._var_scope):
- self._cov = variable_scope.get_variable(
- "cov",
- initializer=self._cov_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- @abc.abstractmethod
- def _compute_new_cov(self, source, tower):
- """Computes minibatch-estimated covariance for a single source.
-
- Args:
- source: int in [0, self._num_sources). Which source to use when computing
- the cov update.
- tower: int in [0, self._num_towers). Which tower to use when computing
- the cov update.
-
- Returns:
- Tensor of same shape as self.get_cov().
- """
- pass
-
- def make_covariance_update_op(self, ema_decay):
- """Constructs and returns the covariance update Op.
-
- Args:
- ema_decay: The exponential moving average decay (float or Tensor).
- Returns:
- An Op for updating the covariance Variable referenced by _cov.
- """
- new_cov_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- device = (self._get_data_device(tower)
- if TOWER_STRATEGY == "separate" else None)
- with place_on_device(device):
- new_cov_contribs.append(self._compute_new_cov(source, tower))
-
- new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
-
- # Compute average of 'new_cov' across all TPU cores. On a TPU, each
- # instance of 'new_cov' will be based on a different minibatch. This ensures
- # that by the end of assign_moving_average(), all TPU cores see the same
- # value for self._cov.
- #
- # Other implementations of make_covariance_update_op() that accumulate
- # statistics in other variables should mimic this behavior.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
-
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
-
- @abc.abstractmethod
- def _get_data_device(self, tower):
- pass
-
- @abc.abstractmethod
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
- pass
-
- @abc.abstractmethod
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- pass
-
- def get_cov(self):
- return self._cov
-
- @abc.abstractmethod
- def get_cov_as_linear_operator(self):
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky_inverse(self, damping_func):
- pass
-
-
-class DenseSquareMatrixFactor(FisherFactor):
- """Base class for FisherFactors that are stored as dense square matrices.
-
- This class explicitly calculates and stores inverses of their `cov` matrices,
- which must be square dense matrices.
-
- Subclasses must implement the _compute_new_cov method, and the _var_scope and
- _cov_shape properties.
- """
-
- # TODO(b/69108481): This class (and its subclasses) should be refactored to
- # serve the matrix quantities it computes as both (potentially stale)
- # variables, updated by the inverse update ops, and fresh values stored in
- # tensors that recomputed once every session.run() call. Currently matpower
- # and damp_inverse have the former behavior, while eigendecomposition has
- # the latter.
-
- def __init__(self):
- self._matpower_by_exp_and_damping = {} # { (float, hashable): variable }
- self._matpower_registrations = set() # { (float, hashable) }
- self._eigendecomp = None
- self._damping_funcs_by_id = {} # {hashable: lambda}
-
- self._cholesky_registrations = set() # { hashable }
- self._cholesky_inverse_registrations = set() # { hashable }
-
- self._cholesky_by_damping = {} # { hashable: variable }
- self._cholesky_inverse_by_damping = {} # { hashable: variable }
-
- super(DenseSquareMatrixFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self.get_cov().shape.ndims == 2
- return lo.LinearOperatorFullMatrix(self.get_cov(),
- is_self_adjoint=True,
- is_square=True)
-
- def _register_damping(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- if damping_id not in self._damping_funcs_by_id:
- self._damping_funcs_by_id[damping_id] = damping_func
- return damping_id
-
- def register_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- self.register_matpower(-1, damping_func)
-
- def register_matpower(self, exp, damping_func):
- """Registers a matrix power to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_matpower.
-
- Args:
- exp: float. The exponent to use in the matrix power.
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- if exp == 1.0:
- return
-
- damping_id = self._register_damping(damping_func)
-
- if (exp, damping_id) not in self._matpower_registrations:
- self._matpower_registrations.add((exp, damping_id))
-
- def register_cholesky(self, damping_func):
- """Registers a Cholesky factor to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_registrations:
- self._cholesky_registrations.add(damping_id)
-
- def register_cholesky_inverse(self, damping_func):
- """Registers an inverse Cholesky factor to be maintained/served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky_inverse.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_inverse_registrations:
- self._cholesky_inverse_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
-
- for (exp, damping_id) in self._matpower_registrations:
- exp_string = scalar_or_tensor_to_string(exp)
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- matpower = variable_scope.get_variable(
- "matpower_exp{}_damp{}".format(exp_string, damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert (exp, damping_id) not in self._matpower_by_exp_and_damping
- self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
-
- for damping_id in self._cholesky_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- chol = variable_scope.get_variable(
- "cholesky_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_by_damping
- self._cholesky_by_damping[damping_id] = chol
-
- for damping_id in self._cholesky_inverse_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- cholinv = variable_scope.get_variable(
- "cholesky_inverse_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_inverse_by_damping
- self._cholesky_inverse_by_damping[damping_id] = cholinv
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- ops = []
-
- num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
- if exp == -1)
-
- num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
-
- other_matrix_power_registered = num_other_matpower >= 1
-
- use_eig = (
- self._eigendecomp or other_matrix_power_registered or
- num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
-
- # We precompute these so we don't need to evaluate them multiple times (for
- # each matrix power that uses them)
- damping_value_by_id = {damping_id: math_ops.cast(
- self._damping_funcs_by_id[damping_id](), self._dtype)
- for damping_id in self._damping_funcs_by_id}
-
- if use_eig:
- eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
-
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- damping = damping_value_by_id[damping_id]
- ops.append(
- matpower.assign(
- math_ops.matmul(eigenvectors *
- (eigenvalues + damping)**exp,
- array_ops.transpose(eigenvectors))))
- # These ops share computation and should be run on a single device.
- ops = [control_flow_ops.group(*ops)]
- else:
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- assert exp == -1
- damping = damping_value_by_id[damping_id]
- ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
-
- # TODO(b/77902055): If inverses are being computed with Cholesky's
- # we can share the work. Instead this code currently just computes the
- # Cholesky a second time. It does at least share work between requests for
- # Cholesky's and Cholesky inverses with the same damping id.
- for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
- cholesky_ops = []
-
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
-
- if damping_id in self._cholesky_by_damping:
- cholesky = self._cholesky_by_damping[damping_id]
- cholesky_ops.append(cholesky.assign(cholesky_value))
-
- identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
- dtype=cholesky_value.dtype)
- cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
- identity)
- cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
-
- ops.append(control_flow_ops.group(*cholesky_ops))
-
- for damping_id, cholesky in self._cholesky_by_damping.items():
- if damping_id not in self._cholesky_inverse_by_damping:
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
- ops.append(cholesky.assign(cholesky_value))
-
- self._eigendecomp = False
- return ops
-
- def get_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- return self.get_matpower(-1, damping_func)
-
- def get_matpower(self, exp, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- if exp != 1:
- damping_id = graph_func_to_id(damping_func)
- matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
- else:
- matpower = self.get_cov()
- identity = linalg_ops.eye(matpower.shape.as_list()[0],
- dtype=matpower.dtype)
- matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
-
- assert matpower.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(matpower,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky = self._cholesky_by_damping[damping_id]
- assert cholesky.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky,
- is_non_singular=True,
- is_square=True)
-
- def get_cholesky_inverse(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
- assert cholesky_inv.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky_inv,
- is_non_singular=True,
- is_square=True)
-
- def get_eigendecomp(self):
- """Creates or retrieves eigendecomposition of self._cov."""
- # Unlike get_matpower this doesn't retrieve a stored variable, but instead
- # always computes a fresh version from the current value of get_cov().
- if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
-
- # The matrix self._cov is positive semidefinite by construction, but the
- # numerical eigenvalues could be negative due to numerical errors, so here
- # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
- clipped_eigenvalues = math_ops.maximum(eigenvalues,
- EIGENVALUE_CLIPPING_THRESHOLD)
- self._eigendecomp = (clipped_eigenvalues, eigenvectors)
-
- return self._eigendecomp
-
-
-class FullFactor(DenseSquareMatrixFactor):
- """FisherFactor for a full matrix representation of the Fisher of a parameter.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- self._batch_size = batch_size
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- super(FullFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_full_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return (size, size)
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- # This will be a very basic rank 1 estimate
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return ((params_grads_flat * array_ops.transpose(
- params_grads_flat)) / math_ops.cast(self._batch_size,
- params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class DiagonalFactor(FisherFactor):
- """A base class for FisherFactors that use diagonal approximations.
-
- A DiagonalFactor's covariance variable can be of any shape, but must contain
- exactly one entry per parameter.
- """
-
- def __init__(self):
- super(DiagonalFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self._matrix_diagonal.shape.ndims == 1
- return lo.LinearOperatorDiag(self._matrix_diagonal,
- is_self_adjoint=True,
- is_square=True)
-
- @property
- def _cov_initializer(self):
- return diagonal_covariance_initializer
-
- @property
- def _matrix_diagonal(self):
- return array_ops.reshape(self.get_cov(), [-1])
-
- def make_inverse_update_ops(self):
- return []
-
- def instantiate_inv_variables(self):
- pass
-
- def register_matpower(self, exp, damping_func):
- pass
-
- def register_cholesky(self, damping_func):
- pass
-
- def register_cholesky_inverse(self, damping_func):
- pass
-
- def get_matpower(self, exp, damping_func):
- matpower_diagonal = (self._matrix_diagonal
- + math_ops.cast(damping_func(), self._dtype))**exp
- return lo.LinearOperatorDiag(matpower_diagonal,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- return self.get_matpower(0.5, damping_func)
-
- def get_cholesky_inverse(self, damping_func):
- return self.get_matpower(-0.5, damping_func)
-
-
-class NaiveDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approximation of any type of param's Fisher.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- """Initializes NaiveDiagonalFactor instance.
-
- Args:
- params_grads: Sequence of Tensors, each with same shape as parameters this
- FisherFactor corresponds to. For example, the gradient of the loss with
- respect to parameters.
- batch_size: int or 0-D Tensor. Size
- """
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- self._batch_size = batch_size
- super(NaiveDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_naivediag_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return [size, 1]
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return (math_ops.square(params_grads_flat) / math_ops.cast(
- self._batch_size, params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class EmbeddingInputKroneckerFactor(DiagonalFactor):
- r"""FisherFactor for input to an embedding layer.
-
- Given input_ids = [batch_size, input_size] representing indices into an
- [vocab_size, embedding_size] embedding matrix, approximate input covariance by
- a diagonal matrix,
-
- Cov(input_ids, input_ids) =
- (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2).
-
- where n_hot() constructs an n-hot binary vector and diag() constructs a
- diagonal matrix of size [vocab_size, vocab_size].
- """
-
- def __init__(self, input_ids, vocab_size, dtype=None):
- """Instantiate EmbeddingInputKroneckerFactor.
-
- Args:
- input_ids: List of Tensors of shape [batch_size, input_size] and dtype
- int32. Indices into embedding matrix. List index is tower.
- vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
- dtype: dtype for covariance statistics. Must be a floating point type.
- Defaults to float32.
- """
- self._input_ids = input_ids
- self._vocab_size = vocab_size
- self._cov_dtype = dtype or dtypes.float32
-
- super(EmbeddingInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diag_embedding_" + scope_string_from_params(self._input_ids)
-
- @property
- def _cov_shape(self):
- return [self._vocab_size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._input_ids)
-
- @property
- def _dtype(self):
- return self._cov_dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- input_ids = self._input_ids[tower]
-
- if len(input_ids.shape) > 2:
- raise ValueError(
- "Input to embeddings must have rank <= 2. Found rank %d." % len(
- input_ids.shape))
-
- batch_size = array_ops.shape(input_ids)[0]
-
- # Transform indices into one-hot vectors.
- #
- # TODO(b/72714822): There must be a faster way to construct the diagonal
- # covariance matrix! This operation is O(batch_size * vocab_size), where
- # it should be O(batch_size * input_size).
- flat_input_ids = array_ops.reshape(input_ids, [-1])
- one_hots = array_ops.one_hot(flat_input_ids,
- self._vocab_size) # [?, vocab_size]
-
- # Take average across examples. Note that, because all entries have
- # magnitude zero or one, there's no need to square the entries.
- #
- # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
- # within an example such as average.
- #
- # TODO(b/72714822): Support for partitioned embeddings.
- new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _get_data_device(self, tower):
- return self._input_ids[tower].device
-
-
-class FullyConnectedDiagonalFactor(DiagonalFactor):
- r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
-
- Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
- approximates the covariance as,
-
- Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0
-
- where the square is taken element-wise.
- """
-
- def __init__(self,
- inputs,
- outputs_grads,
- has_bias=False):
- """Instantiate FullyConnectedDiagonalFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this
- layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size, output_size],
- which are the gradients of the loss with respect to the layer's
- outputs. First index is source, second is tower.
-
- has_bias: bool. If True, append '1' to each input.
- """
- self._inputs = inputs
- self._has_bias = has_bias
- self._outputs_grads = outputs_grads
- self._squared_inputs = None
-
- super(FullyConnectedDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diagfc_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- input_size = self._inputs[0].shape[1] + self._has_bias
- output_size = self._outputs_grads[0][0].shape[1]
- return [input_size, output_size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def make_covariance_update_op(self, ema_decay):
-
- self._squared_inputs = []
- for tower in range(self._num_towers):
- inputs = self._inputs[tower]
-
- with place_on_device(self._get_data_device(tower)):
- if self._has_bias:
- inputs = append_homog(inputs)
- self._squared_inputs.append(math_ops.square(inputs))
-
- return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
- ema_decay)
-
- def _compute_new_cov(self, source, tower):
- batch_size = array_ops.shape(self._squared_inputs[tower])[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- # The well-known special formula that uses the fact that the entry-wise
- # square of an outer product is the outer-product of the entry-wise squares.
- # The gradient is the outer product of the input and the output gradients,
- # so we just square both and then take their outer-product.
- new_cov = math_ops.matmul(
- self._squared_inputs[tower],
- math_ops.square(outputs_grad),
- transpose_a=True)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
- return new_cov
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
-
- def __init__(self,
- inputs,
- outputs_grads,
- filter_shape,
- strides,
- padding,
- data_format=None,
- dilations=None,
- has_bias=False):
- """Creates a ConvDiagonalFactor object.
-
- Args:
- inputs: List of Tensors of shape [batch_size, height, width, in_channels].
- Input activations to this layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size,
- height, width, out_channels], which are the gradients of the loss
- with respect to the layer's outputs. First index is source, second
- index is tower.
- filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
- out_channels). Represents shape of kernel used in this layer.
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (1-D of Tensor length 4).
- data_format: None or str. Format of conv2d inputs.
- dilations: None or tuple of 4 ints.
- has_bias: Python bool. If True, the layer is assumed to have a bias
- parameter in addition to its filter parameter.
-
- Raises:
- ValueError: If inputs, output_grads, and filter_shape do not agree on
- in_channels or out_channels.
- ValueError: If strides, dilations are not length-4 lists of ints.
- ValueError: If data_format does not put channel last.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- if any(input_.shape.ndims != 4 for input_ in inputs):
- raise ValueError("inputs must be a list of 4-D Tensors.")
- if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
- raise ValueError("inputs and filter_shape must agree on in_channels.")
- for i, outputs_grad in enumerate(outputs_grads):
- if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
- raise ValueError("outputs[%d] must be 4-D Tensor." % i)
- if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
- for output_grad in outputs_grad):
- raise ValueError(
- "outputs[%d] and filter_shape must agree on out_channels." % i)
- if len(strides) != 4:
- raise ValueError("strides must be length-4 list of ints.")
- if dilations is not None and len(dilations) != 4:
- raise ValueError("dilations must be length-4 list of ints.")
-
- self._inputs = inputs
- self._outputs_grads = outputs_grads
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._data_format = data_format
- self._dilations = dilations
- self._has_bias = has_bias
- self._patches = None
-
- super(ConvDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convdiag_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- filter_height, filter_width, in_channels, out_channels = self._filter_shape
- return [
- filter_height * filter_width * in_channels + self._has_bias,
- out_channels
- ]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def make_covariance_update_op(self, ema_decay):
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- if self._dilations is None:
- rates = (1, 1, 1, 1)
- else:
- rates = tuple(self._dilations)
-
- self._patches = []
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- patches = array_ops.extract_image_patches(
- self._inputs[tower],
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- if self._has_bias:
- patches = append_homog(patches)
-
- self._patches.append(patches)
-
- return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
-
- def _compute_new_cov(self, source, tower):
- patches = self._patches[tower]
- batch_size = array_ops.shape(patches)[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _convdiag_sum_of_squares(self, patches, outputs_grad):
- # This computes the sum of the squares of the per-training-case "gradients".
- # It does this simply by computing a giant tensor containing all of these,
- # doing an entry-wise square, and them summing along the batch dimension.
- case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
- outputs_grad)
- return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
- """Kronecker factor for the input or output side of a fully-connected layer.
- """
-
- def __init__(self,
- tensors,
- has_bias=False):
- """Instantiate FullyConnectedKroneckerFactor.
-
- Args:
- tensors: List of list of Tensors, each of shape [batch_size, n]. The
- Tensors are typically either a layer's inputs or its output's gradients.
- The first list index is source, the second is tower.
- has_bias: bool. If True, append '1' to each row.
- """
- # The tensor argument is either a tensor of input activations or a tensor of
- # output pre-activation gradients.
- self._has_bias = has_bias
- self._tensors = tensors
- super(FullyConnectedKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_fckron_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors)) + (self._has_bias,))
-
- @property
- def _cov_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._tensors)
-
- @property
- def _num_towers(self):
- return len(self._tensors[0])
-
- @property
- def _dtype(self):
- return self._tensors[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- tensor = self._tensors[source][tower]
- if self._has_bias:
- tensor = append_homog(tensor)
- return compute_cov(tensor)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
-
-class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the input side of a convolutional layer.
-
- Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
- example x. Expectation is taken over all examples and locations.
-
- Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self,
- inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- has_bias=False,
- sub_sample_inputs=None,
- sub_sample_patches=None):
- """Initializes ConvInputKroneckerFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,
- in_channels]. Inputs to layer. List index is tower.
- filter_shape: List of ints. Contains [..spatial_filter_size..,
- in_channels, out_channels]. Shape of convolution kernel.
- padding: str. Padding method for layer. "SAME" or "VALID".
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- has_bias: bool. If True, append 1 to in_channel.
- sub_sample_inputs: `bool`. If True, then subsample the inputs from which
- the image patches are extracted. (Default: None)
- sub_sample_patches: `bool`, If `True` then subsample the extracted
- patches.(Default: None)
- """
- self._inputs = inputs
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._dilation_rate = dilation_rate
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = has_bias
- if sub_sample_inputs is None:
- self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
- else:
- self._sub_sample_inputs = sub_sample_inputs
-
- if sub_sample_patches is None:
- self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
- else:
- self._sub_sample_patches = sub_sample_patches
- super(ConvInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convinkron_" + scope_string_from_params(
- tuple(self._inputs) +
- tuple((self._filter_shape, self._strides, self._padding,
- self._dilation_rate, self._data_format, self._has_bias)))
-
- @property
- def _cov_shape(self):
- spatial_filter_shape = self._filter_shape[0:-2]
- in_channels = self._filter_shape[-2]
- size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- inputs = self._inputs[tower]
- if self._sub_sample_inputs:
- batch_size = inputs.shape.as_list()[0]
- max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
- inputs = _random_tensor_gather(inputs, max_size)
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- if self._extract_patches_fn in [None, "extract_convolution_patches"]:
- patches = utils.extract_convolution_patches(
- inputs,
- self._filter_shape,
- padding=self._padding,
- strides=self._strides,
- dilation_rate=self._dilation_rate,
- data_format=self._data_format)
-
- elif self._extract_patches_fn == "extract_image_patches":
- assert inputs.shape.ndims == 4
- assert len(self._filter_shape) == 4
- assert len(self._strides) == 4, self._strides
- if self._dilation_rate is None:
- rates = [1, 1, 1, 1]
- else:
- rates = self._dilation_rate
- assert len(rates) == 4
- assert rates[0] == rates[-1] == 1
- patches = array_ops.extract_image_patches(
- inputs,
- ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
- assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
- assert self._filter_shape[0] == self._filter_shape[1] == 1
- patches = utils.extract_pointwise_conv2d_patches(
- inputs, self._filter_shape, data_format=None)
-
- else:
- raise NotImplementedError(self._extract_patches_fn)
-
- flatten_size = np.prod(self._filter_shape[0:-1])
- # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
- # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
- # where M = minibatch size, |T| = number of spatial locations,
- # |Delta| = number of spatial offsets, and J = number of input maps
- # for convolutional layer l.
- patches_flat = array_ops.reshape(patches, [-1, flatten_size])
-
- # We append a homogenous coordinate to patches_flat if the layer has
- # bias parameters. This gives us [[A_l]]_H from the paper.
- if self._sub_sample_patches:
- patches_flat = _subsample_for_cov_computation(patches_flat)
-
- if self._has_bias:
- patches_flat = append_homog(patches_flat)
- # We call compute_cov without passing in a normalizer. compute_cov uses
- # the first dimension of patches_flat i.e. M|T| as the normalizer by
- # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
- # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
- # the paper but has a different scale here for consistency with
- # ConvOutputKroneckerFactor.
- # (Tilde omitted over A for clarity.)
- return compute_cov(patches_flat)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the output side of a convolutional layer.
-
- Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
- given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
- all examples and locations.
-
- Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self, outputs_grads, data_format=None):
- """Initializes ConvOutputKroneckerFactor.
-
- Args:
- outputs_grads: List of list of Tensors. Each Tensor is of shape
- [batch_size, ..spatial_input_size.., out_channels]. First list index
- is source, the second is tower.
- data_format: None or str. Format of outputs_grads.
-
- Raises:
- ValueError: If channels are not final dimension.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- self._out_channels = outputs_grads[0][0].shape.as_list()[-1]
- self._outputs_grads = outputs_grads
- super(ConvOutputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convoutkron_" + scope_string_from_params(
- nest.flatten(self._outputs_grads))
-
- @property
- def _cov_shape(self):
- size = self._out_channels
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._outputs_grads[0])
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- outputs_grad = self._outputs_grads[source][tower]
-
- # reshaped_tensor below is the matrix DS_l defined in the KFC paper
- # (tilde omitted over S for clarity). It has shape M|T| x I, where
- # M = minibatch size, |T| = number of spatial locations, and
- # I = number of output maps for convolutional layer l.
- reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels])
- # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
- # as defined in the paper, with shape I x I.
- # (Tilde omitted over S for clarity.)
- return compute_cov(reshaped_tensor)
-
- def _get_data_device(self, tower):
- return self._outputs_grads[0][tower].device
-
-
-class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
- """Kronecker factor for a fully connected layer used multiple times."""
-
- def __init__(self,
- tensors,
- num_uses=None,
- has_bias=False):
- """Constructs a new `FullyConnectedMultiKF`.
-
- Args:
- tensors: List of list of Tensors of shape, each of shape
- [num_uses * batch_size, n], and is a reshape version of a Tensor of
- shape [num_uses, batch_size, n]. Each of these tensors is usually a
- layer's inputs or its output's gradients. The first list index is
- sources, the second is towers.
- num_uses: int. The number of time-steps / uses.
- has_bias: bool. If True, '1' is appended to each row.
- """
-
- self._num_uses = num_uses
-
- self._cov_dt1 = None
- self._make_cov_dt1 = False
- self._option1quants_by_damping = {}
- self._option2quants_by_damping = {}
- self._option1quants_registrations = set()
- self._option2quants_registrations = set()
-
- super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
- has_bias=has_bias)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _var_scope(self):
- return "ff_fc_multi_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors))
- + (self._num_timesteps, self._has_bias,))
-
- def make_covariance_update_op(self, ema_decay):
-
- op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
-
- if self._cov_dt1 is not None:
- new_cov_dt1_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,
- tower))
-
- new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
- / float(self._num_towers))
-
- # See comments in FisherFactor.make_covariance_update_op() for details.
- if utils.on_tpu():
- new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
-
- op2 = moving_averages.assign_moving_average(
- self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
-
- # TODO(b/69112164):
- # It's important that _cov and _cov_dt1 remain consistent with each
- # other while the inverse ops are happening. How can we ensure this?
- # We will need to add explicit synchronization for this to
- # work with asynchronous training.
- op = control_flow_ops.group(op, op2)
-
- return op
-
- def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring
- tensor = self._tensors[source][tower]
- if self._has_bias:
- # This appending is technically done twice (the other time is for
- # _compute_new_cov())
- tensor = append_homog(tensor)
-
- total_len = array_ops.shape(tensor)[0]
- batch_size = total_len // self._num_timesteps
-
- tensor_present = tensor[:-batch_size, :]
- tensor_future = tensor[batch_size:, :]
-
- # We specify a normalizer for this computation to ensure a PSD Fisher
- # block estimate. This is equivalent to padding with zeros, as was done
- # in Section B.2 of the appendix.
- return compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=total_len)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
- @property
- def _vec_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size]
-
- def get_option1quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option1quants_by_damping[damping_id]
-
- def get_option2quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option2quants_by_damping[damping_id]
-
- def get_cov_dt1(self):
- assert self._cov_dt1 is not None
- return self._cov_dt1
-
- def register_cov_dt1(self):
- self._make_cov_dt1 = True
-
- def instantiate_cov_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_cov_variables()
- assert self._cov_dt1 is None
- if self._make_cov_dt1:
- with variable_scope.variable_scope(self._var_scope):
- self._cov_dt1 = variable_scope.get_variable(
- "cov_dt1",
- initializer=init_ops.zeros_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- def register_option1quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option1quants_registrations:
- self._option1quants_registrations.add(damping_id)
-
- def register_option2quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option2quants_registrations:
- self._option2quants_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_inv_variables()
-
- for damping_id in self._option1quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Lmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- psi = variable_scope.get_variable(
- "psi_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option1quants_by_damping
- self._option1quants_by_damping[damping_id] = (Lmat, psi)
-
- for damping_id in self._option2quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Pmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- Kmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Kmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- mu = variable_scope.get_variable(
- "mu_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option2quants_by_damping
- self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- # TODO(b/69918258): Add correctness tests for this method.
- # pylint: disable=invalid-name
-
- ops = []
-
- if (len(self._option1quants_by_damping) +
- len(self._option2quants_by_damping)):
-
- # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
- # the pseudo-code in the original paper. Because the computations for
- # the A and G case are essentially the same they can both be performed by
- # the same class (this one).
-
- C1 = self.get_cov_dt1()
-
- # Get the eigendecomposition of C0 (= self.get_cov())
- eigen_e, eigen_V = self.get_eigendecomp()
-
- # TODO(b/69678661): Note, there is an implicit assumption here that C1
- # and C0 (as represented here by its eigen-decomp) are consistent. This
- # could fail to be the case if self._cov and self._cov_dt1 are not updated
- # consistently, or are somehow read between or during the cov updates.
- # Can this possibly happen? Is there a way to prevent it?
-
- for damping_id, (Lmat_var,
- psi_var) in self._option1quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # The following line imposes the symmetry assumed by "Option 1" on C1.
- # Strangely the code can work okay with this line commented out,
- # depending on how psd_eig is defined. I'm not sure why.
- C1 = (C1 + array_ops.transpose(C1)) / 2.0
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0)
-
- # Compute the decomposition U*diag(psi)*U^T = hPsi
- psi, U = utils.posdef_eig(hPsi)
-
- # L = C0^(-1/2) * U
- Lmat = math_ops.matmul(invsqrtC0, U)
-
- ops.append(Lmat_var.assign(Lmat))
- ops.append(psi_var.assign(psi))
-
- for damping_id, (Pmat_var, Kmat_var,
- mu_var) in self._option2quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- # compute C0^(-1/2)
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # Compute the product C0^(-1/2) * C1
- invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1)
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0)
-
- # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
- # Note that we using the notation mu instead of "m" for the eigenvalues.
- # Instead of computing the product hPsi^T * hPsi and then doing an
- # eigen-decomposition of this we just compute the SVD of hPsi and then
- # square the singular values to get the eigenvalues. For a justification
- # of this approach, see:
- # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
- sqrtmu, _, E = linalg_ops.svd(hPsi)
- mu = math_ops.square(sqrtmu)
-
- # Mathematically, the eigenvalues should not should not exceed 1.0, but
- # due to numerical issues, or possible issues with inconsistent
- # values of C1 and (the eigen-decomposition of) C0 they might. So
- # we enforce this condition.
- mu = math_ops.minimum(mu, 1.0)
-
- # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
- Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)
-
- # K = C_0^(-1/2) * E
- Kmat = math_ops.matmul(invsqrtC0, E)
-
- ops.append(Pmat_var.assign(Pmat))
- ops.append(Kmat_var.assign(Kmat))
- ops.append(mu_var.assign(mu))
-
- ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
- return [control_flow_ops.group(*ops)]
-
- # pylint: enable=invalid-name
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
deleted file mode 100644
index 2d8e378a93..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_factors import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "inverse_initializer", "covariance_initializer",
- "diagonal_covariance_initializer", "scope_string_from_params",
- "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
- "InverseProvidingFactor", "FullFactor", "DiagonalFactor",
- "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
- "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
- "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
- "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
- "compute_cov", "append_homog"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
deleted file mode 100644
index 43aa713edc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ /dev/null
@@ -1,1269 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from collections import OrderedDict
-from contextlib import contextmanager
-from functools import partial
-import warnings
-
-import math
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import loss_functions as lf
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-# Names for various approximations that can be requested for Fisher blocks.
-APPROX_KRONECKER_NAME = "kron"
-APPROX_DIAGONAL_NAME = "diagonal"
-APPROX_FULL_NAME = "full"
-
-_GENERIC_APPROX_TO_BLOCK_TYPES = {
- APPROX_FULL_NAME: fb.FullFB,
- APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
-}
-
-_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
- APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
-}
-
-_CONV2D_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
- APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
-}
-
-_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
-}
-
-APPROX_KRONECKER_INDEP_NAME = "kron_indep"
-APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
-APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
-
-_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB,
- APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB,
- option=1),
- APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB,
- option=2)
-}
-
-_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
-}
-
-_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
-}
-
-# Possible value for `reuse` keyword argument. Sets `reuse` to
-# tf.get_variable_scope().reuse.
-VARIABLE_SCOPE = "VARIABLE_SCOPE"
-
-_DEFAULT_LAYER_COLLECTION = None
-
-
-def get_default_layer_collection():
- """Get default LayerCollection."""
- if _DEFAULT_LAYER_COLLECTION is None:
- raise ValueError(
- "Attempted to retrieve default LayerCollection when none is set. Use "
- "LayerCollection.as_default().")
-
- return _DEFAULT_LAYER_COLLECTION
-
-
-def set_default_layer_collection(layer_collection):
- global _DEFAULT_LAYER_COLLECTION
-
- if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
- raise ValueError("Default LayerCollection is already set.")
-
- _DEFAULT_LAYER_COLLECTION = layer_collection
-
-
-class LayerParametersDict(OrderedDict):
- """An OrderedDict where keys are Tensors or tuples of Tensors.
-
- Ensures that no Tensor is associated with two different keys.
- """
-
- def __init__(self, *args, **kwargs):
- self._tensors = set()
- super(LayerParametersDict, self).__init__(*args, **kwargs)
-
- def __setitem__(self, key, value):
- key = self._canonicalize_key(key)
- tensors = key if isinstance(key, (tuple, list)) else (key,)
- key_collisions = self._tensors.intersection(tensors)
- if key_collisions:
- raise ValueError("Key(s) already present: {}".format(key_collisions))
- self._tensors.update(tensors)
- super(LayerParametersDict, self).__setitem__(key, value)
-
- def __delitem__(self, key):
- key = self._canonicalize_key(key)
- self._tensors.remove(key)
- super(LayerParametersDict, self).__delitem__(key)
-
- def __getitem__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__getitem__(key)
-
- def __contains__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__contains__(key)
-
- def _canonicalize_key(self, key):
- if isinstance(key, (list, tuple)):
- return tuple(key)
- return key
-
-
-# TODO(b/68034464): add capability for LayerCollection to be "finalized"
-# and do this when it gets used by FisherEstimator / KfacOptimizer.
-
-
-class LayerCollection(object):
- """Registry of information about layers and losses.
-
- Note that you need to create a new one of these for each MatrixEstimator or
- KfacOptimizer.
-
- Attributes:
- fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
- parameters (Tensors or tuples of Tensors) to FisherBlock instances.
- fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
- losses: a list of LossFunction objects. The loss to be optimized is their
- sum.
- loss_colocation_ops: ops to colocate loss function evaluations with. These
- will typically be the inputs to the losses.
- """
-
- def __init__(self,
- graph=None,
- name="LayerCollection"):
- warnings.warn(
- "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. "
- "Use https://pypi.python.org/pypi/kfac instead.")
- self.fisher_blocks = LayerParametersDict()
- self.fisher_factors = OrderedDict()
- self._linked_parameters = dict(
- ) # dict mapping sets of variables to optionally specified approximations.
- self._graph = graph or ops.get_default_graph()
- self._loss_dict = {} # {str: LossFunction}
- self._subgraph = None
- self._default_generic_approximation = APPROX_DIAGONAL_NAME
- self._default_embedding_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
- self._default_conv2d_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_conv2d_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
- self.loss_colocation_ops = {}
- self._vars_to_uses = defaultdict(lambda: 0)
-
- with variable_scope.variable_scope(None, default_name=name) as scope:
- self._var_scope = scope.name
-
- @property
- def losses(self):
- """Tuple of LossFunction objects registered with this LayerCollection."""
- return nest.flatten(self.towers_by_loss)
-
- @property
- def towers_by_loss(self):
- """Tuple across losses of LossFunction objects registered to each tower."""
- return tuple(tuple(lst) for lst in self._loss_dict.values())
-
- @property
- def registered_variables(self):
- """A tuple of all of the variables currently registered."""
- tuple_of_tuples = (utils.ensure_sequence(key) for key, block
- in six.iteritems(self.fisher_blocks))
- flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
- return flat_tuple
-
- @property
- def linked_parameters(self):
- """Groups of parameters with an optionally specified approximation.
-
- Linked parameters can be added using `define_linked_parameters`.
- If an approximation is specified, then this approximation will be used
- when registering a layer with exactly these parameters, unless an
- approximation is specified when calling the registration function.
-
- Returns:
- A `dict` mapping tuples of parameters to an optional string.
- """
- return self._linked_parameters
-
- @property
- def default_embedding_approximation(self):
- return self._default_embedding_approximation
-
- def set_default_embedding_approximation(self, value):
- if value != APPROX_KRONECKER_NAME:
- raise ValueError(
- "{} is not a valid approximation for embedding variables.".format(
- value))
- self._default_embedding_approximation = value
-
- @property
- def default_generic_approximation(self):
- return self._default_generic_approximation
-
- def set_default_generic_approximation(self, value):
- if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for generic variables.".format(
- value))
- self._default_generic_approximation = value
-
- @property
- def default_fully_connected_approximation(self):
- return self._default_fully_connected_approximation
-
- def set_default_fully_connected_approximation(self, value):
- if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for fully connected layers.".format(
- value))
- self._default_fully_connected_approximation = value
-
- @property
- def default_conv2d_approximation(self):
- return self._default_conv2d_approximation
-
- def set_default_conv2d_approximation(self, value):
- if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for 2d convolutional layers.".format(
- value))
- self._default_conv2d_approximation = value
-
- @property
- def default_fully_connected_multi_approximation(self):
- return self._default_fully_connected_multi_approximation
-
- def set_default_fully_connected_multi_approximation(self, value):
- if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
- raise ValueError("{} is not a valid approximation for a fully-connected "
- "multi layer.".format(value))
- self._default_fully_connected_multi_approximation = value
-
- @property
- def default_conv2d_multi_approximation(self):
- return self._default_conv2d_multi_approximation
-
- @property
- def default_embedding_multi_approximation(self):
- return self._default_embedding_multi_approximation
-
- def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
- """Validates and registers the layer_key associated with the fisher_block.
-
- Args:
- layer_key: A variable or tuple of variables. The key to check for in
- existing registrations and to register if valid.
- fisher_block: The associated `FisherBlock`.
- reuse: Method to use for inserting new `FisherBlock's. One of True, False,
- or `VARIABLE_SCOPE`.
-
- Raises:
- ValueError: If `layer_key` was already registered and reuse is `False`,
- if `layer_key` was registered with a different block type, or if
- `layer_key` shares any variables with but is not equal to a previously
- registered key.
- KeyError: If `reuse` is `True` but `layer_key` was not previously
- registered.
-
- Returns:
- The `FisherBlock` registered under `layer_key`. If `layer_key` was already
- registered, this will be the previously registered `FisherBlock`.
- """
- if reuse is VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse is True or (reuse is variable_scope.AUTO_REUSE and
- layer_key in self.fisher_blocks):
- result = self.fisher_blocks[layer_key]
- if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
- raise ValueError(
- "Attempted to register FisherBlock of type %s when existing "
- "FisherBlock has type %s." % (type(fisher_block), type(result)))
- return result
- if reuse is False and layer_key in self.fisher_blocks:
- raise ValueError("FisherBlock for %s is already in LayerCollection." %
- (layer_key,))
-
- # Insert fisher_block into self.fisher_blocks.
- if layer_key in self.fisher_blocks:
- raise ValueError("Duplicate registration: {}".format(layer_key))
- # Raise an error if any variable in layer_key has been registered in any
- # other blocks.
- variable_to_block = {
- var: (params, block)
- for (params, block) in self.fisher_blocks.items()
- for var in utils.ensure_sequence(params)
- }
- for variable in utils.ensure_sequence(layer_key):
- if variable in variable_to_block:
- prev_key, prev_block = variable_to_block[variable]
- raise ValueError(
- "Attempted to register layer_key {} with block {}, but variable {}"
- " was already registered in key {} with block {}.".format(
- layer_key, fisher_block, variable, prev_key, prev_block))
- self.fisher_blocks[layer_key] = fisher_block
- return fisher_block
-
- def register_loss_function(self,
- loss,
- colocation_op,
- base_name,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a LossFunction object.
-
- Args:
- loss: The LossFunction object.
- colocation_op: The op to colocate the loss function's computations with.
- base_name: The name to derive a new unique name from is the name argument
- is None.
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional
- tower for the existing loss function.
-
- Raises:
- ValueError: If reuse == True and name == None.
- ValueError: If reuse == True and seed != None.
- KeyError: If reuse == True and no existing LossFunction with `name` found.
- KeyError: If reuse == False and existing LossFunction with `name` found.
- """
-
- name = name or self._graph.unique_name(base_name)
-
- if reuse == VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse:
- if name is None:
- raise ValueError(
- "If reuse is enabled, loss function's name must be set.")
-
- loss_list = self._loss_dict.get(name, None)
-
- if loss_list is None:
- raise KeyError(
- "Unable to find loss function named {}. Register a new loss "
- "function with reuse=False.".format(name))
- else:
- if name in self._loss_dict:
- raise KeyError(
- "Loss function named {} already exists. Set reuse=True to append "
- "another tower.".format(name))
-
- loss_list = []
- self._loss_dict[name] = loss_list
-
- loss_list.append(loss)
- self.loss_colocation_ops[loss] = colocation_op
-
- def _get_use_count_map(self):
- """Returns a dict mapping variables to their number of registrations."""
- return self._vars_to_uses
-
- def _add_uses(self, params, uses):
- """Register additional uses by params in the graph.
-
- Args:
- params: Variable or tuple of Variables. Parameters for a layer.
- uses: int or float. Number of additional uses for these parameters.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
- for var in params:
- self._vars_to_uses[var] += uses
-
- def check_registration(self, variables):
- """Checks that all variable uses have been registered properly.
-
- Args:
- variables: List of variables.
-
- Raises:
- ValueError: If any registered variables are not included in the list.
- ValueError: If any variable in the list is not registered.
- ValueError: If any variable in the list is registered with the wrong
- number of "uses" in the subgraph recorded (vs the number of times that
- variable is actually used in the subgraph).
- """
- # Note that overlapping parameters (i.e. those that share variables) will
- # be caught by layer_collection.LayerParametersDict during registration.
-
- reg_use_map = self._get_use_count_map()
-
- error_messages = []
-
- for var in variables:
- total_uses = self.subgraph.variable_uses(var)
- reg_uses = reg_use_map[var]
-
- if reg_uses == 0:
- error_messages.append("Variable {} not registered.".format(var))
- elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
- error_messages.append(
- "Variable {} registered with wrong number of uses ({} "
- "registrations vs {} uses).".format(var, reg_uses, total_uses))
-
- num_get_vars = len(reg_use_map)
-
- if num_get_vars > len(variables):
- error_messages.append("{} registered variables were not included in list."
- .format(num_get_vars - len(variables)))
-
- if error_messages:
- error_messages = [
- "Found the following errors with variable registration:"
- ] + error_messages
- raise ValueError("\n\t".join(error_messages))
-
- def get_blocks(self):
- return self.fisher_blocks.values()
-
- def get_factors(self):
- return self.fisher_factors.values()
-
- @property
- def graph(self):
- return self._graph
-
- @property
- def subgraph(self):
- return self._subgraph
-
- def define_linked_parameters(self, params, approximation=None):
- """Identify a set of parameters that should be grouped together.
-
- During automatic graph scanning, any matches containing variables that have
- been identified as part of a linked group will be filtered out unless
- the match parameters are exactly equal to the ones specified in the linked
- group.
-
- Args:
- params: A variable, or a tuple or list of variables. The variables
- to be linked.
- approximation: Optional string specifying the type of approximation to use
- for these variables. If unspecified, this layer collection's default
- approximation for the layer type will be used.
-
- Raises:
- ValueError: If the parameters were already registered in a layer or
- identified as part of an incompatible group.
- """
- params = frozenset(utils.ensure_sequence(params))
-
- # Check if any of the variables in `params` is already in
- # 'self.fisher_blocks.keys()`.
- for registered_params, fisher_block in self.fisher_blocks.items():
- registered_params_set = set(utils.ensure_sequence(registered_params))
- for variable in params:
- if (variable in registered_params_set and
- params != registered_params_set):
- raise ValueError(
- "Can`t link parameters {}, variable {} was already registered in "
- "group {} with layer {}".format(params, variable,
- registered_params, fisher_block))
-
- # Check if any of the variables in `params` is already in
- # 'self.linked_parameters`.
- for variable in params:
- for other_linked_params in self.linked_parameters:
- if variable in other_linked_params:
- raise ValueError("Can`t link parameters {}, variable {} was already "
- "linked in group {}.".format(params, variable,
- other_linked_params))
- self._linked_parameters[params] = approximation
-
- def create_subgraph(self):
- if not self.losses:
- raise ValueError("Must have at least one registered loss.")
- inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
- self._subgraph = utils.SubGraph(inputs_to_losses)
-
- def eval_losses(self):
- """Return evaluated losses (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate())
- return evals
-
- def eval_losses_on_samples(self):
- """Return losses evaluated on samples (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate_on_sample())
- return evals
-
- def total_loss(self):
- return math_ops.add_n(self.eval_losses())
-
- def total_sampled_loss(self):
- return math_ops.add_n(self.eval_losses_on_samples())
-
- def _get_linked_approx(self, params):
- """If params were linked, return their specified approximation."""
- params_set = frozenset(utils.ensure_sequence(params))
- if params_set in self.linked_parameters:
- return self.linked_parameters[params_set]
- else:
- return None
-
- def _get_block_type(self, params, approx, default, approx_to_type):
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = default
-
- if approx not in approx_to_type:
- raise ValueError("Bad value {} for approx.".format(approx))
-
- return approx_to_type[approx], approx
-
- def register_embedding(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers an embedding layer.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
- into embedding matrix.
- outputs: Tensor of shape [batch_size, embedding_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be "kron". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_approximation,
- _EMBEDDING_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
- block = self.register_block(
- params, block_type(self, vocab_size), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_fully_connected(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a fully connected layer.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_approximation,
- _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_conv2d(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a call to tf.nn.conv2d().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: List of 4 ints. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
- to layer.
- outputs: Tensor of shape [batch_size, height, width, out_channels].
- Output produced by layer.
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_approximation,
- _CONV2D_APPROX_TO_BLOCK_TYPES)
-
- # It feels bad to pass in configuration that has to do with the internal
- # implementation. And then we can`t use the same constructor for both
- # anymore and are thus forced to use this ugly if-statement.
- # TODO(b/74793309): Clean this up?
- if approx == APPROX_KRONECKER_NAME:
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches"),
- reuse=reuse)
- elif approx == APPROX_DIAGONAL_NAME:
- assert strides[0] == strides[-1] == 1
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilations=dilations,
- data_format=data_format),
- reuse=reuse)
- else:
- raise NotImplementedError(approx)
-
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_convolution(self,
- params,
- inputs,
- outputs,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.convolution().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [..filter_spatial_size..,
- in_channels, out_channels]. Bias should have shape [out_channels].
- inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
- Inputs to layer.
- outputs: Tensor of shape [batch_size, ..output_spatial_size..,
- out_channels]. Output produced by layer.
- padding: string. see tf.nn.conv2d for valid values.
- strides: List of ints of length len(..input_spatial_size..). Strides for
- convolution kernel in spatial dimensions.
- dilation_rate: List of ints of length len(..input_spatial_size..).
- Dilations along spatial dimension.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_KRONECKER_NAME
-
- block = self.register_block(
- params,
- fb.ConvKFCBasicFB(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_depthwise_conv2d(self,
- params,
- inputs,
- outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.depthwise_conv2d().
-
- Args:
- params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Convolutional filter.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_height, output_width,
- in_channels * channel_multiplier]. Output produced by depthwise conv2d.
- strides: List of ints of length 4. Strides along all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rates in spatial
- dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must "diagonal". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_DIAGONAL_NAME
- assert data_format in [None, "NHWC"]
-
- block = self.register_block(
- params,
- fb.DepthwiseConvDiagonalFB(
- layer_collection=self,
- params=params,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_separable_conv2d(self,
- depthwise_params,
- pointwise_params,
- inputs,
- depthwise_outputs,
- pointwise_outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.separable_conv2d().
-
- Note: This requires access to intermediate outputs between depthwise and
- pointwise convolutions.
-
- Args:
- depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Filter for depthwise conv2d.
- pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
- channel_multiplier, out_channels]. Filter for pointwise conv2d.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- depthwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, in_channels * channel_multiplier]. Output produced by
- depthwise conv2d.
- pointwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, out_channels]. Output produced by pointwise conv2d.
- strides: List of ints of length 4. Strides for depthwise conv2d kernel in
- all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
- kernel in spatial dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- self.register_depthwise_conv2d(
- params=depthwise_params,
- inputs=inputs,
- outputs=depthwise_outputs,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format,
- approx=APPROX_DIAGONAL_NAME,
- reuse=reuse)
-
- self.register_conv2d(
- params=pointwise_params,
- inputs=depthwise_outputs,
- outputs=pointwise_outputs,
- strides=[1, 1, 1, 1],
- padding="VALID",
- data_format=data_format,
- approx=approx,
- reuse=reuse)
-
- def register_generic(self,
- params,
- batch_size,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a generic layer.
-
- Args:
- params: Tensor or tuple of Tensors corresponding to the parameters.
- batch_size: 0-D Tensor. Size of the minibatch (for this tower).
- approx: str or None. It not None, must be one of "full" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `batch_size` to the total
- mini-batch size use when estimating the Fisher block for this layer
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_generic_approximation,
- _GENERIC_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(params, block_type(self, params), reuse=reuse)
- block.register_additional_tower(batch_size)
-
- self._add_uses(params, float("inf"))
-
- def register_fully_connected_multi(self, params, inputs, outputs,
- num_uses=None, approx=None,
- reuse=VARIABLE_SCOPE):
- """Register fully connected layers with shared parameters.
-
- This can handle general fully-connected layers with shared parameters, but
- has specialized approximations to deal with the case where there is a
- meaningful linear order to the share instances (such as in an RNN).
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
- to layer. The list indexes each use in the graph (which might
- correspond to a "time-step" in an RNN). OR, can be single Tensor, of
- shape [num_uses * batch_size , input_size], which is a reshaped version
- of a Tensor of shape [num_uses, batch_size, input_size].
- outputs: A list of Tensors, the same length as `inputs`, each of shape
- [batch_size, output_size]. Outputs produced by layer. The list indexes
- each use in the graph (which might correspond to a "time-step" in an
- RNN). Needs to correspond with the order used in `inputs`. OR, can be
- a single Tensor of shape [num_uses * batch_size, output_size], which is
- a reshaped version of a Tensor of shape [num_uses, batch_size,
- output_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
- or "kron_series_2". The Fisher approximation to use. If None the default
- value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_multi_approximation,
- _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
-
- # TODO(b/70283649): something along the lines of find_canonical_output
- # should be added back in here (and for the other block types, arguably).
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias,
- num_uses=num_uses),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_conv2d_multi(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- num_uses=None,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers convolutional layers with shared parameters.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: 1-D Tensor of length 4. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: A list of Tensors, each of shape [batch_size, height, width,
- in_channels]. Inputs to layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). OR, can be single
- Tensor, of shape [num_uses * batch_size, height, width, in_channels],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- height, width, in_channels].
- outputs: A list of Tensors, each of shape [batch_size, height, width,
- out_channels]. Output produced by layer. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- Needs to correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, height, width,
- out_channels], which is a reshaped version of a Tensor of shape
- [num_uses, batch_size, height, width, out_channels].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_multi_approximation,
- _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches",
- num_uses=num_uses),
- reuse=reuse)
-
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- # TODO(b/74108452): change the loss registration functions names to refer
- # to "loss functions" instead of distributions. Following naming convention
- # of the loss function classes themselves.
-
- def register_embedding_multi(self,
- params,
- inputs,
- outputs,
- num_uses=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers embedding layers with shared parameters.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size] and
- dtype int32. Indices into embedding matrix. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- input_size].
- outputs: A list of Tensors, each of shape [batch_size, embedding_size].
- Outputs produced by layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). Needs to
- correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, embedding_size], which
- is a reshaped version of a Tensor of shape [num_uses, batch_size,
- embedding_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_multi_approximation,
- _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
-
- block = self.register_block(
- params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- if isinstance(inputs, (tuple, list)):
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_categorical_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a categorical predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "categorical_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_normal_predictive_distribution(self,
- mean,
- var=0.5,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a normal predictive distribution.
-
- Args:
- mean: The mean vector defining the distribution.
- var: The variance (must be a scalar). Note that the default value of
- 0.5 corresponds to a standard squared error loss (target -
- prediction)**2. If your squared error loss is of the form
- 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `mean` and `var` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
- seed=seed)
- self.register_loss_function(loss, mean,
- "normal_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_multi_bernoulli_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a multi-Bernoulli predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "multi_bernoulli_predictive_distribution",
- name=name, reuse=reuse)
-
- def make_or_get_factor(self, cls, args):
- """Insert `cls(args)` into 'self.fisher_factors` if not already present.
-
- Wraps constructor in `tf.variable_scope()` to ensure variables constructed
- in `cls.__init__` are placed under this LayerCollection's scope.
-
- Args:
- cls: Class that implements FisherFactor.
- args: Tuple of arguments to pass into `cls's constructor. Must be
- hashable.
-
- Returns:
- Instance of `cls` found in self.fisher_factors.
- """
- try:
- hash(args)
- except TypeError:
- raise TypeError(
- ("Unable to use (cls, args) = ({}, {}) as a key in "
- "LayerCollection.fisher_factors. The pair cannot be hashed.").format(
- cls, args))
-
- key = cls, args
- if key not in self.fisher_factors:
- with variable_scope.variable_scope(self._var_scope):
- self.fisher_factors[key] = cls(*args)
- return self.fisher_factors[key]
-
- @contextmanager
- def as_default(self):
- """Sets this LayerCollection as the default."""
- set_default_layer_collection(self)
- yield
- set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
deleted file mode 100644
index 9f46853807..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.layer_collection import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "get_default_layer_collection",
- "set_default_layer_collection",
- "LayerParametersDict",
- "LayerCollection",
- "APPROX_KRONECKER_NAME",
- "APPROX_DIAGONAL_NAME",
- "APPROX_FULL_NAME",
- "VARIABLE_SCOPE",
- "APPROX_KRONECKER_INDEP_NAME",
- "APPROX_KRONECKER_SERIES_1_NAME",
- "APPROX_KRONECKER_SERIES_2_NAME"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py
deleted file mode 100644
index 61cb955ae8..0000000000
--- a/tensorflow/contrib/kfac/python/ops/linear_operator.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SmartMatrices definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.linalg import linalg
-from tensorflow.python.ops.linalg import linalg_impl
-from tensorflow.python.ops.linalg import linear_operator_util as lou
-
-
-class LinearOperatorExtras(object): # pylint: disable=missing-docstring
-
- def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -2 if adjoint else -1
- arg_dim = -1 if adjoint_arg else -2
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
-
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_right_sparse(
- x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -1 if adjoint else -2
- arg_dim = -2 if adjoint_arg else -1
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
-
-class LinearOperatorFullMatrix(LinearOperatorExtras,
- linalg.LinearOperatorFullMatrix):
-
- # TODO(b/78117889) Remove this definition once core LinearOperator
- # has _matmul_right.
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- return lou.matmul_with_broadcast(
- x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- assert not adjoint and not adjoint_arg
- return utils.matmul_sparse_dense(x, self._matrix)
-
-
-class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
- linalg.LinearOperatorDiag):
-
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- x = linalg_impl.adjoint(x) if adjoint_arg else x
- return diag_mat * x
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- assert not adjoint_arg
- return utils.matmul_diag_sparse(diag_mat, x)
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
deleted file mode 100644
index c8cebc42cb..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ /dev/null
@@ -1,754 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.contrib.distributions.python.ops import onehot_categorical
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import bernoulli
-from tensorflow.python.ops.distributions import categorical
-from tensorflow.python.ops.distributions import normal
-
-
-@six.add_metaclass(abc.ABCMeta)
-class LossFunction(object):
- """Abstract base class for loss functions.
-
- Note that unlike typical loss functions used in neural networks these are
- summed and not averaged across cases in the batch, since this is what the
- users of this class (FisherEstimator and MatrixVectorProductComputer) will
- be expecting. The implication of this is that you will may want to
- normalize things like Fisher-vector products by the batch size when you
- use this class. It depends on the use case.
- """
-
- @abc.abstractproperty
- def targets(self):
- """The targets being predicted by the model.
-
- Returns:
- None or Tensor of appropriate shape for calling self._evaluate() on.
- """
- pass
-
- @abc.abstractproperty
- def inputs(self):
- """The inputs to the loss function (excluding the targets)."""
- pass
-
- def evaluate(self):
- """Evaluate the loss function on the targets."""
- if self.targets is not None:
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.targets))
- else:
- raise Exception("Cannot evaluate losses with unspecified targets.")
-
- @abc.abstractmethod
- def _evaluate(self, targets):
- """Evaluates the negative log probability of the targets.
-
- Args:
- targets: Tensor that distribution can calculate log_prob() of.
-
- Returns:
- negative log probability of each target, summed across all targets.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian(self, vector):
- """Right-multiply a vector by the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Hessian. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor(self, vector):
- """Right-multiply a vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'hessian_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'hessian_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'hessian_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_hessian_factor."""
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_static_shape(self):
- """Static version of hessian_factor_inner_shape."""
- pass
-
-
-@six.add_metaclass(abc.ABCMeta)
-class NegativeLogProbLoss(LossFunction):
- """Abstract base class for loss functions that are negative log probs."""
-
- def __init__(self, seed=None):
- self._default_seed = seed
- super(NegativeLogProbLoss, self).__init__()
-
- @property
- def inputs(self):
- return self.params
-
- @abc.abstractproperty
- def params(self):
- """Parameters to the underlying distribution."""
- pass
-
- @abc.abstractmethod
- def multiply_fisher(self, vector):
- """Right-multiply a vector by the Fisher.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Fisher. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor(self, vector):
- """Right-multiply a vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'fisher_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'fisher_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'fisher_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_fisher_factor."""
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_static_shape(self):
- """Static version of fisher_factor_inner_shape."""
- pass
-
- @abc.abstractmethod
- def sample(self, seed):
- """Sample 'targets' from the underlying distribution."""
- pass
-
- def evaluate_on_sample(self, seed=None):
- """Evaluates the log probability on a random sample.
-
- Args:
- seed: int or None. Random seed for this draw from the distribution.
-
- Returns:
- Log probability of sampled targets, summed across examples.
- """
- if seed is None:
- seed = self._default_seed
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
-
-
-# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
-# inheritance, or is there a better way?
-class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses whose inputs are 'natural' parameters.
-
- Note that the Hessian and Fisher for natural parameters of exponential-
- family models are the same, hence the purpose of this class.
- See here: https://arxiv.org/abs/1412.1193
-
- 'Natural parameters' are defined for exponential-family models. See for
- example: https://en.wikipedia.org/wiki/Exponential_family
- """
-
- def multiply_hessian(self, vector):
- return self.multiply_fisher(vector)
-
- def multiply_hessian_factor(self, vector):
- return self.multiply_fisher_factor(vector)
-
- def multiply_hessian_factor_transpose(self, vector):
- return self.multiply_fisher_factor_transpose(vector)
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- return self.multiply_fisher_factor_replicated_one_hot(index)
-
- @property
- def hessian_factor_inner_shape(self):
- return self.fisher_factor_inner_shape
-
- @property
- def hessian_factor_inner_static_shape(self):
- return self.fisher_factor_inner_shape
-
-
-class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses that use the TF Distribution classes."""
-
- def __init__(self, seed=None):
- super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
-
- @abc.abstractproperty
- def dist(self):
- """The underlying tf.distributions.Distribution."""
- pass
-
- def _evaluate(self, targets):
- return -math_ops.reduce_sum(self.dist.log_prob(targets))
-
- def sample(self, seed):
- return self.dist.sample(seed=seed)
-
-
-class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a normal distribution parameterized by a mean vector.
-
-
- Note that the covariance is treated as a constant 'var' times the identity.
- Also note that the Fisher for such a normal distribution with respect the mean
- parameter is given by:
-
- F = (1/var) * I
-
- See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
- """
-
- def __init__(self, mean, var=0.5, targets=None, seed=None):
- self._mean = mean
- self._var = var
- self._targets = targets
- super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
-
- @property
- def params(self):
- return self._mean
-
- def multiply_fisher(self, vector):
- return (1. / self._var) * vector
-
- def multiply_fisher_factor(self, vector):
- return self._var**-0.5 * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- ones_slice = array_ops.expand_dims(
- array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
- axis=-1)
- output_slice = self._var**-0.5 * ones_slice
- return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._mean)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._mean.shape
-
-
-class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
- """Negative log prob loss for a normal distribution with mean and variance.
-
- This class parameterizes a multivariate normal distribution with n independent
- dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
- assume the variance is held constant. The Fisher Information for n = 1
- is given by,
-
- F = [[1 / variance, 0],
- [ 0, 0.5 / variance^2]]
-
- where the parameters of the distribution are concatenated into a single
- vector as [mean, variance]. For n > 1, the mean parameter vector is
- concatenated with the variance parameter vector.
-
- See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
- """
-
- def __init__(self, mean, variance, targets=None, seed=None):
- assert len(mean.shape) == 2, "Expect 2D mean tensor."
- assert len(variance.shape) == 2, "Expect 2D variance tensor."
- self._mean = mean
- self._variance = variance
- self._targets = targets
- super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
-
- @property
- def params(self):
- return self._mean, self._variance
-
- def _concat(self, mean, variance):
- return array_ops.concat([mean, variance], axis=-1)
-
- def _split(self, params):
- return array_ops.split(params, 2, axis=-1)
-
- @property
- def _fisher_mean(self):
- return 1. / self._variance
-
- @property
- def _fisher_mean_factor(self):
- return 1. / math_ops.sqrt(self._variance)
-
- @property
- def _fisher_var(self):
- return 1. / (2 * math_ops.square(self._variance))
-
- @property
- def _fisher_var_factor(self):
- return 1. / (math_ops.sqrt(2.) * self._variance)
-
- def multiply_fisher(self, vecs):
- mean_vec, var_vec = vecs
- return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
-
- def multiply_fisher_factor(self, vecs):
- mean_vec, var_vec = self._split(vecs)
- return (self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_transpose(self, vecs):
- mean_vec, var_vec = vecs
- return self._concat(self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- index = index[0]
-
- if index < int(self._mean.shape[-1]):
- # Index corresponds to mean parameter.
- mean_slice = self._fisher_mean_factor[:, index]
- mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
- mean_output = insert_slice_in_zeros(mean_slice, 1, int(
- self._mean.shape[1]), index)
- var_output = array_ops.zeros_like(mean_output)
- else:
- index -= int(self._mean.shape[-1])
- # Index corresponds to variance parameter.
- var_slice = self._fisher_var_factor[:, index]
- var_slice = array_ops.expand_dims(var_slice, axis=-1)
- var_output = insert_slice_in_zeros(var_slice, 1,
- int(self._variance.shape[1]), index)
- mean_output = array_ops.zeros_like(var_output)
-
- return mean_output, var_output
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.concat(
- [
- array_ops.shape(self._mean)[:-1],
- 2 * array_ops.shape(self._mean)[-1:]
- ],
- axis=0)
-
- @property
- def fisher_factor_inner_static_shape(self):
- shape = self._mean.shape.as_list()
- return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
-
- def multiply_hessian(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_transpose(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_shape(self):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_static_shape(self):
- raise NotImplementedError()
-
-
-class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution parameterized by logits.
-
-
- Note that the Fisher (for a single case) of a categorical distribution, with
- respect to the natural parameters (i.e. the logits), is given by:
-
- F = diag(p) - p*p^T
-
- where p = softmax(logits). F can be factorized as F = B * B^T where
-
- B = diag(q) - p*q^T
-
- where q is the entry-wise square root of p. This is easy to verify using the
- fact that q^T*q = 1.
- """
-
- def __init__(self, logits, targets=None, seed=None):
- """Instantiates a CategoricalLogitsNegativeLogProbLoss.
-
- Args:
- logits: Tensor of shape [batch_size, output_size]. Parameters for
- underlying distribution.
- targets: None or Tensor of shape [output_size]. Each elements contains an
- index in [0, output_size).
- seed: int or None. Default random seed when sampling.
- """
- self._logits = logits
- self._targets = targets
- super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return categorical.Categorical(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def _sqrt_probs(self):
- return math_ops.sqrt(self._probs)
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- probs = self._probs
- return vector * probs - probs * math_ops.reduce_sum(
- vector * probs, axis=-1, keepdims=True)
-
- def multiply_fisher_factor(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - probs * math_ops.reduce_sum(
- sqrt_probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_transpose(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
- probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
- padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
- int(sqrt_probs.shape[1]), index[0])
- return padded_slice - probs * sqrt_probs_slice
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for multiple Bernoulli distributions param'd by logits.
-
- Represents N independent Bernoulli distributions where N = len(logits). Its
- Fisher Information matrix is given by,
-
- F = diag(p * (1-p))
- p = sigmoid(logits)
-
- As F is diagonal with positive entries, its factor B is,
-
- B = diag(sqrt(p * (1-p)))
- """
-
- def __init__(self, logits, targets=None, seed=None):
- self._logits = logits
- self._targets = targets
- super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return bernoulli.Bernoulli(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- return self._probs * (1 - self._probs) * vector
-
- def multiply_fisher_factor(self, vector):
- return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
- output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
- return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
- """Inserts slice into a larger tensor of zeros.
-
- Forms a new tensor which is the same shape as slice_to_insert, except that
- the dimension given by 'dim' is expanded to the size given by 'dim_size'.
- 'position' determines the position (index) at which to insert the slice within
- that dimension.
-
- Assumes slice_to_insert.shape[dim] = 1.
-
- Args:
- slice_to_insert: The slice to insert.
- dim: The dimension which to expand with zeros.
- dim_size: The new size of the 'dim' dimension.
- position: The position of 'slice_to_insert' in the new tensor.
-
- Returns:
- The new tensor.
-
- Raises:
- ValueError: If the slice's shape at the given dim is not 1.
- """
- slice_shape = slice_to_insert.shape
- if slice_shape[dim] != 1:
- raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
- "was {}".format(dim, slice_to_insert.shape[dim]))
-
- before = [0] * int(len(slice_shape))
- after = before[:]
- before[dim] = position
- after[dim] = dim_size - position - 1
-
- return array_ops.pad(slice_to_insert, list(zip(before, after)))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLoss(
- CategoricalLogitsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution with onehot targets.
-
- Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
- distribution is OneHotCategorical as opposed to Categorical.
- """
-
- @property
- def dist(self):
- return onehot_categorical.OneHotCategorical(logits=self._logits)
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
deleted file mode 100644
index 4279cb2792..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.loss_functions import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "LossFunction",
- "NegativeLogProbLoss",
- "NaturalParamsNegativeLogProbLoss",
- "DistributionNegativeLogProbLoss",
- "NormalMeanNegativeLogProbLoss",
- "NormalMeanVarianceNegativeLogProbLoss",
- "CategoricalLogitsNegativeLogProbLoss",
- "OnehotCategoricalLogitsNegativeLogProbLoss",
- "MultiBernoulliNegativeLogProbLoss",
- "insert_slice_in_zeros",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py
deleted file mode 100644
index b6d9d37a31..0000000000
--- a/tensorflow/contrib/kfac/python/ops/op_queue.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops as tf_ops
-
-
-class OpQueue(object):
- """Class for choosing which Op to run next.
-
- Constructs an infinitely repeating sequence of Ops in shuffled order.
-
- In K-FAC, this can be used to distribute inverse update operations among
- workers.
- """
-
- def __init__(self, ops, seed=None):
- """Initializes an OpQueue.
-
- Args:
- ops: list of TensorFlow Ops. Ops to be selected from. All workers must
- initialize with the same set of ops.
- seed: int or None. Random seed used when shuffling order of ops.
- """
- self._ops_by_name = {op.name: op for op in ops}
-
- # Construct a (shuffled) Dataset with Op names.
- op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
- op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
- .shuffle(len(ops), seed=seed).repeat())
- self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
-
- @property
- def ops(self):
- """Ops this OpQueue can return in next_op()."""
- return self._ops_by_name.values()
-
- def next_op(self, sess):
- """Chooses which op to run next.
-
- Note: This call will make a call to sess.run().
-
- Args:
- sess: tf.Session.
-
- Returns:
- Next Op chosen from 'ops'.
- """
- # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
- # returns a str.
- next_op_name = sess.run(self._next_op_name).decode('ascii')
- return self._ops_by_name[next_op_name]
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
deleted file mode 100644
index 38605259b5..0000000000
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ /dev/null
@@ -1,727 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""The KFAC optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-# pylint disable=long-line
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
-from tensorflow.contrib.kfac.python.ops import estimator as est
-# pylint enable=long-line
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.training import gradient_descent
-
-
-class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
- """The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
-
- def __init__(self,
- learning_rate,
- cov_ema_decay,
- damping,
- layer_collection,
- var_list=None,
- momentum=0.9,
- momentum_type="regular",
- norm_constraint=None,
- name="KFAC",
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- batch_size=None,
- placement_strategy=None,
- **kwargs):
- """Initializes the KFAC optimizer with the given settings.
-
- Args:
- learning_rate: The base learning rate for the optimizer. Should probably
- be set to 1.0 when using momentum_type = 'qmodel', but can still be
- set lowered if desired (effectively lowering the trust in the
- quadratic model.)
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: The damping factor used to stabilize training due to errors in
- the local approximation with the Fisher information matrix, and to
- regularize the update direction by making it closer to the gradient.
- If damping is adapted during training then this value is used for
- initializing damping variable.
- (Higher damping means the update looks more like a standard gradient
- update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the fisher
- blocks, Kronecker factors, and losses associated with the
- graph. The layer_collection cannot be modified after KfacOptimizer's
- initialization.
- var_list: Optional list or tuple of variables to train. Defaults to the
- list of variables collected in the graph under the key
- `GraphKeys.TRAINABLE_VARIABLES`.
- momentum: The momentum decay constant to use. Only applies when
- momentum_type is 'regular' or 'adam'. (Default: 0.9)
- momentum_type: The type of momentum to use in this optimizer, one of
- 'regular', 'adam', or 'qmodel'. (Default: 'regular')
- norm_constraint: float or Tensor. If specified, the update is scaled down
- so that its approximate squared Fisher norm v^T F v is at most the
- specified value. May only be used with momentum type 'regular'.
- (Default: None)
- name: The name for this optimizer. (Default: 'KFAC')
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_propagation', or 'exact'.
- (Default: 'gradients'). See the doc-string for FisherEstimator for
- more a more detailed description of these options.
- colocate_gradients_with_ops: Whether we should request gradients we
- compute in the estimator be colocated with their respective ops.
- (Default: True)
- batch_size: The size of the mini-batch. Only needed when momentum_type
- == 'qmodel' or when automatic adjustment is used. (Default: None)
- placement_strategy: string, Device placement strategy used when creating
- covariance variables, covariance ops, and inverse ops.
- (Default: `None`)
- **kwargs: Arguments to be passed to specific placement
- strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
-
- Raises:
- ValueError: If the momentum type is unsupported.
- ValueError: If clipping is used with momentum type other than 'regular'.
- ValueError: If no losses have been registered with layer_collection.
- ValueError: If momentum is non-zero and momentum_type is not 'regular'
- or 'adam'.
- """
- warnings.warn(
- "third_party.tensorflow.contrib.kfac is deprecated."
- "This will be removed on 15-07-2018. Check README for further details.",
- DeprecationWarning)
- # Parameters to be passed to the Fisher estimator:
- self._variables = var_list or tf_variables.trainable_variables
- self._cov_ema_decay = cov_ema_decay
- self._layers = layer_collection
- self._estimation_mode = estimation_mode
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- # The below parameters are required only if damping needs to be adapted.
- # These parameters can be set by calling
- # set_damping_adaptation_params() explicitly.
- self._damping_adaptation_decay = 0.95
- self._damping_adaptation_interval = 5
- # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._adapt_damping = False
- self._min_damping = 1e-5
- self._prev_train_batch = None
- self._is_chief = False
- self._loss_fn = None
- self._damping_constant = damping
- self._damping = None
- self._rho = None
- self._prev_loss = None
- self._q_model_change = None
- self._update_damping_op = None
-
- momentum_type = momentum_type.lower()
- legal_momentum_types = ["regular", "adam", "qmodel"]
-
- if momentum_type not in legal_momentum_types:
- raise ValueError("Unsupported momentum type {}. Must be one of {}."
- .format(momentum_type, legal_momentum_types))
- if momentum_type != "regular" and norm_constraint is not None:
- raise ValueError("Update clipping is only supported with momentum "
- "type 'regular'.")
- if momentum_type not in ["regular", "adam"] and momentum != 0:
- raise ValueError("Momentum must be unspecified if using a momentum_type "
- "other than 'regular' or 'adam'.")
-
- # Extra parameters of the optimizer
- self._momentum = momentum
- self._momentum_type = momentum_type
- self._norm_constraint = norm_constraint
- self._batch_size = batch_size
- self._placement_strategy = placement_strategy
-
- with variable_scope.variable_scope(name):
- self._fisher_est = est.make_fisher_estimator(
- placement_strategy=placement_strategy,
- variables=self._variables,
- cov_ema_decay=self._cov_ema_decay,
- damping=self.damping,
- layer_collection=self._layers,
- exps=(-1,),
- estimation_mode=self._estimation_mode,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops,
- **kwargs)
-
- super(KfacOptimizer, self).__init__(learning_rate, name=name)
-
- def set_damping_adaptation_params(self,
- is_chief,
- prev_train_batch,
- loss_fn,
- min_damping=1e-5,
- damping_adaptation_decay=0.99,
- damping_adaptation_interval=5):
- """Sets parameters required to adapt damping during training.
-
- When called, enables damping adaptation according to the Levenberg-Marquardt
- style rule described in Section 6.5 of "Optimizing Neural Networks with
- Kronecker-factored Approximate Curvature".
-
- Note that this function creates Tensorflow variables which store a few
- scalars and are accessed by the ops which update the damping (as part
- of the training op returned by the minimize() method).
-
- Args:
- is_chief: `Boolean`, `True` if the worker is chief.
- prev_train_batch: Training data used to minimize loss in the previous
- step. This will be used to evaluate loss by calling
- `loss_fn(prev_train_batch)`.
- loss_fn: `function` that takes as input training data tensor and returns
- a scalar loss.
- min_damping: `float`(Optional), Minimum value the damping parameter
- can take. Default value 1e-5.
- damping_adaptation_decay: `float`(Optional), The `damping` parameter is
- multiplied by the `damping_adaptation_decay` every
- `damping_adaptation_interval` number of iterations. Default value 0.99.
- damping_adaptation_interval: `int`(Optional), Number of steps in between
- updating the `damping` parameter. Default value 5.
-
- Raises:
- ValueError: If `set_damping_adaptation_params` is already called and the
- the `adapt_damping` is `True`.
- """
- if self._adapt_damping:
- raise ValueError("Damping adaptation parameters already set.")
-
- with variable_scope.variable_scope(self.get_name()):
- self._adapt_damping = True
- self._is_chief = is_chief
- self._prev_train_batch = prev_train_batch
- self._loss_fn = loss_fn
- self._damping_adaptation_decay = damping_adaptation_decay
- self._damping_adaptation_interval = damping_adaptation_interval
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._min_damping = min_damping
-
- self._rho = variable_scope.get_variable(
- "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
- self._prev_loss = variable_scope.get_variable(
- "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
- self._q_model_change = variable_scope.get_variable(
- "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
- self._damping = variable_scope.get_variable(
- "damping", initializer=self._damping_constant, trainable=False)
-
- @property
- def variables(self):
- return self._fisher_est.variables
-
- @property
- def damping(self):
- if self._damping:
- return self._damping
- else:
- return self._damping_constant
-
- @property
- def damping_adaptation_interval(self):
- return self._damping_adaptation_interval
-
- def make_vars_and_create_op_thunks(self):
- """Make vars and create op thunks.
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
-
- def create_ops_and_vars_thunks(self):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
-
- def minimize(self, *args, **kwargs):
- # Should this variable scope encompass everything below? Or will the super-
- # class make another copy of the same name scope?
- with variable_scope.variable_scope(self.get_name()):
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- if set(kwargs["var_list"]) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- if self._adapt_damping and self._is_chief:
- global_step = kwargs.get("global_step", None)
- if not global_step:
- raise KeyError("global_step needs to be passed to optimizer.minimize "
- "if damping parameter is adapted.")
- update_damping_op = self._update_damping(self._prev_train_batch,
- global_step)
- with ops.control_dependencies([update_damping_op]):
- loss = args[0]
- loss_assign_op = state_ops.assign(self._prev_loss, loss)
- train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
- return control_flow_ops.group(loss_assign_op, train_op)
- else:
- return super(KfacOptimizer, self).minimize(*args, **kwargs)
-
- def compute_gradients(self, *args, **kwargs):
- # args[1] could be our var_list
- if len(args) > 1:
- var_list = args[1]
- else:
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- var_list = kwargs["var_list"]
-
- if set(var_list) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
-
- def apply_gradients(self, grads_and_vars, *args, **kwargs):
- """Applies gradients to variables.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- *args: Additional arguments for super.apply_gradients.
- **kwargs: Additional keyword arguments for super.apply_gradients.
-
- Returns:
- An `Operation` that applies the specified gradients.
- """
- # In Python 3, grads_and_vars can be a zip() object which can only be
- # iterated over once. By converting it to a list, we ensure that it can be
- # iterated over more than once.
- grads_and_vars = list(grads_and_vars)
-
- # Compute step.
- steps_and_vars = self._compute_update_steps(grads_and_vars)
-
- # Update trainable variables with this step.
- return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
- **kwargs)
-
- def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
- """Computes the squared (approximate) Fisher norm of the updates.
-
- This is defined as v^T F v, where F is the approximate Fisher matrix
- as computed by the estimator, and v = F^{-1} g, where g is the gradient.
- This is computed efficiently as v^T g.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the squared norm.
-
- Raises:
- ValueError: if the two list arguments do not contain the same variables,
- in the same order.
- """
- for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
- if gvar is not pgvar:
- raise ValueError("The variables referenced by the two arguments "
- "must match.")
- terms = [
- math_ops.reduce_sum(grad * pgrad)
- for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
- ]
- return math_ops.reduce_sum(terms)
-
- def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
- """Computes the scale factor for the update to satisfy the norm constraint.
-
- Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
- F is the approximate Fisher matrix, and r is the update vector, i.e.
- -alpha * v, where alpha is the learning rate, and v is the preconditioned
- gradient.
-
- This is based on Section 5 of Ba et al., Distributed Second-Order
- Optimization using Kronecker-Factored Approximations. Note that they
- absorb the learning rate alpha (which they denote eta_max) into the formula
- for the coefficient, while in our implementation, the rescaling is done
- before multiplying by alpha. Hence, our formula differs from theirs by a
- factor of alpha.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the coefficient which should be applied to the
- preconditioned gradients to satisfy the norm constraint.
- """
- sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
- precon_grads_and_vars)
- sq_norm_up = sq_norm_grad * self._learning_rate**2
- return math_ops.minimum(1.,
- math_ops.sqrt(self._norm_constraint / sq_norm_up))
-
- def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
- """Rescales the preconditioned gradients to satisfy the norm constraint.
-
- Rescales the preconditioned gradients such that the resulting update r
- (after multiplying by the learning rate) will satisfy the norm constraint.
- This constraint is that r^T F r <= C, where F is the approximate Fisher
- matrix, and C is the norm_constraint attribute. See Section 5 of
- Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
- Approximations.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- List of (rescaled preconditioned gradient, variable) pairs.
- """
- coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
- return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
-
- def _compute_prev_updates(self, variables):
- """Computes previous updates as negative velocities scaled by learning rate.
-
- Args:
- variables: List of variables in the graph that the update will be
- applied to.
-
- Returns:
- List of previous updates applied to the `variables`.
- """
- return list(
- -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
- for var in variables)
-
- def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
- variables):
- """Compute optimal update hyperparameters from the quadratic model.
-
- More specifically, if L is the loss we minimize a quadratic approximation
- of L(theta + d) which we denote by qmodel(d) with
- d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
-
- qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
-
- Unlike in the KL clipping approach we use the non-approximated quadratic
- model where the curvature matrix C is the true Fisher on the current
- mini-batch (computed without any approximations beyond mini-batch sampling),
- with the usual Tikhonov damping/regularization applied,
-
- C = F + damping * I
-
- See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
- the formula. See Appendix C for a discussion of the trick of using
- a factorized Fisher matrix to more efficiently compute the required
- vector-matrix-vector products.
-
- Note that the elements of all 4 lists passed to this function must
- be in correspondence with each other.
-
- Args:
- precon_grads: List of preconditioned gradients.
- prev_updates: List of updates computed at the previous iteration.
- grads: List of gradients.
- variables: List of variables in the graph that the update will be
- applied to. (Note that this function doesn't actually apply the
- update.)
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
- = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
- """
-
- cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
- variables)
-
- # compute the matrix-vector products with the transposed Fisher factor
- fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
- fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
- batch_size = math_ops.cast(
- self._batch_size, dtype=fft_precon_grads[0].dtype)
-
- # compute the entries of the 2x2 matrix
- m_11 = (
- _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(precon_grads, precon_grads))
-
- m_21 = (
- _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(prev_updates, precon_grads))
-
- m_22 = (
- _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
- self.damping * _inner_product_list(prev_updates, prev_updates))
-
- def non_zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given non-zero previous update.
-
- We solve the full 2x2 linear system. See Martens & Grosse (2015),
- Section 7, definition of $\alpha^*$ and $\mu^*$.
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
- the quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
- """
- m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
-
- c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
- [_inner_product_list(grads, prev_updates)]])
-
- sol = -1. * _two_by_two_solve(m, c)
- alpha = sol[0]
- mu = sol[1]
- qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
-
- return alpha, mu, qmodel_change
-
- def zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given all-zero previous update.
-
- The linear system reduces to 1x1. See Martens & Grosse (2015),
- Section 6.4, definition of $\alpha^*$.
-
- Returns:
- (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
- """
- m = m_11
- c = _inner_product_list(grads, precon_grads)
-
- alpha = -c / m
- mu = 0.0
- qmodel_change = 0.5 * alpha * c
-
- return alpha, mu, qmodel_change
-
- return control_flow_ops.cond(
- math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
-
- def _assign_q_model_change(self, q_model_change):
- """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
-
- Note only the chief worker does the assignment.
-
- Args:
- q_model_change: Scalar tensor of type `float32`.
-
- Returns:
- If `adapt_damping` is `True` then returns an assign op, Otherwise returns
- a no_op().
- """
- if self._adapt_damping and self._is_chief:
- q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
- else:
- q_model_assign_op = control_flow_ops.no_op()
- return q_model_assign_op
-
- def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
- precon_grads_and_vars):
- """Wrapper function for `self._compute_qmodel_hyperparams`.
-
- Constructs a list of preconditioned gradients and variables. Also creates a
- op to assign the computed q model change to `self._q_model_change`.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradients, variable)
- pairs.
-
- Returns:
- (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
- the quadratic model, `q_model_assign_op` assigns the computed q model
- change to `self._q_model_change`.
- """
- precon_grads = list(
- precon_grad for (precon_grad, _) in precon_grads_and_vars)
- grads = list(grad for (grad, _) in grads_and_vars)
- variables = list(var for (_, var) in grads_and_vars)
- prev_updates = self._compute_prev_updates(variables)
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
- precon_grads, prev_updates, grads, variables)
-
- return alpha, mu, self._assign_q_model_change(q_model_change)
-
- def _compute_update_steps(self, grads_and_vars):
- """Computes the update steps for the variables given the gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
-
- Returns:
- A list of tuple (assign_op ,var) where `assign_op` assigns the update
- steps to `var`.
- """
-
- if self._momentum_type == "regular":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Apply "KL clipping" if asked for.
- if self._norm_constraint is not None:
- precon_grads_and_vars = self._clip_updates(grads_and_vars,
- precon_grads_and_vars)
-
- # Update the velocity with this and return it as the step.
- if self._adapt_damping and self._is_chief:
- _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- else:
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- elif self._momentum_type == "adam":
- # Update velocity.
- velocities_and_vars = self._update_velocities(grads_and_vars,
- self._momentum)
- # Return "preconditioned" velocity vector as the step.
- return self._fisher_est.multiply_inverse(velocities_and_vars)
-
- elif self._momentum_type == "qmodel":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
-
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(
- precon_grads_and_vars, mu, vec_coeff=-alpha)
-
- def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
- """Updates the velocities of the variables with the given vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- decay: How much to decay the old velocity by. This is often referred to
- as the 'momentum constant'.
- vec_coeff: Coefficient to apply to the vectors before adding them to the
- velocity.
-
- Returns:
- A list of (velocity, var) indicating the new velocity for each var.
- """
-
- def _update_velocity(vec, var):
- velocity = self._zeros_slot(var, "velocity", self._name)
- with ops.colocate_with(velocity):
- # NOTE(mattjj): read/modify/write race condition not suitable for async.
-
- # Compute the new velocity for this variable.
- new_velocity = decay * velocity + vec_coeff * vec
-
- # Save the updated velocity.
- return (array_ops.identity(velocity.assign(new_velocity)), var)
-
- # Go through variable and update its associated part of the velocity vector.
- return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
-
- def _update_damping(self, prev_batch, global_step):
- """Adapts damping parameter. Check KFAC (Section 6.5) for the details.
-
- The damping parameter is updated according to the Levenberg-Marquardt rule
- every `self._damping_adaptation_interval` iterations.
-
- Args:
- prev_batch: Tensor or tuple of tensors which can be passed to
- `self._loss_fn` to evaluate loss.
- global_step: `Variable` which keeps track of number of times the training
- variables have been updated.
- Returns:
- A `tf.cond` op which updates the damping parameter.
- """
- def compute_damping():
- """"Adapts damping parameter based on "reduction ratio".
-
- Reduction ratio captures how closely the quadratic approximation to the
- loss function approximates the actual loss within a trust region. The
- damping update tries to make the damping as small as possible while
- maintaining the property that the quadratic model remains a good local
- approximation to the loss function.
-
- Returns:
- An Op to assign newly computed damping value to `self._damping`.
- """
- prev_batch_loss = self._loss_fn(prev_batch)
- with ops.control_dependencies([prev_batch_loss]):
- rho_assign = self._rho.assign(
- (prev_batch_loss - self._prev_loss) / self._q_model_change)
- with ops.control_dependencies([rho_assign]):
- new_damping = control_flow_ops.case(
- [(self._rho < 0.25, lambda: self.damping / self._omega),
- (self._rho > 0.75, lambda: self.damping * self._omega)],
- lambda: self.damping)
- with ops.control_dependencies([new_damping]):
- new_damping_min = math_ops.maximum(new_damping, self._min_damping)
- return control_flow_ops.group(self._damping.assign(new_damping_min))
-
- return control_flow_ops.cond(
- math_ops.equal(
- math_ops.mod(global_step + 1, self._damping_adaptation_interval),
- 0), compute_damping, control_flow_ops.no_op)
-
-
-def _inner_product_list(list1, list2):
- return math_ops.add_n(
- [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
-
-
-def _two_by_two_solve(m, c):
- # it might be better just to crank out the exact formula for 2x2 inverses
- return math_ops.matmul(linalg_ops.matrix_inverse(m), c)
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
deleted file mode 100644
index c4454325ae..0000000000
--- a/tensorflow/contrib/kfac/python/ops/placement.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Implements placement strategies for cov and inv ops, cov variables."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-from tensorflow.python.framework import ops as tf_ops
-
-
-def _make_thunk_on_device(func, device):
- def thunk():
- with tf_ops.device(device):
- return func()
- return thunk
-
-
-class RoundRobinPlacementMixin(object):
- """Implements round robin placement strategy for ops and variables."""
-
- def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
- """Initializes the RoundRobinPlacementMixin class.
-
- Args:
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- **kwargs: Need something here?
-
- """
- super(RoundRobinPlacementMixin, self).__init__(**kwargs)
- self._cov_devices = cov_devices
- self._inv_devices = inv_devices
-
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks w/ a round-robin device placement start.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the
- `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
- explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the `self._inv_devices` attribute.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
- (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
- inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
-
- if self._cov_devices:
- cov_update_thunks = []
- for cov_variable_thunk, cov_update_thunk, device in zip(
- cov_variable_thunks_raw, cov_update_thunks_raw,
- itertools.cycle(self._cov_devices)):
- with tf_ops.device(device):
- cov_variable_thunk()
- cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
- device))
- else:
- for cov_variable_thunk in cov_variable_thunks_raw:
- cov_variable_thunk()
- cov_update_thunks = cov_update_thunks_raw
-
- for inv_variable_thunk in inv_variable_thunks_raw:
- inv_variable_thunk()
-
- if self._inv_devices:
- inv_update_thunks = []
- for inv_update_thunk, device in zip(inv_update_thunks_raw,
- itertools.cycle(self._inv_devices)):
- inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
- device))
- else:
- inv_update_thunks = inv_update_thunks_raw
-
- return cov_update_thunks, inv_update_thunks
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
deleted file mode 100644
index 144295f4c7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ /dev/null
@@ -1,709 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.tpu.python.ops import tpu_ops
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variables
-
-# Method used for inverting matrices.
-POSDEF_INV_METHOD = "cholesky"
-POSDEF_EIG_METHOD = "self_adjoint"
-
-
-def set_global_constants(posdef_inv_method=None):
- """Sets various global constants used by the classes in this module."""
- global POSDEF_INV_METHOD
-
- if posdef_inv_method is not None:
- POSDEF_INV_METHOD = posdef_inv_method
-
-
-class SequenceDict(object):
- """A dict convenience wrapper that allows getting/setting with sequences."""
-
- def __init__(self, iterable=None):
- self._dict = dict(iterable or [])
-
- def __getitem__(self, key_or_keys):
- if isinstance(key_or_keys, (tuple, list)):
- return list(map(self.__getitem__, key_or_keys))
- else:
- return self._dict[key_or_keys]
-
- def __setitem__(self, key_or_keys, val_or_vals):
- if isinstance(key_or_keys, (tuple, list)):
- for key, value in zip(key_or_keys, val_or_vals):
- self[key] = value
- else:
- self._dict[key_or_keys] = val_or_vals
-
- def items(self):
- return list(self._dict.items())
-
-
-def tensors_to_column(tensors):
- """Converts a tensor or list of tensors to a column vector.
-
- Args:
- tensors: A tensor or list of tensors.
-
- Returns:
- The tensors reshaped into vectors and stacked on top of each other.
- """
- if isinstance(tensors, (tuple, list)):
- return array_ops.concat(
- tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
- else:
- return array_ops.reshape(tensors, [-1, 1])
-
-
-def column_to_tensors(tensors_template, colvec):
- """Converts a column vector back to the shape of the given template.
-
- Args:
- tensors_template: A tensor or list of tensors.
- colvec: A 2d column vector with the same shape as the value of
- tensors_to_column(tensors_template).
-
- Returns:
- X, where X is tensor or list of tensors with the properties:
- 1) tensors_to_column(X) = colvec
- 2) X (or its elements) have the same shape as tensors_template (or its
- elements)
- """
- if isinstance(tensors_template, (tuple, list)):
- offset = 0
- tensors = []
- for tensor_template in tensors_template:
- sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
- tensor = array_ops.reshape(colvec[offset:(offset + sz)],
- tensor_template.shape)
- tensors.append(tensor)
- offset += sz
-
- tensors = tuple(tensors)
- else:
- tensors = array_ops.reshape(colvec, tensors_template.shape)
-
- return tensors
-
-
-def kronecker_product(mat1, mat2):
- """Computes the Kronecker product two matrices."""
- m1, n1 = mat1.get_shape().as_list()
- mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
- m2, n2 = mat2.get_shape().as_list()
- mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
- return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
-
-
-def layer_params_to_mat2d(vector):
- """Converts a vector shaped like layer parameters to a 2D matrix.
-
- In particular, we reshape the weights/filter component of the vector to be
- 2D, flattening all leading (input) dimensions. If there is a bias component,
- we concatenate it to the reshaped weights/filter component.
-
- Args:
- vector: A Tensor or pair of Tensors shaped like layer parameters.
-
- Returns:
- A 2D Tensor with the same coefficients and the same output dimension.
- """
- if isinstance(vector, (tuple, list)):
- w_part, b_part = vector
- w_part_reshaped = array_ops.reshape(w_part,
- [-1, w_part.shape.as_list()[-1]])
- return array_ops.concat(
- (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
- elif isinstance(vector, ops.IndexedSlices):
- return vector
- else: # Tensor or Tensor-like.
- return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
-
-
-def mat2d_to_layer_params(vector_template, mat2d):
- """Converts a canonical 2D matrix representation back to a vector.
-
- Args:
- vector_template: A Tensor or pair of Tensors shaped like layer parameters.
- mat2d: A 2D Tensor with the same shape as the value of
- layer_params_to_mat2d(vector_template).
-
- Returns:
- A Tensor or pair of Tensors with the same coefficients as mat2d and the same
- shape as vector_template.
- """
- if isinstance(vector_template, (tuple, list)):
- w_part, b_part = mat2d[:-1], mat2d[-1]
- return array_ops.reshape(w_part, vector_template[0].shape), b_part
- elif isinstance(vector_template, ops.IndexedSlices):
- if not isinstance(mat2d, ops.IndexedSlices):
- raise TypeError(
- "If vector_template is an IndexedSlices, so should mat2d.")
- return mat2d
- else:
- return array_ops.reshape(mat2d, vector_template.shape)
-
-
-def posdef_inv(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
-
-
-def posdef_inv_matrix_inverse(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) directly."""
- return linalg_ops.matrix_inverse(tensor + damping * identity)
-
-
-def posdef_inv_cholesky(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with Cholesky."""
- chol = linalg_ops.cholesky(tensor + damping * identity)
- return linalg_ops.cholesky_solve(chol, identity)
-
-
-def posdef_inv_eig(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with eigendecomposition."""
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
- tensor + damping * identity)
- return math_ops.matmul(
- eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
-
-
-posdef_inv_functions = {
- "matrix_inverse": posdef_inv_matrix_inverse,
- "cholesky": posdef_inv_cholesky,
- "eig": posdef_inv_eig,
-}
-
-
-def posdef_eig(mat):
- """Computes the eigendecomposition of a positive semidefinite matrix."""
- return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
-
-
-def posdef_eig_svd(mat):
- """Computes the singular values and left singular vectors of a matrix."""
- evals, evecs, _ = linalg_ops.svd(mat)
-
- return evals, evecs
-
-
-def posdef_eig_self_adjoint(mat):
- """Computes eigendecomposition using self_adjoint_eig."""
- evals, evecs = linalg_ops.self_adjoint_eig(mat)
- evals = math_ops.abs(evals) # Should be equivalent to svd approach.
-
- return evals, evecs
-
-
-posdef_eig_functions = {
- "self_adjoint": posdef_eig_self_adjoint,
- "svd": posdef_eig_svd,
-}
-
-
-def cholesky(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return linalg_ops.cholesky(tensor + damping * identity)
-
-
-class SubGraph(object):
- """Defines a subgraph given by all the dependencies of a given set of outputs.
- """
-
- def __init__(self, outputs):
- # Set of all ancestor Tensors, Ops to 'outputs'.
- self._members = set()
-
- self._iter_add(outputs)
-
- def _iter_add(self, root):
- """Iteratively adds all of nodes' ancestors using depth first search."""
- stack = [root]
- while stack:
- nodes = stack.pop()
- for node in nodes:
- if node in self._members:
- continue
- self._members.add(node)
-
- if isinstance(node, ops.Tensor):
- stack.append((node.op,))
- elif isinstance(node, ops.Operation):
- stack.append(node.inputs)
-
- def is_member(self, node):
- """Check if 'node' is in this subgraph."""
- return node in self._members
-
- def variable_uses(self, var):
- """Computes number of times a variable is used.
-
- Args:
- var: Variable or ResourceVariable instance.
-
- Returns:
- Number of times a variable is used within this subgraph.
-
- Raises:
- ValueError: If 'var' is not a variable type.
- """
- if isinstance(var, resource_variable_ops.ResourceVariable):
- var = var.handle
- elif isinstance(var, variables.Variable):
- var = var.value()
- else:
- raise ValueError("%s does not appear to be a variable." % str(var))
-
- return len(self._members.intersection(set(var.consumers())))
-
- def filter_list(self, node_list):
- """Filters 'node_list' to nodes in this subgraph."""
- filtered_list = []
- for node in node_list:
- if self.is_member(node):
- filtered_list.append(node)
- return filtered_list
-
-
-def generate_random_signs(shape, dtype=dtypes.float32):
- """Generate a random tensor with {-1, +1} entries."""
- ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
- return 2 * math_ops.cast(ints, dtype=dtype) - 1
-
-
-def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
- """Compute forward-mode gradients."""
- # See b/37888268.
-
- # This version of forward-mode autodiff is based on code by Tim Cooijmans
- # and handles list arguments and certain special cases such as when the
- # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
- # generated by the first gradients_impl.gradients call.
-
- us = [array_ops.zeros_like(y) + float("nan") for y in ys]
- dydxs = gradients_impl.gradients(
- ys, xs, grad_ys=us, stop_gradients=stop_gradients)
-
- # Deal with strange types that gradients_impl.gradients returns but can't
- # deal with.
- dydxs = [
- ops.convert_to_tensor(dydx)
- if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
- ]
- dydxs = [
- array_ops.zeros_like(x) if dydx is None else dydx
- for x, dydx in zip(xs, dydxs)
- ]
-
- dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
-
- return dysdx
-
-
-def on_tpu():
- """Returns True when building a TPU computation."""
- return tpu_function.get_tpu_context().number_of_shards is not None
-
-
-def cross_replica_mean(tensor, name=None):
- """Takes mean value of a Tensor across all TPU cores.
-
- Args:
- tensor: Tensor to be synchronized.
- name: None or string. Name of Op.
-
- Returns:
- Average of Tensor across all TPU cores.
-
- Raises:
- ValueError: If called outside of TPU context.
- """
- with ops.name_scope(name, "cross_replica_mean", [tensor]):
- num_shards = tpu_function.get_tpu_context().number_of_shards
- if num_shards is None:
- raise ValueError(
- "Cannot take cross_replica_mean() outside of TPU Context.")
- if num_shards == 1:
- return tensor
- return tpu_ops.cross_replica_sum(tensor / num_shards)
-
-
-def ensure_sequence(obj):
- """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
- if isinstance(obj, (tuple, list)):
- return obj
- else:
- return (obj,)
-
-
-def batch_execute(global_step, thunks, batch_size, name=None):
- """Executes a subset of ops per global step.
-
- Given a list of thunks, each of which produces a single stateful op,
- ensures that exactly 'batch_size' ops are run per global step. Ops are
- scheduled in a round-robin fashion. For example, with 3 ops
-
- global_step | op0 | op1 | op2
- ------------+-----+-----+-----
- 0 | x | x |
- ------------+-----+-----+-----
- 1 | x | | x
- ------------+-----+-----+-----
- 2 | | x | x
- ------------+-----+-----+-----
- 3 | x | x |
- ------------+-----+-----+-----
- 4 | x | | x
-
- Does not guarantee order of op execution within a single global step.
-
- Args:
- global_step: Tensor indicating time. Determines which ops run.
- thunks: List of thunks. Each thunk encapsulates one op. Return values are
- ignored.
- batch_size: int. Number of ops to execute per global_step.
- name: string or None. Name scope for newly added ops.
-
- Returns:
- List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
- every global step.
- """
-
- def true_fn(thunk):
- """Ensures thunk is executed and returns an Op (not a Tensor)."""
-
- def result():
- with ops.control_dependencies([thunk()]):
- return control_flow_ops.no_op()
-
- return result
-
- def false_fn(_):
- """Executes a no-op."""
-
- def result():
- return control_flow_ops.no_op()
-
- return result
-
- with ops.name_scope(name, "batch_execute"):
- true_fns = [true_fn(thunk) for thunk in thunks]
- false_fns = [false_fn(thunk) for thunk in thunks]
- num_thunks = len(thunks)
- conditions = [
- math_ops.less(
- math_ops.mod(batch_size - 1 + global_step * batch_size - j,
- num_thunks), batch_size) for j in range(num_thunks)
- ]
- result = [
- control_flow_ops.cond(condition, true_fn, false_fn)
- for (condition, true_fn,
- false_fn) in zip(conditions, true_fns, false_fns)
- ]
- return result
-
-
-def extract_convolution_patches(inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- name=None,
- data_format=None):
- """Extracts inputs to each output coordinate in tf.nn.convolution.
-
- This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
- where the number of spatial dimensions may be something other than 2.
-
- Assumes,
- - First dimension of inputs is batch_size
- - Convolution filter is applied to all input channels.
-
- Args:
- inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
- filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
- padding: string. Padding method. One of "VALID", "SAME".
- strides: None or list of ints. Strides along spatial dimensions.
- dilation_rate: None or list of ints. Dilation along spatial dimensions.
- name: None or str. Name of Op.
- data_format: None or str. Format of data.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: If data_format does not put channel last.
- ValueError: If inputs and filter disagree on in_channels.
- """
- if not is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last dimension.")
- with ops.name_scope(name, "extract_convolution_patches",
- [inputs, filter_shape, padding, strides, dilation_rate]):
- batch_size = inputs.shape.as_list()[0]
- in_channels = inputs.shape.as_list()[-1]
-
- # filter_shape = spatial_filter_shape + [in_channels, out_channels]
- spatial_filter_shape = filter_shape[:-2]
- if in_channels != filter_shape[-2]:
- raise ValueError("inputs and filter_shape must agree on in_channels.")
-
- # Map each input feature to a location in the output.
- out_channels = np.prod(spatial_filter_shape) * in_channels
- filters = linalg_ops.eye(out_channels)
- filters = array_ops.reshape(
- filters,
- list(spatial_filter_shape) + [in_channels, out_channels])
-
- result = nn_ops.convolution(
- inputs,
- filters,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate)
- spatial_output_shape = result.shape.as_list()[1:-1]
- result = array_ops.reshape(result,
- [batch_size or -1] + spatial_output_shape +
- list(spatial_filter_shape) + [in_channels])
-
- return result
-
-
-def extract_pointwise_conv2d_patches(inputs,
- filter_shape,
- name=None,
- data_format=None):
- """Extract patches for a 1x1 conv2d.
-
- Args:
- inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
- filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
- name: None or str. Name for Op.
- data_format: None or str. Format for data. See 'data_format' in
- tf.nn.conv2d() for details.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_input_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: if inputs is not 4-D.
- ValueError: if filter_shape is not [1, 1, ?, ?]
- ValueError: if data_format is not channels-last.
- """
- if inputs.shape.ndims != 4:
- raise ValueError("inputs must have 4 dims.")
- if len(filter_shape) != 4:
- raise ValueError("filter_shape must have 4 dims.")
- if filter_shape[0] != 1 or filter_shape[1] != 1:
- raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
- if not is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels last.")
- with ops.name_scope(name, "extract_pointwise_conv2d_patches",
- [inputs, filter_shape]):
- ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
- strides = [1, 1, 1, 1] # Operate on all pixels.
- rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
- padding = "VALID" # Doesn't matter.
- result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
- padding)
-
- batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
- filter_height, filter_width, in_channels, _ = filter_shape
- return array_ops.reshape(result, [
- batch_size, input_height, input_width, filter_height, filter_width,
- in_channels
- ])
-
-
-def is_data_format_channel_last(data_format):
- """True if data_format puts channel last."""
- if data_format is None:
- return True
- return data_format.endswith("C")
-
-
-def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is sparse, B is dense.
-
- Args:
- A: tf.IndexedSlices with dense shape [m, n].
- B: tf.Tensor with shape [n, k].
- name: str. Name of op.
- transpose_a: Bool. If true we transpose A before multiplying it by B.
- (Default: False)
- transpose_b: Bool. If true we transpose B before multiplying it by A.
- (Default: False)
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A doesn't represent a matrix.
- ValueError: If B is not rank-2.
- """
- with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
- if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
- raise ValueError("A must represent a matrix. Found: %s." % A)
- if B.shape.ndims != 2:
- raise ValueError("B must be a matrix.")
- new_values = math_ops.matmul(
- A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
- return ops.IndexedSlices(
- new_values,
- A.indices,
- dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
-
-
-def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
-
- Args:
- A_diag: diagonal entries of matrix A of shape [m, m].
- B: tf.IndexedSlices. Represents matrix of shape [m, n].
- name: str. Name of op.
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A_diag is not rank-1.
- ValueError: If B doesn't represent a matrix.
- """
- with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
- A_diag = ops.convert_to_tensor(A_diag)
- if A_diag.shape.ndims != 1:
- raise ValueError("A_diag must be a rank-1 Tensor.")
- if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
- raise ValueError("B must represent a matrix. Found: %s." % B)
- a = array_ops.gather(A_diag, B.indices)
- a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
- return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
-
-
-class PartitionedTensor(object):
- """A Tensor partitioned across its 0-th dimension."""
-
- def __init__(self, tensors):
- """Initializes PartitionedTensor.
-
- Args:
- tensors: List of Tensors. All Tensors must agree on shape (excepting
- batch dimension) and dtype.
-
- Raises:
- ValueError: If 'tensors' has length zero.
- ValueError: if contents of 'tensors' don't agree on shape or dtype.
- """
- if not tensors:
- raise ValueError("tensors must be a list of 1+ Tensors.")
-
- dtype = tensors[0].dtype
- if not all(tensor.dtype == dtype for tensor in tensors):
- raise ValueError("all tensors must have dtype = %s." % dtype)
-
- shape = tensors[0].shape[1:]
- if not all(tensor.shape[1:] == shape for tensor in tensors):
- raise ValueError("All tensors must have shape = %s (excluding batch "
- "dimension)." % shape)
-
- self.tensors = tensors
- self._concats = {} # {device: Tensor}
-
- @property
- def shape(self):
- feature_shape = self.tensors[0].shape[1:]
- batch_size = sum([tensor.shape[0] for tensor in self.tensors],
- tensor_shape.Dimension(0))
- return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
-
- def get_shape(self):
- return self.shape
-
- @property
- def dtype(self):
- return self.tensors[0].dtype
-
- def __str__(self):
- return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
- self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
-
- def __hash__(self):
- return hash(tuple(self.tensors))
-
- def __eq__(self, other):
- if not isinstance(other, PartitionedTensor):
- return False
- return self.tensors == other.tensors
-
- def __ne__(self, other):
- return not self == other # pylint: disable=g-comparison-negation
-
- def __getitem__(self, key):
- return self.as_tensor()[key]
-
- def as_tensor(self, dtype=None, name=None, as_ref=False):
- with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
- assert not as_ref
- assert dtype in [None, self.dtype]
- result = array_ops.concat(self.tensors, axis=0)
-
- # Cache 'result' if we haven't already cached a value for this device.
- if result.device not in self._concats:
- self._concats[result.device] = result
- return self._concats[result.device]
-
- @property
- def device(self):
- # PartitionedTensors in general do not live on a single device. If the
- # device cannot be determined unambiguously this property will return None.
- device = self.tensors[0].device
- if all(tensor.device == device for tensor in self.tensors):
- return device
- return None
-
-
-ops.register_tensor_conversion_function(
- PartitionedTensor,
- lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
-
-
-# TODO(b/69623235): Add a function for finding tensors that share gradients
-# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
deleted file mode 100644
index 330d222dbf..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.utils import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "set_global_constants",
- "SequenceDict",
- "tensors_to_column",
- "column_to_tensors",
- "kronecker_product",
- "layer_params_to_mat2d",
- "mat2d_to_layer_params",
- "posdef_inv",
- "posdef_inv_matrix_inverse",
- "posdef_inv_cholesky",
- "posdef_inv_funcs",
- "SubGraph",
- "generate_random_signs",
- "fwd_gradients",
- "ensure_sequence",
- "batch_execute",
- "extract_convolution_patches",
- "extract_pointwise_conv2d_patches",
- "is_data_format_channel_last",
- "matmul_sparse_dense",
- "matmul_diag_sparse",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
index 39e9d65407..9a402d888c 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
@@ -270,7 +270,7 @@ class ReshapeTest(Base):
array_ops.placeholder(dtypes.float32, [None]), ['x'])
reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
np.testing.assert_array_equal(result, [[1], [2]])
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/test_util.py b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py
index 8f0416030f..900c9217c3 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/test_util.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/test_util.py
@@ -27,7 +27,7 @@ class Base(test.TestCase):
"""A class with some useful methods for testing."""
def eval(self, tensors):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index 7355a403ae..b4fe8cac74 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -185,7 +185,7 @@ py_test(
py_test(
name = "normalization_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/normalization_test.py"],
srcs_version = "PY2AND3",
tags = ["no_windows"], # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index a7b41b714f..af8e673f59 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Ops for building neural network layers, regularizers, summaries, etc.
-See the @{$python/contrib.layers} guide.
+See the
+[Contrib Layers](https://tensorflow.org/api_guides/python/contrib.layers)
+guide.
@@avg_pool2d
@@avg_pool3d
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 7ede193029..124515e5a6 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -109,7 +109,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -122,7 +122,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -136,7 +136,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -150,7 +150,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -164,7 +164,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -179,7 +179,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -192,7 +192,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -208,7 +208,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -224,7 +224,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -241,7 +241,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -276,7 +276,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_scattered_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@@ -288,7 +288,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@@ -304,7 +304,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
- with self.test_session():
+ with self.cached_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@@ -316,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -329,7 +329,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@@ -358,7 +358,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -370,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -408,7 +408,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_hashed_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@@ -429,7 +429,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -467,7 +467,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_shape(self):
"""Verifies the shape of the output tensor."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -481,7 +481,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values(self):
"""Verifies the values in a trivial case."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .2, .3])
@@ -495,7 +495,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sampled_candidates(self):
"""Verifies the values for given sampled_candidates."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -520,7 +520,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sign_hash(self):
"""Verifies the values in a trivial case with hash_signs=True."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .1, .1])
@@ -537,7 +537,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_distributive_property(self):
"""Verifies the distributive property of matrix multiplication."""
- with self.test_session():
+ with self.cached_session():
params = constant_op.constant([.1, .2, .3])
sp_values_a = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
@@ -710,7 +710,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
@@ -749,7 +749,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -767,7 +767,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
diff --git a/tensorflow/contrib/layers/python/layers/encoders_test.py b/tensorflow/contrib/layers/python/layers/encoders_test.py
index e8528e9890..1a2aa710d5 100644
--- a/tensorflow/contrib/layers/python/layers/encoders_test.py
+++ b/tensorflow/contrib/layers/python/layers/encoders_test.py
@@ -34,14 +34,14 @@ def _get_const_var(name, shape, value):
class EncodersTest(test.TestCase):
def testBowEncoderSparse(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc = encoders.bow_encoder(docs, 4, 3)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseTensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
enc = encoders.bow_encoder(sparse_docs, 4, 3)
@@ -49,28 +49,28 @@ class EncodersTest(test.TestCase):
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseEmptyRow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 5)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([3, 5], enc.eval().shape)
def testBowEncoderDense(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([4, 3], enc.eval().shape)
def testBowEncoderSparseTensorDenseLookup(self):
- with self.test_session():
+ with self.cached_session():
docs = [[0, 1]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
with self.assertRaises(TypeError):
encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False)
def testBowEncodersSharingEmbeddings(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test')
enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
@@ -79,7 +79,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsInheritedScopes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
with variable_scope.variable_scope('test'):
enc_1 = encoders.bow_encoder(docs, 4, 3)
@@ -90,7 +90,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsSharedScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow')
variable_scope.get_variable_scope().reuse_variables()
@@ -100,7 +100,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncoderReuseEmbeddingsVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
@@ -111,7 +111,7 @@ class EncodersTest(test.TestCase):
self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval())
def testEmbedSequence(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 3ae07cedab..53c8ae5d08 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -997,9 +997,14 @@ class _OneHotColumn(
# Remove (?, -1) index
weighted_column = sparse_ops.sparse_slice(
weighted_column,
- [0, 0],
+ array_ops.zeros_like(weighted_column.dense_shape),
weighted_column.dense_shape)
- return sparse_ops.sparse_tensor_to_dense(weighted_column)
+ dense_tensor = sparse_ops.sparse_tensor_to_dense(weighted_column)
+ batch_shape = array_ops.shape(dense_tensor)[:-1]
+ dense_tensor_shape = array_ops.concat(
+ [batch_shape, [self.length]], axis=0)
+ dense_tensor = array_ops.reshape(dense_tensor, dense_tensor_shape)
+ return dense_tensor
dense_id_tensor = sparse_ops.sparse_tensor_to_dense(sparse_id_column,
default_value=-1)
@@ -1095,9 +1100,9 @@ class _EmbeddingColumn(
raise ValueError("Must specify both `ckpt_to_load_from` and "
"`tensor_name_in_ckpt` or none of them.")
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" after "
- "2017/02/25.")
+ logging.warn("The default stddev value of initializer was changed from "
+ "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" in core "
+ "implementation (tf.feature_column.embedding_column).")
stddev = 1 / math.sqrt(sparse_id_column.length)
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
@@ -1496,8 +1501,6 @@ class _ScatteredEmbeddingColumn(
raise ValueError("initializer must be callable if specified. "
"column_name: {}".format(column_name))
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"0.1\" to \"1/sqrt(dimension)\" after 2017/02/25.")
stddev = 0.1
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index e6bbd86ab7..6fb4b9ff35 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -49,7 +49,7 @@ class TransformerTest(test.TestCase):
real_valued = feature_column.real_valued_column("price")
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops._Transformer(features).transform(real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[20.], [110], [-3]])
def testSparseRealValuedColumnIdentityTransformation(self):
@@ -60,7 +60,7 @@ class TransformerTest(test.TestCase):
features = {"rating": rating_tensor}
output = feature_column_ops._Transformer(features).transform(
sparse_real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.values.eval(), rating_tensor.values.eval())
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -80,7 +80,7 @@ class TransformerTest(test.TestCase):
[sparse_real_valued])
self.assertTrue(sparse_real_valued in output_dict)
output = output_dict[sparse_real_valued]
- with self.test_session():
+ with self.cached_session():
self.assertArrayNear(output.values.eval(), [4.0, 25.0], 1e-5)
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -97,7 +97,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[bucket])
self.assertEqual(len(output), 1)
self.assertIn(bucket, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[bucket].eval(), [[2], [3], [0]])
def testBucketizedColumnWithMultiDimensions(self):
@@ -109,7 +109,7 @@ class TransformerTest(test.TestCase):
"price": constant_op.constant([[20., 110], [110., 20], [-3, -3]])
}
output = feature_column_ops._Transformer(features).transform(bucket)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[2, 3], [3, 2], [0, 0]])
def testCachedTransformation(self):
@@ -118,7 +118,7 @@ class TransformerTest(test.TestCase):
# buckets 2, 3, 0
features = {"price": constant_op.constant([[20.], [110], [-3]])}
transformer = feature_column_ops._Transformer(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
transformer.transform(bucket)
num_of_ops = len(sess.graph.get_operations())
# Verify that the second call to transform the same feature
@@ -138,7 +138,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -161,7 +161,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -177,7 +177,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int64)
@@ -203,7 +203,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 2)
self.assertIn(hashed_sparse, output)
self.assertIn(wire_embedding, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[wire_embedding].indices.eval(),
wire_tensor.indices.eval())
self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
@@ -223,7 +223,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[keys_sparse])
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
@@ -241,7 +241,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(keys_sparse)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
@@ -264,7 +264,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int32)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -282,7 +282,7 @@ class TransformerTest(test.TestCase):
wire_tensor = constant_op.constant([[100, 0], [1, 25]])
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int32)
@@ -310,7 +310,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(weighted_ids, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
@@ -340,7 +340,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -362,7 +362,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -386,7 +386,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -408,7 +408,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -440,7 +440,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_language])
self.assertEqual(len(output), 1)
self.assertIn(country_language, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_language].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_language].values.eval(
@@ -467,7 +467,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_price])
self.assertEqual(len(output), 1)
self.assertIn(country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_price].values.eval()))
@@ -498,7 +498,7 @@ class TransformerTest(test.TestCase):
weights = column_to_variable[country_price][0]
grad = array_ops.squeeze(
gradients_impl.gradients(output, weights)[0].values)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertEqual(len(grad.eval()), 6)
@@ -537,7 +537,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[wire_country_price])
self.assertEqual(len(output), 1)
self.assertIn(wire_country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[wire_country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[wire_country_price].values.eval(
@@ -600,7 +600,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
columns = [one_hot_column, embedding_column, real_valued_column]
output = feature_column_ops.input_from_feature_columns(features, columns)
output_core = fc_core.input_layer(features, columns)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
@@ -626,7 +626,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
cols_to_outs = {}
feature_column_ops.input_from_feature_columns(
features, columns, cols_to_outs=cols_to_outs)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
for column in columns:
@@ -637,7 +637,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -650,7 +650,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -662,7 +662,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating, output)
@@ -673,7 +673,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0, 1, 2, -1],
[3, 4, 5, 6]])
features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating.astype(np.float32), output)
@@ -684,7 +684,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -698,7 +698,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -713,7 +713,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -729,7 +729,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0],
[1, 0, 0, 0, 1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -752,7 +752,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_column])
output_core = fc_core.input_layer(features, [one_hot_column])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
@@ -773,7 +773,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
@@ -794,7 +794,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
@@ -816,7 +816,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@@ -834,7 +834,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
@@ -852,7 +852,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [4, 10])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -878,7 +878,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features, [embedded_sparse], weight_collections=["my_collection_core"])
weights_core = ops.get_collection("my_collection_core")
grad_core = gradients_impl.gradients(output_core, weights_core)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
gradient_values = []
gradient_values_core = []
@@ -907,7 +907,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10])
@@ -935,7 +935,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# Makes sure that trying to use different initializers with the same
# embedding column explicitly fails.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Duplicate feature column key found for column: wire_embedding"):
@@ -961,7 +961,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -986,7 +986,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1005,7 +1005,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(crossed, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1016,7 +1016,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"wire": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: wire"):
variables_lib.global_variables_initializer().run()
@@ -1035,7 +1035,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"ids": ids_tensor, "weights": weights_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
@@ -1053,7 +1053,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"aaa": wire_tensor, "bbb": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: aaa_X_bbb"):
variables_lib.global_variables_initializer().run()
@@ -1080,7 +1080,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
hashed_sparse, 10, initializer=init_ops.constant_initializer(133.7))
output = feature_column_ops.input_from_feature_columns(
features, [real_valued, bucket, embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# size of output = 3 (real_valued) + 2 * 4 (bucket) + 10 (embedding) = 21
self.assertAllEqual(output.eval().shape, [3, 21])
@@ -1099,7 +1099,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values)
self.assertAllEqual(output.eval(), [[1.], [2.], [0.]])
@@ -1119,7 +1119,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
max_norm=0.5)
output = feature_column_ops.input_from_feature_columns(features,
[embedded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values * 0.5)
self.assertAllClose(output.eval(), [[0.5], [1.], [0.]])
@@ -1144,7 +1144,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# score: (sum of weights)
@@ -1236,7 +1236,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# There should be one trainable variables for sparse_2
self.assertEqual(1, len(variables_lib.trainable_variables()))
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_1_eval = output_1.eval()
output_2_eval = output_2.eval()
@@ -1295,7 +1295,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(measurement_input, model_inputs)
@@ -1305,7 +1305,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(
feature_column_ops.sequence_input_from_feature_columns(
features, [var_len_real_valued]))
@@ -1329,7 +1329,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(reshaped_measurements, model_inputs)
@@ -1350,7 +1350,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(measurement_input), model_inputs)
@@ -1373,7 +1373,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(reshaped_measurements), model_inputs)
@@ -1395,7 +1395,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1429,7 +1429,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1459,7 +1459,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1488,7 +1488,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1518,7 +1518,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights = ops.get_collection("my_collection")
gradient_tensor = gradients_impl.gradients(model_input_tensor,
embedding_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
@@ -1585,7 +1585,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
columns_to_tensors, model_input_columns)
self.assertEqual(dtypes.float32, model_input_tensor.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1640,7 +1640,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1654,7 +1654,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1676,7 +1676,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1695,7 +1695,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1716,7 +1716,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [crossed], num_outputs=5)
logits_core = fc_core.linear_model(features, [crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1730,7 +1730,7 @@ class WeightedSumTest(test.TestCase):
dense_shape=[2, 2])
features = {"wire": wire_tensor}
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating weighted sum for column: wire_embedding"):
variables_lib.global_variables_initializer().run()
@@ -1756,7 +1756,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
logits_core = fc_core.linear_model(features, [movies])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.initialize_all_variables().run()
lookup_ops.tables_initializer().run()
@@ -1776,7 +1776,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [real_valued], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1789,7 +1789,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1814,7 +1814,7 @@ class WeightedSumTest(test.TestCase):
features, [real_valued, bucket, hashed_sparse, crossed], num_outputs=5)
output_core = fc_core.linear_model(
features, [real_valued, bucket, hashed_sparse, crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1837,7 +1837,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [age, language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1877,7 +1877,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language], num_outputs=1))
# Assert that only a single weight is created.
self.assertEqual(len(variables), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1941,7 +1941,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1969,7 +1969,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1992,7 +1992,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [movies], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2026,7 +2026,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2083,7 +2083,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2124,7 +2124,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language, country_language],
num_outputs=1,
scope=scope))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2161,7 +2161,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, incomes], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2197,7 +2197,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, height, incomes], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2228,7 +2228,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2259,7 +2259,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket, country])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2290,7 +2290,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2326,7 +2326,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2365,7 +2365,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2389,7 +2389,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2404,7 +2404,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2419,7 +2419,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2440,7 +2440,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2452,7 +2452,7 @@ class WeightedSumTest(test.TestCase):
features = {"age": constant_op.constant([[10.], [20.], [30.], [40.]])}
output, _, bias = feature_column_ops.weighted_sum_from_feature_columns(
features, [feature_column.real_valued_column("age")], num_outputs=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
@@ -2466,7 +2466,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2490,7 +2490,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2516,7 +2516,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2556,7 +2556,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2585,7 +2585,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2651,7 +2651,7 @@ class ParseExampleTest(test.TestCase):
feature_columns=[bucket, wire_cast])
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
@@ -2713,7 +2713,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn("measurements", seq)
self.assertIsInstance(seq["measurements"], ops.Tensor)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
location_val, wire_cast_val, measurement_val = sess.run(
[ctx["location"], seq["wire_cast"], seq["measurements"]])
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index 1de9ab7056..d90d6ecf7f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -57,6 +57,29 @@ def _sparse_id_tensor(shape, vocab_size, seed=112123):
indices=indices, values=values, dense_shape=shape)
+def _sparse_id_tensor_with_weights(shape, vocab_size, seed=112123):
+ # Returns a arbitrary `SparseTensor` with given shape and vocab size.
+ assert vocab_size >= shape[-1]
+ np.random.seed(seed)
+ indices = np.array(list(itertools.product(*[range(s) for s in shape])))
+
+ # Values must be distinct from the vocab
+ values = np.ndarray.flatten(np.array([
+ np.random.choice(vocab_size, size=shape[-1], replace=False)
+ for _ in range(np.prod(shape[:-1]))]))
+ weights = np.sort(np.random.rand(*shape), axis=len(shape)-1)
+
+ # Remove entries if weight < 0.5 for sparsity.
+ keep = np.ndarray.flatten(weights < 0.5) # Remove half of them
+ indices = indices[keep]
+ values = values[keep]
+ weights = np.ndarray.flatten(weights)[keep]
+ return (sparse_tensor_lib.SparseTensor(
+ indices=indices, values=values, dense_shape=shape),
+ sparse_tensor_lib.SparseTensor(
+ indices=indices, values=weights, dense_shape=shape))
+
+
class FeatureColumnTest(test.TestCase):
def testImmutability(self):
@@ -178,7 +201,7 @@ class FeatureColumnTest(test.TestCase):
b2 = feature_column_ops.input_from_feature_columns({
b[1]: input_tensor_c2
}, [b[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
b1_value = b1.eval()
b2_value = b2.eval()
@@ -207,7 +230,7 @@ class FeatureColumnTest(test.TestCase):
e1 = feature_column_ops.input_from_feature_columns({
e[0]: input_tensor_c1
}, [e[0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
d1_value = d1.eval()
e1_value = e1.eval()
@@ -317,7 +340,7 @@ class FeatureColumnTest(test.TestCase):
with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
one_hot_output = one_hot._to_dnn_input_layer(
id_tensor, output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
expected_shape = (id_tensor_shape[:output_rank - 1] + [vocab_size])
self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -329,6 +352,34 @@ class FeatureColumnTest(test.TestCase):
self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights")
self.assertEqual(one_hot.length, 3)
+ def testIntegerizedOneHotColumnForWeightedSparseColumn(self):
+ vocab_size = 5
+ ids = fc.sparse_column_with_integerized_feature("ids", vocab_size)
+ weighted_ids = fc.weighted_sparse_column(ids, "weights")
+ one_hot = fc.one_hot_column(weighted_ids)
+ self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights")
+ self.assertEqual(one_hot.length, vocab_size)
+
+ def testIntegerizedOneHotWeightedSparseColumnShape(self):
+ vocab_size = 5
+ for id_tensor_shape in [[4, 3], [2, 4], [3, 3, 3]]:
+ output_rank = len(id_tensor_shape)
+ a = fc.sparse_column_with_integerized_feature("a", vocab_size)
+ weighted = fc.weighted_sparse_column(a, "weights")
+ one_hot = fc.one_hot_column(weighted)
+ id_tensor, weight_tensor = _sparse_id_tensor_with_weights(
+ id_tensor_shape, vocab_size)
+
+ one_hot_output = one_hot._to_dnn_input_layer(
+ (id_tensor, weight_tensor),
+ output_rank=output_rank)
+ one_hot_output_shape = one_hot_output.get_shape().as_list()
+ expected_shape = id_tensor_shape[:-1] + [vocab_size]
+ self.assertEquals(expected_shape, one_hot_output_shape)
+ with self.cached_session() as sess:
+ one_hot_value = sess.run(one_hot_output)
+ self.assertEquals(expected_shape, list(one_hot_value.shape))
+
def testOneHotColumnWithSparseColumnWithHashKeys(self):
input_values = ["marlo", "unknown", "omar"]
inputs = constant_op.constant(input_values)
@@ -348,7 +399,7 @@ class FeatureColumnTest(test.TestCase):
expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0.,
0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
self.assertTrue(np.array_equal(one_hot_value, expected))
@@ -389,7 +440,7 @@ class FeatureColumnTest(test.TestCase):
}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[2., 6., 0.]], one_hot_tensor.eval())
@@ -400,7 +451,7 @@ class FeatureColumnTest(test.TestCase):
features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[1., 1., 0.]], one_hot_tensor.eval())
@@ -552,7 +603,7 @@ class FeatureColumnTest(test.TestCase):
real_valued_output = real_valued_column._to_dnn_input_layer(
constant_op.constant(real_valued_input, dtype=dtypes.float32),
output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
real_valued_eval = sess.run(real_valued_output)
expected_shape = (
input_shape[:output_rank - 1] +
@@ -746,7 +797,7 @@ class FeatureColumnTest(test.TestCase):
sparse_column.insert_transformed_feature(features)
sparse_output = features[sparse_column]
expected_shape = [batch_size, 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(sparse_output)
self.assertEquals(expected_shape, list(sparse_result.dense_shape))
@@ -1059,7 +1110,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
saved_embedding = embeddings.eval()
save.save(sess, checkpoint_path)
@@ -1080,7 +1131,7 @@ class FeatureColumnTest(test.TestCase):
embedding_col_initialized: input_tensor
}, [embedding_col_initialized])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_embedding = pretrained_embeddings.eval()
@@ -1125,7 +1176,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(assign_op)
saved_col_weights = col_weights[crossed_col][0].eval()
@@ -1150,7 +1201,7 @@ class FeatureColumnTest(test.TestCase):
}, [crossed_col_initialized], 1))
col_weights_from_ckpt = col_weights[crossed_col_initialized][0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_col_weights = col_weights_from_ckpt.eval()
diff --git a/tensorflow/contrib/layers/python/layers/initializers_test.py b/tensorflow/contrib/layers/python/layers/initializers_test.py
index b7fe878893..bd3692b258 100644
--- a/tensorflow/contrib/layers/python/layers/initializers_test.py
+++ b/tensorflow/contrib/layers/python/layers/initializers_test.py
@@ -85,7 +85,7 @@ class VarianceScalingInitializerTest(test.TestCase):
def _test_variance(self, initializer, shape, variance, factor, mode, uniform):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
var = variable_scope.get_variable(
name='test',
shape=shape,
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 04668f112d..a82d4c1951 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -3109,7 +3109,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
inputs: Tensor input
num_units: Specifies how many features will remain after maxout
in the `axis` dimension (usually channel).
- This must be multiple of number of `axis`.
+ This must be a factor of number of features.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
scope: Optional scope for variable_scope.
@@ -3128,7 +3128,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'.format(
num_channels, num_units))
- shape[axis] = -1
+ shape[axis] = num_units
shape += [num_channels // num_units]
# Dealing with batches with arbitrary sizes
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 51c7abb105..85af9de4e4 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -281,7 +281,7 @@ class BiasAddTest(test.TestCase):
def testCreate(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.bias_add(images)
self.assertEqual(output.op.name, 'BiasAdd/BiasAdd')
@@ -289,7 +289,7 @@ class BiasAddTest(test.TestCase):
def testCreateWithActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.bias_add(images, activation_fn=nn_ops.relu)
self.assertEqual(output.op.name, 'BiasAdd/Relu')
@@ -298,7 +298,7 @@ class BiasAddTest(test.TestCase):
def testCreateDimensions(self):
dims = (2, 3, 4)
shape = [5, 2, 3, 4]
- with self.test_session():
+ with self.cached_session():
for d in dims:
input_shape = shape[:d]
inputs = random_ops.random_uniform(input_shape, seed=1)
@@ -311,7 +311,7 @@ class BiasAddTest(test.TestCase):
class ConvolutionTest(test.TestCase):
def testInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'Convolution expects input with rank 5, got 4'):
@@ -323,14 +323,14 @@ class ConvolutionTest(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
layers_lib.convolution2d(images, 32, 3, data_format='CHWN')
def testCreateConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -342,7 +342,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNCHW(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW')
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -354,7 +354,7 @@ class ConvolutionTest(test.TestCase):
def testCreateSquareConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -362,7 +362,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithTensorShape(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, images.get_shape()[1:3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -370,7 +370,7 @@ class ConvolutionTest(test.TestCase):
def testCreateFullyConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
output = layers_lib.convolution2d(
images, 64, images.get_shape()[1:3], padding='VALID')
@@ -381,7 +381,7 @@ class ConvolutionTest(test.TestCase):
def testFullyConvWithCustomGetter(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
called = [0]
def custom_getter(getter, *args, **kwargs):
@@ -395,7 +395,7 @@ class ConvolutionTest(test.TestCase):
def testCreateVerticalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 1])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -407,7 +407,7 @@ class ConvolutionTest(test.TestCase):
def testCreateHorizontalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [1, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -417,7 +417,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithStride(self):
height, width = 6, 8
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], stride=2)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -427,7 +427,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
@@ -436,7 +436,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
@@ -453,14 +453,14 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithoutActivation(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], activation_fn=None)
self.assertEqual(output.op.name, 'Conv/BiasAdd')
def testCreateConvValid(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
@@ -468,7 +468,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithWD(self):
height, width = 7, 9
weight_decay = 0.01
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(weight_decay)
layers_lib.convolution2d(
@@ -481,7 +481,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNoRegularizers(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(
@@ -489,7 +489,7 @@ class ConvolutionTest(test.TestCase):
def testReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(len(variables.get_variables()), 2)
@@ -498,7 +498,7 @@ class ConvolutionTest(test.TestCase):
def testNonReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(len(variables.get_variables()), 2)
@@ -507,7 +507,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithWD(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
with arg_scope(
@@ -523,7 +523,7 @@ class ConvolutionTest(test.TestCase):
def testConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -539,7 +539,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -557,7 +557,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], rate=2, scope='conv1')
@@ -573,7 +573,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -587,7 +587,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -601,7 +601,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -612,7 +612,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 7, 9, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -651,7 +651,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 5, 7, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -670,7 +670,7 @@ class ConvolutionTest(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -688,7 +688,7 @@ class ConvolutionTest(test.TestCase):
padding='VALID',
activation_fn=None,
scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/BiasAdd')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -712,7 +712,7 @@ class Convolution2dTransposeTests(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'data_format has to be either NCHW or NHWC.'):
@@ -915,7 +915,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='SAME')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -929,7 +929,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='VALID')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -944,7 +944,7 @@ class Convolution2dTransposeTests(test.TestCase):
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -958,7 +958,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -971,7 +971,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -984,7 +984,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -997,7 +997,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1010,7 +1010,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 1], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1023,7 +1023,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 4], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1036,7 +1036,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 5], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1067,7 +1067,7 @@ class Convolution2dTransposeTests(test.TestCase):
conv = layers_lib.conv2d(
transpose, num_filters, filter_size, stride=stride, padding='VALID')
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(conv.eval().shape), input_size)
@@ -1083,7 +1083,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=[2, 2], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
@@ -1095,7 +1095,7 @@ class Convolution2dTransposeTests(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 18, 22, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.conv2d_transpose(
@@ -1116,7 +1116,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=2, padding='VALID', scope='conv7')
self.assertEqual(output.op.name, 'conv7/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1135,7 +1135,7 @@ class Convolution2dTransposeTests(test.TestCase):
scope='conv7')
self.assertEqual(output.op.name, 'conv7/BiasAdd')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1146,7 +1146,7 @@ class Convolution2dTransposeTests(test.TestCase):
stride = 2
padding = 'VALID'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(input_size, seed=1)
output_deconv = layers_lib.conv2d_transpose(
images,
@@ -1184,7 +1184,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
expected = np.zeros((1, 10, 9, 1))
@@ -1201,7 +1201,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(
horz_gradients, feed_dict={
@@ -1225,7 +1225,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1245,7 +1245,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1267,7 +1267,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1283,12 +1283,12 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
expected = np.zeros((1, 9, 10, 1))
- self.assertAllEqual(result, expected)
+ self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
def testVertConvWithVaryingImage(self):
image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9'))
@@ -1306,7 +1306,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
@@ -1314,7 +1314,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConv1dShape(self):
width = 7
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, width, 3), seed=1)
output = layers_lib.convolution1d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -1322,7 +1322,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConvInferSpatialDims(self):
depth, height, width = 7, 9, 11
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
output = layers_lib.convolution(images, 32, [3])
self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
@@ -1344,7 +1344,7 @@ class DenseToSparseTest(test.TestCase):
sparse = _layers.dense_to_sparse(tensor)
dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
sparse.values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant = sess.run(dense)
self.assertAllEqual(expected_constant, constant)
@@ -1353,7 +1353,7 @@ class DropoutTest(test.TestCase):
def testCreateDropout(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.dropout(images)
self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
@@ -1362,7 +1362,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantTrue(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(True)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1370,7 +1370,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantFalse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(False)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1378,7 +1378,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithPlaceholder(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1387,7 +1387,7 @@ class DropoutTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1396,7 +1396,7 @@ class DropoutTest(test.TestCase):
def testDropout(self):
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1409,7 +1409,7 @@ class DropoutTest(test.TestCase):
def testDropoutSeed(self):
"""Test that providing the same seed produces the same result."""
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output1 = _layers.dropout(images, seed=1)
@@ -1418,7 +1418,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1431,7 +1431,7 @@ class DropoutTest(test.TestCase):
def testCreateFCFollowByDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(images, 50)
@@ -1445,7 +1445,7 @@ class DropoutTest(test.TestCase):
def testCreateFCWithDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(
@@ -1460,14 +1460,14 @@ class DropoutTest(test.TestCase):
class FlattenTest(test.TestCase):
def testInvalidRank(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5,)))
with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'):
_layers.flatten(inputs)
def testUnknownLastDim(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5, None)))
output = _layers.flatten(inputs)
@@ -1475,7 +1475,7 @@ class FlattenTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.flatten(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1484,7 +1484,7 @@ class FlattenTest(test.TestCase):
def testFlatten4D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.flatten(images)
@@ -1494,7 +1494,7 @@ class FlattenTest(test.TestCase):
def testFlatten3D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width), seed=1, name='images')
output = _layers.flatten(images)
@@ -1504,7 +1504,7 @@ class FlattenTest(test.TestCase):
def testFlattenBatchSize(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, height, width, 3))
@@ -1516,7 +1516,7 @@ class FlattenTest(test.TestCase):
def testUnknownDims(self):
height = width = depth = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, depth), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None))
@@ -1551,7 +1551,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs, new_rank)
static_shape = flattened_t.get_shape().as_list()
self.assertEqual(static_shape, expected_new_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_flattened, flattened)
@@ -1571,7 +1571,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs_t, new_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_indices, flattened.indices)
@@ -1629,7 +1629,7 @@ class FCTest(test.TestCase):
def testCreateFC(self):
height, width = 3, 3
for layer_fn in (_layers.fully_connected, layers_lib.relu):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = np.random.uniform(size=(5, height * width * 3))
output = layer_fn(inputs, 32)
self.assertEqual(output.op.name, 'fully_connected/Relu')
@@ -1641,7 +1641,7 @@ class FCTest(test.TestCase):
def testCreateFCWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(output.op.name, 'fc1/Relu')
@@ -1659,7 +1659,7 @@ class FCTest(test.TestCase):
def testCreateFcCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('fc1/weights'))
self.assertFalse(variables.get_variables('fc1/biases'))
_layers.fully_connected(inputs, 32, scope='fc1')
@@ -1669,7 +1669,7 @@ class FCTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(len(variables.get_variables('fc1')), 2)
_layers.fully_connected(inputs, 32, scope='fc1', reuse=True)
@@ -1678,7 +1678,7 @@ class FCTest(test.TestCase):
def testNonReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32)
self.assertEqual(len(variables.get_variables('fully_connected')), 2)
_layers.fully_connected(inputs, 32)
@@ -1713,14 +1713,14 @@ class FCTest(test.TestCase):
def testCreateFCWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, activation_fn=None)
self.assertEqual(output.op.name, 'fully_connected/BiasAdd')
def testCreateFCWithWD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
@@ -1732,7 +1732,7 @@ class FCTest(test.TestCase):
def testCreateFCWithBD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
bias_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, biases_regularizer=bias_decay)
@@ -1744,7 +1744,7 @@ class FCTest(test.TestCase):
def testCreateNoRegularizers(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
_layers.fully_connected(inputs, 32)
self.assertEqual(
@@ -1752,7 +1752,7 @@ class FCTest(test.TestCase):
def testReuseFCWithWD(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(
@@ -1768,7 +1768,7 @@ class FCTest(test.TestCase):
def testFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1786,7 +1786,7 @@ class FCTest(test.TestCase):
def testReuseFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1814,27 +1814,27 @@ class BatchNormTest(test.TestCase):
a, center=False, data_format='NCHW', zero_debias_moving_mean=True)
def testUnknownShape(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
with self.assertRaisesRegexp(ValueError, 'undefined rank'):
_layers.batch_norm(inputs)
def testInvalidDataFormat(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
with self.assertRaisesRegexp(
ValueError, 'data_format has to be either NCHW or NHWC.'):
_layers.batch_norm(inputs, data_format='CHWN')
def testUnknownChannelsDimNHWC(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5, 3, 3, None)))
with self.assertRaisesRegexp(ValueError, 'undefined'):
_layers.batch_norm(inputs, data_format='NHWC')
def testUnknownChannelsDimNCHW(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5, None, 3, 3)))
with self.assertRaisesRegexp(ValueError, 'undefined'):
@@ -1844,7 +1844,7 @@ class BatchNormTest(test.TestCase):
if dtype is None:
dtype = dtypes.float32
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(
dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused)
@@ -1866,7 +1866,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused)
@@ -1883,7 +1883,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpGammaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(
@@ -1901,7 +1901,7 @@ class BatchNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0]
@@ -1915,7 +1915,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
self.assertEqual(len(variables.get_model_variables()), 4)
@@ -1926,7 +1926,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariablesZeroDebias(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(
images, scale=True, zero_debias_moving_mean=True, fused=False)
@@ -1943,7 +1943,7 @@ class BatchNormTest(test.TestCase):
def testUpdatesCollection(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, updates_collections='my_update_ops')
update_layers = ops.get_collection('my_update_ops')
@@ -1971,7 +1971,7 @@ class BatchNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True, scope='bn')
_layers.batch_norm(images, scale=True, scope='bn', reuse=True)
@@ -1986,7 +1986,7 @@ class BatchNormTest(test.TestCase):
def testReuseUpdateOps(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with arg_scope([_layers.batch_norm], updates_collections='update_ops'):
_layers.batch_norm(images, scope='bn')
@@ -1996,7 +1996,7 @@ class BatchNormTest(test.TestCase):
def testCreateMovingVars(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_ = _layers.batch_norm(images)
moving_mean = variables.get_variables('BatchNorm/moving_mean')
@@ -2029,7 +2029,7 @@ class BatchNormTest(test.TestCase):
moving_variance = variables.get_variables_by_name('moving_variance')[0]
biased = variables.get_variables_by_name('biased')[0]
local_step = variables.get_variables_by_name('local_step')[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertAllClose(local_step.eval(), 0)
self.assertAllClose(moving_mean.eval(), [0] * channels)
@@ -2213,7 +2213,7 @@ class BatchNormTest(test.TestCase):
def _testEvalMovingVars(self, zero_debias_moving_mean=False):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2264,7 +2264,7 @@ class BatchNormTest(test.TestCase):
height, width = 3, 3
batch_size = 10
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (batch_size, height, width, channels)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2435,7 +2435,7 @@ class BatchNormTest(test.TestCase):
def testNoUpdatesWhenIsTrainingFalse(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2460,7 +2460,7 @@ class BatchNormTest(test.TestCase):
def testNoneUpdatesCollectionNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2647,7 +2647,7 @@ class BatchNormTest(test.TestCase):
def testCustomInitializer(self):
height, width = 3, 3
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = (np.ones((5, height, width, channels)) * 9.0).astype('f')
beta = init_ops.constant_initializer(
(np.ones(channels) * 5.0).astype('f'))
@@ -2728,7 +2728,7 @@ class BatchNormTest(test.TestCase):
def testBatchNormBeta(self):
# Test case for 11673
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
_layers.batch_norm(
a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
@@ -2739,7 +2739,7 @@ class BatchNormTest(test.TestCase):
def testVariablesAreFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float16)
_layers.batch_norm(images, scale=True)
@@ -2810,13 +2810,13 @@ class BatchNormTest(test.TestCase):
class LayerNormTest(test.TestCase):
def testUnknownShape(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
with self.assertRaisesRegexp(ValueError, 'undefined rank'):
_layers.layer_norm(inputs)
def testParamsDimsNotFullyDefined(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5, 3, 3, None)))
with self.assertRaisesRegexp(ValueError, 'is not fully defined'):
@@ -2824,7 +2824,7 @@ class LayerNormTest(test.TestCase):
def testCreateOp(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.layer_norm(images)
self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm'))
@@ -2832,7 +2832,7 @@ class LayerNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images)
beta = variables.get_variables_by_name('beta')[0]
@@ -2842,7 +2842,7 @@ class LayerNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images, scope='ln')
_layers.layer_norm(images, scope='ln', reuse=True)
@@ -2853,7 +2853,7 @@ class LayerNormTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2876,7 +2876,7 @@ class LayerNormTest(test.TestCase):
for sigma in [1.0, 0.1]:
input_values = np.random.randn(*input_shape) * sigma + mu
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
inputs = constant_op.constant(
input_values, shape=input_shape, dtype=dtype)
output_t = _layers.layer_norm(
@@ -2940,7 +2940,7 @@ class GDNTest(test.TestCase):
def _runGDN(self, x, shape, inverse, data_format):
inputs = array_ops.placeholder(dtypes.float32, shape)
outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
y, = sess.run([outputs], {inputs: x})
return y
@@ -3152,14 +3152,14 @@ class MaxPool3DTest(test.TestCase):
class OneHotEncodingTest(test.TestCase):
def testOneHotEncodingCreate(self):
- with self.test_session():
+ with self.cached_session():
labels = np.array([0, 1, 2])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertEqual(output.op.name, 'OneHotEncoding/one_hot')
self.assertListEqual(output.get_shape().as_list(), [3, 3])
def testCollectOutputs(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
output = _layers.one_hot_encoding(
labels, num_classes=3, outputs_collections='outputs')
@@ -3168,14 +3168,14 @@ class OneHotEncodingTest(test.TestCase):
self.assertEqual(c_output, output)
def testOneHotEncoding(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertAllClose(output.eval(), one_hot_labels.eval())
def testOneHotEncodingInt32(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2], dtype=dtypes.int32)
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
@@ -3186,7 +3186,7 @@ class RepeatTests(test.TestCase):
def testRepeat(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3])
self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu')
@@ -3194,7 +3194,7 @@ class RepeatTests(test.TestCase):
def testRepeatWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.repeat(
@@ -3207,7 +3207,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvInt32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.int32, maxval=12345)
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
@@ -3215,7 +3215,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float32)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 2)
@@ -3224,7 +3224,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConv(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, None, [3, 3], 2)
self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
@@ -3233,7 +3233,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3245,7 +3245,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3257,7 +3257,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3268,14 +3268,14 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 6, scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
def testCreateConvWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 8, activation_fn=None)
@@ -3283,7 +3283,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID')
@@ -3291,7 +3291,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID', rate=2)
@@ -3299,7 +3299,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID')
@@ -3307,7 +3307,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousDepthwiseConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID', rate=2)
@@ -3316,7 +3316,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithWeightDecay(self):
random_seed.set_random_seed(0)
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3360,7 +3360,7 @@ class SeparableConv2dTest(test.TestCase):
def testReuseConvWithWeightDecay(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3419,7 +3419,7 @@ class SeparableConv2dTest(test.TestCase):
normalizer_params={},
scope='conv1')
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = np.random.rand(5, height, width, 3)
sess.run(init_op)
sess.run(net, feed_dict={images_placeholder: images})
@@ -3440,7 +3440,7 @@ class SeparableConv2dTest(test.TestCase):
def testSepConvNCHW(self):
for num_filters, correct_output_filters in zip((None, 5), (6, 5)):
- with self.test_session():
+ with self.cached_session():
batch, height, width = 4, 10, 12
kernel_dim, stride = 3, 2
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
@@ -3462,7 +3462,7 @@ class ScaleGradientTests(test.TestCase):
"""Simple tests of the scale_gradient function."""
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([42], np.float32)
gradient_scale = np.array([2], np.float32)
@@ -3513,7 +3513,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction = np.array([[self.low, self.high], [0.5, 0.5],
[self.high, self.low]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3529,7 +3529,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3547,7 +3547,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logit_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction, feed_dict=feed_dict)
self.assertAllClose(exp_prediction, prediction)
@@ -3575,7 +3575,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3586,7 +3586,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features, data_format='NCHW')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3613,7 +3613,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3637,7 +3637,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3669,7 +3669,7 @@ class SpatialSoftmaxTests(test.TestCase):
batch_size, nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features1}
tf_keypoints1 = sess.run(spatial_softmax, feed_dict)
@@ -3696,7 +3696,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3719,7 +3719,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3731,7 +3731,7 @@ class SpatialSoftmaxTests(test.TestCase):
spatial_softmax = _layers.spatial_softmax(features)
net = _layers.fully_connected(spatial_softmax, 10)
np_features = np.zeros(batch_shape, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
sess.run(net, feed_dict)
@@ -3741,7 +3741,7 @@ class StackTests(test.TestCase):
def testStackFullyConnected(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height * width * 3))
output = _layers.stack(images, _layers.fully_connected, [10, 20, 30])
self.assertEqual(output.op.name, 'Stack/fully_connected_3/Relu')
@@ -3749,7 +3749,7 @@ class StackTests(test.TestCase):
def testStackFullyConnectedFailOnReuse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('test', reuse=True):
images = np.random.uniform(size=(5, height * width * 3))
with self.assertRaises(ValueError):
@@ -3757,7 +3757,7 @@ class StackTests(test.TestCase):
def testStackRelu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.relu, [10, 20, 30])
@@ -3766,7 +3766,7 @@ class StackTests(test.TestCase):
def testStackElu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.elu, [10, 20, 30])
@@ -3775,7 +3775,7 @@ class StackTests(test.TestCase):
def testStackConvolution2d(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3788,7 +3788,7 @@ class StackTests(test.TestCase):
def testStackWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3817,7 +3817,7 @@ class UnitNormTests(test.TestCase):
del shape[dim]
expected = np.ones(shape)
- with self.test_session():
+ with self.cached_session():
actual = norms.eval()
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3849,7 +3849,7 @@ class UnitNormTests(test.TestCase):
norms = math_ops.sqrt(
math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
- with self.test_session():
+ with self.cached_session():
actual = norms.eval({image: placeholder_value})
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3875,7 +3875,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
for dim in range(len(x_shape)):
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3893,7 +3893,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
dim = [1, 2]
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3908,7 +3908,7 @@ class PoincareNormalizeTest(test.TestCase):
np.random.seed(1)
x_np = np.random.random_sample(x_shape).astype(np.float64)
for dim in range(len(x_shape)):
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -4117,7 +4117,7 @@ class LegacyFullyConnectedTest(test.TestCase):
# Empty x is common if someone masks their input with tf.boolean_mask in
# order to drop missing entries, and in a particular batch all entries are
# missing.
- with self.test_session():
+ with self.cached_session():
x = np.array([]).reshape(0, 3)
self.assertEqual(0, array_ops.size(x).eval())
y = _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax)
@@ -4131,7 +4131,7 @@ class LegacyFullyConnectedTest(test.TestCase):
y = _layers.legacy_fully_connected(x, 1)
# in the output we still only know the 2nd and 3rd dimensions statically.
self.assertEqual(y.get_shape().as_list(), [None, 4, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
# we can feed in input with first dimension 2
shape_value = sess.run(
@@ -4162,7 +4162,7 @@ class LegacyFullyConnectedTest(test.TestCase):
self._unknown_dim_invalid_input(last_dim=None)
def test_1d_invalid_input(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'rank of x must be at least 2 not: 1'):
x = constant_op.constant([[]], shape=[0])
diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py
index c807ab0f2e..11033a2e9c 100644
--- a/tensorflow/contrib/layers/python/layers/normalization.py
+++ b/tensorflow/contrib/layers/python/layers/normalization.py
@@ -176,7 +176,8 @@ def group_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
- scope=None):
+ scope=None,
+ mean_close_to_zero=False):
"""Functional interface for the group normalization layer.
Reference: https://arxiv.org/abs/1803.08494.
@@ -222,6 +223,19 @@ def group_norm(inputs,
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
scope: Optional scope for `variable_scope`.
+ mean_close_to_zero: The mean of `input` before ReLU will be close to zero
+ when batch size >= 4k for Resnet-50 on TPU. If `True`, use
+ `nn.sufficient_statistics` and `nn.normalize_moments` to calculate the
+ variance. This is the same behavior as `fused` equals `True` in batch
+ normalization. If `False`, use `nn.moments` to calculate the variance.
+ When `mean` is close to zero, like 1e-4, use `mean` to calculate the
+ variance may have poor result due to repeated roundoff error and
+ denormalization in `mean`. When `mean` is large, like 1e2,
+ sum(`input`^2) is so large that only the high-order digits of the elements
+ are being accumulated. Thus, use sum(`input` - `mean`)^2/n to calculate
+ the variance has better accuracy compared to (sum(`input`^2)/n - `mean`^2)
+ when `mean` is large.
+
Returns:
A `Tensor` representing the output of the operation.
@@ -333,7 +347,14 @@ def group_norm(inputs,
gamma = array_ops.reshape(gamma, params_shape_broadcast)
# Calculate the moments.
- mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
+ if mean_close_to_zero:
+ # One pass algorithm returns better result when mean is close to zero.
+ counts, means_ss, variance_ss, _ = nn.sufficient_statistics(
+ inputs, moments_axes, keep_dims=True)
+ mean, variance = nn.normalize_moments(
+ counts, means_ss, variance_ss, shift=None)
+ else:
+ mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
# Compute normalization.
# TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index b6e96350db..c8d3c91b10 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -106,7 +106,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform(image_shape, seed=1)
output_train = normalization.instance_norm(images, scope='IN')
output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
@@ -130,7 +130,7 @@ class InstanceNormTest(test.TestCase):
inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
output_op = normalization.instance_norm(
inputs, center=False, scale=False, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
@@ -287,14 +287,19 @@ class GroupNormTest(test.TestCase):
output_train = normalization.group_norm(images, groups=2, scope='IN')
output_eval = normalization.group_norm(images, groups=2, scope='IN',
reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
self.assertAllClose(train_np, eval_np)
- def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None,
- groups=2, tol=1e-2):
+ def doOutputTest(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ mean_close_to_zero=False,
+ groups=2,
+ tol=1e-2):
# Select the axis for the channel and the dimensions along which statistics
# are accumulated.
if channels_axis < 0:
@@ -322,18 +327,29 @@ class GroupNormTest(test.TestCase):
if i not in reduced_axes:
reduced_shape.append(a)
- for mu in (0.0, 1e2):
- for sigma in (1.0, 0.1):
+ if mean_close_to_zero:
+ mu_tuple = (1e-4, 1e-2, 1.0)
+ sigma_tuple = (1e-2, 0.1, 1.0)
+ else:
+ mu_tuple = (1.0, 1e2)
+ sigma_tuple = (1.0, 0.1)
+
+ for mu in mu_tuple:
+ for sigma in sigma_tuple:
# Determine shape of Tensor after normalization.
expected_mean = np.zeros(reduced_shape)
expected_var = np.ones(reduced_shape)
- inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
+ inputs = random_ops.random_normal(input_shape, seed=0) * sigma + mu
output_op = normalization.group_norm(
- inputs, groups=groups, center=False, scale=False,
+ inputs,
+ groups=groups,
+ center=False,
+ scale=False,
channels_axis=channels_axis,
- reduction_axes=reduction_axes)
- with self.test_session() as sess:
+ reduction_axes=reduction_axes,
+ mean_close_to_zero=mean_close_to_zero)
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
@@ -347,12 +363,32 @@ class GroupNormTest(test.TestCase):
self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol)
self.assertAllClose(expected_var, var, rtol=tol, atol=tol)
+ def doOutputTestForMeanCloseToZero(self,
+ input_shape,
+ channels_axis=None,
+ reduction_axes=None,
+ groups=2,
+ tol=5e-2):
+ self.doOutputTest(
+ input_shape,
+ channels_axis=channels_axis,
+ reduction_axes=reduction_axes,
+ groups=groups,
+ tol=tol,
+ mean_close_to_zero=True)
+
def testOutputSmallInput4D_NHWC(self):
input_shape = [10, 10, 10, 30]
# Specify axes with positive values.
self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=3, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput3D_NHWC(self):
input_shape = [10, 10, 30]
@@ -360,6 +396,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=2, reduction_axes=[0, 1])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-1, reduction_axes=[-3, -2])
def testOutputSmallInput4D_NCHW(self):
input_shape = [10, 10, 10, 30]
@@ -367,6 +409,12 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=1, reduction_axes=[2, 3])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputSmallInput3D_NCHW(self):
input_shape = [10, 10, 30]
@@ -374,23 +422,43 @@ class GroupNormTest(test.TestCase):
self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2])
# Specify axes with negative values.
self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1])
+ # Specify axes with positive values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=0, reduction_axes=[1, 2])
+ # Specify axes with negative values.
+ self.doOutputTestForMeanCloseToZero(
+ input_shape, channels_axis=-3, reduction_axes=[-2, -1])
def testOutputBigInput4D_NHWC(self):
- self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2],
- groups=1)
+ self.doOutputTest(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
+ self.doOutputTestForMeanCloseToZero(
+ [5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], groups=1)
def testOutputBigInput4D_NCHW(self):
- self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3],
- groups=4)
+ self.doOutputTest(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
+ self.doOutputTestForMeanCloseToZero(
+ [1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], groups=4)
def testOutputSmallInput2D_NC(self):
- self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTest(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 7 * 100], channels_axis=1, reduction_axes=[], groups=7)
def testOutputSmallInput5D_NCXXX(self):
- self.doOutputTest([10, 10, 20, 40, 5],
- channels_axis=1,
- reduction_axes=[2, 3, 4],
- groups=5)
+ self.doOutputTest(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+ self.doOutputTestForMeanCloseToZero(
+ [10, 10, 20, 40, 5],
+ channels_axis=1,
+ reduction_axes=[2, 3, 4],
+ groups=5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index a4461a20e5..29dede2a49 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -66,7 +66,7 @@ class OptimizersTest(test.TestCase):
]
for optimizer in optimizers:
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer=optimizer)
@@ -82,7 +82,7 @@ class OptimizersTest(test.TestCase):
return gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss, global_step, learning_rate=None, optimizer=optimizer_fn)
@@ -96,14 +96,14 @@ class OptimizersTest(test.TestCase):
optimizers = ["blah", variables.Variable, object(), lambda x: None]
for optimizer in optimizers:
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
_, _, loss, global_step = _setup_model()
with self.assertRaises(ValueError):
optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer=optimizer)
def testBadSummaries(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
_, _, loss, global_step = _setup_model()
with self.assertRaises(ValueError):
optimizers_lib.optimize_loss(
@@ -111,7 +111,7 @@ class OptimizersTest(test.TestCase):
summaries=["loss", "bad_summary"])
def testInvalidLoss(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
_, _, _, global_step = _setup_model()
with self.assertRaises(ValueError):
optimizers_lib.optimize_loss(
@@ -121,7 +121,7 @@ class OptimizersTest(test.TestCase):
[[1.0]], global_step, learning_rate=0.1, optimizer="SGD")
def testInvalidGlobalStep(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
x = array_ops.placeholder(dtypes.float32, [])
var = variable_scope.get_variable(
"test", [], initializer=init_ops.constant_initializer(10))
@@ -157,7 +157,7 @@ class OptimizersTest(test.TestCase):
optimizer="SGD")
def testInvalidLearningRate(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
_, _, loss, global_step = _setup_model()
with self.assertRaises(ValueError):
optimizers_lib.optimize_loss(
@@ -165,7 +165,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoise(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -182,7 +182,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoiseWithClipping(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -198,7 +198,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -213,7 +213,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testAdaptiveGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
clip_gradients = optimizers_lib.adaptive_clipping_fn()
train = optimizers_lib.optimize_loss(
@@ -234,7 +234,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(2, var_count)
def testGradientMultiply(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -270,7 +270,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x = array_ops.placeholder(dtypes.float32, [])
var = variable_scope.get_variable(
"test", [], initializer=init_ops.constant_initializer(10))
@@ -295,7 +295,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
x = array_ops.placeholder(dtypes.float32, [])
var = variable_scope.get_variable(
"test", [], initializer=init_ops.constant_initializer(10))
@@ -319,7 +319,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
@@ -342,7 +342,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
@@ -365,7 +365,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
@@ -389,7 +389,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
@@ -413,7 +413,7 @@ class OptimizersTest(test.TestCase):
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
@@ -433,7 +433,7 @@ class OptimizersTest(test.TestCase):
class AdaptiveClipping(test.TestCase):
def testAverages(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
scale = 2.
grad = array_ops.ones([3, 4]) * scale
log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
@@ -463,7 +463,7 @@ class AdaptiveClipping(test.TestCase):
self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
def testClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
spike = 1000.
multiplier = array_ops.placeholder(dtypes.float32, [], "multiplier")
step = array_ops.placeholder(dtypes.int32, [], "step")
diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py
index 07191eeda7..51faba30c7 100644
--- a/tensorflow/contrib/layers/python/layers/regularizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py
@@ -71,7 +71,7 @@ class RegularizerTest(test.TestCase):
with self.assertRaises(ValueError):
regularizers.l1_l2_regularizer(0.5, 0)
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -84,7 +84,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
@@ -93,7 +93,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
@@ -104,7 +104,7 @@ class RegularizerTest(test.TestCase):
self.assertEquals(loss, None)
def testL1L2RegularizerWithScope(self):
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -142,7 +142,7 @@ class RegularizerTest(test.TestCase):
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
expected = sum([2 * x for l in array_weights_list for x in l])
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(dummy_regularizer,
tensor_weights_list)
self.assertAllClose(expected, result.eval())
@@ -151,7 +151,7 @@ class RegularizerTest(test.TestCase):
regularizer = regularizers.l2_regularizer(0.0)
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(regularizer,
tensor_weights_list)
self.assertAllClose(0.0, result.eval())
@@ -161,7 +161,7 @@ class RegularizerTest(test.TestCase):
tensor_weights_list = [
constant_op.constant(x) for x in [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
]
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
regularizers.apply_regularization(non_scalar_regularizer,
tensor_weights_list)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index dad3da3748..06da32072f 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -30,6 +30,7 @@ import functools
import re
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -44,6 +45,7 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -151,9 +153,19 @@ def _rev_block_forward(x1,
return y1, y2
+def _safe_wraps(fn):
+ if isinstance(fn, functools.partial):
+ # functools.partial objects cannot be wrapped as they are missing the
+ # necessary properties (__name__, __module__, __doc__).
+ def passthrough(f):
+ return f
+ return passthrough
+ return functools.wraps(fn)
+
+
def _scope_wrap(fn, scope):
- @functools.wraps(fn)
+ @_safe_wraps(fn)
def wrap(*args, **kwargs):
with variable_scope.variable_scope(scope, use_resource=True):
return fn(*args, **kwargs)
@@ -430,7 +442,7 @@ def rev_block(x1,
def enable_with_args(dec):
"""A decorator for decorators to enable their usage with or without args."""
- @functools.wraps(dec)
+ @_safe_wraps(dec)
def new_dec(*args, **kwargs):
if len(args) == 1 and not kwargs and callable(args[0]):
# Used as decorator without args
@@ -461,7 +473,8 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
- a tuple of Tensors.
+ a tuple of Tensors. Note that `fn` should not close over any other
+ Tensors or Variables.
use_data_dep: `bool`, if `True` will use a dummy data dependency to force
the recompute to happen. If `False` will use a control dependency. By
default will be `True` if in an XLA context and `False` otherwise. XLA
@@ -475,9 +488,24 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
- """
- @functools.wraps(fn)
+ Raises:
+ ValueError: if `fn` closes over any Tensors or Variables.
+ """
+ # Check for closed-over Tensors/Variables
+ if fn.__code__.co_freevars:
+ closed_over_vars = dict(zip(fn.__code__.co_freevars,
+ [c.cell_contents for c in fn.__closure__]))
+ for var_name, value in six.iteritems(closed_over_vars):
+ if isinstance(value, (framework_ops.Tensor, variables_lib.Variable)):
+ raise ValueError(
+ "fn decorated with @recompute_grad closes over Tensor %s "
+ "(local variable name: %s). The decorated fn must not close over "
+ "Tensors or Variables because gradients will NOT be computed for "
+ "them through fn. To ensure correct gradients, make the "
+ "Tensor an input to fn." % (value.name, var_name))
+
+ @_safe_wraps(fn)
def wrapped(*args):
return _recompute_grad(
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
@@ -490,6 +518,62 @@ def _is_on_tpu():
return control_flow_util.GetContainingXLAContext(ctxt) is not None
+def _recomputing_grad_fn(compute_fn,
+ original_args,
+ original_vars,
+ output_grads,
+ grad_fn_variables,
+ use_data_dep,
+ tupleize_grads,
+ arg_scope,
+ var_scope,
+ has_is_recompute_kwarg):
+ """Grad fn for recompute_grad."""
+ variables = grad_fn_variables or []
+
+ # Identity ops around the inputs ensures correct gradient graph-walking.
+ inputs = [array_ops.identity(x) for x in list(original_args)]
+
+ # Recompute outputs
+ # Use a control dependency to ensure that the recompute is not eliminated by
+ # CSE and that it happens on the backwards pass.
+ ctrl_dep_grads = [g for g in output_grads if g is not None]
+ with framework_ops.control_dependencies(ctrl_dep_grads):
+ if use_data_dep:
+ inputs = _force_data_dependency(output_grads, inputs)
+ # Re-enter scopes
+ with contrib_framework_ops.arg_scope(arg_scope):
+ with variable_scope.variable_scope(var_scope, reuse=True):
+ # Re-call the function and ensure that the touched variables are the
+ # same as in the first call.
+ with backprop.GradientTape() as tape:
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = compute_fn(*inputs, **fn_kwargs)
+ recompute_vars = set(tape.watched_variables())
+ if original_vars != recompute_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs = list(outputs)
+
+ # Compute gradients
+ grads = gradients_impl.gradients(outputs, inputs + variables,
+ output_grads)
+
+ if tupleize_grads:
+ if use_data_dep:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
+ grad_inputs = grads[:len(inputs)]
+ grad_vars = grads[len(inputs):]
+ return grad_inputs, grad_vars
+
+
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
@@ -500,12 +584,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if use_data_dep_ == _USE_DEFAULT:
use_data_dep_ = _is_on_tpu()
+ # Use custom_gradient and return a grad_fn that recomputes on the backwards
+ # pass.
@custom_gradient.custom_gradient
def fn_with_recompute(*args):
"""Wrapper for fn."""
- # Forward pass
+ # Capture the variable and arg scopes so we can re-enter them when
+ # recomputing.
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
+ # Track all variables touched in the function.
with backprop.GradientTape() as tape:
fn_kwargs = {}
if has_is_recompute_kwarg:
@@ -513,46 +601,25 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
- # Backward pass
def _grad_fn(output_grads, variables=None):
- """Recompute outputs for gradient computation."""
- variables = variables or []
+ # Validate that custom_gradient passes the right variables into grad_fn.
if original_vars:
assert variables, ("Fn created variables but the variables were not "
"passed to the gradient fn.")
if set(variables) != original_vars:
raise ValueError(_WRONG_VARS_ERR)
- inputs = [array_ops.identity(x) for x in list(args)]
- # Recompute outputs
- with framework_ops.control_dependencies(output_grads):
- if use_data_dep_:
- inputs = _force_data_dependency(output_grads, inputs)
- with contrib_framework_ops.arg_scope(arg_scope):
- with variable_scope.variable_scope(vs, reuse=True):
- with backprop.GradientTape() as tape:
- fn_kwargs = {}
- if has_is_recompute_kwarg:
- fn_kwargs["is_recomputing"] = True
- outputs = fn(*inputs, **fn_kwargs)
- recompute_vars = set(tape.watched_variables())
- if original_vars != recompute_vars:
- raise ValueError(_WRONG_VARS_ERR)
-
- if not isinstance(outputs, (list, tuple)):
- outputs = [outputs]
- outputs = list(outputs)
- grads = gradients_impl.gradients(outputs, inputs + variables,
- output_grads)
-
- if tupleize_grads:
- if use_data_dep_:
- grads = _tuple_with_data_dep(grads)
- else:
- grads = control_flow_ops.tuple(grads)
- grad_inputs = grads[:len(inputs)]
- grad_vars = grads[len(inputs):]
- return grad_inputs, grad_vars
+ return _recomputing_grad_fn(
+ compute_fn=fn,
+ original_args=args,
+ original_vars=original_vars,
+ output_grads=output_grads,
+ grad_fn_variables=variables,
+ use_data_dep=use_data_dep_,
+ tupleize_grads=tupleize_grads,
+ arg_scope=arg_scope,
+ var_scope=vs,
+ has_is_recompute_kwarg=has_is_recompute_kwarg)
# custom_gradient inspects the signature of the function to determine
# whether the user expects variables passed in the grad_fn. If the function
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index d5971fb9d8..2c7463acc0 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase):
y1, y2 = block.forward(x1, x2)
x1_inv, x2_inv = block.backward(y1, y2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv])
@@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase):
x1, x2 = block.backward(y1, y2)
y1_inv, y2_inv = block.forward(x1, x2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
@@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase):
grads_rev = gradients_impl.gradients(loss_rev, wrt)
grads = gradients_impl.gradients(loss, wrt)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
@@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase):
for out, scope_vars in outputs_and_vars:
all_grads.append(gradients_impl.gradients(out, scope_vars))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = list(zip(*outputs_and_vars))[0]
outs, all_grads_val = sess.run([outputs, all_grads])
@@ -389,9 +389,19 @@ class RecomputeTest(test.TestCase):
layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(grads)
+ def testErrorOnClosedOverTensor(self):
+ x = random_ops.random_uniform((4, 8))
+ y = random_ops.random_uniform((4, 8))
+ z = x * y
+
+ with self.assertRaisesWithPredicateMatch(ValueError, "closes over"):
+ @rev_block_lib.recompute_grad
+ def fn_with_capture(a): # pylint: disable=unused-variable
+ return a * z
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/layers/python/layers/summaries_test.py b/tensorflow/contrib/layers/python/layers/summaries_test.py
index a1ef06feec..2ec2af9d44 100644
--- a/tensorflow/contrib/layers/python/layers/summaries_test.py
+++ b/tensorflow/contrib/layers/python/layers/summaries_test.py
@@ -29,19 +29,19 @@ from tensorflow.python.platform import test
class SummariesTest(test.TestCase):
def test_summarize_scalar_tensor(self):
- with self.test_session():
+ with self.cached_session():
scalar_var = variables.Variable(1)
summary_op = summaries_lib.summarize_tensor(scalar_var)
self.assertEquals(summary_op.op.type, 'ScalarSummary')
def test_summarize_multidim_tensor(self):
- with self.test_session():
+ with self.cached_session():
tensor_var = variables.Variable([1, 2, 3])
summary_op = summaries_lib.summarize_tensor(tensor_var)
self.assertEquals(summary_op.op.type, 'HistogramSummary')
def test_summarize_activation(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = array_ops.identity(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -52,7 +52,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -64,7 +64,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu6(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu6(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -77,7 +77,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_collection_regex(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
array_ops.identity(var, name='Test1')
ops.add_to_collection('foo', array_ops.identity(var, name='Test2'))
diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py
index 645dc1291e..34f63f5d86 100644
--- a/tensorflow/contrib/layers/python/layers/utils_test.py
+++ b/tensorflow/contrib/layers/python/layers/utils_test.py
@@ -42,12 +42,12 @@ class ConstantValueTest(test.TestCase):
c = constant_op.constant(v)
value = utils.constant_value(c)
self.assertEqual(value, v)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(c.eval(), v)
def test_variable(self):
for v in [True, False, 1, 0, 1.0]:
- with ops.Graph().as_default() as g, self.test_session(g) as sess:
+ with ops.Graph().as_default() as g, self.session(g) as sess:
x = variables.Variable(v)
value = utils.constant_value(x)
self.assertEqual(value, None)
@@ -60,7 +60,7 @@ class ConstantValueTest(test.TestCase):
x = array_ops.identity(p)
value = utils.constant_value(p)
self.assertEqual(value, None)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(x.eval(feed_dict={p: v}), v)
@@ -80,7 +80,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -89,7 +89,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -99,7 +99,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -119,7 +119,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -128,7 +128,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -138,7 +138,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -151,7 +151,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_constant(self):
@@ -161,7 +161,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self):
@@ -171,7 +171,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
@@ -182,7 +182,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index d50750001e..b6c2cab64a 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -42,7 +42,7 @@ def _assert_sparse_tensor_value(test_case, expected, actual):
class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_tensor_1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1, 0, 2, 0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -53,7 +53,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_float(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1.5, 0.0, 2.3, 0.0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -64,7 +64,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_bool(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([True, False, True, False])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -75,7 +75,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b''])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -86,7 +86,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(
[b'qwe', b'', b'ewq', b''], ignore_value=b'qwe')
result = sess.run(st)
@@ -98,7 +98,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[1, 2, 0, 0], [3, 4, 5, 0]])
result = sess.run(st)
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -107,7 +107,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_3d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[[1, 2, 0, 0], [3, 4, 5, 0]],
[[7, 8, 0, 0], [9, 0, 0, 0]]])
result = sess.run(st)
@@ -117,7 +117,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_1d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]})
@@ -126,7 +126,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_3d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(
shape=[None, None, None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
@@ -142,7 +142,7 @@ class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_unknown_rank(self):
ph = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(ph)
result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]})
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -155,7 +155,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -167,7 +167,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_unsorted_indices(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[2, 0], [2, 2], [2, 1], [0, 0]],
values=[0, 1, 2, 3],
@@ -179,7 +179,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_in_the_end(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -191,7 +191,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_3d(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 2, 0], [0, 2, 1], [0, 2, 2]],
values=[0, 1, 2, 3],
@@ -207,7 +207,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
def test_indicators_to_sparse_ids_1d(self):
indicators = (0, 0, 1, 0)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0,),),
values=(2,),
@@ -220,7 +220,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(1, 0, 0, 1),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 3),
@@ -235,7 +235,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
((1, 0, 0, 1, 1), (0, 0, 1, 0, 0)),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=(
(0, 0, 0),
@@ -255,7 +255,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, dtype=dtypes.int16)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=np.array((2, 0, 3), dtype=np.int16),
@@ -269,7 +269,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value=-1)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -282,7 +282,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(('B', '', '', 'C'), ('', '', 'D', '')),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -296,7 +296,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value='x')
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -311,7 +311,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
indicators = array_ops.placeholder(
dtype=dtypes.int32, shape=(None, None, None))
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -325,7 +325,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
indicators = array_ops.placeholder(dtype=dtypes.int32)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 418b0cf392..61185f65a9 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -403,6 +403,7 @@ py_test(
srcs = ["python/learn/estimators/dnn_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py
index 79bd73faaf..28a6f5aed9 100644
--- a/tensorflow/contrib/learn/__init__.py
+++ b/tensorflow/contrib/learn/__init__.py
@@ -19,7 +19,8 @@ This module and all its submodules are deprecated. See
[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
for migration instructions.
-See the @{$python/contrib.learn} guide.
+See the [Contrib Learn](https://tensorflow.org/api_guides/python/contrib.learn)
+guide.
@@BaseEstimator
@@Estimator
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index c9a11f27f1..1d8a59281a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -155,7 +155,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
sequence_input = dynamic_rnn_estimator.build_sequence_input(
self.GetColumnsToTensors(), self.sequence_feature_columns,
self.context_feature_columns)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input)
@@ -330,7 +330,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
(state_dict_val, actual_state_val, flattened_state_val) = sess.run(
[state_dict, actual_state, flattened_state])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
index 82563141cc..ebf5f5617d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
@@ -44,7 +44,7 @@ class RnnCommonTest(test.TestCase):
constant_op.constant(labels, dtype=dtypes.int32),
constant_op.constant(sequence_length, dtype=dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
activations_masked, labels_masked = sess.run(
[activations_masked_t, labels_masked_t])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
index 6d04543819..81376c0e2a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
@@ -68,12 +68,12 @@ class StabilityTest(test.TestCase):
minval = -0.3333
maxval = 0.3333
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val1 = session.run(x)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val2 = session.run(x)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 442247409d..06c61554fa 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -53,7 +53,7 @@ class PrepareInputsForRnnTest(test.TestCase):
sequence_feature_columns,
num_unroll)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time)
@@ -314,7 +314,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
else:
self.assertAllEqual(v, got[k])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run(
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index df156da3f4..d5c02124ac 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -175,7 +175,7 @@ class GraphActionsTest(test.TestCase):
return in0, in1, out
def test_infer(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, in1, out = self._build_inference_graph()
self.assertEqual({
@@ -193,7 +193,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_coordinator_request_stop_called(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})
self.assertTrue(request_stop.called)
@@ -204,7 +204,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_run_feeds_iter_cleanup_with_exceptions(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
try:
for _ in learn.graph_actions.run_feeds_iter({
@@ -249,7 +249,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_invalid_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, _ = self._build_inference_graph()
with self.assertRaisesRegexp(TypeError, 'Can not convert a NoneType'):
@@ -257,7 +257,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, out = self._build_inference_graph()
self.assertEqual(
@@ -271,7 +271,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test eval for 1 epoch.
def test_evaluate_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions.evaluate(
@@ -288,7 +288,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -310,7 +310,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_ready_for_local_init(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
variables.Variable(
@@ -327,7 +327,7 @@ class GraphActionsTest(test.TestCase):
max_steps=1)
def test_evaluate_feed_fn(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -352,7 +352,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_feed_fn_with_exhaustion(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -375,7 +375,7 @@ class GraphActionsTest(test.TestCase):
expected_session_logs=[])
def test_evaluate_with_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
@@ -469,7 +469,7 @@ class GraphActionsTrainTest(test.TestCase):
return in0, in1, out
def test_train_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
train_op = constant_op.constant(1.0)
loss_op = constant_op.constant(2.0)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
@@ -503,7 +503,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Mock supervisor, and assert all interactions.
def test_train(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_summaries(self._output_dir)
@@ -522,7 +522,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_steps_is_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -535,7 +535,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -549,7 +549,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(25, step)
def test_train_max_steps_is_not_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -562,7 +562,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -576,7 +576,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(15, step)
def test_train_loss(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
@@ -598,7 +598,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_summaries(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -624,7 +624,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_chief_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -663,7 +663,7 @@ class GraphActionsTrainTest(test.TestCase):
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
- with self.test_session(g):
+ with self.session(g):
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 1f439965da..284a4f45f6 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -58,7 +58,7 @@ class DataFeederTest(test.TestCase):
self.assertEqual(expected_np_dtype, v)
else:
self.assertEqual(expected_np_dtype, feeder.input_dtype)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inp, _ = feeder.input_builder()
if isinstance(inp, dict):
for v in list(inp.values()):
@@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase):
def test_unsupervised(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
inp, _ = feeder.input_builder()
feed_dict_fn = feeder.get_feed_dict_fn()
feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase):
def test_epoch(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
feeder.input_builder()
epoch = feeder.make_epoch_variable()
feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b7d9..5e90d1fa20 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase):
for index in range(2):
yield {'a': np.ones(1) * index}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase):
'label2': np.ones(1) * index - 64,
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key=['label', 'label2'],
@@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones((3, 3)) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase):
def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
x = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
failing_input_fn = generator_io.generator_input_fn(
x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
return np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
yield np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase):
}
y = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', np.arange(10)]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', 'target']
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
features = input_fn()
@@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index e11e8b698a..8e68a17e47 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -207,7 +207,7 @@ class GraphIOTest(test.TestCase):
parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32)
}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
features = graph_io.read_batch_record_features(
_VALID_FILE_PATTERN,
batch_size,
@@ -242,7 +242,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
_VALID_FILE_PATTERN,
batch_size,
@@ -276,7 +276,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
[_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2],
batch_size,
@@ -325,7 +325,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filename,
batch_size,
@@ -374,7 +374,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, result = graph_io.read_keyed_batch_features(
filename,
batch_size,
@@ -429,7 +429,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
result = graph_io.read_batch_features(
filename,
batch_size,
@@ -475,7 +475,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filenames,
batch_size,
@@ -519,7 +519,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
@@ -640,7 +640,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 10
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
[filename],
batch_size,
@@ -672,7 +672,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples(
filename,
batch_size,
@@ -714,7 +714,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda
parsing_ops.decode_json_example(example), dtypes)
@@ -773,7 +773,7 @@ class GraphIOTest(test.TestCase):
examples = parsing_ops.parse_example(serialized, features)
return math_ops.less(examples["age"], 2)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io._read_keyed_batch_examples_helper(
filename,
batch_size,
@@ -812,7 +812,7 @@ class GraphIOTest(test.TestCase):
coord.join(threads)
def test_queue_parsed_features_single_tensor(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features = {"test": constant_op.constant([1, 2, 3])}
_, queued_features = graph_io.queue_parsed_features(features)
coord = coordinator.Coordinator()
@@ -833,7 +833,7 @@ class GraphIOTest(test.TestCase):
_, queued_feature = graph_io.read_keyed_batch_features_shared_queue(
_VALID_FILE_PATTERN, batch_size, feature, reader)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features_result = graph_io.read_batch_features(
_VALID_FILE_PATTERN, batch_size, feature, reader)
session.run(variables.local_variables_initializer())
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e8f3..396539a76a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index ff1da32c21..83e48a36e7 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -127,12 +127,12 @@ class MonitorsTest(test.TestCase):
monitor.end()
def test_base_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(learn.monitors.BaseMonitor())
def test_every_0(self):
monitor = _MyEveryN(every_n_steps=0, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(30))
self.assertAllEqual(expected_steps, monitor.steps_begun)
@@ -141,7 +141,7 @@ class MonitorsTest(test.TestCase):
def test_every_1(self):
monitor = _MyEveryN(every_n_steps=1, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(1, 30))
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -150,7 +150,7 @@ class MonitorsTest(test.TestCase):
def test_every_2(self):
monitor = _MyEveryN(every_n_steps=2, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(2, 29, 2)) + [29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -159,7 +159,7 @@ class MonitorsTest(test.TestCase):
def test_every_8(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = [0, 1, 2, 10, 18, 26, 29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -168,7 +168,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_no_max_steps(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(
monitor, num_epochs=3, num_steps_per_epoch=10, pass_max_steps=False)
begin_end_steps = [0, 1, 2, 10, 18, 26]
@@ -179,7 +179,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_begin(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_begin(step)
@@ -192,7 +192,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_end(step, output=None)
@@ -207,7 +207,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_at_the_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -224,7 +224,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_should_not_be_called_twice(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -240,13 +240,13 @@ class MonitorsTest(test.TestCase):
self.assertEqual([8, 16], monitor.post_steps)
def test_print(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
t = constant_op.constant(42.0, name='foo')
self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
self.assertRegexpMatches(str(self.logged_message), t.name)
def test_logging_trainable(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(constant_op.constant(42.0), name='foo')
var.initializer.run()
cof = constant_op.constant(1.0)
@@ -258,7 +258,7 @@ class MonitorsTest(test.TestCase):
self.assertRegexpMatches(str(self.logged_message), var.name)
def test_summary_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
var = variables.Variable(0.0)
@@ -312,7 +312,7 @@ class MonitorsTest(test.TestCase):
monitor = learn.monitors.ValidationMonitor(
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'set_estimator'):
self._run_monitor(monitor)
@@ -330,7 +330,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
mock_latest_checkpoint.assert_called_with(model_dir)
@@ -351,7 +351,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
@@ -370,7 +370,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=1)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'missing from outputs'):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@@ -392,7 +392,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -477,7 +477,7 @@ class MonitorsTest(test.TestCase):
every_n_steps=0, early_stopping_rounds=2)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -509,7 +509,7 @@ class MonitorsTest(test.TestCase):
metrics=constant_op.constant(2.0),
every_n_steps=0, early_stopping_rounds=2)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
@@ -525,7 +525,7 @@ class MonitorsTest(test.TestCase):
def test_graph_dump(self):
monitor0 = learn.monitors.GraphDump()
monitor1 = learn.monitors.GraphDump()
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
const_var = variables.Variable(42.0, name='my_const')
counter_var = variables.Variable(0.0, name='my_counter')
assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
@@ -568,7 +568,7 @@ class MonitorsTest(test.TestCase):
def test_capture_variable(self):
monitor = learn.monitors.CaptureVariable(
var_name='my_assign_add:0', every_n=8, first_n=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(0.0, name='my_var')
var.initializer.run()
state_ops.assign_add(var, 1.0, name='my_assign_add')
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923db3..ff190110c1 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@ class OpsTest(test.TestCase):
"""Ops tests."""
def test_softmax_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
features = array_ops.placeholder(dtypes.float32, [None, 3])
labels = array_ops.placeholder(dtypes.float32, [None, 2])
weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@ class OpsTest(test.TestCase):
ids_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
ids = np.random.randint(0, n_embed, ids_shape)
- with self.test_session():
+ with self.cached_session():
embed_np = embeds[ids]
embed_tf = ops.embedding_lookup(embeds, ids).eval()
self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@ class OpsTest(test.TestCase):
def test_categorical_variable(self):
random_seed.set_random_seed(42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
embeddings = ops.categorical_variable(
cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61955..5a7e4ebfea 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase):
"""Sequence-to-sequence tests."""
def test_sequence_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
decoding = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
@@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase):
def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
- with self.test_session() as session:
+ with self.cached_session() as session:
x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase):
[[0, 0, 0], [0, 0, 0]]])
def test_rnn_decoder(self):
- with self.test_session():
+ with self.cached_session():
decoder_inputs = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 7ce5fb2da6..2f33a2b74d 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -950,7 +950,7 @@ class Seq2SeqTest(test.TestCase):
num_dec_timesteps = 3
def TestModel(seq2seq):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
random_seed.set_random_seed(111)
random.seed(111)
np.random.seed(111)
diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
index 423dcce8de..8390ddda90 100644
--- a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
+++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class DecodeLibsvmOpTest(test.TestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [
"1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
"2 3:2.5 2:nan 1:0.105"
@@ -48,7 +48,7 @@ class DecodeLibsvmOpTest(test.TestCase):
[0, 0.105, np.nan, 2.5, 0, 0]])
def testNDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index a262a099cf..cbe4c03e4d 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -14,7 +14,8 @@
# ==============================================================================
"""Linear algebra libraries.
-See the @{$python/contrib.linalg} guide.
+See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
+guide.
@@LinearOperator
@@LinearOperatorBlockDiag
diff --git a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
index a4f5086dde..5fe883d647 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
+++ b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
@@ -199,6 +199,46 @@ does.
However, in practice, convergence with $$x_0 = 0$$ always happens (tested for a
sample of generic values for the parameters).
+### Poisson log loss
+
+Poisson log loss is defined as $$ \l(u) = e^u - uy $$ for label $$y \geq 0.$$
+Its dual is
+
+$$ \l^\star(v) = (y+v) (\log(y+v) - 1) $$
+
+and is only defined for $$ y+v > 0 $$. We then have the constraint
+
+$$ y > \a+\d. $$
+
+The dual is
+
+$$ D(\d) = -(y-\a-\d) (\log(y-\a-\d) - 1) - \bar{y} \d - \frac{A}{2} \d^2 $$
+
+and its derivative is,
+
+$$ D'(\d) = \log(y-\a-\d) - \bar{y} - A\d $$
+
+Similar to the logistic loss, we perform a change of variable to handle the
+constraint on $$ \d $$
+
+$$ y - (\a+\d) = e^x $$
+
+After this change of variable, the goal is to find the zero of this function
+
+$$ H(x) = x - \bar{y} -A(y-\a-e^x) $$
+
+whose first derivative is
+
+$$ H'(x) = 1+Ae^x $$
+
+Since this function is always positive, $$H$$ is increasing and has a unique
+zero.
+
+We can start Newton algorithm at $$\d=0$$ which corresponds to $$ x =
+\log(y-\a)$$. As before the Newton step is given by
+
+$$x_{k+1} = x_k - \frac{H(x_k)}{H'(x_k)}. $$
+
### References
[1] C. Ma et al., Adding vs. Averaging in Distributed Primal-Dual Optimization,
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ef0e08a777..1d2db1cec8 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -1192,6 +1192,57 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
+class SdcaWithPoissonLossTest(SdcaModelTest):
+ """SDCA optimizer test class for poisson loss."""
+
+ def testSimple(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 2),
+ ]
+ example_weights = [100.0, 100.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(
+ symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=0,
+ loss_type='poisson_loss')
+ model = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+
+ # Before minimization, the weights default to zero. There is no loss due
+ # to regularization, only unregularized loss which is 1 for each example.
+ predictions = model.predictions(examples)
+ self.assertAllClose([1.0, 1.0], predictions.eval())
+ unregularized_loss = model.unregularized_loss(examples)
+ regularized_loss = model.regularized_loss(examples)
+ approximate_duality_gap = model.approximate_duality_gap()
+ self.assertAllClose(1.0, unregularized_loss.eval())
+ self.assertAllClose(1.0, regularized_loss.eval())
+
+ # There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender
+ # (say w3 and w4). The minimization leads to:
+ # w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.
+ # w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.
+ # This gives an unregularized loss of .3167 and .3366 with regularization.
+ train_op = model.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ model.update_weights(train_op).run()
+
+ self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)
+ self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)
+
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 0047d5753a..14f59a3f64 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
+from tensorflow.python.ops.nn import log_poisson_loss
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.summary import summary
@@ -51,6 +52,7 @@ class SdcaModel(object):
* Squared loss
* Hinge loss
* Smooth hinge loss
+ * Poisson log loss
This class defines an optimizer API to train a linear model.
@@ -112,7 +114,7 @@ class SdcaModel(object):
raise ValueError('examples, variables and options must all be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
- 'smooth_hinge_loss')
+ 'smooth_hinge_loss', 'poisson_loss')
if options['loss_type'] not in supported_losses:
raise ValueError('Unsupported loss_type: ', options['loss_type'])
@@ -315,6 +317,7 @@ class SdcaModel(object):
"""Add operations to compute predictions by the model.
If logistic_loss is being used, predicted probabilities are returned.
+ If poisson_loss is being used, predictions are exponentiated.
Otherwise, (raw) linear predictions (w*x) are returned.
Args:
@@ -335,6 +338,10 @@ class SdcaModel(object):
# Convert logits to probability for logistic loss predictions.
with name_scope('sdca/logistic_prediction'):
result = math_ops.sigmoid(result)
+ elif self._options['loss_type'] == 'poisson_loss':
+ # Exponeniate the prediction for poisson loss predictions.
+ with name_scope('sdca/poisson_prediction'):
+ result = math_ops.exp(result)
return result
def _get_partitioned_update_ops(self,
@@ -624,6 +631,11 @@ class SdcaModel(object):
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
+ if self._options['loss_type'] == 'poisson_loss':
+ return math_ops.reduce_sum(math_ops.multiply(
+ log_poisson_loss(targets=labels, log_input=predictions),
+ weights)) / math_ops.reduce_sum(weights)
+
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
# first convert 0/1 labels into -1/1 labels.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index a2d82cf800..553b116a3b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -30,7 +30,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTable(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -53,7 +53,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTableVectors(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
@@ -79,7 +79,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
output.eval())
def testExportSharded(self):
- with self.test_session():
+ with self.cached_session():
empty_key = -2
default_val = -1
num_shards = 2
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
index 237a6812b7..51c4f68543 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -36,13 +36,13 @@ class SparseFeatureColumnTest(TensorFlowTestCase):
self.assertTrue(isinstance(sfc.example_indices, ops.Tensor))
self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor))
self.assertEqual(sfc.feature_values, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
expected_feature_values = [1.0, 2.0, 3.0, 4.0]
sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
expected_feature_values)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 1e6f1e7da2..f320b53d94 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -36,10 +36,10 @@ cc_library(
srcs = ["arena_planner.cc"],
hdrs = ["arena_planner.h"],
deps = [
- ":context",
":graph_info",
":memory_planner",
":simple_memory_arena",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -54,6 +54,7 @@ cc_test(
deps = [
":arena_planner",
"//tensorflow/contrib/lite/testing:util",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_googletest//:gtest",
],
@@ -63,27 +64,27 @@ cc_test(
# TODO(aselle): Resolve problems preventing C99 usage.
cc_library(
name = "context",
- srcs = ["context.c"],
hdrs = ["context.h"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "graph_info",
hdrs = ["graph_info.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "memory_planner",
hdrs = ["memory_planner.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "simple_memory_arena",
srcs = ["simple_memory_arena.cc"],
hdrs = ["simple_memory_arena.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -91,7 +92,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -121,12 +122,12 @@ cc_library(
name = "framework",
srcs = [
"allocation.cc",
- "error_reporter.cc",
"graph_info.cc",
"interpreter.cc",
"model.cc",
- "op_resolver.cc",
+ "mutable_op_resolver.cc",
"optional_debug_tools.cc",
+ "stderr_reporter.cc",
] + select({
"//tensorflow:android": [
"nnapi_delegate.cc",
@@ -149,21 +150,31 @@ cc_library(
"graph_info.h",
"interpreter.h",
"model.h",
+ "mutable_op_resolver.h",
"nnapi_delegate.h",
"op_resolver.h",
"optional_debug_tools.h",
+ "stderr_reporter.h",
],
copts = tflite_copts(),
+ linkopts = [
+ ] + select({
+ "//tensorflow:android": [
+ "-llog",
+ ],
+ "//conditions:default": [
+ ],
+ }),
deps = [
":arena_planner",
- ":builtin_op_data",
- ":context",
":graph_info",
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
":string",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
@@ -202,6 +213,8 @@ cc_test(
deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
@@ -251,6 +264,8 @@ cc_test(
],
deps = [
":framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -258,9 +273,9 @@ cc_test(
# Test OpResolver.
cc_test(
- name = "op_resolver_test",
+ name = "mutable_op_resolver_test",
size = "small",
- srcs = ["op_resolver_test.cc"],
+ srcs = ["mutable_op_resolver_test.cc"],
tags = ["no_oss"],
deps = [
":framework",
@@ -269,24 +284,12 @@ cc_test(
],
)
-# Test the C extension API code.
-cc_test(
- name = "context_test",
- size = "small",
- srcs = ["context_test.cc"],
- deps = [
- ":framework",
- "//tensorflow/contrib/lite/testing:util",
- "@com_google_googletest//:gtest",
- ],
-)
-
cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -296,7 +299,6 @@ cc_test(
srcs = ["util_test.cc"],
tags = ["no_oss"],
deps = [
- ":context",
":util",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md
deleted file mode 100644
index 8fd63d5cee..0000000000
--- a/tensorflow/contrib/lite/RELEASE.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Release 0.1.7
-
-* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit
- fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0).
-* To reproduce the iOS library, it's required to cherry pick git commit
- f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue.
-* The code is based on TensorFlow 1.8.0 release candidate and it's very close
- to TensorFlow 1.8.0 release.
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index 8946261814..21cb1832a7 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <cstring>
#include <utility>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 121f3d2646..182bc0977f 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
#include "tensorflow/contrib/lite/string.h"
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 55003cf4e9..382577045b 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
@@ -37,8 +37,8 @@ struct AllocationInfo;
// each tensor needs to be allocated and deallocated, and preallocates all the
// necessary memory (the PlanAllocations phase). It then assigns portions of
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
-// share some of the buffer if a tensor B is to be allocated after another tensor
-// A has been deallocated.
+// share some of the buffer if a tensor B is to be allocated after another
+// tensor A has been deallocated.
//
// If dynamic tensors are used the planning steps can be repeated during model
// execution. Since dynamic tensors don't have sizes until after the
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 05d0b453ab..9317e2bb6e 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -49,6 +49,9 @@ def tflite_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
@@ -56,12 +59,7 @@ def tflite_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_jni_linkopts_unstripped():
@@ -73,17 +71,15 @@ def tflite_jni_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_linkopts():
@@ -235,6 +231,7 @@ def generated_test_models():
"exp",
"expand_dims",
"floor",
+ "floor_div",
"fully_connected",
"fused_batch_norm",
"gather",
@@ -266,7 +263,9 @@ def generated_test_models():
"padv2",
"prelu",
"pow",
+ "reduce_any",
"reduce_max",
+ "reduce_min",
"reduce_prod",
"relu",
"relu1",
@@ -292,6 +291,7 @@ def generated_test_models():
"topk",
"transpose",
#"transpose_conv", # disabled due to b/111213074
+ "unpack",
"where",
]
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 70178b2faa..30901bd0fa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,282 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
-#include <stdint.h>
-
-#include "tensorflow/contrib/lite/context.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-// TODO(aselle): Consider using "if this then that" for testing.
-
-// Possible padding types (for convolutions)
-typedef enum {
- kTfLitePaddingUnknown = 0,
- kTfLitePaddingSame,
- kTfLitePaddingValid,
-} TfLitePadding;
-
-typedef struct {
- int width;
- int height;
-} TfLitePaddingValues;
-
-// Possible fused activation functions.
-// TODO(aselle): rename to TfLiteActivation
-typedef enum {
- kTfLiteActNone = 0,
- kTfLiteActRelu,
- kTfLiteActRelu1,
- kTfLiteActRelu6,
- kTfLiteActTanh,
- kTfLiteActSignBit,
- kTfLiteActSigmoid,
-} TfLiteFusedActivation;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int dilation_width_factor;
- int dilation_height_factor;
- TfLiteFusedActivation activation;
-} TfLiteConvParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int filter_width;
- int filter_height;
- TfLiteFusedActivation activation;
- struct {
- TfLitePaddingValues padding;
- } computed;
-} TfLitePoolParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int depth_multiplier;
- TfLiteFusedActivation activation;
-} TfLiteDepthwiseConvParams;
-
-typedef struct {
- int rank;
- TfLiteFusedActivation activation;
-} TfLiteSVDFParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteRNNParams;
-
-typedef struct {
- bool time_major;
- TfLiteFusedActivation activation;
-} TfLiteSequenceRNNParams;
-
-typedef enum {
- kTfLiteFullyConnectedWeightsFormatDefault = 0,
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
-} TfLiteFullyConnectedWeightsFormat;
-
-typedef struct {
- // Parameters for FullyConnected version 1 or above.
- TfLiteFusedActivation activation;
-
- // Parameters for FullyConnected version 2 or above.
- TfLiteFullyConnectedWeightsFormat weights_format;
-} TfLiteFullyConnectedParams;
-
-typedef enum {
- kTfLiteLshProjectionUnknown = 0,
- kTfLiteLshProjectionSparse = 1,
- kTfLiteLshProjectionDense = 2,
-} TfLiteLSHProjectionType;
-
-typedef struct {
- TfLiteLSHProjectionType type;
-} TfLiteLSHProjectionParams;
-
-typedef struct {
- float beta;
-} TfLiteSoftmaxParams;
-
-typedef struct {
- int axis;
- TfLiteFusedActivation activation;
-} TfLiteConcatenationParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteAddParams;
-
-typedef struct {
-} TfLiteSpaceToBatchNDParams;
-
-typedef struct {
-} TfLiteBatchToSpaceNDParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteMulParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteSubParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteDivParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteL2NormParams;
-
-typedef struct {
- int radius;
- float bias;
- float alpha;
- float beta;
-} TfLiteLocalResponseNormParams;
-
-typedef enum {
- kTfLiteLSTMFullKernel = 0,
- kTfLiteLSTMBasicKernel
-} TfLiteLSTMKernelType;
-
-typedef struct {
- // Parameters for LSTM version 1.
- TfLiteFusedActivation activation;
- float cell_clip;
- float proj_clip;
-
- // Parameters for LSTM version 2.
- // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
- TfLiteLSTMKernelType kernel_type;
-} TfLiteLSTMParams;
-
-typedef struct {
- bool align_corners;
-} TfLiteResizeBilinearParams;
-
-typedef struct {
-} TfLitePadParams;
-
-typedef struct {
-} TfLitePadV2Params;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int shape[8];
- int num_dimensions;
-} TfLiteReshapeParams;
-
-typedef struct {
- int ngram_size;
- int max_skip_size;
- bool include_all_ngrams;
-} TfLiteSkipGramParams;
-
-typedef struct {
- int block_size;
-} TfLiteSpaceToDepthParams;
-
-typedef struct {
- TfLiteType in_data_type;
- TfLiteType out_data_type;
-} TfLiteCastParams;
-
-typedef enum {
- kTfLiteCombinerTypeSum = 0,
- kTfLiteCombinerTypeMean = 1,
- kTfLiteCombinerTypeSqrtn = 2,
-} TfLiteCombinerType;
-
-typedef struct {
- TfLiteCombinerType combiner;
-} TfLiteEmbeddingLookupSparseParams;
-
-typedef struct {
- int axis;
-} TfLiteGatherParams;
-
-typedef struct {
-} TfLiteTransposeParams;
-
-typedef struct {
- bool keep_dims;
-} TfLiteReducerParams;
-
-typedef struct {
- int num_splits;
-} TfLiteSplitParams;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int squeeze_dims[8];
- int num_squeeze_dims;
-} TfLiteSqueezeParams;
-
-typedef struct {
- int begin_mask;
- int end_mask;
- int ellipsis_mask;
- int new_axis_mask;
- int shrink_axis_mask;
-} TfLiteStridedSliceParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMaxParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMinParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
-} TfLiteTransposeConvParams;
-
-typedef struct {
- bool validate_indices;
-} TfLiteSparseToDenseParams;
-
-typedef struct {
- TfLiteType out_type;
-} TfLiteShapeParams;
-
-typedef struct {
- // Parameters supported by version 1:
- float min;
- float max;
- int num_bits;
-
- // Parameters supported by version 2:
- bool narrow_range;
-} TfLiteFakeQuantParams;
-
-typedef struct {
- int values_count;
- int axis;
-} TfLitePackParams;
-
-typedef struct {
- int axis;
-} TfLiteOneHotParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 8a8eb98568..9cf4bea73e 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -113,6 +113,10 @@ typedef enum {
kTfLiteBuiltinOneHot = 85,
kTfLiteBuiltinLogicalAnd = 86,
kTfLiteBuiltinLogicalNot = 87,
+ kTfLiteBuiltinUnpack = 88,
+ kTfLiteBuiltinReduceMin = 89,
+ kTfLiteBuiltinFloorDiv = 90,
+ kTfLiteBuiltinReduceAny = 91,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000000..663eb63cad
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api_internal",
+ srcs = ["c_api_internal.c"],
+ hdrs = [
+ "builtin_op_data.h",
+ "c_api_internal.h",
+ ],
+ visibility = [
+ "//tensorflow/contrib/lite:__subpackages__",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "c_api_internal_test",
+ size = "small",
+ srcs = ["c_api_internal_test.cc"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "builtin_op_data_test",
+ size = "small",
+ srcs = ["builtin_op_data_test.cc"],
+ copts = ["-Wno-unused-variable"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000000..fa43e6a024
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int dilation_width_factor;
+ int dilation_height_factor;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+ // Parameters for FullyConnected version 1 or above.
+ TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+ // Parameters for LSTM version 1.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+ bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+ int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+ bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+ int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+ int values_count;
+ int axis;
+} TfLitePackParams;
+
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000000..4d0ba75e68
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+ TfLitePadding padding = kTfLitePaddingSame;
+ TfLitePaddingValues padding_values;
+ TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+ TfLiteConvParams conv_params;
+ TfLitePoolParams pool_params;
+ TfLiteDepthwiseConvParams depthwise_conv_params;
+ TfLiteSVDFParams svdf_params;
+ TfLiteRNNParams rnn_params;
+ TfLiteSequenceRNNParams sequence_rnn_params;
+ TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+ kTfLiteFullyConnectedWeightsFormatDefault;
+ TfLiteFullyConnectedParams fully_connected_params;
+ TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+ TfLiteLSHProjectionParams projection_params;
+ TfLiteSoftmaxParams softmax_params;
+ TfLiteConcatenationParams concatenation_params;
+ TfLiteAddParams add_params;
+ TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+ TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+ TfLiteMulParams mul_params;
+ TfLiteSubParams sub_params;
+ TfLiteDivParams div_params;
+ TfLiteL2NormParams l2_norm_params;
+ TfLiteLocalResponseNormParams local_response_norm_params;
+ TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+ TfLiteLSTMParams lstm_params;
+ TfLiteResizeBilinearParams resize_bilinear_params;
+ TfLitePadParams pad_params;
+ TfLitePadV2Params pad_v2_params;
+ TfLiteReshapeParams reshape_params;
+ TfLiteSkipGramParams skip_gram_params;
+ TfLiteSpaceToDepthParams space_to_depth_params;
+ TfLiteCastParams cast_params;
+ TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+ TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+ TfLiteGatherParams gather_params;
+ TfLiteTransposeParams transpose_params;
+ TfLiteReducerParams reducer_params;
+ TfLiteSplitParams split_params;
+ TfLiteSqueezeParams squeeze_params;
+ TfLiteStridedSliceParams strided_slice_params;
+ TfLiteArgMaxParams arg_max_params;
+ TfLiteArgMinParams arg_min_params;
+ TfLiteTransposeConvParams transpose_conv_params;
+ TfLiteSparseToDenseParams sparse_to_dense_params;
+ TfLiteShapeParams shape_params;
+ TfLiteFakeQuantParams fake_quant_params;
+ TfLitePackParams pack_params;
+ TfLiteOneHotParams one_hot_params;
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 7f2aa316f4..1846bad4b7 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
int TfLiteIntArrayGetSizeInBytes(int size) {
@@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000000..48df68a654
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,491 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ int64_t* i64;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+ bool* b;
+ int16_t* i16;
+ TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+
+ // The delegate which knows how to handle `buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteBufferHandle buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ size_t tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
+ // An array of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with one or more stub delegate operations. This function
+ // does not take ownership of `nodes_to_replace`.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
index 20d6f69a25..af398f3207 100644
--- a/tensorflow/contrib/lite/context_test.cc
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// NOTE: this tests only the TfLiteIntArray part of context.
-// most of context.h is provided in the context of using it with interpreter.h
-// and interpreter.cc, so interpreter_test.cc tests context structures more
-// thoroughly.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
TEST(IntArray, TestIntArrayCreate) {
TfLiteIntArray* a = TfLiteIntArrayCreate(0);
@@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) {
} // namespace tflite
int main(int argc, char** argv) {
- ::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index c920f6a508..b86c2819b8 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -12,481 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines a C API for implementing operations in tflite.
-// These operations can be defined using c++ but the interface between
-// the interpreter and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-//
-// Some abstractions in this file are created and managed by Interpreter.
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#if defined(_MSC_VER)
-#include <complex.h>
-#endif
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdlib.h>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
- kTfLiteEigenContext = 0, // include eigen_support.h to use.
- kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
- kTfLiteMaxExternalContexts = 2
-} TfLiteExternalContextType;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
- TfLiteExternalContextType type;
- TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-// Forward declare so GetNode can use this is in Context.
-typedef struct _TfLiteRegistration TfLiteRegistration;
-typedef struct _TfLiteDelegate TfLiteDelegate;
-
-#define kOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
- int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
- __GNUC_MINOR__ >= 1
- int data[0];
-#else
- int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
-
-// Free memory of array `v`.
-void TfLiteIntArrayFree(TfLiteIntArray* v);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg) \
- do { \
- if (!(value)) { \
- (context)->ReportError((context), __FILE__ " " msg); \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a) \
- do { \
- if (!(a)) { \
- (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
- __LINE__, #a); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
- do { \
- if ((a) != kTfLiteOk) { \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b) \
- do { \
- if ((a) != (b)) { \
- (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
- __LINE__, #a, #b, (a), (b)); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
- do { \
- if ((status) != kTfLiteOk) { \
- return status; \
- } \
- } while (0)
-
-// Types supported by tensor
-typedef enum {
- kTfLiteNoType = 0,
- kTfLiteFloat32 = 1,
- kTfLiteInt32 = 2,
- kTfLiteUInt8 = 3,
- kTfLiteInt64 = 4,
- kTfLiteString = 5,
- kTfLiteBool = 6,
- kTfLiteInt16 = 7,
- kTfLiteComplex64 = 8,
-} TfLiteType;
-
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-// real_value = scale * (quantized_value - zero_point);
-typedef struct {
- float scale;
- int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// A union of pointers that points to memory for a given tensor.
-typedef union {
- int* i32;
- int64_t* i64;
- float* f;
- char* raw;
- const char* raw_const;
- uint8_t* uint8;
- bool* b;
- int16_t* i16;
-#if defined(_MSC_VER)
- _Fcomplex* c64;
-#else
- _Complex float* c64;
-#endif
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
- kTfLiteMemNone = 0,
- kTfLiteMmapRo,
- kTfLiteArenaRw,
- kTfLiteArenaRwPersistent,
- kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
- // The data type specification for data stored in `data`. This affects
- // what member of `data` union should be used.
- TfLiteType type;
- // A union of data pointers. The appropriate type should be used for a typed
- // tensor based on `type`.
- TfLitePtrUnion data;
- // A pointer to a structure representing the dimensionality interpretation
- // that the buffer should have. NOTE: the product of elements of `dims`
- // and the element datatype size should be equal to `bytes` below.
- TfLiteIntArray* dims;
- // Quantization information.
- TfLiteQuantizationParams params;
- // How memory is mapped
- // kTfLiteMmapRo: Memory mapped read only.
- // i.e. weights
- // kTfLiteArenaRw: Arena allocated read write memory
- // (i.e. temporaries, outputs).
- TfLiteAllocationType allocation_type;
- // The number of bytes required to store the data of this Tensor. I.e.
- // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
- // type is kTfLiteFloat32 and dims = {3, 2} then
- // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
- size_t bytes;
-
- // An opaque pointer to a tflite::MMapAllocation
- const void* allocation;
-
- // Null-terminated name of this tensor.
- const char* name;
-
- // The delegate which knows how to handle `buffer_handle`.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteDelegate* delegate;
-
- // An integer buffer handle that can be handled by `delegate`.
- // The value is valid only when delegate is not null.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteBufferHandle buffer_handle;
-
- // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
- // responsible to set data_is_stale to true.
- // `delegate->CopyFromBufferHandle` can be called to copy the data from
- // delegate buffer.
- // WARNING: This is an // experimental interface that is subject to change.
- bool data_is_stale;
-
- // True if the tensor is a variable.
- bool is_variable;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`;
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free memory of tensor `t`;
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
- TfLiteQuantizationParams quantization, char* buffer,
- size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable,
- TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
- // Inputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* inputs;
-
- // Outputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* outputs;
-
- // Temporary tensors uses during the computations. This usually contains no
- // tensors, but ops are allowed to change that if they need scratch space of
- // any sort.
- TfLiteIntArray* temporaries;
-
- // Opaque data provided by the node implementer through `Registration.init`.
- void* user_data;
-
- // Opaque data provided to the node if the node is a builtin. This is usually
- // a structure defined in builtin_op_data.h
- void* builtin_data;
-
- // Custom initial data. This is the opaque data provided in the flatbuffer.
- // WARNING: This is an experimental interface that is subject to change.
- const void* custom_initial_data;
- int custom_initial_data_size;
-
- // The pointer to the delegate. This is non-null only when the node is
- // created by calling `interpreter.ModifyGraphWithDelegate`.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
- // Number of tensors in the context.
- size_t tensors_size;
-
- // The execution plan contains a list of the node indices in execution
- // order. execution_plan->size is the current number of nodes. And,
- // execution_plan->data[0] is the first node that needs to be run.
- // TfLiteDelegates can traverse the current execution plan by iterating
- // through each member of this array and using GetNodeAndRegistration() to
- // access details about a node. i.e.
- // TfLiteIntArray* execution_plan;
- // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
- // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
- // int node_index = execution_plan->data[exec_index];
- // TfLiteNode* node;
- // TfLiteRegistration* reg;
- // context->GetNodeAndRegistration(context, node_index, &node, &reg);
- // }
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
- TfLiteIntArray** execution_plan);
-
- // An array of tensors in the interpreter context (of length `tensors_size`)
- TfLiteTensor* tensors;
-
- // opaque full context ptr (an opaque c++ data structure)
- void* impl_;
-
- // Request memory pointer be resized. Updates dimensions on the tensor.
- // NOTE: ResizeTensor takes ownership of newSize.
- TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
- TfLiteIntArray* new_size);
- // Request that a error be reported with format string msg.
- void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
- // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
- // non-null, the value pointed to by `first_new_tensor_index` will be set to
- // the index of the first new tensor.
- TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
- int* first_new_tensor_index);
-
- // Get a Tensor node by node_index.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
- TfLiteNode** node,
- TfLiteRegistration** registration);
-
- // Replace ops with one or more stub delegate operations. This function
- // does not take ownership of `nodes_to_replace`.
- TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
- struct TfLiteContext*, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
-
- // Number of threads that are recommended to subsystems like gemmlowp and
- // eigen.
- int recommended_num_threads;
-
- // Access external contexts by type.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
- TfLiteExternalContextType);
- // Set the value of a external context. Does not take ownership of the
- // pointer.
- // WARNING: This is an experimental interface that is subject to change.
- void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
- TfLiteExternalContext*);
-} TfLiteContext;
-
-typedef struct _TfLiteRegistration {
- // Initializes the op from serialized data.
- // If a built-in op:
- // `buffer` is the op's params data (TfLiteLSTMParams*).
- // `length` is zero.
- // If custom op:
- // `buffer` is the op's `custom_options`.
- // `length` is the size of the buffer.
- //
- // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
- // or an instance of a struct).
- //
- // The returned pointer will be stored with the node in the `user_data` field,
- // accessible within prepare and invoke functions below.
- // NOTE: if the data is already in the desired format, simply implement this
- // function to return `nullptr` and implement the free function to be a no-op.
- void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
- // The pointer `buffer` is the data previously returned by an init invocation.
- void (*free)(TfLiteContext* context, void* buffer);
-
- // prepare is called when the inputs this node depends on have been resized.
- // context->ResizeTensor() can be called to request output tensors to be
- // resized.
- //
- // Returns kTfLiteOk on success.
- TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
- // Execute the node (should read node->inputs and output to node->outputs).
- // Returns kTfLiteOk on success.
- TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
- // profiling_string is called during summarization of profiling information
- // in order to group executions together. Providing a value here will cause a
- // given op to appear multiple times is the profiling report. This is
- // particularly useful for custom ops that can perform significantly
- // different calculations depending on their `user-data`.
- const char* (*profiling_string)(const TfLiteContext* context,
- const TfLiteNode* node);
-
- // Builtin codes. If this kernel refers to a builtin this is the code
- // of the builtin. This is so we can do marshaling to other frameworks like
- // NN API.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int32_t builtin_code;
-
- // Custom op name. If the op is a builtin, this will be null.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- // WARNING: This is an experimental interface that is subject to change.
- const char* custom_name;
-
- // The version of the op.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int version;
-} TfLiteRegistration;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
- // Data that delegate needs to identify itself. This data is owned by the
- // delegate. The delegate is owned in the user code, so the delegate is
- // responsible for doing this when it is destroyed.
- void* data_;
-
- // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
- // delegate a view of the current graph through TfLiteContext*. It typically
- // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
- // to ask the TensorFlow lite runtime to create macro-nodes to represent
- // delegated subgraphs of the original graph.
- TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
-
- // Copy the data from delegate buffer handle to raw memory.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
- TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Copy the data from raw memory to delegate buffer handle.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
- TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Free the Delegate Buffer Handle. Note: This only frees the handle, but
- // this doesn't release the underlying resource (e.g. textures). The
- // resources are either owned by application layer or the delegate.
- // This can be null if the delegate doesn't use its own buffer.
- void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
- TfLiteBufferHandle* handle);
-} TfLiteDelegate;
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
- TfLiteDelegate* delegate;
- TfLiteIntArray* nodes_to_replace;
- TfLiteIntArray* input_tensors;
- TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
index abe802e342..ccda4c7393 100644
--- a/tensorflow/contrib/lite/context_util.h
+++ b/tensorflow/contrib/lite/context_util.h
@@ -17,7 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD
new file mode 100644
index 0000000000..e4500534f3
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/BUILD
@@ -0,0 +1,57 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "api",
+ srcs = [
+ "error_reporter.cc",
+ "flatbuffer_conversions.cc",
+ "op_resolver.cc",
+ ],
+ hdrs = [
+ "error_reporter.h",
+ "flatbuffer_conversions.h",
+ "op_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "error_reporter_test",
+ size = "small",
+ srcs = ["error_reporter_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "op_resolver_test",
+ size = "small",
+ srcs = ["op_resolver_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "flatbuffer_conversions_test",
+ size = "small",
+ srcs = ["flatbuffer_conversions_test.cc"],
+ deps = [
+ ":api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/contrib/lite/core/api/error_reporter.cc
index bfcdfc62f9..423f83b1a9 100644
--- a/tensorflow/compiler/xla/ptr_util.h
+++ b/tensorflow/contrib/lite/core/api/error_reporter.cc
@@ -12,24 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
-#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
-
-// As this was moved to tensorflow/core/util, provide indirections here to
-// maintain current functionality of the library.
-
-#include <stddef.h>
-
-#include <memory>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/util/ptr_util.h"
-
-namespace xla {
-using tensorflow::MakeUnique;
-using tensorflow::WrapUnique;
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h
new file mode 100644
index 0000000000..a2f780b003
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d", 5);
+// or
+// va_list args;
+// foo.Report("test %d", args); // where args is va_list
+//
+// Subclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter() {}
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000000..0463eee6be
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ vsnprintf(buffer_, kBufferSize, format, args);
+ return 0;
+ }
+ char* GetBuffer() { return buffer_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ reporter->Report("Error: %d", 23);
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
new file mode 100644
index 0000000000..1420fbcdc6
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -0,0 +1,622 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+
+namespace {
+
+// Copies the contents from the flatbuffer int vector `flatbuffer` into the
+// int array `buffer`. `flat_vector` and `buffer` represent the same
+// configuration operation for a given operation.
+void FlatBufferIntVectorToArray(int max_size_of_buffer,
+ const flatbuffers::Vector<int32_t>* flat_vector,
+ int* buffer, ErrorReporter* error_reporter) {
+ if (!flat_vector) {
+ error_reporter->Report("Input array not provided for operation.\n");
+ } else {
+ int num_dimensions = flat_vector->Length();
+ if (num_dimensions > max_size_of_buffer / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in the operation's input array.\n");
+ } else {
+ for (int i = 0; i < num_dimensions; ++i) {
+ buffer[i] = flat_vector->Get(i);
+ }
+ }
+ }
+}
+
+// Allocate a structure using malloc, but make sure the structure is a POD
+// structure that doesn't require constructors to run. The reason we do this,
+// is that Interpreter's C extension part will take ownership so destructors
+// will not be run during deallocation.
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+} // namespace
+
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ case TensorType_BOOL:
+ *type = kTfLiteBool;
+ break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU_N1_TO_1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ *builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ free(params);
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ switch (fully_connected_params->weights_format()) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ params->weights_format =
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+ break;
+ default:
+ error_reporter->Report("Unhandled fully-connected weights format.");
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DIV: {
+ auto* params = MallocPOD<TfLiteDivParams>();
+ if (auto* schema_params = op->builtin_options_as_DivOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SUB: {
+ auto* params = MallocPOD<TfLiteSubParams>();
+ if (auto* schema_params = op->builtin_options_as_SubOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
+ params->shape, error_reporter);
+ params->num_dimensions = new_shape->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_GATHER: {
+ TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ params->axis = 0;
+ if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
+ params->axis = gather_params->axis();
+ }
+
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_REDUCE_MAX:
+ case BuiltinOperator_REDUCE_MIN:
+ case BuiltinOperator_REDUCE_PROD:
+ case BuiltinOperator_REDUCE_ANY:
+ case BuiltinOperator_SUM: {
+ auto* params = MallocPOD<TfLiteReducerParams>();
+ if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
+ params->keep_dims = schema_params->keep_dims();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPLIT: {
+ auto* params = MallocPOD<TfLiteSplitParams>();
+ if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
+ params->num_splits = schema_params->num_splits();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SQUEEZE: {
+ auto* params = MallocPOD<TfLiteSqueezeParams>();
+ if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+ const auto& squeeze_dims = schema_params->squeeze_dims();
+ FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+ params->squeeze_dims, error_reporter);
+ params->num_squeeze_dims = squeeze_dims->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MAX: {
+ auto* params = MallocPOD<TfLiteArgMaxParams>();
+ if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TRANSPOSE_CONV: {
+ TfLiteTransposeConvParams* params =
+ MallocPOD<TfLiteTransposeConvParams>();
+ if (auto* transpose_conv_params =
+ op->builtin_options_as_TransposeConvOptions()) {
+ params->padding = parse_padding(transpose_conv_params->padding());
+ params->stride_width = transpose_conv_params->stride_w();
+ params->stride_height = transpose_conv_params->stride_h();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPARSE_TO_DENSE: {
+ TfLiteSparseToDenseParams* params =
+ MallocPOD<TfLiteSparseToDenseParams>();
+ if (auto* sparse_to_dense_params =
+ op->builtin_options_as_SparseToDenseOptions()) {
+ params->validate_indices = sparse_to_dense_params->validate_indices();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SHAPE: {
+ auto* params = MallocPOD<TfLiteShapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+ ConvertTensorType(schema_params->out_type(), &params->out_type,
+ error_reporter);
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_PACK: {
+ TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ if (auto* pack_params = op->builtin_options_as_PackOptions()) {
+ params->values_count = pack_params->values_count();
+ params->axis = pack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DELEGATE: {
+ // TODO(ycling): Revisit when supporting saving delegated models.
+ error_reporter->Report("DELEGATE op shouldn't exist in model.");
+ return kTfLiteError;
+ }
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ params->narrow_range = schema_params->narrow_range();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_UNPACK: {
+ TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+ params->num = unpack_params->num();
+ params->axis = unpack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
+ // Below are the ops with no builtin_data strcture.
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ case BuiltinOperator_CALL:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_EXP:
+ case BuiltinOperator_EXPAND_DIMS:
+ case BuiltinOperator_FLOOR:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_LOG:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_NEG:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_PRELU:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RSQRT:
+ case BuiltinOperator_SELECT:
+ case BuiltinOperator_SIN:
+ case BuiltinOperator_SLICE:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SQRT:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_TILE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
+ case BuiltinOperator_LOGICAL_OR:
+ case BuiltinOperator_LOGICAL_AND:
+ case BuiltinOperator_LOGICAL_NOT:
+ case BuiltinOperator_FLOOR_DIV:
+ break;
+ }
+ return kTfLiteOk;
+} // NOLINT[readability/fn_size]
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
new file mode 100644
index 0000000000..4dec6f9cfc
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+
+// These functions transform codes and data structures that are defined in the
+// flatbuffer serialization format into in-memory values that are used by the
+// runtime API and interpreter.
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
+// calling function has to pass in an allocator object, and this allocator
+// will be called to reserve space for the output data. If the calling
+// function's allocator reserves memory on the heap, then it's the calling
+// function's responsibility to free it.
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data);
+
+// Converts the tensor data type used in the flat buffer to the representation
+// used by the runtime.
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
new file mode 100644
index 0000000000..b12bdf43b2
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+namespace {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(FlatbufferConversions, TestParseOpDataConv) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> conv_options =
+ CreateConv2DOptions(builder, Padding_SAME, 1, 2,
+ ActivationFunctionType_RELU, 3, 4)
+ .Union();
+ flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
+ nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
+ &output_data));
+ EXPECT_NE(nullptr, output_data);
+ TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
+ EXPECT_EQ(kTfLitePaddingSame, params->padding);
+ EXPECT_EQ(1, params->stride_width);
+ EXPECT_EQ(2, params->stride_height);
+ EXPECT_EQ(kTfLiteActRelu, params->activation);
+ EXPECT_EQ(3, params->dilation_width_factor);
+ EXPECT_EQ(4, params->dilation_height_factor);
+ free(output_data);
+}
+
+TEST(FlatbufferConversions, TestParseOpDataCustom) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> null_options;
+ flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr,
+ CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(custom_offset);
+ void* custom_pointer = builder.GetBufferPointer();
+ const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
+ &output_data));
+ EXPECT_EQ(nullptr, output_data);
+}
+
+TEST(FlatbufferConversions, TestConvertTensorType) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ TfLiteType type;
+ EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter));
+ EXPECT_EQ(kTfLiteFloat32, type);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc
new file mode 100644
index 0000000000..55ee924843
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+namespace tflite {
+
+TfLiteStatus GetRegistrationFromOpCode(
+ const OperatorCode* opcode, const OpResolver& op_resolver,
+ ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
+ TfLiteStatus status = kTfLiteOk;
+ *registration = nullptr;
+ auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
+ if (builtin_code > BuiltinOperator_MAX ||
+ builtin_code < BuiltinOperator_MIN) {
+ error_reporter->Report(
+ "Op builtin_code out of range: %d. Are you using old TFLite binary "
+ "with newer model?",
+ builtin_code);
+ status = kTfLiteError;
+ } else if (builtin_code != BuiltinOperator_CUSTOM) {
+ *registration = op_resolver.FindOp(builtin_code, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter->Report(
+ "Operator with CUSTOM builtin_code has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ *registration = op_resolver.FindOp(name, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h
new file mode 100644
index 0000000000..5f5e6b2736
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Finds the op registration for a builtin operator by enum code.
+ virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const = 0;
+ // Finds the op registration of a custom operator by op name.
+ virtual const TfLiteRegistration* FindOp(const char* op,
+ int version) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Handles the logic for converting between an OperatorCode structure extracted
+// from a flatbuffer and information about a registered operator implementation.
+TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter,
+ const TfLiteRegistration** registration);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
new file mode 100644
index 0000000000..167463110e
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ if (op == BuiltinOperator_CONV_2D) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(OpResolver, TestResolver) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CAST, 0);
+ EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ EXPECT_EQ(nullptr, registration);
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCast) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "mock_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 87486e8814..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,10 +16,11 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:framework",
@@ -54,11 +55,12 @@ cc_library(
":delegate_data",
":kernel",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -87,7 +89,7 @@ cc_library(
"//tensorflow/core/common_runtime/eager:context",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
@@ -104,6 +106,7 @@ tf_cc_test(
":delegate_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -117,6 +120,7 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -124,11 +128,16 @@ cc_library(
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
] + select({
+ # TODO(b/111881878): The android_tensorflow_lib target pulls in the full
+ # set of core TensorFlow kernels. We may want to revisit this dependency
+ # to allow selective registration via build targets.
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
],
}),
)
@@ -164,12 +173,12 @@ cc_library(
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":constants",
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -189,8 +198,3 @@ tf_cc_test(
"@com_google_googletest//:gtest",
],
)
-
-cc_library(
- name = "constants",
- hdrs = ["constants.h"],
-)
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index a28329ae7d..aaaa045840 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6d15ba47dc..70f3c15af4 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index b3a0ffcec1..def063309f 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index eb47f46c0b..984f8bbc98 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) {
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ ASSERT_EQ(GetType(8), kTfLiteFloat32);
+}
+
+TEST_F(DelegateTest, NonFloatTypeInference) {
+ AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
+
+ AddTfOp(testing::kAdd, {0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2});
+ SetTypedValues<int>(0, {1, 2, 3, 4});
+ SetShape(1, {2, 2});
+ SetTypedValues<int>(1, {4, 3, 2, 1});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
+ ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
+ ASSERT_EQ(GetType(2), kTfLiteInt32);
}
TEST_F(DelegateTest, MixedGraph) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 1082b78725..274c3c082a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
#include "tensorflow/contrib/lite/delegates/eager/util.h"
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
// Note: this is part of TF Lite's Eager delegation code which is to be
// completed soon.
@@ -189,6 +190,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
}
}
+ // Fill NodeDef with defaults if it's a valid op.
+ const tensorflow::OpRegistrationData* op_reg_data;
+ auto tf_status = tensorflow::OpRegistry::Global()->LookUp(
+ node_data.nodedef.op(), &op_reg_data);
+ if (tf_status.ok()) {
+ AddDefaultsToNodeDef(op_reg_data->op_def, &node_data.nodedef);
+ }
+
for (auto input_index : TfLiteIntArrayView(node->inputs)) {
node_data.inputs.push_back(input_index);
}
@@ -269,7 +278,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* tensor = &context->tensors[tensor_index];
TF_LITE_ENSURE_OK(
context,
- CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+ CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
tensor->buffer_handle = tensor_index;
tensor->data_is_stale = true;
}
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
index 100672c82d..2478abccaa 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace eager {
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index 26d96acc82..8584999ace 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
#include "absl/memory/memory.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
@@ -25,19 +25,6 @@ namespace testing {
bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetValues(int tensor_index,
- const std::vector<float>& values) {
- float* v = interpreter_->typed_tensor<float>(tensor_index);
- for (float f : values) {
- *v++ = f;
- }
-}
-
-std::vector<float> EagerModelTest::GetValues(int tensor_index) {
- TfLiteTensor* o = interpreter_->tensor(tensor_index);
- return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
-}
-
void EagerModelTest::SetShape(int tensor_index,
const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
@@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
+TfLiteType EagerModelTest::GetType(int tensor_index) {
+ return interpreter_->tensor(tensor_index)->type;
+}
+
void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
const std::vector<int>& outputs,
- const TfLiteType& type,
- const std::vector<int>& dims) {
+ TfLiteType type, const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
+ // Suppress explicit output type specification to ensure type inference
+ // works properly.
+ if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
+ type = kTfLiteFloat32;
+ }
CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
/*name=*/"",
/*dims=*/dims, quant),
@@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
return " attr{ key: '" + key + "' value {" + value + "}}";
};
+ // Crude type attribution, will need fleshing out as more tests are added.
+ // TODO(b/113613439): Use nodedef string utilities to properly handle
+ // all types.
+ string type_attribute = attr("T", "type: DT_FLOAT");
+ if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) {
+ type_attribute = attr("T", "type: DT_INT32");
+ }
+
if (op == kUnpack) {
- string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
- attr("axis", "i: 0");
+ string attributes =
+ type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
index 0eab9e1135..816db41931 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test {
bool Invoke();
+ // Sets the (typed) tensor's values at the given index.
+ template <typename T>
+ void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+ memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+ values.size() * sizeof(T));
+ }
+
+ // Returns the (typed) tensor's values at the given index.
+ template <typename T>
+ std::vector<T> GetTypedValues(int tensor_index) {
+ const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+ const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+ return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+ }
+
// Sets the tensor's values at the given index.
- void SetValues(int tensor_index, const std::vector<float>& values);
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ SetTypedValues<float>(tensor_index, values);
+ }
// Returns the tensor's values at the given index.
- std::vector<float> GetValues(int tensor_index);
+ std::vector<float> GetValues(int tensor_index) {
+ return GetTypedValues<float>(tensor_index);
+ }
// Sets the tensor's shape at the given index.
void SetShape(int tensor_index, const std::vector<int>& values);
@@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test {
// Returns the tensor's shape at the given index.
std::vector<int> GetShape(int tensor_index);
+ // Returns the tensor's type at the given index.
+ TfLiteType GetType(int tensor_index);
+
const TestErrorReporter& error_reporter() const { return error_reporter_; }
// Adds `num_tensor` tensors to the model. `inputs` contains the indices of
// the input tensors and `outputs` contains the indices of the output
// tensors. All tensors are set to have `type` and `dims`.
void AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs, const TfLiteType& type,
+ const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index c8aa0b7f69..051246bf86 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -13,16 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/util.h"
-#include "tensorflow/contrib/lite/delegates/eager/constants.h"
namespace tflite {
namespace eager {
-bool IsEagerOp(const char* custom_name) {
- return custom_name && strncmp(custom_name, kCustomCodePrefix,
- strlen(kCustomCodePrefix)) == 0;
-}
-
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
if (!status.ok()) {
@@ -32,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context,
return kTfLiteOk;
}
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor) {
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor) {
+ tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
+ if (tensor->type == kTfLiteNoType) {
+ context->ReportError(context,
+ "TF Lite does not support TensorFlow data type: %s",
+ DataTypeString(src.dtype()).c_str());
+ return kTfLiteError;
+ }
+
int num_dims = src.dims();
TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
for (int j = 0; j < num_dims; ++j) {
@@ -74,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
}
}
+TfLiteType GetTensorFlowLiteType(TF_DataType type) {
+ switch (type) {
+ case TF_FLOAT:
+ return kTfLiteFloat32;
+ case TF_INT16:
+ return kTfLiteInt16;
+ case TF_INT32:
+ return kTfLiteInt32;
+ case TF_UINT8:
+ return kTfLiteUInt8;
+ case TF_INT64:
+ return kTfLiteInt64;
+ case TF_COMPLEX64:
+ return kTfLiteComplex64;
+ case TF_STRING:
+ return kTfLiteString;
+ case TF_BOOL:
+ return kTfLiteBool;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index b7363361be..930cb99cb9 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -16,30 +16,31 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
namespace eager {
-// Checks whether the prefix of the custom name indicates the operation is an
-// Eager operation.
-bool IsEagerOp(const char* custom_name);
-
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status);
-// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an
-// error and returns kTfLiteError if the shape can't be converted.
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor);
+// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite
+// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't
+// be converted.
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor);
// Returns the TF C API Data type that corresponds to the given TfLiteType.
TF_DataType GetTensorFlowDataType(TfLiteType type);
+// Returns the TfLiteType that corresponds to the given TF C API Data type.
+TfLiteType GetTensorFlowLiteType(TF_DataType);
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 541d0b1701..aebc91149c 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -26,6 +26,7 @@ namespace eager {
namespace {
using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
using tensorflow::Tensor;
using ::testing::ElementsAre;
@@ -71,27 +72,41 @@ TEST(UtilTest, ConvertStatus) {
EXPECT_TRUE(context.error.empty());
}
-TEST(UtilTest, CopyShape) {
+TEST(UtilTest, CopyShapeAndType) {
TestContext context;
context.ReportError = ReportError;
context.ResizeTensor = ResizeTensor;
TfLiteTensor dst;
- EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(0));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst),
+ kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst),
+ kTfLiteOk);
+ EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteInt32);
+
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
kTfLiteError);
EXPECT_EQ(context.error,
"Dimension value in TensorFlow shape is larger than supported by "
"TF Lite");
+
+ EXPECT_EQ(
+ CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst),
+ kTfLiteError);
+ EXPECT_EQ(context.error,
+ "TF Lite does not support TensorFlow data type: half");
}
-TEST(UtilTest, TypeConversions) {
+TEST(UtilTest, TypeConversionsFromTFLite) {
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
@@ -103,14 +118,17 @@ TEST(UtilTest, TypeConversions) {
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
-TEST(UtilTest, IsEagerOp) {
- EXPECT_TRUE(IsEagerOp("Eager"));
- EXPECT_TRUE(IsEagerOp("EagerOp"));
- EXPECT_FALSE(IsEagerOp("eager"));
- EXPECT_FALSE(IsEagerOp("Eage"));
- EXPECT_FALSE(IsEagerOp("OpEager"));
- EXPECT_FALSE(IsEagerOp(nullptr));
- EXPECT_FALSE(IsEagerOp(""));
+TEST(UtilTest, TypeConversionsFromTensorFlow) {
+ EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
+ EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
+ EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
+ EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
+ EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
+ EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
+ EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
+ EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
}
} // namespace
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
index 954955f24b..4e7b2948fb 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/BUILD
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -13,6 +13,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
],
@@ -29,6 +30,7 @@ tf_cc_test(
deps = [
":nnapi_delegate",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index e6cc3dd99c..e3eebac4da 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -238,7 +238,7 @@ class NNAPIOpBuilder {
tensor->params.zero_point};
CHECK_NN(context_,
ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- augmented_inputs_.push_back(ann_index);
+ augmented_outputs_.push_back(ann_index);
*ann_tensor_index_out = ann_index;
return kTfLiteOk;
@@ -370,8 +370,8 @@ struct NNAPIOpMappingArgs {
TfLiteContext* context;
NNAPIOpBuilder* builder;
TfLiteNode* node;
- std::vector<int>* model_state_inputs;
- std::vector<int>* model_state_tfl_outputs;
+ std::vector<int>* model_state_outputs;
+ std::vector<int>* model_state_tfl_inputs;
};
// The kernel that represents the subgraph of TF Lite being run on NN API.
@@ -781,8 +781,7 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinRnn:
// NNAPI only support float32 weights.
- // TODO(miaowang): check the number of inputs before accessing it.
- if (version == 1 &&
+ if (version == 1 && node->inputs->size == 5 &&
context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type ==
kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
@@ -790,11 +789,11 @@ class NNAPIDelegateKernel {
// NNAPI need both state_in and state_out.
int ann_index;
mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0],
+ mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4],
&ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0]);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]);
auto builtin = reinterpret_cast<TfLiteRNNParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
@@ -806,7 +805,7 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinSvdf:
// NNAPI only support float32 weights.
- if (version == 1 &&
+ if (version == 1 && node->inputs->size == 5 &&
context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
.type == kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
@@ -814,11 +813,13 @@ class NNAPIDelegateKernel {
// NNAPI need both state_in and state_out.
int ann_index;
mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kStateTensor*/ 0],
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 4],
&ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kStateTensor*/ 0]);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 4]);
auto builtin = reinterpret_cast<TfLiteSVDFParams*>(
mapping_args.node->builtin_data);
@@ -833,28 +834,12 @@ class NNAPIDelegateKernel {
case kTfLiteBuiltinLstm:
// NNAPI only support float32 weights.
// TODO(miaowang): add loggings to indicate why the op is rejected.
- if (version == 1 && node->inputs->size == 18 &&
+ if (version == 1 && node->inputs->size == 20 &&
context->tensors[node->inputs
->data[/*kInputToOutputWeightsTensor*/ 4]]
.type == kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
- // NNAPI need both state_in and state_out for cell_state and
- // output_state.
- int ann_index;
- mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0],
- &ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0]);
- mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kCellStateTensor*/ 1],
- &ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kCellStateTensor*/ 1]);
-
auto builtin = reinterpret_cast<TfLiteLSTMParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
@@ -864,6 +849,25 @@ class NNAPIDelegateKernel {
// Current NNAPI implementation requires the sratch_buffer as
// output.
mapping_args.builder->AddAdditionalFloat32OutputTensor(2);
+
+ // NNAPI need both state_in and state_out for cell_state and
+ // output_state.
+ int ann_index;
+ mapping_args.builder->AddStateFloat32Tensor(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 18],
+ &ann_index);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 18]);
+ mapping_args.builder->AddStateFloat32Tensor(
+ mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19],
+ &ann_index);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]);
+
return ANEURALNETWORKS_LSTM;
};
} else {
@@ -950,12 +954,10 @@ class NNAPIDelegateKernel {
// Set the input tensor buffers. Note: we access tflite tensors using
// absolute indices but NN api indices inputs by relative indices.
int relative_input_index = 0;
- int num_optional_tensors = 0;
size_t input_offset = 0;
for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
if (absolute_input_index == kOptionalTensor) {
- num_optional_tensors++;
continue;
}
TfLiteTensor* tensor = &context->tensors[absolute_input_index];
@@ -989,16 +991,16 @@ class NNAPIDelegateKernel {
// The state_out of previous invocation need to be mapped to state_in of
// current invocation.
- for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) {
- int state_tensor_idx = model_state_tfl_outputs_[i];
+ for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
+ int state_tensor_idx = model_state_tfl_inputs_[i];
TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
// Here we are using a deep copy for state_in tensors so that we are not
// reading and writing into the same buffer during a invocation.
// TODO(110369471): using double shared buffer to minimize the copies.
- CHECK_NN(context,
- ANeuralNetworksExecution_setInput(
- execution, i + node->inputs->size - num_optional_tensors,
- nullptr, tensor->data.raw, tensor->bytes));
+ CHECK_NN(context, ANeuralNetworksExecution_setOutput(
+ execution, relative_output_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_output_index++;
}
// Invoke ANN in blocking fashion.
ANeuralNetworksEvent* event = nullptr;
@@ -1030,8 +1032,8 @@ class NNAPIDelegateKernel {
// Track indices we use
OperandMapping operand_mapping_;
- std::vector<int> model_state_inputs_;
- std::vector<int> model_state_tfl_outputs_;
+ std::vector<int> model_state_outputs_;
+ std::vector<int> model_state_tfl_inputs_;
std::unique_ptr<NNMemory> nn_input_memory_;
std::unique_ptr<NNMemory> nn_output_memory_;
@@ -1063,9 +1065,9 @@ class NNAPIDelegateKernel {
}
}
// Get op type and operands
- int nn_op_type = Map(context, reg->builtin_code, reg->version,
- node)({context, &builder, node, &model_state_inputs_,
- &model_state_tfl_outputs_});
+ int nn_op_type = Map(context, reg->builtin_code, reg->version, node)(
+ {context, &builder, node, &model_state_outputs_,
+ &model_state_tfl_inputs_});
// Map outputs to NN API tensor indices.
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
@@ -1098,17 +1100,17 @@ class NNAPIDelegateKernel {
}
}
- // Add state input tensors as model inputs
- for (int i : model_state_inputs_) {
- inputs.push_back(i);
- }
-
size_t total_output_byte_size = 0;
for (int i : TfLiteIntArrayView(output_tensors)) {
outputs.push_back(operand_mapping_.lite_index_to_ann(i));
total_output_byte_size += context->tensors[i].bytes;
}
+ // Add state output tensors as model inputs
+ for (int i : model_state_outputs_) {
+ outputs.push_back(i);
+ }
+
// Tell ANN to declare inputs/outputs
CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_.get(), inputs.size(), inputs.data(),
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 44cca2fd28..4852b76974 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 3224b23a0c..4b01aefd6a 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -1773,15 +1773,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ BuildInterpreter({{batches_, input_size_}, // input tensor
+ {units_, input_size_}, // weights tensor
+ {units_, units_}, // recurrent weights tensor
+ {units_}, // bias tensor
+ {batches_, units_}}); // hidden state tensor
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -1802,14 +1803,6 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -1835,7 +1828,6 @@ TEST(NNAPIDelegate, RnnBlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
@@ -1968,16 +1960,20 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
weights_feature_ = AddInput(weights_feature_type);
weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
- state_ = AddOutput(TensorType_FLOAT32);
+ const int num_filters = units * rank;
+ activation_state_ = AddInput(
+ TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
+ /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
BuildInterpreter({
- {batches_, input_size_}, // Input tensor
- {units_ * rank, input_size_}, // weights_feature tensor
- {units_ * rank, memory_size_}, // weights_time tensor
- {units_} // bias tensor
+ {batches_, input_size_}, // input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_}, // bias tensor
+ {batches, memory_size * num_filters} // activation_state tensor
});
}
@@ -1996,15 +1992,6 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(input_, offset, begin, end);
}
- // Resets the state of SVDF op by filling it with 0's.
- void ResetState() {
- const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
// Extracts the output tensor from the SVDF op.
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -2017,7 +2004,7 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
int weights_feature_;
int weights_time_;
int bias_;
- int state_;
+ int activation_state_;
int output_;
int batches_;
@@ -2081,7 +2068,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) {
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
- svdf.ResetState();
svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input));
}
@@ -2120,7 +2106,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) {
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
- svdf.ResetState();
svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input));
}
@@ -2192,8 +2177,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -2271,22 +2260,6 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -2495,10 +2468,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -2602,10 +2571,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -3266,10 +3231,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index 3c5f805f12..5c20eedc25 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
-#include <cstdarg>
-#include "tensorflow/contrib/lite/context.h"
-
-namespace tflite {
-
-// A functor that reports error to supporting system. Invoked similar to
-// printf.
-//
-// Usage:
-// ErrorReporter foo;
-// foo.Report("test %d", 5);
-// or
-// va_list args;
-// foo.Report("test %d", args); // where args is va_list
-//
-// Subclass ErrorReporter to provide another reporting destination.
-// For example, if you have a GUI program, you might redirect to a buffer
-// that drives a GUI error log box.
-class ErrorReporter {
- public:
- virtual ~ErrorReporter();
- virtual int Report(const char* format, va_list args) = 0;
- int Report(const char* format, ...);
- int ReportError(void*, const char* format, ...);
-};
-
-// An error reporter that simplify writes the message to stderr.
-struct StderrReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override;
-};
-
-// Return the default error reporter (output to stderr).
-ErrorReporter* DefaultErrorReporter();
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle
index eb7fd705e1..35e7887852 100644
--- a/tensorflow/contrib/lite/examples/android/app/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -51,10 +50,5 @@ apply from: "download-models.gradle"
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle
index a47fa4bbf6..66a62a921a 100644
--- a/tensorflow/contrib/lite/examples/android/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/build.gradle
@@ -14,6 +14,7 @@ buildscript {
allprojects {
repositories {
+ google()
jcenter()
}
}
diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile
index 8084307ac7..f460693122 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_camera_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile
index eea7ecb759..ddb77088d9 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_simple_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
index 98934ce41d..96d2810937 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
+++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
-#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
#include <vector>
std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width,
int* out_height, int* out_channels);
-#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_IOS_SIMPLE_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
index 5fc75b1f72..7881ee80ca 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
@@ -39,4 +39,4 @@ template void resize<float>(float*, unsigned char*, int, int, int, int, int,
} // namespace label_image
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n.h b/tensorflow/contrib/lite/examples/label_image/get_top_n.h
index 70a7586fe6..adef434c00 100644
--- a/tensorflow/contrib/lite/examples/label_image/get_top_n.h
+++ b/tensorflow/contrib/lite/examples/label_image/get_top_n.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
-#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_
#include "tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h"
@@ -35,4 +35,4 @@ template void get_top_n<float>(float*, int, size_t, float,
} // namespace label_image
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h
index e416fbd39b..708cf2f2b1 100644
--- a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
-#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_
#include <algorithm>
#include <queue>
@@ -67,4 +67,4 @@ void get_top_n(T* prediction, int prediction_size, size_t num_results,
} // namespace label_image
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index 34c223f713..f0be881b58 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
-#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_
#include "tensorflow/contrib/lite/string.h"
@@ -40,4 +40,4 @@ struct Settings {
} // namespace label_image
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 8fc07e8eb7..ea4a543252 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -78,6 +78,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/add.bin"],
deps = [
":c_api",
+ "//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index a4ab0e8c30..c589cf71ea 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -29,12 +31,14 @@ extern "C" {
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
static_cast<const char*>(model_data), model_size);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
TFL_Model* TFL_NewModelFromFile(const char* model_path) {
auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
void TFL_DeleteModel(TFL_Model* model) { delete model; }
@@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) {
return static_cast<void*>(tensor->data.raw);
}
+const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; }
+
TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
size_t input_data_size) {
if (tensor->bytes != input_data_size) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index 3757349b55..b429e76870 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter;
// failure.
//
// * `model` must be a valid model instance. The caller retains ownership of the
-// object, and can destroy it immediately after creating the interpreter.
+// object, and can destroy it immediately after creating the interpreter; the
+// interpreter will maintain its own reference to the underlying model data.
// * `optional_options` may be null. The caller retains ownership of the object,
// and can safely destroy it immediately after creating the interpreter.
//
@@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount(
// Returns the tensor associated with the output index.
// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor)
+//
+// NOTE: The shape and underlying data buffer for output tensors may be not
+// be available until after the output tensor has been both sized and allocated.
+// In general, best practice is to interact with the output tensor *after*
+// calling TFL_InterpreterInvoke().
TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor(
const TFL_Interpreter* interpreter, int32_t output_index);
@@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor);
// Returns a pointer to the underlying data buffer.
//
-// Note: The result may be null if tensors have not yet been allocated, e.g.,
+// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
// if the Tensor has just been created or resized and `TFL_AllocateTensors()`
// has yet to be called, or if the output tensor is dynamically sized and the
// interpreter hasn't been invoked.
TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor);
+// Returns the (null-terminated) name of the tensor.
+TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor);
+
// Copies from the provided input buffer into the tensor's buffer.
// REQUIRES: input_data_size == TFL_TensorByteSize(tensor)
TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer(
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index c5c612a4c6..60c2e4e2cd 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -24,7 +24,8 @@ limitations under the License.
// not be depended on.
struct TFL_Model {
- std::unique_ptr<tflite::FlatBufferModel> impl;
+ // Sharing is safe as FlatBufferModel is const.
+ std::shared_ptr<const tflite::FlatBufferModel> impl;
};
struct TFL_InterpreterOptions {
@@ -35,6 +36,9 @@ struct TFL_InterpreterOptions {
};
struct TFL_Interpreter {
+ // Taking a reference to the (const) model data avoids lifetime-related issues
+ // and complexity with the TFL_Model's existence.
+ std::shared_ptr<const tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index a631dae890..649dac8d1a 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1);
EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(input_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(input_tensor), "input");
std::array<float, 2> input = {1.f, 3.f};
ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(),
@@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1);
EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(output_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(output_tensor), "output");
std::array<float, 2> output;
ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(),
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
index b6905b5fbf..676783063d 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs
@@ -29,15 +29,16 @@ namespace TensorFlowLite
{
private const string TensorFlowLibrary = "tensorflowlite_c";
- private TFL_Interpreter handle;
+ private TFL_Model model;
+ private TFL_Interpreter interpreter;
public Interpreter(byte[] modelData) {
GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned);
IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject();
- TFL_Model model = TFL_NewModel(modelDataPtr, modelData.Length);
- handle = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero);
- TFL_DeleteModel(model);
- if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
+ model = TFL_NewModel(modelDataPtr, modelData.Length);
+ if (model == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Model");
+ interpreter = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero);
+ if (interpreter == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter");
}
~Interpreter() {
@@ -45,43 +46,45 @@ namespace TensorFlowLite
}
public void Dispose() {
- if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle);
- handle = IntPtr.Zero;
+ if (interpreter != IntPtr.Zero) TFL_DeleteInterpreter(interpreter);
+ interpreter = IntPtr.Zero;
+ if (model != IntPtr.Zero) TFL_DeleteModel(model);
+ model = IntPtr.Zero;
}
public void Invoke() {
- ThrowIfError(TFL_InterpreterInvoke(handle));
+ ThrowIfError(TFL_InterpreterInvoke(interpreter));
}
public int GetInputTensorCount() {
- return TFL_InterpreterGetInputTensorCount(handle);
+ return TFL_InterpreterGetInputTensorCount(interpreter);
}
public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) {
GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned);
IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
- TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex);
+ TFL_Tensor tensor = TFL_InterpreterGetInputTensor(interpreter, inputTensorIndex);
ThrowIfError(TFL_TensorCopyFromBuffer(
tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData)));
}
public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) {
ThrowIfError(TFL_InterpreterResizeInputTensor(
- handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
+ interpreter, inputTensorIndex, inputTensorShape, inputTensorShape.Length));
}
public void AllocateTensors() {
- ThrowIfError(TFL_InterpreterAllocateTensors(handle));
+ ThrowIfError(TFL_InterpreterAllocateTensors(interpreter));
}
public int GetOutputTensorCount() {
- return TFL_InterpreterGetOutputTensorCount(handle);
+ return TFL_InterpreterGetOutputTensorCount(interpreter);
}
public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) {
GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned);
IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject();
- TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex);
+ TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(interpreter, outputTensorIndex);
ThrowIfError(TFL_TensorCopyToBuffer(
tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData)));
}
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
index 9c06c4ebd9..4786cc62f9 100644
--- a/tensorflow/contrib/lite/experimental/kernels/BUILD
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -53,6 +53,7 @@ cc_library(
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -61,8 +62,8 @@ cc_library(
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
index c658e43092..7c5099235a 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 834d1ebd66..8442c4d46c 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index 9d1e6a562f..aa42b495bd 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.32134813})));
+ ElementsAreArray(ArrayFloatNear({-0.357094})));
}
TEST(CTCBeamSearchTest, MultiBatchTest) {
@@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) {
EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
// Check log probabilities output.
- EXPECT_THAT(
- m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
}
TEST(CTCBeamSearchTest, MultiPathsTest) {
@@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) {
EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear(
- {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+ ElementsAreArray(
+ ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
}
TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
@@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+ ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
}
} // namespace
diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD
new file mode 100644
index 0000000000..82d39c00ab
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "option_writer_generator",
+ srcs = ["option_writer_generator.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "writer_lib",
+ srcs = [
+ "enum_mapping.h",
+ "writer_lib.cc",
+ ],
+ hdrs = [
+ "writer_lib.h",
+ ],
+ data = [
+ ":option_writer_gen",
+ ],
+ textual_hdrs = ["option_writer_generated.h"],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ ],
+)
+
+cc_binary(
+ name = "writer",
+ srcs = ["writer.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "writer_lib_test",
+ size = "small",
+ srcs = ["writer_lib_test.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+genrule(
+ name = "option_writer_gen",
+ outs = ["option_writer_generated.h"],
+ cmd = "$(location :option_writer_generator) $(@)",
+ tools = [":option_writer_generator"],
+)
diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
new file mode 100644
index 0000000000..8bc464fd71
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+// TODO(aselle): Ideally extract this from the schema.
+
+namespace tflite {
+
+inline ActivationFunctionType TfLiteActivationToSchemaActivation(
+ TfLiteFusedActivation act) {
+ switch (act) {
+ case kTfLiteActNone:
+ return ActivationFunctionType_NONE;
+ case kTfLiteActRelu:
+ return ActivationFunctionType_RELU;
+ case kTfLiteActRelu1:
+ return ActivationFunctionType_RELU_N1_TO_1;
+ case kTfLiteActRelu6:
+ return ActivationFunctionType_RELU6;
+ case kTfLiteActTanh:
+ return ActivationFunctionType_TANH;
+ case kTfLiteActSignBit:
+ return ActivationFunctionType_SIGN_BIT;
+ case kTfLiteActSigmoid:
+ return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
+ }
+ return ActivationFunctionType_NONE;
+}
+
+inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingUnknown:
+ return Padding_SAME; // TODO(aselle): Consider an error.
+ case kTfLitePaddingSame:
+ return Padding_SAME;
+ case kTfLitePaddingValid:
+ return Padding_VALID;
+ }
+ return Padding_SAME; // TODO(aselle): Consider an error.
+}
+
+inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
+ switch (type) {
+ // case kTfLiteNoType: return TensorType_NONE;
+ case kTfLiteNoType:
+ return TensorType_FLOAT32; // TODO(aselle): Consider an error.
+ case kTfLiteFloat32:
+ return TensorType_FLOAT32;
+ case kTfLiteInt32:
+ return TensorType_INT32;
+ case kTfLiteUInt8:
+ return TensorType_UINT8;
+ case kTfLiteInt64:
+ return TensorType_INT64;
+ case kTfLiteString:
+ return TensorType_STRING;
+ case kTfLiteBool:
+ return TensorType_BOOL;
+ case kTfLiteInt16:
+ return TensorType_INT16;
+ case kTfLiteComplex64:
+ return TensorType_COMPLEX64;
+ }
+ // TODO(aselle): consider an error
+}
+
+inline FullyConnectedOptionsWeightsFormat
+FullyConnectedOptionsWeightsFormatToSchema(
+ TfLiteFullyConnectedWeightsFormat format) {
+ switch (format) {
+ case kTfLiteFullyConnectedWeightsFormatDefault:
+ return FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
+ return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ }
+}
+
+inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
+ switch (type) {
+ case kTfLiteLSTMFullKernel:
+ return LSTMKernelType_FULL;
+ case kTfLiteLSTMBasicKernel:
+ return LSTMKernelType_BASIC;
+ }
+}
+
+inline LSHProjectionType LSHProjectionTypeToSchema(
+ TfLiteLSHProjectionType type) {
+ switch (type) {
+ case kTfLiteLshProjectionUnknown:
+ return LSHProjectionType_UNKNOWN;
+ case kTfLiteLshProjectionSparse:
+ return LSHProjectionType_SPARSE;
+ case kTfLiteLshProjectionDense:
+ return LSHProjectionType_DENSE;
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000000..e6d5a776b3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h" // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+ "TfLitePoolParams",
+ "TfLiteDepthwiseConvParams",
+ "TfLiteSVDFParams",
+ "TfLiteRNNParams",
+ "TfLiteSequenceRNNParams",
+ "TfLiteFullyConnectedParams",
+ "TfLiteLSHProjectionParams",
+ "TfLiteSoftmaxParams",
+ "TfLiteConcatenationParams",
+ "TfLiteAddParams",
+ "TfLiteSpaceToBatchNDParams",
+ "TfLiteBatchToSpaceNDParams",
+ "TfLiteMulParams",
+ "TfLiteSubParams",
+ "TfLiteDivParams",
+ "TfLiteL2NormParams",
+ "TfLiteLocalResponseNormParams",
+ "TfLiteLSTMParams",
+ "TfLiteResizeBilinearParams",
+ "TfLitePadParams",
+ "TfLitePadV2Params",
+ "TfLiteReshapeParams",
+ "TfLiteSkipGramParams",
+ "TfLiteSpaceToDepthParams",
+ "TfLiteCastParams",
+ "TfLiteEmbeddingLookupSparseParams",
+ "TfLiteGatherParams",
+ "TfLiteTransposeParams",
+ "TfLiteReducerParams",
+ "TfLiteSplitParams",
+ "TfLiteSqueezeParams",
+ "TfLiteStridedSliceParams",
+ "TfLiteArgMaxParams",
+ "TfLiteArgMinParams",
+ "TfLiteTransposeConvParams",
+ "TfLiteSparseToDenseParams",
+ "TfLiteShapeParams",
+ "TfLiteFakeQuantParams",
+ "TfLitePackParams",
+ "TfLiteOneHotParams",
+ nullptr};
+} // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+ const char* s = in.c_str();
+ bool first = true;
+ std::string out;
+ while (*s != '\0') {
+ if (*s == '_') {
+ first = true;
+ } else if (first) {
+ out.push_back(tolower(*s));
+ first = false;
+ } else {
+ out.push_back(tolower(*s));
+ }
+ s++;
+ }
+ return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+ OpOptionData() {
+ BuildOpList();
+ BuildOptionToTypeFunctionMap();
+ BuildOpToOptionMap();
+ }
+
+ // A list of builtin operations
+ const std::vector<std::string>& ops() const { return ops_; }
+ // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+ const std::unordered_map<std::string, std::string>& op_to_option() {
+ return op_to_option_;
+ }
+ // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+ const std::unordered_map<std::string, std::string>& option_to_struct() {
+ return option_to_struct_;
+ }
+ // Maps from option to a flatbuffer type function that describes that option.
+ const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+ option_to_type_function() {
+ return option_to_type_function_;
+ }
+
+ private:
+ void BuildOpList() {
+ for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+ ++curr) {
+ if (strlen(*curr) != 0) ops_.push_back(*curr);
+ }
+ }
+
+ void BuildOptionToTypeFunctionMap() {
+ auto d = tflite::BuiltinOptionsTypeTable();
+ for (int i = 0; i < d->num_elems; i++) {
+ flatbuffers::TypeCode code = d->type_codes[i];
+ if (code.sequence_ref != -1) {
+ option_to_type_function_.insert(
+ std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+ }
+ }
+ }
+
+ void BuildOpToOptionMap() {
+ // Manually specified mappings between ops and options
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+ op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+ op_to_option_["UNPACK"] = "";
+ op_to_option_["SUM"] = "ReducerOptions";
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+ op_to_option_["MEAN"] = "ReducerOptions";
+ op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ // Manually specified mappings between ops and options (none)
+ op_to_option_["EMBEDDING_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["FLOOR"] = "";
+ op_to_option_["HASHTABLE_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["LOGISTIC"] = "";
+ op_to_option_["RELU"] = "";
+ op_to_option_["RELU_N1_TO_1"] = "";
+ op_to_option_["RELU6"] = "";
+ op_to_option_["TANH"] = "";
+ op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["PRELU"] = "";
+ op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["SIN"] = "";
+ op_to_option_["LOG"] = "";
+ op_to_option_["SQRT"] = "";
+ op_to_option_["RSQRT"] = "";
+
+ // TODO(aselle): These are undesirable hacks. Consider changing C structs
+ option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+ option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+ option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+ option_to_struct_["LocalResponseNormalizationOptions"] =
+ "TfLiteLocalResponseNormParams";
+ // Now for every op, try to find an option.
+ bool fatal = false;
+ for (auto op_name : ops_) {
+ bool found_option = false;
+ auto d = tflite::BuiltinOptionsTypeTable();
+ std::string collapsed_option_name_guess =
+ ToCollapsed(op_name) + "options";
+ // O(n^2) but not that big of n.
+ for (int i = 0; i < d->num_elems; i++) {
+ std::string option_name = d->names[i];
+ std::string collapsed_option_name = ToCollapsed(option_name);
+ if (collapsed_option_name_guess == collapsed_option_name) {
+ op_to_option_.insert(std::make_pair(op_name, option_name));
+ found_option = true;
+ break;
+ }
+ }
+ auto it = op_to_option_.find(op_name);
+ if (it == op_to_option_.end()) {
+ std::cerr << "Didn't find option for " << op_name << std::endl;
+ fatal = true;
+ } else if (!it->second.empty()) {
+ std::string option_name = it->second;
+
+ if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+ bool param_struct_found = false;
+ std::string params_guess = std::string("TfLite") + option_name;
+ size_t start = params_guess.find("Options");
+ size_t len = strlen("Options");
+ params_guess.replace(start, len, "Params");
+ for (auto* param = param_structs; *param != nullptr; param++) {
+ if (*param == params_guess) {
+ param_struct_found = true;
+ break;
+ }
+ }
+ if (!param_struct_found) {
+ std::cerr << "Failed to get param struct for option " << option_name
+ << std::endl;
+ fatal = true;
+ } else {
+ option_to_struct_.insert(std::make_pair(option_name, params_guess));
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<std::string> ops_;
+ std::unordered_map<std::string, std::string> op_to_option_;
+ std::unordered_map<std::string, std::string> option_to_struct_;
+ std::unordered_map<std::string, flatbuffers::TypeFunction>
+ option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+ const std::string& option_name,
+ const std::string& option_type,
+ const flatbuffers::TypeTable* options,
+ const std::string& struct_name) {
+ // Skip tricky ones for now
+ if (struct_name == "TfLiteResizeBilinearParams") return;
+ if (struct_name == "TfLiteSqueezeParams") return;
+ if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+ if (struct_name == "TfLiteReshapeParams") return;
+
+ fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
+ fprintf(fp,
+ " const auto* params = reinterpret_cast<const "
+ "%s*>(builtin_op_data);\n",
+ struct_name.c_str());
+
+ for (size_t i = 0; i < options->num_elems; i++) {
+ std::string elem_name = options->names[i];
+ // TODO(aselle): Irregular naming in builtins
+ if (elem_name == "fused_activation_function")
+ elem_name = "activation";
+ else if (elem_name == "stride_w")
+ elem_name = "stride_width";
+ else if (elem_name == "stride_h")
+ elem_name = "stride_height";
+ else if (elem_name == "dilation_h_factor")
+ elem_name = "dilation_height_factor";
+ else if (elem_name == "dilation_w_factor")
+ elem_name = "dilation_width_factor";
+ else if (elem_name == "new_shape")
+ elem_name = "shape";
+
+ flatbuffers::TypeCode code = options->type_codes[i];
+ auto contained_type = code.sequence_ref != -1
+ ? options->type_refs[code.sequence_ref]
+ : nullptr;
+ std::string mapper = "";
+ if (contained_type == TensorTypeTypeTable) {
+ mapper = "TfLiteTypeToSchemaType";
+ } else if (contained_type == ActivationFunctionTypeTypeTable) {
+ mapper = "TfLiteActivationToSchemaActivation";
+ } else if (contained_type == PaddingTypeTable) {
+ mapper = "TfLitePaddingToSchemaPadding";
+ } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+ mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+ } else if (contained_type == LSTMKernelTypeTypeTable) {
+ mapper = "LSTMKernelTypeToSchema";
+ } else if (contained_type == LSHProjectionTypeTypeTable) {
+ mapper = "LSHProjectionTypeToSchema";
+ }
+
+ fprintf(fp,
+ " auto val%zu = "
+ "%s(params->%s);\n",
+ i, mapper.c_str(), elem_name.c_str());
+ }
+ fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
+ for (size_t i = 0; i < options->num_elems; i++) {
+ fprintf(fp, ", val%zu", i);
+ }
+ fprintf(fp, ").Union();\n");
+ fprintf(fp, " return std::make_pair(%s, union_type);\n",
+ option_type.c_str());
+ fprintf(fp, " }\n break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+ std::unordered_set<std::string> ignores;
+ ignores.insert("CONCAT_EMBEDDINGS");
+ ignores.insert("CALL");
+
+ // Allow any op that doesn't have an options struct to be blocked
+ // together
+ for (const auto& op_name : option->ops()) {
+ auto option_it = option->op_to_option().find(op_name);
+ if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+ continue;
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ }
+ fprintf(fp,
+ " return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+
+ // Iterate over each ops
+ for (const auto& op_name : option->ops()) {
+ if (ignores.find(op_name) != ignores.end()) continue;
+ // Get to the option and struct names, continuing if not found.
+ auto option_it = option->op_to_option().find(op_name);
+ if (option_it->second.empty()) continue;
+ std::string option_name = option_it->second;
+ std::string option_type = "BuiltinOptions_" + option_name;
+ auto option_func_it = option->option_to_type_function().find(option_name);
+ if (option_func_it == option->option_to_type_function().end()) continue;
+ auto struct_name_it = option->option_to_struct().find(option_name);
+ if (struct_name_it == option->option_to_struct().end()) {
+ // If no C struct, then it better have no arguments.
+ auto type_info = option_func_it->second();
+ if (type_info->num_elems != 0) {
+ // We have non-zero arguments in the schema, this means there
+ // should be a struct.
+ fprintf(stderr,
+ "Op %s uses option struct %s which has no builtin struct\n",
+ op_name.c_str(), option_name.c_str());
+ exit(1);
+ }
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
+ option_type.c_str(), option_name.c_str());
+ } else {
+ // If C struct, then we need to assign all properties
+ auto struct_name = struct_name_it->second;
+ GenerateImportForOp(fp, op_name, option_name, option_type,
+ option_func_it->second(), struct_name);
+ }
+ }
+ // TODO(aselle): Handle unhandled cases more gracefully.
+ fprintf(fp,
+ "default: return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+}
+
+} // namespace tflite
+
+int main(int argc, char* argv[]) {
+ tflite::OpOptionData option;
+ if (argc != 2) {
+ fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+ return 1;
+ }
+ FILE* fp = fopen(argv[1], "w");
+ tflite::GenerateImport(&option, fp);
+ fclose(fp);
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc
new file mode 100644
index 0000000000..20ede214fb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Just does a read/write loop of tflite file format using the interpreter as
+// an intermediate.
+//
+// Usage:
+// writer <input tflite> <output tflite>
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+int main(int argc, char* argv[]) {
+ if (argc != 3) {
+ fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
+ return 1;
+ }
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(argv[1]);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
+ tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
+ tflite::InterpreterWriter writer(interpreter.get());
+ writer.Write(argv[2]);
+
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
new file mode 100644
index 0000000000..52b17faf82
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <cstdlib>
+#include <cstring>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+template <class T>
+using Offset = flatbuffers::Offset<T>;
+template <class T>
+using Vector = flatbuffers::Vector<T>;
+using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
+
+std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
+ FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
+ switch (op) {
+#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
+ }
+ return std::make_pair(BuiltinOptions_NONE, Offset<void>());
+}
+
+template <class T_OUTPUT, class T_INPUT>
+Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
+ const T_INPUT& v) {
+ std::vector<T_OUTPUT> inputs(v.begin(), v.end());
+ return fbb->template CreateVector<T_OUTPUT>(inputs);
+}
+
+Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Operator>> operators;
+
+ std::vector<int> operator_to_opcode;
+ // TODO(aselle): Augment this once we put execution plan in schema.
+ operator_to_opcode.resize(interpreter_->nodes_size(), -1);
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteRegistration* registration = &node_and_registration->second;
+ if (!registration->custom_name) {
+ operator_to_opcode[op_index] =
+ GetOpCodeForBuiltin(registration->builtin_code);
+ } else {
+ operator_to_opcode[op_index] =
+ GetOpCodeForCustom(registration->custom_name);
+ }
+ }
+ // second pass serialize operators
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ Offset<void> builtin_options;
+ BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
+ // Custom data
+ // TODO(aselle): Custom options format is not known by default. Just assume
+ // for now.
+ auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
+ Offset<Vector<uint8_t>> custom_options = 0;
+
+ if (!registration.custom_name) {
+ // builtin
+ auto builtin_options_and_type = CreateBuiltinUnion(
+ fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
+ node.builtin_data);
+ builtin_options = builtin_options_and_type.second;
+ builtin_options_type = builtin_options_and_type.first;
+ } else {
+ auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
+ if (custom_writer != custom_op_to_writer_.end() &&
+ custom_writer->second) {
+ // delegate to custom writer if it exists
+ custom_writer->second(fbb, interpreter_, op_index, &custom_options,
+ &custom_options_format);
+ } else {
+ // use the custom data as fact
+ custom_options = fbb->CreateVector(
+ reinterpret_cast<const uint8_t*>(node.custom_initial_data),
+ node.custom_initial_data_size);
+ }
+ }
+
+ int opcode_index = operator_to_opcode[op_index];
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
+ auto inputs = ExportVector<int32_t>(fbb, written_inputs);
+ auto outputs = ExportVector<int32_t>(fbb, written_outputs);
+ operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
+ builtin_options_type, builtin_options,
+ custom_options, custom_options_format));
+ }
+
+ return fbb->template CreateVector<Offset<Operator>>(operators);
+}
+
+Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
+ FlatBufferBuilder* fbb) {
+ tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
+
+ std::vector<Offset<Tensor>> tensors;
+
+ // Make a map from tensor index to whether the tensor is a temporary.
+ std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
+ for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ for (auto tensor_index :
+ TfLiteIntArrayView(node_and_registration->first.temporaries))
+ tensor_is_temporary[tensor_index] = true;
+ }
+
+ // Now we need to remap all used tensor indices
+ int curr_output_index = 0;
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ tensor_index++) {
+ if (!tensor_is_temporary[tensor_index]) {
+ tensor_to_written_tensor_[tensor_index] = curr_output_index++;
+ }
+ }
+
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ ++tensor_index) {
+ // Skip temporaries.
+ if (tensor_is_temporary[tensor_index]) continue;
+
+ if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
+ // We only need to convert non temporaries
+ if (tensor->allocation_type != kTfLiteArenaRw &&
+ tensor->allocation_type != kTfLiteMmapRo &&
+ tensor->allocation_type != kTfLiteArenaRwPersistent)
+ continue;
+ // Allocate a buffer index
+ int buffer_index = 0; // This is null
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ buffer_index = buffers_.size();
+ buffers_.push_back(std::make_pair(
+ reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
+ }
+ // Primitive type.
+ TensorType type = TfLiteTypeToSchemaType(tensor->type);
+ // Handle quantization
+ const Offset<Vector<float>> null_array;
+ Offset<Vector<float>> scale_array;
+ Offset<Vector<int64_t>> zero_point_array;
+ if (tensor->params.scale != 0.f) {
+ // We have quantization, make a single arugment array (multi channel
+ // quant needs updating here).
+ scale_array = fbb->CreateVector<float>({tensor->params.scale});
+ zero_point_array =
+ fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ }
+ Offset<QuantizationParameters> quantization_params =
+ CreateQuantizationParameters(*fbb, null_array, null_array,
+ scale_array, zero_point_array);
+ // Shape
+ TfLiteIntArrayView shape_view(tensor->dims);
+ std::vector<int> shape =
+ std::vector<int>(shape_view.begin(), shape_view.end());
+
+ tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
+ type, buffer_index,
+ fbb->CreateString(tensor->name),
+ quantization_params, tensor->is_variable));
+ }
+ }
+ return fbb->template CreateVector<Offset<Tensor>>(tensors);
+}
+
+Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (auto buffer : buffers_) {
+ auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+ buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+ }
+ return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
+}
+
+Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<OperatorCode>> codes;
+ for (auto it : opcodes_) {
+ const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+ codes.push_back(CreateOperatorCodeDirect(
+ *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+ }
+ return fbb->template CreateVector<Offset<OperatorCode>>(codes);
+}
+
+template <class T>
+std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
+ const T& input) {
+ std::vector<int> output;
+ output.reserve(input.size());
+ for (int x : input) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
+ return output;
+}
+
+TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+ size_t* size) {
+ if (!out || !size) return kTfLiteError;
+ FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ std::vector<Offset<SubGraph>> subgraphs_as_vector;
+ { // subgraph specific stuff
+ auto tensors = ExportTensors(&builder);
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(interpreter_->inputs());
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(interpreter_->outputs());
+ auto inputs = ExportVector<int32_t>(&builder, written_inputs);
+ auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+
+ auto ops = ExportOperators(&builder);
+ subgraphs_as_vector.push_back(
+ CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
+ }
+ Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
+
+ auto description = builder.CreateString("Exported from Interpreter.");
+
+ auto op_codes = CreateOpCodeTable(&builder);
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs_as_vector),
+ description, buffers);
+ ::tflite::FinishModelBuffer(builder, model);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ *size = builder.GetSize();
+ (*out).reset(new uint8_t[*size]);
+ memcpy(out->get(), buffer, *size);
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
+ std::unique_ptr<uint8_t[]> buffer;
+ size_t size;
+ TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+
+ FILE* fp = fopen(filename.c_str(), "wb");
+ if (!fp) return kTfLiteError;
+
+ if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
+ if (fclose(fp)) return kTfLiteError;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::RegisterCustomWriter(
+ const std::string& custom_name, CustomWriter custom_writer) {
+ if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
+ return kTfLiteError;
+ }
+ custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
new file mode 100644
index 0000000000..a98108b496
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -0,0 +1,126 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
+//
+// Usage:
+// From command line:
+// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
+// -- foo.tflite foo.out.tflite
+//
+// From C++
+// std::unique_ptr<Interpreter> interpreter;
+// // Build Interpreter however
+// // ... <omitted>
+// InterpreterWriter(interpreter.get()).Write("output.tflite");
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#include <iostream>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
+// file format.
+class InterpreterWriter {
+ public:
+ typedef flatbuffers::Offset<Operator> (*CustomWriter)(
+ flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
+ int node_index,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
+ CustomOptionsFormat* custom_options_format);
+
+ // Construct an interpreter writer for the specified `interpreter`. Then,
+ // a uses .Write() or .GetBuffer(...) to extract the data.
+ explicit InterpreterWriter(Interpreter* interpreter)
+ : interpreter_(interpreter) {
+ buffers_.push_back(std::make_pair(nullptr, 0));
+ }
+
+ // Get a buffer and size of a serialized flatbuffer.
+ TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+ // Write the serialized flatbuffer to the prescribed `filename`.
+ TfLiteStatus Write(const std::string& filename);
+ // Registers a custom writer for a custom op. The customization allows the
+ // caller to change the custom data.
+ TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
+ CustomWriter custom_writer);
+
+ private:
+ template <class T>
+ using Offset = flatbuffers::Offset<T>;
+ template <class T_OUTPUT, class T_INPUT>
+ Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
+ flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
+ Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+ flatbuffers::FlatBufferBuilder* fbb);
+
+ template <class T>
+ std::vector<int> RemapTensorIndicesToWritten(const T& input);
+
+ int GetOpCodeForBuiltin(int builtin_op_index) {
+ // auto it = builtin_op_to_opcode_.find(builtin_op_index);
+ std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
+ builtin_op_to_opcode_.insert(
+ std::make_pair(builtin_op_index, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({builtin_op_index, ""});
+ }
+ return result.first->second;
+ }
+
+ int GetOpCodeForCustom(const std::string& custom_name) {
+ std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
+ custom_op_to_opcode_.insert(
+ std::make_pair(custom_name, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+ }
+ return result.first->second;
+ }
+
+ // The interpreter we are writing
+ Interpreter* interpreter_;
+ // Keep track of byte buffers
+ std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+ // List of op codes and mappings from builtin or custom op to opcode
+ struct OpCode {
+ int builtin;
+ std::string custom;
+ };
+ // For every tensor index in the interpreter, the index in the written.
+ // This is different due to temporary tensors not being written.
+ std::vector<int> tensor_to_written_tensor_;
+ // List of used opcodes
+ std::vector<OpCode> opcodes_;
+ std::unordered_map<int, int> builtin_op_to_opcode_;
+ std::unordered_map<std::string, int> custom_op_to_opcode_;
+ std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
new file mode 100644
index 0000000000..49194a76c8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+// Make an interpreter that has no tensors and no nodes
+// TODO(b/113731921): add more tests.
+TEST(Writer, BasicTest) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+ float foo[] = {1, 2, 3};
+ interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetTensorParametersReadOnly(
+ 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
+ reinterpret_cast<char*>(foo), sizeof(foo));
+ interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({2});
+ const char* initial_data = "";
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ TfLiteAddParams* builtin_data =
+ reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+ builtin_data->activation = kTfLiteActNone;
+ const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+ interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+ reinterpret_cast<void*>(builtin_data), reg);
+
+ InterpreterWriter writer(&interpreter);
+ writer.Write("/tmp/test.tflite");
+ std::unique_ptr<FlatBufferModel> model =
+ FlatBufferModel::BuildFromFile("/tmp/test.tflite");
+ InterpreterBuilder builder(*model, resolver);
+ std::unique_ptr<Interpreter> new_interpreter;
+ builder(&new_interpreter);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md
deleted file mode 100644
index e3db478481..0000000000
--- a/tensorflow/contrib/lite/g3doc/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-This is a *work-in-progress* TF Lite subsite for:
-https://www.tensorflow.org/mobile
-
-DO NOT PUBLISH
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 98abd5743b..1dffe30790 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -1,6 +1,7 @@
upper_tabs:
# Tabs left of dropdown menu
- include: /_upper_tabs_left.yaml
+- include: /versions/_upper_tabs_versions.yaml
# Dropdown menu
- name: Ecosystem
path: /ecosystem
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
deleted file mode 100644
index 70031a3c3d..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
+++ /dev/null
@@ -1,10 +0,0 @@
-Project: /mobile/_project.yaml
-Book: /mobile/_book.yaml
-page_type: reference
-<style> table img { max-width: 100%; } </style>
-<script src="/_static/js/managed/mathjax/MathJax.js?config=TeX-AMS-MML_SVG"></script>
-
-<!-- DO NOT EDIT! Automatically generated file. -->
-# All symbols in TensorFlow Lite
-
-TEMP PAGE
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index 776803da8c..69616c7b8a 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite APIs
@@ -39,7 +37,7 @@ float* output = interpreter->typed_output_tensor<float>(0);
```
### Data Alignment
-TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended
+TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended
that all data provided to TensorFlow Lite be aligned that way.
### Error Reporting
@@ -114,7 +112,7 @@ below. It should be noted that:
* Tensors are represented by integers, in order to avoid string comparisons
(and any fixed dependency on string libraries).
- * An interpreter must not be accessed from concurrent threads
+ * An interpreter must not be accessed from concurrent threads.
* Memory allocation for input and output tensors must be triggered
by calling AllocateTensors() right after resizing tensors.
@@ -171,7 +169,7 @@ former provides error reporting facilities and access to global objects,
including all the tensors. The latter allows implementations to access their
inputs and outputs.
-When the interpreter loads a model, it calls init() once for each node in the
+When the interpreter loads a model, it calls `init()` once for each node in the
graph. A given `init()` will be called more than once if the op is used
multiple times in the graph. For custom ops a configuration buffer will be
provided, containing a flexbuffer that maps parameter names to their values.
@@ -212,8 +210,9 @@ namespace custom {
Note that registration is not automatic and an explicit call to
`Register_MY_CUSTOM_OP` should be made somewhere. While the standard
-`:builtin_ops` takes care of the registration of builtins, custom ops will have
-to be collected in separated custom libraries.
+`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the
+registration of builtins, custom ops will have to be collected in separate
+custom libraries.
### Customizing the kernel library
@@ -234,7 +233,7 @@ class OpResolver {
};
```
-The regular usage will require the developer to use the `BuiltinOpResolver` and
+Regular usage will require the developer to use the `BuiltinOpResolver` and
write:
```c++
@@ -310,18 +309,25 @@ an `IllegalArgumentException` will be thrown.
#### Inputs
-Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of
-the supported primitive types.
+Each input should be an array or multi-dimensional array of the supported
+primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is
+an array or multi-dimensional array, the associated input tensor will be
+implicitly resized to the array's dimensions at inference time. If the input is
+a ByteBuffer, the caller should first manually resize the associated input
+tensor (via `Interpreter.resizeInput()`) before running inference.
-The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid
-unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its
-order must be `ByteOrder.nativeOrder()`. After it is used for a model inference,
-it must remain unchanged until the model inference is finished.
+When using 'ByteBuffer', prefer using direct byte buffers, as this allows the
+`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte
+buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a
+model inference, it must remain unchanged until the model inference is finished.
#### Outputs
-Each output should be an array, or a multi-dimensional array of the supported
-primitive types.
+Each output should be an array or multi-dimensional array of the supported
+primitive types, or a ByteBuffer of the appropriate size. Note that some models
+have dynamic outputs, where the shape of output tensors can vary depending on
+the input. There's no straightforward way of handling this with the existing
+Java inference API, but planned extensions will make this possible.
#### Running Model Inference
@@ -341,9 +347,10 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
where each entry in `inputs` corresponds to an input tensor and
`map_of_indices_to_outputs` maps indices of output tensors to the
corresponding output data. In both cases the tensor indices should correspond to
-the values given to the `TensorFlow Lite Optimized Converter` when the model was
-created. Be aware that the order of tensors in `input` must match the order
-given to the `TensorFlow Lite Optimized Converter`.
+the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md)
+when the model was created. Be aware that the order of tensors in `input` must
+match the order given to the `TensorFlow Lite Optimized Converter`.
+
The Java API also provides convenient functions for app developers to get the
index of any model input or output using a tensor name:
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
index d979353bb3..ee6150b60e 100644
--- a/tensorflow/contrib/lite/g3doc/custom_operators.md
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# How to use custom operators
diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md
index d79a2696b4..c38b928684 100644
--- a/tensorflow/contrib/lite/g3doc/demo_android.md
+++ b/tensorflow/contrib/lite/g3doc/demo_android.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Android Demo App
diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md
index a554898899..7579ad84a0 100644
--- a/tensorflow/contrib/lite/g3doc/demo_ios.md
+++ b/tensorflow/contrib/lite/g3doc/demo_ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# iOS Demo App
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index dc9cc98c08..90e7915c52 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Developer Guide
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index d78d373ccf..a83d2c8fec 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite for iOS
@@ -38,7 +36,7 @@ brew link libtool
Then you need to run a shell script to download the dependencies you need:
```bash
-tensorflow/contrib/lite/download_dependencies.sh
+tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
This will fetch copies of libraries and data from the web and install them in
@@ -48,14 +46,14 @@ With all of the dependencies set up, you can now build the library for all five
supported architectures on iOS:
```bash
-tensorflow/contrib/lite/build_ios_universal_lib.sh
+tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh
```
Under the hood this uses a makefile in `tensorflow/contrib/lite` to build the
different versions of the library, followed by a call to `lipo` to bundle them
into a universal file containing armv7, armv7s, arm64, i386, and x86_64
architectures. The resulting library is in
-`tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/libtensorflow-lite.a`.
If you get an error such as `no such file or directory: 'x86_64'` when running
`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 4ceb9a53dc..a4267eee4c 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,66 +1,70 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# List of Hosted Models
## Image classification (Float Models)
-Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
-------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
-DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
-SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
-NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms
-NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms
-ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms
-ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms
-Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms
-Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms
-Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms
-Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms
-Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms
-Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms
-Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms
-Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms
-Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms
-Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms
-Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms
-Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms
-Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms
-Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms
-Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms
-Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms
-Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
-Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms
-Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms
+Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
+--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
+DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
+SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
+NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms
+NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms
+ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms
+Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms
+Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms
+Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms
+Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms
+Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
+Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms
+Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms
+Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms |
^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
^^ The performance numbers are generated in the benchmark on Pixel-2 using
single thread large core.
+^^ Accuracy numbers were computed using the
+[TFLite accuracy tool](../tools/accuracy/ilsvrc) .
+
## Image classification (Quantized Models)
-Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
-Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms
-Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms
-Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms
-Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms
-Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms
-Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms
-Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms
-Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms
-Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms
-Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms
-Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms
-Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms
-Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms
-Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms
-Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms
+Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
+--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.7% | 70.8% | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 72.8% | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.1% | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.2% | 80.5% | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.9% | 82.1% | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.2% | 83.2% | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.9% | 79.1% | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.4% | 83.7% | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.2% | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.9% | 86.9% | 45.4 ms
+Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.3% | 84.1% | 24.9 ms
+Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.9% | 86.7% | 37.4 ms
+Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.1% | 88.1% | 51.9 ms
+Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.0% | 89.0% | 70.2 ms
+Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 70.8% | 89.9% | 80.3 ms
+Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms
## Other models
Model | TF Lite FlatBuffer
----------------------- | :----------------:
-Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
+[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html),
+[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md
index b06f4fd3b8..0d571ce547 100644
--- a/tensorflow/contrib/lite/g3doc/ops_versioning.md
+++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite Ops Versioning
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index be60d7941a..8cf43496df 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Introduction to TensorFlow Lite
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 5cd0aab44f..28cb6aba6e 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Performance
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md
index 9fcf79ba00..41a1892b6f 100644
--- a/tensorflow/contrib/lite/g3doc/rpi.md
+++ b/tensorflow/contrib/lite/g3doc/rpi.md
@@ -1,30 +1,36 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
-
# TensorFlow Lite for Raspberry Pi
## Cross compiling
-### Installing toolchian
-This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
-To cross compiling TensorFlow Lite. First you should install the toolchain and libs.
+### Installing the toolchain
+
+This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image
+[tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+To cross compile TensorFlow Lite, first install the toolchain and libs.
+
```bash
sudo apt-get update
sudo apt-get install crossbuild-essential-armhf
```
-> If you are using docker, you may not use `sudo`
+
+> If you are using Docker, you may not use `sudo`.
### Building
+
Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
+
> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
+
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
@@ -33,21 +39,23 @@ This should compile a static library in:
## Native compiling
This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
-Log in to you RPI, install the toolchain.
+Log in to you Raspberry Pi, install the toolchain.
+
```bash
sudo apt-get install build-essential
```
-First, clone this TensorFlow repository. Run this at the root of the repository:
+First, clone the TensorFlow repository. Run this at the root of the repository:
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
-`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`.
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index aa65ec9988..8660d29855 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite & TensorFlow Compatibility Guide
@@ -843,6 +841,31 @@ Outputs {
}
```
+**UNPACK**
+
+```
+Inputs {
+ 0: a tensor.
+ 1: an integer.
+ 2: an integer.
+}
+Outputs {
+ 0-N: tensors of unpacked tensor.
+}
+```
+
+**FLOOR_DIV**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: a list of tensors.
+}
+Outputs {
+ 0: A tensor of floor_div output tensors.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index 76e16fc9db..c7cdee07de 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on Android
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index bd047bfcec..d003bb2f38 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Overview
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index 6223707892..be8b4100c8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on iOS
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4c2071ed05..4d4bb3bc08 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Integrating TensorFlow libraries
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index a0192c3541..7436594fd8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Optimizing for mobile
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index 6b4e4a92bd..d1c67d4c61 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Preparing models for mobile deployment
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 77268d7aeb..8ee83827bb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 362e588725..3f8f4d198f 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include <cstring>
#include "tensorflow/contrib/lite/arena_planner.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -476,6 +476,10 @@ TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
return kTfLiteOk;
}
+void Interpreter::ReserveNodes(int count) {
+ nodes_and_registration_.reserve(count);
+}
+
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 135732917e..f0cd178c19 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
namespace tflite {
@@ -136,6 +137,11 @@ class Interpreter {
// interpreter.
TfLiteStatus SetVariables(std::vector<int> variables);
+ // Ensure the internal node storage memory allocates at least `count`
+ // spots for node. NOTE, this doesn't actually add operators. This is an
+ // efficiency optimization that is subject to change.
+ void ReserveNodes(int count);
+
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
@@ -413,6 +419,10 @@ class Interpreter {
return op_reg.profiling_string(&context_, node);
}
+ // Set the value of an external context.
+ void SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+
private:
friend class InterpreterBuilder;
friend class InterpreterTest;
@@ -544,8 +554,6 @@ class Interpreter {
struct TfLiteContext* context, TfLiteExternalContextType type);
// Set the value of an external context.
- void SetExternalContext(TfLiteExternalContextType type,
- TfLiteExternalContext* ctx);
static void SetExternalContext(struct TfLiteContext* context,
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 5bcf0927d8..cdede430e2 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index e3cea19e16..6a3f0651d0 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -20,9 +20,6 @@ code to merge.
- Make sure to install the latest version of Bazel. Some distributions
ship with Bazel 0.5.4, which is too old.
- Bazel requires Android Build Tools `26.0.1` or higher.
- - **Bazel is incompatible with NDK revisions 15 and above,** with revision
- 16 being a compile-breaking change. [Download an older version manually
- instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites)
- You also need to install the Android Support Repository, available
through Android Studio under `Android SDK Manager -> SDK Tools ->
Android Support Repository`.
@@ -37,8 +34,7 @@ code to merge.
- Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
you have installed.
- By default, Android Studio will install the SDK to `~/Android/Sdk` and
- the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual
- download until Bazel supports NDK 16. See bullet points under (1)).
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build the app with Bazel. The demo needs C++11:
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index 92f04c651c..05301ebf88 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -10,7 +10,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -44,9 +43,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -54,8 +50,6 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 06f46fb923..781289ceb2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -35,6 +35,7 @@ java_binary(
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
],
main_class = "org.tensorflow.ovic.OvicValidator",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java",
],
@@ -47,6 +48,7 @@ android_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
@@ -61,6 +63,7 @@ java_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
"//tensorflow/contrib/lite/java:tensorflowlite_java",
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
index 2a08608bbb..4f3a6cdb2f 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,9 +42,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -53,6 +49,4 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:+'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 94a1ec65d6..41093e8ffe 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -15,8 +15,8 @@ limitations under the License.
package org.tensorflow.lite;
-/** Type of elements in a {@link TfLiteTensor}. */
-enum DataType {
+/** Represents the type of elements in a TensorFlow Lite {@link Tensor} as an enum. */
+public enum DataType {
/** 32-bit single precision floating point. */
FLOAT32(1),
@@ -35,13 +35,29 @@ enum DataType {
this.value = value;
}
- /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */
- int getNumber() {
+ /** Returns the size of an element of this type, in bytes, or -1 if element size is variable. */
+ public int byteSize() {
+ switch (this) {
+ case FLOAT32:
+ return 4;
+ case INT32:
+ return 4;
+ case UINT8:
+ return 1;
+ case INT64:
+ return 8;
+ }
+ throw new IllegalArgumentException(
+ "DataType error: DataType " + this + " is not supported yet");
+ }
+
+ /** Corresponding value of the TfLiteType enum in the TensorFlow Lite C API. */
+ int c() {
return value;
}
- /** Converts an integer to the corresponding type. */
- static DataType fromNumber(int c) {
+ /** Converts a C TfLiteType enum value to the corresponding type. */
+ static DataType fromC(int c) {
for (DataType t : values) {
if (t.value == c) {
return t;
@@ -55,22 +71,6 @@ enum DataType {
+ ")");
}
- /** Returns byte size of the type. */
- int elemByteSize() {
- switch (this) {
- case FLOAT32:
- return 4;
- case INT32:
- return 4;
- case UINT8:
- return 1;
- case INT64:
- return 8;
- }
- throw new IllegalArgumentException(
- "DataType error: DataType " + this + " is not supported yet");
- }
-
/** Gets string names of the data type. */
String toStringName() {
switch (this) {
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 7002f82677..b84720ae8e 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -162,9 +162,7 @@ public final class Interpreter implements AutoCloseable {
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.run(inputs, outputs);
}
@@ -174,12 +172,16 @@ public final class Interpreter implements AutoCloseable {
* <p>IllegalArgumentException will be thrown if it fails to resize.
*/
public void resizeInput(int idx, @NonNull int[] dims) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.resizeInput(idx, dims);
}
+ /** Gets the number of input tensors. */
+ public int getInputTensorCount() {
+ checkNotClosed();
+ return wrapper.getInputTensorCount();
+ }
+
/**
* Gets index of an input given the op name of the input.
*
@@ -187,51 +189,65 @@ public final class Interpreter implements AutoCloseable {
* to initialize the {@link Interpreter}.
*/
public int getInputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getInputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied input index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getInputTensor(int inputIndex) {
+ checkNotClosed();
+ return wrapper.getInputTensor(inputIndex);
+ }
+
+ /** Gets the number of output Tensors. */
+ public int getOutputTensorCount() {
+ checkNotClosed();
+ return wrapper.getOutputTensorCount();
+ }
+
+ /**
* Gets index of an output given the op name of the output.
*
* <p>IllegalArgumentException will be thrown if the op name does not exist in the model file used
* to initialize the {@link Interpreter}.
*/
public int getOutputIndex(String opName) {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getOutputIndex(opName);
}
/**
+ * Gets the Tensor associated with the provdied output index.
+ *
+ * <p>IllegalArgumentException will be thrown if the provided index is invalid.
+ */
+ public Tensor getOutputTensor(int outputIndex) {
+ checkNotClosed();
+ return wrapper.getOutputTensor(outputIndex);
+ }
+
+ /**
* Returns native inference timing.
* <p>IllegalArgumentException will be thrown if the model is not initialized by the
* {@link Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
- if (wrapper == null) {
- throw new IllegalStateException("Internal error: The interpreter has already been closed.");
- }
+ checkNotClosed();
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
/** Turns on/off Android NNAPI for hardware acceleration when it is available. */
public void setUseNNAPI(boolean useNNAPI) {
- if (wrapper != null) {
- wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalStateException(
- "Internal error: NativeInterpreterWrapper has already been closed.");
- }
+ checkNotClosed();
+ wrapper.setUseNNAPI(useNNAPI);
}
public void setNumThreads(int numThreads) {
- if (wrapper == null) {
- throw new IllegalStateException("The interpreter has already been closed.");
- }
+ checkNotClosed();
wrapper.setNumThreads(numThreads);
}
@@ -253,5 +269,11 @@ public final class Interpreter implements AutoCloseable {
}
}
+ private void checkNotClosed() {
+ if (wrapper == null) {
+ throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
+ }
+ }
+
NativeInterpreterWrapper wrapper;
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 767a220f8c..fa25082304 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -114,12 +114,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- if (!isMemoryAllocated) {
+ boolean needsAllocation = !isMemoryAllocated;
+ if (needsAllocation) {
allocateTensors(interpreterHandle, errorHandle);
isMemoryAllocated = true;
- // Allocation can trigger dynamic resizing of output tensors, so clear the
- // output tensor cache.
- Arrays.fill(outputTensors, null);
}
for (int i = 0; i < inputs.length; ++i) {
@@ -130,6 +128,14 @@ final class NativeInterpreterWrapper implements AutoCloseable {
run(interpreterHandle, errorHandle);
long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+ // Allocation can trigger dynamic resizing of output tensors, so refresh all output shapes.
+ if (needsAllocation) {
+ for (int i = 0; i < outputTensors.length; ++i) {
+ if (outputTensors[i] != null) {
+ outputTensors[i].refreshShape();
+ }
+ }
+ }
for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
getOutputTensor(output.getKey()).copyTo(output.getValue());
}
@@ -144,8 +150,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
- // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
- inputTensors[idx] = null;
+ if (inputTensors[idx] != null) {
+ inputTensors[idx].refreshShape();
+ }
}
}
@@ -230,6 +237,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return getOutputQuantizationScale(interpreterHandle, index);
}
+ /** Gets the number of input tensors. */
+ int getInputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the input {@link Tensor} for the provided input index.
*
@@ -247,6 +259,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return inputTensor;
}
+ /** Gets the number of output tensors. */
+ int getOutputTensorCount() {
+ return inputTensors.length;
+ }
+
/**
* Gets the output {@link Tensor} for the provided output index.
*
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 2403570c52..f174178d98 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
* <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
* needed to be closed here.
*/
-final class Tensor {
+public final class Tensor {
static Tensor fromHandle(long nativeHandle) {
return new Tensor(nativeHandle);
@@ -37,11 +37,26 @@ final class Tensor {
return dtype;
}
+ /**
+ * Returns the number of dimensions (sometimes referred to as <a
+ * href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
+ *
+ * <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
+ */
+ public int numDimensions() {
+ return shapeCopy.length;
+ }
+
/** Returns the size, in bytes, of the tensor data. */
public int numBytes() {
return numBytes(nativeHandle);
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor. */
+ public int numElements() {
+ return computeNumElements(shapeCopy);
+ }
+
/**
* Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
* the Tensor, i.e., the sizes of each dimension.
@@ -103,13 +118,22 @@ final class Tensor {
if (isByteBuffer(input)) {
return null;
}
- int[] inputShape = shapeOf(input);
+ int[] inputShape = computeShapeOf(input);
if (Arrays.equals(shapeCopy, inputShape)) {
return null;
}
return inputShape;
}
+ /**
+ * Forces a refresh of the tensor's cached shape.
+ *
+ * <p>This is useful if the tensor is resized or has a dynamic shape.
+ */
+ void refreshShape() {
+ this.shapeCopy = shape(nativeHandle);
+ }
+
/** Returns the type of the data. */
static DataType dataTypeOf(Object o) {
if (o != null) {
@@ -132,22 +156,31 @@ final class Tensor {
}
/** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
+ static int[] computeShapeOf(Object o) {
+ int size = computeNumDimensions(o);
int[] dimensions = new int[size];
fillShape(o, 0, dimensions);
return dimensions;
}
+ /** Returns the number of elements in a flattened (1-D) view of the tensor's shape. */
+ static int computeNumElements(int[] shape) {
+ int n = 1;
+ for (int i = 0; i < shape.length; ++i) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
/** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
- static int numDimensions(Object o) {
+ static int computeNumDimensions(Object o) {
if (o == null || !o.getClass().isArray()) {
return 0;
}
if (Array.getLength(o) == 0) {
throw new IllegalArgumentException("Array lengths cannot be 0.");
}
- return 1 + numDimensions(Array.get(o, 0));
+ return 1 + computeNumDimensions(Array.get(o, 0));
}
/** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
@@ -188,7 +221,7 @@ final class Tensor {
dtype, o.getClass().getName(), oType));
}
- int[] oShape = shapeOf(o);
+ int[] oShape = computeShapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
@@ -204,11 +237,11 @@ final class Tensor {
private final long nativeHandle;
private final DataType dtype;
- private final int[] shapeCopy;
+ private int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
- this.dtype = DataType.fromNumber(dtype(nativeHandle));
+ this.dtype = DataType.fromC(dtype(nativeHandle));
this.shapeCopy = shape(nativeHandle);
}
diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
index 3ffff052df..2a4bbdbead 100644
--- a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
-#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
#include <jni.h>
#include "tensorflow/contrib/lite/error_reporter.h"
@@ -47,4 +47,4 @@ class BufferErrorReporter : public tflite::ErrorReporter {
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 618fba480e..06b35d77c8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
-#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_
#include <jni.h>
#include <stdio.h>
#include <time.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
@@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads);
+ jclass clazz,
+ jlong handle,
+ jint num_threads);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
@@ -230,4 +230,4 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete(
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_NATIVEINTERPRETERWRAPPER_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 06e2546af8..2f73128bdf 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
-#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#ifdef __cplusplus
extern "C" {
@@ -92,4 +92,4 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
index 65f8341149..5e2a7ded1b 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
-#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_
+#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_
#include <jni.h>
@@ -33,4 +33,4 @@ Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_
+#endif // TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_LITE_JNI_H_
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
index cebc944200..6d6417f895 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java
@@ -26,9 +26,16 @@ public final class DataTypeTest {
@Test
public void testElemByteSize() {
- assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.INT32.elemByteSize()).isEqualTo(4);
- assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1);
- assertThat(DataType.INT64.elemByteSize()).isEqualTo(8);
+ assertThat(DataType.FLOAT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.INT32.byteSize()).isEqualTo(4);
+ assertThat(DataType.UINT8.byteSize()).isEqualTo(1);
+ assertThat(DataType.INT64.byteSize()).isEqualTo(8);
+ }
+
+ @Test
+ public void testConversion() {
+ for (DataType dataType : DataType.values()) {
+ assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
+ }
}
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index d66a73db94..9070b788b6 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -47,6 +47,10 @@ public final class InterpreterTest {
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
interpreter.close();
}
@@ -183,6 +187,19 @@ public final class InterpreterTest {
}
@Test
+ public void testResizeInput() {
+ try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
+ int[] inputDims = {1};
+ interpreter.resizeInput(0, inputDims);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(inputDims);
+ ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
+ interpreter.run(input, output);
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
+ }
+ }
+
+ @Test
public void testMobilenetRun() {
// Create a gray image.
float[][][][] img = new float[1][224][224][3];
@@ -199,6 +216,8 @@ public final class InterpreterTest {
Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE);
interpreter.run(img, labels);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
interpreter.close();
assertThat(labels[0])
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 71ef044943..85ad393d89 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -64,6 +64,8 @@ public final class TensorTest {
assertThat(tensor.shape()).isEqualTo(expectedShape);
assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
+ assertThat(tensor.numElements()).isEqualTo(2 * 8 * 8 * 3);
+ assertThat(tensor.numDimensions()).isEqualTo(4);
}
@Test
@@ -201,12 +203,12 @@ public final class TensorTest {
@Test
public void testNumDimensions() {
int scalar = 1;
- assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ assertThat(Tensor.computeNumDimensions(scalar)).isEqualTo(0);
int[][] array = {{2, 4}, {1, 9}};
- assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ assertThat(Tensor.computeNumDimensions(array)).isEqualTo(2);
try {
int[] emptyArray = {};
- Tensor.numDimensions(emptyArray);
+ Tensor.computeNumDimensions(emptyArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
@@ -214,9 +216,21 @@ public final class TensorTest {
}
@Test
+ public void testNumElements() {
+ int[] scalarShape = {};
+ assertThat(Tensor.computeNumElements(scalarShape)).isEqualTo(1);
+ int[] vectorShape = {3};
+ assertThat(Tensor.computeNumElements(vectorShape)).isEqualTo(3);
+ int[] matrixShape = {3, 4};
+ assertThat(Tensor.computeNumElements(matrixShape)).isEqualTo(12);
+ int[] degenerateShape = {3, 4, 0};
+ assertThat(Tensor.computeNumElements(degenerateShape)).isEqualTo(0);
+ }
+
+ @Test
public void testFillShape() {
int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = Tensor.numDimensions(array);
+ int num = Tensor.computeNumDimensions(array);
int[] shape = new int[num];
Tensor.fillShape(array, 0, shape);
assertThat(num).isEqualTo(3);
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index c5586475ec..40f28aeab4 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android")
# Suppress warnings that are introduced by Eigen Tensor.
EXTRA_EIGEN_COPTS = select({
@@ -66,7 +66,7 @@ cc_library(
deps = [
":op_macros",
"//tensorflow/contrib/lite:arena_planner",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -82,7 +82,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@gemmlowp",
],
)
@@ -93,7 +93,7 @@ cc_library(
"activation_functor.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -113,9 +113,9 @@ cc_library(
"kernel_util.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:types",
],
)
@@ -147,7 +147,16 @@ tf_cc_test(
)
cc_library(
- name = "builtin_ops",
+ name = "padding",
+ srcs = [],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+cc_library(
+ name = "builtin_op_kernels",
srcs = [
"activations.cc",
"add.cc",
@@ -172,10 +181,12 @@ cc_library(
"expand_dims.cc",
"fake_quant.cc",
"floor.cc",
+ "floor_div.cc",
"fully_connected.cc",
"gather.cc",
"hashtable_lookup.cc",
"l2norm.cc",
+ "layer_norm_lstm.cc",
"local_response_norm.cc",
"logical.cc",
"lsh_projection.cc",
@@ -190,7 +201,7 @@ cc_library(
"pooling.cc",
"pow.cc",
"reduce.cc",
- "register.cc",
+ "relu1.cc",
"reshape.cc",
"resize_bilinear.cc",
"select.cc",
@@ -211,34 +222,48 @@ cc_library(
"transpose_conv.cc",
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
+ "unpack.cc",
],
hdrs = [
- "padding.h",
- "register.h",
],
- copts = tflite_copts() + EXTRA_EIGEN_COPTS,
+ copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
+ visibility = ["//visibility:private"],
deps = [
":activation_functor",
":eigen_support",
":kernel_util",
":op_macros",
- "//tensorflow/contrib/lite:builtin_op_data",
+ ":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
],
)
+cc_library(
+ name = "builtin_ops",
+ srcs = ["register.cc"],
+ hdrs = ["register.h"],
+ deps = [
+ ":builtin_op_kernels",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
tf_cc_test(
name = "audio_spectrogram_test",
size = "small",
@@ -291,6 +316,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "relu1_test",
+ size = "small",
+ srcs = ["relu1_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
@@ -725,8 +767,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -742,8 +784,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -901,6 +943,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "layer_norm_lstm_test",
+ size = "small",
+ srcs = ["layer_norm_lstm_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "lstm_test",
size = "small",
srcs = ["lstm_test.cc"],
@@ -998,8 +1054,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1101,8 +1157,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1118,8 +1174,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1135,8 +1191,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1152,8 +1208,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1166,8 +1222,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1193,6 +1249,34 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "unpack_test",
+ size = "small",
+ srcs = ["unpack_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "floor_div_test",
+ size = "small",
+ srcs = ["floor_div_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cca33..e075dc7054 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <cmath>
#include <cstdlib>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index d6d62580e2..b2d9b84979 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -200,7 +200,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
const int num_dims = NumDimensions(input);
- TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4);
+ TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -453,6 +453,19 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
}
+// Takes a 3D tensor and perform softmax along the last dimension.
+void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<float>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ params->beta, GetTensorData<float>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -480,6 +493,19 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
GetTensorShape({batch_size, 1, 1, input_size}));
}
+void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ data->input_multiplier, data->input_left_shift, data->diff_min,
+ GetTensorData<uint8_t>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
@@ -515,6 +541,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DFloat(input, output, params);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DFloat(input, output, params);
return kTfLiteOk;
@@ -533,6 +563,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DQuantized(input, output, params, data);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DQuantized(input, output, params, data);
return kTfLiteOk;
@@ -590,10 +624,10 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
input->type);
return kTfLiteError;
}
- reference_ops::BroadcastBinaryFunction<float, float, float>(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(alpha), GetTensorDims(alpha),
- GetTensorData<float>(output), GetTensorDims(output), ApplyPrelu<float>);
+ reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
+ GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(alpha),
+ GetTensorData<float>(alpha), GetTensorShape(output),
+ GetTensorData<float>(output), ApplyPrelu<float>);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index e577e3a762..9fa47e190a 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax3D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax1D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {8}});
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index af9b5c7013..b4393e8097 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 4f30d09030..b91e348c27 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -96,11 +96,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMinMax( \
- GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
- GetTensorDims(input), GetTensorData<output_type>(output), \
- GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorData<axis_type>(axis), GetTensorShape(output), \
+ GetTensorData<output_type>(output), \
+ GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 91d8dd3fa7..44ef587244 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
index 8d460fdfc6..7346b9fd80 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index c09b15b3d2..1aa27602e5 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -31,8 +31,10 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -46,14 +48,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -65,20 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
output_size_array->data[0] = batch_size;
@@ -205,7 +201,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->inputs->data[kHiddenStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// We already checked that weight types are consistent, so branch on one.
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
index 96465fcaf0..d179735404 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -181,15 +181,16 @@ class RNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ BuildInterpreter({{batches_, input_size_}, // input tensor
+ {units_, input_size_}, // weights tensor
+ {units_, units_}, // recurrent weights tensor
+ {units_}, // bias tensor
+ {batches_, units_}}); // hidden state tensor
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -210,14 +211,6 @@ class RNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -258,7 +251,6 @@ TEST(RnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
@@ -286,7 +278,6 @@ TEST(HybridRnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index c8cee88edf..fe2865dfb9 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -125,14 +125,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ type::BatchToSpaceND(GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.crops), \
GetTensorData<int32_t>(op_context.crops), \
- GetTensorDims(op_context.crops), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a11a59aa05..541f320138 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
@@ -94,18 +94,54 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
// Projection bias tensor of size {n_output}
constexpr int kBwProjectionBiasTensor = 34; // Optional
-// Output tensors.
-constexpr int kFwOutputStateTensor = 0;
-constexpr int kFwCellStateTensor = 1;
-constexpr int kFwOutputTensor = 2;
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kFwInputActivationStateTensor = 35;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kFwInputCellStateTensor = 36;
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kBwInputActivationStateTensor = 37;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kBwInputCellStateTensor = 38;
+
+// Auxiliary input and weights when stacking.
+constexpr int kAuxInputTensor = 39; // Optional
+// Forward weights.
+constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
+constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
+constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
+constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
+// Backward weights.
+constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
+constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
+constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
+constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
-constexpr int kBwOutputStateTensor = 3;
-constexpr int kBwCellStateTensor = 4;
-constexpr int kBwOutputTensor = 5;
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ // Scratch buffers for input, forget, etc. gates
+ kFwScratchBuffer = 0,
+ kBwScratchBuffer = 1,
+ // Quantized tensors needed for the hybrid kernel.
+ kInputQuantized = 2,
+ kAuxInputQuantized = 3, // Quantized tensor needed for auxiliary input.
+ kFwActivationStateQuantized = 4,
+ kBwActivationStateQuantized = 5,
+ kFwCellStateQuantized = 6,
+ kBwCellStateQuantized = 7,
+ kScalingFactors = 8,
+ kProductScalingFactors = 9,
+ kRecoveredCellWeights = 10,
+ kNumTemporaryTensors = 11
+};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 2, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -126,7 +162,7 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -307,19 +343,20 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
+// Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input->dims->size > 1);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -343,13 +380,63 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
n_fw_cell));
- // Get the pointer to output, state and scratch buffer tensors.
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ // Get (optional) auxiliary inputs and weights.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) &&
+ (fw_aux_input_to_forget_weights != nullptr) &&
+ (fw_aux_input_to_output_weights != nullptr) &&
+ (bw_aux_input_to_cell_weights != nullptr) &&
+ (bw_aux_input_to_forget_weights != nullptr) &&
+ (bw_aux_input_to_output_weights != nullptr)) ||
+ ((fw_aux_input_to_cell_weights == nullptr) &&
+ (fw_aux_input_to_forget_weights == nullptr) &&
+ (fw_aux_input_to_output_weights == nullptr) &&
+ (bw_aux_input_to_cell_weights == nullptr) &&
+ (bw_aux_input_to_forget_weights == nullptr) &&
+ (bw_aux_input_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
+
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ }
- // Resize the output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
+ n_batch * n_fw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -357,32 +444,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
- TfLiteIntArray* fw_output_state_size = TfLiteIntArrayCreate(2);
- fw_output_state_size->data[0] = n_batch;
- fw_output_state_size->data[1] = n_fw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
- fw_output_state_size));
-
- TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
- fw_cell_size->data[0] = n_batch;
- fw_cell_size->data[1] = n_fw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8);
- // Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
+ }
+ // Create a scratch buffer tensor.
+ node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
+ fw_input_to_input_weights->dims->data[0]);
+ }
const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
fw_scratch_buffer_size->data[0] = n_batch;
@@ -415,13 +498,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ // Resize the output tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -429,30 +513,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, bw_output, bw_output_size));
- TfLiteIntArray* bw_output_state_size = TfLiteIntArrayCreate(2);
- bw_output_state_size->data[0] = n_batch;
- bw_output_state_size->data[1] = n_bw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
- bw_output_state_size));
-
- TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
- bw_cell_size->data[0] = n_batch;
- bw_cell_size->data[1] = n_bw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
+ n_batch * n_bw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
// Create a scratch buffer tensor.
- node->temporaries->data[1] = *(scratch_tensor_index) + 1;
- TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1);
+ node->temporaries->data[kBwScratchBuffer] =
+ *(scratch_tensor_index) + kBwScratchBuffer;
+ TfLiteTensor* bw_scratch_buffer =
+ GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
+ bw_input_to_input_weights->dims->data[0]);
+ }
const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
bw_scratch_buffer_size->data[0] = n_batch;
@@ -465,18 +546,528 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size));
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input, aux_input
+ // (if present), activation_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+
+ node->temporaries->data[kFwActivationStateQuantized] =
+ *scratch_tensor_index + kFwActivationStateQuantized;
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ fw_activation_state_quantized->type = kTfLiteUInt8;
+ fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
+ fw_activation_state->dims)) {
+ TfLiteIntArray* fw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(fw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_activation_state_quantized,
+ fw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kBwActivationStateQuantized] =
+ *scratch_tensor_index + kBwActivationStateQuantized;
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ bw_activation_state_quantized->type = kTfLiteUInt8;
+ bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
+ bw_activation_state->dims)) {
+ TfLiteIntArray* bw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(bw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_activation_state_quantized,
+ bw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kFwCellStateQuantized] =
+ *scratch_tensor_index + kFwCellStateQuantized;
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ fw_cell_state_quantized->type = kTfLiteUInt8;
+ fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
+ fw_cell_state->dims)) {
+ TfLiteIntArray* fw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(fw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, fw_cell_state_quantized,
+ fw_cell_state_quantized_size));
+ }
+ node->temporaries->data[kBwCellStateQuantized] =
+ *scratch_tensor_index + kBwCellStateQuantized;
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ bw_cell_state_quantized->type = kTfLiteUInt8;
+ bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
+ bw_cell_state->dims)) {
+ TfLiteIntArray* bw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(bw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, bw_cell_state_quantized,
+ bw_cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_fw_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ if (forward_sequence) {
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, aux_input_ptr,
+ aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+ aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, aux_input_ptr,
+ aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+ aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+ if (forward_sequence) {
+ // Feed the sequence into the LSTM step-by-step.
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr,
+ forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+ projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr,
+ forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+ projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ }
+
return kTfLiteOk;
}
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
// Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights =
@@ -518,9 +1109,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_projection_bias =
GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
// Tensors for the backward cell.
@@ -563,154 +1155,134 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_projection_bias =
GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ // State tensors.
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- // n_cell and n_output will be the same size when there is no projection.
- const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
- const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
- const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ // Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
- &context->tensors[node->temporaries->data[0]];
- float* fw_input_gate_scratch = nullptr;
- float* fw_cell_scratch = nullptr;
- float* fw_forget_gate_scratch = nullptr;
- float* fw_output_gate_scratch = nullptr;
- if (fw_use_cifg) {
- fw_cell_scratch = fw_scratch_buffer->data.f;
- fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- } else {
- fw_input_gate_scratch = fw_scratch_buffer->data.f;
- fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_forget_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* fw_input_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f;
- const float* fw_recurrent_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f;
- const float* fw_input_gate_bias_ptr =
- (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f;
- const float* fw_cell_to_input_weights_ptr =
- (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f
- : nullptr;
- const float* fw_cell_to_forget_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr;
- const float* fw_cell_to_output_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr;
- const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr)
- ? nullptr
- : fw_projection_weights->data.f;
- const float* fw_projection_bias_ptr =
- (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f;
-
- // Loop through the sequence.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, fw_input_to_input_weights_ptr,
- fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
- fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
- fw_recurrent_to_forget_weights->data.f,
- fw_recurrent_to_cell_weights->data.f,
- fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr,
- fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
- fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
- fw_cell_bias->data.f, fw_output_gate_bias->data.f,
- fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
- n_fw_cell, n_input, n_fw_output, fw_output_state->data.f,
- fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
- fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
- }
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
- const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* bw_scratch_buffer =
- &context->tensors[node->temporaries->data[1]];
- float* bw_input_gate_scratch = nullptr;
- float* bw_cell_scratch = nullptr;
- float* bw_forget_gate_scratch = nullptr;
- float* bw_output_gate_scratch = nullptr;
- if (bw_use_cifg) {
- bw_cell_scratch = bw_scratch_buffer->data.f;
- bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- } else {
- bw_input_gate_scratch = bw_scratch_buffer->data.f;
- bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_forget_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* bw_input_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f;
- const float* bw_recurrent_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f;
- const float* bw_input_gate_bias_ptr =
- (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f;
- const float* bw_cell_to_input_weights_ptr =
- (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f
- : nullptr;
- const float* bw_cell_to_forget_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr;
- const float* bw_cell_to_output_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr;
- const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr)
- ? nullptr
- : bw_projection_weights->data.f;
- const float* bw_projection_bias_ptr =
- (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f;
-
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, bw_input_to_input_weights_ptr,
- bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
- bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
- bw_recurrent_to_forget_weights->data.f,
- bw_recurrent_to_cell_weights->data.f,
- bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr,
- bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
- bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
- bw_cell_bias->data.f, bw_output_gate_bias->data.f,
- bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
- n_bw_cell, n_input, n_bw_output, bw_output_state->data.f,
- bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
- bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
+ GetTemporary(context, node, kBwScratchBuffer);
+
+ // (Optional) auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ switch (fw_input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ TfLiteStatus fw_pass_status = EvalFloat(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, params,
+ /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
+ fw_cell_state, fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalFloat(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
+ bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
+ bw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, params,
+ /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
+ bw_cell_state, bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+
+ TfLiteStatus fw_pass_status = EvalHybrid(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, params,
+ /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ aux_input_quantized, fw_activation_state_quantized,
+ fw_cell_state_quantized, fw_activation_state, fw_cell_state,
+ fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalHybrid(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, params,
+ /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ aux_input_quantized, bw_activation_state_quantized,
+ bw_cell_state_quantized, bw_activation_state, bw_cell_state,
+ bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ fw_input_to_output_weights->type);
+ return kTfLiteError;
}
-
- // Backward step.
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index a18e1bce34..74ba8021c2 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -102,10 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_output_state_ = AddOutput(TensorType_FLOAT32);
- fw_cell_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
-
if (use_cifg) {
bw_input_to_input_weights_ = AddNullInput();
} else {
@@ -161,10 +157,36 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_output_state_ = AddOutput(TensorType_FLOAT32);
- bw_cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ fw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ fw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ // Adding the 2 input state tensors.
+ bw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ bw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
+
bw_output_ = AddOutput(TensorType_FLOAT32);
+ aux_input_ = AddNullInput();
+ fw_aux_input_to_input_weights_ = AddNullInput();
+ fw_aux_input_to_forget_weights_ = AddNullInput();
+ fw_aux_input_to_cell_weights_ = AddNullInput();
+ fw_aux_input_to_output_weights_ = AddNullInput();
+ bw_aux_input_to_input_weights_ = AddNullInput();
+ bw_aux_input_to_forget_weights_ = AddNullInput();
+ bw_aux_input_to_cell_weights_ = AddNullInput();
+ bw_aux_input_to_output_weights_ = AddNullInput();
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
@@ -259,26 +281,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(bw_projection_bias_, f);
}
- void ResetFwOutputAndCellStates() {
- const int zero_buffer_size = n_fw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(fw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(fw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetBwOutputAndCellStates() {
- const int zero_buffer_size = n_bw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(bw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(bw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -340,13 +342,23 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int bw_projection_weights_;
int bw_projection_bias_;
- int fw_output_;
- int fw_output_state_;
- int fw_cell_state_;
+ int fw_input_activation_state_;
+ int fw_input_cell_state_;
+ int bw_input_activation_state_;
+ int bw_input_cell_state_;
+ int fw_output_;
int bw_output_;
- int bw_output_state_;
- int bw_cell_state_;
+
+ int aux_input_;
+ int fw_aux_input_to_input_weights_;
+ int fw_aux_input_to_forget_weights_;
+ int fw_aux_input_to_cell_weights_;
+ int fw_aux_input_to_output_weights_;
+ int bw_aux_input_to_input_weights_;
+ int bw_aux_input_to_forget_weights_;
+ int bw_aux_input_to_cell_weights_;
+ int bw_aux_input_to_output_weights_;
int n_batch_;
int n_input_;
@@ -417,6 +429,22 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -474,10 +502,6 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
-0.0332076, 0.123838, 0.309777, -0.17621,
-0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -500,34 +524,161 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
// Check reversed inputs.
static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -592,6 +743,22 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -642,10 +809,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
-0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -668,34 +831,153 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
- // Check reversed inputs.
- static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+TEST(LSTMOpTest,
+ BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
+ /*use_peephole=*/true, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302};
+ static float lstm_bw_golden_output[] = {
+ -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
+ 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
+
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -759,6 +1041,22 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights(
@@ -1343,10 +1641,6 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
0.065133, 0.024321, 0.038473, 0.062438
}};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
for (int i = 0; i < lstm.sequence_length(); i++) {
float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
float* batch0_end = batch0_start + lstm.num_inputs();
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 517309a226..2f896c5289 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -19,10 +19,11 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -35,34 +36,79 @@ constexpr int kInputTensor = 0;
constexpr int kFwWeightsTensor = 1;
constexpr int kFwRecurrentWeightsTensor = 2;
constexpr int kFwBiasTensor = 3;
-constexpr int kBwWeightsTensor = 4;
-constexpr int kBwRecurrentWeightsTensor = 5;
-constexpr int kBwBiasTensor = 6;
-// State and output tensors.
-constexpr int kFwHiddenStateTensor = 0;
-constexpr int kFwOutputTensor = 1;
-constexpr int kBwHiddenStateTensor = 2;
-constexpr int kBwOutputTensor = 3;
+constexpr int kFwHiddenStateTensor = 4;
+constexpr int kBwWeightsTensor = 5;
+constexpr int kBwRecurrentWeightsTensor = 6;
+constexpr int kBwBiasTensor = 7;
+constexpr int kBwHiddenStateTensor = 8;
+// Auxiliary inputs.
+constexpr int kAuxInputTensor = 9; // Optional.
+constexpr int kFwAuxWeightsTensor = 10; // Optional.
+constexpr int kBwAuxWeightsTensor = 11; // Optional.
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kFwHiddenStateQuantized = 1,
+ kBwHiddenStateQuantized = 2,
+ kScalingFactors = 3,
+ kAuxInputQuantized = 4,
+ kNumTemporaryTensors = 5
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* fw_hidden_state =
+ GetInput(context, node, kFwHiddenStateTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+ const TfLiteTensor* bw_hidden_state =
+ GetInput(context, node, kBwHiddenStateTensor);
+
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_weights != nullptr) &&
+ (bw_aux_input_weights != nullptr)) ||
+ ((aux_input == nullptr) && (fw_aux_input_weights == nullptr) &&
+ (bw_aux_input_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
@@ -75,32 +121,116 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_bias->dims->data[0]);
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
bw_bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ // Check that aux_input_weights has the same dimensions (except last) as
+ // the input_weights.
+ TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
+ TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ fw_aux_input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ bw_aux_input_weights->dims->data[1]);
+ }
- // Resize hidden states.
- TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- fw_hidden_state_size_array->data[0] = batch_size;
- fw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
- fw_hidden_state_size_array));
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- bw_hidden_state_size_array->data[0] = batch_size;
- bw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
- bw_hidden_state_size_array));
+ const bool is_hybrid_op =
+ (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (has_aux_input) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ // No need to create a temporary tensor for the non-existent aux_input.
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
+ }
+
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[kFwHiddenStateQuantized] =
+ *scratch_tensor_index + kFwHiddenStateQuantized;
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ fw_hidden_state_quantized->type = kTfLiteUInt8;
+ fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
+ fw_hidden_state->dims)) {
+ TfLiteIntArray* fw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(fw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_hidden_state_quantized,
+ fw_hidden_state_quantized_size));
+ }
+
+ node->temporaries->data[kBwHiddenStateQuantized] =
+ *scratch_tensor_index + kBwHiddenStateQuantized;
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ bw_hidden_state_quantized->type = kTfLiteUInt8;
+ bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
+ bw_hidden_state->dims)) {
+ TfLiteIntArray* bw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(bw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_hidden_state_quantized,
+ bw_hidden_state_quantized_size));
+ }
- // Mark hidden states as a persistent tensor.
- fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+ // Allocate temporary tensors to store scaling factors of quantization.
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+ }
// Resize outputs.
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
@@ -119,33 +249,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
-
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
+ const TfLiteTensor* bw_aux_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
const int fw_num_units = fw_input_weights->dims->data[0];
const float* fw_bias_ptr = fw_bias->data.f;
@@ -157,6 +274,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* bw_input_weights_ptr = bw_input_weights->data.f;
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
+ const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr)
+ ? fw_aux_input_weights->data.f
+ : nullptr;
+ const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr)
+ ? bw_aux_input_weights->data.f
+ : nullptr;
+
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
@@ -164,12 +288,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
- fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
@@ -178,24 +307,208 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
- bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
+ const TfLiteTensor* aux_bw_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const int batch_size = input->dims->data[0];
+ const int max_time = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+ const int fw_num_units = fw_input_weights->dims->data[0];
+ const float* fw_bias_ptr = fw_bias->data.f;
+ const int8_t* fw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_input_weights->data.uint8);
+ float fw_input_weights_scale = fw_input_weights->params.scale;
+ const int8_t* fw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_recurrent_weights->data.uint8);
+ float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
+
+ const int bw_num_units = bw_input_weights->dims->data[0];
+ const float* bw_bias_ptr = bw_bias->data.f;
+ const int8_t* bw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_input_weights->data.uint8);
+ float bw_input_weights_scale = bw_input_weights->params.scale;
+ const int8_t* bw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_recurrent_weights->data.uint8);
+ float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
+
+ // Set the auxiliary pointers and scales if needed.
+ int8_t* aux_fw_input_weights_ptr = nullptr;
+ float aux_fw_input_weights_scale = 0.0f;
+ int8_t* aux_bw_input_weights_ptr = nullptr;
+ float aux_bw_input_weights_scale = 0.0f;
+ int8_t* aux_quantized_input_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_fw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_fw_input_weights->data.uint8);
+ aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
+ aux_bw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_bw_input_weights->data.uint8);
+ aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
+ aux_quantized_input_ptr = reinterpret_cast<int8_t*>(aux_input_quantized);
+ }
+
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* fw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(fw_hidden_state_quantized->data.uint8);
+ int8_t* bw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch =
+ fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch =
+ bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params =
+ reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+
+ // Get auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ TfLiteTensor* fw_hidden_state =
+ GetVariableInput(context, node, kFwHiddenStateTensor);
+ TfLiteTensor* bw_hidden_state =
+ GetVariableInput(context, node, kBwHiddenStateTensor);
+
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+
+ switch (fw_input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, fw_hidden_state, fw_output, bw_hidden_state,
+ bw_output);
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* aux_input_quantized =
+ (aux_input != nullptr)
+ ? GetTemporary(context, node, kAuxInputQuantized)
+ : nullptr;
+
+ return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, scaling_factors, input_quantized,
+ aux_input_quantized, fw_hidden_state_quantized,
+ fw_hidden_state, fw_output, bw_hidden_state_quantized,
+ bw_hidden_state, bw_output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace bidirectional_sequence_rnn
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_rnn::Prepare,
- bidirectional_sequence_rnn::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
+ bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 911b108eaa..3e34ba6196 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -664,13 +664,19 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
- fw_hidden_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
+ fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
- bw_hidden_state_ = AddOutput(TensorType_FLOAT32);
+ bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
+
+ aux_input_ = AddNullInput();
+ aux_fw_weights_ = AddNullInput();
+ aux_bw_weights_ = AddNullInput();
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, /*time_major=*/false,
@@ -681,9 +687,14 @@ class BidirectionalRNNOpModel : public SingleOpModel {
{fw_units_, input_size_}, // fw_weights
{fw_units_, fw_units_}, // fw_recurrent_weights
{fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_} // bw_bias
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -719,19 +730,6 @@ class BidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenStates() {
- const int fw_zero_buffer_size = fw_units_ * batches_;
- std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]);
- memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float));
- PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(),
- fw_zero_buffer.get() + fw_zero_buffer_size);
- const int bw_zero_buffer_size = bw_units_ * batches_;
- std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]);
- memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float));
- PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(),
- bw_zero_buffer.get() + bw_zero_buffer_size);
- }
-
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@@ -753,6 +751,9 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_bias_;
int bw_hidden_state_;
int bw_output_;
+ int aux_input_;
+ int aux_fw_weights_;
+ int aux_bw_weights_;
int batches_;
int sequence_len_;
@@ -774,7 +775,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -813,8 +813,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
for (int i = 0; i < rnn.sequence_len(); i++) {
@@ -880,8 +878,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
const int output_sequence_size = output_size * rnn.sequence_len();
const int num_examples = 64;
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 8dd48af57f..a7972140ac 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <algorithm>
#include <complex>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 8b4d778332..4cd96348a2 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 605a20ac3e..25ea556d5a 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 50fe5c2e04..ab6bdaecaa 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/kernels/padding.h"
@@ -60,6 +61,8 @@ struct OpData {
// memory buffers.
int im2col_id = kTensorNotAllocated;
int hwcn_weights_id = kTensorNotAllocated;
+ int input_quantized_id = kTensorNotAllocated;
+ int scaling_factors_id = kTensorNotAllocated;
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
@@ -74,6 +77,8 @@ struct OpData {
// of the allocated temporaries.
int32_t im2col_index;
int32_t hwcn_weights_index;
+ int32_t input_quantized_index;
+ int32_t scaling_factors_index;
bool need_hwcn_weights;
bool have_weights_been_transposed;
bool need_im2col;
@@ -125,6 +130,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
@@ -145,8 +153,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// buffer to store the results.
// This path is only used for float processing, so only create the buffer if
// we're running with that data type.
- data->need_hwcn_weights =
- (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel);
+ data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
+ data->run_multithreaded_kernel && !is_hybrid);
int temporaries_count = 0;
if (data->need_im2col) {
@@ -164,6 +172,25 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
++temporaries_count;
}
+ if (is_hybrid) {
+ // Allocate tensor to store the on-the-fly quantized inputs.
+ data->input_quantized_index = temporaries_count;
+ if (data->input_quantized_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->input_quantized_id));
+ }
+ ++temporaries_count;
+
+ // Allocate tensor to store the quantization params computed during
+ // on-the-fly input quantization.
+ data->scaling_factors_index = temporaries_count;
+ if (data->scaling_factors_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->scaling_factors_id));
+ }
+ ++temporaries_count;
+ }
+
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(temporaries_count);
@@ -174,10 +201,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- data->run_multithreaded_kernel = context->recommended_num_threads != 1;
-
- TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
-
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
@@ -193,11 +216,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]);
// Check types. (We assume that UINT8 refers to quantized tensors)
- TfLiteType data_type = input->type;
+ TfLiteType input_type = input->type;
TF_LITE_ENSURE(context,
- data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, output->type, data_type);
- TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, input_type);
TfLiteTensor* bias = nullptr;
@@ -207,15 +229,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_bias) {
bias = &context->tensors[node->inputs->data[2]];
- if (data_type == kTfLiteUInt8) {
+ if (input_type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else {
- TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ TF_LITE_ENSURE_EQ(context, bias->type, input_type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
+ // Hybrid kernels don't support multithreading yet.
+ if (is_hybrid) {
+ data->run_multithreaded_kernel = false;
+ }
+
+ TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
+
+ int channels_in = filter->dims->data[3];
int channels_out = filter->dims->data[0];
int width = input->dims->data[2];
int height = input->dims->data[1];
@@ -250,9 +284,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, has_bias);
- // Note that quantized inference requires that all tensors have their
+ // Note that full fixed-point inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
+ if (input_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
@@ -287,7 +321,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* im2col =
&context->tensors[node->temporaries->data[data->im2col_index]];
- im2col->type = data_type;
+ im2col->type = input->type;
+ if (is_hybrid) {
+ im2col->type = kTfLiteUInt8;
+ }
im2col->allocation_type = kTfLiteArenaRw;
auto im2col_status = context->ResizeTensor(context, im2col, im2col_size);
if (im2col_status != kTfLiteOk) return im2col_status;
@@ -307,7 +344,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* hwcn_weights =
&context->tensors[node->temporaries->data[data->hwcn_weights_index]];
- hwcn_weights->type = data_type;
+ hwcn_weights->type = input_type;
hwcn_weights->allocation_type = kTfLiteArenaRwPersistent;
auto hwcn_weights_status =
@@ -319,6 +356,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->have_weights_been_transposed = false;
}
+ if (is_hybrid) {
+ node->temporaries->data[data->input_quantized_index] =
+ data->input_quantized_id;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[data->scaling_factors_index] =
+ data->scaling_factors_id;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, data->scaling_factors_index);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ // Only one scale factor per batch is typically necessary. See optimized
+ // implementation for why we need to allocate for the height of the inputs
+ // flattened to 2D.
+ scaling_factors_size->data[0] = NumElements(input) / channels_in;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+
return kTfLiteOk;
}
@@ -456,6 +523,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
+ TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ const int input_size = NumElements(input) / SizeOfDimension(input, 0);
+ const int batch_size = SizeOfDimension(input, 0);
+
+ const TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr =
+ GetTemporary(context, node, data->scaling_factors_index)->data.f;
+
+ // Per-batch input quantization for higher accuracy.
+ for (int b = 0; b < batch_size; ++b) {
+ float unused_min, unused_max;
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input->data.f + offset, input_size, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= filter->params.scale;
+ }
+
+ int8_t* im2col_ptr = nullptr;
+ if (im2col != nullptr) {
+ im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ }
+ int8_t* filter_ptr = reinterpret_cast<int8_t*>(filter->data.uint8);
+
+ switch (kernel_type) {
+ case kReference:
+ case kGenericOptimized:
+ case kMultithreadOptimized:
+ case kCblasOptimized:
+ // There is only one implementation for hybrid kernel. Note
+ // this does not make use of gemmlowp nor supports multithreading.
+ optimized_ops::HybridConv(
+ quantized_input_ptr_batch, GetTensorDims(input), filter_ptr,
+ GetTensorDims(filter), GetTensorData<float>(bias),
+ GetTensorDims(bias), params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height, scaling_factors_ptr,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output), im2col_ptr,
+ GetTensorDims(im2col));
+ break;
+ }
+}
+
+template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
@@ -484,7 +605,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// separate ops to avoid dispatch overhead here.
switch (input->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
- if (data->run_multithreaded_kernel) {
+ if (filter->type == kTfLiteUInt8) {
+ EvalHybrid<kernel_type>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ } else if (data->run_multithreaded_kernel) {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
im2col, hwcn_weights, output);
} else {
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 98152043c9..411615aa62 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -142,6 +142,104 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
}));
}
+// This test's output is equivalent to the SimpleTestFloat32
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {3, 2, 2, 2}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {1, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ // First batch
+ 1.5, 1.5, 1.5, 1.5, // row = 1
+ 3., 3., 3., 3., // row = 2
+ // Second batch
+ 1.5, 3., 4.5, 6., // row = 1
+ 1.5, 3., 4.5, 6., // row = 2
+ }));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {2, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ }));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -624,6 +722,192 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+class HybridConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(filter_, f);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ PopulateTensor(bias_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST_P(ConvolutionOpTest, SimpleTestHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ // Example: we get 17.1577 instead of 17.
+ //
+ // Second batch:
+ // 1 2 3 4 -> 32 64 95 127 with scale factor 127/4.
+ // 1 2 3 4 32 64 95 127
+ //
+ // First filter:
+ // 1 2 -> 32 64 with scale factor of 127/4.
+ // 3 4 95 127
+ //
+ // The left half of the input gives us 16288. Multiply by (4/127)^2 for
+ // dequantization and adding 1 for the bias gives us the result. and adding
+ // the bias gives us the result.
+ //
+ // The optimized kernel converts the input into this matrix via Im2Col
+ //
+ // 1 1 2 2
+ // 1 1 2 2
+ // 1 2 1 2
+ // 3 4 3 4
+ //
+ // and multiplies it with the filter directly.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+// This test's output is equivalent to the SimpleTestHybrid
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ // Example: we get 3.03156 instead of 3.
+ //
+ // Second batch:
+ // 0.5 0.5 1 1 1.5 1.5 2 2 -> 32 32 64 64 95 95 127 127 with scale factor
+ // 127/2. We care about the two 64's.
+ //
+ // Filter:
+ // 64 127 with scale factor of 127/2.
+ //
+ // (64 * 64 + 64 * 127) * (2/127)^2 gives us the expected result.
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 1.5, 1.5, 1.5, // first batch, row = 1
+ 3., 3., 3., 3., // first batch, row = 2
+ 1.5, 3., 4.5, 6., // second batch, row = 1
+ 1.5, 3., 4.5, 6., // second batch, row = 2
+ },
+ 0.0316)));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {2, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ },
+ 0.0474)));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 21518156b8..347515f289 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 2b0f04489a..3a08f48b00 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index d7bde0ff79..d2906632d7 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <string.h>
#include <numeric>
#include <vector>
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
index 4e0f8484a3..94c91a6bd6 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d7420ddd8e..7945c095b1 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index ec77856b10..feb1543f7b 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,10 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace EigenForTFLite {
-class ThreadPoolDevice;
+struct ThreadPoolDevice;
}
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index e19779ea59..04995d70dd 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index b2dff87e62..fe33f98eb0 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be36993c..aa75b03990 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdfe26..673e7be90a 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
index ed33012864..fa1140b19c 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -15,8 +15,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
index 50dc860e5a..a3bc1813db 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index 0ef1a50b30..f9bc3747cb 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 697b777693..59ff77f35b 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -41,8 +41,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+
return kTfLiteOk;
}
} // namespace floor
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
new file mode 100644
index 0000000000..5d62cd2755
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -0,0 +1,146 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace floor_div {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for floor_div op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+template <typename T>
+T FloorDiv(T input1, T input2) {
+ return std::floor(std::divides<double>()(static_cast<double>(input1),
+ static_cast<double>(input2)));
+}
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32) {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ const T* denominator_data = GetTensorData<T>(input2);
+
+ // Validate the denominator.
+ for (int i = 0; i < NumElements(input2); ++i) {
+ if (std::equal_to<T>()(denominator_data[i], 0)) {
+ context->ReportError(context, "Division by 0");
+ return kTfLiteError;
+ }
+ }
+ if (requires_broadcast) {
+ reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), denominator_data, GetTensorShape(output),
+ GetTensorData<T>(output), FloorDiv<T>);
+ } else {
+ reference_ops::BinaryFunction<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input1->type) {
+ case kTfLiteInt32: {
+ return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
+ input2, output);
+ }
+ default: {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ }
+}
+
+} // namespace
+} // namespace floor_div
+
+TfLiteRegistration* Register_FLOOR_DIV() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
+ floor_div::Prepare, floor_div::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/floor_div_test.cc b/tensorflow/contrib/lite/kernels/floor_div_test.cc
new file mode 100644
index 0000000000..eea69b61ac
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class FloorDivModel : public SingleOpModel {
+ public:
+ FloorDivModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions,
+ CreateFloorDivOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, 9, 11, 3});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, 3, 4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 0));
+}
+
+TEST(PowOpModel, NegativeValue) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, -3, -4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, -5, 3, -2));
+}
+
+TEST(PowOpModel, BroadcastFloorDiv) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {-3});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(-4, 3, 3, -3));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index eaf5a67d67..7a71fcc219 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 2b2a9e6620..badd2de11a 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 1d4292955c..1b48884e09 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772c68..43cd2b3055 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace gemm_support {
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index f37c66acb3..c0b3c3c0c5 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -39,8 +39,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a97db6c6b2..a6fd4ac2dd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -160,9 +160,10 @@ cc_library(
":types",
":reference_base",
":round",
+ ":tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -191,12 +192,13 @@ cc_library(
deps = [
":quantization_util",
":strided_slice_logic",
+ ":tensor_utils",
":types",
":legacy_reference_base",
":round",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -218,13 +220,15 @@ cc_library(
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
+ # FIXME(petewarden) - This should be removed, since it's a header from the
+ # :tensor dependency below.
"tensor.h",
],
deps = [
":optimized_base",
+ ":tensor",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//third_party/eigen3",
],
)
@@ -234,7 +238,7 @@ cc_test(
srcs = ["tensor_test.cc"],
tags = ["no_oss"],
deps = [
- ":reference",
+ ":tensor",
"@com_google_googletest//:gtest",
],
)
@@ -293,9 +297,8 @@ cc_library(
":round",
":strided_slice_logic",
":types",
- "//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -324,9 +327,8 @@ cc_library(
":round",
":strided_slice_logic",
":types",
- "//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -341,11 +343,27 @@ cc_library(
)
cc_library(
+ name = "tensor",
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
name = "reference",
- hdrs = ["tensor.h"],
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
deps = [
":types",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -359,7 +377,7 @@ cc_library(
],
deps = [
":round",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
],
@@ -384,7 +402,7 @@ cc_library(
":cpu_check",
":round",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
"@arm_neon_2_x86_sse",
@@ -398,7 +416,7 @@ cc_library(
hdrs = ["kernel_utils.h"],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -441,7 +459,7 @@ cc_library(
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
] + select({
@@ -517,7 +535,7 @@ cc_test(
],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index eb4d0108bd..e67fee11b8 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -45,7 +45,7 @@ limitations under the License.
#endif
#endif
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 200f2f1515..56e9367878 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
-#include <algorithm>
-
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
@@ -26,6 +24,21 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
int input_size, int num_units, int batch_size,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
+ bias_ptr, input_size, /*aux_input_size=*/0, num_units,
+ batch_size, activation, hidden_state_ptr_batch,
+ output_ptr_batch);
+}
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -33,6 +46,12 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
output_ptr_batch, /*result_stride=*/1);
+ // Output += aux_input * aux_input_weights (if they are not empty).
+ if (aux_input_size > 0) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
+ batch_size, output_ptr_batch, /*result_stride=*/1);
+ }
// Output += recurrent_weights * hidden_state
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
@@ -54,6 +73,28 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
int8_t* quantized_hidden_state_ptr_batch,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr,
+ /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
+ recurrent_weights_scale, bias_ptr, input_size,
+ /*aux_input_size=*/0, num_units, batch_size, activation,
+ quantized_input_ptr_batch,
+ /*aux_quantized_input_ptr_batch=*/nullptr,
+ quantized_hidden_state_ptr_batch, scaling_factors,
+ hidden_state_ptr_batch, output_ptr_batch);
+}
+
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -80,6 +121,26 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
}
+ if (aux_input_ptr_batch &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch,
+ batch_size * aux_input_size)) {
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * aux_input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, aux_input_size,
+ aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= aux_input_weights_scale;
+ }
+
+ // Output += aux_input * aux_input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size,
+ aux_quantized_input_ptr_batch, scaling_factors, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ }
+
// Save quantization and matmul computation for all zero input.
if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
batch_size * num_units)) {
@@ -127,6 +188,47 @@ void LstmStep(
float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
+ LstmStepWithAuxInput(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_to_input_weights_ptr=*/nullptr,
+ /*aux_input_to_forget_weights_ptr=*/nullptr,
+ /*aux_input_to_cell_weights_ptr=*/nullptr,
+ /*aux_input_to_output_weights_ptr=*/nullptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
+ n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
+}
+
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can
// check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@@ -160,6 +262,26 @@ void LstmStep(
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
output_gate_scratch, /*result_stride=*/1);
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
@@ -286,227 +408,364 @@ void LstmStep(
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
- &unused_min, &unused_max, &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
+ LstmStepWithAuxInput(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_to_input_weights_ptr=*/nullptr,
+ /*aux_input_to_input_weights_scale=*/0.0f,
+ /*aux_input_to_forget_weights_ptr=*/nullptr,
+ /*aux_input_to_forget_weights_scale=*/0.0f,
+ /*aux_input_to_cell_weights_ptr=*/nullptr,
+ /*aux_input_to_cell_weights_scale=*/0.0f,
+ /*aux_input_to_output_weights_ptr=*/nullptr,
+ /*aux_input_to_output_weights_scale=*/0.0f,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, params, n_batch, n_cell, n_input,
+ /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors,
+ product_scaling_factors, recovered_cell_weights,
+ quantized_input_ptr_batch,
+ /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr_batch);
}
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset,
- &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
+ void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr,
+ float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr,
+ float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_cell_weights,
+ int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch,
+ int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
+ float* output_state_ptr, float* cell_state_ptr,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
+ n_batch, input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
+ n_batch, output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input,
+ quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch,
+ /*result_stride=*/1);
}
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
+ cell_state_ptr, n_batch * n_cell,
+ cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell,
+ cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
}
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell,
+ output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch,
+ n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell,
+ quantized_cell_state_ptr, product_scaling_factors, n_batch,
+ output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
}
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
- product_scaling_factors, n_batch, output_ptr_batch,
- /*result_stride=*/1);
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
}
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 2a11b37a60..b5558cce55 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
namespace kernel_utils {
@@ -35,6 +35,15 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs a quantized RNN batch inference step. Same as above, but for
// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and
// quantized_input_ptr_batch pointers for temporary storage of the quantized
@@ -56,6 +65,17 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch);
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
@@ -66,8 +86,7 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
// - n_input: the input size,
// - n_output: the output size.
//
-// The pointers to the cell and output state and the output are updated. Unless
-// projection is specified output and output state contain the same data.
+// The pointers to the cell and output state and the output are updated.
//
// The pointers with the suffix "_batch" point to data aligned in batch_major
// order, and each step processes batch_size many inputs from input_ptr_batch,
@@ -92,6 +111,31 @@ void LstmStep(
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch);
+
// Same as above but with quantized weight matrices. In detail:
// Input of size 'n_batch * n_input':
// input_ptr_batch
@@ -175,6 +219,47 @@ void LstmStep(
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch);
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch);
+
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
index 3a53d3ab07..934308ef29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
namespace tflite {
@@ -58,4 +58,4 @@ inline bool TestCPUFeatureNeon() { return false; }
: Portable##funcname(__VA_ARGS__)
#endif
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
index 250872c422..6443f425b7 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -140,4 +140,4 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index 7f0676be27..b6151c40b3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -27,8 +27,33 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
+using reference_ops::ArgMax;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::SpaceToBatchND;
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float, but reserved in signature for future
+ // activations.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
@@ -46,8 +71,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -296,13 +321,17 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- BroadcastMul4DSlow(
- input1_data, input1_dims, input1_offset, input2_data, input2_dims,
- input2_offset, output_offset, output_multiplier,
- // This legacy version switches the sign of the output shift.
- kReverseShift * output_shift,
- // (Break to highlight preceding line.)
- output_activation_min, output_activation_max, output_data, output_dims);
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -580,8 +609,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -601,8 +630,8 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
@@ -621,6 +650,294 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.output_offset = output_offset;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// For compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ float float_activation_min;
+ float float_activation_max;
+ GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
+ SetActivationParams(float_activation_min, float_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>.
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4> version.
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4> version.
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// Legacy Dims<4>.
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 4a3545d47a..5fb31889fe 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
#include <assert.h>
#include <stdint.h>
@@ -26,7 +26,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -164,4 +164,4 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
} // namespace multithreaded_ops
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_MULTITHREADED_CONV_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 420bc68b43..27418178fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -236,6 +236,35 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
+ // Multiply.
+ float32x4_t result_f32x4 = vmulq_f32(batch_vector_f32x4, vector_f32x4);
+ // Store.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector[v] * batch_vector[v];
+ }
+ // Update the pointers.
+ result += v_size;
+ batch_vector += v_size;
+ }
+}
+
void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 63c89d1eee..630a6bbf29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
@@ -52,6 +52,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
+ n_batch, result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -72,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
n_batch, result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -131,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index da585e5550..2c8e8f90e3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
#include <assert.h>
#include <stdint.h>
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
@@ -42,6 +43,14 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
using reference_ops::BroadcastAdd4DSlow;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
@@ -57,8 +66,12 @@ using reference_ops::FakeQuant;
using reference_ops::Gather;
using reference_ops::Greater;
using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
using reference_ops::Mean;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
@@ -66,6 +79,7 @@ using reference_ops::Relu6;
using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::Transpose;
@@ -319,6 +333,7 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
+// Note: This to be converted to RuntimeShapes along with Conv.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void AddBiasAndEvalActivationFunction(const float* bias_data,
@@ -1934,6 +1949,85 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
output_activation_max);
}
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ const int batch_size = input_dims.sizes[3];
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+
+ const int8_t* gemm_input_data = nullptr;
+ int num_input;
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ // symmetric quantization assumes zero point of 0.
+ const int input_zero_point = 0;
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, input_zero_point,
+ im2col_data, im2col_dims);
+ gemm_input_data = im2col_data;
+ num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
+ im2col_dims.sizes[2] * im2col_dims.sizes[3];
+ } else {
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ num_input = input_dims.sizes[0] * input_dims.sizes[1] *
+ input_dims.sizes[2] * input_dims.sizes[3];
+ }
+
+ // Flatten 4D matrices into 2D matrices for matrix multiplication.
+
+ // Flatten so that each filter has its own row.
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+
+ // In MatrixBatchVectorMultiplyAccumulate, each output value is the
+ // dot product of one row of the first matrix with one row of the second
+ // matrix. Therefore, the number of cols in each matrix are equivalent.
+ //
+ // After Im2Col, each input patch becomes a row.
+ const int gemm_input_cols = filter_cols;
+ const int gemm_input_rows = num_input / gemm_input_cols;
+
+ const int output_cols = output_dims.sizes[0];
+ const int output_rows =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_cols, filter_rows);
+ TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+
+ // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
+ // input matrix has its own scale factor. This code duplicates the scale
+ // factors for each row in the same batch.
+ const int rows_per_batch = gemm_input_rows / batch_size;
+ for (int i = gemm_input_rows - 1; i >= 0; --i) {
+ scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
+ }
+
+ tensor_utils::ZeroVector(output_data, output_rows * output_cols);
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ filter_data, filter_rows, filter_cols, gemm_input_data,
+ scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
+ /*result_stride=*/1);
+
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -2142,38 +2236,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
-template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("DepthToSpace");
-
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
-
- const int output_depth = ArraySize(output_dims, 0);
- const int batch_size = ArraySize(output_dims, 3);
-
- // Number of continuous values that we can copy in one interation.
- const int stride = block_size * output_depth;
-
- for (int batch = 0; batch < batch_size; ++batch) {
- for (int in_h = 0; in_h < input_height; ++in_h) {
- const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
- const T* src = input_ptr;
- for (int in_w = 0; in_w < input_width; ++in_w) {
- memcpy(output_data, src, stride * sizeof(T));
- output_data += stride;
- src += input_depth;
- }
- input_ptr += stride;
- }
- }
- }
-}
-
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
@@ -2249,25 +2311,75 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("DepthToSpace");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+
+ const int output_depth = output_shape.Dims(3);
+ const int batch_size = output_shape.Dims(0);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = op_params.block_size * output_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
+ const T* src = input_ptr;
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ memcpy(output_data, src, stride * sizeof(T));
+ output_data += stride;
+ src += input_depth;
+ }
+ input_ptr += stride;
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
- const int input_depth = ArraySize(input_dims, 0);
- const int batch_size = ArraySize(input_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+
+ const int input_depth = input_shape.Dims(3);
+ const int batch_size = input_shape.Dims(0);
// Number of continuous values that we can copy in one interation.
- const int stride = block_size * input_depth;
+ const int stride = op_params.block_size * input_depth;
for (int batch = 0; batch < batch_size; ++batch) {
for (int out_h = 0; out_h < output_height; ++out_h) {
- T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
T* dst = output_ptr;
for (int out_w = 0; out_w < output_width; ++out_w) {
memcpy(dst, input_data, stride * sizeof(T));
@@ -2280,55 +2392,8 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
-inline void Relu(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
const auto input = MapAsVector(input_data, input_shape);
@@ -2336,11 +2401,12 @@ inline void Relu(const float* input_data, const RuntimeShape& input_shape,
output = input.cwiseMax(0.0f);
}
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization");
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -2409,16 +2475,18 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data,
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
@@ -2725,17 +2793,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
}
}
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
const auto activation_min = vdupq_n_f32(output_activation_min);
const auto activation_max = vdupq_n_f32(output_activation_max);
@@ -2786,25 +2853,16 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mul/int32");
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@@ -2812,22 +2870,24 @@ inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
+inline void MulNoActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() * input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() * scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar * input2_map.array();
} else {
@@ -2836,14 +2896,16 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mul/Int16");
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2855,17 +2917,20 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 output_offset = params.output_offset;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2883,64 +2948,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -3016,9 +3023,53 @@ inline void MulElementwise(int size, const ArithmeticParams& params,
inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
const uint8 broadcast_value,
const uint8* input2_data, uint8* output_data) {
- const int32 input1_val = params.input1_offset + broadcast_value;
+ const int16 input1_val = params.input1_offset + broadcast_value;
+
+ int i = 0;
+ TFLITE_DCHECK_GT(params.input1_offset, -256);
+ TFLITE_DCHECK_LT(params.input1_offset, 256);
+ TFLITE_DCHECK_GT(params.input2_offset, -256);
+ TFLITE_DCHECK_LT(params.input2_offset, 256);
+ TFLITE_DCHECK_GT(params.output_offset, -256);
+ TFLITE_DCHECK_LT(params.output_offset, 256);
+#ifdef USE_NEON
+ const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
+ const auto output_offset_vector = vdupq_n_s16(params.output_offset);
+ const auto output_activation_min_vector =
+ vdup_n_u8(params.quantized_activation_min);
+ const auto output_activation_max_vector =
+ vdup_n_u8(params.quantized_activation_max);
+ for (; i <= size - 8; i += 8) {
+ // We load / store 8 at a time, multiplying as two sets of 4 int32s.
+ const auto input2_val_original = vld1_u8(input2_data + i);
+ const auto input2_val_s16 =
+ vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
+ const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector);
+
+ const auto input2_val_low = vget_low_s16(input2_val);
+ const auto input2_val_high = vget_high_s16(input2_val);
+
+ auto p1 = vmull_n_s16(input2_val_low, input1_val);
+ auto p2 = vmull_n_s16(input2_val_high, input1_val);
+
+ p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
+ p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ p1 = RoundingDivideByPOT(p1, -params.output_shift);
+ p2 = RoundingDivideByPOT(p2, -params.output_shift);
+
+ const auto p1_narrowed = vmovn_s32(p1);
+ const auto p2_narrowed = vmovn_s32(p2);
+ const auto p =
+ vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector);
+ const auto clamped =
+ vmax_u8(output_activation_min_vector,
+ vmin_u8(output_activation_max_vector, vqmovun_s16(p)));
+ vst1_u8(output_data + i, clamped);
+ }
+#endif // NEON
- for (int i = 0; i < size; ++i) {
+ for (; i < size; ++i) {
const int32 input2_val = params.input2_offset + input2_data[i];
const int32 unclamped_result =
params.output_offset +
@@ -3125,15 +3176,28 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -3146,14 +3210,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -3161,6 +3225,21 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
// TODO(aselle): This is not actually optimized yet.
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@@ -3990,29 +4069,28 @@ inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
}
}
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Carry out local response normalization, vector by vector.
// Since the data are stored column major, making row-wise operation
// probably not memory efficient anyway, we do an explicit for loop over
// the columns.
- const int double_range = range * 2;
+ const int double_range = op_params.range * 2;
Eigen::VectorXf padded_square(data_in.rows() + double_range);
padded_square.setZero();
for (int r = 0; r < data_in.cols(); ++r) {
// Do local response normalization for data_in(:, r)
// first, compute the square and store them in buffer for repeated use
- padded_square.block(range, 0, data_in.rows(), 1) =
- data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
+ padded_square.block(op_params.range, 0, data_in.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
// Then, compute the scale and writes them to data_out
float accumulated_scale = 0;
for (int i = 0; i < double_range; ++i) {
@@ -4020,18 +4098,18 @@ inline void LocalResponseNormalization(const float* input_data,
}
for (int i = 0; i < data_in.rows(); ++i) {
accumulated_scale += padded_square(i + double_range);
- data_out(i, r) = bias + accumulated_scale;
+ data_out(i, r) = op_params.bias + accumulated_scale;
accumulated_scale -= padded_square(i);
}
}
// In a few cases, the pow computation could benefit from speedups.
- if (beta == 1) {
+ if (op_params.beta == 1) {
data_out.array() = data_in.array() * data_out.array().inverse();
- } else if (beta == 0.5) {
+ } else if (op_params.beta == 0.5) {
data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
} else {
- data_out.array() = data_in.array() * data_out.array().pow(-beta);
+ data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
}
}
@@ -4500,8 +4578,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -4646,8 +4724,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -4706,8 +4784,14 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+// Legacy version.
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
@@ -4962,19 +5046,19 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
gemmlowp::ScopedProfilingLabel label("Cast");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().template cast<DstT>();
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Floor");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = Eigen::floor(input_map.array());
}
@@ -5077,12 +5161,14 @@ inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
int32 x, int32 y, int32 depth, int32 batch,
+ const RuntimeShape& input_shape,
const float* input_data,
- const Dims<4>& input_dims,
- float* output_data,
- const Dims<4>& output_dims) {
- const int32 input_width = ArraySize(input_dims, 1);
- const int32 output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 output_width = output_shape.Dims(2);
const int32 input_x_offset = (x1 - x0) * depth;
const int32 input_y_offset = (y1 - y0) * depth * input_width;
@@ -5090,7 +5176,6 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const int32 output_y_offset = depth * output_width;
#ifdef USE_NEON
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
TFLITE_DCHECK(x1 >= x0);
TFLITE_DCHECK(y1 >= y0);
@@ -5100,7 +5185,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const float* input_ptr = nullptr;
float32x4x2_t x0y0;
- input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
x0y0.val[0] = vld1q_f32(input_ptr);
x0y0.val[1] = vld1q_f32(input_ptr + 4);
@@ -5120,7 +5205,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
x1y1.val[1] = vld1q_f32(input_ptr + 4);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0.val[0]);
vst1q_f32(output_ptr + 4, x0y0.val[1]);
@@ -5159,14 +5244,15 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle 4 input channels at a time.
for (; ic <= depth - 4; ic += 4) {
- const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ const float* input_ptr =
+ &input_data[Offset(input_shape, batch, y0, x0, ic)];
float32x4_t x0y0 = vld1q_f32(input_ptr);
float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0);
// Top right corner.
@@ -5190,7 +5276,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle one input channel at a time.
for (; ic < depth; ic++) {
- const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5198,7 +5284,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ic, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ic);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5214,7 +5300,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
#else
for (int ch = 0; ch < depth; ch++) {
- const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5222,7 +5308,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ch, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ch);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5239,31 +5325,30 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
#endif
}
-inline void ResizeBilinear2x2(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width) {
+inline void ResizeBilinear2x2(int32 batches, int32 input_height,
+ int32 input_width, int32 depth,
+ int32 output_height, int32 output_width,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
for (int b = 0; b < batches; b++) {
for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
int32 x1 = std::min(x0 + 1, input_width - 1);
int32 y1 = std::min(y0 + 1, input_height - 1);
- ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
- input_dims, output_data, output_dims);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
+ input_data, output_shape, output_data);
}
}
}
}
-inline void ResizeBilinearGeneric(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width, float height_scale,
- float width_scale) {
+inline void ResizeBilinearGeneric(
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
memset(output_data, 0,
batches * output_height * output_width * depth * sizeof(float));
@@ -5280,22 +5365,22 @@ inline void ResizeBilinearGeneric(const float* input_data,
float* output_ptr = &output_data[output_offset];
// Run kernel on the 4 corners of the bilinear resize algorithm.
- int32 input_offset = Offset(input_dims, 0, x0, y0, b);
+ int32 input_offset = Offset(input_shape, b, y0, x0, 0);
float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
const float* input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y0, b);
+ input_offset = Offset(input_shape, b, y0, x1, 0);
scale = (1 - (input_y - y0)) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x0, y1, b);
+ input_offset = Offset(input_shape, b, y1, x0, 0);
scale = (input_y - y0) * (1 - (input_x - x0));
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y1, b);
+ input_offset = Offset(input_shape, b, y1, x1, 0);
scale = (input_y - y0) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
@@ -5308,10 +5393,10 @@ inline void ResizeBilinearGeneric(const float* input_data,
template <typename T>
inline void ResizeBilinearGenericSmallChannel(
- const T* input_data, const Dims<4>& input_dims, T* output_data,
- const Dims<4>& output_dims, int32 batches, int32 input_height,
- int32 input_width, int32 depth, int32 output_height, int32 output_width,
- float height_scale, float width_scale) {
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
memset(output_data, 0,
batches * output_height * output_width * depth * sizeof(T));
@@ -5326,9 +5411,10 @@ inline void ResizeBilinearGenericSmallChannel(
int32 x0 = static_cast<int32>(input_x);
int32 x1 = std::min(x0 + 1, input_width - 1);
- int32 input_offset[4] = {
- Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b),
- Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)};
+ int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
+ Offset(input_shape, b, y0, x1, 0),
+ Offset(input_shape, b, y1, x0, 0),
+ Offset(input_shape, b, y1, x1, 0)};
float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
(1 - (input_y - y0)) * (input_x - x0),
(input_y - y0) * (1 - (input_x - x0)),
@@ -5346,97 +5432,93 @@ inline void ResizeBilinearGenericSmallChannel(
}
}
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const float* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims, bool align_corners) {
+ const RuntimeShape& unextended_output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
// Specialize for 2x2 upsample.
- if (!align_corners && output_height == 2 * input_height &&
+ if (!op_params.align_corners && output_height == 2 * input_height &&
output_width == 2 * input_width) {
- ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
- input_height, input_width, depth, output_height,
- output_width);
+ ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
+ output_width, input_shape, input_data, output_shape,
+ output_data);
} else {
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
- ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
- batches, input_height, input_width, depth,
+ ResizeBilinearGeneric(batches, input_height, input_width, depth,
output_height, output_width, height_scale,
- width_scale);
+ width_scale, input_shape, input_data, output_shape,
+ output_data);
}
}
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
// or int16 arithmetic.
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims, bool align_corners) {
+ const RuntimeShape& unextended_output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
float height_scale =
- (align_corners && output_height > 1)
+ (op_params.align_corners && output_height > 1)
? (static_cast<float>(input_height - 1) / (output_height - 1))
: (static_cast<float>(input_height) / output_height);
float width_scale =
- (align_corners && output_width > 1)
+ (op_params.align_corners && output_width > 1)
? (static_cast<float>(input_width - 1) / (output_width - 1))
: (static_cast<float>(input_width) / output_width);
ResizeBilinearGenericSmallChannel<uint8>(
- input_data, input_dims, output_data, output_dims, batches, input_height,
- input_width, depth, output_height, output_width, height_scale,
- width_scale);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
+ batches, input_height, input_width, depth, output_height, output_width,
+ height_scale, width_scale, input_shape, input_data, output_shape,
+ output_data);
}
// Helper methods for BatchToSpaceND.
@@ -5461,20 +5543,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
}
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -5509,8 +5600,9 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
spatial_offset % block_shape_width - crops_left;
TFLITE_DCHECK_GE(out_w, 0);
TFLITE_DCHECK_LT(out_w, output_width);
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -5554,12 +5646,14 @@ inline void PadImpl(const tflite::PadParams& op_params,
// Runtime calls are currently fixed at 4 dimensions. Copy inputs so
// we can pad them to 4 dims (yes, we are "padding the padding").
std::vector<int> left_padding_copy(4, 0);
+ const int left_padding_extend = 4 - op_params.left_padding_count;
for (int i = 0; i < op_params.left_padding_count; ++i) {
- left_padding_copy[i] = op_params.left_padding[i];
+ left_padding_copy[left_padding_extend + i] = op_params.left_padding[i];
}
std::vector<int> right_padding_copy(4, 0);
+ const int right_padding_extend = 4 - op_params.right_padding_count;
for (int i = 0; i < op_params.right_padding_count; ++i) {
- right_padding_copy[i] = op_params.right_padding[i];
+ right_padding_copy[right_padding_extend + i] = op_params.right_padding[i];
}
const int output_batch = ext_output_shape.Dims(0);
@@ -5578,7 +5672,6 @@ inline void PadImpl(const tflite::PadParams& op_params,
const int right_d_padding = right_padding_copy[3];
const int input_depth = ext_input_shape.Dims(3);
- // const T pad_value = ExtractFloatOrInt<T>(op_params.pad_value);
const T pad_value = *pad_value_ptr;
if (left_b_padding != 0) {
@@ -5673,50 +5766,6 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- // SetFloatOrInt(pad_value, &op_params.pad_value);
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -5761,22 +5810,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
T* output_data) {
@@ -5799,22 +5832,6 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& filter_dims, int stride_width,
int stride_height, int pad_width, int pad_height,
@@ -5934,4 +5951,4 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
#pragma GCC diagnostic pop
#endif
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 010b40b901..f87760a6c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -86,6 +86,14 @@ void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -109,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit,
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
@@ -164,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index f882f9910e..544ef16ce1 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -23,6 +23,32 @@ limitations under the License.
namespace tflite {
+namespace {
+// These constants are used to manipulate the binary representation of doubles.
+// Double-precision binary64 floating point format is:
+// Bit | 63 | 62-52 | 51-0 |
+// | Sign | Exponent | Fraction |
+// To avoid 64-bit integers as much as possible, I break this into high and
+// low 32-bit chunks. High is:
+// Bit | 31 | 30-20 | 19-0 |
+// | Sign | Exponent | High Fraction |
+// Low is:
+// Bit | 31-0 |
+// | Low Fraction |
+// We then access the components through logical bit-wise operations to
+// extract the parts needed, with the positions and masks derived from the
+// layout shown above.
+constexpr uint64_t kSignMask = 0x8000000000000000LL;
+constexpr uint64_t kExponentMask = 0x7ff0000000000000LL;
+constexpr int32_t kExponentShift = 52;
+constexpr int32_t kExponentBias = 1023;
+constexpr uint32_t kExponentIsBadNum = 0x7ff;
+constexpr uint64_t kFractionMask = 0x000fffffffc00000LL;
+constexpr uint32_t kFractionShift = 22;
+constexpr uint32_t kFractionRoundingMask = 0x003fffff;
+constexpr uint32_t kFractionRoundingThreshold = 0x00200000;
+} // namespace
+
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift) {
if (double_multiplier == 0.) {
@@ -30,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
*shift = 0;
return;
}
+#ifdef TFLITE_EMULATE_FLOAT
+ // If we're trying to avoid the use of floating-point instructions (for
+ // example on microcontrollers) then use an alternative implementation
+ // that only requires integer and bitwise operations. To enable this, you
+ // need to set the define during the build process for your platform.
+ int64_t q_fixed = IntegerFrExp(double_multiplier, shift);
+#else // TFLITE_EMULATE_FLOAT
const double q = std::frexp(double_multiplier, shift);
auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+#endif // TFLITE_EMULATE_FLOAT
TFLITE_CHECK(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
@@ -60,6 +94,163 @@ void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
*left_shift = shift;
}
+int64_t IntegerFrExp(double input, int* shift) {
+ // Make sure our assumptions about the double layout hold.
+ TFLITE_CHECK_EQ(8, sizeof(double));
+
+ // We want to access the bits of the input double value directly, which is
+ // tricky to do safely, so use a union to handle the casting.
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } cast_union;
+ cast_union.double_value = input;
+ const uint64_t u = cast_union.double_as_uint;
+
+ // If the bitfield is all zeros apart from the sign bit, this is a normalized
+ // zero value, so return standard values for this special case.
+ if ((u & ~kSignMask) == 0) {
+ *shift = 0;
+ return 0;
+ }
+
+ // Deal with NaNs and Infs, which are always indicated with a fixed pattern in
+ // the exponent, and distinguished by whether the fractions are zero or
+ // non-zero.
+ const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift);
+ if (exponent_part == kExponentIsBadNum) {
+ *shift = std::numeric_limits<int>::max();
+ if (u & kFractionMask) {
+ // NaN, so just return zero (with the exponent set to INT_MAX).
+ return 0;
+ } else {
+ // Infinity, so return +/- INT_MAX.
+ if (u & kSignMask) {
+ return std::numeric_limits<int64_t>::min();
+ } else {
+ return std::numeric_limits<int64_t>::max();
+ }
+ }
+ }
+
+ // The shift is fairly easy to extract from the high bits of the double value,
+ // just by masking it out and applying a bias. The std::frexp() implementation
+ // always returns values between 0.5 and 1.0 though, whereas the exponent
+ // assumes 1.0 to 2.0 is the standard range, so I add on one to match that
+ // interface.
+ *shift = (exponent_part - kExponentBias) + 1;
+
+ // There's an implicit high bit in the double format definition, so make sure
+ // we include that at the top, and then reconstruct the rest of the fractional
+ // value from the remaining fragments.
+ int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift);
+
+ // We're cutting off some bits at the bottom, so to exactly match the standard
+ // frexp implementation here we'll apply rounding by adding one to the least
+ // significant bit of the result if the discarded portion is over half of the
+ // maximum.
+ if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) {
+ fraction += 1;
+ }
+ // Negate the fraction if the sign bit was set.
+ if (u & kSignMask) {
+ fraction *= -1;
+ }
+
+ return fraction;
+}
+
+double DoubleFromFractionAndShift(int64_t fraction, int shift) {
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } result;
+
+ // Detect NaNs and infinities.
+ if (shift == std::numeric_limits<int>::max()) {
+ if (fraction == 0) {
+ return NAN;
+ } else if (fraction > 0) {
+ return INFINITY;
+ } else {
+ return -INFINITY;
+ }
+ }
+
+ // Return a normalized zero for a zero fraction.
+ if (fraction == 0) {
+ result.double_as_uint = 0;
+ return result.double_value;
+ }
+
+ bool is_negative = (fraction < 0);
+ int64_t encoded_fraction = is_negative ? -fraction : fraction;
+ int64_t encoded_shift = (shift - 1);
+ while (encoded_fraction < 0x40000000) {
+ encoded_fraction *= 2;
+ encoded_shift -= 1;
+ }
+ while (encoded_fraction > 0x80000000) {
+ encoded_fraction /= 2;
+ encoded_shift += 1;
+ }
+ encoded_fraction -= 0x40000000;
+ if (encoded_shift < -1022) {
+ encoded_shift = -1023;
+ } else if (encoded_shift > 1022) {
+ encoded_shift = 1023;
+ }
+ encoded_shift += kExponentBias;
+ uint64_t encoded_sign = is_negative ? kSignMask : 0;
+ result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) |
+ (encoded_fraction << kFractionShift);
+ return result.double_value;
+}
+
+double IntegerDoubleMultiply(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return NAN;
+ }
+ const int result_shift = a_shift + b_shift + 1;
+ const int64_t result_fraction = (a_fraction * b_fraction) >> 32;
+ return DoubleFromFractionAndShift(result_fraction, result_shift);
+}
+
+int IntegerDoubleCompare(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return 1;
+ }
+
+ if ((a_fraction == 0) && (b_fraction < 0)) {
+ return 1;
+ } else if ((a_fraction < 0) && (b_fraction == 0)) {
+ return -1;
+ } else if (a_shift < b_shift) {
+ return -1;
+ } else if (a_shift > b_shift) {
+ return 1;
+ } else if (a_fraction < b_fraction) {
+ return -1;
+ } else if (a_fraction > b_fraction) {
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
void PreprocessSoftmaxScaling(double beta, double input_scale,
int input_integer_bits,
int32_t* quantized_multiplier, int* left_shift) {
@@ -72,8 +263,20 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
// result is double equivalent of Q0.31 (actually with more precision). Thus
// this generates a Q(input_integer_bits).(31-input_integer_bits)
// representation.
+#ifdef TFLITE_EMULATE_FLOAT
+ const double input_beta = IntegerDoubleMultiply(beta, input_scale);
+ int shift;
+ int64_t fraction = IntegerFrExp(input_beta, &shift);
+ shift += (31 - input_integer_bits);
+ double input_beta_real_multiplier =
+ DoubleFromFractionAndShift(fraction, shift);
+ if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) {
+ input_beta_real_multiplier = (1ll << 31) - 1.0;
+ }
+#else // TFLITE_EMULATE_FLOAT
const double input_beta_real_multiplier = std::min(
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+#endif // TFLITE_EMULATE_FLOAT
QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
quantized_multiplier, left_shift);
@@ -97,6 +300,12 @@ void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
}
int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+#ifdef TFLITE_EMULATE_FLOAT
+ int64_t result = (1 << input_integer_bits) - 1;
+ result <<= (31 - input_integer_bits);
+ result >>= input_left_shift;
+ return result;
+#else // TFLITE_EMULATE_FLOAT
const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
(1ll << (31 - input_integer_bits)) /
(1ll << input_left_shift);
@@ -104,6 +313,7 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
// After scaling the difference, the result would be at the maximum. Thus we
// must ensure that our value has lower magnitude.
return static_cast<int>(std::floor(max_input_rescaled));
+#endif // TFLITE_EMULATE_FLOAT
}
void NudgeQuantizationRange(const float min, const float max,
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 9ee4a47fbb..d74a1bac97 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -195,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift);
+// Splits a double input value into a returned fraction, and a shift value from
+// the exponent, using only bitwise and integer operations to support
+// microcontrollers and other environments without floating-point support.
+//
+// This is designed to be a replacement for how std::frexp() is used within the
+// QuantizeMultiplier() function, and so has a different signature than the
+// standard version, returning a 64-bit integer rather than a double. This
+// result has a maximum value of 1<<31, with the fraction expressed as a
+// proportion of that maximum.
+//
+// std::frexp() returns NaNs and infinities unmodified, but since we're
+// returning integers that can't represent those values, instead we return
+// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64
+// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and
+// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will
+// result in return values that end up truncating some bits at the end,
+// reflecting the loss of precision inherent in denormalization.
+int64_t IntegerFrExp(double input, int* shift);
+
+// Converts an integer fraction in the format produced by IntegerFrExp (where
+// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an
+// IEEE binary64 double format result. The implementation uses only integer and
+// bitwise operators, so no floating point hardware support or emulation is
+// needed. This is here so quantized operations can run non-time-critical
+// preparation calculations on microcontrollers and other platforms without
+// float support.
+double DoubleFromFractionAndShift(int64_t fraction, int shift);
+
+// Performs a multiplication of two numbers in double format, using only integer
+// and bitwise instructions. This is aimed at supporting housekeeping functions
+// for quantized operations on microcontrollers without floating-point hardware.
+double IntegerDoubleMultiply(double a, double b);
+
+// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is
+// greater than b. It is implemented using only integer and logical instructions
+// so that it can be easily run on microcontrollers for quantized operations.
+int IntegerDoubleCompare(double a, double b);
+
// This first creates a multiplier in a double equivalent of
// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
// precision in the double's fractional bits. It then splits the result into
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 00fc3e91dc..14281f25c6 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -191,6 +191,139 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
EXPECT_EQ(qp.zero_point, 255);
}
+TEST(QuantizationUtilTest, IntegerFrExp) {
+ int shift;
+ int64_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(0, result);
+ EXPECT_EQ(0, shift);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(-1, shift);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(-(1 << 30), result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(2071147315, result, 1);
+ EXPECT_EQ(7, shift);
+
+ result = IntegerFrExp(NAN, &shift);
+ EXPECT_NEAR(0, result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(-INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+}
+
+TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
+ int shift;
+ int32_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(result, 0);
+ EXPECT_EQ(shift, 0);
+
+ int double_shift;
+ double double_result = std::frexp(0.0, &double_shift);
+ EXPECT_EQ(double_result, 0);
+ EXPECT_EQ(double_shift, 0);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(1.0, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, -1);
+ double_result = std::frexp(0.25, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, -1);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(result, -(1 << 30), 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(-1.0, &double_shift);
+ EXPECT_NEAR(double_result, -0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+ EXPECT_EQ(shift, 7);
+ double_result = std::frexp(123.45, &double_shift);
+ EXPECT_NEAR(double_result, 0.964453, 1e-5);
+ EXPECT_EQ(double_shift, 7);
+}
+
+TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
+ double result = DoubleFromFractionAndShift(0, 0);
+ EXPECT_EQ(0, result);
+
+ result = DoubleFromFractionAndShift(0x40000000, 1);
+ EXPECT_NEAR(1.0, result, 1e-5);
+
+ result = DoubleFromFractionAndShift(0x40000000, 2);
+ EXPECT_NEAR(2.0, result, 1e-5);
+
+ int shift;
+ int64_t fraction = IntegerFrExp(3.0, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(3.0, result, 1e-5);
+
+ fraction = IntegerFrExp(123.45, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(123.45, result, 1e-5);
+
+ fraction = IntegerFrExp(-23.232323, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(-23.232323, result, 1e-5);
+
+ fraction = IntegerFrExp(NAN, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_TRUE(std::isnan(result));
+
+ fraction = IntegerFrExp(INFINITY, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_FALSE(std::isfinite(result));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
+ EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
+ EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
+ EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
+ EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleCompare) {
+ EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
+ EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
+ EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
+}
+
#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index b862ae38c7..683ccdc74d 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -27,6 +27,28 @@ namespace tflite {
namespace reference_ops {
template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
@@ -42,20 +64,29 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
inline void Relu(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Relu1(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu1(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Relu6(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Relu6(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::ActivationParams params;
+ params.quantized_activation_max = max_value;
+ params.quantized_activation_min = min_value;
+ ReluX(params, input_shape, input_data, output_shape, output_data);
}
template <FusedActivationFunctionType Ac>
@@ -311,6 +342,30 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
+// Legacy.
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int32 input1_offset, const uint8* input2_data,
const Dims<4>& input2_dims, int32 input2_offset,
@@ -583,8 +638,8 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
@@ -598,14 +653,14 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
int16* output_data, const Dims<4>& output_dims) {
- Logistic(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- Tanh(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
@@ -624,6 +679,377 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No params in this version.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.output_offset = output_offset;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, T* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<float>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = pad_value;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = 0;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename Op>
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
+ MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, op);
+}
+
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims),
+ output_data, std::greater<T1>());
+}
+
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+ ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data, cmp);
+}
+
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data, func);
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction(const T1* input1_data,
+ const Dims<4>& input1_dims,
+ const T2* input2_data,
+ const Dims<4>& input2_dims, R* output_data,
+ const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
+ const T2* input2_data, const Dims<4>& input2_dims,
+ R* output_data, const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index aa93e857d7..77e60adc18 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <string.h>
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -151,6 +151,16 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = vector[v] * *batch_vector++;
+ }
+ }
+}
+
void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
@@ -163,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
}
}
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int i = 0; i < v_size; ++i) {
+ batch_vector[i] += vector[i];
+ }
+ batch_vector += v_size;
+ }
+}
+
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector) {
for (int b = 0; b < n_batch; b++) {
@@ -233,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
}
}
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (int i = 0; i < v_size; ++i) {
+ sum += input_vector[i];
+ sum_sq += input_vector[i] * input_vector[i];
+ }
+ const float mean = sum / v_size;
+ float stddev_inv = 0.0f;
+ const float variance = sum_sq / v_size - mean * mean;
+ if (variance == 0) {
+ stddev_inv = 1.0f / sqrt(normalization_epsilon);
+ } else {
+ stddev_inv = 1.0f / sqrt(variance);
+ }
+ for (int i = 0; i < v_size; ++i) {
+ output_vector[i] = (input_vector[i] - mean) * stddev_inv;
+ }
+ input_vector += v_size;
+ output_vector += v_size;
+ }
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index a375aaffa6..714b1164ee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -69,6 +69,11 @@ void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -82,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Apply sigmoid to elements of a vector.
void PortableApplySigmoidToVector(const float* vector, int v_size,
float* result);
@@ -120,6 +129,12 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
void PortableReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
bool IsZeroVector(const float* vector, int v_size) {
@@ -161,6 +176,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ PortableVectorBatchVectorCwiseProduct(vector, v_size, batch_vector, n_batch,
+ result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -181,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -228,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 5634b8384a..0abacf85e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -19,11 +19,11 @@ limitations under the License.
#include <sys/types.h>
#include <algorithm>
#include <cmath>
+#include <functional>
#include <limits>
#include <memory>
#include <type_traits>
-#include "third_party/eigen3/Eigen/Core"
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
@@ -110,6 +110,11 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
}
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
@@ -407,18 +412,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width * block_size, output_width);
TFLITE_DCHECK_EQ(input_height * block_size, output_height);
@@ -437,9 +453,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
const int in_h = out_h / block_size;
const int in_b = out_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -449,18 +465,29 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
@@ -478,9 +505,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
const int out_h = in_h / block_size;
const int out_b = in_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -803,51 +830,8 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
output_activation_max, output_data, output_dims, gemm_context);
}
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- output_data[b * inner_size + i] = ActivationFunction<Ac>(
- (input_data[b * inner_size + i] - mean_data[i]) * multiplier_data[i] +
- offset_data[i]);
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- output_data[depth * i + c] = ActivationFunction<Ac>(
- (input_data[depth * i + c] - mean_data[c]) * multiplier_data[c] +
- offset_data[c]);
- }
- }
-}
-
-inline void Relu(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
@@ -857,8 +841,8 @@ inline void Relu(const float* input_data, const RuntimeShape& input_shape,
}
}
-inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu1(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -870,8 +854,8 @@ inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
}
}
-inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
@@ -883,11 +867,13 @@ inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
}
}
-inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
- const RuntimeShape& input_shape, uint8* output_data,
- const RuntimeShape& output_shape) {
+inline void ReluX(const tflite::ActivationParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ const uint8 max_value = params.quantized_activation_max;
+ const uint8 min_value = params.quantized_activation_min;
for (int i = 0; i < flat_size; ++i) {
const uint8 val = input_data[i];
const uint8 clamped =
@@ -896,10 +882,11 @@ inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
}
}
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -966,15 +953,17 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data,
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
@@ -1320,11 +1309,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
}
template <typename T>
-inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@@ -1332,52 +1326,57 @@ inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] *
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1385,19 +1384,6 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -1526,62 +1512,14 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
}
}
-// Transitional version that will be moved shortly to legacy_reference_ops, as
-// part of RuntimeShape revisions.
-inline void BroadcastMul4DSlow(const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- input1_val * input2_val, output_multiplier, output_shift);
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, unclamped_result));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
- }
-}
-
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1593,15 +1531,18 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
+ int32 output_offset = params.output_offset;
+ int32 output_activation_min = params.quantized_activation_min;
+ int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1624,15 +1565,27 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1645,14 +1598,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for
// the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1660,12 +1613,32 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <typename T>
-inline void Div(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
@@ -1673,6 +1646,21 @@ inline void Div(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -1968,32 +1956,43 @@ inline void SubWithActivation(const ArithmeticParams& params,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- int concat_size = 0;
+template <typename Scalar>
+inline void Concatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- // For now we don't have a model with a Concatenation with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
+ }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= output_shape.Dims(i);
}
+
Scalar* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
memcpy(output_ptr, input_data[i] + k * copy_size,
copy_size * sizeof(Scalar));
output_ptr += copy_size;
@@ -2001,6 +2000,125 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// TODO(prabhumk): This is the same as the optimized implementation.
+// TODO(prabhumk): The quantized implementation of concatentation isn't fully
+// quantized as it takes scale as a floating point value. This should be fixed
+// when optimizng this routine further.
+
+// template <>
+inline void ConcatenationWithScaling(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
+ // The arguments input_zeropoint and input_scale are expected to be an array
+ // that have the quantization parameters for all the inputs to the concat
+ // operator.
+ TFLITE_DCHECK_GT(inputs_count, 1);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ int64_t concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), 4);
+ for (int j = 0; j < 4; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
+ }
+ }
+ concat_size += input_shapes[i]->Dims(axis);
+ }
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
+ }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < 4; ++i) {
+ base_inner_size *= output_shape.Dims(i);
+ }
+ const float inverse_output_scale = 1.f / output_scale;
+ uint8* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
+ const uint8* input_ptr = input_data[i] + k * copy_size;
+ if (input_zeropoint[i] == output_zeropoint &&
+ input_scale[i] == output_scale) {
+ memcpy(output_ptr, input_ptr, copy_size);
+ } else {
+ const float scale = input_scale[i] * inverse_output_scale;
+ const float bias = -input_zeropoint[i] * scale;
+ for (int j = 0; j < copy_size; ++j) {
+ const int32_t value =
+ static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
+ output_zeropoint;
+ output_ptr[j] =
+ static_cast<uint8_t>(std::max(std::min(255, value), 0));
+ }
+ }
+ output_ptr += copy_size;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename Scalar>
void Pack(int dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
@@ -2021,48 +2139,50 @@ void Pack(int dim, const Scalar* const* input_data,
}
}
-// TODO(prabhumk): This is the same as the optimized implementation.
-// TODO(prabhumk): The quantized implementation of concatentation isn't fully
-// quantized as it takes scale as a floating point value. This should be fixed
-// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- TFLITE_DCHECK_GT(inputs_count, 1);
- int64_t concat_size = 0;
- for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
- }
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ int outer_size = 1;
+ for (int i = dimensions - axis; i < 4; i++) {
+ outer_size *= input_dims.sizes[i];
+ }
+
+ const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ Scalar* output_ptr = output_datas[i] + copy_size * k;
+ int loc = k * outputs_count * copy_size + i * copy_size;
+ memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- int64_t outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, const int32* input_zeropoint,
+ const float* input_scale, int inputs_count, Scalar* output_data,
+ const Dims<4>& output_dims, const int32 output_zeropoint,
+ const float output_scale) {
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ int outer_size = 1;
+ for (int i = dim + 1; i < 4; i++) {
outer_size *= output_dims.sizes[i];
}
+ Scalar* output_ptr = output_data;
+ const int copy_size = FlatSize(**input_dims) / outer_size;
const float inverse_output_scale = 1.f / output_scale;
- uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- const uint8* input_ptr = input_data[i] + k * copy_size;
if (input_zeropoint[i] == output_zeropoint &&
input_scale[i] == output_scale) {
- memcpy(output_ptr, input_ptr, copy_size);
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
} else {
+ assert(false);
const float scale = input_scale[i] * inverse_output_scale;
const float bias = -input_zeropoint[i] * scale;
+ auto input_ptr = input_data[i];
for (int j = 0; j < copy_size; ++j) {
const int32_t value =
static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
@@ -2404,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
template <typename Scalar>
+void Split(const SplitParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape* const* output_shapes,
+ Scalar* const* output_data) {
+ const int concat_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
+ int outputs_count = params.num_split;
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
+ for (int i = 0; i < outputs_count; i++) {
+ TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*output_shapes[i], j, input_shape, j);
+ }
+ }
+ concat_size += output_shapes[i]->Dims(axis);
+ }
+ TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
+ }
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= input_shape.Dims(i);
+ }
+
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr,
+ copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int axis, int outputs_count, Scalar* const* output_data,
const Dims<4>* const* output_dims) {
- const int batches = ArraySize(*output_dims[0], 3);
- const int height = ArraySize(*output_dims[0], 2);
- const int width = ArraySize(*output_dims[0], 1);
- const int depth = ArraySize(*output_dims[0], 0);
-
- const int slice_size = ArraySize(*output_dims[0], axis);
-
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
for (int i = 0; i < outputs_count; ++i) {
- int offset = i * slice_size * input_dims.strides[axis];
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- auto out = Offset(*output_dims[i], c, x, y, b);
- auto in = Offset(input_dims, c, x, y, b);
- output_data[i][out] = input_data[offset + in];
- }
- }
- }
- }
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
}
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <FusedActivationFunctionType Ac, typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int outputs_count, Scalar* const* output_data,
@@ -2440,44 +2597,13 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
/* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
/* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
}
- // for now we dont have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
output_data, output_dims);
}
-// TODO(benoitjacob) make this a proper reference impl without Eigen!
-template <typename Scalar>
-using MatrixMap = typename std::conditional<
- std::is_const<Scalar>::value,
- Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
- Eigen::Dynamic, Eigen::Dynamic>>,
- Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
- const Dims<N>& dims) {
- const int cols = dims.sizes[N - 1];
- int rows = 1;
- for (int d = 0; d < N - 1; d++) {
- rows *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
@@ -2750,24 +2876,27 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
}
}
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
for (int c = 0; c < depth; ++c) {
- const int begin_input_c = std::max(0, c - range);
- const int end_input_c = std::min(depth, c + range);
+ const int begin_input_c = std::max(0, c - op_params.range);
+ const int end_input_c = std::min(depth, c + op_params.range);
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_data[i * depth + input_c];
accum += input_val * input_val;
}
- const float multiplier = std::pow(bias + alpha * accum, -beta);
+ const float multiplier =
+ std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
}
}
@@ -3118,8 +3247,8 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3167,8 +3296,8 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3185,8 +3314,8 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3269,10 +3398,12 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ int32 zero_point = op_params.zero_point;
+ double scale = op_params.scale;
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int32 val = input_data[i];
@@ -3281,9 +3412,25 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const tflite::FakeQuantParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ float rmin = op_params.minmax.min;
+ float rmax = op_params.minmax.max;
+ int num_bits = op_params.num_bits;
// 0 should always be a representable value. Let's assume that the initial
// min,max range contains 0.
TFLITE_DCHECK_LE(rmin, 0.0f);
@@ -3296,15 +3443,29 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
float nudged_min, nudged_max, nudged_scale;
NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
&nudged_max, &nudged_scale);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
output_data, flat_size);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -3312,9 +3473,9 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
}
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -3323,45 +3484,90 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
+inline void Gather(const tflite::GatherParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& coords_shape, const int32* coords_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Enable these checks when moving legacy ops to legacy_reference_ops.
+ //
+ // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+ const int input_rank = op_params.input_rank;
+ const int gather_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+ const int axis = gather_dimensions - input_rank;
+ TFLITE_DCHECK_LT(axis, gather_dimensions);
+ TFLITE_DCHECK_GE(axis, 0);
+ const int coords_count = coords_shape.FlatSize();
+ TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis));
+
+ int64_t stride = 1;
+ for (int i = axis + 1; i < gather_dimensions; ++i) {
+ stride *= input_shape.Dims(i);
+ }
T* out = output_data;
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ for (int i = 0; i < coords_count; ++i) {
TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis));
const T* in = input_data + coords_data[i] * stride;
memcpy(out, in, sizeof(T) * stride);
out += stride;
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4> version.
+// When moving legacy ops to legacy_reference_ops, replace content with looser
+// implementation.
template <typename T>
-inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, T* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_size_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
+ int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
@@ -3376,80 +3582,72 @@ inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
int32 x1 = std::min(x0 + 1, input_width - 1);
for (int c = 0; c < depth; ++c) {
T interpolation =
- static_cast<T>(input_data[Offset(input_dims, c, x0, y0, b)] *
+ static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
(1 - (input_y - y0)) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x0, y1, b)] *
+ input_data[Offset(input_shape, b, y1, x0, c)] *
(input_y - y0) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x1, y0, b)] *
+ input_data[Offset(input_shape, b, y0, x1, c)] *
(1 - (input_y - y0)) * (input_x - x0) +
- input_data[Offset(input_dims, c, x1, y1, b)] *
+ input_data[Offset(input_shape, b, y1, x1, c)] *
(input_y - y0) * (input_x - x0));
- output_data[Offset(output_dims, c, x, y, b)] = interpolation;
+ output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<float>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
+template <typename T>
+inline void SpaceToBatchND(
+ const SpaceToBatchParams& params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* paddings_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims,
- const int32_t pad_value) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
const int block_shape_height = block_shape_data[0];
const int block_shape_width = block_shape_data[1];
const int padding_top = paddings_data[0];
const int padding_left = paddings_data[2];
+ // For uint8 quantized, the correct padding "zero value" is the output offset.
+ const int32_t pad_value = params.output_offset;
+
for (int out_b = 0; out_b < output_batch_size; ++out_b) {
int input_batch = out_b % input_batch_size;
int shift_w = (out_b / input_batch_size) % block_shape_width;
int shift_h = (out_b / input_batch_size) / block_shape_width;
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
if (out_h * block_shape_height + shift_h < padding_top ||
out_h * block_shape_height + shift_h >=
padding_top + input_height ||
out_w * block_shape_width + shift_w < padding_left ||
out_w * block_shape_width + shift_w >= padding_left + input_width) {
+ // This may not execute correctly when pad_value != 0 and T != uint8.
memset(out, pad_value, depth * sizeof(T));
} else {
const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
+ input1_data +
+ Offset(input1_shape, input_batch,
(out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
+ (out_w * block_shape_width + shift_w) - padding_left, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3458,29 +3656,27 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims,
- paddings_data, paddings_dims, output_data, output_dims, 0);
-}
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
-template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -3502,8 +3698,9 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
if (out_w < 0 || out_w >= output_width) {
continue;
}
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3609,103 +3806,111 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- // SetFloatOrInt(pad_value, &op_params.pad_value);
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 3, start_b);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 2, start_h);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 1, start_w);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 0, start_d);
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -3747,22 +3952,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {
for (size_t idx = 0; idx < num_elements; ++idx) {
@@ -3858,39 +4047,16 @@ inline bool InitTensorDataForReduce(const int* dims, const int num_dims,
return true;
}
-// Computes the sum of elements across dimensions given in axis.
+// Computes the generic value (i.e., sum/max/min/prod) of elements across
+// dimensions given in axis. It needs to pass in init_value and reducer.
template <typename T>
-inline bool Sum(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis) {
- // Reset output data.
- if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(0),
- output_data)) {
- return false;
- }
-
- // Resolve axis.
- int num_resolved_axis = 0;
- if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
- &num_resolved_axis)) {
- return false;
- }
-
- return ReduceSumImpl<T, T>(input_data, input_dims, output_dims,
- input_num_dims, output_num_dims, resolved_axis,
- num_resolved_axis, temp_index, output_data);
-}
-
-// Computes the max of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceMax(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = std::numeric_limits<T>::lowest();
+inline bool ReduceGeneric(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis,
+ T init_value,
+ T reducer(const T current, const T in)) {
// Reset output data.
if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
output_data)) {
@@ -3904,35 +4070,6 @@ inline bool ReduceMax(const T* input_data, const int* input_dims,
return false;
}
- auto reducer = [](const T current, const T in) -> T {
- return (in > current) ? in : current;
- };
- return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
- output_num_dims, resolved_axis, num_resolved_axis,
- temp_index, reducer, output_data);
-}
-
-// Computes the prod of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceProd(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- // Reset output data.
- if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(1),
- output_data)) {
- return false;
- }
-
- // Resolve axis.
- int num_resolved_axis = 0;
- if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
- &num_resolved_axis)) {
- return false;
- }
-
- auto reducer = [](const T current, const T in) -> T { return in * current; };
return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
output_num_dims, resolved_axis, num_resolved_axis,
temp_index, reducer, output_data);
@@ -3996,22 +4133,32 @@ inline bool Mean(const T* input_data, const int* input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
+inline void Mean(const tflite::MeanParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_batch = output_shape.Dims(0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int output_depth = output_shape.Dims(3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
// The current implementation only supports simultaneous reduction over
// width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+ TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+ (op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
@@ -4020,15 +4167,95 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
float value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
value / (input_width * input_height);
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis for quantized values.
+template <typename T, typename U>
+inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
+ const int* input_dims, const int input_num_dims,
+ T* output_data, int32 output_zero_point, float output_scale,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum) {
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
+ }
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] = T();
+ temp_sum[idx] = U();
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, temp_sum)) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
+ for (int idx = 0; idx < num_resolved_axis; ++idx) {
+ size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ // Overflow prevention.
+ if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+ return false;
+ }
+ num_elements_in_axis *= current;
+ }
+
+ if (num_elements_in_axis > 0) {
+ const float scale = input_scale / output_scale;
+ const float bias = -input_zero_point * scale;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] =
+ static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ }
+ }
+ return true;
+}
+
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4053,38 +4280,31 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data, Op op) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- auto out_idx = Offset(output_dims, c, x, y, b);
- auto in1_idx = SubscriptToIndex(desc1, c, x, y, b);
- auto in2_idx = SubscriptToIndex(desc2, c, x, y, b);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = op(in1_val, in2_val);
@@ -4095,8 +4315,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4104,15 +4325,19 @@ void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = ArraySize(input_dims, 0);
+ const int trailing_dim = output_shape.DimensionsCount() - 1;
+ TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(),
+ output_shape.DimensionsCount());
+ TFLITE_DCHECK_EQ(output_shape.Dims(trailing_dim), 1);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input1_shape, trailing_dim, output_shape);
+ const int depth = input1_shape.Dims(trailing_dim);
for (int i = 0; i < outer_size; ++i) {
- auto min_max_value = input_data[i * depth];
+ auto min_max_value = input1_data[i * depth];
int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
- const auto& curr_value = input_data[i * depth + d];
+ const auto& curr_value = input1_data[i * depth + d];
if (cmp(curr_value, min_max_value)) {
min_max_value = curr_value;
min_max_index = d;
@@ -4122,12 +4347,11 @@ void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
}
}
-// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data,
- const tflite::Dims<4>& input_dims, T2* output_data,
- const tflite::Dims<4>& output_dims) {
- ArgMinMax(axis, input_data, input_dims, output_data, output_dims,
+void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data) {
+ ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
std::greater<T1>());
}
@@ -4254,26 +4478,56 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
+inline void ComparisonImpl(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
-template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape, bool* output_data) {
+ ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void ComparisonWithScaling(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
@@ -4281,68 +4535,140 @@ inline void Comparison(int left_shift, const T* input1_data,
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
+inline void BroadcastComparison4DSlowImpl(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
+ F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ bool* output_data) {
+ BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+ input2_shape, input2_data,
+ output_shape, output_data);
+}
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison4DSlowWithScaling(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- output_data[Offset(output_dims, c, x, y, b)] =
+ shifted_input2_val, input2_multiplier, input2_shift);
+ output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
@@ -4350,51 +4676,117 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
}
}
-#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ } \
+ inline void name(const ComparisonParams& op_params, \
+ const RuntimeShape& input1_shape, const float* input1_data, \
+ const RuntimeShape& input2_shape, const float* input2_data, \
+ const RuntimeShape& output_shape, bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
+ input2_data, output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ inline void Broadcast4DSlow##name( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const float* input1_data, const RuntimeShape& input2_shape, \
+ const float* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
}
TFLITE_COMPARISON_OP(Equal);
TFLITE_COMPARISON_OP(NotEqual);
@@ -4474,61 +4866,81 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
}
template <typename T>
-inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = std::pow(input1_data[i], input2_data[i]);
}
}
template <typename T>
-inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
+ const T* input1_data,
+ const RuntimeShape& input2_shape,
+ const T* input2_data,
+ const RuntimeShape& output_shape,
+ T* output_data) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
+ &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = std::pow(in1_val, in2_val);
}
}
}
}
}
-inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
- const bool* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims,
+inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
+ const RuntimeShape& input2_shape, const bool* input2_data,
+ const RuntimeShape& output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = func(input1_data[i], input2_data[i]);
}
}
-inline void BroadcastLogical(const bool* input1_data,
- const Dims<4>& input1_dims,
- const bool* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims,
- const std::function<bool(bool, bool)>& func) {
+inline void BroadcastLogical4DSlow(
+ const RuntimeShape& unextended_input1_shape, const bool* input1_data,
+ const RuntimeShape& unextended_input2_shape, const bool* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data,
+ const std::function<bool(bool, bool)>& func) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
}
}
}
@@ -4538,30 +4950,58 @@ inline void BroadcastLogical(const bool* input1_data,
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
// generalized and efficient BroadcastBinaryFunction.
//
+// Also appears to duplicte MinimumMaximum.
+//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction(const T1* input1_data,
- const Dims<4>& input1_dims,
- const T2* input2_data,
- const Dims<4>& input2_dims, R* output_data,
- const Dims<4>& output_dims,
- R (*func)(T1, T2)) {
+inline void BroadcastBinaryFunction4DSlow(
+ const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+ const RuntimeShape& unextended_output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = func(in1_val, in2_val);
}
}
}
}
}
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+// TODO(renjieliu): Refactor other binary functions to use this one.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
index 3d8765f11b..15df31f75a 100644
--- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -28,14 +28,12 @@ template <typename T>
void TestOneResizeBilinear(int batch, int depth, int input_width,
int input_height, int output_width,
int output_height, float error_threshold) {
- Dims<4> input_dims_inference =
- MakeDimsForInference(depth, input_width, input_height, batch);
- Dims<4> output_dims_inference =
- MakeDimsForInference(depth, output_width, output_height, batch);
+ RuntimeShape input_dims_inference({batch, input_height, input_width, depth});
+ RuntimeShape output_dims_inference(
+ {batch, output_height, output_width, depth});
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int output_buffer_size =
- RequiredBufferSizeForDims(output_dims_inference);
+ const int input_buffer_size = input_dims_inference.FlatSize();
+ const int output_buffer_size = output_dims_inference.FlatSize();
std::vector<T> input_data(input_buffer_size, 0);
std::vector<T> reference_output_data(output_buffer_size, 0);
@@ -47,15 +45,19 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
const T max_amplitude = static_cast<T>(255);
FillRandom(&input_data, min_amplitude, max_amplitude);
- Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
+ RuntimeShape output_size_dims({1, 1, 1, 2});
std::vector<int32> output_size_data = {output_height, output_width};
- reference_ops::ResizeBilinear(
- input_data.data(), input_dims_inference, output_size_data.data(),
- output_size_dims, reference_output_data.data(), output_dims_inference);
- optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference,
- output_size_data.data(), output_size_dims,
- output_data.data(), output_dims_inference);
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = false;
+
+ reference_ops::ResizeBilinear(op_params, input_dims_inference,
+ input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference,
+ reference_output_data.data());
+ optimized_ops::ResizeBilinear(
+ op_params, input_dims_inference, input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference, output_data.data());
double sum_diff = 0;
float max_abs_val = 0;
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index 5994fad5c7..af5db1064c 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-
namespace strided_slice {
// Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) {
return v;
}
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+ int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ TFLITE_CHECK_LE(dim_count, 4);
+ TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ const int pad_count = dim_count - p->start_indices_count;
+
+ // Pad indices at start, so move arrays by pad_count.
+ for (int i = p->start_indices_count - 1; i > 0; --i) {
+ p->strides[i + pad_count] = p->strides[i];
+ p->start_indices[i + pad_count] = p->start_indices[i];
+ p->stop_indices[i + pad_count] = p->stop_indices[i];
+ }
+ for (int i = 0; i < pad_count; ++i) {
+ p->start_indices[i] = 0;
+ p->stop_indices[i] = 0;
+ p->strides[i] = 1;
+ }
+
+ // Pad masks with 0s or 1s as required.
+ p->shrink_axis_mask <<= pad_count;
+ p->ellipsis_mask <<= pad_count;
+ p->new_axis_mask <<= pad_count;
+ p->begin_mask <<= pad_count;
+ p->end_mask <<= pad_count;
+ p->begin_mask |= (1 << pad_count) - 1;
+ p->end_mask |= (1 << pad_count) - 1;
+
+ p->start_indices_count = dim_count;
+ p->stop_indices_count = dim_count;
+ p->strides_count = dim_count;
+}
+
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
- std::vector<IntType> const& start_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
- // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis) {
+ const auto begin_mask = params.begin_mask;
+ const auto* start_indices = params.start_indices;
+ const auto* strides = params.strides;
+ // Begin with the specified index.
int start = start_indices[axis];
// begin_mask override
@@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask,
// element. ie. So if you were iterating through all elements of a 1D array of
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, int shrink_axis_mask,
- std::vector<IntType> const& stop_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis, int start_for_axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis,
+ int start_for_axis) {
+ const auto end_mask = params.end_mask;
+ const auto shrink_axis_mask = params.shrink_axis_mask;
+ const auto* stop_indices = params.stop_indices;
+ const auto* strides = params.strides;
+
// Begin with the specified index
const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
@@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask,
}
// Handle negative indices
- const int axis_size = input_shape[axis];
+ const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}
@@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) {
return stride > 0 ? index >= stop : index <= stop;
}
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+ const std::vector<int>& strides) {
+ tflite::StridedSliceParams op_params;
+ const int dims_count = start_indices.size();
+
+ op_params.start_indices_count = dims_count;
+ op_params.stop_indices_count = dims_count;
+ op_params.strides_count = dims_count;
+ for (int i = 0; i < dims_count; ++i) {
+ op_params.start_indices[i] = start_indices[i];
+ op_params.stop_indices[i] = stop_indices[i];
+ op_params.strides[i] = strides[i];
+ }
+
+ op_params.begin_mask = begin_mask;
+ op_params.ellipsis_mask = 0;
+ op_params.end_mask = end_mask;
+ op_params.new_axis_mask = 0;
+ op_params.shrink_axis_mask = shrink_axis_mask;
+
+ return op_params;
+}
+
} // namespace strided_slice
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ee2af5b460..13106456df 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -17,44 +17,12 @@ limitations under the License.
#include <complex>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int16_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr
@@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
: nullptr;
}
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
-template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr
@@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
return GetTensorDims(data.data(), data.size());
}
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
-inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return RuntimeShape();
- }
-
- auto* dims = tensor->dims;
- return RuntimeShape(dims->size, dims->data);
-}
-
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000000..77e22a08b4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ TfLiteIntArray* dims = tensor->dims;
+ const int dims_size = dims->size;
+ const int32_t* dims_data = dims->data;
+ return RuntimeShape(dims_size, dims_data);
+}
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 1ff8cfe39c..b0fe5adf65 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -101,6 +101,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -108,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result);
+// Add another vector for each batch in the batch vector.
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector);
@@ -147,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value);
// added to get one element of output.
void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon);
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index e8343f1223..6458af714b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
@@ -496,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
{1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
}
+TEST(uKernels, VectorBatchVectorAddTest) {
+ constexpr int kVectorSize = 3;
+ constexpr int kBatchSize = 2;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0};
+ std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output,
+ testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
+}
+
TEST(uKernels, VectorBatchVectorAssignTest) {
constexpr int kVectorSize = 5;
constexpr int kBatchSize = 3;
@@ -555,6 +565,120 @@ TEST(uKernels, ZeroVectorTest) {
ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
}
+TEST(uKernels, VectorBatchVectorCwiseProductAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
+ kBatchSize, output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 2.310000, 7.040000, 14.190000, 23.760000, 35.750000, 50.159996, 66.989998,
+ 86.240005, 107.909996, 112.110008, 134.542084, 159.014389, 185.526901,
+ 214.079605, 244.672485, 277.305603, 311.978912, 348.692413, 387.446136,
+ 428.240051, 471.074066, 515.948364, 562.862854, 611.817566, 662.812500,
+ 715.847595, 770.922974, 828.038452, 0.000000,
+ /* batch 1 */
+ -2.310000, -7.040000, -14.190000, -23.760000, -35.750000, -50.159996,
+ -66.989998, -86.240005, -107.909996, -112.110008, -134.542084,
+ -159.014389, -185.526901, -214.079605, -244.672485, -277.305603,
+ -311.978912, -348.692413, -387.446136, -428.240051, -471.074066,
+ -515.948364, -562.862854, -611.817566, -662.812500, -715.847595,
+ -770.922974, -828.038452, 0.000000,
+ /* batch 2 */
+ 2.310000, -7.040000, 14.190000, -23.760000, 35.750000, -50.159996,
+ 66.989998, -86.240005, 107.909996, -112.110008, 134.542084, -159.014389,
+ 185.526901, -214.079605, 244.672485, -277.305603, 311.978912, -348.692413,
+ 387.446136, -428.240051, 471.074066, -515.948364, 562.862854, -611.817566,
+ 662.812500, -715.847595, 770.922974, -828.038452, 0.000000,
+ /* batch 3 */
+ -2.310000, 7.040000, -14.190000, 23.760000, -35.750000, 50.159996,
+ -66.989998, 86.240005, -107.909996, 112.110008, -134.542084, 159.014389,
+ -185.526901, 214.079605, -244.672485, 277.305603, -311.978912, 348.692413,
+ -387.446136, 428.240051, -471.074066, 515.948364, -562.862854, 611.817566,
+ -662.812500, 715.847595, -770.922974, 828.038452, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
+ output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
+ 77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
+ 199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
+ 408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
+ 689.587585, 743.652954, 799.758423, 0.000000,
+ /* batch 1 */
+ -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
+ -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
+ -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
+ -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
+ -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
+ -799.758423, 0.000000,
+ /* batch 2 */
+ 1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
+ 59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
+ 172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
+ 368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
+ 637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
+ /* batch 3 */
+ -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
+ -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
+ -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
+ -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
+ -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
constexpr int kVectorSize = 5;
constexpr int kBatch = 2;
@@ -598,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) {
EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
}
+TEST(uKernels, MeanStddevNormalizationNoneZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // None-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.1, 0.2, 0.3, 0.4, // batch 0
+ 0.9, 1.0, 1.1, 1.2, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0
+ -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationAllZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationMixed) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.1, 0.2, 0.3, 0.4, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationSmallValue) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 3e-5, -7e-6, -9e-5, 1e-6, // batch 0
+ 4e-5, 9e-6, 2e-4, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0
+ -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 204df9ab19..c4c7cf3842 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -668,9 +668,9 @@ static_assert(sizeof(MinMax) == 8, "");
struct ActivationParams {
FusedActivationFunctionType activation_type;
- // Quantized inference params.
- int32 activation_min;
- int32 activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
};
// For Add, Sub, Mul ops.
@@ -710,17 +710,22 @@ struct ArithmeticParams {
struct ConcatenationParams {
int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
};
struct ComparisonParams {
// uint8 inference params.
int left_shift;
- int32 input0_offset;
- int32 input0_multiplier;
- int input0_shift;
int32 input1_offset;
int32 input1_multiplier;
int input1_shift;
+ int32 input2_offset;
+ int32 input2_multiplier;
+ int input2_shift;
// Shape dependent / common to inference types.
bool is_broadcast;
};
@@ -745,7 +750,7 @@ struct ConvParams {
};
struct DepthToSpaceParams {
- int16 block_size;
+ int32 block_size;
};
struct DepthwiseParams {
@@ -764,6 +769,11 @@ struct DepthwiseParams {
int32 output_activation_max;
};
+struct DequantizationParams {
+ double scale;
+ int32 zero_point;
+};
+
struct FakeQuantParams {
MinMax minmax;
int32 num_bits;
@@ -871,14 +881,20 @@ struct SoftmaxParams {
int diff_min;
};
+struct SpaceToBatchParams {
+ // "Zero" padding for uint8 means padding with the output offset.
+ int32 output_offset;
+};
+
struct SpaceToDepthParams {
- int16 block_size;
+ int32 block_size;
};
struct SplitParams {
// Graphs that split into, say, 2000 nodes are encountered. The indices in
// OperatorEdges are of type uint16.
uint16 num_split;
+ int16 axis;
};
struct SqueezeParams {
@@ -908,23 +924,30 @@ struct TanhParams {
int input_left_shift;
};
-template <typename T>
-inline void SetActivationParams(T min, T max, ArithmeticParams* params);
-
-template <>
-inline void SetActivationParams(float min, float max,
- ArithmeticParams* params) {
+template <typename P>
+inline void SetActivationParams(float min, float max, P* params) {
params->float_activation_min = min;
params->float_activation_max = max;
}
-template <>
-inline void SetActivationParams(int32 min, int32 max,
- ArithmeticParams* params) {
+template <typename P>
+inline void SetActivationParams(int32 min, int32 max, P* params) {
params->quantized_activation_min = min;
params->quantized_activation_max = max;
}
+template <typename P>
+inline void GetActivationParams(const P& params, int32* min, int32* max) {
+ *min = params.quantized_activation_min;
+ *max = params.quantized_activation_max;
+}
+
+template <typename P>
+inline void GetActivationParams(const P& params, float* min, float* max) {
+ *min = params.float_activation_min;
+ *max = params.float_activation_max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index c8ce3c917d..e9a5fd7a40 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -16,9 +16,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#include <algorithm>
+#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
@@ -30,6 +31,11 @@ inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->inputs->data[index]];
}
+inline TfLiteTensor* GetVariableInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ return (tensor->is_variable) ? tensor : nullptr;
+}
inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index a7b54c6b84..e02d7df9ef 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -68,10 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorShape(input), \
- GetTensorData<float>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = 0; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +83,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = input->params.zero_point; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<uint8>(input), GetTensorShape(output), \
+ GetTensorData<uint8>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
new file mode 100644
index 0000000000..1bbea67b93
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -0,0 +1,1316 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Layer Normalization LSTM op that applies normalization by mean and standard
+// deviation to the activation of the LSTM layers. Please see
+// https://arxiv.org/abs/1607.06450 for details.
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace layer_norm_lstm {
+
+// Struct to hold Layer Norm LSTM option data.
+struct OpData {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+ int scratch_tensor_index;
+};
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kInputLayerNormWeightsTensor = 12;
+constexpr int kForgetLayerNormWeightsTensor = 13;
+constexpr int kCellLayerNormWeightsTensor = 14;
+constexpr int kOutputLayerNormWeightsTensor = 15;
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 16; // Optional
+constexpr int kForgetGateBiasTensor = 17;
+constexpr int kCellGateBiasTensor = 18;
+constexpr int kOutputGateBiasTensor = 19;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 20; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 21; // Optional
+
+// State tensors.
+constexpr int kInputActivationStateTensor = 22;
+constexpr int kInputCellStateTensor = 23;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Total number of scratch tensors for hybrid Op.
+constexpr int kTensorsToAdd = 7;
+
+// Small float to avoid divergence during calculation of deviation.
+const float kLayerNormEpsilon = 1e-8;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+
+ // Turn custom option data into flexbuffer map format.
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+ // Get activation function, cell_clip and proj_clip from the flexbuffer.
+ // TODO(b/113824099): make activation more generic.
+ assert(m["fused_activation_function"].ToString() == "TANH");
+ data->activation = kTfLiteActTanh;
+ data->cell_clip = m["cell_clip"].AsFloat();
+ data->proj_clip = m["proj_clip"].AsFloat();
+
+ // Populate scratch_tensor_index.
+ context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
+ TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Making sure layer norm weights are not null and have the right dimension.
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ const bool projection_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projection_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
+
+ // Get the pointer to output, activation_state and cell_state tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const TfLiteTensor* activation_state =
+ GetInput(context, node, kInputActivationStateTensor);
+ const TfLiteTensor* cell_state =
+ GetInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ // Resize the output tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ if (use_cifg) {
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ } else {
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
+ recovered_weights->type = kTfLiteFloat32;
+ recovered_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
+ recovered_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_weights,
+ recovered_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, float cell_clip, float proj_clip,
+ const TfLiteFusedActivation& activation, int n_batch, int n_cell,
+ int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ int n_batch, int n_cell, int n_input, int n_output,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// The LayerNormLSTM Op engine.
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+ const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+ const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+ const float* recurrent_to_forget_weights_ptr =
+ recurrent_to_forget_weights->data.f;
+ const float* recurrent_to_cell_weights_ptr =
+ recurrent_to_cell_weights->data.f;
+ const float* recurrent_to_output_weights_ptr =
+ recurrent_to_output_weights->data.f;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
+ forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
+ output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, activation_state_ptr, cell_state_ptr,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_weights_ptr = recovered_weights->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
+ cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_weights_ptr, quantized_input_ptr,
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[node->inputs->data[kInputActivationStateTensor]];
+ TfLiteTensor* cell_state =
+ &context->tensors[node->inputs->data[kInputCellStateTensor]];
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_layer_norm_weights,
+ forget_layer_norm_weights, cell_layer_norm_weights,
+ output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, op_data->cell_clip,
+ op_data->proj_clip, op_data->activation, scratch_buffer,
+ activation_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_layer_norm_weights, forget_layer_norm_weights,
+ cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, op_data->cell_clip, op_data->proj_clip,
+ op_data->activation, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_weights, input_quantized,
+ activation_state_quantized, cell_state_quantized, activation_state,
+ cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+} // namespace layer_norm_lstm
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM() {
+ static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
+ layer_norm_lstm::Prepare,
+ layer_norm_lstm::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
new file mode 100644
index 0000000000..abc229f85a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -0,0 +1,664 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Layer Norm LSTM op.
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LayerNormLSTMOpModel : public SingleOpModel {
+ public:
+ LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(weight_type);
+ }
+
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(weight_type);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(weight_type);
+ }
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(weight_type);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ // Adding the 2 state tensors.
+ output_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ // Set up and pass in custom options using flexbuffer.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("cell_clip", cell_clip);
+ fbb.Int("proj_clip", proj_clip);
+ fbb.String("fused_activation_function", "TANH");
+ });
+ fbb.Finish();
+ SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ protected:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_layer_norm_weights_;
+ int forget_layer_norm_weights_;
+ int cell_layer_norm_weights_;
+ int output_layer_norm_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_state_;
+ int cell_state_;
+
+ int output_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
+ public:
+ HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights,
+ bool use_projection_bias, float cell_clip,
+ float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
+ use_peephole, use_projection_weights,
+ use_projection_bias, cell_clip, proj_clip,
+ input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLayerNormLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the Layer Norm LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> input_layer_norm_weights_;
+ std::initializer_list<float> forget_layer_norm_weights_;
+ std::initializer_list<float> cell_layer_norm_weights_;
+ std::initializer_list<float> output_layer_norm_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> layer_norm_lstm_input_;
+
+ // Compares output up to tolerance to the result of the layer_norm_lstm given
+ // the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LayerNormLSTMOpModel* layer_norm_lstm,
+ float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = layer_norm_lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+
+ layer_norm_lstm->Invoke();
+
+ const int num_outputs = layer_norm_lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(layer_norm_lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
+ : public BaseLayerNormLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
+ 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
+ -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+ input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
+ -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
+ -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
+
+ input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
+ -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
+ -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
+
+ input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
+ -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
+ -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
+
+ input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
+
+ forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
+
+ cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
+
+ output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
+
+ recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
+ -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+ recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
+ -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+ recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
+ 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
+
+ recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
+ -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+ cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
+
+ cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
+
+ cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
+
+ input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
+ forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
+ cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
+ output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
+
+ projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
+ 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
+
+ layer_norm_lstm_input_ = {
+ {// Batch0: 3 (input_sequence_size) * 5 (n_input)
+ 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
+ 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
+ 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
+
+ {// Batch1: 3 (input_sequence_size) * 5 (n_input)
+ 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
+ 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
+ 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
+ };
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ LayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ LayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ // Verify the final output.
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244077, 0.128027, -0.00170918, // seq 0
+ 0.0137642, 0.140751, 0.0395835, // seq 1
+ -0.00459231, 0.155278, 0.0837377, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00692428, 0.0848741, 0.063445, // seq 0
+ -0.00403912, 0.139963, 0.072681, // seq 1
+ 0.00752706, 0.161903, 0.0561371, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ HybridLayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ HybridLayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244576, 0.127847, -0.00181765, // seq 0
+ 0.0137518, 0.140892, 0.0402234, // seq 1
+ -0.0048839, 0.155096, 0.0840309, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00728636, 0.0843957, 0.0634786, // seq 0
+ -0.00448382, 0.139278, 0.0737372, // seq 1
+ 0.00734616, 0.161793, 0.0560238, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 36dca299d0..334d2a2788 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -64,11 +64,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
- type::LocalResponseNormalization( \
- GetTensorData<float>(input), GetTensorDims(input), params->radius, \
- params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
- GetTensorDims(output))
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ tflite::LocalResponseNormalizationParams op_params; \
+ op_params.range = params->radius; \
+ op_params.bias = params->bias; \
+ op_params.alpha = params->alpha; \
+ op_params.beta = params->beta; \
+ type::LocalResponseNormalization( \
+ op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
}
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index 87c2fee667..f770cb35d1 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -86,14 +86,14 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data->requires_broadcast) {
- reference_ops::BroadcastLogical(
- GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output), func);
+ reference_ops::BroadcastLogical4DSlow(
+ GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output), func);
} else {
- reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output),
+ reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output),
func);
}
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 69523b02cc..9fa1c5f100 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -59,8 +59,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include <farmhash.h>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index ba251c451e..aaa3ce966e 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
@@ -37,7 +37,7 @@ namespace builtin {
namespace lstm {
struct OpData {
- // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // Which kernel type to use. Full kernel (20 inputs) or basic kernel
// (5 inputs).
TfLiteLSTMKernelType kernel_type;
@@ -47,7 +47,7 @@ struct OpData {
int scratch_tensor_index;
};
-// For full inputs kernel (18 or 20 inputs).
+// For full inputs kernel (20-inputs).
namespace full {
// Input Tensors of size {n_batch, n_input}
@@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
-// If the node has 20 inputs, the following 2 tensors are used as state tensors.
-// These are defined as variable tensors, and will be modified by this op.
+// These state tensors are defined as variable tensors, and will be modified by
+// this op.
constexpr int kInputActivationStateTensor = 18;
constexpr int kInputCellStateTensor = 19;
// Output tensors.
-// * If the node has 18 inputs, these 2 tensors are used as state tensors.
-// * If the node has 20 inputs, these 2 tensors are ignored.
-// TODO(ycling): Make the 2 output state tensors optional, and propagate the
-// state to output tensors when the 2 tensors present.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
-
- // True if the node is using input variable state tensors. It means:
- // * The state tensors are defined as inputs. In this case it would be the
- // 19th and 20th input tensors.
- // * Otherwise, the output tensors are used to store states.
- bool use_input_variable_states;
- if (node->inputs->size == 20) {
- use_input_variable_states = true;
- op_data->activation_state_tensor_index =
- node->inputs->data[kInputActivationStateTensor];
- op_data->cell_state_tensor_index =
- node->inputs->data[kInputCellStateTensor];
- } else if (node->inputs->size == 18) {
- use_input_variable_states = false;
- op_data->activation_state_tensor_index =
- node->outputs->data[kOutputStateTensor];
- op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
- } else {
- context->ReportError(
- context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* cell_state =
&context->tensors[op_data->cell_state_tensor_index];
- if (use_input_variable_states) {
- // Check the shape of input state tensors.
- // These tensor may be 1D or 2D. It's fine as long as the total size is
- // correct.
- TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
- n_batch * n_output);
- TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
- } else {
- // If the state tensors are outputs, this function takes the
- // responsibility to resize the state tensors.
- TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
- activation_state_size->data[0] = n_batch;
- activation_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
- activation_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
- // Mark state tensors as persistent tensors.
- activation_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
- }
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0266f5fe57..e7ddfceb45 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
+
BuildInterpreter(input_shapes);
}
@@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest {
void SetUp() override {
input_to_input_weights_ = {
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
@@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 8d676218bd..7cb01465ee 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -86,13 +86,14 @@ struct MinimumOp {
template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) {
- reference_ops::TensorFlowMaximumMinimum<data_type>(
+ reference_ops::MaximumMinimumBroadcast4DSlow(
+ GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1),
- GetTensorDims(op_context.input1),
+ GetTensorShape(op_context.input2),
GetTensorData<data_type>(op_context.input2),
- GetTensorDims(op_context.input2),
+ GetTensorShape(op_context.output),
GetTensorData<data_type>(op_context.output),
- GetTensorDims(op_context.output), op_type::template op<data_type>);
+ op_type::template op<data_type>);
}
template <KernelType kernel_type, typename OpType>
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 3f5bc4d68a..66cf147d75 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc
index 0291ca8c1c..c9124adcaf 100644
--- a/tensorflow/contrib/lite/kernels/mfcc_test.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 561e39cfc6..e0aac8a842 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -102,24 +102,28 @@ template <KernelType kernel_type>
void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_MUL(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, int32_t);
}
@@ -127,13 +131,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, float);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(reference_ops, Mul, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, float);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}
@@ -149,14 +153,20 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2, TfLiteTensor* output) {
if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), -input2->params.zero_point, \
- output->params.zero_point, data->output_multiplier, \
- data->output_shift, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.input1_offset = -input1->params.zero_point; \
+ op_params.input2_offset = -input2->params.zero_point; \
+ op_params.output_offset = output->params.zero_point; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+
// The quantized version of Mul doesn't support activations, so we
// always use BroadcastMul.
if (kernel_type == kReference) {
@@ -167,10 +177,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteInt16) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
@@ -179,12 +191,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- output->params.zero_point, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.output_offset = output->params.zero_point; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 4124c05388..0ddd0644f5 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
index 9ff3dca932..910aed6f14 100644
--- a/tensorflow/contrib/lite/kernels/one_hot.cc
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
index 7568eaa88e..d66364c4d8 100644
--- a/tensorflow/contrib/lite/kernels/op_macros.h
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
#include <cstdio>
@@ -31,4 +31,4 @@ limitations under the License.
if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
} while (0)
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index 1c728a4733..90a915bb02 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -101,8 +101,6 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -180,22 +178,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -238,8 +220,6 @@ class LSTMOpModel : public SingleOpModel {
int input_cell_state_;
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -324,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
lstm.SetCellToOutputWeights(
{-0.17135078, 0.82760304, 0.85573703, -0.77109635});
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
// Verify the model by unpacking it.
lstm.Verify();
}
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index bb3416f6a6..4cb98fdd19 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -27,24 +27,9 @@ namespace {
constexpr int kOutputTensor = 0;
-// Op data for pack op.
-struct OpData {
- int values_count;
- int axis;
-};
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* data = new OpData;
- data->axis = 0;
- return data;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<OpData*>(buffer);
-}
-
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -54,9 +39,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
// TODO(renjieliu): Support negative axis.
TF_LITE_ENSURE(context, data->axis >= 0);
- if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) {
+ if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
+ input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt16) {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int16/int32.");
return kTfLiteError;
}
// Make sure all inputs have the same shape and type.
@@ -82,6 +69,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, output->type, input0->type);
+ // Guarantee input/output quantization params match as we do not support
+ // packing quantized tensors.
+ for (int i = 0; i < data->values_count; i++) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point,
+ output->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
+ }
+
return context->ResizeTensor(context, output, output_shape);
}
@@ -95,7 +91,8 @@ void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+ const TfLitePackParams* data =
+ reinterpret_cast<TfLitePackParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
@@ -103,13 +100,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PackImpl<float>(context, node, output, data->values_count, data->axis);
break;
}
+ case kTfLiteUInt8: {
+ PackImpl<uint8_t>(context, node, output, data->values_count, data->axis);
+ break;
+ }
case kTfLiteInt32: {
PackImpl<int32_t>(context, node, output, data->values_count, data->axis);
break;
}
default: {
context->ReportError(context,
- "Currently pack only supports int32 and float32.");
+ "Currently pack only supports "
+ "float32/uint8/int32.");
return kTfLiteError;
}
}
@@ -121,8 +123,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace pack
TfLiteRegistration* Register_PACK() {
- static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare,
- pack::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc
index 485a50ad3a..c70dbd2764 100644
--- a/tensorflow/contrib/lite/kernels/pack_test.cc
+++ b/tensorflow/contrib/lite/kernels/pack_test.cc
@@ -51,6 +51,7 @@ class PackOpModel : public SingleOpModel {
int output_;
};
+// float32 tests.
TEST(PackOpTest, FloatThreeInputs) {
PackOpModel<float> model({TensorType_FLOAT32, {2}}, 0, 3);
model.SetInput(0, {1, 4});
@@ -81,7 +82,8 @@ TEST(PackOpTest, FloatMultilDimensions) {
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
-TEST(PackOpTest, IntThreeInputs) {
+// int32 tests.
+TEST(PackOpTest, Int32ThreeInputs) {
PackOpModel<int32_t> model({TensorType_INT32, {2}}, 0, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
@@ -91,7 +93,7 @@ TEST(PackOpTest, IntThreeInputs) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
}
-TEST(PackOpTest, IntThreeInputsDifferentAxis) {
+TEST(PackOpTest, Int32ThreeInputsDifferentAxis) {
PackOpModel<int32_t> model({TensorType_INT32, {2}}, 1, 3);
model.SetInput(0, {1, 4});
model.SetInput(1, {2, 5});
@@ -101,7 +103,7 @@ TEST(PackOpTest, IntThreeInputsDifferentAxis) {
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
-TEST(PackOpTest, IntMultilDimensions) {
+TEST(PackOpTest, Int32MultilDimensions) {
PackOpModel<int32_t> model({TensorType_INT32, {2, 3}}, 1, 2);
model.SetInput(0, {1, 2, 3, 4, 5, 6});
model.SetInput(1, {7, 8, 9, 10, 11, 12});
@@ -110,6 +112,38 @@ TEST(PackOpTest, IntMultilDimensions) {
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}
+
+// uint8
+TEST(PackOpTest, Uint8ThreeInputs) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 0, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(PackOpTest, Uint8ThreeInputsDifferentAxis) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2}}, 1, 3);
+ model.SetInput(0, {1, 4});
+ model.SetInput(1, {2, 5});
+ model.SetInput(2, {3, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(PackOpTest, Uint8MultilDimensions) {
+ PackOpModel<uint8_t> model({TensorType_UINT8, {2, 3}}, 1, 2);
+ model.SetInput(0, {1, 2, 3, 4, 5, 6});
+ model.SetInput(1, {7, 8, 9, 10, 11, 12});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4be8c243c1..0d939405f6 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
op_context.constant_values->type);
}
- // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+ // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+ TF_LITE_ENSURE(context, op_context.dims <= 4);
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
@@ -134,12 +134,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::PadV2(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
-
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE(context, before_padding.size() <= 4); \
+ TF_LITE_ENSURE(context, after_padding.size() <= 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = before_padding.size(); \
+ op_params.right_padding_count = after_padding.size(); \
+ for (int i = 0; i < op_context.dims; ++i) { \
+ op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
+ op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f8b9064fbb..f663899713 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) {
PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadOpTest, UnequalDimensions) {
@@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+TEST(PadOpTest, SimpleConst1DTest) {
+ PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2},
+ {TensorType_FLOAT32});
+ m.SetInput({2, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
TEST(PadOpTest, SimpleDynamicTest) {
PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
{TensorType_FLOAT32});
@@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) {
{TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadV2OpTest, UnequalDimensions) {
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f19a9..42b6b45d3b 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 29a5be0683..6451142391 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index 4a539c47a8..1e96cc80b1 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -80,14 +80,14 @@ template <typename T>
void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
if (requires_broadcast) {
- reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output),
- GetTensorDims(output));
+ reference_ops::BroadcastPow4DSlow(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
} else {
- reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output));
+ reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
}
}
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index e99f67c725..d94d821e87 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
+#include <limits>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -177,6 +178,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
case kTfLiteUInt8:
temp_sum->type = kTfLiteInt32;
break;
+ case kTfLiteBool:
+ temp_sum->type = kTfLiteBool;
+ break;
default:
return kTfLiteError;
}
@@ -204,6 +208,13 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteBool);
+ return PrepareSimple(context, node);
+}
+
TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
@@ -256,11 +267,27 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
break;
case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
+ if (op_context.input->params.zero_point ==
+ op_context.output->params.zero_point &&
+ op_context.input->params.scale == op_context.output->params.scale) {
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
+ } else {
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::Mean<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point,
+ op_context.input->params.scale, op_context.input->dims->data,
+ op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale,
+ op_context.output->dims->data, op_context.output->dims->size,
+ GetTensorData<int>(op_context.axis), num_axis,
+ op_context.params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int>(temp_sum)));
+ }
break;
default:
return kTfLiteError;
@@ -270,146 +297,123 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int num_axis = static_cast<int>(NumElements(op_context.axis));
+// The underlying logic for Reduce Sum/Prod/Max/Min/Any
+template <typename T>
+TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, T init_value,
+ T reducer(const T current, const T in)) {
+ int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
// Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
+ if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ ResizeTempAxis(context, op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
}
-
-#define TF_LITE_SUM(kernel_type, data_type) \
- kernel_type::Sum<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+ if (op_context->input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
+ op_context->output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
+ op_context->output->params.zero_point);
}
-#undef TF_LITE_SUM
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::ReduceGeneric<T>(
+ GetTensorData<T>(op_context->input), op_context->input->dims->data,
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->dims->data, op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), init_value, reducer));
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_PROD(kernel_type, data_type) \
- kernel_type::ReduceProd<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
+enum ReduceType {
+ kSum,
+ kProd,
+ kMax,
+ kMin,
+ kAny,
+};
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- // TODO(wangtz): uint8 reduce_prod is not yet supported.
- default:
- return kTfLiteError;
- }
+// Eval for determined input type and reduce type.
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kSum:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(0),
+ [](const T current, const T in) -> T { return in + current; });
+ break;
+ case kProd:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(1),
+ [](const T current, const T in) -> T { return in * current; });
+ break;
+ case kMax:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::lowest(),
+ [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ });
+ break;
+ case kMin:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::max(),
+ [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_PROD
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+// Template specialization for bool type
+template <>
+TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kAny:
+ return EvalLogic<bool>(context, node, op_context, false,
+ [](const bool current, const bool in) -> bool {
+ return in || current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
+}
-#define TF_LITE_MAX(kernel_type, data_type) \
- kernel_type::ReduceMax<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// The entry point that handles input types and then calls template functions to
+// handle ReduceType.
+template <KernelType kernel_type, ReduceType reduce_type>
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
+ if (kernel_type != kReference) {
+ return kTfLiteOk;
+ }
+ OpContext op_context(context, node);
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ return EvalType<float>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt32:
+ return EvalType<int>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt64:
+ return EvalType<int64_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteUInt8:
+ return EvalType<uint8_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteBool:
+ return EvalType<bool>(context, node, &op_context, reduce_type);
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_MAX
- return kTfLiteOk;
}
} // namespace reduce
@@ -422,23 +426,37 @@ TfLiteRegistration* Register_MEAN_REF() {
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalSum<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
return &r;
}
TfLiteRegistration* Register_REDUCE_PROD_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalProd<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MAX_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMax<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_MIN_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_ANY_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareAny,
+ reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
return &r;
}
@@ -449,6 +467,8 @@ TfLiteRegistration* Register_REDUCE_PROD() {
return Register_REDUCE_PROD_REF();
}
TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
+TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); }
+TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); }
} // namespace builtin
} // namespace ops
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 5d432d34ef..6d289b14d8 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -169,6 +169,64 @@ class MaxOpDynamicModel : public BaseOpModel {
}
};
+// Model for the tests case where axis is a const tensor.
+class MinOpConstModel : public BaseOpModel {
+ public:
+ MinOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class MinOpDynamicModel : public BaseOpModel {
+ public:
+ MinOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_MIN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a const tensor.
+class AnyOpConstModel : public BaseOpModel {
+ public:
+ AnyOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class AnyOpDynamicModel : public BaseOpModel {
+ public:
+ AnyOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
// for quantized Add, the error shouldn't exceed step
float GetTolerance(int min, int max) { return (max - min) / 255.0; }
@@ -309,6 +367,33 @@ TEST(DynamicUint8MeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
}
+TEST(DynamicUint8MeanOpTest, QuantizedScalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {0.643};
+ MeanOpDynamicModel m({TensorType_UINT8, {}, 0.0, 1.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.643}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MeanOpTest, QuantizedKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {3}, -5.0, 5.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
+}
+
// Tests for reduce_sum
TEST(ConstFloatSumOpTest, NotKeepDims) {
@@ -665,6 +750,209 @@ TEST(DynamicUint8MaxOpTest, Scalar) {
ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
}
+// Tests for reduce_min
+
+TEST(ConstFloatMinOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({1, 2})));
+}
+
+TEST(ConstFloatMinOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 3, 5})));
+}
+
+TEST(DynamicFloatMinOpTest, NotKeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({1, 2})));
+}
+
+TEST(DynamicFloatMinOpTest, KeepDims) {
+ std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ MinOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 3, 5})));
+}
+
+TEST(DynamicFloatMinOpTest, Scalar) {
+ std::vector<float> data = {9.527};
+ MinOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8MinOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MinOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.294117, 0.2}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MinOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MinOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.2, 0.3, 0.5}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
+ MinOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::vector<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-4.807843, -3.6}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14, -0.14, 7.423, 0.879};
+ MinOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({7.427451, -0.164706}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8MinOpTest, Scalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {11.14};
+ MinOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
+}
+
+// Tests for reduce_any
+
+TEST(ConstAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}}, {4},
+ {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(ConstAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}}, {2},
+ {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}},
+ {TensorType_INT32, {4}}, false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(DynamicAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}},
+ {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, Scalar) {
+ std::vector<bool> data = {false};
+ AnyOpDynamicModel m({TensorType_BOOL, {1}}, {TensorType_BOOL, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 6159311910..c66959fdf4 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
namespace ops {
@@ -21,8 +22,10 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+TfLiteRegistration* Register_RELU_1();
} // namespace custom
@@ -93,6 +96,8 @@ TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_SUM();
TfLiteRegistration* Register_REDUCE_PROD();
TfLiteRegistration* Register_REDUCE_MAX();
+TfLiteRegistration* Register_REDUCE_MIN();
+TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
@@ -111,6 +116,8 @@ TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_LOGICAL_OR();
TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
+TfLiteRegistration* Register_UNPACK();
+TfLiteRegistration* Register_FLOOR_DIV();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -129,9 +136,7 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
// Return the NULL Op for all ops whose name start with "Eager", allowing
// the interpreter to delegate their execution.
- // TODO(ycling): Refactoring and extract an `IsEagerOp` function into
- // `lite:framework` build target.
- if (string(op).find("Eager") == 0) {
+ if (IsEagerOp(op)) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
@@ -220,6 +225,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
+ AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN());
+ AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY());
AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
@@ -234,12 +241,16 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
+ AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
+ AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+ AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
+ AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 0296152d68..61856ab9de 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc
new file mode 100644
index 0000000000..abafee2d57
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace relu1 {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+// This is derived from lite/kernels/activations.cc.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const int elements = NumElements(input);
+ const float* in = input->data.f;
+ const float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; ++in, ++out) {
+ *out = std::min(std::max(0.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace relu1
+
+TfLiteRegistration* Register_RELU_1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ relu1::Prepare, relu1::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
new file mode 100644
index 0000000000..c1e0149c20
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_RELU_1();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ explicit BaseActivationsOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput({input.type, {}});
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+ SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, 0.0, 0.2, 0.0, //
+ 0.3, 0.0, 1.0, 0.0, //
+ }));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 49ba0571e2..f41147b2d6 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 86c4cd3ee8..fb045d15f3 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -88,11 +88,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
- type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<datatype>(output), GetTensorDims(output), \
- params->align_corners)
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ tflite::ResizeBilinearParams op_params; \
+ op_params.align_corners = params->align_corners; \
+ type::ResizeBilinear(op_params, GetTensorShape(input), \
+ GetTensorData<datatype>(input), GetTensorShape(size), \
+ GetTensorData<int32>(size), GetTensorShape(output), \
+ GetTensorData<datatype>(output))
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3cdb5db209..3959502d91 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
index dbcd2ef004..66d4c9e5c1 100644
--- a/tensorflow/contrib/lite/kernels/shape.cc
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b3a2..de80a4016e 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 6a20e802a9..ccfee41b9c 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -159,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
sizes.push_back(1);
}
-#define TF_LITE_SLICE(data_type) \
- optimized_ops::Slice<data_type>( \
- GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+ // The original Slice op implementation only accepted 4-D sizes. That
+ // constraint is, for the present, maintained here.
+ //
+ // The dimensions in the kernel used to be in reverse-order, and TFLite
+ // arranged the begins and sizes vectors accordingly. This macro incorporates
+ // the needed reversing.
+#define TF_LITE_SLICE(data_type) \
+ { \
+ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
+ tflite::SliceParams op_params; \
+ op_params.begin_count = 4; \
+ op_params.size_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.begin[i] = begins[3 - i]; \
+ op_params.size[i] = sizes[3 - i]; \
+ } \
+ \
+ optimized_ops::Slice<data_type>( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 03079f1c3b..3a10d2e60c 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -114,14 +114,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ tflite::SpaceToBatchParams op_params; \
+ op_params.output_offset = pad_value; \
+ type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.paddings), \
GetTensorData<int32_t>(op_context.paddings), \
- GetTensorDims(op_context.paddings), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9dbe9b9eda..64c56c017b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -79,10 +79,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
- type::SpaceToDepth<scalar>( \
- GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
- GetTensorData<scalar>(output), GetTensorDims(output))
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ tflite::SpaceToDepthParams op_params; \
+ op_params.block_size = params->block_size; \
+ type::SpaceToDepth(op_params, GetTensorShape(input), \
+ GetTensorData<scalar>(input), GetTensorShape(output), \
+ GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index fec2a6f0d9..178568e07c 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b144486041..719e2dc606 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662fd9..080c51cd18 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index bed2117f9a..87ffcc4110 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 77a1f59689..1be0c83f17 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 6d4912ce3a..9903fd5c35 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -40,19 +40,22 @@ namespace {
struct OpData {
int scratch_tensor_index;
bool float_weights_time_initialized;
+
+ int activation_state_tensor_index;
};
static inline void ApplyTimeWeightsBiasAndActivation(
int batch_size, int memory_size, int num_filters, int num_units, int rank,
const TfLiteTensor* weights_time, const TfLiteTensor* bias,
- TfLiteFusedActivation activation, TfLiteTensor* state,
+ TfLiteFusedActivation activation, TfLiteTensor* activation_state,
TfLiteTensor* scratch, TfLiteTensor* output) {
// Compute matmul(state, weights_time).
// The right most column is used to save temporary output (with the size of
- // num_filters). This is achieved by starting at state->data.f and having the
- // stride equal to memory_size.
+ // num_filters). This is achieved by starting at activation_state->data.f,
+ // and having the stride equal to memory_size.
for (int b = 0; b < batch_size; ++b) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
float* scratch_ptr_batch = scratch->data.f + b * num_filters;
tensor_utils::BatchVectorBatchVectorDotProduct(
weights_time->data.f, state_ptr_batch, memory_size, num_filters,
@@ -82,13 +85,14 @@ static inline void ApplyTimeWeightsBiasAndActivation(
activation, output_ptr_batch);
}
- // Left shift the state to make room for next cycle's activation.
+ // Left shift the activation_state to make room for next cycle's activation.
// TODO(alanchiao): explore collapsing this into a single loop.
for (int b = 0; b < batch_size; ++b) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
for (int f = 0; f < num_filters; ++f) {
tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
- /*shift_value=*/0.0);
+ /*shift_value=*/0.0f);
state_ptr_batch += memory_size;
}
}
@@ -96,12 +100,16 @@ static inline void ApplyTimeWeightsBiasAndActivation(
} // namespace
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kStateTensor = 0;
-constexpr int kOutputTensor = 1;
+// This is a variable tensor, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -121,8 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@@ -148,22 +158,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
}
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- // For each batch, the state is a 2-D tensor: memory_size * num_filters
- // The left most column is used to save current cycle activation.
- // The right most column is used to save temporary output which will be
- // reduced to num_units outputs.
- TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
- state_size_array->data[0] = batch_size;
- state_size_array->data[1] = memory_size * num_filters;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, state, state_size_array));
-
- // Mark state as a persistent tensor.
- state->allocation_type = kTfLiteArenaRwPersistent;
+ // Check the shape of input state tensors.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
+ memory_size * num_filters);
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
@@ -220,8 +223,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
- // Used to store dequantized weights_time matrix for hybrid computation
- // of matmul(state, weights_time), which occurs in floating point.
+ // Used to store dequantized weights_time matrix for hybrid computation of
+ // matmul(activation_state, weights_time), which occurs in floating point.
node->temporaries->data[3] = scratch_tensor_index + 3;
TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
float_weights_time->type = kTfLiteFloat32;
@@ -253,13 +256,13 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
const int memory_size = weights_time->dims->data[1];
// Clear the activation (state left most column).
- // TODO(ghodrat): Add a test which initialize state with invalid values in
- // left most column and make sure it passes.
+ // TODO(ghodrat): Add a test which initialize activation_state with invalid
+ // values in left most column and make sure it passes.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
for (int c = 0; c < num_filters; ++c) {
float* state_ptr = state_ptr_batch + c * memory_size;
- state_ptr[memory_size - 1] = 0.0;
+ state_ptr[memory_size - 1] = 0.0f;
}
}
@@ -307,7 +310,7 @@ TfLiteStatus EvalHybrid(
// Clear the activation (state left most column).
// TODO(ghodrat): Add a test which initialize state with invalid values in
- // left most column and make sure it passes.
+ // the left most column and make sure it passes.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
for (int c = 0; c < num_filters; ++c) {
@@ -329,9 +332,10 @@ TfLiteStatus EvalHybrid(
}
// Compute conv1d(inputs, weights_feature).
- // The state right most column is used to save current cycle activation.
- // This is achieved by starting at state->data.f[memory_size - 1] and having
- // the stride equal to memory_size.
+ // The rightmost column of state is used to save the current cycle
+ // activation.
+ // This is achieved by starting at state->data.f[memory_size - 1]
+ // and having the stride equal to memory_size.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1],
@@ -359,13 +363,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (weights_feature->type) {
case kTfLiteFloat32: {
return EvalFloat(context, node, input, weights_feature, weights_time,
- bias, params, scratch, state, output);
+ bias, params, scratch, activation_state, output);
break;
}
case kTfLiteUInt8: {
@@ -392,7 +397,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
return EvalHybrid(context, node, input, weights_feature,
float_weights_time, bias, params, scratch,
- scaling_factors, input_quantized, state, output);
+ scaling_factors, input_quantized, activation_state,
+ output);
break;
}
default:
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index 5af3ff8500..6d60dc63f4 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -141,16 +141,20 @@ class BaseSVDFOpModel : public SingleOpModel {
weights_feature_ = AddInput(weights_feature_type);
weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
- state_ = AddOutput(TensorType_FLOAT32);
+ const int num_filters = units * rank;
+ activation_state_ = AddInput(
+ TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
+ /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
BuildInterpreter({
- {batches_, input_size_}, // Input tensor
- {units_ * rank, input_size_}, // weights_feature tensor
- {units_ * rank, memory_size_}, // weights_time tensor
- {units_} // bias tensor
+ {batches_, input_size_}, // input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_}, // bias tensor
+ {batches, memory_size * num_filters} // activation_state tensor
});
}
@@ -169,15 +173,6 @@ class BaseSVDFOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- // Resets the state of SVDF op by filling it with 0's.
- void ResetState() {
- const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
// Extracts the output tensor from the SVDF op.
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -190,7 +185,7 @@ class BaseSVDFOpModel : public SingleOpModel {
int weights_feature_;
int weights_time_;
int bias_;
- int state_;
+ int activation_state_;
int output_;
int batches_;
@@ -274,7 +269,6 @@ TEST_F(SVDFOpTest, BlackBoxTestRank1) {
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
- svdf.ResetState();
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf);
}
@@ -314,7 +308,6 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) {
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
- svdf.ResetState();
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf);
}
@@ -339,7 +332,6 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) {
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
- svdf.ResetState();
VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
&svdf,
/*tolerance=*/0.002945);
@@ -380,7 +372,6 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) {
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
- svdf.ResetState();
VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
&svdf,
/*tolerance=*/0.00625109);
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index 5181a8f89a..49421eb870 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
index 4f78c224e5..e73ca7b750 100644
--- a/tensorflow/contrib/lite/kernels/tile_test.cc
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 2dd760bbfe..6c38b6739e 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 2abb89b617..16106fdafe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 800b0563d7..95359962e0 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index a9baa5c698..6f2d98ede8 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 0acd705950..63817bd886 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
@@ -64,10 +64,14 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensor of size {n_batch, n_output}
+constexpr int kInputActivationStateTensor = 18;
+// Cell state tensor of size {n_batch, n_cell}
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
// Temporary tensors
enum TemporaryTensor {
@@ -82,7 +86,7 @@ enum TemporaryTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
+ auto* scratch_tensor_index = new int();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -247,8 +251,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -276,12 +280,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
n_output, n_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -289,22 +302,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
-
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
@@ -340,7 +337,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
// Allocate temporary tensors to store quantized values of input,
- // output_state and cell_state tensors.
+ // activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
@@ -354,17 +351,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
*scratch_tensor_index + kOutputStateQuantized;
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
- output_state_quantized->type = kTfLiteUInt8;
- output_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(output_state_quantized->dims,
- output_state->dims)) {
- TfLiteIntArray* output_state_quantized_size =
- TfLiteIntArrayCopy(output_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output_state_quantized,
- output_state_quantized_size));
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
}
node->temporaries->data[kCellStateQuantized] =
*scratch_tensor_index + kCellStateQuantized;
@@ -449,7 +446,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -510,7 +507,7 @@ TfLiteStatus EvalFloat(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Feed the sequence into the LSTM step-by-step.
@@ -527,7 +524,7 @@ TfLiteStatus EvalFloat(
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, output_state_ptr,
+ params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, output_ptr_batch);
}
@@ -552,9 +549,9 @@ TfLiteStatus EvalHybrid(
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -655,14 +652,14 @@ TfLiteStatus EvalHybrid(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
int8_t* quantized_cell_state_ptr =
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
@@ -692,8 +689,8 @@ TfLiteStatus EvalHybrid(
n_input, n_output, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, scaling_factors_ptr,
prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ quantized_input_ptr, quantized_activation_state_ptr,
+ quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
output_ptr_batch);
}
return kTfLiteOk;
@@ -744,8 +741,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_to_output_weights->type) {
@@ -758,11 +758,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params,
- scratch_buffer, output_state, cell_state, output);
+ scratch_buffer, activation_state, cell_state, output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@@ -780,8 +780,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, output_state_quantized, cell_state_quantized,
- output_state, cell_state, output);
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index de38bdef6f..cd3aac0532 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,8 +100,14 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
+ /*is_variable=*/true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -180,22 +186,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -233,9 +223,10 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
+
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -458,6 +449,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -475,10 +469,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -519,6 +509,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -536,10 +529,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -629,6 +618,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -646,10 +638,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -691,6 +679,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -708,10 +699,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
@@ -1351,6 +1338,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1374,10 +1364,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -1418,6 +1404,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1441,10 +1430,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0d6d29a171..744ee7c109 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -31,12 +31,15 @@ namespace ops {
namespace builtin {
namespace unidirectional_sequence_rnn {
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -50,14 +53,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -74,20 +79,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
output_size_array->data[0] = (time_major) ? max_time : batch_size;
@@ -276,7 +273,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ // The hidden_state is a variable input tensor that can be modified.
+ TfLiteTensor* hidden_state =
+ const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_weights->type) {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 0adab837b0..6b48e3fff7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
BuildInterpreter({{sequence_len_, batches_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units}});
} else {
BuildInterpreter({{batches_, sequence_len_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units_}});
}
}
@@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
@@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
new file mode 100644
index 0000000000..9ff06f8331
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unpack {
+namespace {
+
+constexpr int kInputTensor = 0;
+
+// Op data for unpack op.
+struct OpData {
+ int num;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
+ TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) > data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+
+ const TfLiteIntArray* input_shape = input->dims;
+ // Num should be equal to the shape[axis].
+ // Resize outputs. rank will be R - 1.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
+ int o = 0;
+ for (int index = 0; index < NumDimensions(input); ++index) {
+ if (index != data->axis) {
+ output_shape->data[o++] = input_shape->data[index];
+ }
+ }
+
+ TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]);
+ for (int i = 0; i < data->num; ++i) {
+ TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->type, input->type);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output, copied_output_shape));
+ }
+
+ TfLiteIntArrayFree(output_shape);
+ return kTfLiteOk;
+}
+
+template <typename T>
+void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input, int output_count, int axis) {
+ VectorOfTensors<T> all_outputs(*context, *node->outputs);
+ reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
+ NumDimensions(input), output_count,
+ all_outputs.data(), **all_outputs.dims());
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ UnpackImpl<float>(context, node, input, data->num, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace unpack
+
+TfLiteRegistration* Register_UNPACK() {
+ static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare,
+ unpack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc
new file mode 100644
index 0000000000..4efc92a0fd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack_test.cc
@@ -0,0 +1,225 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class UnpackOpModel : public SingleOpModel {
+ public:
+ UnpackOpModel(const TensorData& input, int axis) {
+ CHECK_LE(axis, input.shape.size());
+ const int num_outputs = input.shape[axis];
+ input_ = AddInput(input);
+ for (int i = 0; i < num_outputs; ++i) {
+ outputs_.push_back(AddOutput(input.type));
+ }
+ SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions,
+ CreatePackOptions(builder_, num_outputs, axis).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+
+ std::vector<std::vector<T>> GetOutputDatas() {
+ std::vector<std::vector<T>> output_datas;
+ for (const int output : outputs_) {
+ std::cerr << "the output is " << output << std::endl;
+ output_datas.push_back(ExtractVector<T>(output));
+ }
+ return output_datas;
+ }
+
+ std::vector<std::vector<int>> GetOutputShapes() {
+ std::vector<std::vector<int>> output_shapes;
+ for (const int output : outputs_) {
+ output_shapes.push_back(GetTensorShape(output));
+ }
+ return output_shapes;
+ }
+
+ private:
+ int input_;
+ std::vector<int> outputs_;
+};
+
+// float32 tests.
+TEST(UnpackOpTest, FloatThreeOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, FloatOneOutput) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+// int32 tests.
+TEST(UnpackOpTest, IntThreeOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, IntOneOutput) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
index b58ae26601..6195426d6d 100755
--- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
+++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
+# TODO(ycling): Refactoring - Move this script into `tools/make`.
set -e
echo "Starting"
@@ -32,7 +33,7 @@ echo "Headers, populating: TensorFlow Lite"
cd $TFLITE_DIR/../../..
find tensorflow/contrib/lite -name '*.h' \
- -not -path 'tensorflow/contrib/lite/downloads/*' \
+ -not -path 'tensorflow/contrib/lite/tools/*' \
-not -path 'tensorflow/contrib/lite/examples/*' \
-not -path 'tensorflow/contrib/lite/gen/*' \
-not -path 'tensorflow/contrib/lite/toco/*' \
@@ -44,7 +45,7 @@ tar xf tmp.tar
rm -f tmp.tar
echo "Headers, populating: Flatbuffer"
-cd $TFLITE_DIR/downloads/flatbuffers/include/
+cd $TFLITE_DIR/tools/make/downloads/flatbuffers/include/
find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T -
cd $FW_DIR_TFLITE_HDRS
tar xf tmp.tar
@@ -57,7 +58,7 @@ cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tens
$FW_DIR_TFLITE
echo "Copying static libraries"
-cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \
+cp $TFLITE_DIR/tools/make/gen/lib/libtensorflow-lite.a \
$FW_DIR_TFLITE/tensorflow_lite
# This is required, otherwise they interfere with the documentation of the
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index 0294ec815c..2d4707f849 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
index fa9a3cd1d8..92934d1fd1 100644
--- a/tensorflow/contrib/lite/mmap_allocation.cc
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <unistd.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 7b9413cd17..241865b3d8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -20,8 +20,9 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/contrib/lite/model.h"
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
-TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
- ErrorReporter* error_reporter) {
- switch (tensor_type) {
- case TensorType_FLOAT32:
- *type = kTfLiteFloat32;
- break;
- case TensorType_INT16:
- *type = kTfLiteInt16;
- break;
- case TensorType_INT32:
- *type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- *type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- *type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- *type = kTfLiteString;
- break;
- case TensorType_BOOL:
- *type = kTfLiteBool;
- break;
- case TensorType_COMPLEX64:
- *type = kTfLiteComplex64;
- break;
- default:
- error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor_type), tensor_type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {
const TfLiteRegistration* registration = nullptr;
- auto builtin_code = opcode->builtin_code();
- int version = opcode->version();
-
- if (builtin_code > BuiltinOperator_MAX ||
- builtin_code < BuiltinOperator_MIN) {
- error_reporter_->Report(
- "Op builtin_code out or range: %d. Are you using old TFLite binary "
- "with newer model?",
- builtin_code);
- status = kTfLiteError;
- } else if (builtin_code != BuiltinOperator_CUSTOM) {
- registration = op_resolver_.FindOp(builtin_code, version);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find op for builtin opcode '%s' version '%d'\n",
- EnumNameBuiltinOperator(builtin_code), version);
- status = kTfLiteError;
- }
- } else if (!opcode->custom_code()) {
- error_reporter_->Report(
- "Operator with CUSTOM builtin_code has no custom_code.\n");
- status = kTfLiteError;
- } else {
- const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name, version);
- flatbuffer_op_index_to_registration_types_.push_back(
- BuiltinOperator_CUSTOM);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find custom op for name '%s' with version %d\n", name,
- version);
- status = kTfLiteError;
- }
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
}
flatbuffer_op_index_to_registration_.push_back(registration);
}
@@ -247,559 +184,16 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
-// Copies the contents from the flatbuffer int vector `flatbuffer` into the
-// int array `buffer`. `flat_vector` and `buffer` represent the same
-// configuration operation for a given operation.
-void FlatBufferIntVectorToArray(int max_size_of_buffer,
- const flatbuffers::Vector<int32_t>* flat_vector,
- int* buffer, ErrorReporter* error_reporter) {
- if (!flat_vector) {
- error_reporter->Report("Input array not provided for operation.\n");
- } else {
- int num_dimensions = flat_vector->Length();
- if (num_dimensions > max_size_of_buffer / sizeof(int)) {
- error_reporter->Report(
- "Found too many dimensions in the operation's input array.\n");
- } else {
- for (int i = 0; i < num_dimensions; ++i) {
- buffer[i] = flat_vector->Get(i);
- }
- }
- }
-}
-
-// Allocate a structure using C malloc, but make sure the structure is a
-// POD structure that doesn't require constructors to run. The reason we do
-// this, is that Interpreter's C extension part will take ownership and wants
-// to use malloc() and free().
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
-// Parse the appropriate data out of the op.
-//
-// This handles builtin data explicitly as there are flatbuffer schemas.
-// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
-// need to be released by calling `free`.`
-// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
-TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
- auto parse_padding = [](Padding padding) {
- switch (padding) {
- case Padding_SAME:
- return kTfLitePaddingSame;
- case Padding_VALID:
- return kTfLitePaddingValid;
- }
- return kTfLitePaddingUnknown;
- };
- auto parse_activation = [](ActivationFunctionType activation) {
- switch (activation) {
- case ActivationFunctionType_NONE:
- return kTfLiteActNone;
- case ActivationFunctionType_RELU:
- return kTfLiteActRelu;
- case ActivationFunctionType_RELU_N1_TO_1:
- return kTfLiteActRelu1;
- case ActivationFunctionType_RELU6:
- return kTfLiteActRelu6;
- case ActivationFunctionType_TANH:
- return kTfLiteActTanh;
- case ActivationFunctionType_SIGN_BIT:
- return kTfLiteActSignBit;
- }
- return kTfLiteActNone;
- };
- auto parseLSHProjectionType = [](LSHProjectionType type) {
- switch (type) {
- case LSHProjectionType_SPARSE:
- return kTfLiteLshProjectionSparse;
- case LSHProjectionType_DENSE:
- return kTfLiteLshProjectionDense;
- default:
- return kTfLiteLshProjectionUnknown;
- }
- };
- auto parseCombinerType = [](CombinerType type) {
- switch (type) {
- case CombinerType_MEAN:
- return kTfLiteCombinerTypeMean;
- case CombinerType_SQRTN:
- return kTfLiteCombinerTypeSqrtn;
- case CombinerType_SUM:
- default:
- return kTfLiteCombinerTypeSum;
- }
- };
-
- *builtin_data = nullptr;
- switch (op_type) {
- case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
- if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
-
- params->dilation_width_factor = conv_params->dilation_w_factor();
- params->dilation_height_factor = conv_params->dilation_h_factor();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
- if (auto* schema_params = op->builtin_options_as_CastOptions()) {
- auto in_status =
- ConvertTensorType(schema_params->in_data_type(),
- &params->in_data_type, error_reporter);
- auto out_status =
- ConvertTensorType(schema_params->out_data_type(),
- &params->out_data_type, error_reporter);
- if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LSH_PROJECTION: {
- TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
- if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
- params->type = parseLSHProjectionType(lshParams->type());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_AVERAGE_POOL_2D:
- case BuiltinOperator_MAX_POOL_2D:
- case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
- if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
- params->padding = parse_padding(pool_params->padding());
- params->stride_width = pool_params->stride_w();
- params->stride_height = pool_params->stride_h();
- params->filter_width = pool_params->filter_width();
- params->filter_height = pool_params->filter_height();
- params->activation =
- parse_activation(pool_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DEPTHWISE_CONV_2D: {
- TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
- if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->depth_multiplier = conv_params->depth_multiplier();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
- if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
- params->rank = svdf_params->rank();
- params->activation =
- parse_activation(svdf_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
- if (auto* sequence_rnn_params =
- op->builtin_options_as_SequenceRNNOptions()) {
- params->activation =
- parse_activation(sequence_rnn_params->fused_activation_function());
- params->time_major = sequence_rnn_params->time_major();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
- if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
- params->activation =
- parse_activation(rnn_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
- TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
- if (auto* embedding_params =
- op->builtin_options_as_EmbeddingLookupSparseOptions()) {
- params->combiner = parseCombinerType(embedding_params->combiner());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_FULLY_CONNECTED: {
- TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
- if (auto* fully_connected_params =
- op->builtin_options_as_FullyConnectedOptions()) {
- params->activation = parse_activation(
- fully_connected_params->fused_activation_function());
- switch (fully_connected_params->weights_format()) {
- case FullyConnectedOptionsWeightsFormat_DEFAULT:
- params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
- break;
- case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
- params->weights_format =
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
- break;
- default:
- error_reporter->Report("Unhandled fully-connected weights format.");
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_HASHTABLE_LOOKUP:
- // no-op.
- break;
- case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
- if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
- params->beta = softmax_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CONCATENATION: {
- TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
- if (auto* concatenation_params =
- op->builtin_options_as_ConcatenationOptions()) {
- params->activation =
- parse_activation(concatenation_params->fused_activation_function());
- params->axis = concatenation_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
- if (auto* schema_params = op->builtin_options_as_MulOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
- if (auto* schema_params = op->builtin_options_as_AddOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
- if (auto* schema_params = op->builtin_options_as_DivOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
- if (auto* schema_params = op->builtin_options_as_SubOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
- if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
- if (auto* schema_params =
- op->builtin_options_as_LocalResponseNormalizationOptions()) {
- params->radius = schema_params->radius();
- params->bias = schema_params->bias();
- params->alpha = schema_params->alpha();
- params->beta = schema_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
- if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
- params->activation =
- parse_activation(lstm_params->fused_activation_function());
- params->cell_clip = lstm_params->cell_clip();
- params->proj_clip = lstm_params->proj_clip();
- switch (lstm_params->kernel_type()) {
- case LSTMKernelType_FULL:
- params->kernel_type = kTfLiteLSTMFullKernel;
- break;
- case LSTMKernelType_BASIC:
- params->kernel_type = kTfLiteLSTMBasicKernel;
- break;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
- if (auto* schema_params =
- op->builtin_options_as_ResizeBilinearOptions()) {
- params->align_corners = schema_params->align_corners();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
- if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
- auto* new_shape = schema_params->new_shape();
- FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
- params->shape, error_reporter);
- params->num_dimensions = new_shape->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
- if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
- params->ngram_size = skip_gram_params->ngram_size();
- params->max_skip_size = skip_gram_params->max_skip_size();
- params->include_all_ngrams = skip_gram_params->include_all_ngrams();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
- if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
- params->block_size = schema_params->block_size();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
- params->axis = 0;
- if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
- params->axis = gather_params->axis();
- }
-
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MEAN:
- case BuiltinOperator_REDUCE_MAX:
- case BuiltinOperator_REDUCE_PROD:
- case BuiltinOperator_SUM: {
- auto* params = MallocPOD<TfLiteReducerParams>();
- if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
- params->keep_dims = schema_params->keep_dims();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
- if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
- params->num_splits = schema_params->num_splits();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
- if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
- const auto& squeeze_dims = schema_params->squeeze_dims();
- FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
- params->squeeze_dims, error_reporter);
- params->num_squeeze_dims = squeeze_dims->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
- if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
- params->begin_mask = schema_params->begin_mask();
- params->end_mask = schema_params->end_mask();
- params->ellipsis_mask = schema_params->ellipsis_mask();
- params->new_axis_mask = schema_params->new_axis_mask();
- params->shrink_axis_mask = schema_params->shrink_axis_mask();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
- if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
- if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_TRANSPOSE_CONV: {
- TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
- if (auto* transpose_conv_params =
- op->builtin_options_as_TransposeConvOptions()) {
- params->padding = parse_padding(transpose_conv_params->padding());
- params->stride_width = transpose_conv_params->stride_w();
- params->stride_height = transpose_conv_params->stride_h();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPARSE_TO_DENSE: {
- TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
- if (auto* sparse_to_dense_params =
- op->builtin_options_as_SparseToDenseOptions()) {
- params->validate_indices = sparse_to_dense_params->validate_indices();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
- if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
- ConvertTensorType(schema_params->out_type(), &params->out_type,
- error_reporter);
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
- if (auto* pack_params = op->builtin_options_as_PackOptions()) {
- params->values_count = pack_params->values_count();
- params->axis = pack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DELEGATE: {
- // TODO(ycling): Revisit when supporting saving delegated models.
- error_reporter->Report("DELEGATE op shouldn't exist in model.");
- return kTfLiteError;
- }
- case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
- if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
- params->min = schema_params->min();
- params->max = schema_params->max();
- params->num_bits = schema_params->num_bits();
- params->narrow_range = schema_params->narrow_range();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
- if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
- params->axis = schema_params->axis();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
-
- // Below are the ops with no builtin_data strcture.
- case BuiltinOperator_BATCH_TO_SPACE_ND:
- // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
- // ok for now, since there is no call implementation either.
- case BuiltinOperator_CALL:
- case BuiltinOperator_CONCAT_EMBEDDINGS:
- case BuiltinOperator_CUSTOM:
- case BuiltinOperator_DEQUANTIZE:
- case BuiltinOperator_EMBEDDING_LOOKUP:
- case BuiltinOperator_EQUAL:
- case BuiltinOperator_EXP:
- case BuiltinOperator_EXPAND_DIMS:
- case BuiltinOperator_FLOOR:
- case BuiltinOperator_GREATER:
- case BuiltinOperator_GREATER_EQUAL:
- case BuiltinOperator_LESS:
- case BuiltinOperator_LESS_EQUAL:
- case BuiltinOperator_LOG:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_MINIMUM:
- case BuiltinOperator_NEG:
- case BuiltinOperator_NOT_EQUAL:
- case BuiltinOperator_PAD:
- case BuiltinOperator_PADV2:
- case BuiltinOperator_PRELU:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_RELU_N1_TO_1:
- case BuiltinOperator_RSQRT:
- case BuiltinOperator_SELECT:
- case BuiltinOperator_SIN:
- case BuiltinOperator_SLICE:
- case BuiltinOperator_SPACE_TO_BATCH_ND:
- case BuiltinOperator_SQRT:
- case BuiltinOperator_TANH:
- case BuiltinOperator_TILE:
- case BuiltinOperator_TOPK_V2:
- case BuiltinOperator_TRANSPOSE:
- case BuiltinOperator_POW:
- case BuiltinOperator_LOGICAL_OR:
- case BuiltinOperator_LOGICAL_AND:
- case BuiltinOperator_LOGICAL_NOT:
- break;
- }
- return kTfLiteOk;
-}
-
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
Interpreter* interpreter) {
TfLiteStatus status = kTfLiteOk;
+
+ // Reduce the number of redundant allocations
+ interpreter->ReserveNodes(operators->Length());
+
for (int i = 0; i < operators->Length(); ++i) {
const auto* op = operators->Get(i);
int index = op->opcode_index();
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8bc9ecd7ce..6abdfcd079 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -35,9 +35,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index df4f60d4ad..ec7d46af7c 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index 206de1962d..8ecf0b6154 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -102,7 +102,7 @@ class SpeechTest : public ::testing::TestWithParam<int> {
int GetMaxInvocations() { return GetParam(); }
};
-TEST_P(SpeechTest, HotwordOkGoogleRank1Test) {
+TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
@@ -114,7 +114,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank1Test) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, HotwordOkGoogleRank2Test) {
+TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
@@ -126,7 +126,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank2Test) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
+TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
@@ -139,7 +139,7 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, AsrAmTest) {
+TEST_P(SpeechTest, DISABLED_AsrAmTest) {
std::stringstream os;
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
@@ -156,7 +156,7 @@ TEST_P(SpeechTest, AsrAmTest) {
// through the interpreter and stored the sum of all the output, which was them
// compared for correctness. In this test we are comparing all the intermediate
// results.
-TEST_P(SpeechTest, AsrLmTest) {
+TEST_P(SpeechTest, DISABLED_AsrLmTest) {
std::ifstream in_file;
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
@@ -165,7 +165,7 @@ TEST_P(SpeechTest, AsrLmTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, EndpointerTest) {
+TEST_P(SpeechTest, DISABLED_EndpointerTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
@@ -178,7 +178,7 @@ TEST_P(SpeechTest, EndpointerTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, TtsTest) {
+TEST_P(SpeechTest, DISABLED_TtsTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index f6e435e982..8ee63d2a02 100644
--- a/tensorflow/contrib/lite/op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000000..c319041e9b
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ return CombineHashes({a, b});
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index 10b7e31972..db690eaab9 100644
--- a/tensorflow/contrib/lite/op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 42b8163445..81dd459223 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef NN_API_SHIM_H0
-#define NN_API_SHIM_H0
+#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_
+#define TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_
#include <dlfcn.h>
#include <stdint.h>
@@ -970,4 +970,4 @@ inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
/**/
-#endif // NN_API_SHIM_H0
+#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_NEURALNETWORKSSHIM_H_
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 45c92a8671..817486e898 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
@@ -64,6 +64,14 @@ void logError(const char* format, ...) {
__LINE__); \
}
+#define RETURN_ERROR_IF_TFLITE_FAILED(x) \
+ if (x != kTfLiteOk) { \
+ logError( \
+ "Returning error since TFLite returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ return kTfLiteError; \
+ }
+
#define RETURN_ERROR_IF_NN_FAILED(x) \
if (x != ANEURALNETWORKS_NO_ERROR) { \
logError( \
@@ -98,7 +106,10 @@ int32_t GetAndroidSdkVersion() {
return 0;
}
-static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+int32_t GetAndroidSdkVersionCached() {
+ static int32_t androidSdkVersion = GetAndroidSdkVersion();
+ return androidSdkVersion;
+}
} // namespace
@@ -296,17 +307,21 @@ TfLiteStatus AddOpsAndParams(
};
auto check_and_add_activation = [&add_scalar_int32](int activation) {
if (activation > kTfLiteActRelu6) {
- FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ return kTfLiteError;
}
add_scalar_int32(activation);
+ return kTfLiteOk;
};
auto add_add_params = [&add_scalar_int32](void* data) {
auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
if (builtin->activation > kTfLiteActRelu6) {
- FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ return kTfLiteError;
}
add_scalar_int32(builtin->activation);
+ return kTfLiteOk;
};
auto add_pooling_params = [&add_scalar_int32,
@@ -317,7 +332,7 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->filter_width);
add_scalar_int32(builtin->filter_height);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_convolution_params = [&add_scalar_int32,
@@ -326,7 +341,7 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_depthwise_conv_params = [&add_scalar_int32,
@@ -336,20 +351,22 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->depth_multiplier);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_fully_connected_params = [&check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_concatenation_params = [&add_scalar_int32](void* data) {
auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
add_scalar_int32(builtin->axis);
if (builtin->activation != kTfLiteActNone) {
- FATAL("Concatenation does not support fused activation in NNAPI");
+ logError("Concatenation does not support fused activation in NNAPI");
+ return kTfLiteError;
}
+ return kTfLiteOk;
};
auto add_softmax_params = [&add_scalar_float32](void* data) {
@@ -430,22 +447,22 @@ TfLiteStatus AddOpsAndParams(
switch (builtin) {
case tflite::BuiltinOperator_ADD:
nn_op_type = ANEURALNETWORKS_ADD;
- add_add_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
break;
case tflite::BuiltinOperator_MUL:
nn_op_type = ANEURALNETWORKS_MUL;
- add_add_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
break;
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
break;
case tflite::BuiltinOperator_MAX_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
break;
case tflite::BuiltinOperator_L2_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
break;
case tflite::BuiltinOperator_CONV_2D: {
@@ -456,7 +473,8 @@ TfLiteStatus AddOpsAndParams(
return kTfLiteError;
}
}
- add_convolution_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_convolution_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_CONV_2D;
break;
case tflite::BuiltinOperator_RELU:
@@ -475,11 +493,13 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_LOGISTIC;
break;
case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
- add_depthwise_conv_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_depthwise_conv_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
break;
case tflite::BuiltinOperator_CONCATENATION:
- add_concatenation_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_concatenation_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_CONCATENATION;
break;
case tflite::BuiltinOperator_SOFTMAX:
@@ -487,7 +507,8 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_SOFTMAX;
break;
case tflite::BuiltinOperator_FULLY_CONNECTED:
- add_fully_connected_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_fully_connected_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
break;
case tflite::BuiltinOperator_RESHAPE:
@@ -541,14 +562,14 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_DIV:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_DIV;
- check_and_add_activation(
- reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation);
+ RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+ reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation));
break;
case tflite::BuiltinOperator_SUB:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_SUB;
- check_and_add_activation(
- reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation);
+ RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+ reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation));
break;
case tflite::BuiltinOperator_SQUEEZE:
nnapi_version = 11; // requires NNAPI 1.1
@@ -636,6 +657,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_NOT_EQUAL:
case tflite::BuiltinOperator_SUM:
case tflite::BuiltinOperator_REDUCE_MAX:
+ case tflite::BuiltinOperator_REDUCE_MIN:
case tflite::BuiltinOperator_REDUCE_PROD:
case tflite::BuiltinOperator_SQRT:
case tflite::BuiltinOperator_RSQRT:
@@ -647,6 +669,9 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_ONE_HOT:
case tflite::BuiltinOperator_LOGICAL_AND:
case tflite::BuiltinOperator_LOGICAL_NOT:
+ case tflite::BuiltinOperator_UNPACK:
+ case tflite::BuiltinOperator_FLOOR_DIV:
+ case tflite::BuiltinOperator_REDUCE_ANY:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
@@ -656,8 +681,9 @@ TfLiteStatus AddOpsAndParams(
break;
}
- if (nnapi_version == 11 && kAndroidSdkVersion < 28) {
- FATAL("Op %d needs NNAPI1.1", builtin);
+ if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) {
+ logError("Op %d needs NNAPI1.1", builtin);
+ return kTfLiteError;
}
// Add the operation.
@@ -705,9 +731,9 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
interpreter->outputs().size());
uint32_t next_id = 0;
- RETURN_ERROR_IF_NN_FAILED(addTensorOperands(
+ RETURN_ERROR_IF_TFLITE_FAILED(addTensorOperands(
interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id));
- RETURN_ERROR_IF_NN_FAILED(
+ RETURN_ERROR_IF_TFLITE_FAILED(
AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
&model_states_outputs_, tensor_id_to_nnapi_id));
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 2bdb2cc5c8..22359d557e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
class ANeuralNetworksModel;
diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
index efde72b1a7..e3536d3db6 100644
--- a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
@@ -27,7 +27,13 @@ NNAPIAllocation::NNAPIAllocation(const char* filename,
NNAPIAllocation::~NNAPIAllocation() {}
-NNAPIDelegate::~NNAPIDelegate() {}
+NNAPIDelegate::~NNAPIDelegate() {
+#define UNUSED_MEMBER(x) (void)(x)
+ UNUSED_MEMBER(nn_model_);
+ UNUSED_MEMBER(nn_compiled_model_);
+ UNUSED_MEMBER(model_status_);
+#undef UNUSED_MEMBER
+}
TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 9d7e3f2085..e93134cbde 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
-#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/schema/schema_generated.h"
-#include "tensorflow/contrib/lite/util.h"
-
-namespace tflite {
-
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
- // Finds the op registration for a builtin operator by enum code.
- virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const = 0;
- // Finds the op registration of a custom operator by op name.
- virtual const TfLiteRegistration* FindOp(const char* op,
- int version) const = 0;
- virtual ~OpResolver() {}
-};
-
-// Some versions of gcc doesn't support partial specialization in class scope,
-// so these are defined in a namescope.
-namespace op_resolver_hasher {
-template <typename V>
-struct ValueHasher {
- size_t operator()(const V& v) const { return std::hash<V>()(v); }
-};
-
-template <>
-struct ValueHasher<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator& v) const {
- return std::hash<int>()(static_cast<int>(v));
- }
-};
-
-template <typename T>
-struct OperatorKeyHasher {
- size_t operator()(const T& x) const {
- size_t a = ValueHasher<typename T::first_type>()(x.first);
- size_t b = ValueHasher<typename T::second_type>()(x.second);
- return CombineHashes({a, b});
- }
-};
-} // namespace op_resolver_hasher
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-// MutableOpResolver resolver;
-// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-// InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const override;
- const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
-
- private:
- typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
- typedef std::pair<std::string, int> CustomOperatorKey;
-
- std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
- builtins_;
- std::unordered_map<CustomOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
- custom_ops_;
-};
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
index 7fb4b8d8b7..82a6e114a6 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.h
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Optional debugging functionality. For small sized binaries, these are not
// needed.
-#ifndef TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
-#define TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_
+#define TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_
#include "tensorflow/contrib/lite/interpreter.h"
@@ -26,4 +26,4 @@ void PrintInterpreterState(Interpreter* interpreter);
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_OPTIONAL_DEBUG_TOOLS_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 860aff9e7e..57e1290e07 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
- data = [":interpreter_test_data"],
+ data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
@@ -112,8 +112,11 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
"//tensorflow/python:platform",
+ "//tensorflow/python:util",
],
)
@@ -127,6 +130,7 @@ py_test(
],
deps = [
":convert",
+ ":interpreter",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 11d4bdbe82..1c5516ae7c 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os as _os
+import platform as _platform
import subprocess as _subprocess
import tempfile as _tempfile
@@ -26,6 +27,7 @@ from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -90,12 +92,13 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
fp_output.name
]
cmdline = " ".join(cmd)
+ is_windows = _platform.system() == "Windows"
proc = _subprocess.Popen(
cmdline,
shell=True,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
- close_fds=True)
+ close_fds=not is_windows)
stdout, stderr = proc.communicate()
exitcode = proc.returncode
if exitcode == 0:
@@ -123,7 +126,7 @@ def build_toco_convert_protos(input_tensors,
reorder_across_fake_quant=False,
allow_custom_ops=False,
change_concat_input_ranges=False,
- quantize_weights=False,
+ post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
@@ -146,9 +149,11 @@ def build_toco_convert_protos(input_tensors,
as `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: List of tuples of integers representing the mean and
+ quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
- Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
+ Only need if `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -168,9 +173,9 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -194,10 +199,12 @@ def build_toco_convert_protos(input_tensors,
toco.inference_type = inference_type
if inference_input_type:
toco.inference_input_type = inference_input_type
+ else:
+ toco.inference_input_type = toco.inference_type
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
- toco.quantize_weights = quantize_weights
+ toco.post_training_quantize = post_training_quantize
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
@@ -209,7 +216,7 @@ def build_toco_convert_protos(input_tensors,
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
input_array = model.input_arrays.add()
- if inference_type == lite_constants.QUANTIZED_UINT8:
+ if toco.inference_input_type == lite_constants.QUANTIZED_UINT8:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
input_array.name = tensor_name(input_tensor)
if input_shapes is None:
@@ -223,7 +230,56 @@ def build_toco_convert_protos(input_tensors,
return model, toco
-def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
+ *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ This function is used to convert GraphDefs that cannot be loaded into
+ TensorFlow to TFLite. Conversion can be customized by providing arguments
+ that are forwarded to `build_toco_convert_protos` (see documentation for
+ details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` is None. (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `output_tensors` is None.
+ (default None)
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors=[], output_tensors=[], *args, **kwargs)
+
+ for idx, (name, shape) in enumerate(input_arrays_with_shape):
+ input_array = model_flags.input_arrays.add()
+ if kwargs["inference_type"] == lite_constants.QUANTIZED_UINT8:
+ input_array.mean_value, input_array.std_value = kwargs[
+ "quantized_input_stats"][idx]
+ input_array.name = name
+ input_array.shape.dims.extend(map(int, shape))
+
+ for name in output_arrays:
+ model_flags.output_arrays.append(name)
+
+ data = toco_convert_protos(model_flags.SerializeToString(),
+ toco_flags.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
+def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs):
""""Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
@@ -252,3 +308,30 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
toco_flags.SerializeToString(),
input_data.SerializeToString())
return data
+
+
+@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ Typically this function is used to convert from TensorFlow GraphDef to TFLite.
+ Conversion can be customized by providing arguments that are forwarded to
+ `build_toco_convert_protos` (see documentation for details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs)
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index dc21a9b669..40a8b5fafb 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.lite.python import convert
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python import op_hint
+from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -37,9 +40,12 @@ class ConvertTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
+
# Try running on valid graph
- result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
- self.assertTrue(result)
+ tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
+ [out_tensor])
+ self.assertTrue(tflite_model)
+
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
@@ -52,11 +58,85 @@ class ConvertTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = convert.toco_convert(
+
+ tflite_model = convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8,
quantized_input_stats=[(0., 1.)])
- self.assertTrue(result)
+ self.assertTrue(tflite_model)
+
+ def testGraphDefBasic(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
+ inference_type=lite_constants.FLOAT)
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual("input", input_details[0]["name"])
+ self.assertEqual(np.float32, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), input_details[0]["quantization"])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("add", output_details[0]["name"])
+ self.assertEqual(np.float32, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), output_details[0]["quantization"])
+
+ def testGraphDefQuantization(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+ _ = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+ sess = session.Session()
+
+ input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
+ output_arrays = ["output"]
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def,
+ input_arrays_map,
+ output_arrays,
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.), (0., 1.)])
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual("inputA", input_details[0]["name"])
+ self.assertEqual(np.uint8, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[0]["quantization"]) # scale, zero_point
+
+ self.assertEqual("inputB", input_details[1]["name"])
+ self.assertEqual(np.uint8, input_details[1]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[1]["quantization"]) # scale, zero_point
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("output", output_details[0]["name"])
+ self.assertEqual(np.uint8, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
class ConvertTestOpHint(test_util.TensorFlowTestCase):
@@ -108,17 +188,18 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
return output
output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# check if identities have been put into the graph (2 input, 1 output,
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["cool_activation", "Const", "Identity"])
def testScaleAndBiasAndIdentity(self):
@@ -134,17 +215,18 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (3) and output (2) => 3 + 2 = 5
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
def testTwoFunctions(self):
@@ -153,24 +235,100 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
b = array_ops.constant([1.])
def _double_values(x):
custom = op_hint.OpHint("add_test")
- x = custom.add_inputs(x)
+ x, = custom.add_inputs(x)
output = math_ops.multiply(x, x)
output, = custom.add_outputs(output)
return output
output = array_ops.identity(
math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["add_test", "Const", "Identity", "Add"])
+ def _get_input_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
+
+ def _get_output_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i
+
+ def _get_sort_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i
+
+ def testTags(self):
+ """Test if multiple args with the same tag are grouped."""
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ d = array_ops.constant([4.])
+ custom = op_hint.OpHint("test_tag")
+ a = custom.add_input(a, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b, = custom.add_inputs(b)
+ c = custom.add_input(c, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ d = custom.add_input(d, tag="mytag2",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
+ custom.add_outputs([res])
+ with self.cached_session():
+ self.assertEqual(self._get_input_index(a), 0)
+ self.assertEqual(self._get_sort_index(a), 0)
+ self.assertEqual(self._get_input_index(b), 1)
+ self.assertEqual(self._get_input_index(c), 0)
+ self.assertEqual(self._get_sort_index(c), 1)
+
+ def testOverrideIndex(self):
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ custom = op_hint.OpHint("test_override")
+ b = custom.add_input(b) # should auto assign 0
+ a = custom.add_input(a, index_override=1)
+ c = custom.add_input(c) # should auto assign 2
+ with self.cached_session():
+ self.assertEqual(self._get_input_index(a), 1)
+ self.assertEqual(self._get_input_index(b), 0)
+ self.assertEqual(self._get_input_index(c), 2)
+
+ def testAggregate(self):
+ a = array_ops.constant([3., 4.])
+ b = array_ops.constant([5., 6.])
+ hint = op_hint.OpHint("agg")
+ a0, a1 = array_ops.unstack(a)
+ b0, b1 = array_ops.unstack(b)
+
+ a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ c0 = math_ops.add(a0, b0, name="addleft")
+ c1 = math_ops.add(a1, b1, name="addright")
+ c0 = hint.add_output(
+ c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ c1 = hint.add_output(
+ c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ curr = array_ops.stack([c0, c1])
+ output = array_ops.identity(curr, name="FINAL_OUTPUT")
+ with self.cached_session() as sess:
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
+ self.assertCountEqual(
+ self._getGraphOpTypes(
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
+ ["agg", "Const", "Identity"])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 5ec52035ad..44dfb97b84 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -41,7 +41,9 @@ from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
-from tensorflow.contrib.lite.python.convert import toco_convert
+from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
+from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
@@ -54,7 +56,9 @@ from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -75,9 +79,11 @@ class TocoConverter(object):
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
+ mapped to tuple of floats representing the mean and standard deviation
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default {})
+ `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default {})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -97,9 +103,9 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -110,6 +116,7 @@ class TocoConverter(object):
Example usage:
+ ```python
# Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
@@ -124,9 +131,19 @@ class TocoConverter(object):
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
+
+ # Converting a tf.keras model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ tflite_model = converter.convert()
+ ```
"""
- def __init__(self, graph_def, input_tensors, output_tensors):
+ def __init__(self,
+ graph_def,
+ input_tensors,
+ output_tensors,
+ input_arrays_with_shape=None,
+ output_arrays=None):
"""Constructor for TocoConverter.
Args:
@@ -135,6 +152,17 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` and `output_tensors` are None.
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `input_tensors` and
+ `output_tensors` are None. (default None)
+
+ Raises:
+ ValueError: Invalid arguments.
"""
self._graph_def = graph_def
self._input_tensors = input_tensors
@@ -148,10 +176,19 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
- self.quantize_weights = False
+ self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ # Attributes are used by models that cannot be loaded into TensorFlow.
+ if not self._has_valid_tensors():
+ if not input_arrays_with_shape or not output_arrays:
+ raise ValueError(
+ "If input_tensors and output_tensors are None, both "
+ "input_arrays_with_shape and output_arrays must be defined.")
+ self._input_arrays_with_shape = input_arrays_with_shape
+ self._output_arrays = output_arrays
+
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
@@ -189,18 +226,24 @@ class TocoConverter(object):
TocoConverter class.
Raises:
- ValueError:
+ IOError:
+ File not found.
Unable to parse input file.
+ ValueError:
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
+ input_shapes is not correctly defined when required
"""
with _ops.Graph().as_default():
with _session.Session() as sess:
# Read GraphDef from file.
- graph_def = _graph_pb2.GraphDef()
- with open(graph_def_file, "rb") as f:
+ if not _file_io.file_exists(graph_def_file):
+ raise IOError("File '{0}' does not exist.".format(graph_def_file))
+ with _file_io.FileIO(graph_def_file, "rb") as f:
file_content = f.read()
+
try:
+ graph_def = _graph_pb2.GraphDef()
graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
try:
@@ -211,24 +254,49 @@ class TocoConverter(object):
file_content = file_content.decode("utf-8")
else:
file_content = file_content.encode("utf-8")
+ graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):
- raise ValueError(
+ raise IOError(
"Unable to parse input file '{}'.".format(graph_def_file))
- _import_graph_def(graph_def, name="")
-
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph,
- output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ # Handles models with custom TFLite ops that cannot be resolved in
+ # TensorFlow.
+ load_model_in_session = True
+ try:
+ _import_graph_def(graph_def, name="")
+ except _NotFoundError:
+ load_model_in_session = False
+
+ if load_model_in_session:
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ return cls(sess.graph_def, input_tensors, output_tensors)
+ else:
+ if not input_shapes:
+ raise ValueError("input_shapes must be defined for this model.")
+ if set(input_arrays) != set(input_shapes.keys()):
+ raise ValueError("input_shapes must contain a value for each item "
+ "in input_array.")
+
+ input_arrays_with_shape = [
+ (name, input_shapes[name]) for name in input_arrays
+ ]
+ return cls(
+ graph_def,
+ input_tensors=None,
+ output_tensors=None,
+ input_arrays_with_shape=input_arrays_with_shape,
+ output_arrays=output_arrays)
@classmethod
def from_saved_model(cls,
@@ -323,25 +391,25 @@ class TocoConverter(object):
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
- for tensor in self._input_tensors:
- if not tensor.get_shape():
- raise ValueError("Provide an input shape for input array '{0}'.".format(
- _tensor_name(tensor)))
- shape = tensor.get_shape().as_list()
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
- elif shape[0] is None:
- self._set_batch_size(batch_size=1)
+ if self._has_valid_tensors():
+ for tensor in self._input_tensors:
+ if not tensor.get_shape():
+ raise ValueError("Provide an input shape for input array "
+ "'{0}'.".format(_tensor_name(tensor)))
+ shape = tensor.get_shape().as_list()
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
+ elif shape[0] is None:
+ self._set_batch_size(batch_size=1)
# Get quantization stats. Ensures there is one stat per name if the stats
# are specified.
if self.quantized_input_stats:
quantized_stats = []
invalid_stats = []
- for tensor in self._input_tensors:
- name = _tensor_name(tensor)
+ for name in self.get_input_arrays():
if name in self.quantized_input_stats:
quantized_stats.append(self.quantized_input_stats[name])
else:
@@ -353,24 +421,35 @@ class TocoConverter(object):
else:
quantized_stats = None
+ converter_kwargs = {
+ "inference_type": self.inference_type,
+ "inference_input_type": self.inference_input_type,
+ "input_format": constants.TENSORFLOW_GRAPHDEF,
+ "output_format": self.output_format,
+ "quantized_input_stats": quantized_stats,
+ "default_ranges_stats": self.default_ranges_stats,
+ "drop_control_dependency": self.drop_control_dependency,
+ "reorder_across_fake_quant": self.reorder_across_fake_quant,
+ "change_concat_input_ranges": self.change_concat_input_ranges,
+ "allow_custom_ops": self.allow_custom_ops,
+ "post_training_quantize": self.post_training_quantize,
+ "dump_graphviz_dir": self.dump_graphviz_dir,
+ "dump_graphviz_video": self.dump_graphviz_video
+ }
+
# Converts model.
- result = toco_convert(
- input_data=self._graph_def,
- input_tensors=self._input_tensors,
- output_tensors=self._output_tensors,
- inference_type=self.inference_type,
- inference_input_type=self.inference_input_type,
- input_format=constants.TENSORFLOW_GRAPHDEF,
- output_format=self.output_format,
- quantized_input_stats=quantized_stats,
- default_ranges_stats=self.default_ranges_stats,
- drop_control_dependency=self.drop_control_dependency,
- reorder_across_fake_quant=self.reorder_across_fake_quant,
- change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops,
- quantize_weights=self.quantize_weights,
- dump_graphviz_dir=self.dump_graphviz_dir,
- dump_graphviz_video=self.dump_graphviz_video)
+ if self._has_valid_tensors():
+ result = _toco_convert_impl(
+ input_data=self._graph_def,
+ input_tensors=self._input_tensors,
+ output_tensors=self._output_tensors,
+ **converter_kwargs)
+ else:
+ result = _toco_convert_graph_def(
+ input_data=self._graph_def,
+ input_arrays_with_shape=self._input_arrays_with_shape,
+ output_arrays=self._output_arrays,
+ **converter_kwargs)
return result
def get_input_arrays(self):
@@ -379,7 +458,18 @@ class TocoConverter(object):
Returns:
List of strings.
"""
- return [_tensor_name(tensor) for tensor in self._input_tensors]
+ if self._has_valid_tensors():
+ return [_tensor_name(tensor) for tensor in self._input_tensors]
+ else:
+ return [name for name, _ in self._input_arrays_with_shape]
+
+ def _has_valid_tensors(self):
+ """Checks if the input and output tensors have been initialized.
+
+ Returns:
+ Bool.
+ """
+ return self._input_tensors and self._output_tensors
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -387,7 +477,14 @@ class TocoConverter(object):
Args:
batch_size: Batch size for the model. Replaces the first dimension of an
input size array if undefined. (default 1)
+
+ Raises:
+ ValueError: input_tensor is not defined.
"""
+ if not self._has_valid_tensors():
+ raise ValueError("The batch size cannot be set for this model. Please "
+ "use input_shapes parameter.")
+
for tensor in self._input_tensors:
shape = tensor.get_shape().as_list()
shape[0] = batch_size
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f13684228..3f8ea433ff 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph
+class FromConstructor(test_util.TensorFlowTestCase):
+
+ # Tests invalid constructors using a dummy value for the GraphDef.
+ def testInvalidConstructor(self):
+ message = ('If input_tensors and output_tensors are None, both '
+ 'input_arrays_with_shape and output_arrays must be defined.')
+
+ # `output_arrays` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(
+ None, None, [], input_arrays_with_shape=[('input', [3, 9])])
+ self.assertEqual(message, str(error.exception))
+
+ # `input_arrays_with_shape` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(None, [], None, output_arrays=['output'])
+ self.assertEqual(message, str(error.exception))
+
+ # Tests valid constructors using a dummy value for the GraphDef.
+ def testValidConstructor(self):
+ converter = lite.TocoConverter(
+ None,
+ None,
+ None,
+ input_arrays_with_shape=[('input', [3, 9])],
+ output_arrays=['output'])
+ self.assertFalse(converter._has_valid_tensors())
+ self.assertEqual(converter.get_input_arrays(), ['input'])
+
+ with self.assertRaises(ValueError) as error:
+ converter._set_batch_size(1)
+ self.assertEqual(
+ 'The batch size cannot be set for this model. Please use '
+ 'input_shapes parameter.', str(error.exception))
+
+ converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ self.assertTrue(converter._has_valid_tensors())
+
+
class FromSessionTest(test_util.TensorFlowTestCase):
def testFloat(self):
@@ -279,6 +319,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -331,7 +372,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
- def testQuantizeWeights(self):
+ def testPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
@@ -352,14 +393,14 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_weights_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TocoConverter.from_session(
sess, [in_tensor_1], [out_tensor])
- quantized_weights_converter.quantize_weights = True
- quantized_weights_tflite = quantized_weights_converter.convert()
- self.assertTrue(quantized_weights_tflite)
+ quantized_converter.post_training_quantize = True
+ quantized_tflite = quantized_converter.convert()
+ self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
- self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+ self.assertTrue(len(quantized_tflite) < len(float_tflite))
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -373,6 +414,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -407,6 +449,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(
@@ -434,6 +477,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
@@ -451,6 +495,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
write_graph(sess.graph_def, '', graph_def_file, True)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -476,20 +521,104 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
- def testInvalidFile(self):
+ def testInvalidFileNotFound(self):
+ with self.assertRaises(IOError) as error:
+ lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
+ self.assertEqual('File \'invalid_file\' does not exist.',
+ str(error.exception))
+
+ def testInvalidFileBadData(self):
graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
with gfile.Open(graph_def_file, 'wb') as temp_file:
temp_file.write('bad data')
temp_file.flush()
# Attempts to convert the invalid model.
- with self.assertRaises(ValueError) as error:
+ with self.assertRaises(IOError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
self.assertEqual(
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
+ # TODO(nupurgarg): Test model loading in open source.
+ def _initObjectDetectionArgs(self):
+ # Initializes the arguments required for the object detection model.
+ self._graph_def_file = resource_loader.get_path_to_datafile(
+ 'testdata/tflite_graph.pb')
+ self._input_arrays = ['normalized_input_image_tensor']
+ self._output_arrays = [
+ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
+ 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
+ ]
+ self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
+
+ def testTFLiteGraphDef(self):
+ # Tests the object detection model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ converter = lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays,
+ self._input_shapes)
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(4, len(output_details))
+ self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('TFLite_Detection_PostProcess:1',
+ output_details[1]['name'])
+ self.assertTrue(([1, 10] == output_details[1]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:2',
+ output_details[2]['name'])
+ self.assertTrue(([1, 10] == output_details[2]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:3',
+ output_details[3]['name'])
+ self.assertTrue(([1] == output_details[3]['shape']).all())
+
+ def testTFLiteGraphDefMissingShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # Missing `input_shapes`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays)
+ self.assertEqual('input_shapes must be defined for this model.',
+ str(error.exception))
+
+ def testTFLiteGraphDefInvalidShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # `input_shapes` does not contain the names in `input_arrays`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file,
+ self._input_arrays,
+ self._output_arrays,
+ input_shapes={'invalid-value': [1, 19]})
+ self.assertEqual(
+ 'input_shapes must contain a value for each item in input_array.',
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -628,26 +757,27 @@ class FromKerasFile(test_util.TensorFlowTestCase):
keras.backend.clear_session()
def _getSequentialModel(self):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- try:
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
- return keras_file
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
def testSequentialModel(self):
"""Test a Sequential tf.keras model with default inputs."""
@@ -752,25 +882,26 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModel(self):
"""Test a Functional tf.keras model with default inputs."""
- inputs = keras.layers.Input(shape=(3,), name='input')
- x = keras.layers.Dense(2)(inputs)
- output = keras.layers.Dense(3)(x)
-
- model = keras.models.Model(inputs, output)
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy])
- x = np.random.random((1, 3))
- y = np.random.random((1, 3))
- model.train_on_batch(x, y)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
@@ -809,36 +940,39 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModelMultipleInputs(self):
"""Test a Functional tf.keras model with multiple inputs and outputs."""
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.mae],
- loss_weights=[1., 0.5])
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- model.predict([input_a_np, input_b_np], batch_size=5)
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
+ with session.Session().as_default():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.close(fd)
os.remove(keras_file)
# Check values from converted model.
@@ -871,28 +1005,29 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalSequentialModel(self):
"""Test a Functional tf.keras model containing a Sequential model."""
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model = keras.models.Model(model.input, model.output)
-
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py
index 7908689ce4..8c920132e5 100644
--- a/tensorflow/contrib/lite/python/op_hint.py
+++ b/tensorflow/contrib/lite/python/op_hint.py
@@ -25,9 +25,9 @@ Example:
def tflite_cool_activation(input):
# A cool activation function.
custom = tf.contrib.lite.OpHint("cool_activation")
- input = custom.add_inputs(input)
+ input, = custom.add_inputs(input)
output = tf.sigmoid(input) * input
- custom.add_outputs(output)
+ output, = custom.add_outputs(output)
return output
image = tf.placeholder(tf.float32, (1, 16, 16, 1))
@@ -64,18 +64,27 @@ ops don't actually exist in the normal TensorFlow runtime, but will be
understood by toco later.
"""
+# TODO(aselle): Make this use generic graph transformations.
+# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as _collections
-import itertools as _itertools
+import copy as _copy
import uuid as _uuid
+import six as _six
-from tensorflow.contrib import framework as _framework
from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
from tensorflow.python.framework import ops as _ops
+# TODO(aselle): publicize these apis if we continue to use these.
+from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
+from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.ops import array_ops as _array_ops
+from tensorflow.python.util import compat as _compat
from tensorflow.python.util.all_util import remove_undocumented
@@ -97,11 +106,174 @@ class OpHint(object):
constructs, this mechanism can be retired and changed to use python defun's.
"""
- # Attr constants that are used for representation in the GraphDef
+ # Attr constants that are used for representation in the GraphDef. These
+ # will be used on every Identity op that is involved in a total OpHint.
+
+ # Name of the OpHint function (cosmetic).
FUNCTION_NAME_ATTR = "_tflite_function_name"
+ # UUID of the function (each OpHint gets a new uuid).
FUNCTION_UUID_ATTR = "_tflite_function_uuid"
+ # The index index of the input (or nothing if it is an output).
FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
+ # The output index of the output (or nothing if it is an input).
FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
+ # An index that orders aggregate arguments. Aggregate arguments are ones
+ # that are separate but will be fused horizontally. For example a static LSTM
+ # has a lstm cell for each time step. Each one has a separate opHint, but a
+ # fused SequentialLSTM will treat this as a single tensor.
+ FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
+ # The way in which multiple parts of the aggregate argument will be joined
+ # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
+ # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
+ FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
+ # On fused OpHint stub, the order of inputs that the final LSTM call will
+ # have. What this means is that the TensorFlow order might be
+ # "foo", "bar", "stuff" and you might want the TF lite op order to be
+ # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
+ # attribute to [2, 0, 1, -1].
+ TFLITE_INPUT_INDICES = "_tflite_input_indices"
+
+ # Types of aggregations
+ # stack: stacks all ophints with matching tags. i.e. for a static rnn.
+ # specifically, this is good for an input or output to a static rnn cell.
+ AGGREGATE_STACK = _compat.as_bytes("stack")
+ # first: only takes the first output (one with lowest sort index)
+ # of matching tags. This is good for the input state to an RNN.
+ AGGREGATE_FIRST = _compat.as_bytes("first")
+ # aggregation last takes only the last tag (one with highest sort index).
+ # This is good for an output value on the last stack item of a
+ # static rnn.
+ AGGREGATE_LAST = _compat.as_bytes("last")
+
+ class OpHintArgumentTracker(object):
+ """Conceptually tracks indices of arguments of "OpHint functions".
+
+ The inputs and arguments of these functions both use an instance
+ of the class so they can have independent numbering."""
+
+ def __init__(self, function_name, unique_function_id, node_name_prefix,
+ attr_name):
+ """Initialize ophint argument.
+
+ Args:
+ function_name: Name of the function that this tracks arguments for.
+ unique_function_id: UUID of function that this tracks arguments for.
+ node_name_prefix: How identities that are created are named.
+ attr_name: Name of attribute to use to store the index for this hint.
+ i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
+ """
+
+ # The global index is the argument index of the op. This is in contrast
+ # to the sort index which is the sequence number of a particular instance
+ # of a given global index. For example, you may have called add hint
+ # twice with the tag "foo". Then the global index will be 0 for both
+ # and the sort index will be 0 for the first added and 1 for the second.
+ self._function_name = function_name
+ self._unique_function_id = unique_function_id
+ self._next_global_index = 0 # The absolute global index
+ self._used_global_indices = set()
+ self._tag_to_global_index = {} # The argument index a given tag maps to
+ self._tag_to_next_sort_index = {} # The current index for each tag
+ self._node_name_prefix = node_name_prefix
+ self._attr_name = attr_name
+
+ def _get_new_global_index(self, index_override):
+ """Return the next unused argument index in order or use an override.
+
+ Args:
+ index_override: An index to use instead of the next available or None
+ to use the next available.
+
+ Returns:
+ A valid global_index to use for the next hint argument.
+
+ Raises:
+ ValueError: If the index_override is already used by another hint.
+ """
+ if index_override is None:
+ global_index = self._next_global_index
+ else:
+ if index_override in self._used_global_indices:
+ raise ValueError("Index %d was already used by another call to add")
+ global_index = index_override
+ # Make next_global_index valid
+ self._used_global_indices.add(global_index)
+ while self._next_global_index in self._used_global_indices:
+ self._next_global_index += 1
+ return global_index
+
+ def add(self, arg, tag=None, name=None, aggregate=None,
+ index_override=None):
+ """Return a wrapped tensor of an input tensor as an argument.
+
+ Args:
+ arg: A TensorFlow tensor that should be considered an argument.
+ tag: String tag to identify arguments that should be packed.
+ name: Name of argument. This is included in the Identity hint op names.
+ aggregate: Strategy to aggregate.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ Note, aggregate is only valid if tag is specified.
+ index_override: Specify what input/output index should this be in the
+ final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the
+ final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
+ the default call order based ordering.
+
+ Returns:
+ A tensor representing the wrapped argument.
+
+ Raises:
+ ValueError: When indices are not consistent.
+ """
+
+ # Find the appropriate index
+ if tag is None:
+ if aggregate is not None:
+ raise ValueError("You must specify `tag` if using aggregate.")
+ global_index = self._get_new_global_index(index_override)
+ sort_index = None
+ else:
+ if aggregate is None:
+ raise ValueError("You must specify `aggregate` if using tag.")
+ if tag not in self._tag_to_global_index:
+ self._tag_to_global_index[tag] = (
+ self._get_new_global_index(index_override))
+ self._tag_to_next_sort_index[tag] = 0
+ elif (index_override and
+ index_override != self._tag_to_global_index[tag]):
+ raise ValueError(
+ "Tag %r was called with two indices %r and %r" %
+ (tag, index_override, self._tag_to_global_index[tag]))
+ global_index = self._tag_to_global_index[tag]
+ sort_index = self._tag_to_next_sort_index[tag]
+ self._tag_to_next_sort_index[tag] += 1
+
+ uuid = self._unique_function_id
+ name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
+ uuid, global_index, sort_index, name)
+ identity_op = _array_ops.identity(arg, name=name)
+
+ # pylint: disable=protected-access
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_NAME_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._function_name)))
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_UUID_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._unique_function_id)))
+ identity_op.op._set_attr(
+ self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
+ if sort_index is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_SORT_INDEX_ATTR,
+ _attr_value_pb2.AttrValue(i=sort_index))
+ if aggregate is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_AGGREGATE_ATTR,
+ _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
+ # pylint: enable=protected-access
+ return identity_op
def __init__(self, function_name, **kwargs):
"""Create a OpHint.
@@ -112,10 +284,14 @@ class OpHint(object):
"""
self._function_name = function_name
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
- self._curr_input_index = 0
- self._curr_output_index = 0
self._attrs_to_store_later = kwargs
self._stored_attrs = False
+ self._inputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "InputHint",
+ OpHint.FUNCTION_INPUT_INDEX_ATTR)
+ self._outputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "OutputHint",
+ OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
@@ -124,68 +300,278 @@ class OpHint(object):
tensor=tensor_value.op.node_def.attr["value"].tensor))
# pylint: enable=protected-access
- def add_inputs(self, *args):
+ def add_input(self, *args, **kwargs):
+ """Add a wrapped input argument to the hint.
+
+ Args:
+ *args: The input tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped input tensor.
+ """
+ return self._inputs.add(*args, **kwargs)
+
+ def add_output(self, *args, **kwargs):
+ """Add a wrapped output argument to the hint.
+
+ Args:
+ *args: The output tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped output tensor.
+ """
+ return self._outputs.add(*args, **kwargs)
+
+ def add_inputs(self, *args, **kwargs):
"""Add a sequence of inputs to the function invocation.
Args:
*args: List of inputs to be converted (should be Tf.Tensor).
+ **kwargs: This allows 'names' which should be a list of names.
Returns:
Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's.
"""
-
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_INPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_input_index))
- # pylint: enable=protected-access
- self._curr_input_index += 1
- return identity_op
-
- return [augmented_identity(arg) for arg in args]
-
- def add_outputs(self, *args):
+ if "names" in kwargs:
+ return [
+ self._inputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._inputs.add(arg) for arg in args]
+
+ def add_outputs(self, *args, **kwargs):
"""Add a sequence of outputs to the function invocation.
Args:
*args: List of outputs to be converted (should be tf.Tensor).
+ **kwargs: See
Returns:
Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's.
"""
+ if "names" in kwargs:
+ return [
+ self._outputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._outputs.add(arg) for arg in args]
+
+
+class _LiteOperand(object):
+ """Abstract operand for a tflite hint function.
+
+ This is a base class that handles representing arguments to an OpHint.
+ It also is able to serialize operands to the stubbed graph_def.
+ Child classes are responsible for being able to
+ store information about the hint identity operators. They are also responsible
+ for knowing how to serialize to output graphdefs.
+
+ Typically this will be implemented by holding one or more identity nodes
+ that were previously discovered as hints.
+ """
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the node(s) to out_graphdef and returns the input node name.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The the output that the stub should use as an input for this operand.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """Add node(s) to graph representing output operands and returns type.
+
+ Args:
+ fused_op_name: name of the fused op stub name.
+ output_index: Output index that we are currently processing from stub.
+ out_graphdef: The destination graphdef we are currently building up.
+
+ Returns:
+ The datatype of this identity.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del fused_op_name, output_index, out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_OUTPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_output_index))
- # pylint: enable=protected-access
- self._curr_output_index += 1
- return identity_op
- wrapped_outputs = [augmented_identity(arg) for arg in args]
+class _LiteSingleOperand(_LiteOperand):
+ """A simple operand that is non-aggregated (i.e. most hints)."""
- if not self._stored_attrs:
- for key, value in self._attrs_to_store_later.iteritems():
- self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value)
- self._stored_attrs = True
+ def __init__(self, node):
+ _LiteOperand.__init__(self)
+ self.node = node
+ self.name = _tensor_name_base(node.name)
- return wrapped_outputs
+ def flatten(self):
+ return [self.name]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ return self.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, index,
+ out_graphdef):
+ output_node = _copy.deepcopy(self.node)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(fused_op_name, index))
+ out_graphdef.node.extend([output_node])
+ return self.node.attr["type"].i
+
+ def __str__(self):
+ return str(self.name)
+
+
+class _LiteAggregateOperand(_LiteOperand):
+ """An operand for a tflite hint function that is aggregated from many.
+
+ For example, an LSTM is a grid of operators that are all related. Inputs
+ going into them may need to be fused, so they should all be tracked as
+ related arguments.
+ """
+
+ def __init__(self, aggregation):
+ _LiteOperand.__init__(self)
+ self.aggregation = aggregation
+ self.names = {}
+ self.nodes = {}
+ self.flattened = None
+
+ def add(self, sort, node):
+ self.names[sort] = _tensor_name_base(node.name)
+ self.nodes[sort] = node
+
+ def flatten_nodes(self):
+ """Return a list of all the node protos in aggregation sorted order."""
+ if not self.flattened:
+ self.flattened = [None] * len(self.nodes)
+ for idx, node in _six.iteritems(self.nodes):
+ self.flattened[idx] = node
+ for n in self.nodes:
+ if n is None:
+ raise RuntimeError("Aggregate was missing argument.")
+ if self.aggregation == OpHint.AGGREGATE_FIRST:
+ self.flattened = self.flattened[:1]
+ elif self.aggregation == OpHint.AGGREGATE_LAST:
+ self.flattened = self.flattened[-1:]
+ elif self.aggregation == OpHint.AGGREGATE_STACK:
+ pass
+ else:
+ raise ValueError(
+ "Invalid aggregation type %r specified" % self.aggregation)
+ return self.flattened
+
+ def flatten(self):
+ """Return a list of all node names in aggregation sorted sorter."""
+ return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the nodes to out_graphdef and returns an aggregated output.
+
+ In particular, if you have 4 inputs to a hint stub, this will be the
+ node that you can use as an output. I.e. you have 4 timesteps from a
+ static rnn, then a fused UnidriecitonalLSTM will expect 1 input with
+ all 4 time steps. So here we make a pack and return the output name of
+ that pack.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The name of a pack that aggregates this node.
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ return _tensor_name_base(flattened[0].name)
+ else:
+ new_node = _node_def_pb2.NodeDef()
+ new_node.op = "Pack"
+ new_node.name = "OpHintStack-%s" % flattened[0].name
+ new_node.attr["N"].i = len(flattened)
+ new_node.attr["T"].type = flattened[0].attr["T"].type
+ for discrete in flattened:
+ new_node.input.append(_tensor_name_base(discrete.name))
+ out_graphdef.node.extend([new_node])
+ return new_node.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """This adds to `out_graphdef` all the unaggregated outputs.
+
+ I.e. we are outputting from a fused stub, but we need to make it compatible
+ with the unfused original graph so we insert an unpack. Ideally in a later
+ stage the unpack -> pack sequences will be removed.
+
+ Args:
+ fused_op_name: The name of the stub we are in the process of fusing.
+ output_index: The output output_index this object represents.
+ out_graphdef: The graphdef we are in the process of buildings
+
+ Returns:
+ The type of the aggregated output (so we can finish building the stub
+ op).
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ temp_op = _LiteSingleOperand(flattened[0])
+ return temp_op.aggregate_and_return_name_for_output(
+ fused_op_name, output_index, out_graphdef)
+ else:
+ stack_node = _node_def_pb2.NodeDef()
+ stack_node.op = "Unpack"
+ stack_node.name = "OpHintUnstack-%s" % flattened[0].name
+ stack_node.attr["num"].i = len(flattened)
+ output_type = flattened[0].attr["T"].type
+ stack_node.attr["T"].type = output_type
+ stack_node.input.append(_tensorflow_output_name(
+ fused_op_name, output_index))
+ out_graphdef.node.extend([stack_node])
+
+ for idx, discrete in enumerate(flattened):
+ output_node = _copy.deepcopy(discrete)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
+ out_graphdef.node.extend([output_node])
+
+ return output_type
+
+ def __str__(self):
+ s = "\t\t\tAGGREGATE %s\n" % self.aggregation
+ for sort, val in self.names.iteritems():
+ s += "\t\t\t%d: %s\n" % (sort, val)
+ return s
class _LiteFuncCall(object):
@@ -212,46 +598,87 @@ class _LiteFuncCall(object):
self.uuid = None
self.params = {}
+ def flattened_inputs_and_outputs(self):
+ """Return a list of inputs and outputs in a flattened format.
+
+ Returns:
+ Tuple of (inputs, outputs). where input and output i a list of names.
+ """
+ def _flatten(input_or_output_dict):
+ flattened_items = []
+ for item in input_or_output_dict.values():
+ flattened_items.extend(item.flatten())
+ return flattened_items
+
+ return _flatten(self.inputs), _flatten(self.outputs)
+
def __str__(self):
- return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % (
- self.function_name, self.uuid, self.inputs, self.outputs)
+ def format_args(items):
+ s = ""
+ for idx, item in items.iteritems():
+ s += ("\t\t%d:\n" % idx) + str(item)
+ return s
+
+ inputs_str = "\tInputs\n" + format_args(self.inputs)
+ outputs_str = "\tOutputs\n" + format_args(self.outputs)
+ return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
+ % (self.function_name, self.uuid, inputs_str, outputs_str))
-def _find_all_hints_in_graph_def(session):
+
+def _find_all_hints_in_graph_def(graphdef):
"""Look at the current default graph and return a list of LiteFuncCall objs.
Args:
- session: A TensorFlow session that contains the graph to convert.
+ graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
Returns:
a list of `LifeFuncCall` objects in the form
"""
func_calls = _collections.defaultdict(_LiteFuncCall)
- seen_ops = set()
-
- for op in session.graph.get_operations():
- for operand in _itertools.chain(op.inputs, op.outputs):
- if operand in seen_ops:
- continue
- seen_ops.add(operand)
- attr = operand.op.node_def.attr
- uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
- if OpHint.FUNCTION_UUID_ATTR not in attr:
- continue
- call_def = func_calls[uuid]
- call_def.uuid = uuid
- if OpHint.FUNCTION_UUID_ATTR in attr:
- call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
- if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
- call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand
- if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
- call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand
-
- for a in attr:
- if a.startswith("_tflite_attr_"):
- # TODO(aselle): Remember the attribute tensors so we can put them
- # in collapse.
- call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
+
+ for node in graphdef.node:
+ attr = node.attr
+ # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
+ uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
+ if (OpHint.FUNCTION_UUID_ATTR not in attr
+ or not attr[OpHint.FUNCTION_UUID_ATTR].s):
+ continue
+
+ # Start building function
+ call_def = func_calls[uuid]
+ call_def.uuid = uuid
+ call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
+ # Get sorting and aggregation information
+
+ sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
+ if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
+ if sort == -1: sort = None
+ aggregation = None
+ if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
+ aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s
+
+ # Add the input or output
+ def put_operand(stuff, index, sort, operand, aggregation):
+ """Add a given index into the function structure."""
+ if sort is None:
+ stuff[index] = _LiteSingleOperand(operand)
+ else:
+ if index not in stuff:
+ stuff[index] = _LiteAggregateOperand(aggregation)
+ stuff[index].add(sort, operand)
+
+ if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
+ put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+ if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
+ put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+
+ # Remember attributes
+ for a in attr:
+ if a.startswith("_tflite_attr_"):
+ call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
return func_calls
@@ -267,42 +694,305 @@ def _tensor_name_base(full_tensor_name):
Returns:
A name without any device assignment.
"""
- return full_tensor_name.name.split(":")[0]
+ if full_tensor_name.startswith("^"):
+ return full_tensor_name[1:]
+ return full_tensor_name.split(":")[0]
+
+
+def _tensorflow_output_name(tensor_name, output_index):
+ return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
+ output_index)
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name):
+ """Checks to make sure node only connects to predecessor graph through inputs.
+
+ Args:
+ n: Node to check
+ reachable_by_input: Nodes that are reachable by all inputs of subgraph
+ input_nodes_set: The set of nodes that are "inputs".
+ name_to_input_name: Maps from name to the list of inputs.
+
+ Raises:
+ TypeError: If the given node uses items past inputs directly.
+ """
+ next_to_visit = [n]
+ visited = set()
+ while next_to_visit:
+ current_node = next_to_visit.pop()
+ visited.add(current_node)
+ if (current_node in reachable_by_input
+ and current_node not in input_nodes_set):
+ raise TypeError(
+ "Node %s uses input %s not in input_nodes." % (n, current_node))
+ if current_node not in input_nodes_set:
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node]
+ if input_node not in visited
+ ]
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _convert_single_op_hint_to_stub(call, graph_def):
+ """Given a graph_def, converts `call` into a stub and returns a new graph_def.
+ Args:
+ call: A single function call to be converted.
+ graph_def: A graph_def to use as input (that hass call obviously).
+ Returns:
+ A new transformed graph-def that has call as a stub (single op).
-def convert_op_hints_to_stubs(session):
+ Note: after this process, the graph_def can no longer be loaded into
+ the tensorflow runtime, so all future manipulations are done in graph_def
+ level.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ graph_def)
+ input_names, output_names = call.flattened_inputs_and_outputs()
+
+ reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
+ reachable_by_output = _bfs_for_reachable_nodes(output_names,
+ name_to_input_name)
+ input_nodes_set = set(input_names)
+ output_nodes_set = set(output_names)
+ nodes_after_fuse = []
+ nodes_deleted_by_fuse = set()
+ # Classify each node. We want to keep everything reachable by input, but
+ # we don't know if things that are not reachable by output or input (things
+ # after fusing).
+ for node in graph_def.node:
+ n = _tensor_name_base(node.name)
+ if n in reachable_by_output:
+ if n not in reachable_by_input and n not in output_nodes_set:
+ # n is an internal node. Check to make sure it is really internal.
+ # TODO(aselle): this could be done more efficiently by flooding
+ # the graph first.
+ _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name)
+ nodes_deleted_by_fuse.add(n)
+ elif n not in reachable_by_input:
+ # n is a node that after all the fusings, so keep it.
+ nodes_after_fuse.append(n)
+ else:
+ # n is a node that is randomly in the graph but not connected to
+ # the chain of dependencies.
+ pass
+
+ # Make a new graphdef with all the pre-input and input nodes
+ out = _graph_pb2.GraphDef()
+ reachable_by_input_sorted = sorted(
+ list(reachable_by_input), key=lambda n: name_to_seq_num[n])
+ for node in reachable_by_input_sorted:
+ out.node.extend([_copy.deepcopy(name_to_node[node])])
+
+ # Create any stacks to aggregate arguments into to a single input
+ # i.e. for static_rnn's.
+ # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
+ sorted_input_indices = list(call.inputs.keys())
+ sorted_input_indices.sort()
+ sorted_output_indices = list(call.outputs.keys())
+ sorted_output_indices.sort()
+ new_node = _node_def_pb2.NodeDef()
+ # Delegate to each operand to produce the proper new input for this stub node.
+ # In particular, an aggregate input will now be a Pack of some previously
+ # non-fused things.
+ for input_index in sorted_input_indices:
+ inputs = call.inputs[input_index]
+ new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
+ new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
+
+ # Ceate the function
+ new_node.op = call.function_name
+ new_node.name = call.uuid
+ out.node.extend([new_node])
+
+ # Now call each output argument to give them a chance to make the proper
+ # output type and add it to our new_node.
+ output_dtypes = []
+ for output_index in sorted_output_indices:
+ output = call.outputs[output_index]
+ output_dtype = (
+ output.aggregate_and_return_name_for_output(new_node.name, output_index,
+ out))
+ output_dtypes.append(output_dtype)
+ new_node.attr["_output_types"].list.type[:] = output_dtypes
+ # TODO(aselle): what is right here?
+ new_node.attr["_output_quantized"].b = False
+
+ # Add post output nodes that do not depend on the outputs
+ for n in nodes_after_fuse:
+ should_keep = True
+ for input_name in name_to_input_name[n]:
+ if input_name in nodes_deleted_by_fuse:
+ should_keep = False
+ if should_keep:
+ out.node.extend([_copy.deepcopy(name_to_node[n])])
+
+ # Misc. graph_def data that needs copying.
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+
+ return out
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _remove_one_redundant_stack_unstack(in_graph_def):
+ """Removes a stack->unstack pattern from in_graph_def in a returned graph.
+
+ Args:
+ in_graph_def: Graph def to use as input.
+ Returns:
+ Simplified tuple (graph_def, changed_something) where changed_something
+ is true if anything was done.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ in_graph_def)
+ del name_to_seq_num
+
+ # TODO(aselle): Make this not hardcoded.
+ do_generic_pack_unpack = True
+
+ out = _graph_pb2.GraphDef()
+ out.library.CopyFrom(in_graph_def.library)
+ out.versions.CopyFrom(in_graph_def.versions)
+ for n in in_graph_def.node:
+ node_name = _tensor_name_base(n.name)
+ if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
+ continue
+ next_to_visit = [node_name]
+ visited = set()
+
+ unpack_nodes = set()
+ pack_node = node_name
+
+ # Find a pattern of unstack connected to a stack (with identities
+ # in between.
+ matches_pattern = True
+ is_hint_created_stack = False
+ while next_to_visit:
+ current_node_name = next_to_visit[0]
+ visited.add(current_node_name)
+ del next_to_visit[0]
+ node = name_to_node[current_node_name]
+ is_op_hint_stack = node.name.startswith("OpHintStack")
+ is_op_hint_unstack = node.name.startswith("OpHintUnstack")
+ if (node.op == "Identity" or is_op_hint_stack
+ or (do_generic_pack_unpack and node.op == "Pack")):
+ is_hint_created_stack |= is_op_hint_stack
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node_name]
+ if input_node not in visited
+ ]
+ elif (is_op_hint_unstack
+ or (do_generic_pack_unpack and node.op == "Unpack")):
+ unpack_nodes.add(node.name)
+ is_hint_created_stack &= is_op_hint_unstack
+ else:
+ matches_pattern = False
+ break
+ visited.add(node.name)
+
+ if matches_pattern and len(unpack_nodes) == 1:
+ pack_node = node_name
+
+ # Check to see if anyone depends on the intermediate identity or the
+ # Unstacked form
+ no_external_dependency = True
+ for other_n in in_graph_def.node:
+ if other_n.name in visited: continue
+ for input_tensor in name_to_input_name[other_n.name]:
+ input_op = _tensor_name_base(input_tensor)
+ if input_op in visited and input_op != pack_node:
+ no_external_dependency = False
+ # Proceed with the substitution if the stack/unstack pair was created
+ # through hints, or that it was not, but nobody is consuming things
+ # between the stack and unstack.
+ if is_hint_created_stack or no_external_dependency:
+ end = unpack_nodes.pop()
+ end_input = name_to_node[end].input[0]
+ # All nodes that depend on the final stack need to be redone to use
+ for other_n in in_graph_def.node:
+ node_name = _tensor_name_base(other_n.name)
+ if node_name not in visited:
+ new_node = _copy.deepcopy(other_n)
+ new_node.input[:] = [
+ (end_input if stripped == pack_node else
+ non_stripped) for stripped, non_stripped in zip(
+ name_to_input_name[node_name], new_node.input[:])
+ ]
+ out.node.extend([new_node])
+ return out, True
+ return in_graph_def, False
+
+
+def _remove_redundant_stack_unstack(graph_def):
+ curr = graph_def
+ del graph_def
+ changed_stuff = True
+ while changed_stuff:
+ curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
+ return curr
+
+
+def _convert_op_hints_to_stubs_helper(
+ graph_def, write_callback=lambda sess, graph_def: None):
+ """Converts a graph_def to a new graph_def where all op hints are stubbed.
+
+ Args:
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
+ Returns:
+ A new stubbed graph_def.
+ """
+
+ hints = _find_all_hints_in_graph_def(graph_def)
+ curr_graph_def = graph_def
+ del graph_def # prevent using graph_def again (common source of error)
+ for hint in _six.itervalues(hints):
+ curr_graph_def = _convert_single_op_hint_to_stub(
+ hint, curr_graph_def)
+ write_callback(curr_graph_def, "initial")
+ # The stubbing process can create stacks/unstacks in the case of LSTMs
+ # remove them.
+ curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
+ return curr_graph_def
+
+
+def convert_op_hints_to_stubs(session=None,
+ graph_def=None,
+ write_callback=lambda graph_def, comments: None):
"""Converts a graphdef with LiteOp hints into stub operations.
This is used to prepare for toco conversion of complex intrinsic usages.
+ Note: only one of session or graph_def should be used, not both.
Args:
session: A TensorFlow session that contains the graph to convert.
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
Returns:
A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters.
+ Raises:
+ ValueError: If both session and graph_def are provided.
"""
- hints = _find_all_hints_in_graph_def(session)
- current_graph_def = session.graph_def
- for call in hints.values():
- input_names = [None] * len(call.inputs)
- output_names = [None] * len(call.outputs)
- output_dtypes = [None] * len(call.outputs)
- output_quantized = False
- for input_index, tensor in call.inputs.items():
- input_names[input_index] = _tensor_name_base(tensor)
- for output_index, tensor in call.outputs.items():
- output_names[output_index] = _tensor_name_base(tensor)
- output_dtypes[output_index] = tensor.dtype.as_datatype_enum
- # TODO(aselle): Support quantized flag properly
- current_graph_def = _framework.fuse_op(
- current_graph_def, input_names, output_names, output_dtypes,
- output_quantized, call.uuid, call.function_name)
- for node in current_graph_def.node:
- if node.name == call.uuid:
- for param, tensor in call.params.items():
- node.attr[param].tensor.CopyFrom(tensor)
- return current_graph_def
-
-
-_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"]
+
+ if session is not None and graph_def is not None:
+ raise ValueError("Provide only one of session and graph_def.")
+
+ if session is not None:
+ return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
+ elif graph_def is not None:
+ return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
+ else:
+ raise ValueError("Must specify session or graph_def as input.")
+
+
+_allowed_symbols = [
+ "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new"
+]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index a76cc39635..cc08ed3fe9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -47,6 +47,9 @@ def _get_toco_converter(flags):
Returns:
TocoConverter object.
+
+ Raises:
+ ValueError: Invalid flags.
"""
# Parse input and output arrays.
input_arrays = _parse_array(flags.input_arrays)
@@ -77,6 +80,9 @@ def _get_toco_converter(flags):
elif flags.keras_model_file:
converter_fn = lite.TocoConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
+ else:
+ raise ValueError("--graph_def_file, --saved_model_dir, or "
+ "--keras_model_file must be specified.")
return converter_fn(**converter_kwargs)
@@ -103,8 +109,14 @@ def _convert_model(flags):
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
- std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
- mean_values = _parse_array(flags.mean_values, type_fn=int)
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
+
+ # In quantized inference, mean_value has to be integer so that the real
+ # value 0.0 is exactly representable.
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
+ else:
+ mean_values = _parse_array(flags.mean_values, type_fn=float)
quant_stats = list(zip(mean_values, std_dev_values))
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
@@ -126,14 +138,18 @@ def _convert_model(flags):
if flags.reorder_across_fake_quant:
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
if flags.change_concat_input_ranges:
- converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ converter.change_concat_input_ranges = (
+ flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
- if flags.quantize_weights:
+
+ if flags.post_training_quantize:
+ converter.post_training_quantize = flags.post_training_quantize
if flags.inference_type == lite_constants.QUANTIZED_UINT8:
- raise ValueError("--quantized_weights is not supported with "
- "--inference_type=QUANTIZED_UINT8")
- converter.quantize_weights = flags.quantize_weights
+ print("--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
+ converter.inference_type = lite_constants.FLOAT
+
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
if flags.dump_graphviz_video:
@@ -286,12 +302,13 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated integers. Used for quantization. (default None)"))
+ "comma-separated floats. Used for quantized input tensors. "
+ "(default None)"))
parser.add_argument(
"--mean_values",
type=str,
help=("Mean of training data for each input tensor, comma-separated "
- "integers. Used for quantization. (default None)"))
+ "floats. Used for quantized input tensors. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,
@@ -304,12 +321,20 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ # quantize_weights is DEPRECATED.
parser.add_argument(
"--quantize_weights",
- type=bool,
- help=("Store float weights as quantized weights followed by dequantize "
- "operations. Inference is still done in FLOAT, but reduces model "
- "size (at the cost of accuracy and latency)."))
+ dest="post_training_quantize",
+ action="store_true",
+ help=argparse.SUPPRESS)
+ parser.add_argument(
+ "--post_training_quantize",
+ dest="post_training_quantize",
+ action="store_true",
+ help=(
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy). (default False)"))
# Graph manipulation flags.
parser.add_argument(
@@ -327,9 +352,14 @@ def run_main(_):
"the graph. Results in a graph that differs from the quantized "
"training graph, potentially causing differing arithmetic "
"behavior. (default False)"))
+ # Usage for this flag is --change_concat_input_ranges=true or
+ # --change_concat_input_ranges=false in order to make it clear what the flag
+ # is set to. This keeps the usage consistent with other usages of the flag
+ # where the default is different. The default value here is False.
parser.add_argument(
"--change_concat_input_ranges",
- action="store_true",
+ type=str.upper,
+ choices=["TRUE", "FALSE"],
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index b616e449e6..55bf2c48b9 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -48,7 +48,7 @@ exports_files([
"schema_v3.fbs",
])
-load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library")
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
# Generic schema for inference on device.
flatbuffer_cc_library(
@@ -56,6 +56,20 @@ flatbuffer_cc_library(
srcs = ["schema.fbs"],
)
+# Generic schema for inference on device (but with reflections makes bigger).
+flatbuffer_cc_library(
+ name = "schema_fbs_with_reflection",
+ srcs = ["schema.fbs"],
+ flatc_args = [
+ "--reflect-types",
+ "--reflect-names",
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+ ],
+ gen_reflections = True,
+ out_prefix = "reflection/",
+)
+
# Schema test to make sure we don't introduce backward incompatible changes
# to schemas.
cc_test(
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
index cd46a06f7d..11057203a8 100644
--- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <fstream>
#include <gtest/gtest.h>
-#include "flatbuffers/flatc.h"
+#include "flatbuffers/flatc.h" // flatbuffers
#include "tensorflow/core/platform/platform.h"
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 14f88b4c00..cf66403ec9 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -169,6 +169,10 @@ enum BuiltinOperator : byte {
ONE_HOT = 85,
LOGICAL_AND = 86,
LOGICAL_NOT = 87,
+ UNPACK = 88,
+ REDUCE_MIN = 89,
+ FLOOR_DIV = 90,
+ REDUCE_ANY = 91,
}
// Options for the builtin operators.
@@ -236,6 +240,8 @@ union BuiltinOptions {
OneHotOptions,
LogicalAndOptions,
LogicalNotOptions,
+ UnpackOptions,
+ FloorDivOptions,
}
enum Padding : byte { SAME, VALID }
@@ -565,6 +571,14 @@ table LogicalAndOptions {
table LogicalNotOptions {
}
+table UnpackOptions {
+ num:int;
+ axis:int;
+}
+
+table FloorDivOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
@@ -631,9 +645,9 @@ table SubGraph {
}
// Table of raw data buffers (used for constant tensors). Referenced by tensors
-// by index.
+// by index. The generous alignment accommodates mmap-friendly data structures.
table Buffer {
- data:[ubyte];
+ data:[ubyte] (force_align: 16);
}
table Model {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 3efa153e2c..6d9630d75e 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -220,6 +220,12 @@ struct LogicalAndOptionsT;
struct LogicalNotOptions;
struct LogicalNotOptionsT;
+struct UnpackOptions;
+struct UnpackOptionsT;
+
+struct FloorDivOptions;
+struct FloorDivOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -373,11 +379,15 @@ enum BuiltinOperator {
BuiltinOperator_ONE_HOT = 85,
BuiltinOperator_LOGICAL_AND = 86,
BuiltinOperator_LOGICAL_NOT = 87,
+ BuiltinOperator_UNPACK = 88,
+ BuiltinOperator_REDUCE_MIN = 89,
+ BuiltinOperator_FLOOR_DIV = 90,
+ BuiltinOperator_REDUCE_ANY = 91,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT
+ BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -465,7 +475,11 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] {
BuiltinOperator_LOGICAL_OR,
BuiltinOperator_ONE_HOT,
BuiltinOperator_LOGICAL_AND,
- BuiltinOperator_LOGICAL_NOT
+ BuiltinOperator_LOGICAL_NOT,
+ BuiltinOperator_UNPACK,
+ BuiltinOperator_REDUCE_MIN,
+ BuiltinOperator_FLOOR_DIV,
+ BuiltinOperator_REDUCE_ANY
};
return values;
}
@@ -560,6 +574,10 @@ inline const char **EnumNamesBuiltinOperator() {
"ONE_HOT",
"LOGICAL_AND",
"LOGICAL_NOT",
+ "UNPACK",
+ "REDUCE_MIN",
+ "FLOOR_DIV",
+ "REDUCE_ANY",
nullptr
};
return names;
@@ -635,11 +653,13 @@ enum BuiltinOptions {
BuiltinOptions_OneHotOptions = 61,
BuiltinOptions_LogicalAndOptions = 62,
BuiltinOptions_LogicalNotOptions = 63,
+ BuiltinOptions_UnpackOptions = 64,
+ BuiltinOptions_FloorDivOptions = 65,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions
+ BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -704,7 +724,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] {
BuiltinOptions_LogicalOrOptions,
BuiltinOptions_OneHotOptions,
BuiltinOptions_LogicalAndOptions,
- BuiltinOptions_LogicalNotOptions
+ BuiltinOptions_LogicalNotOptions,
+ BuiltinOptions_UnpackOptions,
+ BuiltinOptions_FloorDivOptions
};
return values;
}
@@ -775,6 +797,8 @@ inline const char **EnumNamesBuiltinOptions() {
"OneHotOptions",
"LogicalAndOptions",
"LogicalNotOptions",
+ "UnpackOptions",
+ "FloorDivOptions",
nullptr
};
return names;
@@ -1041,6 +1065,14 @@ template<> struct BuiltinOptionsTraits<LogicalNotOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions;
};
+template<> struct BuiltinOptionsTraits<UnpackOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FloorDivOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1576,6 +1608,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LogicalNotOptions ?
reinterpret_cast<const LogicalNotOptionsT *>(value) : nullptr;
}
+ UnpackOptionsT *AsUnpackOptions() {
+ return type == BuiltinOptions_UnpackOptions ?
+ reinterpret_cast<UnpackOptionsT *>(value) : nullptr;
+ }
+ const UnpackOptionsT *AsUnpackOptions() const {
+ return type == BuiltinOptions_UnpackOptions ?
+ reinterpret_cast<const UnpackOptionsT *>(value) : nullptr;
+ }
+ FloorDivOptionsT *AsFloorDivOptions() {
+ return type == BuiltinOptions_FloorDivOptions ?
+ reinterpret_cast<FloorDivOptionsT *>(value) : nullptr;
+ }
+ const FloorDivOptionsT *AsFloorDivOptions() const {
+ return type == BuiltinOptions_FloorDivOptions ?
+ reinterpret_cast<const FloorDivOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5649,6 +5697,112 @@ inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(
flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct UnpackOptionsT : public flatbuffers::NativeTable {
+ typedef UnpackOptions TableType;
+ int32_t num;
+ int32_t axis;
+ UnpackOptionsT()
+ : num(0),
+ axis(0) {
+ }
+};
+
+struct UnpackOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnpackOptionsT NativeTableType;
+ enum {
+ VT_NUM = 4,
+ VT_AXIS = 6
+ };
+ int32_t num() const {
+ return GetField<int32_t>(VT_NUM, 0);
+ }
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_NUM) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+ UnpackOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(UnpackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<UnpackOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct UnpackOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_num(int32_t num) {
+ fbb_.AddElement<int32_t>(UnpackOptions::VT_NUM, num, 0);
+ }
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(UnpackOptions::VT_AXIS, axis, 0);
+ }
+ explicit UnpackOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnpackOptionsBuilder &operator=(const UnpackOptionsBuilder &);
+ flatbuffers::Offset<UnpackOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnpackOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num = 0,
+ int32_t axis = 0) {
+ UnpackOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_num(num);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FloorDivOptionsT : public flatbuffers::NativeTable {
+ typedef FloorDivOptions TableType;
+ FloorDivOptionsT() {
+ }
+};
+
+struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FloorDivOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ FloorDivOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FloorDivOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FloorDivOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit FloorDivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FloorDivOptionsBuilder &operator=(const FloorDivOptionsBuilder &);
+ flatbuffers::Offset<FloorDivOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FloorDivOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ FloorDivOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5971,6 +6125,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const {
return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast<const LogicalNotOptions *>(builtin_options()) : nullptr;
}
+ const UnpackOptions *builtin_options_as_UnpackOptions() const {
+ return builtin_options_type() == BuiltinOptions_UnpackOptions ? static_cast<const UnpackOptions *>(builtin_options()) : nullptr;
+ }
+ const FloorDivOptions *builtin_options_as_FloorDivOptions() const {
+ return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast<const FloorDivOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6254,6 +6414,14 @@ template<> inline const LogicalNotOptions *Operator::builtin_options_as<LogicalN
return builtin_options_as_LogicalNotOptions();
}
+template<> inline const UnpackOptions *Operator::builtin_options_as<UnpackOptions>() const {
+ return builtin_options_as_UnpackOptions();
+}
+
+template<> inline const FloorDivOptions *Operator::builtin_options_as<FloorDivOptions>() const {
+ return builtin_options_as_FloorDivOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8441,6 +8609,58 @@ inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffer
_fbb);
}
+inline UnpackOptionsT *UnpackOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new UnpackOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void UnpackOptions::UnPackTo(UnpackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = num(); _o->num = _e; };
+ { auto _e = axis(); _o->axis = _e; };
+}
+
+inline flatbuffers::Offset<UnpackOptions> UnpackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateUnpackOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnpackOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _num = _o->num;
+ auto _axis = _o->axis;
+ return tflite::CreateUnpackOptions(
+ _fbb,
+ _num,
+ _axis);
+}
+
+inline FloorDivOptionsT *FloorDivOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FloorDivOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FloorDivOptions::UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<FloorDivOptions> FloorDivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFloorDivOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FloorDivOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateFloorDivOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8882,6 +9102,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9152,6 +9380,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9410,6 +9646,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const LogicalNotOptionsT *>(value);
return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<const UnpackOptionsT *>(value);
+ return CreateUnpackOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptionsT *>(value);
+ return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9668,6 +9912,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new LogicalNotOptionsT(*reinterpret_cast<LogicalNotOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_UnpackOptions: {
+ value = new UnpackOptionsT(*reinterpret_cast<UnpackOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FloorDivOptions: {
+ value = new FloorDivOptionsT(*reinterpret_cast<FloorDivOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9990,6 +10242,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_UnpackOptions: {
+ auto ptr = reinterpret_cast<UnpackOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<FloorDivOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index f738315cf2..45d0d8735e 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <list>
#include <memory>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc
index 646913c026..e29a6345fd 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/stderr_reporter.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#include <cstdarg>
#include <cstdio>
@@ -22,26 +22,6 @@ limitations under the License.
namespace tflite {
-ErrorReporter::~ErrorReporter() {}
-
-int ErrorReporter::Report(const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
int StderrReporter::Report(const char* format, va_list args) {
#ifdef __ANDROID__
// On Android stderr is not captured for applications, only for code run from
diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h
new file mode 100644
index 0000000000..c6f4ffbdff
--- /dev/null
+++ b/tensorflow/contrib/lite/stderr_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h
index 7f8f4e851e..af3fadfcb3 100644
--- a/tensorflow/contrib/lite/string.h
+++ b/tensorflow/contrib/lite/string.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Abstract string. We don't want even absl at this level.
-#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
-#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_STRING_H_
+#define TENSORFLOW_CONTRIB_LITE_STRING_H_
#include <string>
@@ -26,4 +26,4 @@ using std::string;
} // namespace tflite
-#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_
+#endif // TENSORFLOW_CONTRIB_LITE_STRING_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a316a40b62..b991e999b6 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 57f129bf5e..d24627b509 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -42,7 +42,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
index d53fec7512..a583a9184b 100644
--- a/tensorflow/contrib/lite/string_util_test.cc
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/string_util.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 89912fd116..3a6c16cafc 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -36,7 +36,7 @@ load(
tags = [
"gen_zip_test",
"no_oss",
- "tflite_not_portable",
+ "tflite_not_portable_intentional",
],
test_name = test_name,
deps = [
@@ -214,6 +214,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string",
+ "//tensorflow/contrib/lite/core/api",
],
)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 597ee8fb1e..32f02a4f6c 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -780,10 +780,15 @@ def make_binary_op_tests(zip_path, binary_operator):
"input_shape_2": [[5]],
"activation": [False, True]
}, {
- "dtype": [tf.float32],
+ "dtype": [tf.float32, tf.int32],
"input_shape_1": [[1, 3, 4, 3]],
"input_shape_2": [[3]],
- "activation": [True]
+ "activation": [True, False]
+ }, {
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ "activation": [True, False]
}, {
"dtype": [tf.float32],
"input_shape_1": [[]],
@@ -821,13 +826,17 @@ def make_binary_op_tests(zip_path, binary_operator):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
+def make_reduce_tests(reduce_op,
+ min_value=-10,
+ max_value=10,
+ boolean_tensor_only=False):
"""Make a set of tests to do reduce operation.
Args:
reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`.
min_value: min value for created tensor data.
max_value: max value for created tensor data.
+ boolean_tensor_only: If true, will only generate tensor with boolean value.
Returns:
a function representing the true generator with `reduce_op_in` curried.
@@ -867,10 +876,11 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
def build_graph(parameters):
"""Build the mean op testing graph."""
+ dtype = parameters["input_dtype"]
+ if boolean_tensor_only:
+ dtype = tf.bool
input_tensor = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input",
- shape=parameters["input_shape"])
+ dtype=dtype, name="input", shape=parameters["input_shape"])
# Get axis as either a placeholder or constants.
if parameters["const_axis"]:
@@ -889,9 +899,12 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
+ dtype = parameters["input_dtype"]
+ if boolean_tensor_only:
+ dtype = tf.bool
values = [
create_tensor_data(
- parameters["input_dtype"],
+ dtype,
parameters["input_shape"],
min_value=min_value,
max_value=max_value)
@@ -926,6 +939,16 @@ def make_reduce_max_tests(zip_path):
return make_reduce_tests(tf.reduce_max)(zip_path)
+def make_reduce_min_tests(zip_path):
+ """Make a set of tests to do min."""
+ return make_reduce_tests(tf.reduce_min)(zip_path)
+
+
+def make_reduce_any_tests(zip_path):
+ """Make a set of tests to do any."""
+ return make_reduce_tests(tf.reduce_any, boolean_tensor_only=True)(zip_path)
+
+
def make_exp_tests(zip_path):
"""Make a set of tests to do exp."""
@@ -1080,6 +1103,10 @@ def make_pow_tests(zip_path):
make_binary_op_tests(zip_path, tf.pow)
+def make_floor_div_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.floor_div)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
@@ -1652,6 +1679,7 @@ def make_pad_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1659,13 +1687,20 @@ def make_pad_tests(zip_path):
[0, 0], [2, 3]]],
"constant_paddings": [True, False],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[1, 2]]],
+ "constant_paddings": [False],
+ },
]
def build_graph(parameters):
@@ -1703,6 +1738,7 @@ def make_padv2_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1711,14 +1747,22 @@ def make_padv2_tests(zip_path):
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[0, 1]]],
+ "constant_paddings": [False],
+ "constant_values": [0, 2],
+ },
]
def build_graph(parameters):
@@ -2373,7 +2417,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
- "split_tflite_lstm_inputs": [True, False],
+ "split_tflite_lstm_inputs": [False],
},
]
@@ -3144,6 +3188,36 @@ def make_pack_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_unpack_tests(zip_path):
+ """Make a set of tests to do unstack."""
+
+ test_parameters = [{
+ "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_valid_axis(parameters):
+ """Return a tweaked version of 'axis'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ while axis > len(shape) - 1:
+ axis -= 1
+ return axis
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
+ outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
+ return [input_tensor], outs
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def _make_logical_tests(op):
"""Make a set of tests to do logical operations."""
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e67fee2a1c..349aa5a3b4 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -58,12 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- // Pad and PadV2 only supports 4D tensors.
- {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
- {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
-
// L2Norm only supports tensors with 4D or fewer.
{R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
@@ -101,6 +95,15 @@ std::map<string, string> kBrokenTests = {
"77546240"},
{R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
+
+ // No Support for float.
+ {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"},
+
+ // Relu does not support int32.
+ // These test cases appends a Relu after the tested ops when
+ // activation=True. The tests are failing since Relu doesn't support int32.
+ {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"},
+ {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"},
};
// Allows test data to be unarchived into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h
index d94361d735..26ee825866 100644
--- a/tensorflow/contrib/lite/testing/parse_testdata.h
+++ b/tensorflow/contrib/lite/testing/parse_testdata.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
-#define TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_
#include <vector>
#include "tensorflow/contrib/lite/interpreter.h"
@@ -72,4 +72,4 @@ bool ParseAndRunTests(std::istream* input, TestRunner* test_runner,
} // namespace testing
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_PARSE_TESTDATA_H_
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4dacf9c84b..1836eb53b9 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() {
void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensorsToZero();
-
- // Below is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Remove the code below after nobody is using the 18-inputs
- // definition.
- for (auto node_index : interpreter_->execution_plan()) {
- const auto& node_and_reg = interpreter_->node_and_registration(node_index);
- const auto& node = node_and_reg->first;
- const auto& registration = node_and_reg->second;
-
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
- const auto* params =
- reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
- if (params->kernel_type == kTfLiteLSTMFullKernel &&
- node.inputs->size == 18 && node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
- }
- }
- }
- }
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h
index 7ed8eb96b7..8195391851 100644
--- a/tensorflow/contrib/lite/testing/tokenize.h
+++ b/tensorflow/contrib/lite/testing/tokenize.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
-#define TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_
#include <istream>
#include <string>
@@ -39,4 +39,4 @@ void Tokenize(std::istream* input, TokenProcessor* processor);
} // namespace testing
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZE_H_
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 8aa639157b..925791d390 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <cstdio>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 02d0890a7a..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -213,7 +213,6 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
- "graph_transformations/quantize_weights.cc",
"graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
@@ -373,6 +372,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 1f3ea2e1c7..18c904c6d4 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -106,6 +106,17 @@ class Allocator {
// Core allocation routine.
void Allocate(std::size_t size, Alloc* result) {
+ if (size == 0) {
+ // zero-sized arrays get a dummy alloc of (0, 0) that does not
+ // need to be kept in the books (no need to insert that into
+ // live_allocs_).
+ // Note: zero-sized arrays shouldn't exist, but handling that case
+ // here allows such pathological cases to get a cleaner error message
+ // later instead of generating spurious allocator failures.
+ result->start = 0;
+ result->end = 0;
+ return;
+ }
// Naive algorithm: pick the first gap between live allocations,
// that is wide enough for the new array.
std::size_t pos = 0;
@@ -128,6 +139,11 @@ class Allocator {
}
void Deallocate(const Alloc& a) {
+ // Special-case dummy allocs for zero-sized arrays.
+ if (a.start == 0 && a.end == 0) {
+ // Nothing needs to be done, these aren't kept in the books.
+ return;
+ }
auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a);
CHECK(iter != live_allocs_.end());
CHECK(*iter == a);
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index aef35ad490..f14dbc258b 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -236,8 +236,9 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
- Arg<bool> quantize_weights = Arg<bool>(false);
+ Arg<bool> post_training_quantize = Arg<bool>(false);
// Deprecated flags
+ Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
@@ -246,6 +247,10 @@ struct ParsedTocoFlags {
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> allow_eager_ops = Arg<bool>(false);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> force_eager_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 02671f0408..b52a79282c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1701,9 +1701,11 @@ void ConvertReduceOperator(const Model& model, const T& src_op,
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const tensorflow::DataType params_type =
- GetTensorFlowDataType(model, src_op.inputs[0]);
- (*new_op->mutable_attr())["T"].set_type(params_type);
+ if (src_op.type != OperatorType::kAny) {
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ }
const tensorflow::DataType indices_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
(*new_op->mutable_attr())["Tidx"].set_type(indices_type);
@@ -1900,21 +1902,6 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op,
(*pow_op->mutable_attr())["T"].set_type(data_type);
}
-void ConvertAnyOperator(const Model& model, const AnyOperator& src_op,
- GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* any_op = tensorflow_graph->add_node();
- any_op->set_op("Any");
- any_op->set_name(src_op.outputs[0]);
- CHECK_EQ(src_op.inputs.size(), 2);
- for (int i = 0; i < 2; ++i) {
- *any_op->add_input() = src_op.inputs[i];
- }
- const tensorflow::DataType data_type =
- GetTensorFlowDataType(model, src_op.inputs[1]);
- (*any_op->mutable_attr())["Tidx"].set_type(data_type);
- (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims);
-}
-
void ConvertLogicalAndOperator(const Model& model,
const LogicalAndOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1967,6 +1954,20 @@ void ConvertCTCBeamSearchDecoderOperator(
(*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
}
+void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
+ unpack_op->set_op(op_name);
+ unpack_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *unpack_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*unpack_op->mutable_attr())["T"].set_type(data_type);
+ (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
+ (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2118,7 +2119,7 @@ void ConvertOperator(const Model& model, const Operator& src_op,
tensorflow_graph, "Prod");
} else if (src_op.type == OperatorType::kReduceMin) {
ConvertReduceOperator(model,
- static_cast<const TensorFlowMaxOperator&>(src_op),
+ static_cast<const TensorFlowMinOperator&>(src_op),
tensorflow_graph, "Min");
} else if (src_op.type == OperatorType::kReduceMax) {
ConvertReduceOperator(model,
@@ -2207,8 +2208,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
tensorflow_graph);
} else if (src_op.type == OperatorType::kAny) {
- ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op),
- tensorflow_graph);
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowAnyOperator&>(src_op),
+ tensorflow_graph, "Any");
} else if (src_op.type == OperatorType::kLogicalAnd) {
ConvertLogicalAndOperator(model,
static_cast<const LogicalAndOperator&>(src_op),
@@ -2228,6 +2230,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertCTCBeamSearchDecoderOperator(
model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
"CTCBeamSearchDecoder", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kUnpack) {
+ ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
+ "Unpack", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 4bf47aa3c4..84680b968e 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -24,8 +24,8 @@ Table of contents:
* [Multiple output arrays](#multiple-output-arrays)
* [Specifying subgraphs](#specifying-subgraphs)
* [Graph visualizations](#graph-visualizations)
- * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot)
- * [Using --dump_graphviz](#using-dump-graphviz)
+ * [Using --output_format=GRAPHVIZ_DOT](#using-output-format-graphviz-dot)
+ * [Using --dump_graphviz_dir](#using-dump-graphviz-dir)
* [Graph "video" logging](#graph-video-logging)
* [Legend for the graph visualizations](#graphviz-legend)
@@ -247,17 +247,17 @@ function tends to get fused).
## Graph visualizations
-TOCO can export a graph to the GraphViz Dot format for easy visualization via
+TOCO can export a graph to the Graphviz Dot format for easy visualization via
either the `--output_format` flag or the `--dump_graphviz_dir` flag. The
subsections below outline the use cases for each.
-### Using `--output_format=GRAPHVIZ_DOT`
+### Using `--output_format=GRAPHVIZ_DOT` <a name="using-output-format-graphviz-dot"></a>
-The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into
+The first way to get a Graphviz rendering is to pass `GRAPHVIZ_DOT` into
`--output_format`. This results in a plausible visualization of the graph. This
-reduces the requirements that exist during conversion between other input and
-output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to
-TFLITE is failing.
+reduces the requirements that exist during conversion from a TensorFlow GraphDef
+to a TensorFlow Lite FlatBuffer. This may be useful if the conversion to TFLite
+is failing.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
@@ -287,10 +287,10 @@ google-chrome /tmp/foo.dot.pdf
Example PDF files are viewable online in the next section.
-### Using `--dump_graphviz`
+### Using `--dump_graphviz_dir`
-The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir`
-flag, specifying a destination directory to dump GraphViz rendering to. Unlike
+The second way to get a Graphviz rendering is to pass the `--dump_graphviz_dir`
+flag, specifying a destination directory to dump Graphviz rendering to. Unlike
the previous approach, this one retains the original output format. This
provides a visualization of the actual graph resulting from a specific
conversion process.
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index decc8a45a4..00bc8d4ccb 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -38,7 +38,7 @@ files. The flag `--output_file` is always required. Additionally, either
of TFLite specific transformations. Therefore, the resulting
visualization may not reflect the final set of graph
transformations. To get a final visualization with all graph
- transformations use `--dump_graphviz` instead.
+ transformations use `--dump_graphviz_dir` instead.
The following flags specify optional parameters when using SavedModels.
@@ -67,21 +67,22 @@ based on index.
* `--input_shapes`. Type: colon-separated list of comma-separated lists of
integers. Each comma-separated list of integers gives the shape of one of
- the input arrays specified in [TensorFlow
- convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
+ the input arrays specified in
+ [TensorFlow convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
* Example: `--input_shapes=1,60,80,3` for a typical vision model means a
batch size of 1, an input image height of 60, an input image width of
80, and an input image depth of 3 (representing RGB channels).
* Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo"
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
-* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers.
+* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
These specify the (de-)quantization parameters of the input array, when it
- is quantized.
+ is quantized. This is only needed if `inference_input_type` is
+ `QUANTIZED_UINT8`.
* The meaning of `mean_values` and `std_dev_values` is as follows: each
quantized value in the quantized input array will be interpreted as a
mathematical real number (i.e. as an input activation value) according
to the following formula:
- * `real_value = (quantized_input_value - mean_value) / std_value`.
+ * `real_value = (quantized_input_value - mean_value) / std_dev_value`.
* When performing float inference (`--inference_type=FLOAT`) on a
quantized input, the quantized input would be immediately dequantized by
the inference code according to the above formula, before proceeding
@@ -91,7 +92,8 @@ based on index.
the inference code. However, the quantization parameters of all arrays,
including those of the input arrays as specified by `mean_value` and
`std_dev_value`, determine the fixed-point multipliers used in the
- quantized inference code.
+ quantized inference code. `mean_value` must be an integer when
+ performing quantized inference.
## Transformation flags
@@ -147,10 +149,10 @@ have.
true, custom ops are created for any op that is unknown. The developer will
need to provide these to the TensorFlow Lite runtime with a custom resolver.
-* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to
- store weights as quantized weights followed by dequantize operations.
- Computation is still done in float, but reduces model size (at the cost of
- accuracy and latency).
+* `--post_training_quantize`. Type: boolean. Default: False. Boolean
+ indicating whether to quantize the weights of the converted float model.
+ Model size will be reduced and there will be latency improvements (at the
+ cost of accuracy).
## Logging flags
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 3799eac0a1..51f808d4f0 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -70,6 +70,7 @@ val = img + var
out = tf.identity(val, name="out")
with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
index 262e13a591..335debde57 100644
--- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
+++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
@@ -1 +1 @@
-<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m99.50131 100.0l0 76.0l54.992126 0l0 76.0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m99.50131 100.0l0 76.0l54.992126 0l0 72.57292" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m154.49344 248.5729l-1.124588 -1.1245728l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 99f4a7d8f6..fdd0632451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -142,7 +142,6 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
-DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
@@ -178,9 +177,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReshapeAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
@@ -217,12 +217,6 @@ class PropagateDefaultMinMax : public GraphTransformation {
std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
};
-class ResolveReshapeAttributes : public GraphTransformation {
- public:
- bool Run(Model* model, std::size_t op_index) override;
- const char* Name() const override { return "ResolveReshapeAttributes"; }
-};
-
class RemoveTrivialReshape : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 502de88f7c..3114fa93e8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -63,6 +63,25 @@ bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
return true;
}
+bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
+ auto& input = model->GetArray(op->inputs[0]);
+ if (input.minmax) {
+ const auto* minmax = input.minmax.get();
+ if (minmax) {
+ return false;
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ input.GetOrCreateMinMax() = *minmax;
+ return true;
+ }
+ }
+ return false;
+}
+
bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
// Do not early return if the output already has min/max:
// we may still need to adjust the inputs min/max.
@@ -366,6 +385,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForL2Normalization(model, op);
break;
+ case OperatorType::kRelu:
+ // For any normalization other than batch norm, the quantizations ranges
+ // before and after relu are expected to be known. Having a quantization
+ // op before relu would reduce the number of bits of precision for the
+ // activation in half. So we deduce the range before relu from that after
+ // the relu. This would eliminate the need for two fake quantization nodes
+ // and would not reduce the bits of precision available for activation.
+ changed = HardcodeInputMinMaxFromOutput(model, op);
+ break;
+
case OperatorType::kConcatenation:
changed = HardcodeMinMaxForConcatenation(model, op);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c8310161cb..323eefcd3a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -227,6 +227,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
ArrayDataType::kFloat;
break;
}
+ case OperatorType::kUnpack: {
+ CHECK_EQ(op->inputs.size(), 1);
+ const int output_size = op->outputs.size();
+ for (int i = 0; i < output_size; ++i) {
+ model->GetArray(op->outputs[i]).data_type =
+ model->GetArray(op->inputs[0]).data_type;
+ }
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 91e290439a..f103bb94ae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -539,6 +539,8 @@ bool KeepDims(const Operator& op) {
return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
+ case OperatorType::kAny:
+ return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
default:
LOG(FATAL) << "Not a reduction operator!";
return false;
@@ -559,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = model->GetArray(op->inputs[1]);
- if (!reduction_array.buffer) {
+ const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.buffer) {
return;
}
- CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
- const auto& reduction_array_vals =
- reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
- auto& output_dims = *output_array.mutable_shape()->mutable_dims();
- output_dims.clear();
- for (int i = 0; i < input_shape.dimensions_count(); i++) {
- bool is_reduction_dim = false;
- for (int r : reduction_array_vals) {
- if (i == r) {
- is_reduction_dim = true;
- }
+ CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
}
- if (!is_reduction_dim) {
- output_dims.push_back(input_shape.dims(i));
- } else if (keep_dims) {
- output_dims.push_back(1);
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
+ }
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
}
}
} else {
@@ -1300,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op->begin_mask, op->end_mask, op->shrink_axis_mask,
+ op->start_indices, op->stop_indices, op->strides);
int start_index = tflite::strided_slice::StartForAxis(
- op->begin_mask, op->start_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis, start_index);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
+
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
@@ -1515,65 +1533,6 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
}
}
-void ProcessAnyOperator(Model* model, AnyOperator* op) {
- CHECK_EQ(op->inputs.size(), 2);
- CHECK_EQ(op->outputs.size(), 1);
-
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- // We have already run.
- return;
- }
-
- const auto& input_array = model->GetArray(op->inputs[0]);
- if (!input_array.has_shape()) {
- // Yield until input dims have been resolved.
- return;
- }
- const auto& input_shape = input_array.shape();
-
- auto& reduction_indices_array = model->GetArray(op->inputs[1]);
- if (!reduction_indices_array.has_shape()) {
- // Yield until reduction indices shape been resolved.
- return;
- }
- if (!reduction_indices_array.buffer) {
- // Yield until the reduction indices are constant.
- return;
- }
- CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32)
- << "Any reduction input must be int32";
-
- int input_rank = input_shape.dimensions_count();
- std::set<int32> true_indices;
- const auto& reduction_indices =
- reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
- for (int i = 0; i < reduction_indices.size(); ++i) {
- const int32 reduction_index = reduction_indices[i];
- if (reduction_index < -input_rank || reduction_index >= input_rank) {
- CHECK(false) << "Invalid reduction dimension " << reduction_index
- << " for input with " << input_rank << " dimensions";
- }
- int32 wrapped_index = reduction_index;
- if (wrapped_index < 0) {
- wrapped_index += input_rank;
- }
- true_indices.insert(wrapped_index);
- }
-
- auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
- mutable_dims->clear();
- for (int i = 0; i < input_rank; ++i) {
- if (true_indices.count(i) > 0) {
- if (op->keep_dims) {
- mutable_dims->emplace_back(1);
- }
- } else {
- mutable_dims->emplace_back(input_shape.dims(i));
- }
- }
-}
-
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
CHECK_EQ(op->inputs.size(), 4);
CHECK_EQ(op->outputs.size(), 1);
@@ -1629,6 +1588,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
}
}
+void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ output_dims.reserve(input_dims.size() - 1);
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (i != op->axis) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ for (const string& output_name : op->outputs) {
+ auto& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) {
+ return;
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1743,6 +1728,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSum:
case OperatorType::kReduceProd:
case OperatorType::kMean:
+ case OperatorType::kAny:
ProcessTensorFlowReductionOperator(model, op);
break;
case OperatorType::kSelect:
@@ -1874,12 +1860,13 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
- case OperatorType::kAny:
- ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
break;
+ case OperatorType::kUnpack:
+ ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 8d22ae2eb1..1bc366f555 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -62,7 +62,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
type == OperatorType::kArgMax || type == OperatorType::kRelu ||
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
- type == OperatorType::kShape || type == OperatorType::kExpandDims;
+ type == OperatorType::kShape || type == OperatorType::kExpandDims ||
+ type == OperatorType::kPack || type == OperatorType::kTopK_V2;
}
// The quantized op allows output arrays of type float using
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
deleted file mode 100644
index 7a8515f6d1..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <iterator>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-namespace {
-
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this minimum size configurable.
-const int kWeightsMinSize = 1024;
-
-// Gets the quantization params from the float array.
-void GetQuantizationParamsFromArray(const Array& array,
- QuantizationParams* params) {
- const std::vector<float>& float_vals =
- array.GetBuffer<ArrayDataType::kFloat>().data;
- auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
- *params = tflite::ChooseQuantizationParams<uint8>(
- *minmax.first, *minmax.second, array.narrow_range);
-}
-
-} // namespace
-
-bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
- const auto op_it = model->operators.begin() + op_index;
- Operator* op = op_it->get();
-
- // Get the weights tensor, if the current operator has one.
- int weights_index;
- if (op->type == OperatorType::kConv ||
- op->type == OperatorType::kDepthwiseConv ||
- op->type == OperatorType::kFullyConnected) {
- weights_index = 1;
- } else if (op->type == OperatorType::kLstmCell) {
- weights_index = LstmCellOperator::WEIGHTS_INPUT;
- } else {
- return false;
- }
-
- // Return early if the array isn't a constant param, this can happen in early
- // transformation passes until transpose operations following the weight array
- // are resolved.
- const string weights = op->inputs[weights_index];
- if (!IsConstantParameterArray(*model, weights)) {
- return false;
- }
-
- // Return early if the weight tensor is not type float.
- Array& weights_array = model->GetArray(weights);
- if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
- }
-
- // Return early if the tensor is too small. Small tensors don't take up too
- // much space and can result in bad quantization results.
- if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
- kWeightsMinSize) {
- return false;
- }
-
- // Quantize the weight tensor to type kUint8.
- QuantizationParams params;
- GetQuantizationParamsFromArray(weights_array, &params);
- QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
-
- // Insert a Dequantize operation after the quantized weights tensor.
- auto* dequantize_op = new DequantizeOperator;
- model->operators.emplace(op_it, dequantize_op);
-
- // Create a new intermediate tensor to connect the Dequantize op to the
- // original op.
- const string dequantized_output =
- AvailableArrayName(*model, weights + "_dequantized");
- Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
- dequantized_output_array.data_type = ArrayDataType::kFloat;
-
- // Connect up the new Dequantize op with the weights and original op.
- op->inputs[weights_index] = dequantized_output;
- dequantize_op->inputs = {weights};
- dequantize_op->outputs = {dequantized_output};
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index d395d7a6a0..f5f2f77460 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -117,6 +117,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
&quantized_max);
if (fakequant_op->narrow_range) {
quantized_min++;
+ output_array.narrow_range = true;
}
// It is important for matching accuracy between TF training and TFLite
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index 41562ab393..a6f665b5f0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -100,13 +100,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolving constant reshape of %s", LogName(*op));
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
// Erase input arrays if no longer used.
for (const auto& input : op->inputs) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
index 0b0d070714..5cfa1a5582 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -128,15 +128,7 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
multiples_array.data_type == ArrayDataType::kInt64)
<< "Only int32/int64 indices are supported";
- // Copy min/max info if present. The ranges of the selected values may be
- // a subset of the original range but we want to ensure the quantization
- // params stay the same.
- if (input_array.minmax) {
- const auto& input_minmax = input_array.GetMinMax();
- auto& output_minmax = output_array.GetOrCreateMinMax();
- output_minmax.min = input_minmax.min;
- output_minmax.max = input_minmax.max;
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
CHECK(!output_array.buffer);
switch (output_array.data_type) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
index 1fd20314b1..fe15dfa06f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
@@ -128,13 +128,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
}
const Array& input_array = model->GetArray(op->inputs[0]);
- if (input_array.minmax) {
- output_array.GetOrCreateMinMax() = input_array.GetMinMax();
- }
- if (input_array.quantization_params) {
- output_array.GetOrCreateQuantizationParams() =
- input_array.GetQuantizationParams();
- }
+ CopyMinMaxAndQuantizationRelatedFields(input_array, &output_array);
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 475415e481..c698a9567a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -51,6 +51,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// Test for unary ops of types that we know how to resolve.
switch (unary_op->type) {
case OperatorType::kCast:
+ case OperatorType::kExp:
case OperatorType::kLog:
case OperatorType::kNeg:
case OperatorType::kRsqrt:
@@ -218,7 +219,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
max = std::max(max, (*input_float_data)[i]);
}
output_float_data[0] = max;
- } else if (unary_op->type == OperatorType::kNeg ||
+ } else if (unary_op->type == OperatorType::kExp ||
+ unary_op->type == OperatorType::kNeg ||
unary_op->type == OperatorType::kLog ||
unary_op->type == OperatorType::kRsqrt ||
unary_op->type == OperatorType::kSqrt ||
@@ -231,7 +233,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < output_buffer_size; i++) {
const float val = (*input_float_data)[i];
float outval = 0.f;
- if (unary_op->type == OperatorType::kNeg) {
+ if (unary_op->type == OperatorType::kExp) {
+ outval = std::exp(val);
+ } else if (unary_op->type == OperatorType::kNeg) {
outval = -val;
} else if (unary_op->type == OperatorType::kLog) {
outval = std::log(val);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 7d456af2fb..73198ac7c0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -52,6 +52,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ case OperatorType::kAny:
+ return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
default:
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index e163fc9ae1..acf1e3ede5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -20,19 +20,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "quantize_weights_test",
- srcs = ["quantize_weights_test.cc"],
- tags = ["no_oss"],
- deps = [
- "//tensorflow/contrib/lite/toco:graph_transformations",
- "//tensorflow/contrib/lite/toco:model",
- "//tensorflow/contrib/lite/toco:tooling_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-tf_cc_test(
name = "resolve_constant_concatenation_test",
srcs = ["resolve_constant_concatenation_test.cc"],
tags = ["no_oss"],
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
deleted file mode 100644
index c05eb0929f..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <math.h>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-class QuantizeWeightsTest : public ::testing::Test {
- protected:
- QuantizeWeightsTest() {}
-
- // The name of the weights input array.
- const string kWeightsName = "weights";
- // The zero_point of the values in the input array.
- const int kZeroPoint = 128;
-
- // Prepare a hypothetical TOCO model of a quantizable fully connected float
- // layer.
- void PrepareModel(Model* model, int elements_per_dim) {
- std::vector<string> fc_input_names = {"inputs", kWeightsName};
-
- const int kDim = 4;
- const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
- auto in_buf = absl::make_unique<float[]>(buf_size);
- // Initialize the array with values from -128.0 to 127.0, since these values
- // should be exactly representable by quantization.
- for (int i = 0; i < buf_size; i++) {
- in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
- }
-
- for (const string& fc_input_name : fc_input_names) {
- Array& in_array = model->GetOrCreateArray(fc_input_name);
- in_array.data_type = ArrayDataType::kFloat;
-
- // Initialize shape for the input array.
- Shape* in_array_shape = in_array.mutable_shape();
- std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
- in_array_shape_dim->resize(kDim, elements_per_dim);
- auto& in_array_buffer =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>();
- in_array_buffer.data.resize(buf_size);
- float* buf_ptr =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
- std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
- }
-
- auto* fc_op = new FullyConnectedOperator;
- fc_op->inputs = fc_input_names;
- fc_op->outputs = {"fc_op_outputs"};
- Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
- out_array.data_type = ArrayDataType::kFloat;
- Shape* out_array_shape = out_array.mutable_shape();
- std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
- out_array_shape_dim->resize(kDim, elements_per_dim);
- model->operators.push_back(std::unique_ptr<Operator>(fc_op));
- }
-};
-
-TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
- // Test that weight arrays that are large enough are quantized.
- Model model;
- // 6 elements per dim gives us 1296 elements, which is sufficient to be
- // quantized.
- PrepareModel(&model, 6);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (const auto& element : float_array_map) {
- EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
- }
- const std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& quantized_array_map = model.GetArrayMap();
- EXPECT_EQ(quantized_array_map.size(), 4);
- // After the transformation, three arrays should be type float and one array
- // should be uint8.
- int num_float = 0;
- int num_uint8 = 0;
- for (const auto& element : quantized_array_map) {
- if (element.second->data_type == ArrayDataType::kFloat) {
- num_float++;
- } else if (element.second->data_type == ArrayDataType::kUint8) {
- num_uint8++;
- } else {
- FAIL() << "Unexpected array type.";
- }
- }
- EXPECT_EQ(num_float, 3);
- EXPECT_EQ(num_uint8, 1);
- // Ensure that the values were quantized correctly.
- const std::vector<uint8>& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
- }
-
- // Ensure that a Dequantize operator has been inserted before the
- // FullyConnectedLayer.
- EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
-}
-
-TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
- // Test that weight arrays that are too small are left untouched.
- Model model;
- // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
- // quantized.
- PrepareModel(&model, 5);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& post_array_map = model.GetArrayMap();
- EXPECT_EQ(post_array_map.size(), 3);
- for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- // Ensure that the values remain unchanged.
- std::vector<float> const& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
- }
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
index 5f0cece67a..fedf4441e2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
@@ -154,6 +154,7 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
pack_op->inputs = pack_inputs;
pack_op->outputs = {batch_op->outputs[0]};
pack_op->axis = 0;
+ pack_op->values_count = pack_inputs.size();
model->operators.emplace(tail_it, pack_op);
// Remove the old batch matmul now that we've unrolled.
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b7fffbce22..9bc23c4b3c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1576,6 +1576,26 @@ tensorflow::Status ConvertPackOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertUnpackOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Unpack");
+ auto op = absl::make_unique<UnpackOperator>();
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ QCHECK_EQ(num_inputs, 1);
+ op->inputs.push_back(node.input(0));
+ op->num = GetIntAttr(node, "num");
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
+ op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
+
+ op->outputs.push_back(node.name()); // Implicit :0.
+ for (int i = 1; i < op->num; ++i) {
+ op->outputs.push_back(node.name() + ":" + std::to_string(i));
+ }
+ model->operators.emplace_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
// Some TensorFlow ops only occur in graph cycles, representing
// control flow. We do not currently support control flow, so we wouldn't
// be able to fully support such graphs, including performing inference,
@@ -1618,24 +1638,6 @@ tensorflow::Status ConvertShapeOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertAnyOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Any");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- const auto idx_type =
- HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
- CHECK(idx_type == DT_INT32);
- auto op = absl::make_unique<AnyOperator>();
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- op->keep_dims =
- HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false;
- model->operators.push_back(std::move(op));
- return tensorflow::Status::OK();
-}
-
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1917,7 +1919,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"Any", ConvertAnyOperator},
+ {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
{"ArgMax", ConvertArgMaxOperator},
{"ArgMin", ConvertArgMinOperator},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
@@ -2020,6 +2022,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"TopK", ConvertTopKV2Operator},
{"TopKV2", ConvertTopKV2Operator},
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
+ {"Unpack", ConvertUnpackOperator},
});
}
@@ -2058,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
- const internal::ConverterMapType& converter_map =
- internal::GetTensorFlowNodeConverterMap();
+ internal::ConverterMapType converter_map;
+
+ // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
+ // converted to TFLite Eager ops.
+ if (!tf_import_flags.import_all_ops_as_unsupported) {
+ converter_map = internal::GetTensorFlowNodeConverterMap();
+ }
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 2177872334..7db23f2d44 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -27,6 +27,11 @@ struct TensorFlowImportFlags {
// If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
+
+ // Do not recognize any op and import all ops as
+ // `TensorFlowUnsupportedOperator`. This is used to populated with the
+ // `force_eager_ops` flag.
+ bool import_all_ops_as_unsupported = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef(
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 412e14c4ad..2e100e37f6 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -149,6 +149,7 @@ enum class OperatorType : uint8 {
kLogicalNot,
kLogicalOr,
kCTCBeamSearchDecoder,
+ kUnpack,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1770,8 +1771,9 @@ struct PowOperator : Operator {
// Inputs[1]: required: reduction_indices.
//
// TensorFlow equivalent: tf.reduce_any.
-struct AnyOperator : Operator {
- AnyOperator() : Operator(OperatorType::kAny) {}
+struct TensorFlowAnyOperator : Operator {
+ TensorFlowAnyOperator() : Operator(OperatorType::kAny) {}
+ std::vector<int> axis;
bool keep_dims = false;
};
@@ -1828,6 +1830,20 @@ struct LogicalOrOperator : Operator {
LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
};
+// Unpack operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
+//
+// TensorFlow equivalent: tf.unstack.
+struct UnpackOperator : Operator {
+ UnpackOperator() : Operator(OperatorType::kUnpack) {}
+ int num;
+ int axis;
+ ArrayDataType dtype = ArrayDataType::kNone;
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
index 3761e0095e..75c1c8970c 100644
--- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -50,7 +50,7 @@ class TocoFromProtosTest(googletest.TestCase):
toco_flags.output_format = toco_flags_pb2.TFLITE
toco_flags.inference_input_type = types_pb2.FLOAT
toco_flags.inference_type = types_pb2.FLOAT
- toco_flags.allow_custom_ops = True;
+ toco_flags.allow_custom_ops = True
model_flags = model_flags_pb2.ModelFlags()
input_array = model_flags.input_arrays.add()
input_array.name = TensorName(in_tensor)
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
index 7e8ad9c1da..ee054bbed9 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.h
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
-#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
#include <Python.h>
#include <string>
@@ -33,4 +33,4 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
} // namespace toco
-#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
index 18ff73ac39..fda7743a27 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
#include <string>
#include <vector>
@@ -98,4 +98,4 @@ class ClusterFactoryInterface {
} // end namespace toco
-#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
index a15e480e70..b57bded305 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_
#include <string>
@@ -30,4 +30,4 @@ void Transpose2DTensor(const float* tensor, int row, int col,
} // end namespace toco
-#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_UTILS_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
index 7d33dd1885..3334552afb 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_
#include <string>
#include <unordered_map>
@@ -60,4 +60,4 @@ std::unique_ptr<tensorflow::GraphDef> MaybeReplaceCompositeSubgraph(
} // end namespace toco
-#endif // CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
index c4c6c34117..383fd99dff 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
+++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_
#include <string>
#include <vector>
@@ -79,4 +79,4 @@ class SvdfClusterFactory : public ClusterFactoryInterface {
} // end namespace toco
-#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 709c53606b..71cdb7703e 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -91,6 +91,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/contrib/lite/tools/optimize:quantize_weights",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5ad307af14..fee10b1dff 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -16,10 +16,12 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/contrib/lite/version.h"
namespace toco {
@@ -47,12 +49,21 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_eager_ops) {
+ custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ unsupported_op.tensorflow_op;
+ } else {
+ custom_code = unsupported_op.tensorflow_op;
+ }
}
int version = 1;
if (ops_by_type.count(op.type) != 0) {
@@ -61,6 +72,13 @@ details::OperatorKey GetOperatorKey(
return details::OperatorKey(op.type, custom_code, version);
}
+void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
+ string* file_contents) {
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
} // Anonymous namespace.
namespace details {
@@ -82,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -180,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary) {
+ std::set<string>* error_summary, const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -196,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
- const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
+ const details::OperatorKey operator_key =
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -243,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
- std::set<int32_t>* variable_tensor_indices) {
+ std::set<int32_t>* variable_tensor_indices, const ExportParams& params) {
variable_tensor_indices->clear();
// The operators are in execution order, so we just follow tf.mini order.
@@ -260,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+ int op_index = operators_map.at(
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -311,14 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops,
- string* output_file_contents) {
- const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params) {
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ Export(model, output_file_contents, params, ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -326,7 +348,8 @@ void Export(
details::LoadTensorsMap(model, &tensors_map);
details::OperatorsMap operators_map;
- details::LoadOperatorsMap(model, &operators_map, ops_by_type);
+ details::LoadOperatorsMap(model, &operators_map, ops_by_type,
+ params.allow_eager_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -334,7 +357,7 @@ void Export(
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary);
+ &builder, &error_summary, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -344,7 +367,7 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!allow_custom_ops && !error_summary.empty()) {
+ if (!params.allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
// However, if an op is unimplemented earlier in the model, the graph
@@ -365,14 +388,14 @@ void Export(
"the standard TensorFlow Lite runtime. If you have a custom "
"implementation for them you can disable this error with "
"--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.toco_convert(). Here is a list "
+ "when calling tf.contrib.lite.TocoConverter(). Here is a list "
"of operators for which you will need custom implementations: "
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
std::set<int32_t> variable_tensor_indices;
auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
- &builder, &variable_tensor_indices);
+ &builder, &variable_tensor_indices, params);
auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
variable_tensor_indices);
@@ -390,9 +413,24 @@ void Export(
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- const uint8_t* buffer = builder.GetBufferPointer();
- int size = builder.GetSize();
- *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+
+ if (params.quantize_weights) {
+ // Call the quantize_weights tool.
+ LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
+ "dump_graphviz will only output the model before this "
+ "transformation. To visualize the output graph use "
+ "lite/tools/optimize.py.";
+ flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
+ if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
+ kTfLiteOk) {
+ LOG(QFATAL) << "Quantize weights transformation failed.";
+ }
+ WriteModelToString(q_builder, output_file_contents);
+ } else {
+ WriteModelToString(builder, output_file_contents);
+ }
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 58ea5c725c..b070a38768 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -23,22 +23,55 @@ namespace toco {
namespace tflite {
+// The parameters for exporting a TFLite model.
+struct ExportParams {
+ bool allow_custom_ops = false;
+ bool allow_eager_ops = false;
+ bool quantize_weights = false;
+};
+
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops,
- string* output_file_contents);
-
-// This if backward-compatibility.
-// TODO(ycling): Remove the deprecated entry functions.
-inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, output_file_contents);
-}
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params);
// Export API with custom TFLite operator mapping.
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, bool allow_custom_ops,
+ bool quantize_weights, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params, ops_by_type);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = true;
+ Export(model, output_file_contents, params);
+ Export(model, true, false, output_file_contents);
+}
+
namespace details {
// A maps from tensor name to its final position in the TF Lite buffer.
@@ -87,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a95937ba0f..8d4d197c46 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -52,6 +52,42 @@ class ExportTest : public ::testing::Test {
input_model_.operators.emplace_back(new SubOperator);
}
+ void BuildQuantizableTestModel() {
+ input_model_.GetOrCreateArray("inputs");
+ Array& weight_array = input_model_.GetOrCreateArray("weights");
+
+ // Make the buffer large enough for QuantizeWeights transformation to take
+ // effect.
+ int buf_size = 1296;
+ auto weight_buf = absl::make_unique<float[]>(buf_size);
+ for (int i = 0; i < buf_size; i++) {
+ // Fill the array with some garbage values.
+ weight_buf[i] = static_cast<float>(i % 128);
+ }
+
+ weight_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* weight_array_shape = weight_array.mutable_shape();
+ std::vector<int>* weight_array_shape_dim =
+ weight_array_shape->mutable_dims();
+ weight_array_shape_dim->resize(4, 6);
+ auto& weight_array_buffer =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ weight_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
+
+ {
+ auto* op = new ConvOperator;
+ op->padding.type = PaddingType::kSame;
+ op->inputs = {"inputs", "weights"};
+ input_model_.operators.emplace_back(op);
+ }
+ input_model_.operators.emplace_back(new AddOperator);
+ }
+
Model input_model_;
};
@@ -69,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ // TODO(ycling): Add a test for allow_eager_ops.
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
@@ -81,7 +118,7 @@ TEST_F(ExportTest, Export) {
BuildTestModel();
string result;
- Export(input_model_, true, &result);
+ Export(input_model_, true, false, &result);
auto* model = ::tflite::GetModel(result.data());
@@ -108,6 +145,20 @@ TEST_F(ExportTest, Export) {
EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
}
+TEST_F(ExportTest, QuantizeWeights) {
+ // Sanity check for quantize_weights parameter.
+ BuildQuantizableTestModel();
+ string unquantized_result;
+ Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
+
+ BuildQuantizableTestModel();
+ string quantized_result;
+ Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
+
+ // The quantized models should be smaller.
+ EXPECT_LT(quantized_result.size(), unquantized_result.size());
+}
+
// This test is based on a hypothetical scenario that dilation is supported
// only in Conv version 2. So Toco populates version=1 when dialation
// parameters are all 1, and version=2 otehrwise.
@@ -203,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -214,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
@@ -226,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -239,7 +290,7 @@ TEST_F(VersionedOpExportTest, Export) {
string result;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- Export(input_model_, true, &result, ops_by_type);
+ Export(input_model_, true, false, &result, ops_by_type);
auto* model = ::tflite::GetModel(result.data());
auto operator_codes = model->operator_codes();
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 75808f2b69..eb0f7c443a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -769,7 +769,26 @@ class Sum
};
class ReduceMax
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class ReduceMin
+ : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -788,7 +807,26 @@ class ReduceMax
};
class ReduceProd
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class ReduceAny
+ : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -1091,9 +1129,29 @@ class CTCBeamSearchDecoder
int GetVersion(const Operator& op) const override { return 1; }
};
+class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
+ ::tflite::BuiltinOptions_UnpackOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->num = options.num();
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
- using BaseOperator::BaseOperator;
+ TensorFlowUnsupported(const string& name, OperatorType type,
+ bool allow_eager_ops)
+ : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1109,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
+ // Deserializing Eager ops doesn't work now.
+ // TODO(ycling): Revisit and decide if we should fix the flow for importing
+ // TFLite models with Eager ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1129,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
+ if (allow_eager_ops_) {
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(op.tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing eager op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
@@ -1229,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator {
// custom ops.
return 1;
}
+
+ private:
+ const bool allow_eager_ops_;
};
namespace {
// Build a vector containing all the known operators.
-std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
+ bool allow_eager_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1297,6 +1372,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kReduceProd));
ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
OperatorType::kReduceMax));
+ ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
+ OperatorType::kReduceMin));
+ ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
+ OperatorType::kAny));
ops.push_back(
MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
@@ -1332,14 +1411,16 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
+ ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
+ OperatorType::kUnpack));
// Custom Operators.
ops.push_back(
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
- ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
- OperatorType::kUnsupported));
+ ops.push_back(MakeUnique<TensorFlowUnsupported>(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1396,6 +1477,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"LOGICAL_AND", OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
"LOGICAL_NOT", OperatorType::kLogicalNot));
+ ops.emplace_back(new SimpleOperator<FloorDivOperator>(
+ "FLOOR_DIV", OperatorType::kFloorDiv));
// Element-wise operator
ops.push_back(
MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
@@ -1410,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
}
} // namespace
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1421,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
return result;
}
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index d9ea23edf2..702fb28ea6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,11 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// is ugly here. Consider refactoring.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index fc854461b4..519a3a4e01 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -97,6 +97,16 @@ class OperatorTest : public ::testing::Test {
ASSERT_NE(nullptr, output_toco_op.get());
}
+
+ template <typename T>
+ void CheckReducerOperator(const string& name, OperatorType type) {
+ T op;
+
+ op.keep_dims = false;
+
+ auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
+ EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+ }
};
TEST_F(OperatorTest, SimpleOperators) {
@@ -133,6 +143,7 @@ TEST_F(OperatorTest, SimpleOperators) {
OperatorType::kLogicalAnd);
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
+ CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
}
TEST_F(OperatorTest, BuiltinAdd) {
@@ -144,13 +155,16 @@ TEST_F(OperatorTest, BuiltinAdd) {
output_toco_op->fused_activation_function);
}
-TEST_F(OperatorTest, BuiltinMean) {
- MeanOperator op;
- op.keep_dims = false;
-
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
- EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+TEST_F(OperatorTest, BuiltinReducerOps) {
+ CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
+ CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
+ CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
+ OperatorType::kReduceProd);
+ CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
+ OperatorType::kReduceMax);
+ CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
+ OperatorType::kReduceMin);
+ CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
}
TEST_F(OperatorTest, BuiltinCast) {
@@ -476,6 +490,16 @@ TEST_F(OperatorTest, BuiltinOneHot) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, BuiltinUnpack) {
+ UnpackOperator op;
+ op.num = 5;
+ op.axis = 2;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
+ EXPECT_EQ(op.num, output_toco_op->num);
+ EXPECT_EQ(op.axis, output_toco_op->axis);
+}
+
TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
CTCBeamSearchDecoderOperator op;
op.beam_width = 3;
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index c6d0a03452..b6aebc0470 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -160,10 +160,18 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Ignored if the output format is not TFLite."),
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
parsed_flags.quantize_weights.default_value(),
- "Store weights as quantized weights followed by dequantize "
- "operations. Computation is still done in float, but reduces model "
- "size (at the cost of accuracy and latency)."),
- };
+ "Deprecated. Please use --post_training_quantize instead."),
+ Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
+ parsed_flags.post_training_quantize.default_value(),
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy)."),
+ // WARNING: Experimental interface, subject to change
+ Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
+ parsed_flags.allow_eager_ops.default_value(), ""),
+ // WARNING: Experimental interface, subject to change
+ Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
+ parsed_flags.force_eager_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -257,6 +265,17 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
+ READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+
+ if (parsed_toco_flags.force_eager_ops.value() &&
+ !parsed_toco_flags.allow_eager_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_eager_ops` when
+ // `force_eager_ops` is true.
+ LOG(WARNING) << "--force_eager_ops should always be used with "
+ "--allow_eager_ops.";
+ }
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -291,9 +310,19 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
toco_flags->set_inference_input_type(input_type);
}
if (parsed_toco_flags.quantize_weights.value()) {
- QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
- << "quantize_weights is not supported with inference_type "
- "QUANTIZED_UINT8.";
+ LOG(WARNING)
+ << "--quantize_weights is deprecated. Falling back to "
+ "--post_training_quantize. Please switch --post_training_quantize.";
+ toco_flags->set_post_training_quantize(
+ parsed_toco_flags.quantize_weights.value());
+ }
+ if (parsed_toco_flags.quantize_weights.value()) {
+ if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
+ LOG(WARNING)
+ << "--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
}
#undef READ_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index b4a9870d58..53d60fed05 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 26.
+// Next ID to use: 29.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -173,6 +173,7 @@ message TocoFlags {
// Store weights as quantized weights followed by dequantize operations.
// Computation is still done in float, but reduces model size (at the cost of
// accuracy and latency).
+ // DEPRECATED: Please use post_training_quantize instead.
optional bool quantize_weights = 20 [default = false];
// Full filepath of folder to dump the graphs at various stages of processing
@@ -183,4 +184,22 @@ message TocoFlags {
// Boolean indicating whether to dump the graph after every graph
// transformation.
optional bool dump_graphviz_include_video = 25;
+
+ // Boolean indicating whether to quantize the weights of the converted float
+ // model. Model size will be reduced and there will be latency improvements
+ // (at the cost of accuracy).
+ optional bool post_training_quantize = 26 [default = false];
+
+ // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // TODO(ycling): Consider to rename the following 2 flags and don't call it
+ // "Eager".
+ // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool allow_eager_ops = 27 [default = false];
+
+ // When enabled, all TensorFlow ops will be converted to TFLite Eager
+ // ops directly. This will force `allow_eager_ops` to true.
+ // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool force_eager_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 34130a02b0..a7c17156b1 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
+
+ tf_import_flags.import_all_ops_as_unsupported =
+ toco_flags.force_eager_ops();
+
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break;
@@ -281,12 +285,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
- if (toco_flags.quantize_weights()) {
- // Run the quantize weights transformation after batchnorms have been
- // folded into the weights.
- RunGraphTransformations(model, "quantize weights transformation",
- {new QuantizeWeights});
- }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -403,9 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TENSORFLOW_GRAPHDEF:
ExportTensorFlowGraphDef(model, output_file_contents);
break;
- case TFLITE:
- toco::tflite::Export(model, allow_custom_ops, output_file_contents);
- break;
+ case TFLITE: {
+ toco::tflite::ExportParams params;
+
+ // Always allow custom ops when eager ops are allowed.
+ if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
+ params.allow_eager_ops = true;
+ params.allow_custom_ops = true;
+ } else if (allow_custom_ops) {
+ params.allow_custom_ops = true;
+ }
+
+ params.quantize_weights = toco_flags.post_training_quantize();
+
+ toco::tflite::Export(model, output_file_contents, params);
+ } break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
break;
diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h
index d72a3bd1f3..319f1066cd 100644
--- a/tensorflow/contrib/lite/toco/toco_types.h
+++ b/tensorflow/contrib/lite/toco/toco_types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_
#include <string>
#include "tensorflow/core/platform/platform.h"
@@ -42,4 +42,4 @@ using tensorflow::uint8;
} // namespace toco
-#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 2ad2719811..6ab93d9316 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -405,6 +405,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
+ HANDLE_OPERATORTYPENAME_CASE(Unpack)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -2278,4 +2279,14 @@ void UndoWeightsShuffling(Model* model) {
}
}
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
+ if (src.minmax) {
+ dst->GetOrCreateMinMax() = src.GetMinMax();
+ }
+ if (src.quantization_params) {
+ dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
+ }
+ dst->narrow_range = src.narrow_range;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index b99e6111fe..5f4b8cb66a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
// - For the remaining indices [0..i0), d0[i0] == 1.
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+ return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
// If there is a wildcard dimension (-1), this may return a negative value.
@@ -348,6 +353,9 @@ tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
// so that the rest of toco doesn't need to know about shuffled weights.
void UndoWeightsShuffling(Model* model);
+// Copies minmax, quantization_params, and narrow_range.
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD
new file mode 100644
index 0000000000..1b60d6a60d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/BUILD
@@ -0,0 +1,328 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+
+common_linkopts = tflite_linkopts() + select({
+ "//conditions:default": [],
+ "//tensorflow:android": [
+ "-pie",
+ "-llog",
+ ],
+})
+
+cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = ["utils.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":utils",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_op",
+ srcs = ["run_tflite_model_op.cc"],
+ copts = tflite_copts(),
+ deps = [
+ ":utils",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "android_required_build_flags",
+ srcs = ["android_required_build_flags.cc"],
+ copts = tflite_copts(),
+)
+
+tf_cc_test(
+ name = "run_tflite_model_op_test",
+ srcs = ["run_tflite_model_op_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ":run_tflite_model_op",
+ ":android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "stage",
+ hdrs = ["stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "file_reader_stage",
+ srcs = ["file_reader_stage.cc"],
+ hdrs = ["file_reader_stage.h"],
+ deps = [
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+tf_cc_test(
+ name = "file_reader_stage_test",
+ srcs = ["file_reader_stage_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":file_reader_stage",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_stage",
+ srcs = ["run_tflite_model_stage.cc"],
+ hdrs = ["run_tflite_model_stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":run_tflite_model_op",
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "accuracy_eval_stage",
+ hdrs = ["accuracy_eval_stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline",
+ srcs = ["eval_pipeline.cc"],
+ hdrs = ["eval_pipeline.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":accuracy_eval_stage",
+ ":stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_test",
+ srcs = ["eval_pipeline_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":eval_pipeline",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline_builder",
+ srcs = ["eval_pipeline_builder.cc"],
+ hdrs = ["eval_pipeline_builder.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":eval_pipeline",
+ ":accuracy_eval_stage",
+ ":stage",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_builder_test",
+ srcs = ["eval_pipeline_builder_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":eval_pipeline_builder",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "csv_writer",
+ hdrs = ["csv_writer.h"],
+ copts = tflite_copts(),
+ deps = select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/README.md b/tensorflow/contrib/lite/tools/accuracy/README.md
new file mode 100644
index 0000000000..8100cd1e8c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/README.md
@@ -0,0 +1,38 @@
+## TFLite accuracy library.
+
+This library provides evaluation pipelines that can be used to evaluate
+accuracy and other metrics of a model. The resulting binary can be run on
+a desktop or on a mobile device.
+
+## Usage
+The tool provides an evaluation pipeline with different stages. Each
+stage outputs a Tensorflow graph.
+A sample usage is shown below.
+
+```C++
+// First build the pipeline.
+EvalPipelineBuilder builder;
+std::unique_ptr<EvalPipeline> eval_pipeline;
+auto status = builder.WithInput("pipeline_input", DT_FLOAT)
+ .WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+TF_CHECK_OK(status);
+
+// Now run the pipeline with inputs and outputs.
+std::unique_ptr<Session> session(NewSession(SessionOptions()));
+TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+Tensor input = ... read input for the model ...
+Tensor ground_truth = ... read ground truth for the model ...
+TF_CHECK_OK(eval_pipeline.Run(input1, ground_truth1));
+```
+For further examples, check the usage in [imagenet accuracy evaluation binary](ilsvrc/imagenet_model_evaluator.cc)
+
+## Measuring accuracy of published models.
+
+### ILSVRC (Imagenet Large Scale Visual Recognition Contest) classification task
+For measuring accuracy for [ILSVRC 2012 image classification task](http://www.image-net.org/challenges/LSVRC/2012/), the binary can be built
+using these
+[instructions.](ilsvrc/)
diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
new file mode 100644
index 0000000000..9cb843729a
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Base class for evaluation stage that evaluates the accuracy of the model.
+// This stage calculates the accuracy metrics given the model outputs and
+// expected ground truth.
+class AccuracyEval {
+ public:
+ AccuracyEval() = default;
+ AccuracyEval(const AccuracyEval&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&) = delete;
+
+ AccuracyEval(const AccuracyEval&&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&&) = delete;
+
+ virtual ~AccuracyEval() = default;
+
+ // Evaluates the accuracy of the model for given `model_outputs` and the
+ // `ground truth`.
+ // Derived classes can do additional book keeping, calculate aggregrate
+ // statistics etc for the given model.
+ virtual Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) = 0;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
new file mode 100644
index 0000000000..7fa8986716
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tensorflow on Android requires selective registration to be enabled in order
+// for certain types (e.g. DT_UINT8) to work.
+// Checks below ensure that for Android build, the right flags are passed to
+// the compiler.
+
+#if defined(__ANDROID__) && (!defined(__ANDROID_TYPES_FULL__) || \
+ !defined(SUPPORT_SELECTIVE_REGISTRATION))
+#error \
+ "Binary needs custom kernel support. For enabling custom kernels on " \
+ "Android, please pass -D__ANDROID_TYPES_FULL__ && " \
+ "-DSUPPORT_SELECTIVE_REGISTRATION for including the kernel in the binary."
+#endif
diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
new file mode 100644
index 0000000000..806b0d9418
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+
+#include <fstream>
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+// A simple CSV writer that writes values of same type for fixed number of
+// columns. This supports a very limited set of CSV spec and doesn't do any
+// escaping.
+// Usage:
+// std::ofstream * output_stream = ...
+// CSVWriter writer({"column1", "column2"}, output_stream);
+// writer.WriteRow({4, 5});
+// writer.Flush(); // flush results immediately.
+class CSVWriter {
+ public:
+ CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
+ : num_columns_(columns.size()), output_stream_(output_stream) {
+ TF_CHECK_OK(WriteRow(columns, output_stream_));
+ }
+
+ template <typename T>
+ Status WriteRow(const std::vector<T>& values) {
+ if (values.size() != num_columns_) {
+ return errors::InvalidArgument("Invalid size for row:", values.size(),
+ " expected: ", num_columns_);
+ }
+ return WriteRow(values, output_stream_);
+ }
+
+ void Flush() { output_stream_->flush(); }
+
+ ~CSVWriter() { output_stream_->flush(); }
+
+ private:
+ template <typename T>
+ static Status WriteRow(const std::vector<T>& values,
+ std::ofstream* output_stream) {
+ bool first = true;
+ for (const auto& v : values) {
+ if (!first) {
+ (*output_stream) << ", ";
+ } else {
+ first = false;
+ }
+ (*output_stream) << v;
+ }
+ (*output_stream) << "\n";
+ if (!output_stream->good()) {
+ return errors::Internal("Writing to stream failed.");
+ }
+ return Status::OK();
+ }
+ const size_t num_columns_;
+ std::ofstream* output_stream_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
new file mode 100644
index 0000000000..a03aba6a26
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+
+namespace tensorflow {
+namespace metrics {
+
+Status EvalPipeline::AttachSession(std::unique_ptr<Session> session) {
+ session_ = std::move(session);
+ TF_RETURN_IF_ERROR(session_->Create(model_graph_));
+ return Status::OK();
+}
+
+Status EvalPipeline::Run(const Tensor& input, const Tensor& ground_truth) {
+ if (session_ == nullptr) {
+ return errors::Internal("No session is associated with the graph.");
+ }
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(session_->Run({{params_.model_input_node_name, input}},
+ {params_.model_output_node_name}, {},
+ &outputs));
+ TF_RETURN_IF_ERROR(eval_->ComputeEval(outputs, ground_truth));
+ return Status::OK();
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
new file mode 100644
index 0000000000..c9cfc86613
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Pipeline for evaluating a model.
+// Runs the graph and passes the output of graph to
+// the provided instance of AccuracyEval.
+// Example usage:
+// AccuracyEval *eval;
+// GraphDef graph_def;
+// ... populate graph_def...
+//
+// EvalPipeline eval_pipeline(&graph_def,
+// {.model_input_node_name = "model_input",
+// .model_output_node_name = "model_output"},
+// eval);
+// std::unique_ptr<Session> session(NewSession(SessionOptions()));
+// TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+// Tensor input = ... read input for the model ...
+// Tensor ground_truth = ... read ground truth for the model ...
+// TF_CHECK_OK(eval_pipeline.Run(input, ground_truth));
+//
+class EvalPipeline {
+ public:
+ struct Params {
+ string model_input_node_name;
+ string model_output_node_name;
+ };
+
+ // Creates a new `EvalPipeline` object. The ownership of the `accuracy_eval`
+ // is retained by the caller. Lifetime of `accuracy_eval` instance should
+ // be longer than the lifetime of this instance of pipeline.
+ EvalPipeline(const GraphDef& graph, const Params& params,
+ AccuracyEval* accuracy_eval)
+ : model_graph_(graph),
+ params_(params),
+ eval_(accuracy_eval),
+ session_(nullptr) {}
+
+ EvalPipeline(const EvalPipeline&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&) = delete;
+
+ EvalPipeline(const EvalPipeline&&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&&) = delete;
+
+ // Attaches the given session to this instance of pipeline.
+ // The provided session object will be reused for subsequent calls to
+ // EvalPipeline::Run.
+ Status AttachSession(std::unique_ptr<Session> session);
+
+ // Runs the model by feeding `input` and then passes the output of the model
+ // along with provided `ground_truth` to the AccuracyEval instance by calling
+ // AccuracyEval::ComputeEval.
+ Status Run(const Tensor& input, const Tensor& ground_truth);
+
+ private:
+ GraphDef model_graph_;
+ Params params_;
+ AccuracyEval* eval_;
+ std::unique_ptr<Session> session_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
new file mode 100644
index 0000000000..2e16437e15
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInputStage(Stage* input_stage) {
+ input_stage_ = input_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithPreprocessingStage(
+ Stage* preprocessing_stage) {
+ preprocessing_stage_ = preprocessing_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithRunModelStage(
+ Stage* run_model_stage) {
+ run_model_stage_ = run_model_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithAccuracyEval(
+ AccuracyEval* accuracy_eval) {
+ accuracy_eval_ = accuracy_eval;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInput(const string& input_name,
+ DataType input_type) {
+ input_name_ = input_name;
+ input_type_ = input_type;
+ return *this;
+}
+
+Status EvalPipelineBuilder::Build(
+ const Scope& scope, std::unique_ptr<EvalPipeline>* eval_pipeline) {
+ if (input_stage_ == nullptr) {
+ return errors::InvalidArgument("Input stage is null.");
+ }
+ if (preprocessing_stage_ == nullptr) {
+ return errors::InvalidArgument("Preprocessing stage is null.");
+ }
+ if (run_model_stage_ == nullptr) {
+ return errors::InvalidArgument("Run model stage is null.");
+ }
+ if (accuracy_eval_ == nullptr) {
+ return errors::InvalidArgument("accuracy_eval is null.");
+ }
+ if (input_name_.empty()) {
+ return errors::InvalidArgument("input name is not set.");
+ }
+ if (input_type_ == DT_INVALID) {
+ return errors::InvalidArgument("input type is not set.");
+ }
+
+ auto input_placeholder =
+ ops::Placeholder(scope.WithOpName(input_name_), input_type_);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ input_stage_->AddToGraph(scope, input_placeholder);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ preprocessing_stage_->AddToGraph(scope, input_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ run_model_stage_->AddToGraph(scope, preprocessing_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = input_name_;
+ params.model_output_node_name = run_model_stage_->output_name();
+ *eval_pipeline =
+ absl::make_unique<EvalPipeline>(graph_def, params, accuracy_eval_);
+
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
new file mode 100644
index 0000000000..692db022f8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A builder to simplify construction of an `EvalPipeline` instance.
+// The `Build` method creates an |EvalPipeline| with the following structure:
+// |input| -> |input_stage|
+// |--> |preprocessing_stage|
+// |--> |run_model_stage| -> |accuracy_eval_stage|.
+// The stages are chained in the order shown above. Any missing stage results in
+// an error. The ownership of the stage object is retained by the caller. Stage
+// objects need to exist until the |Build| method is called.
+//
+// Currently only single inputs are supported.
+//
+// Example Usage:
+// EvalPipelineBuilder builder;
+// std::unique_ptr<EvalPipeline> eval_pipeline;
+// auto status = builder.WithInput("pipeline_input", DT_FLOAT)
+// .WithInputStage(&input_stage)
+// .WithRunModelStage(&run_model_stage)
+// .WithPreprocessingStage(&preprocess_stage)
+// .WithAccuracyEval(&eval)
+// .Build(scope, &eval_pipeline);
+// TF_CHECK_OK(status);
+class EvalPipelineBuilder {
+ public:
+ EvalPipelineBuilder() = default;
+ EvalPipelineBuilder(const EvalPipelineBuilder&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&) = delete;
+
+ EvalPipelineBuilder(const EvalPipelineBuilder&&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&&) = delete;
+
+ // Sets the input stage for the pipeline.
+ // Input stage converts the input, say filename into appropriate format
+ // that can be consumed by the preprocessing stage.
+ EvalPipelineBuilder& WithInputStage(Stage* input_stage);
+
+ // Sets the preprocessing stage for the pipeline.
+ // Preprocessing stage converts the input into a format that can be used to
+ // run the model.
+ EvalPipelineBuilder& WithPreprocessingStage(Stage* preprocessing_stage);
+
+ // Sets the run model stage for the pipeline.
+ // This stage receives the preprocessing input and output of this stage is
+ // fed to the accuracy eval stage.
+ EvalPipelineBuilder& WithRunModelStage(Stage* run_model_stage);
+
+ // Sets the accuracy eval for the pipeline.
+ // Results of evaluating the pipeline are fed to the `accuracy_eval` instance.
+ EvalPipelineBuilder& WithAccuracyEval(AccuracyEval* accuracy_eval);
+
+ // Sets the name and type of input for the pipeline.
+ // TODO(shashishekhar): Support multiple inputs for the pipeline, use a vector
+ // here.
+ EvalPipelineBuilder& WithInput(const string& input_name, DataType input_type);
+
+ // Builds the pipeline and assigns the pipeline to `eval_pipeline`.
+ // If the pipeline creation fails `eval_pipeline` is untouched.
+ Status Build(const Scope& scope,
+ std::unique_ptr<EvalPipeline>* eval_pipeline);
+
+ private:
+ Stage* input_stage_ = nullptr;
+ Stage* preprocessing_stage_ = nullptr;
+ Stage* run_model_stage_ = nullptr;
+ AccuracyEval* accuracy_eval_ = nullptr;
+ string input_name_;
+ DataType input_type_ = DT_INVALID;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
new file mode 100644
index 0000000000..2d41929b79
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
@@ -0,0 +1,229 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class IdentityStage : public Stage {
+ public:
+ IdentityStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ inputs_.push_back(input.node()->name());
+ stage_output_ = ops::Identity(scope.WithOpName(output_), input);
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ const std::vector<string> input_params() { return inputs_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+ std::vector<string> inputs_;
+};
+
+class FailingStage : public Stage {
+ public:
+ FailingStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ scope.UpdateStatus(errors::Internal("Stage failed:", name_));
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+};
+
+class SimpleAccuracyEval : public AccuracyEval {
+ public:
+ SimpleAccuracyEval() {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ return Status::OK();
+ }
+};
+
+TEST(EvalPipelineBuilder, MissingPipelineStages) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status =
+ builder.WithInputStage(&input_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithRunModelStage(&run_model_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithPreprocessingStage(&preprocess_stage)
+ .Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithInput(pipeline_input, DT_FLOAT).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithAccuracyEval(&eval).Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+ EXPECT_TRUE(eval_pipeline);
+}
+
+TEST(EvalPipeline, InputStageFailure) {
+ FailingStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(scope.status().ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(0, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PreprocessingFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ FailingStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, GraphEvalFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ FailingStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(1, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PipelineHasCorrectSequence) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+
+ ASSERT_EQ(1, input_stage.times_called());
+ ASSERT_EQ(1, run_model_stage.times_called());
+ ASSERT_EQ(1, preprocess_stage.times_called());
+
+ EXPECT_EQ(pipeline_input, input_stage.input_params()[0]);
+ EXPECT_EQ(input_stage.output_name(), preprocess_stage.input_params()[0]);
+ EXPECT_EQ(preprocess_stage.output_name(), run_model_stage.input_params()[0]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
new file mode 100644
index 0000000000..ea0f6e19df
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+Tensor CreateFloatTensor(float value) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = value;
+ return tensor;
+}
+
+class NoOpAccuracyEval : public AccuracyEval {
+ public:
+ explicit NoOpAccuracyEval(const Status& status_to_return)
+ : status_to_return_(status_to_return) {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ model_outputs_ = model_outputs;
+ ground_truth_ = ground_truth;
+ was_called_ = true;
+ return status_to_return_;
+ }
+
+ bool WasCalled() { return was_called_; }
+ std::vector<Tensor> model_outputs() { return model_outputs_; }
+ Tensor ground_truth() { return ground_truth_; }
+
+ private:
+ std::vector<Tensor> model_outputs_;
+ Tensor ground_truth_;
+ Status status_to_return_;
+ bool was_called_ = false;
+};
+
+TEST(EvalPipeline, AccuracyEvalIsCalled) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ auto outputs = accuracy_eval.model_outputs();
+ ASSERT_EQ(1, outputs.size());
+ EXPECT_EQ(6.0f, outputs[0].scalar<float>()());
+ // Ground truth is unchanged.
+ EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()());
+}
+
+TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+
+ // Pass a string tensor instead of a float tensor.
+ Tensor string_tensor(DT_STRING, TensorShape{});
+ auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27));
+ EXPECT_FALSE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail"));
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/constants.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
index 7ed6ab7552..61bed369f8 100644
--- a/tensorflow/contrib/lite/delegates/eager/constants.h
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
@@ -12,18 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
-namespace tflite {
-namespace eager {
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
-// The prefix of Eager op custom code.
-// This will be matched agains the `custom_code` field in `OperatorCode`
-// Flatbuffer Table.
-constexpr char kCustomCodePrefix[] = "Eager";
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
-} // namespace eager
-} // namespace tflite
-
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_
+namespace tensorflow {
+namespace metrics {
+void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
new file mode 100644
index 0000000000..18db5837c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// A stage for reading a file into |string|.
+// Inputs: a string tensor: |file_name|.
+// Outputs: a string tensor: contents of |file_name|.
+class FileReaderStage : public Stage {
+ public:
+ string name() const override { return "stage_filereader"; }
+ string output_name() const override { return "stage_filereader_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
new file mode 100644
index 0000000000..a75f99187d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdio>
+#include <fstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class TempFile {
+ public:
+ TempFile() {
+ string file_path;
+ if (Env::Default()->LocalTempFilename(&file_path)) {
+ file_path_ = file_path;
+ created_ = true;
+ }
+ }
+
+ string filepath() { return file_path_; }
+ bool CreateFileWithContents(const std::string& contents) {
+ if (!created_) {
+ return false;
+ }
+ std::fstream file(file_path_, std::ios_base::out);
+ if (file) {
+ file << contents;
+ }
+ return file.good();
+ }
+
+ ~TempFile() {
+ if (created_) {
+ std::remove(file_path_.c_str());
+ }
+ }
+
+ private:
+ bool created_ = false;
+ string file_path_;
+};
+
+TEST(FileReaderStageTest, FileIsRead) {
+ TempFile file;
+ const string kFileContents = "Hello world.";
+ ASSERT_TRUE(file.CreateFileWithContents(kFileContents));
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, file.filepath());
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ string contents = outputs[0].scalar<string>()();
+ EXPECT_EQ(kFileContents, contents);
+}
+
+TEST(FileReaderStageTest, InvalidFile) {
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, string("non_existent_file"));
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ EXPECT_FALSE(run_status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
new file mode 100644
index 0000000000..98e2835b2e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -0,0 +1,182 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+
+common_linkopts = tflite_linkopts() + select({
+ "//conditions:default": [],
+ "//tensorflow:android": [
+ "-pie",
+ "-llog",
+ ],
+})
+
+cc_library(
+ name = "inception_preprocessing",
+ srcs = ["inception_preprocessing.cc"],
+ hdrs = ["inception_preprocessing.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "inception_preprocessing_test",
+ srcs = ["inception_preprocessing_test.cc"],
+ args = [
+ "--test_image=$(location :testdata/grace_hopper.jpg)",
+ ],
+ data = [":testdata/grace_hopper.jpg"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "no_oss", # b/114307765
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":inception_preprocessing",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_topk_eval",
+ srcs = ["imagenet_topk_eval.cc"],
+ hdrs = ["imagenet_topk_eval.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/tools/accuracy:accuracy_eval_stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "imagenet_topk_eval_test",
+ srcs = ["imagenet_topk_eval_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":imagenet_topk_eval",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_model_evaluator",
+ srcs = ["imagenet_model_evaluator.cc"],
+ hdrs = ["imagenet_model_evaluator.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":imagenet_topk_eval",
+ ":inception_preprocessing",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline",
+ "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline_builder",
+ "//tensorflow/contrib/lite/tools/accuracy:file_reader_stage",
+ "//tensorflow/contrib/lite/tools/accuracy:run_tflite_model_stage",
+ "//tensorflow/contrib/lite/tools/accuracy:utils",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_binary(
+ name = "imagenet_accuracy_eval",
+ srcs = ["imagenet_accuracy_eval.cc"],
+ copts = tflite_copts(),
+ linkopts = common_linkopts,
+ deps = [
+ ":imagenet_model_evaluator",
+ ":imagenet_topk_eval",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:csv_writer",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:framework_internal",
+ ],
+ },
+ ),
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
new file mode 100644
index 0000000000..362ea3ac34
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
@@ -0,0 +1,146 @@
+## Accuracy evaluation for ILSVRC 2012 (Imagenet Large Scale Visual Recognition Challenge) image classification task
+
+This binary can evaluate the accuracy of TFLite models trained for the [ILSVRC 2012 image classification task]
+(http://www.image-net.org/challenges/LSVRC/2012/).
+The binary takes the path to validation images and labels as inputs. It outputs the accuracy after running the TFLite model on the validation sets.
+
+To run the binary download the ILSVRC 2012 devkit [see instructions](#downloading-ilsvrc) and run the [`generate_validation_ground_truth` script](#ground-truth-label-generation) to generate the ground truth labels.
+
+## Parameters
+The binary takes the following parameters:
+
+* `model_file` : `string` \
+ Path to the TFlite model file.
+
+* `ground_truth_images_path`: `string` \
+ The path to the directory containing ground truth images.
+
+* `ground_truth_labels`: `string` \
+ Path to ground truth labels file. This file should contain the same number of labels as the number images in the ground truth directory. The labels are assumed to be in the
+ same order as the sorted filename of images. See [ground truth label generation](#ground-truth-label-generation)
+ section for more information about how to generate labels for images.
+
+* `model_output_labels`: `string` \
+ Path to the file containing labels, that is used to interpret the output of
+ the model. E.g. in case of mobilenets, this is the path to
+ `mobilenet_labels.txt` where each label is in the same order as the output
+ 1001 dimension tensor.
+
+* `output_path`: `string` \
+ This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set.
+
+and the following optional parameters:
+
+* `blacklist_file_path`: `string` \
+ Path to blacklist file. This file contains the indices of images that are blacklisted for evaluation. 1762 images are blacklisted in ILSVRC dataset. For details please refer to readme.txt of ILSVRC2014 devkit.
+
+* `num_images`: `int` (default=0) \
+ The number of images to process, if 0, all images in the directory are processed otherwise only num_images will be processed.
+
+* `num_threads`: `int` (default=4) \
+ The number of threads to use for evaluation.
+
+
+## Downloading ILSVRC
+In order to use this tool to run evaluation on the full 50K ImageNet dataset,
+download the data set from http://image-net.org/request.
+
+## Ground truth label generation
+The ILSVRC 2012 devkit `validation_ground_truth.txt` contains IDs that correspond to synset of the image.
+The accuracy binary however expects the ground truth labels to contain the actual name of
+category instead of synset ids. A conversion script has been provided to convert the validation ground truth to
+category labels. The `validation_ground_truth.txt` can be converted by the following steps:
+
+```
+ILSVRC_2012_DEVKIT_DIR=[set to path to ILSVRC 2012 devkit]
+VALIDATION_LABELS=[set to path to output]
+
+python generate_validation_labels.py -- \
+--ilsvrc_devkit_dir=${ILSVRC_2012_DEVKIT_DIR} \
+--validation_labels_output=${VALIDATION_LABELS}
+```
+
+## Running the binary
+
+### On Android
+
+(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android for configuring NDK and SDK.
+
+(1) Build using the following command:
+
+```
+bazel build -c opt \
+ --config=android_arm \
+ --config=monolithic \
+ --cxxopt='--std=c++11' \
+ --copt=-D__ANDROID_TYPES_FULL__ \
+ --copt=-DSUPPORT_SELECTIVE_REGISTRATION \
+ //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval
+```
+
+(2) Connect your phone. Push the binary to your phone with adb push
+ (make the directory if required):
+
+```
+adb push bazel-bin/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval /data/local/tmp
+```
+
+(3) Make the binary executable.
+
+```
+adb shell chmod +x /data/local/tmp/imagenet_accuracy_eval
+```
+
+(4) Push the TFLite model that you need to test. For example:
+
+```
+adb push mobilenet_quant_v1_224.tflite /data/local/tmp
+```
+
+(5) Push the imagenet images to device, make sure device has sufficient storage available before pushing the dataset:
+
+```
+adb shell mkdir /data/local/tmp/ilsvrc_images && \
+adb push ${IMAGENET_IMAGES_DIR} /data/local/tmp/ilsvrc_images
+```
+
+(6) Push the generated validation ground labels to device.
+
+```
+adb push ${VALIDATION_LABELS} /data/local/tmp/ilsvrc_validation_labels.txt
+```
+
+(7) Push the model labels text file to device.
+
+```
+adb push ${MODEL_LABELS_TXT} /data/local/tmp/model_output_labels.txt
+```
+
+(8) Run the binary.
+
+```
+adb shell /data/local/tmp/imagenet_accuracy_eval \
+ --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \
+ --ground_truth_images_path=/data/local/tmp/ilsvrc_images \
+ --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \
+ --model_output_labels=/data/local/tmp/model_output_labels.txt \
+ --output_file_path=/data/local/tmp/accuracy_output.txt \
+ --num_images=0 # Run on all images.
+```
+
+### On Desktop
+
+(1) Build and run using the following command:
+
+```
+bazel run -c opt \
+ --cxxopt='--std=c++11' \
+ -- \
+ //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval \
+ --model_file=mobilenet_quant_v1_224.tflite \
+ --ground_truth_images_path=${IMAGENET_IMAGES_DIR} \
+ --ground_truth_labels=${VALIDATION_LABELS} \
+ --model_output_labels=${MODEL_LABELS_TXT} \
+ --output_file_path=/tmp/accuracy_output.txt \
+ --num_images=0 # Run on all images.
+```
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
new file mode 100644
index 0000000000..b2f00e034e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
@@ -0,0 +1,1762 @@
+36
+50
+56
+103
+127
+195
+199
+226
+230
+235
+251
+254
+288
+397
+485
+543
+556
+601
+605
+652
+653
+663
+666
+697
+699
+705
+745
+774
+815
+816
+845
+848
+951
+977
+1006
+1008
+1018
+1056
+1066
+1079
+1102
+1128
+1133
+1188
+1193
+1194
+1266
+1271
+1372
+1382
+1405
+1426
+1430
+1441
+1477
+1502
+1518
+1606
+1621
+1642
+1658
+1716
+1722
+1734
+1750
+1807
+1880
+1882
+1936
+1951
+1970
+1977
+1983
+2086
+2112
+2146
+2152
+2217
+2304
+2321
+2404
+2526
+2554
+2563
+2647
+2675
+2732
+2733
+2827
+2839
+2854
+2865
+2872
+2880
+2886
+2893
+2915
+2973
+2993
+3019
+3020
+3044
+3047
+3049
+3117
+3167
+3197
+3201
+3282
+3311
+3315
+3344
+3345
+3378
+3425
+3477
+3497
+3514
+3525
+3531
+3587
+3637
+3650
+3657
+3686
+3720
+3732
+3798
+3802
+3823
+3847
+3971
+4007
+4059
+4072
+4087
+4099
+4124
+4126
+4156
+4195
+4197
+4241
+4275
+4321
+4333
+4352
+4356
+4368
+4377
+4428
+4440
+4497
+4509
+4513
+4526
+4528
+4565
+4570
+4596
+4633
+4677
+4696
+4743
+4759
+4778
+4835
+4976
+5032
+5058
+5061
+5066
+5140
+5145
+5177
+5197
+5219
+5226
+5228
+5240
+5289
+5292
+5385
+5433
+5445
+5448
+5465
+5488
+5549
+5553
+5609
+5638
+5666
+5683
+5711
+5729
+5760
+5793
+5819
+5837
+5855
+5858
+5961
+5966
+6048
+6197
+6199
+6201
+6206
+6215
+6220
+6264
+6278
+6280
+6305
+6388
+6411
+6466
+6490
+6509
+6523
+6529
+6625
+6754
+6818
+6886
+6890
+6893
+6902
+6912
+6942
+7067
+7141
+7144
+7214
+7217
+7278
+7312
+7320
+7329
+7342
+7345
+7369
+7408
+7428
+7463
+7556
+7557
+7582
+7613
+7621
+7624
+7647
+7671
+7679
+7734
+7736
+7747
+7750
+7777
+7851
+7854
+7883
+7889
+7902
+7985
+7999
+8070
+8087
+8096
+8100
+8128
+8180
+8195
+8367
+8377
+8465
+8497
+8508
+8528
+8538
+8581
+8657
+8692
+8742
+8784
+8839
+8861
+8912
+8970
+8982
+8987
+9103
+9155
+9180
+9248
+9284
+9300
+9357
+9382
+9414
+9450
+9463
+9493
+9522
+9543
+9563
+9630
+9643
+9653
+9693
+9747
+9787
+9847
+9851
+9892
+9913
+9929
+9965
+10026
+10027
+10055
+10154
+10189
+10243
+10297
+10337
+10346
+10347
+10377
+10403
+10483
+10518
+10540
+10559
+10567
+10568
+10580
+10606
+10615
+10618
+10645
+10685
+10707
+10710
+10807
+10837
+10856
+10873
+10989
+11046
+11054
+11132
+11163
+11218
+11243
+11255
+11265
+11292
+11306
+11307
+11310
+11343
+11349
+11407
+11411
+11422
+11427
+11431
+11439
+11496
+11644
+11662
+11690
+11692
+11725
+11743
+11767
+11812
+11867
+11871
+11897
+11975
+12001
+12046
+12076
+12119
+12158
+12216
+12252
+12261
+12264
+12293
+12296
+12306
+12357
+12358
+12371
+12415
+12422
+12472
+12497
+12499
+12538
+12540
+12544
+12569
+12645
+12647
+12652
+12699
+12727
+12750
+12832
+12849
+12873
+12889
+12902
+12996
+13029
+13065
+13073
+13075
+13079
+13268
+13338
+13372
+13529
+13530
+13537
+13623
+13626
+13637
+13644
+13646
+13681
+13778
+13782
+13805
+13846
+13853
+13881
+13914
+13961
+13975
+13979
+14011
+14135
+14143
+14144
+14161
+14170
+14207
+14212
+14215
+14260
+14311
+14368
+14373
+14400
+14509
+14523
+14566
+14594
+14628
+14629
+14633
+14649
+14652
+14705
+14709
+14732
+14734
+14802
+14834
+14865
+14883
+14933
+14965
+15003
+15100
+15159
+15178
+15272
+15289
+15308
+15319
+15327
+15353
+15357
+15363
+15408
+15429
+15438
+15469
+15485
+15495
+15501
+15524
+15530
+15551
+15598
+15613
+15614
+15631
+15646
+15647
+15661
+15679
+15684
+15758
+15775
+15826
+15838
+15840
+15931
+15940
+15969
+15976
+16003
+16037
+16045
+16116
+16200
+16233
+16247
+16339
+16340
+16345
+16361
+16400
+16408
+16430
+16468
+16474
+16500
+16521
+16565
+16569
+16584
+16613
+16645
+16662
+16671
+16719
+16724
+16760
+16764
+16805
+16849
+16893
+16896
+16954
+16979
+17023
+17026
+17034
+17038
+17049
+17054
+17061
+17073
+17074
+17133
+17163
+17176
+17177
+17217
+17237
+17246
+17298
+17312
+17324
+17337
+17365
+17415
+17442
+17449
+17576
+17578
+17581
+17588
+17589
+17591
+17593
+17605
+17661
+17688
+17689
+17695
+17697
+17703
+17736
+17746
+17758
+17788
+17798
+17828
+17841
+17884
+17898
+17924
+17956
+17960
+18001
+18013
+18025
+18052
+18097
+18106
+18158
+18211
+18223
+18240
+18261
+18266
+18297
+18325
+18329
+18335
+18340
+18351
+18433
+18462
+18466
+18524
+18569
+18581
+18631
+18696
+18748
+18766
+18787
+18793
+18950
+18961
+19001
+19008
+19011
+19154
+19177
+19217
+19255
+19286
+19320
+19333
+19360
+19403
+19407
+19419
+19464
+19499
+19510
+19519
+19555
+19564
+19605
+19610
+19689
+19699
+19705
+19707
+19725
+19732
+19741
+19774
+19799
+19838
+19877
+19903
+19940
+19945
+19952
+19973
+19987
+20024
+20086
+20111
+20114
+20174
+20193
+20201
+20245
+20299
+20329
+20439
+20485
+20534
+20562
+20575
+20578
+20601
+20604
+20605
+20648
+20658
+20665
+20677
+20693
+20697
+20699
+20791
+20794
+20808
+20876
+20890
+20906
+20914
+20990
+21065
+21128
+21144
+21151
+21156
+21175
+21199
+21204
+21207
+21225
+21236
+21241
+21342
+21351
+21429
+21533
+21550
+21622
+21676
+21727
+21764
+21785
+21822
+21830
+21845
+21853
+21867
+21909
+21910
+21923
+21924
+21937
+21948
+21955
+21962
+22008
+22017
+22026
+22037
+22072
+22075
+22135
+22138
+22160
+22167
+22190
+22287
+22375
+22440
+22457
+22460
+22471
+22481
+22484
+22488
+22515
+22553
+22679
+22703
+22714
+22730
+22735
+22752
+22768
+22809
+22813
+22817
+22846
+22902
+22910
+22944
+22986
+23026
+23053
+23065
+23088
+23117
+23124
+23126
+23132
+23142
+23165
+23172
+23223
+23264
+23280
+23322
+23335
+23439
+23453
+23455
+23474
+23501
+23518
+23580
+23589
+23608
+23614
+23641
+23649
+23660
+23698
+23728
+23766
+23809
+23859
+23874
+23902
+23946
+24040
+24105
+24132
+24137
+24151
+24153
+24157
+24171
+24271
+24281
+24296
+24303
+24308
+24328
+24332
+24338
+24402
+24440
+24453
+24466
+24504
+24531
+24543
+24547
+24556
+24562
+24610
+24649
+24660
+24693
+24706
+24745
+24834
+24948
+24963
+25056
+25057
+25083
+25093
+25120
+25150
+25161
+25197
+25219
+25220
+25253
+25257
+25290
+25327
+25332
+25344
+25387
+25390
+25422
+25453
+25481
+25489
+25587
+25599
+25600
+25622
+25681
+25686
+25702
+25708
+25740
+25776
+25870
+25918
+25973
+25978
+25986
+25987
+26033
+26038
+26041
+26087
+26113
+26155
+26162
+26184
+26235
+26299
+26301
+26318
+26364
+26383
+26430
+26511
+26528
+26561
+26618
+26653
+26688
+26697
+26778
+26940
+26951
+27023
+27029
+27037
+27046
+27051
+27118
+27244
+27252
+27258
+27272
+27283
+27303
+27381
+27392
+27403
+27422
+27437
+27440
+27476
+27493
+27494
+27501
+27506
+27550
+27559
+27571
+27581
+27596
+27604
+27612
+27665
+27687
+27701
+27711
+27732
+27759
+27766
+27772
+27797
+27813
+27854
+27864
+27865
+27879
+27894
+27907
+27958
+27963
+27969
+28003
+28027
+28032
+28051
+28058
+28079
+28093
+28120
+28132
+28194
+28227
+28324
+28328
+28331
+28360
+28373
+28419
+28431
+28436
+28451
+28467
+28471
+28527
+28541
+28588
+28640
+28649
+28662
+28670
+28678
+28722
+28768
+28780
+28835
+28863
+28879
+28885
+28928
+28948
+28954
+28963
+28969
+29020
+29065
+29077
+29105
+29117
+29143
+29166
+29172
+29299
+29302
+29342
+29357
+29378
+29410
+29411
+29414
+29415
+29447
+29473
+29488
+29499
+29505
+29533
+29537
+29601
+29637
+29650
+29667
+29671
+29681
+29686
+29708
+29721
+29749
+29755
+29771
+29853
+29886
+29894
+29919
+29928
+29990
+30008
+30064
+30067
+30107
+30150
+30160
+30164
+30186
+30195
+30219
+30243
+30282
+30314
+30324
+30389
+30418
+30497
+30550
+30592
+30615
+30624
+30640
+30650
+30695
+30720
+30741
+30750
+30751
+30767
+30830
+30856
+30885
+30901
+30907
+30953
+30985
+31005
+31027
+31034
+31045
+31057
+31071
+31109
+31119
+31227
+31230
+31250
+31303
+31320
+31371
+31401
+31440
+31447
+31464
+31478
+31487
+31494
+31525
+31553
+31554
+31558
+31572
+31588
+31639
+31641
+31683
+31698
+31704
+31708
+31717
+31722
+31781
+31786
+31788
+31791
+31803
+31850
+31853
+31862
+31886
+31901
+31944
+32020
+32048
+32052
+32073
+32094
+32116
+32147
+32180
+32212
+32218
+32256
+32270
+32305
+32411
+32414
+32430
+32465
+32484
+32534
+32584
+32589
+32608
+32612
+32613
+32615
+32641
+32674
+32697
+32708
+32757
+32763
+32796
+32824
+32861
+32877
+32944
+32945
+32946
+32984
+33004
+33012
+33029
+33050
+33090
+33096
+33097
+33124
+33139
+33161
+33170
+33173
+33179
+33191
+33293
+33367
+33370
+33371
+33373
+33399
+33415
+33436
+33440
+33443
+33488
+33551
+33563
+33564
+33629
+33643
+33664
+33685
+33696
+33714
+33722
+33728
+33764
+33809
+33868
+33883
+33913
+33942
+33956
+33994
+34081
+34089
+34091
+34098
+34178
+34207
+34269
+34287
+34348
+34392
+34445
+34447
+34455
+34529
+34579
+34591
+34643
+34659
+34692
+34729
+34758
+34836
+34857
+34862
+34883
+34930
+34942
+34957
+34963
+35003
+35089
+35180
+35187
+35209
+35220
+35239
+35247
+35253
+35263
+35380
+35393
+35394
+35408
+35452
+35485
+35486
+35557
+35578
+35639
+35663
+35688
+35746
+35832
+35862
+35890
+35903
+35917
+35929
+35946
+35984
+36060
+36084
+36090
+36124
+36135
+36151
+36197
+36249
+36269
+36303
+36364
+36377
+36398
+36402
+36418
+36421
+36435
+36499
+36511
+36521
+36544
+36556
+36601
+36627
+36640
+36660
+36673
+36676
+36787
+36790
+36797
+36821
+36840
+36901
+36921
+36934
+37006
+37041
+37051
+37112
+37160
+37167
+37213
+37231
+37242
+37274
+37313
+37332
+37391
+37416
+37522
+37594
+37621
+37664
+37699
+37731
+37915
+37968
+38030
+38070
+38117
+38128
+38135
+38172
+38184
+38224
+38277
+38295
+38311
+38428
+38464
+38529
+38549
+38599
+38623
+38673
+38681
+38713
+38722
+38726
+38762
+38867
+38872
+38944
+38947
+39015
+39023
+39028
+39043
+39068
+39080
+39097
+39118
+39171
+39197
+39236
+39254
+39271
+39277
+39280
+39336
+39338
+39340
+39341
+39358
+39364
+39497
+39503
+39537
+39541
+39559
+39560
+39562
+39596
+39600
+39613
+39623
+39656
+39670
+39781
+39810
+39832
+39861
+39875
+39892
+39918
+39919
+40008
+40016
+40082
+40091
+40095
+40164
+40213
+40234
+40274
+40279
+40324
+40332
+40341
+40349
+40365
+40438
+40446
+40482
+40501
+40510
+40516
+40541
+40544
+40545
+40574
+40617
+40659
+40668
+40742
+40754
+40758
+40764
+40765
+40795
+40858
+40901
+40985
+40986
+41080
+41112
+41121
+41136
+41196
+41199
+41219
+41233
+41246
+41278
+41376
+41401
+41409
+41434
+41470
+41492
+41502
+41517
+41571
+41572
+41608
+41648
+41699
+41773
+41779
+41801
+41837
+41843
+41849
+41855
+41873
+41881
+41901
+41924
+41926
+41935
+41962
+42008
+42062
+42069
+42072
+42094
+42097
+42104
+42112
+42117
+42137
+42147
+42170
+42185
+42224
+42237
+42250
+42254
+42257
+42276
+42282
+42298
+42321
+42351
+42372
+42378
+42420
+42446
+42453
+42466
+42470
+42502
+42514
+42518
+42527
+42662
+42721
+42727
+42743
+42794
+42840
+42843
+42871
+42872
+42897
+42950
+42956
+42967
+42969
+42975
+42995
+43005
+43008
+43046
+43052
+43091
+43103
+43124
+43198
+43225
+43228
+43385
+43394
+43402
+43405
+43408
+43423
+43503
+43529
+43557
+43647
+43656
+43704
+43706
+43714
+43745
+43748
+43759
+43812
+43927
+43950
+43997
+43998
+44016
+44018
+44025
+44060
+44066
+44099
+44128
+44149
+44150
+44169
+44184
+44198
+44254
+44272
+44293
+44310
+44352
+44389
+44399
+44400
+44442
+44451
+44470
+44474
+44522
+44569
+44590
+44713
+44738
+44787
+44823
+44829
+44845
+44895
+44918
+44975
+45024
+45121
+45148
+45154
+45179
+45208
+45210
+45215
+45218
+45220
+45235
+45265
+45282
+45283
+45285
+45286
+45303
+45351
+45359
+45396
+45407
+45414
+45472
+45519
+45522
+45564
+45621
+45641
+45660
+45678
+45695
+45696
+45710
+45780
+45800
+45823
+45828
+45862
+45947
+45964
+46001
+46050
+46084
+46113
+46132
+46146
+46198
+46221
+46234
+46236
+46256
+46272
+46298
+46325
+46337
+46347
+46374
+46386
+46388
+46437
+46491
+46560
+46561
+46589
+46600
+46656
+46660
+46664
+46673
+46690
+46700
+46808
+46809
+46828
+46918
+46963
+46979
+46984
+47005
+47088
+47097
+47100
+47143
+47147
+47261
+47320
+47369
+47450
+47503
+47533
+47538
+47576
+47601
+47608
+47618
+47621
+47624
+47659
+47681
+47698
+47708
+47745
+47817
+47826
+47879
+47883
+47917
+47937
+47957
+48000
+48023
+48076
+48099
+48130
+48133
+48281
+48298
+48321
+48349
+48351
+48353
+48358
+48371
+48426
+48455
+48522
+48526
+48544
+48573
+48606
+48609
+48646
+48667
+48699
+48701
+48740
+48773
+48777
+48785
+48847
+48886
+48940
+48986
+49029
+49054
+49100
+49121
+49137
+49157
+49191
+49222
+49291
+49315
+49347
+49374
+49376
+49381
+49407
+49427
+49481
+49497
+49624
+49785
+49791
+49835
+49875
+49877
+49981
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
new file mode 100644
index 0000000000..c32a41e50d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
@@ -0,0 +1,105 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tool to convert ILSVRC devkit validation ground truth to synset labels."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+from os import path
+import sys
+import scipy.io
+
+_SYNSET_ARRAYS_RELATIVE_PATH = 'data/meta.mat'
+_VALIDATION_FILE_RELATIVE_PATH = 'data/ILSVRC2012_validation_ground_truth.txt'
+
+
+def _synset_to_word(filepath):
+ """Returns synset to word dictionary by reading sysnset arrays."""
+ mat = scipy.io.loadmat(filepath)
+ entries = mat['synsets']
+ # These fields are listed in devkit readme.txt
+ fields = [
+ 'synset_id', 'WNID', 'words', 'gloss', 'num_children', 'children',
+ 'wordnet_height', 'num_train_images'
+ ]
+ synset_index = fields.index('synset_id')
+ words_index = fields.index('words')
+ synset_to_word = {}
+ for entry in entries:
+ entry = entry[0]
+ synset_id = int(entry[synset_index][0])
+ first_word = entry[words_index][0].split(',')[0]
+ synset_to_word[synset_id] = first_word
+ return synset_to_word
+
+
+def _validation_file_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _VALIDATION_FILE_RELATIVE_PATH)
+
+
+def _synset_array_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _SYNSET_ARRAYS_RELATIVE_PATH)
+
+
+def _generate_validation_labels(ilsvrc_dir, output_file):
+ synset_to_word = _synset_to_word(_synset_array_path(ilsvrc_dir))
+ with open(_validation_file_path(ilsvrc_dir), 'r') as synset_id_file, open(
+ output_file, 'w') as output:
+ for synset_id in synset_id_file:
+ synset_id = int(synset_id)
+ output.write('%s\n' % synset_to_word[synset_id])
+
+
+def _check_arguments(args):
+ if not args.validation_labels_output:
+ raise ValueError('Invalid path to output file.')
+ ilsvrc_dir = args.ilsvrc_devkit_dir
+ if not ilsvrc_dir or not path.isdir(ilsvrc_dir):
+ raise ValueError('Invalid path to ilsvrc_dir')
+ if not path.exists(_validation_file_path(ilsvrc_dir)):
+ raise ValueError('Invalid path to ilsvrc_dir, cannot find validation file.')
+ if not path.exists(_synset_array_path(ilsvrc_dir)):
+ raise ValueError(
+ 'Invalid path to ilsvrc_dir, cannot find synset arrays file.')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Converts ILSVRC devkit validation_ground_truth.txt to synset'
+ ' labels file that can be used by the accuracy script.')
+ parser.add_argument(
+ '--validation_labels_output',
+ type=str,
+ help='Full path for outputting validation labels.')
+ parser.add_argument(
+ '--ilsvrc_devkit_dir',
+ type=str,
+ help='Full path to ILSVRC 2012 devikit directory.')
+ args = parser.parse_args()
+ try:
+ _check_arguments(args)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = path.basename(sys.argv[0])
+ sys.stderr.write('{0}: error: {1}\n'.format(file_name, str(e)))
+ sys.exit(1)
+ _generate_validation_labels(args.ilsvrc_devkit_dir,
+ args.validation_labels_output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
new file mode 100644
index 0000000000..2a8a2b9b59
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
@@ -0,0 +1,165 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iomanip>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+std::vector<double> GetAccuracies(
+ const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) {
+ std::vector<double> results;
+ results.reserve(accuracy_stats.number_of_images);
+ if (accuracy_stats.number_of_images > 0) {
+ for (int n : accuracy_stats.topk_counts) {
+ double accuracy = 0;
+ if (accuracy_stats.number_of_images > 0) {
+ accuracy = (n * 100.0) / accuracy_stats.number_of_images;
+ }
+ results.push_back(accuracy);
+ }
+ }
+ return results;
+}
+
+} // namespace
+
+// Writes results to a CSV file.
+class ResultsWriter : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
+ : writer_(std::move(writer)) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {}
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
+void ResultsWriter::OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ mutex_lock lock(mu_);
+ TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
+ writer_->Flush();
+}
+
+// Logs results to standard output with `kLogDelayUs` microseconds.
+class ResultsLogger : public ImagenetModelEvaluator::Observer {
+ public:
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override;
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
+ int total_num_images_ GUARDED_BY(mu_);
+ static constexpr int kLogDelayUs = 500 * 1000;
+ mutex mu_;
+};
+
+void ResultsLogger::OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
+ int total_num_images = 0;
+ for (const auto& kv : shard_id_image_count_map) {
+ total_num_images += kv.second;
+ }
+ LOG(ERROR) << "Starting model evaluation: " << total_num_images;
+ mutex_lock lock(mu_);
+ total_num_images_ = total_num_images;
+}
+
+void ResultsLogger::OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ auto now_us = Env::Default()->NowMicros();
+ int num_evaluated = stats.number_of_images;
+ mutex_lock lock(mu_);
+ if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
+ last_logged_time_us_ = now_us;
+ double current_percent = num_evaluated * 100.0 / total_num_images_;
+ LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_
+ << " images, " << std::setprecision(2) << std::fixed
+ << current_percent << "%";
+ }
+}
+
+int Main(int argc, char* argv[]) {
+ // TODO(shashishekhar): Make this binary configurable and model
+ // agnostic.
+ string output_file_path;
+ int num_threads = 4;
+ std::vector<Flag> flag_list = {
+ Flag("output_file_path", &output_file_path, "Path to output file."),
+ Flag("num_threads", &num_threads, "Number of threads."),
+ };
+ Flags::Parse(&argc, argv, flag_list);
+
+ std::unique_ptr<ImagenetModelEvaluator> evaluator;
+ CHECK(!output_file_path.empty()) << "Invalid output file path.";
+
+ CHECK(num_threads > 0) << "Invalid number of threads.";
+
+ TF_CHECK_OK(
+ ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
+
+ std::ofstream output_stream(output_file_path, std::ios::out);
+ CHECK(output_stream) << "Unable to open output file path: '"
+ << output_file_path << "'";
+
+ output_stream << std::setprecision(3) << std::fixed;
+ std::vector<string> columns;
+ columns.reserve(evaluator->params().num_ranks);
+ for (int i = 0; i < evaluator->params().num_ranks; i++) {
+ string column_name = "Top ";
+ tensorflow::strings::StrAppend(&column_name, i + 1);
+ columns.push_back(column_name);
+ }
+
+ ResultsWriter results_writer(
+ absl::make_unique<CSVWriter>(columns, &output_stream));
+ ResultsLogger logger;
+ evaluator->AddObserver(&results_writer);
+ evaluator->AddObserver(&logger);
+ LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
+ TF_CHECK_OK(evaluator->EvaluateModel());
+ return 0;
+}
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ return tensorflow::metrics::Main(argc, argv);
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
new file mode 100644
index 0000000000..63616fc3b4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -0,0 +1,351 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
+
+#include <fstream>
+#include <iomanip>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+using tensorflow::string;
+
+string StripTrailingSlashes(const string& path) {
+ int end = path.size();
+ while (end > 0 && path[end - 1] == '/') {
+ end--;
+ }
+ return path.substr(0, end);
+}
+
+tensorflow::Tensor CreateStringTensor(const string& value) {
+ tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+template <typename T>
+std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
+ if (n >= v.size()) return v;
+ std::vector<T> result(v.begin(), v.begin() + n);
+ return result;
+}
+
+template <typename T>
+std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
+ CHECK_GT(n, 0);
+ std::vector<std::vector<T>> vecs(n);
+ int input_index = 0;
+ int vec_index = 0;
+ while (input_index < v.size()) {
+ vecs[vec_index].push_back(v[input_index]);
+ vec_index = (vec_index + 1) % n;
+ input_index++;
+ }
+ CHECK_EQ(vecs.size(), n);
+ return vecs;
+}
+
+// File pattern for imagenet files.
+const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
+
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+class CompositeObserver : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit CompositeObserver(const std::vector<Observer*>& observers)
+ : observers_(observers) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnEvaluationStart(shard_id_image_count_map);
+ }
+ }
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
+ }
+ }
+
+ private:
+ const std::vector<ImagenetModelEvaluator::Observer*>& observers_
+ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
+/*static*/ Status ImagenetModelEvaluator::Create(
+ int argc, char* argv[], int num_threads,
+ std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
+ Params params;
+ const std::vector<Flag> flag_list = {
+ Flag("model_output_labels", &params.model_output_labels_path,
+ "Path to labels that correspond to output of model."
+ " E.g. in case of mobilenet, this is the path to label "
+ "file where each label is in the same order as the output"
+ " of the model."),
+ Flag("ground_truth_images_path", &params.ground_truth_images_path,
+ "Path to ground truth images."),
+ Flag("ground_truth_labels", &params.ground_truth_labels_path,
+ "Path to ground truth labels."),
+ Flag("num_images", &params.number_of_images,
+ "Number of examples to evaluate, pass 0 for all "
+ "examples. Default: 100"),
+ Flag("blacklist_file_path", &params.blacklist_file_path,
+ "Path to blacklist file (optional)."
+ "Path to blacklist file where each line is a single integer that is "
+ "equal to number of blacklisted image."),
+ Flag("model_file", &params.model_file_path,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result)
+ return errors::InvalidArgument("Invalid command line flags");
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->IsDirectory(params.ground_truth_images_path),
+ "Invalid ground truth data path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.ground_truth_labels_path),
+ "Invalid ground truth labels path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.model_output_labels_path),
+ "Invalid model output labels path.");
+
+ if (!params.blacklist_file_path.empty()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.blacklist_file_path),
+ "Invalid blacklist path.");
+ }
+
+ if (params.number_of_images < 0) {
+ return errors::InvalidArgument("Invalid: num_examples");
+ }
+
+ utils::ModelInfo model_info;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ utils::GetTFliteModelInfo(params.model_file_path, &model_info),
+ "Invalid TFLite model.");
+
+ *model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
+ model_info, params, num_threads);
+ return Status::OK();
+}
+
+struct ImageLabel {
+ string image;
+ string label;
+};
+
+Status EvaluateModelForShard(const uint64_t shard_id,
+ const std::vector<ImageLabel>& image_labels,
+ const std::vector<string>& model_labels,
+ const utils::ModelInfo& model_info,
+ const ImagenetModelEvaluator::Params& params,
+ ImagenetModelEvaluator::Observer* observer,
+ ImagenetTopKAccuracy* eval) {
+ const TensorShape& input_shape = model_info.input_shapes[0];
+ const int image_height = input_shape.dim_size(1);
+ const int image_width = input_shape.dim_size(2);
+ const bool is_quantized = (model_info.input_types[0] == DT_UINT8);
+
+ RunTFLiteModelStage::Params tfl_model_params;
+ tfl_model_params.model_file_path = params.model_file_path;
+ if (is_quantized) {
+ tfl_model_params.input_type = {DT_UINT8};
+ tfl_model_params.output_type = {DT_UINT8};
+ } else {
+ tfl_model_params.input_type = {DT_FLOAT};
+ tfl_model_params.output_type = {DT_FLOAT};
+ }
+
+ Scope root = Scope::NewRootScope();
+ FileReaderStage reader;
+ InceptionPreprocessingStage inc(image_height, image_width, is_quantized);
+ RunTFLiteModelStage tfl_model_stage(tfl_model_params);
+ EvalPipelineBuilder builder;
+
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+
+ auto build_status = builder.WithInputStage(&reader)
+ .WithPreprocessingStage(&inc)
+ .WithRunModelStage(&tfl_model_stage)
+ .WithAccuracyEval(eval)
+ .WithInput("input_file", DT_STRING)
+ .Build(root, &eval_pipeline);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
+ "Failure while building eval pipeline.");
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+
+ TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
+
+ for (const auto& image_label : image_labels) {
+ TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
+ CreateStringTensor(image_label.label)));
+ observer->OnSingleImageEvaluationComplete(
+ shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
+ }
+ return Status::OK();
+}
+
+Status FilterBlackListedImages(const string& blacklist_file_path,
+ std::vector<ImageLabel>* image_labels) {
+ if (!blacklist_file_path.empty()) {
+ std::vector<string> lines;
+ TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
+ std::vector<int> blacklist_ids;
+ blacklist_ids.reserve(lines.size());
+ // Populate blacklist_ids with indices of images.
+ std::transform(lines.begin(), lines.end(),
+ std::back_inserter(blacklist_ids),
+ [](const string& val) { return std::stoi(val) - 1; });
+
+ std::vector<ImageLabel> filtered_images;
+ std::sort(blacklist_ids.begin(), blacklist_ids.end());
+ const size_t size_post_filtering =
+ image_labels->size() - blacklist_ids.size();
+ filtered_images.reserve(size_post_filtering);
+ int blacklist_index = 0;
+ for (int image_index = 0; image_index < image_labels->size();
+ image_index++) {
+ if (blacklist_index < blacklist_ids.size() &&
+ blacklist_ids[blacklist_index] == image_index) {
+ blacklist_index++;
+ continue;
+ }
+ filtered_images.push_back((*image_labels)[image_index]);
+ }
+
+ if (filtered_images.size() != size_post_filtering) {
+ return errors::Internal("Invalid number of filtered images");
+ }
+ *image_labels = filtered_images;
+ }
+ return Status::OK();
+}
+
+Status ImagenetModelEvaluator::EvaluateModel() const {
+ if (model_info_.input_shapes.size() != 1) {
+ return errors::InvalidArgument("Invalid input shape");
+ }
+
+ const TensorShape& input_shape = model_info_.input_shapes[0];
+ // Input should be of the shape {1, height, width, 3}
+ if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
+ return errors::InvalidArgument("Invalid input shape for the model.");
+ }
+
+ string data_path =
+ StripTrailingSlashes(params_.ground_truth_images_path) + "/";
+
+ const string imagenet_file_pattern = data_path + kImagenetFilePattern;
+ std::vector<string> image_files;
+ TF_CHECK_OK(
+ Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
+ std::vector<string> ground_truth_image_labels;
+ TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
+ &ground_truth_image_labels));
+ CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
+
+ // Process files in filename sorted order.
+ std::sort(image_files.begin(), image_files.end());
+
+ std::vector<ImageLabel> image_labels;
+ image_labels.reserve(image_files.size());
+ for (int i = 0; i < image_files.size(); i++) {
+ image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
+ }
+
+ // Filter any blacklisted images.
+ TF_CHECK_OK(
+ FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
+
+ if (params_.number_of_images > 0) {
+ image_labels = GetFirstN(image_labels, params_.number_of_images);
+ }
+
+ std::vector<string> model_labels;
+ TF_RETURN_IF_ERROR(
+ utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
+ if (model_labels.size() != 1001) {
+ return errors::InvalidArgument("Invalid number of labels: ",
+ model_labels.size());
+ }
+
+ ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
+
+ auto img_labels = Split(image_labels, num_threads_);
+
+ BlockingCounter counter(num_threads_);
+
+ CompositeObserver observer(observers_);
+
+ ::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
+ num_threads_);
+ std::unordered_map<uint64_t, int> shard_id_image_count_map;
+ std::vector<std::function<void()>> thread_funcs;
+ thread_funcs.reserve(num_threads_);
+ for (int i = 0; i < num_threads_; i++) {
+ const auto& image_label = img_labels[i];
+ const uint64_t shard_id = i + 1;
+ shard_id_image_count_map[shard_id] = image_label.size();
+ auto func = [shard_id, &image_label, &model_labels, this, &observer, &eval,
+ &counter]() {
+ TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
+ model_info_, params_, &observer,
+ &eval));
+ counter.DecrementCount();
+ };
+ thread_funcs.push_back(func);
+ }
+
+ observer.OnEvaluationStart(shard_id_image_count_map);
+ for (const auto& func : thread_funcs) {
+ pool.Schedule(func);
+ }
+
+ counter.Wait();
+
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
new file mode 100644
index 0000000000..97e4232b35
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Evaluates models accuracy for ILSVRC dataset.
+//
+// Generates the top-1, top-k accuracy counts where k is
+// controlled by |num_ranks|.
+// Usage:
+// ModelInfo model_info = ..
+// ImagenetModelEvaluator::Params params;
+// .. set params to image, label, output label and model file path..
+// SomeObserver observer;
+// ImagenetModelEvaluator evaluator(model_info, params);
+// evaluator.AddObserver(&observer);
+// TF_CHECK_OK(evaluator.EvaluateModel());
+class ImagenetModelEvaluator {
+ public:
+ struct Params {
+ // Path to ground truth images.
+ string ground_truth_images_path;
+
+ // Path to labels file for ground truth image.
+ // This file should be generated with the scripts.
+ string ground_truth_labels_path;
+
+ // This is word labels generated by the model. The category
+ // indices of output probabilities generated by the model maybe different
+ // from the indices in the imagenet dataset.
+ string model_output_labels_path;
+
+ // Path to the model file.
+ string model_file_path;
+
+ // Path to black list file. 1762 images were blacklisted from
+ // original ILSVRC dataset. This black list file is present in
+ // ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
+ // devkit for details.
+ // This file is a list of image indices in a sorted order.
+ string blacklist_file_path;
+
+ // The maximum number of images to calculate accuracy.
+ // 0 means all images, a positive number means only the specified
+ // number of images.
+ int number_of_images = 0;
+
+ // Number of ranks, top K.
+ int num_ranks = 10;
+ };
+
+ // An evaluation observer.
+ // Observers can be called from multiple threads and need to be thread safe.
+ class Observer {
+ public:
+ Observer() = default;
+ Observer(const Observer&) = delete;
+ Observer& operator=(const Observer&) = delete;
+
+ Observer(const Observer&&) = delete;
+ Observer& operator=(const Observer&&) = delete;
+
+ // Called on start of evaluation.
+ // `shard_id_image_count_map` map from shard id to image count.
+ virtual void OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) = 0;
+
+ // Called when evaluation was complete for `image`.
+ virtual void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) = 0;
+
+ virtual ~Observer() = default;
+ };
+
+ ImagenetModelEvaluator(const utils::ModelInfo& model_info,
+ const Params& params, const int num_threads)
+ : model_info_(model_info), params_(params), num_threads_(num_threads) {}
+
+ // Factory method to create the evaluator by parsing command line arguments.
+ static Status Create(int argc, char* argv[], int num_threads,
+ std::unique_ptr<ImagenetModelEvaluator>* evaluator);
+
+ // Adds an observer that can observe evaluation events..
+ void AddObserver(Observer* observer) { observers_.push_back(observer); }
+
+ const Params& params() const { return params_; }
+
+ // Evaluates the provided model over the dataset.
+ Status EvaluateModel() const;
+
+ private:
+ const utils::ModelInfo model_info_;
+ const Params params_;
+ const int num_threads_;
+ std::vector<Observer*> observers_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
new file mode 100644
index 0000000000..c75baa82b1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
@@ -0,0 +1,114 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+
+#include <numeric>
+
+namespace {
+constexpr int kNumCategories = 1001;
+std::vector<int> GetTopK(const std::vector<float>& values, int k) {
+ CHECK_LE(k, values.size());
+ std::vector<int> indices(values.size());
+
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&values](int a, int b) { return values[a] > values[b]; });
+
+ indices.resize(k);
+ return indices;
+}
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+ImagenetTopKAccuracy::ImagenetTopKAccuracy(
+ const std::vector<string>& ground_truth_labels, int k)
+ : ground_truth_labels_(ground_truth_labels),
+ k_(k),
+ accuracy_counts_(k_, 0),
+ num_samples_(0) {
+ CHECK_EQ(kNumCategories, ground_truth_labels.size());
+}
+
+Status ImagenetTopKAccuracy::ComputeEval(
+ const std::vector<Tensor>& model_outputs, const Tensor& ground_truth) {
+ if (model_outputs.size() != 1) {
+ return errors::InvalidArgument("Invalid model output: ",
+ model_outputs.size());
+ }
+ const Tensor& output = model_outputs[0];
+ if (!output.shape().IsSameSize({1, kNumCategories})) {
+ return errors::InvalidArgument("Invalid shape of model output: ",
+ output.shape().DebugString());
+ }
+ if (ground_truth.dtype() != DT_STRING && ground_truth.dims() != 0) {
+ return errors::InvalidArgument("Invalid ground truth type: ",
+ ground_truth.DebugString());
+ }
+ string ground_truth_label = ground_truth.scalar<string>()();
+
+ std::vector<float> probabilities;
+ probabilities.reserve(kNumCategories);
+ if (output.dtype() == DT_FLOAT) {
+ auto probs = output.flat<float>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ } else {
+ auto probs = output.flat<uint8>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ }
+
+ CHECK_EQ(kNumCategories, probabilities.size());
+ std::vector<int> topK = GetTopK(probabilities, k_);
+ int ground_truth_index = GroundTruthIndex(ground_truth_label);
+ UpdateSamples(topK, ground_truth_index);
+ return Status::OK();
+}
+
+const ImagenetTopKAccuracy::AccuracyStats
+ImagenetTopKAccuracy::GetTopKAccuracySoFar() const {
+ mutex_lock lock(mu_);
+ AccuracyStats stats;
+ stats.number_of_images = num_samples_;
+ stats.topk_counts = accuracy_counts_;
+ return stats;
+}
+
+void ImagenetTopKAccuracy::UpdateSamples(const std::vector<int>& counts,
+ int ground_truth_index) {
+ mutex_lock lock(mu_);
+ for (size_t i = 0; i < counts.size(); ++i) {
+ if (ground_truth_index == counts[i]) {
+ for (size_t j = i; j < counts.size(); j++) {
+ accuracy_counts_[j] += 1;
+ }
+ break;
+ }
+ }
+ num_samples_++;
+}
+
+int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const {
+ auto index = std::find(ground_truth_labels_.cbegin(),
+ ground_truth_labels_.cend(), label);
+ CHECK(index != ground_truth_labels_.end()) << "Invalid label: " << label;
+ return std::distance(ground_truth_labels_.cbegin(), index);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
new file mode 100644
index 0000000000..cad646a30c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace metrics {
+// An |AccuracyEval| stage that calculates the top K error rate for model
+// evaluations on imagenet like datasets.
+// Inputs: A {1, 1001} shaped tensor that contains the probabilities for objects
+// predicted by the model.
+// Ground truth: A |string| label for the image.
+// From the input object probabilities, the stage computes the predicted labels
+// and finds the top K error rates by comparing the predictions with ground
+// truths.
+class ImagenetTopKAccuracy : public AccuracyEval {
+ public:
+ // Accuracy statistics.
+ struct AccuracyStats {
+ // Number of images evaluated.
+ int number_of_images;
+ // A vector of size |k| that contains the number of images
+ // that have correct labels in top K.
+ // E.g. topk_counts[0] contains number of images for which
+ // model returned the correct label as the first result.
+ // Similarly topk_counts[4] contains the number of images for which
+ // model returned the correct label in top 5 results.
+ // This can be used to compute the top K error-rate for the model.
+ std::vector<int> topk_counts;
+ };
+
+ // Creates a new instance of |ImagenetTopKAccuracy| with the given
+ // |ground_truth_labels| and |k|.
+ // Args:
+ // |ground_truth_labels| : an ordered vector of labels for images. This is
+ // used to compute the index for the predicted labels and ground_truth label.
+ ImagenetTopKAccuracy(const std::vector<string>& ground_truth_labels, int k);
+
+ // Computes accuracy for a given image. The |model_outputs| should
+ // be a vector containing exactly one Tensor of shape: {1, 1001} where each
+ // item is a probability of the predicted object representing the image as
+ // output by the model.
+ // Uses |ground_truth_labels| to compute the index of |model_outputs| and
+ // |ground_truth| and computes the top K error rate.
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override;
+
+ // Gets the topK accuracy for images that have been evaluated till now.
+ const AccuracyStats GetTopKAccuracySoFar() const;
+
+ private:
+ int GroundTruthIndex(const string& label) const;
+ void UpdateSamples(const std::vector<int>& counts, int ground_truth_index);
+ const std::vector<string> ground_truth_labels_;
+ const int k_;
+ std::vector<int> accuracy_counts_ GUARDED_BY(mu_);
+ int num_samples_ GUARDED_BY(mu_);
+ mutable mutex mu_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc
new file mode 100644
index 0000000000..ff332af5c5
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+const int kNumCategories = 1001;
+
+Tensor CreateStringTensor(const string& value) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+Tensor CreateOutputTensor() {
+ Tensor tensor(DT_FLOAT, TensorShape({1, kNumCategories}));
+ for (int i = 0; i < kNumCategories; i++) {
+ tensor.flat<float>()(i) = 0;
+ }
+ return tensor;
+}
+
+std::vector<string> CreateGroundTruth() {
+ std::vector<string> ground_truth;
+ ground_truth.reserve(kNumCategories);
+ for (int i = 0; i < kNumCategories; i++) {
+ string category;
+ strings::StrAppend(&category, i);
+ ground_truth.push_back(category);
+ }
+ return ground_truth;
+}
+
+TEST(ImagenetTopKAccuracy, AllCorrect) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(0, i);
+ }
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.8;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(1, i);
+ }
+ tensor.flat<float>()(1) = 0.9;
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(2, i);
+ }
+}
+
+TEST(ImagenetTopKAccuracy, Top5) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ // For first image, with ground truth "0" probabilities were
+ // 0.5 for "0",
+ // "0.6" for 1,
+ // "0.7" for 2,
+ // "0.8" for 3,
+ // "0.9" for 4.
+ // remaining all zeroes.
+
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.5;
+ tensor.flat<float>()(1) = 0.6;
+ tensor.flat<float>()(2) = 0.7;
+ tensor.flat<float>()(3) = 0.8;
+ tensor.flat<float>()(4) = 0.9;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[4]);
+
+ for (int i = 0; i < 4; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // Now for "1" only last two buckets are going to be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[3]);
+ EXPECT_EQ(2, accuracies.topk_counts[4]);
+ for (int i = 0; i < 3; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // All buckets will be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("4")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(3, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+
+ // No buckets will be affected
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("10")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(4, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
new file mode 100644
index 0000000000..7512b39c32
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+
+#include <memory>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+void CentralCropImage(const Scope& s, const tensorflow::Output& decoded_image,
+ double crop_fraction, tensorflow::Output* cropped_image) {
+ auto image_dims = ops::Slice(s, ops::Shape(s, decoded_image), {0}, {2});
+ auto height_width = ops::Cast(s, image_dims, DT_DOUBLE);
+ auto cropped_begin = ops::Div(
+ s, ops::Sub(s, height_width, ops::Mul(s, height_width, crop_fraction)),
+ 2.0);
+ auto bbox_begin = ops::Cast(s, cropped_begin, DT_INT32);
+ auto bbox_size = ops::Sub(s, image_dims, ops::Mul(s, bbox_begin, 2));
+ auto slice_begin = ops::Concat(s, {bbox_begin, Input({0})}, 0);
+ auto slice_size = ops::Concat(s, {bbox_size, {-1}}, 0);
+ *cropped_image = ops::Slice(s, decoded_image, slice_begin, slice_size);
+}
+
+} // namespace
+
+void InceptionPreprocessingStage::AddToGraph(const Scope& scope,
+ const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ ops::DecodeJpeg::Attrs attrs;
+ attrs.channels_ = 3;
+ auto decoded_jpeg = ops::DecodeJpeg(s, input, attrs);
+ tensorflow::Output cropped_image;
+ CentralCropImage(s, decoded_jpeg, params_.cropping_fraction, &cropped_image);
+ auto dims_expander = ops::ExpandDims(s, cropped_image, 0);
+ auto resized_image = ops::ResizeBilinear(
+ s, dims_expander,
+ ops::Const(s.WithOpName("size"), {image_height_, image_width_}));
+ if (is_quantized_) {
+ this->stage_output_ =
+ ops::Cast(s.WithOpName(output_name()), resized_image, DT_UINT8);
+ } else {
+ auto squeezed_image = ops::Squeeze(s, resized_image);
+ auto normalized_image =
+ ops::Div(s,
+ ops::Sub(s, squeezed_image,
+ {params_.input_means[0], params_.input_means[1],
+ params_.input_means[2]}),
+ {params_.scale});
+ this->stage_output_ =
+ ops::ExpandDims(s.WithOpName(output_name()), normalized_image, {0});
+ }
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h
new file mode 100644
index 0000000000..15df719817
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+
+#include <utility>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage that does inception preprocessing.
+// Inputs: A tensor containing bytes of a JPEG image.
+// Outputs: A tensor containing rescaled and preprocessed image that has
+// shape {1, image_height, image_width, 3}, where 3 is the number of channels.
+class InceptionPreprocessingStage : public Stage {
+ public:
+ struct Params {
+ std::vector<float> input_means;
+ float scale;
+ double cropping_fraction;
+ };
+
+ static Params DefaultParams() {
+ return {.input_means = {127.5, 127.5, 127.5},
+ .scale = 127.5,
+ .cropping_fraction = 0.875};
+ }
+
+ // Creates a new preprocessing stage object with provided |image_width|
+ // |image_height| as the size of output image.
+ // If |is_quantized| is set to true then |params| is ignored since quantized
+ // images don't go through any preprocessing.
+ InceptionPreprocessingStage(int image_width, int image_height,
+ bool is_quantized,
+ Params params = DefaultParams())
+ : image_width_(image_width),
+ image_height_(image_height),
+ is_quantized_(is_quantized),
+ params_(std::move(params)) {}
+
+ string name() const override { return "stage_inception_preprocess"; }
+ string output_name() const override {
+ return "stage_inception_preprocess_output";
+ }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ int image_width_;
+ int image_height_;
+ bool is_quantized_;
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc
new file mode 100644
index 0000000000..3587878ba3
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <string>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_image_file = nullptr;
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+using tensorflow::Status;
+using tensorflow::Tensor;
+
+Status GetContents(const string& filename, string* output) {
+ std::ifstream input(filename, std::ios::binary);
+ const int kBufferSize = 2048;
+ char buffer[kBufferSize];
+ while (true) {
+ input.read(buffer, kBufferSize);
+ output->append(buffer, input.gcount());
+ if (!input.good()) {
+ if (input.eof()) return Status::OK();
+ return Status(tensorflow::error::ABORTED, "Failed to read file.");
+ }
+ }
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessQuantized) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = true;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_UINT8, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessFloat) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = false;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_FLOAT, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+} // namespace
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_image_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_image", g_test_image_file,
+ "Path to image file for test."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg
new file mode 100644
index 0000000000..d2a427810f
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
new file mode 100644
index 0000000000..da4258f1c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+namespace {
+Status ValidateInputsMatch(const OpInputList& input_tensors,
+ const tflite::Interpreter& interpreter) {
+ std::vector<int> tflite_tensor_indices = interpreter.inputs();
+ if (tflite_tensor_indices.size() != input_tensors.size()) {
+ return errors::InvalidArgument(
+ "size mismatch, interpreter size: ", tflite_tensor_indices.size(),
+ " actual: ", input_tensors.size());
+ }
+
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const TfLiteTensor* tflite_tensor =
+ interpreter.tensor(tflite_tensor_indices[i]);
+ if (tflite_tensor == nullptr) {
+ return errors::InvalidArgument("Tensor is null at index: ", i);
+ }
+
+ const Tensor& tensor = input_tensors[i];
+ auto i_type = metrics::utils::GetTFDataType(tflite_tensor->type);
+ auto i_shape = metrics::utils::GetTFLiteTensorShape(*tflite_tensor);
+ if (i_type != tensor.dtype()) {
+ return errors::InvalidArgument("Data types mismatch for tensors: ", i,
+ " expected: ", i_type,
+ " got: ", tensor.dtype());
+ }
+
+ if (i_shape != tensor.shape()) {
+ return errors::InvalidArgument("Data shapes mismatch for tensors: ", i,
+ " expected: ", i_shape,
+ " got: ", tensor.shape());
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+class RunTFLiteModelOp : public OpKernel {
+ public:
+ explicit RunTFLiteModelOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string model_file_path;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("model_file_path", &model_file_path));
+ model_ = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ OP_REQUIRES(ctx, model_,
+ errors::InvalidArgument(
+ "Model loading failed. Invalid model file path: ",
+ model_file_path));
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
+ OP_REQUIRES(ctx, interpreter_,
+ errors::Internal("Interpreter creation failed."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpInputList input_tensors;
+ OP_REQUIRES_OK(context, context->input_list("model_input", &input_tensors));
+
+ OP_REQUIRES_OK(context, ValidateInputsMatch(input_tensors, *interpreter_));
+ OpOutputList output_tensors;
+ OP_REQUIRES_OK(context,
+ context->output_list("model_output", &output_tensors));
+ auto tfl_outputs = interpreter_->outputs();
+ OP_REQUIRES(context, output_tensors.size() == tfl_outputs.size(),
+ errors::InvalidArgument(
+ "Invalid output size, expected: ", tfl_outputs.size(),
+ " got: ", output_tensors.size()));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ DataType tfl_type = metrics::utils::GetTFDataType(
+ interpreter_->tensor(tfl_outputs[i])->type);
+ DataType otype = output_tensors.expected_output_dtype(i);
+ OP_REQUIRES(
+ context, tfl_type == otype,
+ errors::InvalidArgument("Invalid data type for output at index: ", i,
+ " expected: ", tfl_type, " got: ", otype));
+ }
+
+ auto allocation_status = interpreter_->AllocateTensors();
+ OP_REQUIRES(context, allocation_status == kTfLiteOk,
+ errors::Internal("Unable to allocate tensors."));
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const int tfl_index = interpreter_->inputs()[i];
+ TfLiteTensor* tflite_tensor = interpreter_->tensor(tfl_index);
+ auto tensor_bytes = input_tensors[i].tensor_data();
+ OP_REQUIRES(context, tflite_tensor->bytes == tensor_bytes.size(),
+ errors::InvalidArgument(
+ "Size mismatch, expected: ", tflite_tensor->bytes,
+ " got: ", tensor_bytes.size()));
+ std::memcpy(tflite_tensor->data.raw, tensor_bytes.data(),
+ tensor_bytes.size());
+ }
+ auto invocation_status = interpreter_->Invoke();
+ OP_REQUIRES(context, invocation_status == kTfLiteOk,
+ errors::Internal("Interpreter invocation failed."));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ auto tfl_tensor = interpreter_->tensor(tfl_outputs[i]);
+ TensorShape shape = metrics::utils::GetTFLiteTensorShape(*tfl_tensor);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, output_tensors.allocate(i, shape, &output));
+ auto tensor_bytes = output->tensor_data();
+ OP_REQUIRES(context, tensor_bytes.size() == tfl_tensor->bytes,
+ errors::Internal("Invalid size"));
+ std::memcpy(const_cast<char*>(tensor_bytes.data()), tfl_tensor->data.raw,
+ tfl_tensor->bytes);
+ }
+ }
+
+ private:
+ std::unique_ptr<tflite::FlatBufferModel> model_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RunTFLiteModel").Device(DEVICE_CPU),
+ RunTFLiteModelOp);
+
+REGISTER_OP("RunTFLiteModel")
+ .Input("model_input: input_type")
+ .Output("model_output: output_type")
+ .Attr("model_file_path: string")
+ .Attr("input_type : list(type)")
+ .Attr("output_type: list(type)")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // TODO(shashishekhar): Infer the correct shape based on output_type and
+ // maybe another attribute.
+ return shape_inference::UnknownShape(c);
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
new file mode 100644
index 0000000000..88175984a0
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace {
+
+TEST(RunTfliteModelOpTest, ModelIsRun) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 3}), // a
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(
+ session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_EQ(2, outputs.size());
+
+ for (const auto& tensor : outputs) {
+ EXPECT_TRUE(tensor.shape().IsSameSize({1, 8, 8, 3}));
+ }
+ auto output_x = outputs[0].flat<float>();
+ auto output_y = outputs[1].flat<float>();
+ EXPECT_EQ(1 * 8 * 8 * 3, output_x.size());
+ EXPECT_EQ(1 * 8 * 8 * 3, output_y.size());
+ for (int i = 0; i < output_x.size(); i++) {
+ EXPECT_NEAR(6.3f, output_x(i), 1e-6f); // a+b+c
+ EXPECT_NEAR(9.6f, output_y(i), 1e-6f); // b+c+d
+ }
+}
+
+TEST(RunTfliteModelOpTest, NumInputsMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Remove a from input.
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT};
+
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(RunTfliteModelOpTest, InputSizesMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Set a to be invalid size.
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 4}), // a invalid size,
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
new file mode 100644
index 0000000000..c96795d499
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+
+#include <vector>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+void RunTFLiteModelStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+
+ std::vector<NodeBuilder::NodeOut> _data = {ops::AsNodeOut(s, input)};
+ ::tensorflow::Node* ret;
+ auto builder = NodeBuilder(output_name(), "RunTFLiteModel")
+ .Input(_data)
+ .Attr("model_file_path", params_.model_file_path)
+ .Attr("input_type", params_.input_type)
+ .Attr("output_type", params_.output_type);
+
+ s.UpdateBuilder(&builder);
+ s.UpdateStatus(builder.Finalize(s.graph(), &ret));
+ if (!s.ok()) return;
+ s.UpdateStatus(s.DoShapeInference(ret));
+ this->stage_output_ = ::tensorflow::Output(ret, 0);
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
new file mode 100644
index 0000000000..90d12d6f42
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// Stage that loads and runs a TFLite model.
+// Inputs: The input to TFLite model.
+// Outputs: The output of running the TFLite model.
+class RunTFLiteModelStage : public Stage {
+ public:
+ // The parameters for the stage.
+ struct Params {
+ string model_file_path;
+ std::vector<TensorShape> output_shape;
+ std::vector<DataType> input_type;
+ std::vector<DataType> output_type;
+ };
+
+ explicit RunTFLiteModelStage(const Params& params) : params_(params) {}
+
+ string name() const override { return "stage_run_tfl_model"; }
+ // TODO(shashishekhar): This stage can have multiple inputs and
+ // outputs, perhaps change the definition of stage.
+ string output_name() const override { return "stage_run_tfl_model_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/contrib/lite/tools/accuracy/stage.h
new file mode 100644
index 0000000000..8292ea2ec7
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/stage.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+
+#include "tensorflow/cc/framework/scope.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage in an evaluation pipeline.
+// Each stage adds a subgraph to the pipeline. Stages can be chained
+// together.
+class Stage {
+ public:
+ Stage() = default;
+ Stage(const Stage&) = delete;
+ Stage& operator=(const Stage&) = delete;
+
+ Stage(const Stage&&) = delete;
+ Stage& operator=(const Stage&&) = delete;
+
+ // Adds a subgraph to given scope that takes in `input` as a parameter.
+ virtual void AddToGraph(const Scope& scope, const Input& input) = 0;
+ virtual ~Stage() {}
+
+ // The name of the stage.
+ // Can be used by derived classes for naming the subscope for the stage
+ // graph.
+ virtual string name() const = 0;
+
+ // The name of the output for the stage.
+ virtual string output_name() const = 0;
+
+ const ::tensorflow::Output& Output() const { return stage_output_; }
+
+ protected:
+ ::tensorflow::Output stage_output_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/contrib/lite/tools/accuracy/utils.cc
new file mode 100644
index 0000000000..f5493301fc
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+
+#include <sys/stat.h>
+
+#include <cstring>
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+DataType GetTFDataType(TfLiteType tflite_type) {
+ switch (tflite_type) {
+ case kTfLiteFloat32:
+ return DT_FLOAT;
+ case kTfLiteUInt8:
+ return DT_UINT8;
+ default:
+ return DT_INVALID;
+ }
+}
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor) {
+ TensorShape shape;
+ for (int i = 0; i < tflite_tensor.dims->size; i++) {
+ shape.AddDim(tflite_tensor.dims->data[i]);
+ }
+ return shape;
+}
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output) {
+ if (!lines_output) {
+ return errors::InvalidArgument("Invalid output");
+ }
+ std::vector<string> lines;
+ std::ifstream stream(file_path, std::ios_base::in);
+ if (!stream) {
+ return errors::InvalidArgument("Unable to open file: ", file_path);
+ }
+ std::string line;
+ while (std::getline(stream, line)) {
+ lines_output->push_back(line);
+ }
+ return Status::OK();
+}
+
+Status GetTFliteModelInfo(const string& model_file_path,
+ ModelInfo* model_info) {
+ if (model_file_path.empty()) {
+ return errors::InvalidArgument("Invalid model file.");
+ }
+ struct stat stat_buf;
+ if (stat(model_file_path.c_str(), &stat_buf) != 0) {
+ int error_num = errno;
+ return errors::InvalidArgument("Invalid model file: ", model_file_path,
+ std::strerror(error_num));
+ }
+
+ std::unique_ptr<tflite::FlatBufferModel> model;
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ model = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (!interpreter) {
+ return errors::InvalidArgument("Invalid model", model_file_path);
+ }
+ for (int i : interpreter->inputs()) {
+ TfLiteTensor* tensor = interpreter->tensor(i);
+ model_info->input_shapes.push_back(utils::GetTFLiteTensorShape(*tensor));
+ model_info->input_types.push_back(utils::GetTFDataType(tensor->type));
+ }
+ return Status::OK();
+}
+
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/contrib/lite/tools/accuracy/utils.h
new file mode 100644
index 0000000000..37cbad4d51
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+struct ModelInfo {
+ std::vector<TensorShape> input_shapes;
+ std::vector<DataType> input_types;
+};
+
+Status GetTFliteModelInfo(const string& model_file_path, ModelInfo* model_info);
+
+DataType GetTFDataType(TfLiteType tflite_type);
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor);
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output);
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
new file mode 100644
index 0000000000..727eba21b6
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace metrics {
+namespace utils {
+namespace {
+
+TEST(UtilsTest, GetTFLiteModelInfoReturnsCorrectly) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Input and outputs have shape : {1,8,8,3}
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo(test_model_file, &model_info);
+ TF_CHECK_OK(status);
+ ASSERT_EQ(4, model_info.input_shapes.size());
+ ASSERT_EQ(4, model_info.input_types.size());
+
+ for (int i = 0; i < 4; i++) {
+ const TensorShape& shape = model_info.input_shapes[i];
+ DataType dataType = model_info.input_types[i];
+ EXPECT_TRUE(shape.IsSameSize({1, 8, 8, 3}));
+ EXPECT_EQ(DT_FLOAT, dataType);
+ }
+}
+
+TEST(UtilsTest, GetTFliteModelInfoIncorrectFile) {
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo("non_existent_file", &model_info);
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 2cb07eb6ec..dc97d22401 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -5,8 +5,8 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
common_copts = ["-Wall"] + tflite_copts()
@@ -35,6 +35,25 @@ cc_binary(
],
)
+cc_binary(
+ name = "benchmark_model_plus_eager",
+ srcs = [
+ "benchmark_main.cc",
+ ],
+ copts = common_copts + ["-DTFLITE_EXTENDED"],
+ linkopts = tflite_linkopts() + select({
+ "//tensorflow:android": [
+ "-pie", # Android 5.0 and later supports only PIE
+ "-lm", # some builtin ops, e.g., tanh, need -lm
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ ":benchmark_tflite_model_plus_eager_lib",
+ ":logging",
+ ],
+)
+
cc_test(
name = "benchmark_test",
srcs = ["benchmark_test.cc"],
@@ -88,7 +107,25 @@ cc_library(
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/profiling:profile_summarizer",
- "//tensorflow/contrib/lite/profiling:profiler",
+ ],
+)
+
+cc_library(
+ name = "benchmark_tflite_model_plus_eager_lib",
+ srcs = [
+ "benchmark_tflite_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_tflite_model.h"],
+ copts = common_copts + ["-DTFLITE_EXTENDED"],
+ deps = [
+ ":benchmark_model_lib",
+ ":logging",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
],
)
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
index f1e257ad10..8d997639fb 100644
--- a/tensorflow/contrib/lite/tools/benchmark/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark.
The instructions below are for running the binary on Desktop and Android,
for iOS please use the
-[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
## Parameters
@@ -17,11 +17,6 @@ The binary takes the following required parameters:
* `graph`: `string` \
The path to the TFLite model file.
-* `input_layer`: `string` \
- The name of the input layer, this is typically the first layer of the model.
-* `input_layer_shape`: `string` \
- The shape of the input layer. This is a comma separated string of the shape
- of tensor of input layer.
and the following optional parameters:
@@ -29,11 +24,13 @@ and the following optional parameters:
The number of threads to use for running TFLite interpreter.
* `warmup_runs`: `int` (default=1) \
The number of warmup runs to do before starting the benchmark.
+* `num_runs`: `int` (default=50) \
+ The number of runs. Increase this to reduce variance.
* `run_delay`: `float` (default=-1.0) \
The delay in seconds between subsequent benchmark runs. Non-positive values
mean use no delay.
* `use_nnapi`: `bool` (default=false) \
- Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+ Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
This API is available on recent Android devices.
## To build/install/run
@@ -75,8 +72,6 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp
```
adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
@@ -93,13 +88,10 @@ For example:
```
bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
--graph=mobilenet_quant_v1_224.tflite \
- --input_layer="Placeholder" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
-The MobileNet graph used as an example here may be downloaded from
-https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+The MobileNet graph used as an example here may be downloaded from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
## Reducing variance between runs on Android.
@@ -117,8 +109,6 @@ can use the following command:
```
adb shell taskset f0 /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=1
```
@@ -205,5 +195,3 @@ Memory (bytes): count=0
Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
```
-
-
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
index 677a1ee68c..cc215a7b7f 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
-#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
#include <cmath>
#include <limits>
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 7f97f5d0cd..02039922b4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,6 +23,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#ifdef TFLITE_EXTENDED
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#endif // TFLITE_EXTENDED
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -261,6 +264,16 @@ void BenchmarkTfLiteModel::Init() {
bool use_nnapi = params_.Get<bool>("use_nnapi");
interpreter->UseNNAPI(use_nnapi);
+
+#ifdef TFLITE_EXTENDED
+ TFLITE_LOG(INFO) << "Instantiating Eager Delegate";
+ delegate_ = EagerDelegate::Create();
+ if (delegate_) {
+ interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
+ }
+#endif // TFLITE_EXTENDED
+
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 9931dcbafe..4c4320a998 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
-#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_
#include <memory>
#include <string>
#include <vector>
+#ifdef TFLITE_EXTENDED
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#endif // TFLITE_EXTENDED
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -52,6 +55,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
public:
BenchmarkTfLiteModel();
BenchmarkTfLiteModel(BenchmarkParams params);
+ virtual ~BenchmarkTfLiteModel() {}
std::vector<Flag> GetFlags() override;
void LogParams() override;
@@ -59,7 +63,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
uint64_t ComputeInputBytes() override;
void Init() override;
void RunImpl() override;
- virtual ~BenchmarkTfLiteModel() {}
struct InputLayerInfo {
std::string name;
@@ -67,6 +70,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
};
private:
+#ifdef TFLITE_EXTENDED
+ std::unique_ptr<EagerDelegate> delegate_;
+#endif // TFLITE_EXTENDED
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
index 2e514ae3ea..6a0affd834 100644
--- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
-#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_
#include <functional>
#include <string>
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
index c8d3307e29..46144f7bf8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/ios/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
@@ -17,8 +17,8 @@ Mobilenet_1.0_224 model
## To build/install/run
-- Follow instructions at [iOS build for TFLite]
-(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
+- Follow instructions at
+[iOS build for TFLite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
to build TFLite.
Running
diff --git a/tensorflow/contrib/lite/tools/benchmark/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h
index 9e9292e2fe..4045d1e731 100644
--- a/tensorflow/contrib/lite/tools/benchmark/logging.h
+++ b/tensorflow/contrib/lite/tools/benchmark/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
-#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_
// LOG and CHECK macros for benchmarks.
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index e30cc1d70e..59bdb10811 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32
TARGET := $(HOST_OS)
TARGET_ARCH := $(HOST_ARCH)
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
+-I$(MAKEFILE_DIR)/downloads/farmhash/src \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
# These are the default libraries needed, but they can be added to or
# overridden by the platform-specific settings in target makefiles.
LIBS := \
@@ -44,55 +59,17 @@ ARFLAGS := -r
TARGET_TOOLCHAIN_PREFIX :=
CC_PREFIX :=
-# These target-specific makefiles should modify or replace options like
-# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
-# based on platforms or architectures should happen within these files, to
-# keep this main makefile focused on the sources and dependencies.
-include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
-
-# Where compiled objects are stored.
-GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
-OBJDIR := $(GENDIR)obj/
-BINDIR := $(GENDIR)bin/
-LIBDIR := $(GENDIR)lib/
-
-INCLUDES := \
--I. \
--I$(MAKEFILE_DIR)/../../../../../ \
--I$(MAKEFILE_DIR)/../../../../../../ \
--I$(MAKEFILE_DIR)/downloads/ \
--I$(MAKEFILE_DIR)/downloads/eigen \
--I$(MAKEFILE_DIR)/downloads/gemmlowp \
--I$(MAKEFILE_DIR)/downloads/neon_2_sse \
--I$(MAKEFILE_DIR)/downloads/farmhash/src \
--I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(OBJDIR)
-# This is at the end so any globally-installed frameworks like protobuf don't
-# override local versions in the source tree.
-INCLUDES += -I/usr/local/include
-
-CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
-CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
-
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
LIB_NAME := libtensorflow-lite.a
-LIB_PATH := $(LIBDIR)$(LIB_NAME)
-
-# A small example program that shows how to link against the library.
-MINIMAL_PATH := $(BINDIR)minimal
# Benchmark static library and binary
BENCHMARK_LIB_NAME := benchmark-lib.a
BENCHMARK_BINARY_NAME := benchmark_model
-BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
-BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+# A small example program that shows how to link against the library.
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
-MINIMAL_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
@@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c)
+$(wildcard tensorflow/contrib/lite/*.c) \
+$(wildcard tensorflow/contrib/lite/c/*.c) \
+$(wildcard tensorflow/contrib/lite/core/api/*.cc)
ifneq ($(BUILD_TYPE),micro)
CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
@@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
-# File names of the intermediate files target compilation generates.
-TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
-LIB_OBJS := $(TF_LITE_CC_OBJS)
# Benchmark sources
BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
@@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \
$(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
$(BENCHMARK_ALL_SRCS))
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MINIMAL_SRCS) \
+ $(PROFILER_SRCS) \
+ $(PROFILER_SUMMARY_SRCS) \
+ $(TF_LITE_CC_SRCS) \
+ $(BENCHMARK_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+LIB_PATH := $(LIBDIR)$(LIB_NAME)
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+MINIMAL_BINARY := $(BINDIR)minimal
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
+
+LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
+
BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
@@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY)
# The target that's compiled for micro-controllers
micro: $(LIB_PATH)
@@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
-$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
+$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
- -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
+ -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
-
$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
benchmark_lib: $(BENCHMARK_LIB)
-$(info $(BENCHMARK_BINARY))
+
$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
@@ -213,4 +221,4 @@ cleantarget:
$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d
--include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD
new file mode 100644
index 0000000000..51ccaedc23
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/BUILD
@@ -0,0 +1,25 @@
+# TODO(suharshs): Write quantize_weights tests that use small exportable files.
+# Then we can remove this file.
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "quantize_weights",
+ srcs = ["quantize_weights.cc"],
+ hdrs = ["quantize_weights.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:tflite_portable_logging",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
new file mode 100644
index 0000000000..93fe576583
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
@@ -0,0 +1,70 @@
+# TFLite Quantize Weights Tool
+
+## Recommended usage
+
+The Quantize Weights transformation is integrated with
+[tflite_convert](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md#transformation-flags).
+
+The recommended way of invoking this tool is by simply adding the
+`--post_training_quantize` flag to your original tflite_convert invocation. For
+example,
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --saved_model_dir=/tmp/saved_model \
+ --post_training_quantize
+```
+
+## Overview
+
+The Quantize Weights tool provides a simple way to quantize the weights for a
+float TFLite model.
+
+TODO(raghuramank): Add link to weight quantization tutorial.
+
+### Size reduction
+
+float32 weights will be converted to 8 bit integers. This results in a model
+that is around 1/4th the size of the original model.
+
+### Latency reduction
+
+TFLite also has "hybrid" kernels implemented for many operations. These "hybrid"
+kernels take 8 bit integer weights and float inputs, dynamically quantize the
+inputs tensor (based on the input tensor's min and max elements), and does
+computations using the 8 bit integer values. This results in a 2-4x reduction in
+latency for "hybrid" kernels. In this mode the inference type is still FLOAT
+since the inputs and output to each operation is still float.
+
+For operations that do not yet have "hybrid" kernels implemented, we introduce a
+Dequantize operation after 8 bit integer weights. These convert weights back to
+float32 during inference to allow original float32 kernels to run. Since we
+cache dequantized results, the result of each of this dequantized path will be
+on-par with the original float model.
+
+TODO(yunluli): Fill in latency results from latency experiments.
+
+### Accuracy
+
+Since this technique quantizes weights after the model has already been trained,
+there can be accuracy drops depending on the model. For common CNN networks, the
+observed accuracy drops are small and can be seen below.
+
+TODO(yunluli): Fill in accuracy results from accuracy experiments.
+
+## Direct usage
+
+One can also invoke the Quantize Weights directly via C++ if they have a float
+`::tflite::Model` that they want to convert. They must provide a
+`flatbuffers::FlatBufferBuilder` which owns the underlying buffer of the created
+model. Here is an example invocation:
+
+```
+::tflite::Model* input_model = ...;
+flatbuffers::FlatBufferBuilder builder;
+TfLiteStatus status = ::tflite::optimize::QuantizeWeights(&builder, input_model);
+CHECK(status, kTfLiteStatusOk);
+const uint8_t* buffer = builder->GetBufferPointer();
+tflite::Model* output_model = ::tflite::GetModel(buffer);
+```
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
new file mode 100644
index 0000000000..b863108aa4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -0,0 +1,432 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+namespace optimize {
+
+namespace {
+
+typedef struct {
+ TensorT* tensor;
+ // The index of the tensor to quantize in subgraph->tensors.
+ int32_t tensor_idx;
+ // The index of the tensor of the weight tensor to be quantize in op->inputs.
+ int32_t op_input_idx;
+ // True if the tensor supports hybrid evaluation.
+ bool eval_hybrid;
+} TensorInfo;
+
+// The default minimum number of elements a weights array must have to be
+// quantized by this transformation.
+const int kWeightsMinNumElementsDefault = 1024;
+
+// Nudge min and max so that floating point 0 falls exactly on a quantized
+// value, returning the nudges scale and zero_point.
+//
+// Although this code originates from FakeQuantization in quantized training,
+// we may deviate from that implementation as we please since we do not fine
+// tune the weights with quantized training.
+void GetAsymmetricQuantizationParams(
+ const float min, const float max, const int quant_min, const int quant_max,
+ QuantizationParametersT* quantization_params) {
+ // Adjust the boundaries to guarantee 0 is included.
+ const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f);
+ const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f);
+ const float scale = (max - min) / (quant_max_float - quant_min_float);
+ const float zero_point_from_min = quant_min_float - min / scale;
+ int64_t zero_point;
+ if (zero_point_from_min < quant_min_float) {
+ zero_point = static_cast<int64_t>(quant_min);
+ } else if (zero_point_from_min > quant_max_float) {
+ zero_point = static_cast<int64_t>(quant_max);
+ } else {
+ zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
+ }
+ quantization_params->scale = std::vector<float>(1, scale);
+ quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
+}
+
+// Returns the number of elements in tensor.
+uint64_t NumElements(const TensorT* tensor) {
+ if (tensor->shape.empty()) {
+ LOG(FATAL) << "Tensor has no shape information.";
+ }
+ uint64_t num_elements = 1;
+ for (const uint64_t dim : tensor->shape) {
+ num_elements *= dim;
+ }
+ return num_elements;
+}
+
+uint64_t CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
+ int32_t tensor_idx) {
+ uint64_t count = 0;
+ for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
+ const OperatorT* op = subgraph->operators[op_idx].get();
+ if (op == nullptr) {
+ continue;
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] == tensor_idx) {
+ count++;
+ }
+ }
+ }
+ return count;
+}
+
+// Gets the list of op->inputs indices of the weights inputs to be quantized for
+// the provided op.
+std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
+ return {1};
+ } else if (op_code == BuiltinOperator_SVDF) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/svdf.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/lstm.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
+ } else if (op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/basic_rnn.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16,
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+ return {1, 2, 4, 5};
+ }
+ return {};
+}
+
+// Returns true if the operator supports hybrid evaluation.
+bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
+ // Operations that support hybrid evaluation.
+ bool eval_hybrid = false;
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
+ op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ eval_hybrid = true;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
+ // Only lstm kernel_type full supports hybrid evaluation.
+ if (options->kernel_type == LSTMKernelType_FULL) {
+ eval_hybrid = true;
+ }
+ }
+ return eval_hybrid;
+}
+
+// Returns a vector of TensorInfos for each input tensor of op that should be
+// quantized.
+std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
+ const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
+ bool use_hybrid_evaluation) {
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+
+ std::vector<TensorInfo> tensor_infos;
+
+ bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code);
+
+ std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
+ for (const int32_t op_input_idx : op_input_indices) {
+ int32_t tensor_idx = op->inputs[op_input_idx];
+
+ if (tensor_idx == -1) {
+ LOG(INFO) << "Skipping optional tensor input " << op_input_idx
+ << " of operation " << EnumNameBuiltinOperator(op_code);
+ continue;
+ }
+
+ TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+ // same weight array, things may break. (i.e. SSD object detection)
+ if (!eval_hybrid &&
+ CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is shared between multiple multiple operations.";
+ continue;
+ }
+
+ if (tensor->type != TensorType_FLOAT32) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is not type float.";
+ continue;
+ }
+
+ const uint64_t num_elements = NumElements(tensor);
+ if (num_elements < weights_min_num_elements) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " because it has fewer than " << weights_min_num_elements
+ << " elements (" << num_elements << ").";
+ // If one of the weights isn't quantized, then we cannot use the hybrid
+ // kernel for this operation, since it expects everything to be quantized.
+ eval_hybrid = false;
+ continue;
+ }
+
+ TensorInfo tensor_info;
+ tensor_info.eval_hybrid = eval_hybrid;
+ tensor_info.op_input_idx = op_input_idx;
+ tensor_info.tensor_idx = tensor_idx;
+ tensor_info.tensor = tensor;
+
+ tensor_infos.push_back(tensor_info);
+ }
+
+ return tensor_infos;
+}
+
+// Quantizes tensor using asymmetric quantization with the min and max elements
+// of the tensor. This is needed to pass to Dequantize operations.
+TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for float evaluation.";
+
+ // Compute the quantization params.
+ float min_value = *std::min_element(float_data, float_data + num_elements);
+ float max_value = *std::max_element(float_data, float_data + num_elements);
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ GetAsymmetricQuantizationParams(min_value, max_value, 0, 255,
+ tensor->quantization.get());
+
+ // Quantize the buffer.
+ std::vector<uint8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+ const double inverse_scale = 1. / tensor->quantization->scale[0];
+ for (std::size_t i = 0; i < num_elements; i++) {
+ const float src_val = float_data[i];
+ double scaled_val;
+ if (tensor->quantization->scale[0] == 0) {
+ scaled_val = tensor->quantization->zero_point[0];
+ } else {
+ scaled_val =
+ tensor->quantization->zero_point[0] + inverse_scale * src_val;
+ }
+ uint8_t integer_val = static_cast<uint8_t>(std::round(scaled_val));
+ quantized_buffer[i] = integer_val;
+ }
+ model->buffers[tensor->buffer]->data = quantized_buffer;
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
+// Quantizes tensor using symmetric quantization with the min and max elements
+// of the tensor. This is need for operations with hybrid evaluation
+// implemented.
+TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for hybrid evaluation.";
+
+ std::vector<int8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+
+ float min_value, max_value, scaling_factor;
+ tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
+ quantized_buffer.data(), &min_value,
+ &max_value, &scaling_factor);
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ tensor->quantization->scale = std::vector<float>(1, scaling_factor);
+ tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
+
+ uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
+ model->buffers[tensor->buffer]->data.assign(uint8_buffer,
+ uint8_buffer + num_elements);
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
+// Returns the index of the Dequantize op_code.
+// If a Dequantize op_code doesn't exist, adds it and returns its index.
+int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
+ for (int i = 0; i < model->operator_codes.size(); ++i) {
+ if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
+ return i;
+ }
+ }
+ model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
+ int op_code_idx = model->operator_codes.size() - 1;
+ model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
+ // TODO(suharshs): How should the version be set in this op_code?
+
+ // Return the index of the newly placed OperatorCodeT.
+ return op_code_idx;
+}
+
+// Creates a Dequantize OperatorT object.
+void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
+ int32_t input, int32_t output) {
+ OperatorT* op_raw = new OperatorT;
+ op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
+ op_raw->inputs = {input};
+ op_raw->outputs = {output};
+
+ op->reset(op_raw);
+}
+
+// Create a new TensorT object.
+void MakeTensor(const string& name, const std::vector<int32_t>& shape,
+ std::unique_ptr<TensorT>* tensor) {
+ TensorT* tensor_raw = new TensorT;
+ tensor_raw->name = name;
+ tensor_raw->shape = shape;
+
+ tensor->reset(tensor_raw);
+}
+
+TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements) {
+ std::unique_ptr<ModelT> model;
+ model.reset(input_model->UnPack());
+
+ // TODO(suharshs): When models support multiple subgraphs, add support.
+ if (model->subgraphs.size() != 1) {
+ LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
+ "subgraph.";
+ return kTfLiteError;
+ }
+
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+
+ std::vector<std::unique_ptr<OperatorT>> new_operators;
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+
+ std::vector<TensorInfo> tensor_infos = GetQuantizableTensorsFromOperator(
+ model.get(), op, weights_min_num_elements, use_hybrid_evaluation);
+
+ for (const TensorInfo& tensor_info : tensor_infos) {
+ if (tensor_info.eval_hybrid) {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ SymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+ } else {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ AsymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+
+ // Create a new tensor to be the output of the dequantize op.
+ std::unique_ptr<TensorT> dequantize_output;
+ MakeTensor(tensor_info.tensor->name + "_dequantize",
+ tensor_info.tensor->shape, &dequantize_output);
+ const int32_t dequantize_output_idx = subgraph->tensors.size();
+ subgraph->tensors.push_back(std::move(dequantize_output));
+
+ // Create the Dequantize operation.
+ std::unique_ptr<OperatorT> dequantize_op;
+ MakeDequantizeOperator(model.get(), &dequantize_op,
+ tensor_info.tensor_idx, dequantize_output_idx);
+
+ // Update the op_input of tensor_idx to dequantize_output_idx.
+ op->inputs[tensor_info.op_input_idx] = dequantize_output_idx;
+
+ // Insert the newly created Dequantize operation.
+ new_operators.push_back(std::move(dequantize_op));
+ }
+ }
+ // After (maybe) quantizing inputs, we copy the operator into the new list.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+ }
+
+ // At this point all unique_ptrs in the original operators are invalid, and
+ // we need to replace it with the new_operators vector.
+ subgraph->operators = std::move(new_operators);
+
+ flatbuffers::Offset<Model> output_model_location =
+ Model::Pack(*builder, model.get());
+ FinishModelBuffer(*builder, output_model_location);
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+namespace internal {
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation) {
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
+ kWeightsMinNumElementsDefault);
+}
+} // namespace internal
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements) {
+ return QuantizeWeightsInternal(builder, input_model, true,
+ weights_min_num_elements);
+}
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model) {
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, true,
+ kWeightsMinNumElementsDefault);
+}
+
+} // namespace optimize
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
new file mode 100644
index 0000000000..706f10b87b
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+
+#include <memory>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+
+// Quantizes input_model and populates the provided builder with the new model.
+// By default only weights tensors weight more than 1024 elements will be
+// quantized.
+//
+// A tflite::Model can be obtained from the builder with:
+// const uint8_t* buffer = builder->GetBufferPointer();
+// tflite::Model* model = GetModel(buffer);
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model);
+
+// Same as above, but only weights with greater than or equal
+// weights_min_num_elements elements will be quantized.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements);
+
+namespace internal {
+// If use_hybrid_evaluation is false, will disable using hybrid eval for
+// operations that support it.
+//
+// We use this internal QuantizeWeights call to test models with hybrid
+// evaluation disabled.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation);
+} // namespace internal
+
+} // namespace optimize
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
new file mode 100644
index 0000000000..387b3471c2
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -0,0 +1,226 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <memory>
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+namespace {
+
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ int GetElementsNum(const TensorT* tensor) {
+ int tensor_size = 1;
+ for (const int dim : tensor->shape) {
+ tensor_size *= dim;
+ }
+ return tensor_size;
+ }
+
+ const OperatorT* GetOpWithOutput(const SubGraphT* subgraph,
+ int32_t output_tensor_idx) {
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ if (std::find(op->outputs.begin(), op->outputs.end(),
+ output_tensor_idx) != op->outputs.end()) {
+ return op;
+ }
+ }
+ return nullptr;
+ }
+
+ void SymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer,
+ float scale) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const int8_t* output_buffer_data =
+ reinterpret_cast<const int8_t*>(output_buffer->data.data());
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff = input_buffer_data[i] - (output_buffer_data[i] * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void AsymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer, float scale,
+ int64_t zero_point) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const uint8_t* output_buffer_data = output_buffer->data.data();
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff =
+ input_buffer_data[i] - ((output_buffer_data[i] - zero_point) * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void CheckWeights(const Model* input_model_packed,
+ const Model* output_model_packed,
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements = 1024) {
+ std::unique_ptr<ModelT> input_model;
+ input_model.reset(input_model_packed->UnPack());
+
+ std::unique_ptr<ModelT> output_model;
+ output_model.reset(output_model_packed->UnPack());
+
+ SubGraphT* subgraph = output_model->subgraphs.at(0).get();
+
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ const BuiltinOperator op_code =
+ output_model->operator_codes[op->opcode_index]->builtin_code;
+
+ // These are the operations that should be quantized.
+ // TODO(suharshs): Right now this test only checks the relevant operations
+ // for the mobilenet v1 model used in the tests below.
+ int32_t tensor_idx;
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED) {
+ tensor_idx = op->inputs[1];
+ } else {
+ continue;
+ }
+
+ bool eval_hybrid = false;
+ // These are the ops that support hybrid evaluation.
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D) {
+ eval_hybrid = true;
+ }
+
+ const TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ int tensor_size = GetElementsNum(tensor);
+ // If the tensor_size is less than 1024 we expect the tensor to remain
+ // unquantized.
+ if (tensor_size < weights_min_num_elements) {
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32)
+ << tensor->name << " of type " << tensor->type;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ // The weight tensor should not come from a dequantize op.
+ ASSERT_TRUE(preceding_op == nullptr);
+ } else if (use_hybrid_evaluation && eval_hybrid) {
+ // The input to the op should still be uint8.
+ ASSERT_TRUE(tensor->type == TensorType_UINT8) << tensor->name;
+ // The weight tensor should not come from a dequantize op.
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op == nullptr);
+
+ // Test symmetric quantization.
+ SymmetricDequantizeAndCompare(
+ input_model->buffers[tensor->buffer].get(),
+ output_model->buffers[tensor->buffer].get(),
+ tensor->quantization->scale[0]);
+
+ } else {
+ // The input to the op should still be float.
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op != nullptr);
+ // The float input should be the dequantize output.
+ ASSERT_TRUE(output_model->operator_codes[preceding_op->opcode_index]
+ ->builtin_code == BuiltinOperator_DEQUANTIZE);
+ // Finally, ensure that the input to the dequantize operation is
+ // quantized.
+ const TensorT* quantized_tensor =
+ subgraph->tensors[preceding_op->inputs[0]].get();
+ ASSERT_TRUE(quantized_tensor->type == TensorType_UINT8);
+
+ // Test the assymetric quantization.
+ AsymmetricDequantizeAndCompare(
+ input_model->buffers[quantized_tensor->buffer].get(),
+ output_model->buffers[quantized_tensor->buffer].get(),
+ quantized_tensor->quantization->scale[0],
+ quantized_tensor->quantization->zero_point[0]);
+ }
+ }
+ }
+};
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithHybrid) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(input_model, output_model, true);
+}
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Disable hybrid evaluation.
+ EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(input_model, output_model, false);
+}
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Make weights_min_size sufficiently large such that no quantization should
+ // happen, i.e. the original model is the same size as the old one.
+ const uint64_t kWeightsMinNumElements = 1000000;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements),
+ kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+ CheckWeights(input_model, output_model, true, kWeightsMinNumElements);
+}
+
+// TODO(suharshs): Add tests that run the resulting model.
+
+} // namespace
+} // namespace optimize
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000000..67ff1ea124
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mnist_tflite",
+ srcs = [
+ "dataset.py",
+ "mnist_tflite.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000000..ba49dfcc9b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+ """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+ dt = np.dtype(np.uint32).newbyteorder('>')
+ return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+ """Validate that filename corresponds to images for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_images, unused
+ rows = read32(f)
+ cols = read32(f)
+ if magic != 2051:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+ if rows != 28 or cols != 28:
+ raise ValueError(
+ 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+ (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+ """Validate that filename corresponds to labels for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_items, unused
+ if magic != 2049:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+
+
+def download(directory, filename):
+ """Download (and unzip) a file from the MNIST dataset if not already done."""
+ filepath = os.path.join(directory, filename)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(directory):
+ tf.gfile.MakeDirs(directory)
+ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+ url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+ _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+ print('Downloading %s to %s' % (url, zipped_filepath))
+ urllib.request.urlretrieve(url, zipped_filepath)
+ with gzip.open(zipped_filepath, 'rb') as f_in, \
+ tf.gfile.Open(filepath, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(zipped_filepath)
+ return filepath
+
+
+def dataset(directory, images_file, labels_file):
+ """Download and parse MNIST dataset."""
+
+ images_file = download(directory, images_file)
+ labels_file = download(directory, labels_file)
+
+ check_image_file_header(images_file)
+ check_labels_file_header(labels_file)
+
+ def decode_image(image):
+ # Normalize from [0, 255] to [0.0, 1.0]
+ image = tf.decode_raw(image, tf.uint8)
+ image = tf.cast(image, tf.float32)
+ image = tf.reshape(image, [784])
+ return image / 255.0
+
+ def decode_label(label):
+ label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
+ label = tf.reshape(label, []) # label is a scalar
+ return tf.to_int32(label)
+
+ images = tf.data.FixedLengthRecordDataset(
+ images_file, 28 * 28, header_bytes=16).map(decode_image)
+ labels = tf.data.FixedLengthRecordDataset(
+ labels_file, 1, header_bytes=8).map(decode_label)
+ return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+ """tf.data.Dataset object for MNIST training data."""
+ return dataset(directory, 'train-images-idx3-ubyte',
+ 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+ """tf.data.Dataset object for MNIST test data."""
+ return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index 8ccb65c24f..7950653da9 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -14,8 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/util.h"
+#include <cstring>
+
namespace tflite {
+bool IsEagerOp(const char* custom_name) {
+ return custom_name && strncmp(custom_name, kEagerCustomCodePrefix,
+ strlen(kEagerCustomCodePrefix)) == 0;
+}
+
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
return ConvertArrayToTfLiteIntArray(input.size(), input.data());
}
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index 3c4801183b..6d81f844f8 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -22,10 +22,20 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_UTIL_H_
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
+// The prefix of Eager op custom code.
+// This will be matched agains the `custom_code` field in `OperatorCode`
+// Flatbuffer Table.
+// WARNING: This is an experimental API and subject to change.
+constexpr char kEagerCustomCodePrefix[] = "Eager";
+
+// Checks whether the prefix of the custom name indicates the operation is an
+// Eager operation.
+bool IsEagerOp(const char* custom_name);
+
// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership
// of the returned pointer.
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 04579c53aa..c5c1709f1d 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
@@ -41,6 +41,16 @@ TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) {
TfLiteIntArrayFree(output);
}
+TEST(UtilTest, IsEagerOp) {
+ EXPECT_TRUE(IsEagerOp("Eager"));
+ EXPECT_TRUE(IsEagerOp("EagerOp"));
+ EXPECT_FALSE(IsEagerOp("eager"));
+ EXPECT_FALSE(IsEagerOp("Eage"));
+ EXPECT_FALSE(IsEagerOp("OpEager"));
+ EXPECT_FALSE(IsEagerOp(nullptr));
+ EXPECT_FALSE(IsEagerOp(""));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 291972cce3..f83765a48d 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_lookup_ops
@@ -39,6 +42,7 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex
from tensorflow.python.ops.lookup_ops import TextFileInitializer
from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer
# pylint: enable=unused-import
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.deprecation import deprecated
@@ -285,7 +289,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None):
return table.lookup(tensor)
-class MutableHashTable(LookupInterface):
+class MutableHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation.
Data can be inserted by calling the insert method. It does not support
@@ -336,6 +340,13 @@ class MutableHashTable(LookupInterface):
dtype=value_dtype)
self._value_shape = self._default_value.get_shape()
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
# The table must be shared if checkpointing is requested for multi-worker
# training to work correctly. Use the node name if no shared_name has been
# explicitly specified.
@@ -355,9 +366,12 @@ class MutableHashTable(LookupInterface):
value_dtype=value_dtype,
value_shape=self._default_value.get_shape(),
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableHashTable, self).__init__(key_dtype, value_dtype,
- self._table_ref.op.name.split(
- "/")[-1])
+ op_name)
if checkpoint:
saveable = MutableHashTable._Saveable(self, name)
@@ -446,6 +460,10 @@ class MutableHashTable(LookupInterface):
self._table_ref, self._key_dtype, self._value_dtype, name=name)
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableHashTable."""
@@ -458,14 +476,15 @@ class MutableHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
self.op._table_ref, restored_tensors[0], restored_tensors[1])
-class MutableDenseHashTable(LookupInterface):
+class MutableDenseHashTable(LookupInterface, checkpointable.CheckpointableBase):
"""A generic mutable hash table implementation using tensors as backing store.
Data can be inserted by calling the insert method. It does not support
@@ -536,6 +555,13 @@ class MutableDenseHashTable(LookupInterface):
use_node_name_sharing = checkpoint and shared_name is None
empty_key = ops.convert_to_tensor(
empty_key, dtype=key_dtype, name="empty_key")
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly and shared_name is None:
+ # TODO(allenl): This will leak memory due to kernel caching by the
+ # shared_name attribute value (but is better than the alternative of
+ # sharing everything by default when executing eagerly; hopefully creating
+ # tables in a loop is uncommon).
+ shared_name = "table_%d" % (ops.uid(),)
self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
empty_key=empty_key,
shared_name=shared_name,
@@ -544,8 +570,12 @@ class MutableDenseHashTable(LookupInterface):
value_shape=self._value_shape,
initial_num_buckets=initial_num_buckets,
name=name)
+ if executing_eagerly:
+ op_name = None
+ else:
+ op_name = self._table_ref.op.name.split("/")[-1]
super(MutableDenseHashTable, self).__init__(
- key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])
+ key_dtype, value_dtype, op_name)
if checkpoint:
saveable = MutableDenseHashTable._Saveable(self, name)
@@ -636,6 +666,11 @@ class MutableDenseHashTable(LookupInterface):
return exported_keys, exported_values
+ def _gather_saveables_for_checkpoint(self):
+ """For object-based checkpointing."""
+ return {"table": functools.partial(
+ MutableDenseHashTable._Saveable, table=self)}
+
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableDenseHashTable."""
@@ -648,7 +683,8 @@ class MutableDenseHashTable(LookupInterface):
# pylint: disable=protected-access
super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name)
- def restore(self, restored_tensors, unused_restored_shapes):
+ def restore(self, restored_tensors, restored_shapes):
+ del restored_shapes # unused
# pylint: disable=protected-access
with ops.colocate_with(self.op._table_ref):
return gen_lookup_ops.lookup_table_import_v2(
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 81257e1de5..89b538d1ba 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -38,12 +38,13 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
+from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -67,7 +68,7 @@ class HashTableOpTest(test.TestCase):
self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -85,7 +86,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -104,7 +105,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -121,7 +122,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -149,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -164,7 +165,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -187,7 +188,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -209,7 +210,7 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup.HashTable(
@@ -217,7 +218,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup.HashTable(
lookup.KeyValueTensorInitializer(
@@ -231,7 +232,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -243,7 +244,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -282,7 +283,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testHashTableInt32String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int32)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -300,7 +301,7 @@ class HashTableOpTest(test.TestCase):
class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -332,7 +333,7 @@ class MutableHashTableOpTest(test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(20.0, name="v1")
@@ -357,7 +358,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
default_val = -1
@@ -383,6 +384,59 @@ class MutableHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(20.0, name="v1")
+
+ default_val = -1
+ keys = constant_op.constant(["b", "c", "d"], dtypes.string)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+ self.evaluate([v0.initializer, v1.initializer])
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(0, self.evaluate(table.size()))
+ self.evaluate(table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ save_path = checkpoint.save(save_prefix)
+ del table, checkpoint, v0, v1
+
+ v0 = variables.Variable(-1.0, name="v0")
+ v1 = variables.Variable(-1.0, name="v1")
+ default_val = -1
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True)
+ self.evaluate(table.insert(
+ constant_op.constant(["a", "c"], dtypes.string),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(table.size()))
+
+ checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1)
+
+ # Restore the saved values in the parameter nodes.
+ checkpoint.restore(save_path).run_restore_ops()
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, self.evaluate(v0))
+ self.assertEqual(20.0, self.evaluate(v1))
+
+ self.assertAllEqual(3, self.evaluate(table.size()))
+
+ input_string = constant_op.constant(["a", "b", "c", "d", "e"],
+ dtypes.string)
+ output = table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testSharing(self):
# Start a server to store the table state
server = server_lib.Server(
@@ -416,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
def testMutableHashTableOfTensors(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -446,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
def testMutableHashTableExportInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -477,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(expected_output, output2.eval())
def testMutableHashTableOfTensorsInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
@@ -509,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testMutableHashTableInvalidDefaultValue(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
@@ -517,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
def testMutableHashTableDuplicateInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@@ -535,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([3, 1, -1], result)
def testMutableHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -554,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableInsertHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
@@ -571,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, 3, -1], result)
def testMutableHashTableOfTensorsFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
@@ -592,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase):
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
def testMultipleMutableHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -622,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testMutableHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -639,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -680,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase):
lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
def testMutableHashTableStringFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.5
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
@@ -698,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, default_val], result)
def testMutableHashTableIntFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.0
keys = constant_op.constant([3, 7, 0], dtypes.int64)
values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
@@ -716,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([-1.2, 9.9, default_val], result)
def testMutableHashTableInt64String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int64)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -737,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase):
class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -755,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testBasicBool(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
@@ -773,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([True, True, False], result)
def testLookupUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -789,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMapStringToFloat(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant(["a", "b", "c"], dtypes.string)
values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
@@ -812,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
@@ -831,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, -1.5], result)
def testVectorValues(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
dtypes.int64)
@@ -864,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
result)
def testVectorKeys(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
@@ -895,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([10, 11, -1], result)
def testResize(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -923,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -958,7 +1012,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
default_value = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -983,7 +1037,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
table = lookup.MutableDenseHashTable(
dtypes.int64,
dtypes.int64,
@@ -1010,11 +1064,65 @@ class MutableDenseHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
+ @test_util.run_in_graph_and_eager_modes
+ def testObjectSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "save_restore")
+ save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ default_value = -1
+ empty_key = 0
+ keys = constant_op.constant([11, 12, 13], dtypes.int64)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ save_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=32)
+
+ save_checkpoint = checkpointable.Checkpoint(table=save_table)
+
+ self.assertAllEqual(0, self.evaluate(save_table.size()))
+ self.evaluate(save_table.insert(keys, values))
+ self.assertAllEqual(3, self.evaluate(save_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(save_table.export()[0])))
+
+ save_path = save_checkpoint.save(save_prefix)
+ del save_table, save_checkpoint
+
+ load_table = lookup.MutableDenseHashTable(
+ dtypes.int64,
+ dtypes.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t1",
+ checkpoint=True,
+ initial_num_buckets=64)
+ self.evaluate(load_table.insert(
+ constant_op.constant([11, 14], dtypes.int64),
+ constant_op.constant([12, 24], dtypes.int64)))
+ self.assertAllEqual(2, self.evaluate(load_table.size()))
+ self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
+
+ restore_checkpoint = checkpointable.Checkpoint(table=load_table)
+
+ # Restore the saved values in the parameter nodes.
+ restore_checkpoint.restore(save_path).run_restore_ops()
+
+ self.assertAllEqual(3, self.evaluate(load_table.size()))
+ self.assertAllEqual(32, len(self.evaluate(load_table.export()[0])))
+
+ input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
+ output = load_table.lookup(input_string)
+ self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
+
def testVectorSaveRestore(self):
save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
@@ -1039,7 +1147,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant([-1, -2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1074,7 +1182,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64)
@@ -1099,7 +1207,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
empty_key = constant_op.constant([11, 13], dtypes.int64)
default_value = constant_op.constant(-1, dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1130,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
def testReprobe(self):
- with self.test_session():
+ with self.cached_session():
# Insert 6 keys into a table with 8 buckets.
# The values are chosen to make sure collisions occur when using GCC STL
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
@@ -1155,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
def testCustomEmptyKey(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1173,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testErrors(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
@@ -1220,7 +1328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1231,7 +1339,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -1245,7 +1353,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -1262,7 +1370,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int32)
@@ -1276,7 +1384,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int64)
@@ -1290,7 +1398,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1301,7 +1409,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -1331,7 +1439,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1343,7 +1451,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -1358,7 +1466,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1370,7 +1478,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -1391,21 +1499,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
@@ -1434,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -1445,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -1457,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1467,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_mapping(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -1482,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
@@ -1501,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase):
class StringToIndexTest(test.TestCase):
def test_string_to_index(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1512,7 +1620,7 @@ class StringToIndexTest(test.TestCase):
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
feats = constant_op.constant(["hello", "hola"])
_ = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1522,7 +1630,7 @@ class StringToIndexTest(test.TestCase):
def test_string_to_index_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(
@@ -1543,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table(self):
vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
@@ -1555,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1567,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -1580,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1592,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1605,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1619,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1630,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_value)
@@ -1646,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
class IndexToStringTest(test.TestCase):
def test_index_to_string(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1658,7 +1766,7 @@ class IndexToStringTest(test.TestCase):
feats.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1670,7 +1778,7 @@ class IndexToStringTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([1, 2, 4], dtypes.int64)
feats = lookup.index_to_string(
@@ -1710,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
@@ -1729,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.LINE_NUMBER
value_index = lookup.TextFileIndex.WHOLE_LINE
@@ -1750,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -1772,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -1786,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.WHOLE_LINE
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1799,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1814,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup.HashTable(
@@ -1853,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup.HashTable(
@@ -1863,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -1914,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer("old_file.txt", dtypes.string,
@@ -1941,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -1964,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup.HashTable(
@@ -1982,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2000,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2025,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2046,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2068,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2088,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2109,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2131,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -2186,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2208,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2232,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -2270,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -2299,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -2328,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -2356,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py
index db58647d48..92b380df53 100644
--- a/tensorflow/contrib/losses/__init__.py
+++ b/tensorflow/contrib/losses/__init__.py
@@ -15,7 +15,7 @@
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py
index 6e9d1d4a77..1675387227 100644
--- a/tensorflow/contrib/losses/python/losses/__init__.py
+++ b/tensorflow/contrib/losses/python/losses/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 2a442a8fc8..c0aec09778 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -43,68 +43,68 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.absolute_difference(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2,])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -117,12 +117,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -141,7 +141,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -154,7 +154,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -166,7 +166,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -179,7 +179,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -191,7 +191,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -203,12 +203,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -223,7 +223,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -253,7 +253,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights = [2.3, 2.4, 2.5]
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -268,7 +268,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights_placeholder = array_ops.placeholder(
dtypes.float32, shape=[None, None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -280,12 +280,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -320,7 +320,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -331,7 +331,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -342,7 +342,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -353,7 +353,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -363,7 +363,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -374,7 +374,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -384,7 +384,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -394,7 +394,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -404,12 +404,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -422,7 +422,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -435,7 +435,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -448,7 +448,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -462,7 +462,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -484,7 +484,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -498,7 +498,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None, None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -506,7 +506,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -537,7 +537,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -546,7 +546,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -582,11 +582,11 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -608,7 +608,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -641,33 +641,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = loss_ops.log_loss(tf_predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = loss_ops.log_loss(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -675,7 +675,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -685,7 +685,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -695,7 +695,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -706,7 +706,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -715,7 +715,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -724,12 +724,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._predictions, self._labels, weights)
@@ -742,7 +742,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -756,7 +756,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -769,7 +769,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -780,35 +780,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = loss_ops.hinge_loss(logits, labels).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = loss_ops.hinge_loss(logits, labels)
self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -817,7 +817,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -834,62 +834,62 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_squared_error(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2,])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -914,7 +914,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 9.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -925,14 +925,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
def testGradientWithZeroWeight(self):
@@ -954,7 +954,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -966,7 +966,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -976,7 +976,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -986,7 +986,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
@@ -998,7 +998,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=tf_predictions,
labels=tf_labels,
weights=constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1015,7 +1015,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
def testZeroLossWithOneDimBatchZeroWeights(self):
@@ -1025,7 +1025,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
@@ -1041,7 +1041,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1056,7 +1056,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testLossIsAssociativeAcrossBatchElements(self):
@@ -1087,7 +1087,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=array_ops.concat([predictions0, predictions1], 0),
labels=array_ops.concat([labels0, labels1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1115,7 +1115,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1128,7 +1128,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1136,7 +1136,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1154,7 +1154,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1163,7 +1163,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=constant_op.constant([1, 0, 0]))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1173,12 +1173,12 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testValueErrorThrownWithShapelessPlaceholder(self):
tf_predictions = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=tf_predictions,
@@ -1196,7 +1196,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1206,7 +1206,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3,)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1215,7 +1215,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1228,7 +1228,7 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss = loss_ops.compute_weighted_loss(losses)
self.assertTrue(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1243,7 +1243,7 @@ class AddLossTest(test.TestCase):
loss_ops.add_loss(math_ops.reduce_mean(losses))
self.assertTrue(loss_ops.get_losses())
total_loss = loss_ops.get_total_loss()
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1254,7 +1254,7 @@ class AddLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None)
self.assertFalse(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
def testNoCollectLosses(self):
diff --git a/tensorflow/contrib/losses/python/metric_learning/__init__.py b/tensorflow/contrib/losses/python/metric_learning/__init__.py
index 4e551d6aca..3d93a4d0ac 100644
--- a/tensorflow/contrib/losses/python/metric_learning/__init__.py
+++ b/tensorflow/contrib/losses/python/metric_learning/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Ops for building neural network losses.
-See @{$python/contrib.losses}.
+See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses).
"""
from __future__ import absolute_import
@@ -35,5 +35,3 @@ _allowed_symbols = [
'triplet_semihard_loss',
]
remove_undocumented(__name__, _allowed_symbols)
-
-
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 1a1ab54a53..d962a5e12d 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -90,6 +90,7 @@ HOST_INCLUDES := \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(MAKEFILE_DIR)/downloads/double_conversion \
+-I$(MAKEFILE_DIR)/downloads/absl \
-I$(HOST_GENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include
@@ -116,6 +117,25 @@ ifeq ($(HOST_OS),PI)
HOST_LIBS += -ldl -lpthread
endif
+# Abseil sources.
+ABSL_CC_ALL_SRCS := \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*.cc)
+
+ABSL_CC_EXCLUDE_SRCS := \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*test*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*test*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*test*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*test*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
+$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
+tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc
+
+ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS))
# proto_text is a tool that converts protobufs into a form we can use more
# compactly within TensorFlow. It's a bit like protoc, but is designed to
@@ -125,7 +145,9 @@ endif
PROTO_TEXT := $(HOST_BINDIR)proto_text
# The list of dependencies is derived from the Bazel build file by running
# the gen_file_lists.sh script on a system with a working Bazel setup.
-PROTO_TEXT_CC_FILES := $(shell cat $(MAKEFILE_DIR)/proto_text_cc_files.txt)
+PROTO_TEXT_CC_FILES := \
+ $(ABSL_CC_SRCS) \
+ $(shell cat $(MAKEFILE_DIR)/proto_text_cc_files.txt)
PROTO_TEXT_PB_CC_LIST := \
$(shell cat $(MAKEFILE_DIR)/proto_text_pb_cc_files.txt) \
$(wildcard tensorflow/contrib/makefile/downloads/double_conversion/double-conversion/*.cc)
@@ -175,6 +197,7 @@ INCLUDES := \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(MAKEFILE_DIR)/downloads/double_conversion \
+-I$(MAKEFILE_DIR)/downloads/absl \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
@@ -236,7 +259,6 @@ ifeq ($(TARGET),PI)
endif
# Set up Android building
-# LINT.IfChange
ifeq ($(TARGET),ANDROID)
# Override NDK_ROOT on the command line with your own NDK location, e.g.
# make -f tensorflow/contrib/makefile/Makefile TARGET=ANDROID \
@@ -331,6 +353,7 @@ $(MARCH_OPTION) \
-I$(MAKEFILE_DIR)/downloads/nsync/public \
-I$(MAKEFILE_DIR)/downloads/fft2d \
-I$(MAKEFILE_DIR)/downloads/double_conversion \
+-I$(MAKEFILE_DIR)/downloads/absl \
-I$(MAKEFILE_DIR)/gen/protobuf_android/$(ANDROID_ARCH)/include \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
@@ -446,7 +469,6 @@ $(MARCH_OPTION) \
DEPDIR := $(DEPDIR)android_$(ANDROID_ARCH)/
endif # ifeq ($(BUILD_FOR_TEGRA),1)
endif # ANDROID
-# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
# Settings for iOS.
ifeq ($(TARGET),IOS)
@@ -596,6 +618,7 @@ BENCHMARK_NAME := $(BINDIR)benchmark
# gen_file_lists.sh script.
CORE_CC_ALL_SRCS := \
+$(ABSL_CC_SRCS) \
$(wildcard tensorflow/core/*.cc) \
$(wildcard tensorflow/core/common_runtime/*.cc) \
$(wildcard tensorflow/core/framework/*.cc) \
diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh
index a28fc3a87f..cb4c94d92f 100755
--- a/tensorflow/contrib/makefile/compile_nsync.sh
+++ b/tensorflow/contrib/makefile/compile_nsync.sh
@@ -256,6 +256,7 @@ for arch in $archs; do
esac
makefile='
+ AR := ${NDK_ROOT}/toolchains/'"$toolchain"'/prebuilt/'"$android_os_arch"'/bin/'"$bin_prefix"'-ar
CC=${CC_PREFIX} \
${NDK_ROOT}/toolchains/'"$toolchain"'/prebuilt/'"$android_os_arch"'/bin/'"$bin_prefix"'-g++
PLATFORM_CPPFLAGS=--sysroot \
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 7d26429f9c..9ea94c7433 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -1,62 +1,61 @@
-tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
-tensorflow/tools/proto_text/gen_proto_text_functions.cc
tensorflow/core/framework/resource_handle.cc
+tensorflow/core/lib/core/arena.cc
+tensorflow/core/lib/core/coding.cc
+tensorflow/core/lib/core/status.cc
+tensorflow/core/lib/core/threadpool.cc
+tensorflow/core/lib/hash/crc32c.cc
+tensorflow/core/lib/hash/crc32c_accelerate.cc
+tensorflow/core/lib/hash/hash.cc
+tensorflow/core/lib/histogram/histogram.cc
+tensorflow/core/lib/io/block.cc
+tensorflow/core/lib/io/block_builder.cc
+tensorflow/core/lib/io/buffered_inputstream.cc
+tensorflow/core/lib/io/compression.cc
+tensorflow/core/lib/io/format.cc
+tensorflow/core/lib/io/inputbuffer.cc
+tensorflow/core/lib/io/inputstream_interface.cc
+tensorflow/core/lib/io/iterator.cc
+tensorflow/core/lib/io/path.cc
+tensorflow/core/lib/io/random_inputstream.cc
+tensorflow/core/lib/io/record_reader.cc
+tensorflow/core/lib/io/record_writer.cc
+tensorflow/core/lib/io/table.cc
+tensorflow/core/lib/io/table_builder.cc
+tensorflow/core/lib/io/two_level_iterator.cc
+tensorflow/core/lib/io/zlib_compression_options.cc
+tensorflow/core/lib/io/zlib_inputstream.cc
+tensorflow/core/lib/io/zlib_outputbuffer.cc
+tensorflow/core/lib/random/distribution_sampler.cc
+tensorflow/core/lib/random/random.cc
+tensorflow/core/lib/random/simple_philox.cc
+tensorflow/core/lib/random/weighted_picker.cc
+tensorflow/core/lib/strings/numbers.cc
+tensorflow/core/lib/strings/ordered_code.cc
+tensorflow/core/lib/strings/proto_text_util.cc
+tensorflow/core/lib/strings/scanner.cc
+tensorflow/core/lib/strings/str_util.cc
+tensorflow/core/lib/strings/strcat.cc
+tensorflow/core/lib/strings/stringprintf.cc
+tensorflow/core/lib/wav/wav_io.cc
+tensorflow/core/platform/cpu_info.cc
+tensorflow/core/platform/default/logging.cc
+tensorflow/core/platform/default/mutex.cc
tensorflow/core/platform/default/protobuf.cc
-tensorflow/core/platform/tracing.cc
-tensorflow/core/platform/tensor_coding.cc
-tensorflow/core/platform/protobuf_util.cc
-tensorflow/core/platform/posix/posix_file_system.cc
-tensorflow/core/platform/posix/port.cc
-tensorflow/core/platform/posix/error.cc
-tensorflow/core/platform/posix/env.cc
-tensorflow/core/platform/posix/load_library.cc
-tensorflow/core/platform/posix/env_time.cc
-tensorflow/core/platform/file_system.cc
-tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/default/tracing.cc
+tensorflow/core/platform/denormal.cc
tensorflow/core/platform/env.cc
tensorflow/core/platform/env_time.cc
+tensorflow/core/platform/file_system.cc
+tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/posix/env.cc
+tensorflow/core/platform/posix/env_time.cc
+tensorflow/core/platform/posix/error.cc
+tensorflow/core/platform/posix/load_library.cc
+tensorflow/core/platform/posix/port.cc
+tensorflow/core/platform/posix/posix_file_system.cc
+tensorflow/core/platform/protobuf_util.cc
tensorflow/core/platform/setround.cc
-tensorflow/core/platform/denormal.cc
-tensorflow/core/platform/default/tracing.cc
-tensorflow/core/platform/default/mutex.cc
-tensorflow/core/platform/default/logging.cc
-tensorflow/core/platform/cpu_info.cc
-tensorflow/core/lib/wav/wav_io.cc
-tensorflow/core/lib/strings/stringprintf.cc
-tensorflow/core/lib/strings/strcat.cc
-tensorflow/core/lib/strings/str_util.cc
-tensorflow/core/lib/strings/scanner.cc
-tensorflow/core/lib/strings/proto_text_util.cc
-tensorflow/core/lib/strings/ordered_code.cc
-tensorflow/core/lib/strings/numbers.cc
-tensorflow/core/lib/random/weighted_picker.cc
-tensorflow/core/lib/random/simple_philox.cc
-tensorflow/core/lib/random/random.cc
-tensorflow/core/lib/random/distribution_sampler.cc
-tensorflow/core/lib/io/zlib_outputbuffer.cc
-tensorflow/core/lib/io/zlib_inputstream.cc
-tensorflow/core/lib/io/zlib_compression_options.cc
-tensorflow/core/lib/io/two_level_iterator.cc
-tensorflow/core/lib/io/table_builder.cc
-tensorflow/core/lib/io/table.cc
-tensorflow/core/lib/io/record_writer.cc
-tensorflow/core/lib/io/record_reader.cc
-tensorflow/core/lib/io/random_inputstream.cc
-tensorflow/core/lib/io/path.cc
-tensorflow/core/lib/io/iterator.cc
-tensorflow/core/lib/io/inputstream_interface.cc
-tensorflow/core/lib/io/inputbuffer.cc
-tensorflow/core/lib/io/format.cc
-tensorflow/core/lib/io/compression.cc
-tensorflow/core/lib/io/buffered_inputstream.cc
-tensorflow/core/lib/io/block_builder.cc
-tensorflow/core/lib/io/block.cc
-tensorflow/core/lib/histogram/histogram.cc
-tensorflow/core/lib/hash/hash.cc
-tensorflow/core/lib/hash/crc32c.cc
-tensorflow/core/lib/hash/crc32c_accelerate.cc
-tensorflow/core/lib/core/threadpool.cc
-tensorflow/core/lib/core/stringpiece.cc
-tensorflow/core/lib/core/status.cc
-tensorflow/core/lib/core/coding.cc
-tensorflow/core/lib/core/arena.cc
+tensorflow/core/platform/tensor_coding.cc
+tensorflow/core/platform/tracing.cc
+tensorflow/tools/proto_text/gen_proto_text_functions.cc
+tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 938c4a53ab..0d8df93d11 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -1,41 +1,41 @@
-tensorflow/core/util/test_log.pb.cc
-tensorflow/core/util/saved_tensor_slice.pb.cc
-tensorflow/core/util/memmapped_file_system.pb.cc
-tensorflow/core/util/event.pb.cc
-tensorflow/core/protobuf/tensorflow_server.pb.cc
-tensorflow/core/protobuf/saver.pb.cc
-tensorflow/core/protobuf/queue_runner.pb.cc
-tensorflow/core/protobuf/named_tensor.pb.cc
-tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/example/example.pb.cc
+tensorflow/core/example/feature.pb.cc
+tensorflow/core/framework/allocation_description.pb.cc
+tensorflow/core/framework/api_def.pb.cc
+tensorflow/core/framework/attr_value.pb.cc
+tensorflow/core/framework/cost_graph.pb.cc
+tensorflow/core/framework/device_attributes.pb.cc
+tensorflow/core/framework/function.pb.cc
+tensorflow/core/framework/graph.pb.cc
+tensorflow/core/framework/graph_transfer_info.pb.cc
+tensorflow/core/framework/kernel_def.pb.cc
+tensorflow/core/framework/log_memory.pb.cc
+tensorflow/core/framework/node_def.pb.cc
+tensorflow/core/framework/op_def.pb.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
+tensorflow/core/framework/resource_handle.pb.cc
+tensorflow/core/framework/step_stats.pb.cc
+tensorflow/core/framework/summary.pb.cc
+tensorflow/core/framework/tensor.pb.cc
+tensorflow/core/framework/tensor_description.pb.cc
+tensorflow/core/framework/tensor_shape.pb.cc
+tensorflow/core/framework/tensor_slice.pb.cc
+tensorflow/core/framework/types.pb.cc
+tensorflow/core/framework/variable.pb.cc
+tensorflow/core/framework/versions.pb.cc
+tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
-tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc
tensorflow/core/protobuf/device_properties.pb.cc
-tensorflow/core/lib/core/error_codes.pb.cc
-tensorflow/core/framework/versions.pb.cc
-tensorflow/core/framework/variable.pb.cc
-tensorflow/core/framework/types.pb.cc
-tensorflow/core/framework/tensor_slice.pb.cc
-tensorflow/core/framework/tensor_shape.pb.cc
-tensorflow/core/framework/tensor_description.pb.cc
-tensorflow/core/framework/tensor.pb.cc
-tensorflow/core/framework/summary.pb.cc
-tensorflow/core/framework/step_stats.pb.cc
-tensorflow/core/framework/resource_handle.pb.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
-tensorflow/core/framework/api_def.pb.cc
-tensorflow/core/framework/op_def.pb.cc
-tensorflow/core/framework/node_def.pb.cc
-tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/kernel_def.pb.cc
-tensorflow/core/framework/graph_transfer_info.pb.cc
-tensorflow/core/framework/graph.pb.cc
-tensorflow/core/framework/function.pb.cc
-tensorflow/core/framework/device_attributes.pb.cc
-tensorflow/core/framework/cost_graph.pb.cc
-tensorflow/core/framework/attr_value.pb.cc
-tensorflow/core/framework/allocation_description.pb.cc
-tensorflow/core/example/feature.pb.cc
-tensorflow/core/example/example.pb.cc
-tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/named_tensor.pb.cc
+tensorflow/core/protobuf/queue_runner.pb.cc
+tensorflow/core/protobuf/rewriter_config.pb.cc
+tensorflow/core/protobuf/saver.pb.cc
+tensorflow/core/protobuf/tensorflow_server.pb.cc
+tensorflow/core/util/event.pb.cc
+tensorflow/core/util/memmapped_file_system.pb.cc
+tensorflow/core/util/saved_tensor_slice.pb.cc
+tensorflow/core/util/test_log.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index aa91b2f954..d982df9319 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -1,42 +1,43 @@
-tensorflow/core/util/test_log.pb.h
-tensorflow/core/util/saved_tensor_slice.pb.h
-tensorflow/core/util/memmapped_file_system.pb.h
-tensorflow/core/util/event.pb.h
-tensorflow/core/protobuf/tensorflow_server.pb.h
-tensorflow/core/protobuf/saver.pb.h
-tensorflow/core/protobuf/queue_runner.pb.h
-tensorflow/core/protobuf/named_tensor.pb.h
-tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/example/example.pb.h
+tensorflow/core/example/feature.pb.h
+tensorflow/core/framework/allocation_description.pb.h
+tensorflow/core/framework/api_def.pb.h
+tensorflow/core/framework/attr_value.pb.h
+tensorflow/core/framework/cost_graph.pb.h
+tensorflow/core/framework/device_attributes.pb.h
+tensorflow/core/framework/function.pb.h
+tensorflow/core/framework/graph.pb.h
+tensorflow/core/framework/graph_transfer_info.pb.h
+tensorflow/core/framework/kernel_def.pb.h
+tensorflow/core/framework/log_memory.pb.h
+tensorflow/core/framework/node_def.pb.h
+tensorflow/core/framework/op_def.pb.h
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
+tensorflow/core/framework/resource_handle.pb.h
+tensorflow/core/framework/step_stats.pb.h
+tensorflow/core/framework/summary.pb.h
+tensorflow/core/framework/tensor.pb.h
+tensorflow/core/framework/tensor_description.pb.h
+tensorflow/core/framework/tensor_shape.pb.h
+tensorflow/core/framework/tensor_slice.pb.h
+tensorflow/core/framework/types.pb.h
+tensorflow/core/framework/variable.pb.h
+tensorflow/core/framework/versions.pb.h
+tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/device_properties.pb.h
+tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/named_tensor.pb.h
+tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
+tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
-tensorflow/core/lib/core/error_codes.pb.h
-tensorflow/core/framework/versions.pb.h
-tensorflow/core/framework/variable.pb.h
-tensorflow/core/framework/types.pb.h
-tensorflow/core/framework/tensor_slice.pb.h
-tensorflow/core/framework/tensor_shape.pb.h
-tensorflow/core/framework/tensor_description.pb.h
-tensorflow/core/framework/tensor.pb.h
-tensorflow/core/framework/summary.pb.h
-tensorflow/core/framework/step_stats.pb.h
-tensorflow/core/framework/resource_handle.pb.h
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
-tensorflow/core/framework/api_def.pb.h
-tensorflow/core/framework/op_def.pb.h
-tensorflow/core/framework/node_def.pb.h
-tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/kernel_def.pb.h
-tensorflow/core/framework/graph_transfer_info.pb.h
-tensorflow/core/framework/graph.pb.h
-tensorflow/core/framework/function.pb.h
-tensorflow/core/framework/device_attributes.pb.h
-tensorflow/core/framework/cost_graph.pb.h
-tensorflow/core/framework/attr_value.pb.h
-tensorflow/core/framework/allocation_description.pb.h
-tensorflow/core/example/feature.pb.h
-tensorflow/core/example/example.pb.h
-tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/protobuf/tensorflow_server.pb.h
+tensorflow/core/util/event.pb.h
+tensorflow/core/util/memmapped_file_system.pb.h
+tensorflow/core/util/saved_tensor_slice.pb.h
+tensorflow/core/util/test_log.pb.h
+
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index ecf2e120df..08de54b8e1 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -4,218 +4,19 @@ tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
tensorflow/contrib/boosted_trees/ops/training_ops.cc
-tensorflow/core/kernels/xent_op.cc
-tensorflow/core/kernels/where_op.cc
-tensorflow/core/kernels/variable_ops.cc
-tensorflow/core/kernels/unpack_op.cc
-tensorflow/core/kernels/unique_op.cc
-tensorflow/core/kernels/transpose_op.cc
-tensorflow/core/kernels/transpose_functor_cpu.cc
-tensorflow/core/kernels/training_op_helpers.cc
-tensorflow/core/kernels/training_ops.cc
-tensorflow/core/kernels/topk_op.cc
-tensorflow/core/kernels/tile_functor_cpu.cc
-tensorflow/core/kernels/tile_ops.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
-tensorflow/core/kernels/tensor_array_ops.cc
-tensorflow/core/kernels/tensor_array.cc
-tensorflow/core/kernels/strided_slice_op_inst_7.cc
-tensorflow/core/kernels/strided_slice_op_inst_6.cc
-tensorflow/core/kernels/strided_slice_op_inst_5.cc
-tensorflow/core/kernels/strided_slice_op_inst_4.cc
-tensorflow/core/kernels/strided_slice_op_inst_3.cc
-tensorflow/core/kernels/strided_slice_op_inst_2.cc
-tensorflow/core/kernels/strided_slice_op_inst_1.cc
-tensorflow/core/kernels/strided_slice_op_inst_0.cc
-tensorflow/core/kernels/strided_slice_op.cc
-tensorflow/core/kernels/stack_ops.cc
-tensorflow/core/kernels/split_op.cc
-tensorflow/core/kernels/split_v_op.cc
-tensorflow/core/kernels/split_lib_cpu.cc
-tensorflow/core/kernels/spectrogram_op.cc
-tensorflow/core/kernels/spectrogram.cc
-tensorflow/core/kernels/sparse_to_dense_op.cc
-tensorflow/core/kernels/sparse_matmul_op.cc
-tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
-tensorflow/core/kernels/sparse_reshape_op.c
-tensorflow/core/kernels/segment_reduction_ops.cc
-tensorflow/core/kernels/softsign_op.cc
-tensorflow/core/kernels/softplus_op.cc
-tensorflow/core/kernels/softmax_op.cc
-tensorflow/core/kernels/slice_op_cpu_impl_1.cc
-tensorflow/core/kernels/slice_op_cpu_impl_2.cc
-tensorflow/core/kernels/slice_op_cpu_impl_3.cc
-tensorflow/core/kernels/slice_op_cpu_impl_4.cc
-tensorflow/core/kernels/slice_op_cpu_impl_5.cc
-tensorflow/core/kernels/slice_op_cpu_impl_6.cc
-tensorflow/core/kernels/slice_op_cpu_impl_7.cc
-tensorflow/core/kernels/slice_op.cc
-tensorflow/core/kernels/shape_ops.cc
-tensorflow/core/kernels/session_ops.cc
-tensorflow/core/kernels/sequence_ops.cc
-tensorflow/core/kernels/sendrecv_ops.cc
-tensorflow/core/kernels/scatter_op.cc
-tensorflow/core/kernels/scatter_functor.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/scatter_nd_op.cc
-tensorflow/core/kernels/save_restore_tensor.cc
-tensorflow/core/kernels/save_restore_v2_ops.cc
-tensorflow/core/kernels/save_op.cc
-tensorflow/core/kernels/string_join_op.cc
-tensorflow/core/kernels/reverse_sequence_op.cc
-tensorflow/core/kernels/reverse_op.cc
-tensorflow/core/kernels/restore_op.cc
-tensorflow/core/kernels/resize_nearest_neighbor_op.cc
-tensorflow/core/kernels/resize_bilinear_op.cc
-tensorflow/core/kernels/reshape_util.cc
-tensorflow/core/kernels/reshape_op.cc
-tensorflow/core/kernels/relu_op.cc
-tensorflow/core/kernels/reduction_ops_sum.cc
-tensorflow/core/kernels/reduction_ops_prod.cc
-tensorflow/core/kernels/reduction_ops_min.cc
-tensorflow/core/kernels/reduction_ops_mean.cc
-tensorflow/core/kernels/reduction_ops_max.cc
-tensorflow/core/kernels/reduction_ops_common.cc
-tensorflow/core/kernels/reduction_ops_any.cc
-tensorflow/core/kernels/reduction_ops_all.cc
-tensorflow/core/kernels/roll_op.cc
-tensorflow/core/kernels/queue_op.cc
-tensorflow/core/kernels/queue_ops.cc
-tensorflow/core/kernels/queue_base.cc
-tensorflow/core/kernels/pooling_ops_common.cc
-tensorflow/core/kernels/padding_fifo_queue_op.cc
-tensorflow/core/kernels/padding_fifo_queue.cc
-tensorflow/core/kernels/pad_op.cc
-tensorflow/core/kernels/pack_op.cc
-tensorflow/core/kernels/ops_util.cc
-tensorflow/core/kernels/one_hot_op.cc
-tensorflow/core/kernels/non_max_suppression_op.cc
-tensorflow/core/kernels/no_op.cc
-tensorflow/core/kernels/mirror_pad_op.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
-tensorflow/core/kernels/mfcc_op.cc
-tensorflow/core/kernels/mfcc_mel_filterbank.cc
-tensorflow/core/kernels/mfcc_dct.cc
-tensorflow/core/kernels/mfcc.cc
-tensorflow/core/kernels/maxpooling_op.cc
-tensorflow/core/kernels/matmul_op.cc
-tensorflow/core/kernels/lrn_op.cc
-tensorflow/core/kernels/logging_ops.cc
-tensorflow/core/kernels/initializable_lookup_table.c
-tensorflow/core/kernels/lookup_table_init_op.cc
-tensorflow/core/kernels/lookup_table_op.cc
-tensorflow/core/kernels/lookup_util.cc
-tensorflow/core/kernels/inplace_ops.cc
-tensorflow/core/kernels/in_topk_op.cc
-tensorflow/core/kernels/immutable_constant_op.cc
-tensorflow/core/kernels/identity_op.cc
-tensorflow/core/kernels/identity_n_op.cc
-tensorflow/core/kernels/gather_op.cc
-tensorflow/core/kernels/gather_functor.cc
-tensorflow/core/kernels/gather_nd_op.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/fused_batch_norm_op.cc
-tensorflow/core/kernels/function_ops.cc
-tensorflow/core/kernels/fill_functor.cc
-tensorflow/core/kernels/fifo_queue.cc
-tensorflow/core/kernels/fifo_queue_op.cc
-tensorflow/core/kernels/fake_quant_ops.cc
-tensorflow/core/kernels/example_parsing_ops.cc
-tensorflow/core/kernels/encode_wav_op.cc
-tensorflow/core/kernels/dynamic_stitch_op.cc
-tensorflow/core/kernels/dynamic_partition_op.cc
-tensorflow/core/kernels/decode_bmp_op.cc
-tensorflow/core/kernels/depthtospace_op.cc
-tensorflow/core/kernels/data_format_ops.cc
-tensorflow/core/kernels/spacetodepth_op.cc
-tensorflow/core/kernels/dense_update_functor.cc
-tensorflow/core/kernels/dense_update_ops.cc
-tensorflow/core/kernels/deep_conv2d.cc
-tensorflow/core/kernels/decode_wav_op.cc
-tensorflow/core/kernels/xsmm_conv2d.cc
-tensorflow/core/kernels/cwise_ops_common.cc
-tensorflow/core/kernels/cwise_op_tanh.cc
-tensorflow/core/kernels/cwise_op_pow.cc
-tensorflow/core/kernels/cwise_op_sub.cc
-tensorflow/core/kernels/cwise_op_squared_difference.cc
-tensorflow/core/kernels/cwise_op_square.cc
-tensorflow/core/kernels/cwise_op_sqrt.cc
-tensorflow/core/kernels/cwise_op_sigmoid.cc
-tensorflow/core/kernels/cwise_op_sign.cc
-tensorflow/core/kernels/cwise_op_select.cc
-tensorflow/core/kernels/cwise_op_round.cc
-tensorflow/core/kernels/cwise_op_rsqrt.cc
-tensorflow/core/kernels/cwise_op_reciprocal.cc
-tensorflow/core/kernels/cwise_op_neg.cc
-tensorflow/core/kernels/cwise_op_mul_2.cc
-tensorflow/core/kernels/cwise_op_mul_1.cc
-tensorflow/core/kernels/cwise_op_minimum.cc
-tensorflow/core/kernels/cwise_op_maximum.cc
-tensorflow/core/kernels/cwise_op_logical_not.cc
-tensorflow/core/kernels/cwise_op_logical_and.cc
-tensorflow/core/kernels/cwise_op_logical_or.cc
-tensorflow/core/kernels/cwise_op_log.cc
-tensorflow/core/kernels/cwise_op_less.cc
-tensorflow/core/kernels/cwise_op_less_equal.cc
-tensorflow/core/kernels/cwise_op_isnan.cc
-tensorflow/core/kernels/cwise_op_isfinite.cc
-tensorflow/core/kernels/cwise_op_invert.cc
-tensorflow/core/kernels/cwise_op_greater_equal.cc
-tensorflow/core/kernels/cwise_op_greater.cc
-tensorflow/core/kernels/cwise_op_floor_div.cc
-tensorflow/core/kernels/cwise_op_floor_mod.cc
-tensorflow/core/kernels/cwise_op_floor.cc
-tensorflow/core/kernels/cwise_op_exp.cc
-tensorflow/core/kernels/cwise_op_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_div.cc
-tensorflow/core/kernels/cwise_op_bitwise_xor.cc
-tensorflow/core/kernels/cwise_op_bitwise_or.cc
-tensorflow/core/kernels/cwise_op_bitwise_and.cc
-tensorflow/core/kernels/cwise_op_left_shift.cc
-tensorflow/core/kernels/cwise_op_right_shift.cc
-tensorflow/core/kernels/cwise_op_add_2.cc
-tensorflow/core/kernels/cwise_op_add_1.cc
-tensorflow/core/kernels/cwise_op_abs.cc
-tensorflow/core/kernels/ctc_decoder_ops.cc
-tensorflow/core/kernels/crop_and_resize_op.cc
-tensorflow/core/kernels/conv_ops_using_gemm.cc
-tensorflow/core/kernels/conv_ops_fused.cc
-tensorflow/core/kernels/conv_ops.cc
-tensorflow/core/kernels/conv_grad_filter_ops.cc
-tensorflow/core/kernels/conv_grad_input_ops.cc
-tensorflow/core/kernels/conv_grad_ops.cc
-tensorflow/core/kernels/control_flow_ops.cc
-tensorflow/core/kernels/constant_op.cc
-tensorflow/core/kernels/concat_op.cc
-tensorflow/core/kernels/concat_lib_cpu.cc
-tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/argmax_op.cc
+tensorflow/core/kernels/avgpooling_op.cc
+tensorflow/core/kernels/batch_matmul_op_real.cc
+tensorflow/core/kernels/batch_norm_op.cc
+tensorflow/core/kernels/batchtospace_op.cc
+tensorflow/core/kernels/bcast_ops.cc
+tensorflow/core/kernels/bias_op.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
tensorflow/core/kernels/cast_op.cc
tensorflow/core/kernels/cast_op_impl_bfloat.cc
tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -232,20 +33,131 @@ tensorflow/core/kernels/cast_op_impl_uint16.cc
tensorflow/core/kernels/cast_op_impl_uint32.cc
tensorflow/core/kernels/cast_op_impl_uint64.cc
tensorflow/core/kernels/cast_op_impl_uint8.cc
-tensorflow/core/kernels/boosted_trees/prediction_ops.cc
-tensorflow/core/kernels/boosted_trees/resource_ops.cc
-tensorflow/core/kernels/boosted_trees/resources.cc
-tensorflow/core/kernels/boosted_trees/stats_ops.cc
-tensorflow/core/kernels/boosted_trees/training_ops.cc
-tensorflow/core/kernels/bias_op.cc
-tensorflow/core/kernels/bcast_ops.cc
-tensorflow/core/kernels/batch_norm_op.cc
-tensorflow/core/kernels/avgpooling_op.cc
-tensorflow/core/kernels/argmax_op.cc
-tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/concat_lib_cpu.cc
+tensorflow/core/kernels/concat_op.cc
+tensorflow/core/kernels/constant_op.cc
+tensorflow/core/kernels/control_flow_ops.cc
+tensorflow/core/kernels/conv_grad_filter_ops.cc
+tensorflow/core/kernels/conv_grad_input_ops.cc
+tensorflow/core/kernels/conv_grad_ops.cc
+tensorflow/core/kernels/conv_ops.cc
+tensorflow/core/kernels/conv_ops_fused.cc
+tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/crop_and_resize_op.cc
+tensorflow/core/kernels/ctc_decoder_ops.cc
+tensorflow/core/kernels/cwise_op_abs.cc
+tensorflow/core/kernels/cwise_op_add_1.cc
+tensorflow/core/kernels/cwise_op_add_2.cc
+tensorflow/core/kernels/cwise_op_bitwise_and.cc
+tensorflow/core/kernels/cwise_op_bitwise_or.cc
+tensorflow/core/kernels/cwise_op_bitwise_xor.cc
+tensorflow/core/kernels/cwise_op_div.cc
+tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_exp.cc
+tensorflow/core/kernels/cwise_op_floor.cc
+tensorflow/core/kernels/cwise_op_floor_div.cc
+tensorflow/core/kernels/cwise_op_floor_mod.cc
+tensorflow/core/kernels/cwise_op_greater.cc
+tensorflow/core/kernels/cwise_op_greater_equal.cc
+tensorflow/core/kernels/cwise_op_invert.cc
+tensorflow/core/kernels/cwise_op_isfinite.cc
+tensorflow/core/kernels/cwise_op_isnan.cc
+tensorflow/core/kernels/cwise_op_left_shift.cc
+tensorflow/core/kernels/cwise_op_less.cc
+tensorflow/core/kernels/cwise_op_less_equal.cc
+tensorflow/core/kernels/cwise_op_log.cc
+tensorflow/core/kernels/cwise_op_logical_and.cc
+tensorflow/core/kernels/cwise_op_logical_not.cc
+tensorflow/core/kernels/cwise_op_logical_or.cc
+tensorflow/core/kernels/cwise_op_maximum.cc
+tensorflow/core/kernels/cwise_op_minimum.cc
+tensorflow/core/kernels/cwise_op_mul_1.cc
+tensorflow/core/kernels/cwise_op_mul_2.cc
+tensorflow/core/kernels/cwise_op_neg.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_pow.cc
+tensorflow/core/kernels/cwise_op_reciprocal.cc
+tensorflow/core/kernels/cwise_op_right_shift.cc
+tensorflow/core/kernels/cwise_op_round.cc
+tensorflow/core/kernels/cwise_op_rsqrt.cc
+tensorflow/core/kernels/cwise_op_select.cc
+tensorflow/core/kernels/cwise_op_sigmoid.cc
+tensorflow/core/kernels/cwise_op_sign.cc
+tensorflow/core/kernels/cwise_op_sqrt.cc
+tensorflow/core/kernels/cwise_op_square.cc
+tensorflow/core/kernels/cwise_op_squared_difference.cc
+tensorflow/core/kernels/cwise_op_sub.cc
+tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_ops_common.cc
+tensorflow/core/kernels/data_format_ops.cc
+tensorflow/core/kernels/decode_bmp_op.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/decode_wav_op.cc
+tensorflow/core/kernels/deep_conv2d.cc
+tensorflow/core/kernels/dense_update_functor.cc
+tensorflow/core/kernels/dense_update_ops.cc
+tensorflow/core/kernels/depthtospace_op.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/core/kernels/dequantize_op.cc
+tensorflow/core/kernels/dynamic_partition_op.cc
+tensorflow/core/kernels/dynamic_stitch_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/kernels/encode_wav_op.cc
+tensorflow/core/kernels/example_parsing_ops.cc
+tensorflow/core/kernels/fake_quant_ops.cc
+tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
+tensorflow/core/kernels/fill_functor.cc
+tensorflow/core/kernels/function_ops.cc
+tensorflow/core/kernels/fused_batch_norm_op.cc
+tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/gather_op.cc
+tensorflow/core/kernels/identity_n_op.cc
+tensorflow/core/kernels/identity_op.cc
+tensorflow/core/kernels/immutable_constant_op.cc
+tensorflow/core/kernels/in_topk_op.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/inplace_ops.cc
+tensorflow/core/kernels/listdiff_op.cc
+tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
+tensorflow/core/kernels/lrn_op.cc
+tensorflow/core/kernels/matmul_op.cc
+tensorflow/core/kernels/maxpooling_op.cc
tensorflow/core/kernels/meta_support.cc
+tensorflow/core/kernels/mfcc.cc
+tensorflow/core/kernels/mfcc_dct.cc
+tensorflow/core/kernels/mfcc_mel_filterbank.cc
+tensorflow/core/kernels/mfcc_op.cc
+tensorflow/core/kernels/mirror_pad_op.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
+tensorflow/core/kernels/no_op.cc
+tensorflow/core/kernels/non_max_suppression_op.cc
+tensorflow/core/kernels/one_hot_op.cc
+tensorflow/core/kernels/ops_util.cc
+tensorflow/core/kernels/pack_op.cc
+tensorflow/core/kernels/pad_op.cc
+tensorflow/core/kernels/padding_fifo_queue.cc
+tensorflow/core/kernels/padding_fifo_queue_op.cc
+tensorflow/core/kernels/pooling_ops_common.cc
tensorflow/core/kernels/population_count_op.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
@@ -262,47 +174,135 @@ tensorflow/core/kernels/quantized_mul_op.cc
tensorflow/core/kernels/quantized_pooling_ops.cc
tensorflow/core/kernels/quantized_reshape_op.cc
tensorflow/core/kernels/quantized_resize_bilinear_op.cc
-tensorflow/core/kernels/requantization_range_op.cc
-tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/queue_base.cc
+tensorflow/core/kernels/queue_op.cc
+tensorflow/core/kernels/queue_ops.cc
+tensorflow/core/kernels/random_op.cc
+tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/reduction_ops_any.cc
+tensorflow/core/kernels/reduction_ops_common.cc
+tensorflow/core/kernels/reduction_ops_max.cc
+tensorflow/core/kernels/reduction_ops_mean.cc
+tensorflow/core/kernels/reduction_ops_min.cc
+tensorflow/core/kernels/reduction_ops_prod.cc
+tensorflow/core/kernels/reduction_ops_sum.cc
+tensorflow/core/kernels/relu_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
-tensorflow/core/kernels/batch_matmul_op_real.cc
-tensorflow/core/kernels/random_op.cc
-tensorflow/core/ops/training_ops.cc
-tensorflow/core/ops/string_ops.cc
-tensorflow/core/ops/state_ops.cc
-tensorflow/core/ops/sparse_ops.cc
-tensorflow/core/ops/sendrecv_ops.cc
-tensorflow/core/ops/script_ops.cc
-tensorflow/core/ops/remote_fused_graph_ops.cc
-tensorflow/core/ops/random_ops.cc
-tensorflow/core/ops/random_grad.cc
-tensorflow/core/ops/parsing_ops.cc
-tensorflow/core/ops/no_op.cc
-tensorflow/core/ops/nn_ops.cc
-tensorflow/core/ops/nn_grad.cc
-tensorflow/core/ops/manip_ops.cc
-tensorflow/core/ops/math_ops.cc
-tensorflow/core/ops/math_grad.cc
-tensorflow/core/ops/logging_ops.cc
-tensorflow/core/ops/linalg_ops.cc
-tensorflow/core/ops/io_ops.cc
-tensorflow/core/ops/image_ops.cc
-tensorflow/core/ops/functional_ops.cc
-tensorflow/core/ops/functional_grad.cc
-tensorflow/core/ops/function_ops.cc
-tensorflow/core/ops/data_flow_ops.cc
-tensorflow/core/ops/ctc_ops.cc
-tensorflow/core/ops/control_flow_ops.cc
-tensorflow/core/ops/candidate_sampling_ops.cc
-tensorflow/core/ops/boosted_trees_ops.cc
-tensorflow/core/ops/array_ops.cc
-tensorflow/core/ops/array_grad.cc
+tensorflow/core/kernels/requantization_range_op.cc
+tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/reshape_op.cc
+tensorflow/core/kernels/reshape_util.cc
+tensorflow/core/kernels/resize_bilinear_op.cc
+tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+tensorflow/core/kernels/restore_op.cc
+tensorflow/core/kernels/reverse_op.cc
+tensorflow/core/kernels/reverse_sequence_op.cc
+tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/save_op.cc
+tensorflow/core/kernels/save_restore_tensor.cc
+tensorflow/core/kernels/save_restore_v2_ops.cc
+tensorflow/core/kernels/scatter_functor.cc
+tensorflow/core/kernels/scatter_nd_op.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/scatter_op.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/sendrecv_ops.cc
+tensorflow/core/kernels/sequence_ops.cc
+tensorflow/core/kernels/session_ops.cc
+tensorflow/core/kernels/shape_ops.cc
+tensorflow/core/kernels/slice_op.cc
+tensorflow/core/kernels/slice_op_cpu_impl_1.cc
+tensorflow/core/kernels/slice_op_cpu_impl_2.cc
+tensorflow/core/kernels/slice_op_cpu_impl_3.cc
+tensorflow/core/kernels/slice_op_cpu_impl_4.cc
+tensorflow/core/kernels/slice_op_cpu_impl_5.cc
+tensorflow/core/kernels/slice_op_cpu_impl_6.cc
+tensorflow/core/kernels/slice_op_cpu_impl_7.cc
+tensorflow/core/kernels/softmax_op.cc
+tensorflow/core/kernels/softplus_op.cc
+tensorflow/core/kernels/softsign_op.cc
tensorflow/core/kernels/spacetobatch_functor.cc
tensorflow/core/kernels/spacetobatch_op.cc
-tensorflow/core/kernels/batchtospace_op.cc
-tensorflow/core/kernels/warn_about_ints.cc
-tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/sparse_to_dense_op.cc
+tensorflow/core/kernels/spectrogram.cc
+tensorflow/core/kernels/spectrogram_op.cc
+tensorflow/core/kernels/split_lib_cpu.cc
+tensorflow/core/kernels/split_op.cc
+tensorflow/core/kernels/split_v_op.cc
+tensorflow/core/kernels/stack_ops.cc
+tensorflow/core/kernels/strided_slice_op.cc
+tensorflow/core/kernels/strided_slice_op_inst_0.cc
+tensorflow/core/kernels/strided_slice_op_inst_1.cc
+tensorflow/core/kernels/strided_slice_op_inst_2.cc
+tensorflow/core/kernels/strided_slice_op_inst_3.cc
+tensorflow/core/kernels/strided_slice_op_inst_4.cc
+tensorflow/core/kernels/strided_slice_op_inst_5.cc
+tensorflow/core/kernels/strided_slice_op_inst_6.cc
+tensorflow/core/kernels/strided_slice_op_inst_7.cc
+tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/tensor_array.cc
+tensorflow/core/kernels/tensor_array_ops.cc
+tensorflow/core/kernels/tile_functor_cpu.cc
+tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
+tensorflow/core/kernels/topk_op.cc
+tensorflow/core/kernels/training_op_helpers.cc
+tensorflow/core/kernels/training_ops.cc
+tensorflow/core/kernels/transpose_functor_cpu.cc
+tensorflow/core/kernels/transpose_op.cc
+tensorflow/core/kernels/unique_op.cc
+tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/variable_ops.cc
+tensorflow/core/kernels/where_op.cc
+tensorflow/core/kernels/xent_op.cc
+tensorflow/core/kernels/xsmm_conv2d.cc
+tensorflow/core/ops/array_grad.cc
+tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/audio_ops.cc
-tensorflow/core/kernels/decode_proto_op.cc
-tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/ops/boosted_trees_ops.cc
+tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/control_flow_ops.cc
+tensorflow/core/ops/ctc_ops.cc
+tensorflow/core/ops/data_flow_ops.cc
+tensorflow/core/ops/function_ops.cc
+tensorflow/core/ops/functional_grad.cc
+tensorflow/core/ops/functional_ops.cc
+tensorflow/core/ops/image_ops.cc
+tensorflow/core/ops/io_ops.cc
+tensorflow/core/ops/linalg_ops.cc
+tensorflow/core/ops/logging_ops.cc
+tensorflow/core/ops/manip_ops.cc
+tensorflow/core/ops/math_grad.cc
+tensorflow/core/ops/math_ops.cc
+tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/nn_ops.cc
+tensorflow/core/ops/no_op.cc
+tensorflow/core/ops/parsing_ops.cc
+tensorflow/core/ops/random_grad.cc
+tensorflow/core/ops/random_ops.cc
+tensorflow/core/ops/remote_fused_graph_ops.cc
+tensorflow/core/ops/script_ops.cc
+tensorflow/core/ops/sendrecv_ops.cc
+tensorflow/core/ops/sparse_ops.cc
+tensorflow/core/ops/state_ops.cc
+tensorflow/core/ops/string_ops.cc
+tensorflow/core/ops/training_ops.cc
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index b5431df2eb..f94d70db90 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,33 +1,33 @@
-tensorflow/core/util/saved_tensor_slice.pb_text.cc
-tensorflow/core/util/memmapped_file_system.pb_text.cc
-tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/example/example.pb_text.cc
+tensorflow/core/example/feature.pb_text.cc
+tensorflow/core/framework/allocation_description.pb_text.cc
+tensorflow/core/framework/api_def.pb_text.cc
+tensorflow/core/framework/attr_value.pb_text.cc
+tensorflow/core/framework/cost_graph.pb_text.cc
+tensorflow/core/framework/device_attributes.pb_text.cc
+tensorflow/core/framework/function.pb_text.cc
+tensorflow/core/framework/graph.pb_text.cc
+tensorflow/core/framework/graph_transfer_info.pb_text.cc
+tensorflow/core/framework/kernel_def.pb_text.cc
+tensorflow/core/framework/log_memory.pb_text.cc
+tensorflow/core/framework/node_def.pb_text.cc
+tensorflow/core/framework/op_def.pb_text.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
+tensorflow/core/framework/resource_handle.pb_text.cc
+tensorflow/core/framework/step_stats.pb_text.cc
+tensorflow/core/framework/summary.pb_text.cc
+tensorflow/core/framework/tensor.pb_text.cc
+tensorflow/core/framework/tensor_description.pb_text.cc
+tensorflow/core/framework/tensor_shape.pb_text.cc
+tensorflow/core/framework/tensor_slice.pb_text.cc
+tensorflow/core/framework/types.pb_text.cc
+tensorflow/core/framework/versions.pb_text.cc
+tensorflow/core/lib/core/error_codes.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc
+tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
-tensorflow/core/lib/core/error_codes.pb_text.cc
-tensorflow/core/framework/versions.pb_text.cc
-tensorflow/core/framework/types.pb_text.cc
-tensorflow/core/framework/tensor_slice.pb_text.cc
-tensorflow/core/framework/tensor_shape.pb_text.cc
-tensorflow/core/framework/tensor_description.pb_text.cc
-tensorflow/core/framework/tensor.pb_text.cc
-tensorflow/core/framework/summary.pb_text.cc
-tensorflow/core/framework/step_stats.pb_text.cc
-tensorflow/core/framework/resource_handle.pb_text.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
-tensorflow/core/framework/api_def.pb_text.cc
-tensorflow/core/framework/op_def.pb_text.cc
-tensorflow/core/framework/node_def.pb_text.cc
-tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/kernel_def.pb_text.cc
-tensorflow/core/framework/graph_transfer_info.pb_text.cc
-tensorflow/core/framework/graph.pb_text.cc
-tensorflow/core/framework/function.pb_text.cc
-tensorflow/core/framework/device_attributes.pb_text.cc
-tensorflow/core/framework/cost_graph.pb_text.cc
-tensorflow/core/framework/attr_value.pb_text.cc
-tensorflow/core/framework/allocation_description.pb_text.cc
-tensorflow/core/example/feature.pb_text.cc
-tensorflow/core/example/example.pb_text.cc
+tensorflow/core/util/memmapped_file_system.pb_text.cc
+tensorflow/core/util/saved_tensor_slice.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 1f254692d7..8bec3e3e01 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -2,47 +2,47 @@ tensorflow/contrib/boosted_trees/proto/learner.proto
tensorflow/contrib/boosted_trees/proto/quantiles.proto
tensorflow/contrib/boosted_trees/proto/split_info.proto
tensorflow/contrib/boosted_trees/proto/tree_config.proto
-tensorflow/core/util/test_log.proto
-tensorflow/core/util/saved_tensor_slice.proto
-tensorflow/core/util/memmapped_file_system.proto
-tensorflow/core/util/event.proto
-tensorflow/core/protobuf/tensorflow_server.proto
-tensorflow/core/protobuf/saver.proto
-tensorflow/core/protobuf/queue_runner.proto
-tensorflow/core/protobuf/named_tensor.proto
-tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/example/example.proto
+tensorflow/core/example/feature.proto
+tensorflow/core/framework/allocation_description.proto
+tensorflow/core/framework/api_def.proto
+tensorflow/core/framework/attr_value.proto
+tensorflow/core/framework/cost_graph.proto
+tensorflow/core/framework/device_attributes.proto
+tensorflow/core/framework/function.proto
+tensorflow/core/framework/graph.proto
+tensorflow/core/framework/graph_transfer_info.proto
+tensorflow/core/framework/kernel_def.proto
+tensorflow/core/framework/log_memory.proto
+tensorflow/core/framework/node_def.proto
+tensorflow/core/framework/op_def.proto
+tensorflow/core/framework/reader_base.proto
+tensorflow/core/framework/remote_fused_graph_execute_info.proto
+tensorflow/core/framework/resource_handle.proto
+tensorflow/core/framework/step_stats.proto
+tensorflow/core/framework/summary.proto
+tensorflow/core/framework/tensor.proto
+tensorflow/core/framework/tensor_description.proto
+tensorflow/core/framework/tensor_shape.proto
+tensorflow/core/framework/tensor_slice.proto
+tensorflow/core/framework/types.proto
+tensorflow/core/framework/variable.proto
+tensorflow/core/framework/versions.proto
+tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+tensorflow/core/lib/core/error_codes.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/device_properties.proto
+tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/named_tensor.proto
+tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/rewriter_config.proto
+tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/tensor_bundle.proto
-tensorflow/core/lib/core/error_codes.proto
-tensorflow/core/kernels/boosted_trees/boosted_trees.proto
-tensorflow/core/framework/versions.proto
-tensorflow/core/framework/variable.proto
-tensorflow/core/framework/types.proto
-tensorflow/core/framework/tensor_slice.proto
-tensorflow/core/framework/tensor_shape.proto
-tensorflow/core/framework/tensor_description.proto
-tensorflow/core/framework/tensor.proto
-tensorflow/core/framework/summary.proto
-tensorflow/core/framework/step_stats.proto
-tensorflow/core/framework/resource_handle.proto
-tensorflow/core/framework/remote_fused_graph_execute_info.proto
-tensorflow/core/framework/reader_base.proto
-tensorflow/core/framework/api_def.proto
-tensorflow/core/framework/op_def.proto
-tensorflow/core/framework/node_def.proto
-tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/kernel_def.proto
-tensorflow/core/framework/graph_transfer_info.proto
-tensorflow/core/framework/graph.proto
-tensorflow/core/framework/function.proto
-tensorflow/core/framework/device_attributes.proto
-tensorflow/core/framework/cost_graph.proto
-tensorflow/core/framework/attr_value.proto
-tensorflow/core/framework/allocation_description.proto
-tensorflow/core/example/feature.proto
-tensorflow/core/example/example.proto
-tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/protobuf/tensorflow_server.proto
+tensorflow/core/util/event.proto
+tensorflow/core/util/memmapped_file_system.proto
+tensorflow/core/util/saved_tensor_slice.proto
+tensorflow/core/util/test_log.proto
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 88798d61b7..5645784f8d 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Ops for evaluation metrics and summary statistics.
-See the @{$python/contrib.metrics} guide.
+See the
+[Contrib Metrics](https://tensorflow.org/api_guides/python/contrib.metrics)
+guide.
@@auc_with_confidence_intervals
@@streaming_accuracy
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index a328670526..bbf5d3f30c 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2532,7 +2532,8 @@ def sparse_recall_at_top_k(labels,
name=name_scope)
-def _compute_recall_at_precision(tp, fp, fn, precision, name):
+def _compute_recall_at_precision(tp, fp, fn, precision, name,
+ strict_mode=False):
"""Helper function to compute recall at a given `precision`.
Args:
@@ -2541,17 +2542,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name):
fn: The number of false negatives.
precision: The precision for which the recall will be calculated.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ no smaller than the target precision, return the corresponding recall at
+ the threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
The recall at a given `precision`.
"""
precisions = math_ops.div(tp, tp + fp + _EPSILON)
- tf_index = math_ops.argmin(
- math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ if not strict_mode:
+ tf_index = math_ops.argmin(
+ math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ # Now, we have the implicit threshold, so compute the recall:
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
+ else:
+ # We aim to find the threshold where the precision is minimum but no smaller
+ # than the target precision.
+ # The rationale:
+ # 1. Compute the difference between precisions (by different thresholds) and
+ # the target precision.
+ # 2. Take the reciprocal of the values by the above step. The intention is
+ # to make the positive values rank before negative values and also the
+ # smaller positives rank before larger positives.
+ tf_index = math_ops.argmax(
+ math_ops.div(1.0, precisions - precision + _EPSILON),
+ 0,
+ output_type=dtypes.int32)
+
+ def _return_good_recall():
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
- # Now, we have the implicit threshold, so compute the recall:
- return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
- name)
+ return control_flow_ops.cond(precisions[tf_index] >= precision,
+ _return_good_recall, lambda: .0)
def recall_at_precision(labels,
@@ -2561,7 +2587,8 @@ def recall_at_precision(labels,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
- name=None):
+ name=None,
+ strict_mode=False):
"""Computes `recall` at `precision`.
The `recall_at_precision` function creates four local variables,
@@ -2593,6 +2620,11 @@ def recall_at_precision(labels,
updates_collections: An optional list of collections that `update_op` should
be added to.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ above the target precision, return the corresponding recall at the
+ threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
recall: A scalar `Tensor` representing the recall at the given
@@ -2621,10 +2653,11 @@ def recall_at_precision(labels,
predictions, labels, thresholds, weights)
recall = _compute_recall_at_precision(values['tp'], values['fp'],
- values['fn'], precision, 'value')
+ values['fn'], precision, 'value',
+ strict_mode)
update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
update_ops['fn'], precision,
- 'update_op')
+ 'update_op', strict_mode)
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
index 7acfc383eb..5777e64c29 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
# code used float32 for accumulation.
num_updates = 71
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_updates):
sess.run(update_op)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 401fedcbed..955b83b44d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2)
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
sess.run(variables.local_variables_initializer())
@@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights)
@@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights_placeholder)
@@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(1, tp_update_op.eval())
@@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(37.0, tp_update_op.eval())
@@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(2, fn_update_op.eval())
@@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(8.0, fn_update_op.eval())
@@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(4, fp_update_op.eval())
@@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase):
weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
29.0, 31.0)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(42.0, fp_update_op.eval())
@@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(5, tn_update_op.eval())
@@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(15.0, tn_update_op.eval())
@@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
@@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fpr.eval())
@@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 2.0 + 5.0
weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 1.0 + 3.0
weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
@@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fpr.eval())
@@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase):
labels = array_ops.ones((1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fnr.eval())
@@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 4.0
weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
@@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fnr.eval())
@@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase):
labels = array_ops.zeros((1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase):
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
@@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testPredictionsOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve)
@@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
def testAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
@@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase):
np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
for _ in xrange(10):
@@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testAllLabelsOnes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([1, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testAllLabelsZeros(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([0, 0, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testNonZeroOnePredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
@@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0, 1, 0])
labels = constant_op.constant([0, 1, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testExceptionOnIncompatibleShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([5])
labels = array_ops.zeros([6])
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
@@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval())
def testAUCPRReverseIncreasingPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1])
@@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
def testAUCPRJumbledPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
@@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
def testAUCPRPredictionsLessThanHalf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase):
auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
tf_predictions,
weights=tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
expected_result: The expected result (dict) that maps to tensors.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
@@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
dtype=dtypes_lib.float32)
auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertAllClose(expected_auc, auc.auc.eval())
def testExceptionOnFloatLabels(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([0.7, 0, 1, 0, 1])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertRaises(TypeError, sess.run(update_op))
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
result, update_op = metric_ops.precision_recall_at_equal_thresholds(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run several updates.
sess.run(variables.local_variables_initializer())
for _ in range(3):
@@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
default from assertAllClose.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(predictions, dtype=dtype)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
weights_tensor = None
@@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertEqual(0, fpr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fpr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=1.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, recall.eval())
@@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, recall.eval())
@@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3461,12 +3461,66 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, weights=weights, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
self.assertAlmostEqual(target_recall, recall.eval())
+ def _test_strict_mode(self, strict_mode, target_precision, expected_recall):
+ num_thresholds = 11
+ predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1]
+ labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1]
+ # Resulting thresholds and the corresponding precision and recall values at
+ # each threshold:
+ # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
+ # precisions: [0.3 0.2 0.1 0 0 0 0 0 0]
+ # recalls: [1.0 0.7 0.3 0 0 0 0 0 0]
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ labels,
+ predictions,
+ num_thresholds=num_thresholds,
+ precision=target_precision,
+ strict_mode=strict_mode)
+
+ with self.cached_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expected_recall, sess.run(update_op))
+ self.assertAlmostEqual(expected_recall, recall.eval())
+
+ def testStrictMode_Off(self):
+ # strict_mode is turned off and return the recall at the threshold where the
+ # precision (0.3) is closest to target precision (0.9). The recall
+ # corresponding to the threshold is 1.0.
+ self._test_strict_mode(
+ strict_mode=False, target_precision=0.9, expected_recall=1.0)
+
+ def testStrictMode_OnAndFail(self):
+ # strict_mode is turned on and we fail to reach the target precision at any
+ # threshold.
+ # Target precision: 0.9
+ # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9]
+ # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1]
+ # Max index: 3 and corresponding precision is: 0 which is smaller than
+ # target precsion 0.9. As a result, the expected recall is 0.
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.9, expected_recall=.0)
+
+ def testStrictMode_OnAndSucceed(self):
+ # strict_mode is on and we can reach the target precision at certain
+ # threshold.
+ # Target precision: 0.2
+ # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2]
+ # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0]
+ # Max index: 1 and corresponding precision is: 0.2 which is no smaller than
+ # target precsion 0.2. In this case, we return the recall at index 1, which
+ # is 2.0/3 (0.7).
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3)
+
class PrecisionAtRecallTest(test.TestCase):
@@ -3511,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3531,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, precision.eval())
@@ -3545,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(sess.run(label_prior), sess.run(update_op))
self.assertEqual(sess.run(label_prior), precision.eval())
@@ -3560,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, precision.eval())
@@ -3575,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(2.0/3, sess.run(update_op))
self.assertAlmostEqual(2.0/3, precision.eval())
@@ -3594,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(34.0/43, sess.run(update_op))
self.assertAlmostEqual(34.0/43, precision.eval())
@@ -3643,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3658,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3671,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertEqual(0, fnr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3687,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3700,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fnr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3720,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3740,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3758,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3798,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3886,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.25, sess.run(update_op))
self.assertEqual(0.25, recall.eval())
@@ -3904,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
@@ -3922,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -3946,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase):
k=2,
weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -3963,7 +4017,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
expected,
class_id=None,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metrics.streaming_sparse_precision_at_k(
@@ -3992,7 +4046,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
expected,
class_id=None,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metrics.streaming_sparse_precision_at_top_k(
@@ -4021,7 +4075,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
k,
expected,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
predictions = constant_op.constant(predictions, dtypes_lib.float32)
@@ -4047,7 +4101,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
labels,
expected,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metrics.streaming_sparse_average_precision_at_top_k(
@@ -4068,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
self.assertAlmostEqual(expected, metric.eval())
def test_top_k_rank_invalid(self):
- with self.test_session():
+ with self.cached_session():
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
@@ -4615,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
expected_precision = 0.5
- with self.test_session():
+ with self.cached_session():
_, precision = metrics.streaming_sparse_precision_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -4635,7 +4689,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected,
class_id=None,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metrics.streaming_sparse_recall_at_k(
@@ -4664,7 +4718,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected,
class_id=None,
weights=None):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metric_ops.sparse_recall_at_top_k(
@@ -5320,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
- with self.test_session():
+ with self.cached_session():
_, recall = metrics.streaming_sparse_recall_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5364,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5386,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -5430,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5455,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -5471,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -5509,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5527,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -5540,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -5555,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5586,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5629,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5691,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_root_mean_squared_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5704,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -5718,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5732,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5788,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase):
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
cov, update_op = metrics.streaming_covariance(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5801,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertEqual(initial_cov, cov.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5813,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5827,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5845,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5879,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase):
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5969,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase):
pearson_r, update_op = metrics.streaming_pearson_correlation(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5982,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertEqual(initial_r, pearson_r.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5995,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -6010,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = np.array([2, 4, 6, 8])
labels = np.array([1, 3, 2, 7])
weights = np.array([0, 1, 3, 1])
@@ -6031,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6066,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6108,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6189,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6212,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6229,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -6251,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -6270,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6289,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -6324,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -6344,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -6421,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase):
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6435,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6467,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6515,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -6557,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[7])
], 0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6570,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6581,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6603,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -6618,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6630,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6644,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6679,7 +6733,7 @@ class StreamingConcatTest(test.TestCase):
def testNextArraySize(self):
next_array_size = metric_ops._next_array_size # pylint: disable=protected-access
- with self.test_session():
+ with self.cached_session():
self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
@@ -6687,7 +6741,7 @@ class StreamingConcatTest(test.TestCase):
self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
def testStreamingConcat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6704,7 +6758,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual(np.arange(10), concatenated.eval())
def testStreamingConcatStringValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.string, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6723,7 +6777,7 @@ class StreamingConcatTest(test.TestCase):
concatenated.eval())
def testStreamingConcatMaxSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = math_ops.range(3)
concatenated, update_op = metrics.streaming_concat(values, max_size=5)
sess.run(variables.local_variables_initializer())
@@ -6740,7 +6794,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
def testStreamingConcat2D(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.reshape(math_ops.range(3), (3, 1))
concatenated, update_op = metrics.streaming_concat(values, axis=-1)
sess.run(variables.local_variables_initializer())
@@ -6763,7 +6817,7 @@ class StreamingConcatTest(test.TestCase):
array_ops.placeholder(dtypes_lib.float32, [None, None]))
def testStreamingConcatReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6791,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean(values))
self.assertEqual(len(value_tensors), 1)
self.assertEqual(len(update_ops), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, update_ops[0].eval())
self.assertEqual(1, value_tensors[0].eval())
@@ -6804,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean_squared_error(predictions, labels))
self.assertEqual(len(value_tensors), 2)
self.assertEqual(len(update_ops), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, update_ops[0].eval())
self.assertEqual(4, update_ops[1].eval())
@@ -6825,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(2, len(names_to_values))
self.assertEqual(2, len(names_to_updates))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, names_to_updates['m1'].eval())
self.assertEqual(4, names_to_updates['m2'].eval())
@@ -6860,7 +6914,7 @@ class CountTest(test.TestCase):
self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6877,7 +6931,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6898,7 +6952,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6925,7 +6979,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -6947,7 +7001,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6974,7 +7028,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7047,7 +7101,7 @@ class CohenKappaTest(test.TestCase):
(10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -7081,7 +7135,7 @@ class CohenKappaTest(test.TestCase):
for dtype in dtypes:
for shape in shapes:
for weight in weights:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
np.reshape(predictions, shape), dtype=dtype)
labels_tensor = constant_op.constant(
@@ -7102,7 +7156,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
expect = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7121,7 +7175,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
expect = -0.333333333333
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7139,7 +7193,7 @@ class CohenKappaTest(test.TestCase):
# labels, predictions, sample_weight=weights)
expect = 0.453466583385
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(
@@ -7164,7 +7218,7 @@ class CohenKappaTest(test.TestCase):
weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for idx in range(0, num_samples, batch_size):
@@ -7202,7 +7256,7 @@ class CohenKappaTest(test.TestCase):
def testConditionalPackingOptimization(self):
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
values, update_op = metric_ops.streaming_concat(placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for feed in range(10):
sess.run(update_op, feed_dict={placeholder: [feed]})
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index 16ddc38f5a..3cffd76a25 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -113,12 +113,13 @@ py_library(
py_test(
name = "pruning_utils_test",
- size = "small",
+ size = "medium",
srcs = ["python/pruning_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":pruning_utils",
"//tensorflow/python:client_testlib",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index a5267fd904..15d95896d9 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -53,7 +53,7 @@ The pruning library allows for specification of the following hyper parameters:
| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
-| nbins | integer | 256 | Number of bins to use for histogram computation |
+| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. |
| block_height|integer | 1 | Number of rows in a block for block sparse matrices|
| block_width |integer | 1 | Number of cols in a block for block sparse matrices|
| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
index e85ae7b22a..586c6c7bfc 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
@@ -37,7 +37,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
@@ -61,7 +61,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index cd58526ed3..a81abac2fa 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -476,8 +476,8 @@ class Pruning(object):
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
- updated_mask = pruning_utils.kronecker_product(
- new_mask, array_ops.ones(self._block_dim))
+
+ updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 33c4ad58bd..cd3d8e76bb 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -61,14 +61,14 @@ class PruningHParamsTest(test.TestCase):
self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8)
def testInitWithExternalSparsity(self):
- with self.test_session():
+ with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
def testInitWithVariableReuse(self):
- with self.test_session():
+ with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
p_copy = pruning.Pruning(
spec=self.pruning_hparams, sparsity=self.sparsity)
@@ -87,7 +87,7 @@ class PruningTest(test.TestCase):
def testCreateMask2D(self):
width = 10
height = 20
- with self.test_session():
+ with self.cached_session():
weights = variables.Variable(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
@@ -98,7 +98,7 @@ class PruningTest(test.TestCase):
self.assertAllEqual(weights_val, masked_weights_val)
def testUpdateSingleMask(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
weights = variables.Variable(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
@@ -122,7 +122,7 @@ class PruningTest(test.TestCase):
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
_, new_mask = p._maybe_update_block_mask(weights, threshold)
# Check if the mask is the same size as the weights
@@ -167,7 +167,7 @@ class PruningTest(test.TestCase):
def testPartitionedVariableMasking(self):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
- with self.test_session() as session:
+ with self.cached_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
sparsity = variables.Variable(0.5, name="Sparsity")
weights = variable_scope.get_variable(
@@ -201,7 +201,7 @@ class PruningTest(test.TestCase):
sparsity_val = math_ops.linspace(0.0, 0.9, 10)
increment_global_step = state_ops.assign_add(self.global_step, 1)
non_zero_count = []
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
for i in range(10):
session.run(state_ops.assign(sparsity, sparsity_val[i]))
@@ -234,7 +234,7 @@ class PruningTest(test.TestCase):
mask_update_op = p.conditional_mask_update_op()
increment_global_step = state_ops.assign_add(self.global_step, 1)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
for _ in range(110):
session.run(mask_update_op)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index ef6c6a3f5d..91b0bb7f60 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -69,7 +69,7 @@ def weight_threshold_variable(var, scope):
scope: The variable scope of the variable var
Returns:
- a scalar threshold variable initialized to 0.
+ A scalar threshold variable initialized to 0.
"""
with variable_scope.variable_scope(scope):
threshold = variable_scope.get_variable(
@@ -97,6 +97,74 @@ def kronecker_product(mat1, mat2):
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+def expand_tensor(tensor, block_dims):
+ """Expands a 2D tensor by replicating the tensor values.
+
+ This is equivalent to the kronecker product of the tensor and a matrix of
+ ones of size block_dims.
+
+ Example:
+
+ tensor = [[1,2]
+ [3,4]]
+ block_dims = [2,2]
+
+ result = [[1 1 2 2]
+ [1 1 2 2]
+ [3 3 4 4]
+ [3 3 4 4]]
+
+ Args:
+ tensor: A 2D tensor that needs to be expanded.
+ block_dims: List of integers specifying the expansion factor.
+
+ Returns:
+ The expanded tensor
+
+ Raises:
+ ValueError: if tensor is not rank-2 or block_dims is does not have 2
+ elements.
+ """
+ if tensor.get_shape().ndims != 2:
+ raise ValueError('Input tensor must be rank 2')
+
+ if len(block_dims) != 2:
+ raise ValueError('block_dims must have 2 elements')
+
+ block_height, block_width = block_dims
+
+ def _tile_rows(tensor, multiple):
+ """Create a new tensor by tiling the tensor along rows."""
+ return array_ops.tile(tensor, [multiple, 1])
+
+ def _generate_indices(num_rows, block_dim):
+ indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32)
+ for k in range(block_dim):
+ for r in range(num_rows):
+ indices[k * num_rows + r] = r * block_dim + k
+ return indices
+
+ def _replicate_rows(tensor, multiple):
+ tensor_shape = tensor.shape.as_list()
+ expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
+ indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple))
+ return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple),
+ expanded_shape)
+
+ expanded_tensor = tensor
+
+ # Expand rows by factor block_height.
+ if block_height > 1:
+ expanded_tensor = _replicate_rows(tensor, block_height)
+
+ # Transpose and expand by factor block_width. Transpose the result.
+ if block_width > 1:
+ expanded_tensor = array_ops.transpose(
+ _replicate_rows(array_ops.transpose(expanded_tensor), block_width))
+
+ return expanded_tensor
+
+
def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
"""Return histogram of values.
@@ -167,19 +235,18 @@ def compute_cdf_from_histogram(values, value_range, **kwargs):
def compute_cdf(values, value_range, **kwargs):
"""Returns the normalized cumulative distribution of the given values tensor.
- Uses tf.while_loop to directly compute the cdf of the values. Number of bins
- for histogram is fixed at _NBINS=255
+ Uses tf.while_loop to directly compute the cdf of the values.
Args:
values: Numeric `Tensor`.
value_range: Shape [2] `Tensor` of same `dtype` as `values`
- **kwargs: keyword arguments: name
+ **kwargs: keyword arguments: nbins, name
Returns:
A 1-D `Tensor` holding normalized cdf of values.
"""
- nbins = _NBINS
+ nbins = kwargs.get('nbins', _NBINS)
name = kwargs.get('name', None)
with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
values = ops.convert_to_tensor(values, name='values')
@@ -213,7 +280,7 @@ def compute_cdf(values, value_range, **kwargs):
cdf = math_ops.add(
cdf,
array_ops.one_hot(
- loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
+ loop_count, depth=nbins, on_value=temp, off_value=0.0))
return [loop_count + 1, cdf]
_, cdf = control_flow_ops.while_loop(
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index ccde5b4e8a..0aca843497 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -36,27 +38,13 @@ class PruningUtilsTest(test.TestCase):
def _compare_cdf(self, values):
abs_values = math_ops.abs(values)
max_value = math_ops.reduce_max(abs_values)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
- def _compare_pooling_methods(self, weights, pooling_kwargs):
- with self.test_session():
- variables.global_variables_initializer().run()
- pooled_weights_tf = array_ops.squeeze(
- nn_ops.pool(
- array_ops.reshape(
- weights,
- [1, weights.get_shape()[0],
- weights.get_shape()[1], 1]), **pooling_kwargs))
- pooled_weights_factorized_pool = pruning_utils.factorized_pool(
- weights, **pooling_kwargs)
- self.assertAllClose(pooled_weights_tf.eval(),
- pooled_weights_factorized_pool.eval())
-
def testHistogram(self):
width = 10
height = 10
@@ -67,7 +55,7 @@ class PruningUtilsTest(test.TestCase):
"weights", [width, height], initializer=init)
histogram = pruning_utils._histogram(
weights, [0, 1.0], nbins, dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
computed_histogram = histogram.eval()
self.assertAllEqual(expected_histogram, computed_histogram)
@@ -79,7 +67,7 @@ class PruningUtilsTest(test.TestCase):
norm_cdf = pruning_utils.compute_cdf_from_histogram(
abs_weights, [0.0, 5.0], nbins=nbins)
expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
norm_cdf_val = sess.run(norm_cdf)
self.assertAllEqual(len(norm_cdf_val), nbins)
@@ -95,26 +83,60 @@ class PruningUtilsTest(test.TestCase):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
self._compare_cdf(weights)
- def testFactorizedAvgPool(self):
+
+@parameterized.named_parameters(
+ ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]),
+ ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1]))
+class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
+
+ def _compare_pooling_methods(self, weights, pooling_kwargs):
+ with self.cached_session():
+ variables.global_variables_initializer().run()
+ pooled_weights_tf = array_ops.squeeze(
+ nn_ops.pool(
+ array_ops.reshape(
+ weights,
+ [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]), **pooling_kwargs))
+ pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+ weights, **pooling_kwargs)
+ self.assertAllClose(pooled_weights_tf.eval(),
+ pooled_weights_factorized_pool.eval())
+
+ def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
+ with self.cached_session() as session:
+ variables.global_variables_initializer().run()
+ expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
+ kronecker_product = pruning_utils.kronecker_product(
+ tensor, array_ops.ones(block_dim))
+ expanded_tensor_val, kronecker_product_val = session.run(
+ [expanded_tensor, kronecker_product])
+ self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
+
+ def testFactorizedAvgPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "AVG",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
- def testFactorizedMaxPool(self):
+ def testFactorizedMaxPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "MAX",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
+ def testExpandTensor(self, block_dim):
+ weights = random_ops.random_normal(shape=[1024, 512])
+ self._compare_expand_tensor_with_kronecker_product(weights, block_dim)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py
index 255daa0360..237510cb0c 100644
--- a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py
+++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py
@@ -144,7 +144,7 @@ class StripPruningVarsTest(test.TestCase):
return outputs
def _get_initial_outputs(self, output_tensor_names_list):
- with self.test_session(graph=self.initial_graph) as sess1:
+ with self.session(graph=self.initial_graph) as sess1:
self._prune_model(sess1)
reference_outputs = self._get_outputs(sess1, self.initial_graph,
output_tensor_names_list)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index 09fad35d23..7d158cc980 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
-#define TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#ifndef TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
+#define TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
#ifdef GOOGLE_CUDA
@@ -135,4 +135,4 @@ class NcclManager {
#endif // GOOGLE_CUDA
-#endif // TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#endif // TENSORFLOW_CONTRIB_NCCL_KERNELS_NCCL_MANAGER_H_
diff --git a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
index cb69c72970..d0955cbe11 100644
--- a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
+++ b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
@@ -31,7 +31,7 @@ class HyperplaneLshProbesTest(test.TestCase):
# tests in hyperplane_lsh_probes_test.cc already cover most of the LSH
# functionality.
def simple_batch_test(self):
- with self.test_session():
+ with self.cached_session():
hyperplanes = np.eye(4)
points = np.array([[1.2, 0.5, -0.9, -1.0], [2.0, -3.0, 1.0, -1.5]])
product = np.dot(points, hyperplanes)
diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py
index 54a98e6f14..3aec88bcbf 100644
--- a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py
+++ b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py
@@ -32,7 +32,7 @@ class AlphaDropoutTest(test.TestCase):
def testAlphaDropout(self):
x_dim, y_dim = 40, 30
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = random_ops.random_normal([x_dim, y_dim])
output = alpha_dropout(t, keep_prob)
self.assertEqual([x_dim, y_dim], output.get_shape())
diff --git a/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py
index 56062c3cab..4cdac6a742 100644
--- a/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py
+++ b/tensorflow/contrib/nn/python/ops/fwd_gradients_test.py
@@ -35,7 +35,7 @@ class ForwardAdTest(test.TestCase):
dydx_tf = fwd_gradients.fwd_gradients([y], [x], [grad_x])[0]
dydx_py = 2. * grad_x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose(sess.run(dydx_tf), dydx_py, 1e-6)
def testGather(self):
@@ -44,7 +44,7 @@ class ForwardAdTest(test.TestCase):
y.set_shape([2])
dydx = fwd_gradients.fwd_gradients([y], [x], assert_unused=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(dydx)
diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops_test.py b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py
index 1d4fe1321b..11738bb215 100644
--- a/tensorflow/contrib/nn/python/ops/sampling_ops_test.py
+++ b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py
@@ -227,7 +227,7 @@ class RankSampledSoftmaxLossTest(test.TestCase):
sampled_values=self._resampled_values,
remove_accidental_hits=self._remove_accidental_hits,
partition_strategy=partition_strategy)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss_val = sess.run(loss)
loss_nn_val = sess.run(loss_nn)
@@ -299,7 +299,7 @@ class RankSampledSoftmaxLossTest(test.TestCase):
sampled_values=resampled_values,
remove_accidental_hits=remove_accidental_hits,
partition_strategy='div')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss_val = sess.run(loss)
loss_nn_val = sess.run(loss_nn)
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 5319a8b655..2e4d61d931 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -22,6 +22,7 @@ py_library(
"python/training/ggt.py",
"python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
+ "python/training/matrix_functions.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
"python/training/multitask_optimizer_wrapper.py",
@@ -158,8 +159,10 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -381,3 +384,18 @@ py_test(
"@six_archive//:six",
],
)
+
+py_test(
+ name = "matrix_functions_test",
+ srcs = ["python/training/matrix_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 781621dba0..ad7d7cfa6e 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
+from tensorflow.contrib.opt.python.training.reg_adagrad_optimizer import *
from tensorflow.contrib.opt.python.training.shampoo import *
from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
@@ -65,6 +66,7 @@ _allowed_symbols = [
'ModelAverageCustomGetter',
'GGTOptimizer',
'ShampooOptimizer',
+ 'RegAdagradOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py
index 915e6504e1..61d8b94eca 100644
--- a/tensorflow/contrib/opt/python/training/adamax_test.py
+++ b/tensorflow/contrib/opt/python/training/adamax_test.py
@@ -74,7 +74,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype)
m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
@@ -142,7 +142,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -172,7 +172,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -233,7 +233,7 @@ class AdaMaxOptimizerTest(test.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -242,7 +242,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -278,7 +278,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 5763593b81..6c203e5519 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -17,22 +17,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-
-from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import constant_op
LOCAL_VARIABLE_NAME = 'local_center_variable'
GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GLOBAL_STEP = 'global_step'
class ElasticAverageCustomGetter(object):
@@ -52,16 +53,32 @@ class ElasticAverageCustomGetter(object):
with tf.device(
tf.train.replica_device_setter(
worker_device=worker_device,
- ps_device="/job:ps/cpu:0",
+ ps_device="/job:ps",
cluster=cluster)),
tf.variable_scope('',custom_getter=ea_custom_getter):
- hid_w = tf.get_variable(
- initializer=tf.truncated_normal(
- [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
- stddev=1.0 / IMAGE_PIXELS),
- name="hid_w")
- hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
- name="hid_b")
+ ...
+ create your model here
+ ...
+ with tf.device(worker_device):
+ opt = tf.train.MomentumOptimizer(...)
+ optimizer = ElasticAverageOptimizer(
+ opt,
+ num_worker=2,
+ moving_rate=0.01, # or use default value
+ communication_period=20,
+ ea_custom_getter=ea_custom_getter)
+ ...
+ train_op = optimizer.apply_gradients(
+ grads_vars,
+ global_step=global_step)
+ ...
+ hooks = [optimizer.make_session_run_hook(is_chief, task_index)]
+ ...
+ with tf.train.MonitoredTrainingSession(master=server.target,
+ is_chief=is_chief,
+ checkpoint_dir=("...),
+ save_checkpoint_secs=600,
+ hooks=hooks) as mon_sess:
"""
def __init__(self, worker_device):
@@ -83,24 +100,42 @@ class ElasticAverageCustomGetter(object):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
*args,
**kwargs)
- global_center_variable = variable_scope.variable(
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
with ops.device(self._worker_device):
- local_center_variable = variable_scope.variable(
+ local_center_variable = getter(
name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
-
- self._local_map[local_var] = local_center_variable
- self._global_map[local_var] = global_center_variable
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._local_map[local_var] = local_center_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._local_map[v_list[i]] \
+ = list(local_center_variable)[i]
+ self._global_map[v_list[i]] \
+ = list(global_center_variable)[i]
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
+
class ElasticAverageOptimizer(optimizer.Optimizer):
@@ -125,6 +160,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
moving_rate=None,
rho=None,
use_locking=True,
+ synchronous=False,
name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
@@ -136,9 +172,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
communication_period: An int point value to controls the frequency
of the communication between every worker and the ps.
moving_rate: A floating point value to control the elastic difference.
- rho: the amount of exploration we allow ine the model. The default
+ rho: the amount of exploration we allow in the model. The default
value is moving_rate/learning_rate
+ rho=0.0 is suggested in async mode.
use_locking: If True use locks for update operations.
+ synchronous: Add_sync_queues_and_barrier or not.
+ True: all workers will wait for each other before start training
+ False: worker can start training when its initilization is done,
+ no need to wait for everyone is ready.
+ in case one worker is restarted, it can join and continue
+ training without being blocked.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "ElasticAverageOptimizer".
"""
@@ -148,6 +191,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._period = communication_period
self._local_map = ea_custom_getter._local_map
self._global_map = ea_custom_getter._global_map
+ self._synchronous = synchronous
if moving_rate is None:
self._moving_rate = self.BETA / communication_period / num_worker
@@ -241,11 +285,29 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
+ global_old = set(n.op.name for n in variables.global_variables())
apply_updates = self._opt.apply_gradients(grads_and_vars)
+ global_new = set(n.op.name for n in variables.global_variables())
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
self._local_step, 1, name='local_step_update').op
+ # this is for place the variables created by optimizer to local collection
+ # e.g., AdamOptimizer will create beta as global variables
+ def _adjust_optimizer_variable_collection(opt_vars):
+ g = ops.get_default_graph()
+ idx = 0
+ for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])):
+ var = g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ name = var.op.name
+ if name in opt_vars:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var)
+ del g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx]
+ else:
+ idx += 1
+
+ _adjust_optimizer_variable_collection(global_new - global_old)
+
# update global variables.
def _Update_global_variables():
local_vars = [v for g, v in grads_and_vars if g is not None]
@@ -290,7 +352,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
variables equal to the global center variables before the training begins"""
def _Add_sync_queues_and_barrier(enqueue_after_list):
- """Adds ops to enqueu on all worker queues"""
+ """Adds ops to enqueue on all worker queues"""
sync_queues = [
data_flow_ops.FIFOQueue(
self._num_worker, [dtypes.bool],
@@ -324,6 +386,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
init_ops.append(state_ops.assign(lc_var, gc_var))
init_op = control_flow_ops.group(*(init_ops))
+ if self._synchronous == False:
+ return init_op
+
sync_queue_op = _Add_sync_queues_and_barrier([init_op])
return sync_queue_op
@@ -331,6 +396,51 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
"""Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization."""
return _ElasticAverageOptimizerHook(self, is_chief, task_index)
+ def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
+ """Create a saver copy global_center_variable to trainable variables
+ Please call this function after all your variables created with
+ ElasticAverageCustomGetter. For evaluations or inference, use this saver
+ during training. It will save the global_center_variable of the trained
+ parameters under the original parameter names.
+ Args:
+ var_list: List of variables to save, as per `Saver()`.
+ If set to None, save all the trainable_variables that have
+ been created before this call.
+ name: The name of the saver.
+ **kwargs: Keyword arguments of `Saver()`.
+ Returns:
+ A `tf.train.Saver` object.
+ Raises:
+ RuntimeError: global_center_variable is empty, please make sure
+ this is called after model created and
+ ElasticAverageCustomGetter is used when declaring you model
+ """
+ if not self._global_map:
+ raise RuntimeError('global_center_variable is empty, please make sure '
+ 'this is called after model created and '
+ 'ElasticAverageCustomGetter is used when declaring '
+ 'you model')
+
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ if not isinstance(var_list, dict):
+ var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+ swapped_var_list = {}
+ for key, var in var_list.items():
+ tensor = var
+
+ if not isinstance(var, list):
+ for tvar in variables.trainable_variables():
+ if tvar.op.name == var.op.name:
+ tensor = self._global_map.get(tvar, var)
+ break
+ else: #partitioned variable
+ tensor = [self._global_map.get(lvar, lvar) for lvar in var]
+
+ swapped_var_list[key] = tensor
+
+ return saver.Saver(swapped_var_list, name=name, **kwargs)
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
@@ -351,3 +461,7 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
if self._is_chief:
self._global_init_op = variables.global_variables_initializer()
self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op) \ No newline at end of file
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 5ed8057b86..5bf6a08de1 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -17,17 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import portpicker
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import device_setter
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import \
ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME
@@ -59,29 +64,49 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
# Creates the workers and return their sessions, graphs, train_ops.
# Chief worker will update at last
-def _get_workers(num_workers, period, workers, moving_rate):
+def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
sessions = []
graphs = []
train_ops = []
+ savers = []
for worker_id in range(num_workers):
graph = ops.Graph()
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
with variable_scope.variable_scope(
- "", custom_getter=ea_coustom), ops.device(
+ "", custom_getter=ea_custom), ops.device(
device_setter.replica_device_setter(
worker_device=worker_device,
ps_device="/job:ps/task:0/cpu:0",
ps_tasks=1)):
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ global_step = training_util.get_or_create_global_step()
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ if num_ps > 1:
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0),
+ custom_getter=ea_custom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=num_ps)):
+
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(-1.0)
grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
opt = ElasticAverageOptimizer(
@@ -89,12 +114,22 @@ def _get_workers(num_workers, period, workers, moving_rate):
num_worker=num_workers,
moving_rate=moving_rate,
communication_period=period,
- ea_custom_getter=ea_coustom)
- train_op = [
- opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
- global_step)
- ]
+ ea_custom_getter=ea_custom)
+ if num_ps == 1:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
+ ]
+ else:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0],
+ [grads_1, var_1],
+ [grads_part_0, part_0],
+ [grads_part_1, part_1]),
+ global_step)
+ ]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
+ saver = opt.swapping_saver()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
workers[worker_id].target, hooks=[easgd_hook])
@@ -102,8 +137,9 @@ def _get_workers(num_workers, period, workers, moving_rate):
sessions.append(sess)
graphs.append(graph)
train_ops.append(train_op)
+ savers.append(saver)
- return sessions, graphs, train_ops
+ return sessions, graphs, train_ops, savers
class ElasticAverageOptimizerTest(test.TestCase):
@@ -118,7 +154,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
+ sessions, graphs, train_ops, savers = _get_workers(
num_workers, communication_period, workers, 1.0)
var_0 = graphs[0].get_tensor_by_name("v0:0")
@@ -158,6 +194,21 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(2.0, sessions[0].run(var_0_g))
self.assertAllEqual(3.0, sessions[0].run(var_1_g))
self.assertAllEqual(1, sessions[0].run(global_step))
+ sessions[0].run(train_ops[0])
+
+ # save, data will be global value
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+ ops.reset_default_graph() # restore on a new graph
+ with session.Session() as sess:
+ v0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ v1 = variable_scope.get_variable(initializer=1.0, name="v1")
+ sess.run(variables.local_variables_initializer())
+ saver_opt = saver.Saver(var_list=[v1, v0])
+ saver_opt.restore(sess, outfile)
+ self.assertAllEqual(2.0, sess.run(v0))
+ self.assertAllEqual(3.0, sess.run(v1))
def test2Worker1Period(self):
num_workers = 2
@@ -166,8 +217,8 @@ class ElasticAverageOptimizerTest(test.TestCase):
cluster, workers, _ = create_local_cluster(
num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(
- num_workers, communication_period, workers, 0.5)
+ sessions, graphs, train_ops, savers = _get_workers(
+ num_workers, communication_period, workers, 0.5, num_ps=2)
var_0 = graphs[0].get_tensor_by_name("v0:0")
var_1 = graphs[0].get_tensor_by_name("v1:0")
@@ -177,6 +228,9 @@ class ElasticAverageOptimizerTest(test.TestCase):
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(
+ GLOBAL_VARIABLE_NAME + "/partition_var/part_0:0")
+
# Verify the initialized value.
self.assertAllEqual(0.0, sessions[0].run(var_0))
self.assertAllEqual(1.0, sessions[0].run(var_1))
@@ -194,22 +248,45 @@ class ElasticAverageOptimizerTest(test.TestCase):
self.assertAllEqual(1.75, sessions[0].run(var_1_g))
self.assertAllEqual(0.75, sessions[1].run(var_0_1))
self.assertAllEqual(1.75, sessions[1].run(var_1_1))
+ # part_0 of global_center copy
+ part_0_g = sessions[0].run(part_0_g)
+
+ outfile = os.path.join(test.get_temp_dir(), "model")
+ savers[0].save(sessions[0]._sess._sess._sess._sess,
+ save_path=outfile)
+
+ # verify restore of partitioned_variables
+ ops.reset_default_graph() # restore on a new graph
+ g = ops.get_default_graph()
+ with session.Session() as sess, g.as_default():
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0)):
+ partition_var = variable_scope.get_variable(
+ 'partition_var',
+ shape=[2, 4],
+ initializer=init_ops.ones_initializer)
+ s = saver.Saver(var_list=[partition_var])
+ s.restore(sess, outfile)
+ part_0 = g.get_tensor_by_name('partition_var/part_0:0')
+ self.assertAllEqual(part_0_g, sess.run(part_0))
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
+ ea_custom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope("", custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_custom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
w = variable_scope.get_variable(initializer=[2, 1], name="w")
- v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
+ v_g, w_g = ea_custom._global_map[v], ea_custom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
index 953586ee70..9997103016 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
@@ -85,7 +85,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -107,7 +107,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
@@ -164,7 +164,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
self._objective(x), method=method, options=options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -176,7 +176,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
x = variables.Variable(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -242,7 +242,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -260,7 +260,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -277,7 +277,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose([0., 2.], sess.run(vector))
@@ -293,7 +293,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
method = optimizer.optimizer_kwargs.get('method')
@@ -312,7 +312,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py
index 42162960b0..1775edabb3 100644
--- a/tensorflow/contrib/opt/python/training/ggt_test.py
+++ b/tensorflow/contrib/opt/python/training/ggt_test.py
@@ -76,7 +76,7 @@ class GGTOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
# SVD does not support float16
for i, dtype in enumerate([dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0 = 0.0
window = 3
@@ -171,7 +171,7 @@ class GGTOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
index d94249b994..b76db763da 100644
--- a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
@@ -31,7 +31,7 @@ class LARSOptimizerTest(test.TestCase):
def testLARSGradientOneStep(self):
for _ in range(10):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [3, 3]
var_np = np.ones(shape)
grad_np = np.ones(shape)
@@ -77,7 +77,7 @@ class LARSOptimizerTest(test.TestCase):
def testLARSGradientMultiStep(self):
for _ in range(10):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [3, 3]
var_np = np.ones(shape)
grad_np = np.ones(shape)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1e81..f55209ec49 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -28,6 +28,7 @@ from __future__ import print_function
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
@@ -78,3 +79,36 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
lr * m_t_slice / denominator_slice,
use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+ # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+ m = self.get_slot(var, "m")
+ m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+ m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+ indices,
+ m_t_slice)
+
+ # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+ v = self.get_slot(var, "v")
+ v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
+ v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+ indices,
+ v_t_slice)
+
+ # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+ var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+ var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+ indices,
+ var_slice)
+
+ return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index a16857db7d..f08ffaa36f 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -19,14 +19,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.opt.python.training import lazy_adam_optimizer
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -49,11 +53,12 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(test.TestCase):
+class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
- def testSparse(self):
+ @parameterized.parameters([False, True])
+ def testSparse(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -61,8 +66,13 @@ class AdamOptimizerTest(test.TestCase):
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -94,12 +104,17 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testSparseDevicePlacement(self):
+ @parameterized.parameters([False, True])
+ def testSparseDevicePlacement(self, use_resource):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
# If a GPU is available, tests that all optimizer ops can be placed on
# it (i.e. they have GPU kernels).
- var = variables.Variable([[1.0], [2.0]])
+ if use_resource:
+ var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
+ else:
+ var = variables.Variable([[1.0], [2.0]])
+
indices = constant_op.constant([0, 1], dtype=index_dtype)
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0)
@@ -107,13 +122,21 @@ class AdamOptimizerTest(test.TestCase):
variables.global_variables_initializer().run()
minimize_op.run()
- def testSparseRepeatedIndices(self):
+ @parameterized.parameters([False, True])
+ def testSparseRepeatedIndices(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- repeated_index_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
- aggregated_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
+ with self.cached_session():
+ if use_resource:
+ repeated_index_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ else:
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+
grad_repeated_index = ops.IndexedSlices(
constant_op.constant(
[0.1, 0.1], shape=[2, 1], dtype=dtype),
@@ -139,6 +162,204 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertIsNotNone(beta1_power)
+ self.assertIsNotNone(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = lazy_adam_optimizer.LazyAdamOptimizer()
+
+ with context.eager_mode():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.session(graph=g):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with self.session(graph=gg):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py
new file mode 100644
index 0000000000..baab577638
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Matrix functions contains iterative methods for M^p."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4):
+ """Iterative method to get matrix square root.
+
+ Stable iterations for the matrix square root, Nicholas J. Higham
+
+ Page 231, Eq 2.6b
+ http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
+
+ Args:
+ mat_a: the symmetric PSD matrix whose matrix square root be computed
+ mat_a_size: size of mat_a.
+ iter_count: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_a^0.5
+ """
+
+ def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
+ unused_old_mat_z, err, old_err):
+ # This method require that we check for divergence every step.
+ return math_ops.logical_and(i < iter_count, err < old_err)
+
+ def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err,
+ unused_old_err):
+ current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y))
+ current_mat_y = math_ops.matmul(mat_y, current_iterate)
+ current_mat_z = math_ops.matmul(current_iterate, mat_z)
+ # Compute the error in approximation.
+ mat_sqrt_a = current_mat_y * math_ops.sqrt(norm)
+ mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a)
+ residual = mat_a - mat_a_approx
+ current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm
+ return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_a_size))
+ mat_a = mat_a + ridge_epsilon * identity
+ norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a))
+ mat_init_y = mat_a / norm
+ mat_init_z = identity
+ init_err = norm
+
+ _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop(
+ _iter_condition, _iter_body, [
+ 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err,
+ init_err + 1.0
+ ])
+ return prev_mat_y * math_ops.sqrt(norm)
+
+
+def matrix_inverse_pth_root(mat_g,
+ mat_g_size,
+ alpha,
+ iter_count=100,
+ epsilon=1e-6,
+ ridge_epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+
+ We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
+
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
+ by Chun-Hua Guo and Nicholas J. Higham
+ SIAM Journal on Matrix Analysis and Applications,
+ 2006, Vol. 28, No. 3 : pp. 788-804
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
+
+ Args:
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g.
+ alpha: exponent, must be -1/p for p a positive integer.
+ iter_count: Maximum number of iterations.
+ epsilon: accuracy indicator, useful for early termination.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_g^alpha
+ """
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
+
+ def mat_power(mat_m, p):
+ """Computes mat_m^p, for p a positive integer.
+
+ Power p is known at graph compile time, so no need for loop and cond.
+ Args:
+ mat_m: a square matrix
+ p: a positive integer
+
+ Returns:
+ mat_m^p
+ """
+ assert p == int(p) and p > 0
+ power = None
+ while p > 0:
+ if p % 2 == 1:
+ power = math_ops.matmul(mat_m, power) if power is not None else mat_m
+ p //= 2
+ mat_m = math_ops.matmul(mat_m, mat_m)
+ return power
+
+ def _iter_condition(i, mat_m, _):
+ return math_ops.logical_and(
+ i < iter_count,
+ math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
+
+ def _iter_body(i, mat_m, mat_x):
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
+ return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m),
+ math_ops.matmul(mat_x, mat_m_i))
+
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
+ else:
+ damped_mat_g = mat_g + ridge_epsilon * identity
+ z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
+ # The best value for z is
+ # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
+ # (c_max^{1-alpha} - c_min^{1-alpha})
+ # where c_max and c_min are the largest and smallest singular values of
+ # damped_mat_g.
+ # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
+ # Can replace above line by the one below, but it is less accurate,
+ # hence needs more iterations to converge.
+ # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
+ # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
+ # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
+ # extra iterations.
+ _, _, mat_h = control_flow_ops.while_loop(
+ _iter_condition, _iter_body,
+ [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
+ return mat_h
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions_test.py b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
new file mode 100644
index 0000000000..518fa38233
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Matrix functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import matrix_functions
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class MatrixFunctionTests(test.TestCase):
+
+ def testMatrixSquareRootFunction(self):
+ """Tests for matrix square roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, 0.5)
+ mat_root = matrix_functions.matrix_square_root(mat, size)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testMatrixInversePthRootFunction(self):
+ """Tests for matrix inverse pth roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, -0.125)
+ mat_root = matrix_functions.matrix_inverse_pth_root(mat, size, -0.125)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index b6b10e500b..746df77ba2 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object):
self._local_2_global[local_var] = global_variable
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
class ModelAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index 3acd940268..b1fc50a21f 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers):
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
- with ops.device("/job:worker/task:" + str(worker_id)):
- if worker_id == 0:
- grads_0 = constant_op.constant(-1.0)
- grads_1 = constant_op.constant(-1.0)
- else:
- grads_0 = constant_op.constant(-2.0)
- grads_1 = constant_op.constant(-2.0)
- sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = model_average_optimizer.ModelAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- ma_custom_getter=ma_coustom,
- is_chief=is_chief,
- interval_steps=steps)
- train_op = [
- opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
- global_step)
- ]
- easgd_hook = opt.make_session_run_hook()
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ if worker_id == 0:
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ else:
+ grads_0 = constant_op.constant(-2.0)
+ grads_1 = constant_op.constant(-2.0)
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
+ train_op = [
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
+ ]
+ ma_hook = opt.make_session_run_hook()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
- workers[worker_id].target, hooks=[easgd_hook])
+ workers[worker_id].target, hooks=[ma_hook])
sessions.append(sess)
graphs.append(graph)
diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
index ac04ad9911..f22e724528 100644
--- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
@@ -46,7 +46,7 @@ class MovingAverageOptimizerTest(test.TestCase):
def _helpTestRun(self, use_resource=False):
for sequential_update in [True, False]:
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
orig_val0 = [1.0, 2.0]
orig_val1 = [3.0, 4.0]
var0 = variable_scope.get_variable(
@@ -165,7 +165,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.assertLess(avg_val1[i], orig_val1[i])
def testFailWhenSaverCreatedBeforeInitialized(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([1.0], name='var', dtype=dtypes.float32)
opt = moving_average_optimizer.MovingAverageOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=2.0))
@@ -187,7 +187,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.apply_gradients_called = True
return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1.2], name='var', dtype=dtypes.float32)
loss = var ** 2
wrapper_opt = WrapperOptimizer(learning_rate=2.0)
diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
index 618d8eb18d..904aa9ab13 100644
--- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
+++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
@@ -34,7 +34,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
"""
def testWrapper(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
@@ -92,7 +92,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
self.evaluate(slot1))
def testGradientClipping(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
var2 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
index 825c08a09a..85e05ce71c 100644
--- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
@@ -53,7 +53,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -106,7 +106,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/powersign.py b/tensorflow/contrib/opt/python/training/powersign.py
index 828f3c51c9..b4aa19264d 100644
--- a/tensorflow/contrib/opt/python/training/powersign.py
+++ b/tensorflow/contrib/opt/python/training/powersign.py
@@ -65,7 +65,7 @@ class PowerSignOptimizer(optimizer.Optimizer):
Example usage for PowerSign-cd (PowerSign with cosine sign decay)
```
decay_steps = 1000
- linear_decay_fn = sign_decays.get_linear_decay_fn(decay_steps)
+ linear_decay_fn = sign_decays.get_cosine_decay_fn(decay_steps)
opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn)
```
diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
index ea56e1646a..c09e2ac76d 100644
--- a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
@@ -36,7 +36,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def doTestBasic(self, use_locking=False, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
@@ -73,7 +73,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -92,7 +92,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -116,7 +116,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -144,7 +144,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype)
@@ -170,7 +170,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -194,7 +194,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -230,7 +230,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -263,7 +263,7 @@ class RegAdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
"v", initializer=constant_op.constant(1.), validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
@@ -274,7 +274,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -306,7 +306,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index 294627f42a..f161521b97 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from tensorflow.contrib.opt.python.training import matrix_functions
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -76,7 +77,7 @@ class ShampooOptimizer(optimizer.Optimizer):
learning_rate=1.0,
svd_interval=1,
precond_update_interval=1,
- epsilon=0.1,
+ epsilon=1e-4,
alpha=0.5,
use_iterative_root=False,
use_locking=False,
@@ -255,81 +256,18 @@ class ShampooOptimizer(optimizer.Optimizer):
def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
iter_count=100, epsilon=1e-6):
- """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
+
+ mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
+ iter_count, self._epsilon)
+ mat_h = matrix_functions.matrix_inverse_pth_root(
+ mat_g_sqrt,
+ mat_g_size,
+ 2 * alpha,
+ iter_count,
+ epsilon,
+ ridge_epsilon=0.0)
- We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
-
- A Schur-Newton Method for the Matrix p-th Root and its Inverse
- by Chun-Hua Guo and Nicholas J. Higham
- SIAM Journal on Matrix Analysis and Applications,
- 2006, Vol. 28, No. 3 : pp. 788-804
- https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
-
- Args:
- var: the variable we are updating.
- mat_g: the symmetric PSD matrix whose power it to be computed
- mat_g_size: size of mat_g.
- alpha: exponent, must be -1/p for p a positive integer.
- mat_h_slot_name: name of slot to store the power, if needed.
- iter_count: Maximum number of iterations.
- epsilon: accuracy indicator, useful for early termination.
-
- Returns:
- mat_g^alpha
- """
-
- identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
-
- def MatPower(mat_m, p):
- """Computes mat_m^p, for p a positive integer.
-
- Power p is known at graph compile time, so no need for loop and cond.
- Args:
- mat_m: a square matrix
- p: a positive integer
-
- Returns:
- mat_m^p
- """
- assert p == int(p) and p > 0
- power = None
- while p > 0:
- if p % 2 == 1:
- power = math_ops.matmul(mat_m, power) if power is not None else mat_m
- p //= 2
- mat_m = math_ops.matmul(mat_m, mat_m)
- return power
-
- def IterCondition(i, mat_m, _):
- return math_ops.logical_and(
- i < iter_count,
- math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
-
- def IterBody(i, mat_m, mat_x):
- mat_m_i = (1 - alpha) * identity + alpha * mat_m
- return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m),
- math_ops.matmul(mat_x, mat_m_i))
-
- if mat_g_size == 1:
- mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
- else:
- damped_mat_g = mat_g + self._epsilon * identity
- z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
- # The best value for z is
- # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
- # (c_max^{1-alpha} - c_min^{1-alpha})
- # where c_max and c_min are the largest and smallest singular values of
- # damped_mat_g.
- # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
- # Can replace above line by the one below, but it is less accurate,
- # hence needs more iterations to converge.
- # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
- # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
- # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
- # extra iterations.
- _, _, mat_h = control_flow_ops.while_loop(
- IterCondition, IterBody,
- [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
if mat_h_slot_name is not None:
return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
return mat_h
@@ -422,6 +360,8 @@ class ShampooOptimizer(optimizer.Optimizer):
mat_gbar_weight_t * precond_update_interval, i),
lambda: mat_g)
+ mat_g_updated = mat_g_updated / float(shape[i].value)
+
if self._svd_interval == 1:
mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
else:
@@ -443,7 +383,13 @@ class ShampooOptimizer(optimizer.Optimizer):
name="precond_" + str(i))
else:
# Tensor size is too large -- perform diagonal Shampoo update
- grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
+ # Only normalize non-vector cases.
+ if axes:
+ normalizer = 1.0 if indices is not None else float(shape[i].value)
+ grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
+ else:
+ grad_outer = grad * grad
+
if i == 0 and indices is not None:
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 2e0a202ae2..05bcf2cfa3 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
TOLERANCE = 1e-3
+RIDGE_EPSILON = 1e-4
def np_power(mat_g, alpha):
@@ -52,7 +53,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size)
grad_np_2 = np.random.rand(size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -77,8 +78,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g^{-0.5} * grad
# lr = 1
- mat_g = np.outer(grad_np, grad_np)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np = init_var_np - np.dot(mat_h, grad_np)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -88,8 +89,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g += np.outer(grad_np_2, grad_np_2)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np -= np.dot(mat_h, grad_np_2)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -103,7 +104,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1])
grad_np_2 = np.random.rand(size[0], size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -128,10 +129,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25}
# lr = 1
- mat_g1 = np.dot(grad_np, grad_np.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -141,10 +142,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.dot(grad_np_2, grad_np_2.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -162,7 +163,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1], size[2])
grad_np_2 = np.random.rand(size[0], size[1], size[2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -188,12 +189,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = (
+ np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) /
+ grad_np.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) /
+ grad_np.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) /
+ grad_np.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -207,12 +214,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) /
+ grad_np_2.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) /
+ grad_np_2.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) /
+ grad_np_2.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -240,7 +253,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size)
grad_np_2 = np.random.rand(size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -265,19 +278,21 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
- mat_g = grad_np * grad_np + 0.1
- new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
-
- self.assertAllCloseAccordingToType(new_val_np, new_val)
+ mat_g = (grad_np * grad_np)
+ new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
# Run another step of Shampoo
update_2.run()
new_val = sess.run(var)
- mat_g += grad_np_2 * grad_np_2
- new_val_np -= np.power(mat_g, -0.5) * grad_np_2
+ mat_g += (grad_np_2 * grad_np_2)
+ new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
- self.assertAllCloseAccordingToType(new_val_np, new_val)
@parameterized.named_parameters(('Var', False), ('ResourceVar', True))
def testLargeMatrix(self, use_resource_var):
@@ -294,7 +309,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1])
grad_np_2 = np.random.rand(size[0], size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -322,10 +337,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# with broadcasting
# lr = 1
- mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.sum(
+ grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -335,10 +351,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.sum(
+ grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -365,7 +382,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
replace=False))
grad_np_2 = np.random.rand(sample_size_2, size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -405,9 +422,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
mat_g1_acc = np.zeros((size[0], 1))
mat_g1_acc[grad_indices] += mat_g1
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np
new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right)
@@ -420,9 +437,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
mat_g1_acc[grad_indices_2] += mat_g1
- mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -445,7 +462,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
replace=False))
grad_np = np.random.rand(sample_size, size[1], size[2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -474,12 +491,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_dense = np.zeros_like(init_var_np)
grad_dense[grad_indices] = grad_np
- mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -512,7 +532,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_decay = 0.9
gbar_weight = 0.1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -536,12 +556,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np = gbar_weight * grad_np
precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0]))
@@ -556,12 +579,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2
precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0]))
@@ -601,7 +627,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3_a = np.eye(size[2])
mat_g3 = np.zeros_like(mat_g3_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -626,13 +652,19 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
+ mat_g1 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0]
+ mat_g2 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1]
+ mat_g3 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2]
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -672,7 +704,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3_a = np.eye(size[2])
mat_g3 = np.zeros_like(mat_g3_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -700,17 +732,23 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
if (i + 1) % precond_update_interval == 0:
- mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- * precond_update_interval)
- mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- * precond_update_interval)
- mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
- * precond_update_interval)
+ mat_g1 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) /
+ grad_np[i].shape[0] * precond_update_interval)
+ mat_g2 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) /
+ grad_np[i].shape[1] * precond_update_interval)
+ mat_g3 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) /
+ grad_np[i].shape[2] * precond_update_interval)
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
diff --git a/tensorflow/contrib/opt/python/training/sign_decay_test.py b/tensorflow/contrib/opt/python/training/sign_decay_test.py
index c31cb924ea..3a84789afd 100644
--- a/tensorflow/contrib/opt/python/training/sign_decay_test.py
+++ b/tensorflow/contrib/opt/python/training/sign_decay_test.py
@@ -66,7 +66,7 @@ class SignDecaysTest(test.TestCase):
linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = linear_decay_fn(step).eval()
py_decayed = py_linear_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -78,7 +78,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = cosine_decay_fn(step).eval()
py_decayed = py_cosine_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -95,7 +95,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = restart_decay_fn(step).eval()
py_decayed = py_restart_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
index fdda86b0b5..ff0ea8d766 100644
--- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
@@ -158,7 +158,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testDenseLocal(self):
for dtype in [dtypes.float32, dtypes.float64, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupDense(False, dtype)
self._assertDenseCorrect(var0, var1, update_op)
@@ -171,7 +171,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testSparseLocal(self):
for dtype in [dtypes.float64, dtypes.float32, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupSparse(False, dtype)
self._assertSparseCorrect(var0, var1, update_op)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
index b9cf40eb7b..200b0d2008 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.opt.python.training import shampoo
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -26,6 +27,7 @@ from tensorflow.python.training import adam
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.training import optimizer
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.ops import array_ops
class DecoupledWeightDecayExtension(object):
@@ -159,8 +161,8 @@ class DecoupledWeightDecayExtension(object):
def _decay_weights_sparse_op(self, var, indices, scatter_add):
if not self._decay_var_list or var in self._decay_var_list:
- return scatter_add(var, indices, -self._weight_decay * var,
- self._use_locking)
+ update = -self._weight_decay * array_ops.gather(var, indices)
+ return scatter_add(var, indices, update, self._use_locking)
return control_flow_ops.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
@@ -360,3 +362,74 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
super(AdamWOptimizer, self).__init__(
weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
epsilon=epsilon, use_locking=use_locking, name=name)
+
+
+@tf_export("contrib.opt.ShampooWOptimizer")
+class ShampooWOptimizer(DecoupledWeightDecayExtension,
+ shampoo.ShampooOptimizer):
+ """Optimizer that implements the Shampoo algorithm with weight decay.
+
+ For further information see the documentation of the Shampoo Optimizer.
+ """
+
+ def __init__(self,
+ weight_decay,
+ global_step,
+ max_matrix_size=768,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=1e-4,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="ShampooW"):
+ """Construct a new ShampooW optimizer.
+
+ For further information see the documentation of the Shampoo Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] +
+ gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] =
+ mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is
+ also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after this
+ many steps. Default = 1. Usually less than svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the iterative
+ root method (for TPU) for finding the roots of PSD matrices.
+ use_locking: If `True` use locks for update operations.
+ name: name of optimizer.
+ """
+ super(ShampooWOptimizer, self).__init__(
+ weight_decay,
+ global_step=global_step,
+ max_matrix_size=max_matrix_size,
+ gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ mat_gbar_decay=mat_gbar_weight,
+ learning_rate=learning_rate,
+ svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ epsilon=epsilon,
+ alpha=alpha,
+ use_iterative_root=use_iterative_root,
+ use_locking=use_locking,
+ name=name)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
index 76d8a5697a..9c91078301 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
@@ -58,7 +58,7 @@ class WeightDecayOptimizerTest(test.TestCase):
def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
use_resource=False, do_sparse=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/optimizer_v2/adadelta_test.py b/tensorflow/contrib/optimizer_v2/adadelta_test.py
index 31cfec0d50..4c94b66679 100644
--- a/tensorflow/contrib/optimizer_v2/adadelta_test.py
+++ b/tensorflow/contrib/optimizer_v2/adadelta_test.py
@@ -37,7 +37,7 @@ class AdadeltaOptimizerTest(test.TestCase):
for dtype in [dtypes.half, dtypes.float32]:
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
- with self.test_session():
+ with self.cached_session():
var0_init = [1.0, 2.0]
var1_init = [3.0, 4.0]
if use_resource:
@@ -146,7 +146,7 @@ class AdadeltaOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py
index 18191c3ef2..debaaaeeba 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad_test.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py
@@ -36,7 +36,7 @@ class AdagradOptimizerTest(test.TestCase):
def doTestBasic(self, use_locking=False, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
@@ -73,7 +73,7 @@ class AdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -92,7 +92,7 @@ class AdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -116,7 +116,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -147,7 +147,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -177,7 +177,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -201,7 +201,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -237,7 +237,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -270,7 +270,7 @@ class AdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 631d4f44df..04b1552b61 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -40,15 +40,14 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
Initialization:
- $$m_0 := 0 (Initialize initial 1st moment vector)$$
- $$v_0 := 0 (Initialize initial 2nd moment vector)$$
- $$t := 0 (Initialize timestep)$$
-
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
- $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
diff --git a/tensorflow/contrib/optimizer_v2/adam_test.py b/tensorflow/contrib/optimizer_v2/adam_test.py
index d9ad58b0a6..b1ad0ade42 100644
--- a/tensorflow/contrib/optimizer_v2/adam_test.py
+++ b/tensorflow/contrib/optimizer_v2/adam_test.py
@@ -56,7 +56,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -122,7 +122,7 @@ class AdamOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -152,7 +152,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -215,7 +215,7 @@ class AdamOptimizerTest(test.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -224,7 +224,7 @@ class AdamOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -261,7 +261,7 @@ class AdamOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 28a531dfec..e13b82d1d2 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -310,7 +310,7 @@ class CheckpointingTests(test.TestCase):
global_step=root.global_step)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
- with self.test_session(graph=ops.get_default_graph()) as session:
+ with self.session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
@@ -504,7 +504,7 @@ class CheckpointingTests(test.TestCase):
"""Saves after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()
@@ -522,7 +522,7 @@ class CheckpointingTests(test.TestCase):
"""Restores after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()
diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent_test.py b/tensorflow/contrib/optimizer_v2/gradient_descent_test.py
index ad9aef804f..4a77bce478 100644
--- a/tensorflow/contrib/optimizer_v2/gradient_descent_test.py
+++ b/tensorflow/contrib/optimizer_v2/gradient_descent_test.py
@@ -34,7 +34,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -57,7 +57,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasicResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -82,7 +82,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -108,7 +108,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -135,7 +135,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -157,7 +157,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testGradWrtRef(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
opt = gradient_descent.GradientDescentOptimizer(3.0)
values = [1.0, 3.0]
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
@@ -168,7 +168,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testWithGlobalStep(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
global_step = variables.Variable(0, trainable=False)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
@@ -191,7 +191,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py
index 24cdab4626..e69f12839e 100644
--- a/tensorflow/contrib/optimizer_v2/momentum_test.py
+++ b/tensorflow/contrib/optimizer_v2/momentum_test.py
@@ -123,7 +123,7 @@ class MomentumOptimizerTest(test.TestCase):
]), self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -162,7 +162,7 @@ class MomentumOptimizerTest(test.TestCase):
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -188,7 +188,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparseNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -282,7 +282,7 @@ class MomentumOptimizerTest(test.TestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -435,7 +435,7 @@ class MomentumOptimizerTest(test.TestCase):
return db_grad, db_out
def testLikeDistBeliefMom01(self):
- with self.test_session():
+ with self.cached_session():
db_grad, db_out = self._dbParamsMom01()
num_samples = len(db_grad)
var0 = variables.Variable([0.0] * num_samples)
@@ -449,7 +449,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
grads0 = ops.IndexedSlices(
@@ -518,7 +518,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f6ecaba834..6af59dcfbf 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -214,7 +214,8 @@ class _OptimizerV2State(object):
# with that Tensor cast to that dtype.
with ops.init_scope():
self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in hyper.items() if not dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information
@@ -231,7 +232,8 @@ class _OptimizerV2State(object):
ret._deferred_dependencies = self._deferred_dependencies
ret._deferred_slot_restorations = self._deferred_slot_restorations
ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in hyper.items() if dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
ret._hyper.update(self._hyper)
ret._non_slot_devices = non_slot_devices
ret._distribution = distribution
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
index a44bfd1bfd..dd7f2f4405 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
@@ -61,7 +61,7 @@ class OptimizerTest(test.TestCase):
def testAggregationMethod(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -86,7 +86,7 @@ class OptimizerTest(test.TestCase):
def testPrecomputedGradient(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -212,7 +212,7 @@ class OptimizerTest(test.TestCase):
sgd_op.apply_gradients(grads_and_vars)
def testTrainOp(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([3.0, 4.0])
cost = 5 * var0 + 3 * var1
@@ -225,7 +225,7 @@ class OptimizerTest(test.TestCase):
def testConstraint(self):
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0],
constraint=constraint_01)
var1 = variables.Variable([3.0, 4.0],
@@ -247,7 +247,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([0., 0.], var1.eval())
def testStopGradients(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], name='var0')
var1 = variables.Variable([3.0, 4.0], name='var1')
var0_id = array_ops.identity(var0)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
index 628d0418dd..44301ffe9e 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
@@ -162,7 +162,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([dtypes.float32, dtypes.float64])
def testMinimizeSparseResourceVariable(self, dtype):
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -184,7 +184,7 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([dtypes.float32, dtypes.float64])
def testMinimizeSparseResourceVariableCentered(self, dtype):
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index 42fba81a5c..85b5a5a3b9 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -14,8 +14,8 @@
// limitations under the License.
// =============================================================================
-#ifndef TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
-#define TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
+#ifndef TENSORFLOW_CONTRIB_PERIODIC_RESAMPLE_KERNELS_PERIODIC_RESAMPLE_OP_H_
+#define TENSORFLOW_CONTRIB_PERIODIC_RESAMPLE_KERNELS_PERIODIC_RESAMPLE_OP_H_
#include <cmath>
#include <type_traits>
@@ -421,4 +421,4 @@ class PeriodicResampleOpGrad : public tensorflow::OpKernel {
tensorflow::PartialTensorShape desired_shape;
};
-#endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
+#endif // TENSORFLOW_CONTRIB_PERIODIC_RESAMPLE_KERNELS_PERIODIC_RESAMPLE_OP_H_
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 31a6fe1d94..9a19502276 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -38,7 +38,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([6, None])
output_tensor = input_tensor.reshape((6, 2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -49,7 +49,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([5, None])
output_tensor = input_tensor.reshape((6, 2))[:-1]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -63,7 +63,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[15]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
# input_tensor[0, 0, 0] == result[0, 0, 0]
@@ -88,14 +88,14 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -109,7 +109,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([4, 4, None])
result_shape = (4, 4, 1)
input_shape = (2, 2, 4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=input_shape)
output = periodic_resample(x, desired_shape)
error = gradient_checker.compute_gradient_error(
@@ -117,7 +117,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
self.assertLess(error, 1e-4)
def testPeriodicResampleShapeInference(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Case 1: output shape can be fully inferreed.
x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
output = periodic_resample(x, [4, 4, None])
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
index e3570e38a3..17b69c7b35 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -170,7 +170,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names = [f.name for f in fields]
output_types = [f.dtype for f in fields]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sizes, vtensor = self._decode_module.decode_proto(
batch,
message_type=message_type,
@@ -290,7 +290,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names = ['sizes']
field_types = [dtypes.int32]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ctensor, vtensor = self._decode_module.decode_proto(
batch,
message_type=msg_type,
diff --git a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py
index 9a1c04af32..7e9b355c69 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py
@@ -137,7 +137,7 @@ class DescriptorSourceTestBase(test.TestCase):
field_names = ['values', 'shapes', 'sizes', 'fields']
tensor_types = [dtypes.string, dtypes.int32, dtypes.int32, dtypes.string]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sizes, field_tensors = self._decode_module.decode_proto(
in_bufs,
message_type=message_type,
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
index 07dfb924d3..01b3ccc7fd 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
@@ -55,7 +55,7 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
def testBadInputs(self):
# Invalid field name
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError('Unknown field: non_existent_field'):
self._encode_module.encode_proto(
sizes=[[1]],
@@ -64,7 +64,7 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names=['non_existent_field']).eval()
# Incorrect types.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
'Incompatible type for field double_value.'):
self._encode_module.encode_proto(
@@ -74,7 +74,7 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names=['double_value']).eval()
# Incorrect shapes of sizes.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
r'sizes should be batch_size \+ \[len\(field_names\)\]'):
sizes = array_ops.placeholder(dtypes.int32)
@@ -89,7 +89,7 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
})
# Inconsistent shapes of values.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
'Values must match up to the last dimension'):
sizes = array_ops.placeholder(dtypes.int32)
@@ -109,7 +109,7 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
field_names = [f.name for f in fields]
out_types = [f.dtype for f in fields]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sizes, field_tensors = self._decode_module.decode_proto(
in_bufs,
message_type=message_type,
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 23363617ed..c59f667f6a 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@ py_test(
":common",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@ py_library(
":common",
":graph_matcher",
":input_to_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@ py_library(
":graph_matcher",
":input_to_ops",
":quant_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
@@ -244,7 +243,9 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
],
)
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e158e..b27117dd48 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix):
return s[len(prefix):]
else:
return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+ """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ consumers = t1.consumers()
+ if can_modify is not None:
+ consumers = [c for c in consumers if c in can_modify]
+ consumers_indices = {}
+ for c in consumers:
+ consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+ for c in consumers:
+ for i in consumers_indices[c]:
+ c._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2d26..2b26302f8a 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase):
_, step_val = sess.run([b, quantization_step_tensor])
self.assertEqual(step_val, 2)
+ def testRerouteTensor(self):
+ a = constant_op.constant(1, name='a')
+ b = constant_op.constant(2, name='b')
+ c = constant_op.constant(3, name='c')
+ d = constant_op.constant(4, name='d')
+
+ add_ac = math_ops.add(a, c)
+ add_ad = math_ops.add(a, d)
+
+ # Ensure that before rerouting the inputs are what we think.
+ self._CheckOpHasInputs(add_ac.op, [a, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ # references to tensor a should be replaced with b for all ops in
+ # can_modify. This means add_ac will be changed but add_ad will not.
+ common.RerouteTensor(b, a, can_modify=[add_ac.op])
+ self._CheckOpHasInputs(add_ac.op, [b, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ def _CheckOpHasInputs(self, op, inputs):
+ for i in inputs:
+ self.assertIn(i, op.inputs)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179bee4..2971b28f45 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
- nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
- match.output_tensor)
+ nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+ match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
@@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ common.RerouteTensor(
+ bn_decay_mean_out,
+ match.bn_decay_mean_tensor,
can_modify=bn_decay_mean_consumers)
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
+ common.RerouteTensor(
+ bn_decay_var_out,
+ match.bn_decay_var_tensor,
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
@@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73ea6..e88db0acd5 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
if consumers:
- tensors_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 2944f964c7..484493f1b2 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -59,6 +59,10 @@ def _create_graph(input_graph=None,
if input_graph is None:
input_graph = ops.get_default_graph()
+
+ # Add check to see if graph has training ops, if so provide error message and
+ # exit
+ _check_for_training_ops(input_graph)
with input_graph.as_default():
fold_batch_norms.FoldBatchNorms(
input_graph,
@@ -78,6 +82,9 @@ def create_training_graph(input_graph=None, quant_delay=0):
Variables added by the rewrite get added to the global variables collection.
+ This function must be invoked prior to insertion of gradient ops in a graph
+ as quantization should be modeled in both forward and backward passes.
+
The graph has fake quantization ops inserted to simulate the error
introduced by quantization. Since the graph is transformed in place,
the expected behavior of previously held references to nodes and tensors may
@@ -104,7 +111,6 @@ def create_training_graph(input_graph=None, quant_delay=0):
# Currently the values below are hardcoded for mobilenetV1 on imagenet
# Please use the experimental API if you need to tune these values.
freeze_bn_delay = None
-
_create_graph(
input_graph=input_graph,
is_training=True,
@@ -141,6 +147,9 @@ def experimental_create_training_graph(input_graph=None,
scope=None):
"""Rewrites a training input_graph in place for simulated quantization.
+ This function must be invoked prior to insertion of gradient ops in a graph
+ as quantization should be modeled in both forward and backward passes.
+
Variables added by the rewrite get added to the global variables collection.
This function has additional experimental options not (yet) available to
@@ -226,3 +235,45 @@ def experimental_create_eval_graph(input_graph=None,
activation_bits=activation_bits,
quant_delay=quant_delay,
scope=scope)
+
+
+def _check_for_training_ops(g):
+ """Check if training ops are present in the graph.
+
+ Args:
+ g: The tf.Graph on which the check for training ops needs to be
+ performed.
+
+ Raises:
+ ValueError: If a training op is seen in the graph;
+ """
+
+ # The list here is obtained
+ # from https://www.tensorflow.org/api_docs/cc/group/training-ops
+ training_ops = frozenset([
+ 'ApplyAdagrad', 'ApplyAdagradDA', 'ApplyAdam', 'ApplyAddSign',
+ 'ApplyCenteredRMSProp', 'ApplyFtrl', 'ApplyFtrlV2',
+ 'ApplyGradientDescent', 'ApplyMomentum', 'ApplyPowerSign',
+ 'ApplyProximalAdagrad', 'ApplyProximalGradientDescent', 'ApplyRMSProp',
+ 'ResourceApplyAdadelta', 'ResourceApplyAdagrad', 'ResourceApplyAdagradDA',
+ 'ResourceApplyAdam', 'ResourceApplyAddSign',
+ 'ResourceApplyCenteredRMSProp', 'ResourceApplyFtrl',
+ 'ResourceApplyFtrlV2', 'ResourceApplyGradientDescent',
+ 'ResourceApplyMomentum', 'ResourceApplyPowerSign',
+ 'ResourceApplyProximalAdagrad', 'ResourceApplyProximalGradientDescent',
+ 'ResourceApplyRMSProp', 'ResourceSparseApplyAdadelta',
+ 'ResourceSparseApplyAdagrad', 'ResourceSparseApplyAdagradDA',
+ 'ResourceSparseApplyCenteredRMSProp', 'ResourceSparseApplyFtrl',
+ 'ResourceSparseApplyFtrlV2', 'ResourceSparseApplyMomentum',
+ 'ResourceSparseApplyProximalAdagrad',
+ 'ResourceSparseApplyProximalGradientDescent',
+ 'ResourceSparseApplyRMSProp', 'SparseApplyAdadelta', 'SparseApplyAdagrad',
+ 'SparseApplyAdagradDA', 'SparseApplyCenteredRMSProp', 'SparseApplyFtrl',
+ 'SparseApplyFtrlV2', 'SparseApplyMomentum', 'SparseApplyProximalAdagrad',
+ 'SparseApplyProximalGradientDescent', 'SparseApplyRMSProp'
+ ])
+
+ op_types = set([op.type for op in g.get_operations()])
+ train_op_list = op_types.intersection(training_ops)
+ if train_op_list:
+ raise ValueError('Training op found in graph, exiting %s' % train_op_list)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index 54faf582f1..e80d2183a6 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -20,10 +20,12 @@ from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import quantize_graph
+from tensorflow.python import training
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@@ -145,6 +147,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
self.assertTrue(quant_delay_found)
+ def testTrainingOpsCheck(self):
+ self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck)
+
+ def _TestTrainingOpsCheck(self, rewrite_fn):
+ with ops.Graph().as_default():
+ output = self._ConvLayer()
+ output_scalar = math_ops.reduce_sum(output)
+ loss = math_ops.square(output_scalar - 1)
+ opt = training.gradient_descent.GradientDescentOptimizer(0.0001)
+ opt.minimize(loss)
+ with self.assertRaisesRegexp(ValueError, 'Training op found in graph'):
+ rewrite_fn()
+
def testWeightBits(self):
self._RunTestOverExperimentalRewrites(self._TestWeightBits)
diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD
new file mode 100644
index 0000000000..c461a7145e
--- /dev/null
+++ b/tensorflow/contrib/rate/BUILD
@@ -0,0 +1,48 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "rate",
+ srcs = [
+ "rate.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "rate_test",
+ size = "small",
+ srcs = ["rate_test.py"],
+ deps = [
+ ":rate",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:test",
+ ],
+)
diff --git a/tensorflow/contrib/rate/rate.py b/tensorflow/contrib/rate/rate.py
new file mode 100644
index 0000000000..24d586479a
--- /dev/null
+++ b/tensorflow/contrib/rate/rate.py
@@ -0,0 +1,151 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of tf.contrib.rate module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+_to_replace = re.compile("[^A-Za-z0-9.]")
+
+
+class Rate(object):
+ """Computes the rate of change since the last rate call."""
+
+ def __init__(self, name=None):
+ self._built = False
+ self._vars = []
+ self._initial_values = {}
+ name = name or self.__class__.__name__
+ # Replace things like spaces in name to create a valid scope name.
+ scope_name = _to_replace.sub("_", name)
+ # We create the variable scope now to get the unique name that will
+ # be used as a variable prefix when build() calls _add_variable().
+ with variable_scope.variable_scope(
+ scope_name, use_resource=True, reuse=False) as scope:
+ pos = scope.name.rfind(scope_name)
+ self._name = name + scope.name[pos + len(scope_name):]
+ self._scope = scope
+
+ # Ensures that if the user calls build directly we still set self._built to
+ # True to prevent variables from being recreated.
+ self._build = self.build
+ if context.executing_eagerly():
+ self._construction_scope = context.eager_mode
+ else:
+ # We make self.call() into a graph callable here, so that we can
+ # return a single op that performs all of the variable updates.
+ self._construction_scope = ops.get_default_graph().as_default
+ self.call = function.defun(self.call)
+
+ def build(self, values, denominator):
+ """Method to create variables.
+
+ Called by `__call__()` before `call()` for the first time.
+
+ Args:
+ values: The numerator for rate.
+ denominator: Value to which the rate is taken with respect.
+ """
+ self.numer = self._add_variable(
+ name="numer", shape=values.get_shape(), dtype=dtypes.float64)
+ self.denom = self._add_variable(
+ name="denom", shape=denominator.get_shape(), dtype=dtypes.float64)
+ self.prev_values = self._add_variable(
+ name="prev_values", shape=values.get_shape(), dtype=dtypes.float64)
+ self.prev_denominator = self._add_variable(
+ name="prev_denominator",
+ shape=denominator.get_shape(),
+ dtype=dtypes.float64)
+ self._built = True
+
+ def __call__(self, *args, **kwargs):
+ """Returns op to execute to update.
+
+ Returns None if eager execution is enabled.
+ Returns a graph-mode function if graph execution is enabled.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to Rate, passed on to `call()`.
+ """
+ if not self._built:
+ with variable_scope.variable_scope(
+ self._scope), self._construction_scope():
+ self.build(*args, **kwargs)
+ self._built = True
+ return self.call(*args, **kwargs)
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def variables(self):
+ return self._vars
+
+ def _safe_div(self, numerator, denominator, name):
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero, name=name)
+
+ def _add_variable(self, name, shape=None, dtype=None):
+ """Private method for adding variables to the graph."""
+ if self._built:
+ raise RuntimeError("Can't call add_variable() except in build().")
+ v = resource_variable_ops.ResourceVariable(
+ lambda: array_ops.zeros(shape, dtype),
+ trainable=False,
+ validate_shape=True,
+ name=name,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ return v
+
+ def call(self, values, denominator):
+ """Computes the rate since the last call.
+
+ Args:
+ values: Tensor with the per-example value.
+ denominator: Measure to take the rate with respect to.
+
+ Returns:
+ The rate or 0 if denominator is unchanged since last call.
+ """
+ if denominator.dtype != dtypes.float64:
+ denominator = math_ops.cast(denominator, dtypes.float64)
+ if values.dtype != dtypes.float64:
+ values = math_ops.cast(values, dtypes.float64)
+
+ state_ops.assign(self.numer, math_ops.subtract(values, self.prev_values))
+ state_ops.assign(self.denom,
+ math_ops.subtract(denominator, self.prev_denominator))
+ state_ops.assign(self.prev_values, values)
+ state_ops.assign(self.prev_denominator, denominator)
+
+ return self._safe_div(self.numer, self.denom, name="safe_rate")
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
new file mode 100644
index 0000000000..08908104f4
--- /dev/null
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rate import rate
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RateTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBuildRate(self):
+ m = rate.Rate()
+ m.build(
+ constant_op.constant([1], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ old_numer = m.numer
+ m(
+ constant_op.constant([2], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ self.assertTrue(old_numer is m.numer)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBasic(self):
+ with self.test_session():
+ r_ = rate.Rate()
+ a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ self.assertEqual([[1]], self.evaluate(a))
+ b = r_(constant_op.constant([2]), denominator=constant_op.constant([2]))
+ self.assertEqual([[1]], self.evaluate(b))
+ c = r_(constant_op.constant([4]), denominator=constant_op.constant([3]))
+ self.assertEqual([[2]], self.evaluate(c))
+ d = r_(constant_op.constant([16]), denominator=constant_op.constant([3]))
+ self.assertEqual([[0]], self.evaluate(d)) # divide by 0
+
+ def testNamesWithSpaces(self):
+ m1 = rate.Rate(name="has space")
+ m1(array_ops.ones([1]), array_ops.ones([1]))
+ self.assertEqual(m1.name, "has space")
+ self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testWhileLoop(self):
+ with self.test_session():
+ r_ = rate.Rate()
+
+ def body(value, denom, i, ret_rate):
+ i += 1
+ ret_rate = r_(value, denom)
+ with ops.control_dependencies([ret_rate]):
+ value = math_ops.add(value, 2)
+ denom = math_ops.add(denom, 1)
+ return [value, denom, i, ret_rate]
+
+ def condition(v, d, i, r):
+ del v, d, r # unused vars by condition
+ return math_ops.less(i, 100)
+
+ i = constant_op.constant(0)
+ value = constant_op.constant([1], dtype=dtypes.float64)
+ denom = constant_op.constant([1], dtype=dtypes.float64)
+ ret_rate = r_(value, denom)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ loop = control_flow_ops.while_loop(condition, body,
+ [value, denom, i, ret_rate])
+ self.assertEqual([[2]], self.evaluate(loop[3]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
index f23194a6f2..1800edc05a 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
@@ -165,7 +165,7 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
fetches = self._CreateRnnGraph(
fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
sess.run(variables.global_variables_initializer())
# Note that cell.trainable_variables it not always set.
self._MaybeResetVariables(variable_cache, sess,
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
index 00fbd4fbb8..aea80a5256 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -56,7 +56,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
x_power=state.x_power * theta.x)
return next_state, []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
theta = _PolyTheta(x=array_ops.constant(2.0))
state = _PolyState(
value=array_ops.constant(0.0),
@@ -142,7 +142,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
def _ParameterizedTestElman(self, seqlen, use_grad):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(342462)
batch = 3
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index 67a8f59c3c..c3db71359c 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
# TODO(drpng): just use Update so that we don't carry over the gradients?
"""Sets the output to be zero at the end of the sequence."""
# output is batch major.
- batch_size, max_time, vector_size = tf_output.shape
+ shape = array_ops.shape(tf_output)
+ batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
output_time = array_ops.reshape(output_time, [batch_size, max_time])
lengths = array_ops.tile(
@@ -278,11 +279,16 @@ def functional_rnn(cell, inputs, sequence_length=None,
if initial_state is None:
initial_state = cell.zero_state(batch_size, dtype)
func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
+ if sequence_length is not None:
+ max_length = math_ops.reduce_max(sequence_length)
+ else:
+ max_length = None
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
+ max_input_length=max_length,
use_tpu=use_tpu)
tf_output, tf_state = _PostProcessOutput(
extended_acc_state, extended_final_state, func_cell,
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
index d8c0a0631d..69ef521c01 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
-#define TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_REDUCE_SLICE_OPS_KERNELS_REDUCE_SLICE_OPS_H_
+#define TENSORFLOW_CONTRIB_REDUCE_SLICE_OPS_KERNELS_REDUCE_SLICE_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -81,4 +81,4 @@ CALL_ALL_REDUCEOPS(ReduceSliceFunctorReduceop)
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CONTRIB_REDUCE_SLICE_OPS_KERNELS_REDUCE_SLICE_OPS_H_
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 5874245d58..4e67d80558 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -212,6 +212,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
+ tags = ["noasan"],
)
tf_custom_op_library(
@@ -279,7 +280,10 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = ["no_oss"],
+ tags = [
+ "no_oss",
+ "noasan",
+ ],
)
tf_cc_test(
@@ -287,6 +291,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/gru_ops_test.cc"],
data = [":python/ops/_gru_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -306,6 +311,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/lstm_ops_test.cc"],
data = [":python/ops/_lstm_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index cb437f2a2f..026bf08ced 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""RNN Cells and additional RNN operations.
-See @{$python/contrib.rnn} guide.
+See [Contrib RNN](https://tensorflow.org/api_guides/python/contrib.rnn) guide.
<!--From core-->
@@RNNCell
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 85f0f8ced9..be0306cb07 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
class RNNCellTest(test.TestCase):
def testLinear(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0)):
x = array_ops.zeros([1, 2])
@@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(variables_lib.trainable_variables()), 2)
def testBasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
@@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testIndRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.156736, 0.156736]])
def testIndyGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.155127, 0.157328]])
def testSRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testSRUCellWithDiffSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -225,7 +225,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2], dtype=dtype)
@@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellDimension0Error(self):
"""Tests that dimension 0 in both(x and m) shape must be equal."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase):
})
def testBasicLSTMCellStateTupleType(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -395,7 +395,7 @@ class RNNCellTest(test.TestCase):
def testIndyLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2], dtype=dtype)
@@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(res), 2)
def testLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testLSTMCellVariables(self):
- with self.test_session():
+ with self.cached_session():
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase):
"root/lstm_cell/projection/kernel")
def testLSTMCellLayerNorm(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
num_proj = 3
batch_size = 1
@@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.DropoutWrapper,
rnn_cell_impl.ResidualWrapper,
lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
- with self.test_session():
- cell = rnn_cell_impl.BasicRNNCell(1)
- wrapper = wrapper_type(cell)
- wrapper(array_ops.ones([1, 1]),
- state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
- self.evaluate([v.initializer for v in cell.variables])
- checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
- prefix = os.path.join(self.get_temp_dir(), "ckpt")
- self.evaluate(cell._bias.assign([40.]))
- save_path = checkpoint.save(prefix)
- self.evaluate(cell._bias.assign([0.]))
- checkpoint.restore(save_path).assert_consumed().run_restore_ops()
- self.assertAllEqual([40.], self.evaluate(cell._bias))
+ cell = rnn_cell_impl.BasicRNNCell(1)
+ wrapper = wrapper_type(cell)
+ wrapper(array_ops.ones([1, 1]),
+ state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+ self.evaluate([v.initializer for v in cell.variables])
+ checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(cell._bias.assign([40.]))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(cell._bias.assign([0.]))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([40.], self.evaluate(cell._bias))
def testOutputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testInputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testResidualWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[2], res[3])
def testResidualWrapperWithSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 5])
@@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
def testEmbeddingWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
@@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.17139, 0.17139]])
def testEmbeddingWrapperWithDynamicRnn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root"):
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
@@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase):
sess.run(outputs)
def testMultiRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase):
time_steps=None,
parallel_iterations=None,
**kwargs):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
if batch_size is None and time_steps is None:
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index d62ec45d18..bf699db3ed 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -457,7 +457,7 @@ class LSTMTest(test.TestCase):
input_size = 5
batch_size = 2
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, num_units)
@@ -491,7 +491,7 @@ class LSTMTest(test.TestCase):
input_size = 5
batch_size = 2
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(
@@ -588,7 +588,7 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
inputs = max_length * [
@@ -834,7 +834,7 @@ class LSTMTest(test.TestCase):
batch_size = 2
num_proj = 4
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
initializer_d = init_ops.random_uniform_initializer(
-1, 1, seed=self._seed + 1)
@@ -884,7 +884,7 @@ class LSTMTest(test.TestCase):
batch_size = 2
num_proj = 4
max_length = 8
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed)
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
@@ -930,7 +930,7 @@ class LSTMTest(test.TestCase):
max_length = 8
sequence_length = [4, 6]
in_graph_mode = not context.executing_eagerly()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
if in_graph_mode:
@@ -1006,7 +1006,7 @@ class LSTMTest(test.TestCase):
max_length = 8
sequence_length = [4, 6]
in_graph_mode = not context.executing_eagerly()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
if in_graph_mode:
@@ -1612,7 +1612,7 @@ class MultiDimensionalLSTMTest(test.TestCase):
batch_size = 2
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
inputs = max_length * [
array_ops.placeholder(dtypes.float32, shape=(None,) + input_size)
]
@@ -1723,7 +1723,7 @@ class NestedLSTMTest(test.TestCase):
state_size = 6
max_length = 8
sequence_length = [4, 6]
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
state_saver = TestStateSaver(batch_size, state_size)
single_input = (array_ops.placeholder(
dtypes.float32, shape=(None, input_size)),
@@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase):
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
@@ -2017,7 +2017,7 @@ class RawRNNTest(test.TestCase):
np.random.seed(self._seed)
def _testRawRNN(self, max_time):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
batch_size = 16
input_depth = 4
num_units = 3
@@ -2126,7 +2126,7 @@ class RawRNNTest(test.TestCase):
self._testRawRNN(max_time=10)
def testLoopState(self):
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
max_time = 10
batch_size = 16
input_depth = 4
@@ -2162,7 +2162,7 @@ class RawRNNTest(test.TestCase):
self.assertEqual([10], loop_state.eval())
def testLoopStateWithTensorArray(self):
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
max_time = 4
batch_size = 16
input_depth = 4
@@ -2205,7 +2205,7 @@ class RawRNNTest(test.TestCase):
self.assertAllEqual([1, 2, 2 + 2, 4 + 3, 7 + 4], loop_state.eval())
def testEmitDifferentStructureThanCellOutput(self):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
max_time = 10
batch_size = 16
input_depth = 4
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e41e..8d34b9e852 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase):
def testBasicRNNFusedWrapper(self):
"""This test checks that using a wrapper for BasicRNN works as expected."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase):
self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testTimeReversedFusedRNN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index c7d85862f6..6689664fb9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.util import nest
class RNNCellTest(test.TestCase):
def testCoupledInputForgetGateLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
state_size = num_units * 2
batch_size = 3
@@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_state)
def testTimeFreqLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
state_size = num_units * 2
batch_size = 3
@@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
@@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase):
.state_f00_b00_c[i, :]))) > 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
feature_size = 2
@@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase):
]],
dtype=np.float32)
for state_is_tuple in [False, True]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCellWithSliceOffset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase):
batch_size = 3
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase):
0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase):
0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase):
[[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase):
[[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 1]
filter_size = [3]
num_features = 1
@@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 1]
filter_size = [3, 3]
num_features = 1
@@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 2, 1]
filter_size = [3, 3, 3]
num_features = 1
@@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase):
# Try with input dimension equal to num_units or not.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root1_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase):
# Try with num_inputs equal to or not equal to num_units.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root2_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase):
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
def testBasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_h, 1e-5)
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
num_units = 5
allowed_low = [1, 2, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(1)):
x = array_ops.zeros([1, 5])
@@ -1440,7 +1440,7 @@ class CompiledWrapperTest(test.TestCase):
atol = 1e-5
random_seed.set_random_seed(1234)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
xla_ops = _create_multi_lstm_cell_ops(
batch_size=batch_size,
num_units=num_units,
@@ -1452,7 +1452,7 @@ class CompiledWrapperTest(test.TestCase):
xla_results = sess.run(xla_ops)
random_seed.set_random_seed(1234)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
non_xla_ops = _create_multi_lstm_cell_ops(
batch_size=batch_size,
num_units=num_units,
@@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase):
def _cell_output(self, cell):
"""Calculates cell output."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index f74c95f962..06c481672c 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
Training of Deep Neural Networks
The default LSTM implementation based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The class uses optional peephole connections, optional cell clipping
and an optional projection layer.
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index e7eb4ac563..b897224c6d 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -36,6 +36,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":keras_saved_model",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
@@ -101,23 +102,33 @@ py_library(
tags = ["no_windows"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:saver",
"//tensorflow/python:util",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export",
+ "//tensorflow/python/estimator:keras",
+ "//tensorflow/python/estimator:model_fn",
"//tensorflow/python/keras:engine",
- "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model",
],
)
py_test(
name = "keras_saved_model_test",
- size = "small",
+ size = "medium",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":saved_model_py",
+ ":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/keras",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 95e1a8967b..074dc655ac 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -26,10 +26,13 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
-# pylint: enable=unused-import,widcard-import,line-too-long
+# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
+_allowed_symbols = [
+ "get_signature_def_by_key",
+ "load_keras_model",
+ "save_keras_model"]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index 3c616c555b..ea4d41d43b 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -30,6 +30,7 @@ cc_library(
hdrs = ["signature_def_utils.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
@@ -42,6 +43,7 @@ tf_cc_test(
srcs = ["signature_def_utils_test.cc"],
deps = [
":signature_def_utils",
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
index a45908d272..e87e497e5f 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description,
*value = &it->second;
return Status::OK();
}
+
+// Looks up the TensorInfo for the given key in the given map and verifies that
+// its datatype matches the given correct datatype.
+bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
+ const string& key, DataType correct_dtype) {
+ const TensorInfo* tensor_info;
+ const Status& status = FindInProtobufMap("", map, key, &tensor_info);
+ if (!status.ok()) {
+ return false;
+ }
+ if (tensor_info->dtype() != correct_dtype) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidPredictSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kPredictMethodName) {
+ return false;
+ }
+ if (signature_def.inputs().empty()) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidRegressionSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kRegressMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
+ DT_FLOAT)) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidClassificationSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kClassifyMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ for (auto const& output : signature_def.outputs()) {
+ const string& key = output.first;
+ const TensorInfo& tensor_info = output.second;
+ if (key == kClassifyOutputClasses) {
+ if (tensor_info.dtype() != DT_STRING) {
+ return false;
+ }
+ } else if (key == kClassifyOutputScores) {
+ if (tensor_info.dtype() != DT_FLOAT) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace
Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
@@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
return Status::OK();
}
+bool IsValidSignature(const SignatureDef& signature_def) {
+ return IsValidClassificationSignature(signature_def) ||
+ IsValidRegressionSignature(signature_def) ||
+ IsValidPredictSignature(signature_def);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index b732cdd41e..bb24faa989 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def,
Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
const string& tensor_info_key, string* name);
+// Determine whether a SignatureDef can be served by TensorFlow Serving.
+bool IsValidSignature(const SignatureDef& signature_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
index a063e95696..c743112ce0 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -22,7 +23,7 @@ limitations under the License.
namespace tensorflow {
-class SignatureDefUtilsTest : public ::testing::Test {
+class FindByKeyTest : public ::testing::Test {
protected:
MetaGraphDef MakeSampleMetaGraphDef() {
MetaGraphDef result;
@@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test {
return result;
}
+ void SetInputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_inputs())[key].set_name(name);
+ }
+
+ void SetOutputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_outputs())[key].set_name(name);
+ }
+
SignatureDef MakeSampleSignatureDef() {
SignatureDef result;
result.set_method_name(kMethodName);
- (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
- (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
- (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
- (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+ SetInputNameForKey(kInput1Key, kInput1Name, &result);
+ SetInputNameForKey(kInput2Key, kInput2Name, &result);
+ SetOutputNameForKey(kOutput1Key, kOutput1Name, &result);
+ SetOutputNameForKey(kOutput2Key, kOutput2Name, &result);
return result;
}
@@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test {
const string kOutput2Name = "output_two";
};
-TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+TEST_F(FindByKeyTest, FindSignatureDefByKey) {
const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
const SignatureDef* signature_def;
// Succeeds for an existing signature.
@@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
.ok());
}
-TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindInputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing input.
@@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
-TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindOutputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing output.
@@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
+class IsValidSignatureTest : public ::testing::Test {
+ protected:
+ void SetInputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_inputs())[key].set_dtype(dtype);
+ }
+
+ void SetOutputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_outputs())[key].set_dtype(dtype);
+ }
+
+ void EraseOutputKey(const string& key) {
+ (*signature_def_.mutable_outputs()).erase(key);
+ }
+
+ void ExpectInvalidSignature() {
+ EXPECT_FALSE(IsValidSignature(signature_def_));
+ }
+
+ void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); }
+
+ SignatureDef signature_def_;
+};
+
+TEST_F(IsValidSignatureTest, IsValidPredictSignature) {
+ signature_def_.set_method_name("not_kPredictMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kPredictMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kPredictInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kPredictOutputs, DT_STRING);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidRegressionSignature) {
+ signature_def_.set_method_name("not_kRegressMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kRegressMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kRegressInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_STRING);
+ // Incorrect data type
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidClassificationSignature) {
+ signature_def_.set_method_name("not_kClassifyMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kClassifyMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kClassifyInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey("invalidKey", DT_FLOAT);
+ // Invalid key
+ ExpectInvalidSignature();
+
+ EraseOutputKey("invalidKey");
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT);
+ // Invalid dtype for classes
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING);
+ // Valid without scores
+ ExpectValidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING);
+ // Invalid dtype for scores
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT);
+ // Valid with both classes and scores
+ ExpectValidSignature();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
index e2a969f053..2c5c8c4afd 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
@@ -20,28 +20,69 @@ from __future__ import print_function
import os
+from tensorflow.python.client import session
+from tensorflow.python.estimator import keras as estimator_keras_util
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export as export_helpers
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import models as models_lib
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
-def save_model(model, saved_model_path):
+def save_keras_model(
+ model, saved_model_path, custom_objects=None, as_text=None):
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
- `save_model` generates such files/folders under the `saved_model_path` folder:
+ `save_model` generates new files/folders under the `saved_model_path` folder:
1) an asset folder containing the json string of the model's
- configuration(topology).
+ configuration (topology).
2) a checkpoint containing the model weights.
+ 3) a saved_model.pb file containing the model's MetaGraphs. The prediction
+ graph is always exported. The evaluaton and training graphs are exported
+ if the following conditions are met:
+ - Evaluation: model loss is defined.
+ - Training: model is compiled with an optimizer defined under `tf.train`.
+ This is because `tf.keras.optimizers.Optimizer` instances cannot be
+ saved to checkpoints.
- Note that subclassed models can not be saved via this function, unless you
- provide an implementation for get_config() and from_config().
- Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
- saved to checkpoints. Use optimizers from `tf.train`.
+ Model Requirements:
+ - Model must be a sequential model or functional model. Subclassed models can
+ not be saved via this function, unless you provide an implementation for
+ get_config() and from_config().
+ - All variables must be saveable by the model. In general, this condition is
+ met through the use of layers defined in the keras library. However,
+ there is currently a bug with variables created in Lambda layer functions
+ not being saved correctly (see
+ https://github.com/keras-team/keras/issues/9740).
+
+ Note that each mode is exported in separate graphs, so different modes do not
+ share variables. To use the train graph with evaluation or prediction graphs,
+ create a new checkpoint if variable values have been updated.
Args:
model: A `tf.keras.Model` to be saved.
saved_model_path: a string specifying the path to the SavedModel directory.
+ The SavedModel will be saved to a timestamped folder created within this
+ directory.
+ custom_objects: Optional dictionary mapping string names to custom classes
+ or functions (e.g. custom loss functions).
+ as_text: whether to write the `SavedModel` proto in text format.
+
+ Returns:
+ String path to the SavedModel folder, a subdirectory of `saved_model_path`.
Raises:
NotImplementedError: If the passed in model is a subclassed model.
@@ -49,35 +90,200 @@ def save_model(model, saved_model_path):
if not model._is_graph_network:
raise NotImplementedError
- # save model configuration as a json string under assets folder.
- model_json = model.to_json()
- assets_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.ASSETS_DIRECTORY))
+ export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
+ temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
+
+ builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
+
+ # Manually save variables to export them in an object-based checkpoint. This
+ # skips the `builder.add_meta_graph_and_variables()` step, which saves a
+ # named-based checkpoint.
+ # TODO(b/113134168): Add fn to Builder to save with object-based saver.
+ # TODO(b/113178242): This should only export the model json structure. Only
+ # one save is needed once the weights can be copied from the model to clone.
+ checkpoint_path = _export_model_json_and_variables(model, temp_export_dir)
+
+ # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
+ # Keras models and `Estimator`s are exported with the same format.
+ # Every time a mode is exported, the code checks to see if new variables have
+ # been created (e.g. optimizer slot variables). If that is the case, the
+ # checkpoint is re-saved to include the new variables.
+ export_args = {'builder': builder,
+ 'model': model,
+ 'custom_objects': custom_objects,
+ 'checkpoint_path': checkpoint_path}
+
+ has_saved_vars = False
+ if model.optimizer:
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args)
+ has_saved_vars = True
+ _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args)
+ else:
+ logging.warning(
+ 'Model was compiled with an optimizer, but the optimizer is not from '
+ '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
+ 'graph was exported. The train and evaluate graphs were not added to '
+ 'the SavedModel.')
+ _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args)
+
+ builder.save(as_text)
+
+ gfile.Rename(temp_export_dir, export_dir)
+ return export_dir
- if not file_io.file_exists(assets_destination_dir):
- file_io.recursive_create_dir(assets_destination_dir)
+def _export_model_json_and_variables(model, saved_model_path):
+ """Save model variables and json structure into SavedModel subdirectories."""
+ # Save model configuration as a json string under assets folder.
+ model_json = model.to_json()
model_json_filepath = os.path.join(
- compat.as_bytes(assets_destination_dir),
- compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ saved_model_utils.get_or_create_assets_dir(saved_model_path),
+ compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
file_io.write_string_to_file(model_json_filepath, model_json)
- # save model weights in checkpoint format.
- checkpoint_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.VARIABLES_DIRECTORY))
+ # Save model weights in checkpoint format under variables folder.
+ saved_model_utils.get_or_create_variables_dir(saved_model_path)
+ checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
+ model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+ return checkpoint_prefix
- if not file_io.file_exists(checkpoint_destination_dir):
- file_io.recursive_create_dir(checkpoint_destination_dir)
- checkpoint_prefix = os.path.join(
- compat.as_text(checkpoint_destination_dir),
- compat.as_text(constants.VARIABLES_FILENAME))
- model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+def _get_var_list(model):
+ """Return list of all checkpointed saveable objects in the model."""
+ return checkpointable_utils.named_saveables(model)
+
+
+def _export_mode(
+ mode, has_saved_vars, builder, model, custom_objects, checkpoint_path):
+ """Export a model, and optionally save new vars from the clone model.
+
+ Args:
+ mode: A `tf.estimator.ModeKeys` string.
+ has_saved_vars: A `boolean` indicating whether the SavedModel has already
+ exported variables.
+ builder: A `SavedModelBuilder` object.
+ model: A `tf.keras.Model` object.
+ custom_objects: A dictionary mapping string names to custom classes
+ or functions.
+ checkpoint_path: String path to checkpoint.
+
+ Raises:
+ ValueError: If the train/eval mode is being exported, but the model does
+ not have an optimizer.
+ """
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
+ if compile_clone and not model.optimizer:
+ raise ValueError(
+ 'Model does not have an optimizer. Cannot export mode %s' % mode)
+
+ model_graph = ops.get_default_graph()
+ with ops.Graph().as_default() as g:
+
+ K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
+
+ # Clone the model into blank graph. This will create placeholders for inputs
+ # and targets.
+ clone = models_lib.clone_and_build_model(
+ model, custom_objects=custom_objects, compile_clone=compile_clone)
+
+ # Make sure that iterations variable is added to the global step collection,
+ # to ensure that, when the SavedModel graph is loaded, the iterations
+ # variable is returned by `tf.train.get_global_step()`. This is required for
+ # compatibility with the SavedModelEstimator.
+ if compile_clone:
+ g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
+
+ # Extract update and train ops from train/test/predict functions.
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ clone._make_train_function()
+ builder._add_train_op(clone.train_function.updates_op)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ clone._make_test_function()
+ else:
+ clone._make_predict_function()
+ g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
+
+ clone_var_list = checkpointable_utils.named_saveables(clone)
+
+ with session.Session().as_default():
+ if has_saved_vars:
+ # Confirm all variables in the clone have an entry in the checkpoint.
+ status = clone.load_weights(checkpoint_path)
+ status.assert_existing_objects_matched()
+ else:
+ # Confirm that variables between the clone and model match up exactly,
+ # not counting optimizer objects. Optimizer objects are ignored because
+ # if the model has not trained, the slot variables will not have been
+ # created yet.
+ # TODO(b/113179535): Replace with checkpointable equivalence.
+ _assert_same_non_optimizer_objects(model, model_graph, clone, g)
+
+ # TODO(b/113178242): Use value transfer for checkpointable objects.
+ clone.load_weights(checkpoint_path)
+
+ # Add graph and variables to SavedModel.
+ # TODO(b/113134168): Switch to add_meta_graph_and_variables.
+ clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
+ builder._has_saved_variables = True
+
+ # Add graph to the SavedModel builder.
+ builder.add_meta_graph(
+ model_fn_lib.EXPORT_TAG_MAP[mode],
+ signature_def_map=_create_signature_def_map(clone, mode),
+ saver=saver_lib.Saver(clone_var_list),
+ main_op=variables.local_variables_initializer())
+ return None
+
+
+def _create_signature_def_map(model, mode):
+ """Create a SignatureDef map from a Keras model."""
+ inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
+ if model.optimizer:
+ targets_dict = {x.name.split(':')[0]: x
+ for x in model.targets if x is not None}
+ inputs_dict.update(targets_dict)
+ outputs_dict = {name: x
+ for name, x in zip(model.output_names, model.outputs)}
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode,
+ predictions=outputs_dict,
+ loss=model.total_loss if model.optimizer else None,
+ metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model))
+ return export_helpers.build_all_signature_defs(
+ inputs_dict,
+ export_outputs=export_outputs,
+ serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))
+
+
+def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):
+ """Assert model and clone contain the same checkpointable objects."""
+
+ def get_non_optimizer_objects(m, g):
+ """Gather set of model and optimizer checkpointable objects."""
+ # Set default graph because optimizer.variables() returns optimizer
+ # variables defined in the default graph.
+ with g.as_default():
+ all_objects = set(checkpointable_utils.list_objects(m))
+ optimizer_and_variables = set()
+ for obj in all_objects:
+ if isinstance(obj, optimizers.TFOptimizer):
+ optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
+ optimizer_and_variables.update(set(obj.optimizer.variables()))
+ return all_objects - optimizer_and_variables
+
+ model_objects = get_non_optimizer_objects(model, model_graph)
+ clone_objects = get_non_optimizer_objects(clone, clone_graph)
+
+ if len(model_objects) != len(clone_objects):
+ raise errors.InternalError(
+ None, None,
+ 'Model and clone must use the same variables.'
+ '\n\tModel variables: %s\n\t Clone variables: %s'
+ % (model_objects, clone_objects))
-def load_model(saved_model_path):
+def load_keras_model(saved_model_path):
"""Load a keras.Model from SavedModel.
load_model reinstantiates model state by:
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 107ae1b07b..12dd72a95b 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -20,20 +20,37 @@ from __future__ import print_function
import os
import shutil
+
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as training_module
class TestModelSavingandLoading(test.TestCase):
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
def test_saving_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -48,19 +65,17 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_sequential_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -69,18 +84,15 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
-
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
def test_saving_functional_model(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -95,19 +107,17 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_functional_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -118,19 +128,17 @@ class TestModelSavingandLoading(test.TestCase):
y = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_with_tf_optimizer(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -142,14 +150,13 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
+ model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
loaded_model.compile(
loss='mse',
optimizer=training_module.RMSPropOptimizer(0.1),
@@ -170,8 +177,10 @@ class TestModelSavingandLoading(test.TestCase):
self.assertAllClose(ref_y, y, atol=1e-05)
# test saving/loading again
- keras_saved_model.save_model(loaded_model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model2 = self._save_model_dir('saved_model_2')
+ output_path2 = keras_saved_model.save_keras_model(
+ loaded_model, temp_saved_model2)
+ loaded_model = keras_saved_model.load_keras_model(output_path2)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -190,11 +199,231 @@ class TestModelSavingandLoading(test.TestCase):
return self.layer2(self.layer1(inp))
model = SubclassedModel()
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
+
+ temp_saved_model = self._save_model_dir()
with self.assertRaises(NotImplementedError):
- keras_saved_model.save_model(model, temp_saved_model)
+ keras_saved_model.save_keras_model(model, temp_saved_model)
+
+
+class LayerWithLearningPhase(keras.engine.base_layer.Layer):
+
+ def call(self, x):
+ phase = keras.backend.learning_phase()
+ output = tf_utils.smart_cond(
+ phase, lambda: x * 0, lambda: array_ops.identity(x))
+ if not context.executing_eagerly():
+ output._uses_learning_phase = True # pylint: disable=protected-access
+ return output
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+def functional_model(uses_learning_phase):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ if uses_learning_phase:
+ x = LayerWithLearningPhase()(x)
+ return keras.models.Model(inputs, x)
+
+
+def sequential_model(uses_learning_phase):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ if uses_learning_phase:
+ model.add(LayerWithLearningPhase())
+ return model
+
+
+def load_model(sess, path, mode):
+ tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+ sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ if mode == model_fn_lib.ModeKeys.PREDICT else mode)
+ meta_graph_def = loader_impl.load(sess, tags, path)
+ inputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].inputs.items()}
+ outputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
+ return inputs, outputs
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
+
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
+ @parameterized.parameters(
+ (functional_model, True, training_module.AdadeltaOptimizer(), True),
+ (functional_model, True, training_module.AdadeltaOptimizer(), False),
+ (functional_model, False, None, False),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), True),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), False),
+ (sequential_model, False, None, False))
+ def testSaveAndLoadSavedModelExport(
+ self, model_builder, uses_learning_phase, optimizer, train_before_export):
+ saved_model_path = self._save_model_dir()
+ with self.test_session(graph=ops.Graph()):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model = model_builder(uses_learning_phase)
+ if optimizer is not None:
+ model.compile(
+ loss='mse',
+ optimizer=optimizer,
+ metrics=['mae'])
+ if train_before_export:
+ model.train_on_batch(input_arr, target_arr)
+
+ ref_loss, ref_mae = model.evaluate(input_arr, target_arr)
+
+ ref_predict = model.predict(input_arr)
+
+ # Export SavedModel
+ output_path = keras_saved_model.save_keras_model(model, saved_model_path)
+
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ target_name = output_name + '_target'
+
+ # Load predict graph, and test predictions
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+
+ predictions = sess.run(outputs[output_name],
+ {inputs[input_name]: input_arr})
+ self.assertAllClose(ref_predict, predictions, atol=1e-05)
+
+ if optimizer:
+ # Load eval graph, and test predictions, loss and metric values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.EVAL)
+
+ eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
+ self.assertAllClose(
+ ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05)
+ self.assertAllClose(
+ ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
+
+ # Load train graph, and check for the train op, and prediction values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.TRAIN)
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertIn('loss', outputs)
+ self.assertIn('metrics/mae/update_op', outputs)
+ self.assertIn('metrics/mae/value', outputs)
+ self.assertIn('predictions/' + output_name, outputs)
+
+ # Train for a step
+ train_op = ops.get_collection(constants.TRAIN_OP_KEY)
+ train_outputs, _ = sess.run(
+ [outputs, train_op], {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+ self.assertEqual(int(train_before_export) + 1,
+ sess.run(training_module.get_global_step()))
+
+ if uses_learning_phase:
+ self.assertAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+ else:
+ self.assertNotAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+
+ def testSaveAndLoadSavedModelWithCustomObject(self):
+ saved_model_path = self._save_model_dir()
+ with session.Session(graph=ops.Graph()) as sess:
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+ inputs = keras.layers.Input(shape=(1,))
+ outputs = keras.layers.Activation(relu6)(inputs)
+ model = keras.models.Model(inputs, outputs)
+ output_path = keras_saved_model.save_keras_model(
+ model, saved_model_path, custom_objects={'relu6': relu6})
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ predictions = sess.run(
+ outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]})
+ self.assertAllEqual([[6], [0], [4]], predictions)
+
+ def testAssertModelCloneSameObjectsIgnoreOptimizer(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
+
+ def testAssertModelCloneSameObjectsThrowError(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(4)(x)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ with self.assertRaisesRegexp(
+ errors.InternalError, 'Model and clone must use the same variables.'):
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
index d10ec9cf0c..3e6ff65c33 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
@@ -43,7 +43,7 @@ class ReaderTest(test.TestCase):
def testReadSavedModelValid(self):
saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
builder.save()
@@ -68,35 +68,35 @@ class ReaderTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph([tag_constants.SERVING])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple predefined tags.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple predefined tags for serving on TPU.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph(["foo", "bar"])
diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py
index a7279bc339..674f7cdb22 100644
--- a/tensorflow/contrib/seq2seq/__init__.py
+++ b/tensorflow/contrib/seq2seq/__init__.py
@@ -15,7 +15,9 @@
"""Ops for building neural network seq2seq decoders and losses.
-See the @{$python/contrib.seq2seq} guide.
+See the
+[Contrib Seq2seq](https://tensorflow.org/api_guides/python/contrib.seq2seq)
+guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index cd162bae25..f2c43f30d4 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -512,7 +512,7 @@ class AttentionWrapperTest(test.TestCase):
for axis in [0, 1]:
for exclusive in [True, False]:
- with self.test_session():
+ with self.cached_session():
# Compute cumprod with regular tf.cumprod
cumprod_output = math_ops.cumprod(
test_input, axis=axis, exclusive=exclusive).eval()
@@ -548,7 +548,7 @@ class AttentionWrapperTest(test.TestCase):
for p, a in zip(p_choose_i, previous_attention)])
# Compute output with TensorFlow function, for both calculation types
- with self.test_session():
+ with self.cached_session():
recursive_output = wrapper.monotonic_attention(
p_choose_i, previous_attention, 'recursive').eval()
@@ -569,7 +569,7 @@ class AttentionWrapperTest(test.TestCase):
for p, a in zip(p_choose_i, previous_attention)])
# Compute output with TensorFlow function, for both calculation types
- with self.test_session():
+ with self.cached_session():
parallel_output = wrapper.monotonic_attention(
p_choose_i, previous_attention, 'parallel').eval()
@@ -594,7 +594,7 @@ class AttentionWrapperTest(test.TestCase):
for p, a in zip(p_choose_i, previous_attention)])
# Compute output with TensorFlow function, for both calculation types
- with self.test_session():
+ with self.cached_session():
hard_output = wrapper.monotonic_attention(
# TensorFlow is unhappy when these are not wrapped as tf.constant
constant_op.constant(p_choose_i),
@@ -634,7 +634,7 @@ class AttentionWrapperTest(test.TestCase):
recursive_output = [np.array([1] + [0]*(p_choose_i.shape[1] - 1),
np.float32)]
# Compute output with TensorFlow function, for both calculation types
- with self.test_session():
+ with self.cached_session():
for j in range(p_choose_i.shape[0]):
# Compute attention distribution for this output time step
recursive_output.append(wrapper.monotonic_attention(
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 4073b390fc..f5b6b1bde9 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -66,7 +66,7 @@ class TestGatherTree(test.TestCase):
max_sequence_lengths=max_sequence_lengths,
end_token=11)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
res_ = sess.run(res)
self.assertAllEqual(expected_result, res_)
@@ -115,7 +115,7 @@ class TestGatherTree(test.TestCase):
sorted_array = beam_search_decoder.gather_tree_from_array(
array, parent_ids, sequence_length)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sorted_array = sess.run(sorted_array)
expected_array = sess.run(expected_array)
self.assertAllEqual(expected_array, sorted_array)
@@ -170,7 +170,7 @@ class TestGatherTree(test.TestCase):
sorted_array = beam_search_decoder.gather_tree_from_array(
array, parent_ids, sequence_length)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sorted_array, expected_array = sess.run([sorted_array, expected_array])
self.assertAllEqual(expected_array, sorted_array)
@@ -186,7 +186,7 @@ class TestArrayShapeChecks(test.TestCase):
batch_size = array_ops.constant(batch_size)
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if is_valid:
sess.run(check_op)
else:
@@ -220,7 +220,7 @@ class TestEosMasking(test.TestCase):
masked = beam_search_decoder._mask_probs(probs, eos_token,
previously_finished)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
probs = sess.run(probs)
masked = sess.run(masked)
@@ -283,7 +283,7 @@ class TestBeamStep(test.TestCase):
end_token=self.end_token,
length_penalty_weight=self.length_penalty_weight)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
@@ -338,7 +338,7 @@ class TestBeamStep(test.TestCase):
end_token=self.end_token,
length_penalty_weight=self.length_penalty_weight)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
@@ -436,7 +436,7 @@ class TestLargeBeamStep(test.TestCase):
end_token=self.end_token,
length_penalty_weight=self.length_penalty_weight)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
outputs_, next_state_, _, _ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
@@ -471,7 +471,7 @@ class BeamSearchDecoderTest(test.TestCase):
output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
beam_width = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_tensor = constant_op.constant(batch_size)
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
index 277c5b6ef7..9662a5780a 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
@@ -67,7 +67,7 @@ class GatherTreeTest(test.TestCase):
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
_ = beams.eval()
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
index cf26e3cae7..a690d9b129 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -138,10 +138,10 @@ Status RunRestoreOp(const RunOptions& run_options, const StringPiece export_dir,
Tensor variables_tensor =
CreateStringTensor(GetVariablesFilename(export_dir));
std::vector<std::pair<string, Tensor>> inputs = {
- {variables_filename_const_op_name.ToString(), variables_tensor}};
+ {string(variables_filename_const_op_name), variables_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
+ return session->Run(run_options, inputs, {}, {string(restore_op_name)},
nullptr /* outputs */, &run_metadata);
}
@@ -152,7 +152,7 @@ Status RunInitOp(const RunOptions& run_options, const StringPiece export_dir,
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {init_op_name.ToString()},
+ return session->Run(run_options, inputs, {}, {string(init_op_name)},
nullptr /* outputs */, &run_metadata);
}
@@ -251,15 +251,14 @@ Status LoadSessionBundleFromPathUsingRunOptions(const SessionOptions& options,
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "Loading SessionBundle: " << status_str << ". Took "
<< load_latency_microsecs << " microseconds.";
- load_attempt_count->GetCell(export_dir.ToString(), status_str)
- ->IncrementBy(1);
+ load_attempt_count->GetCell(string(export_dir), status_str)->IncrementBy(1);
};
if (status.ok()) {
log_and_count(kLoadAttemptSuccess);
} else {
log_and_count(kLoadAttemptFail);
}
- load_latency->GetCell(export_dir.ToString())
+ load_latency->GetCell(string(export_dir))
->IncrementBy(load_latency_microsecs);
return status;
}
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.py b/tensorflow/contrib/session_bundle/session_bundle_test.py
index a57e8920c5..3c06ec048d 100644
--- a/tensorflow/contrib/session_bundle/session_bundle_test.py
+++ b/tensorflow/contrib/session_bundle/session_bundle_test.py
@@ -167,7 +167,7 @@ class SessionBundleLoadNoVarsTest(test.TestCase):
y = math_ops.subtract(w * x, 7.0, name="y") # pylint: disable=unused-variable
ops.add_to_collection("meta", "this is meta")
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
variables.global_variables_initializer().run()
new_graph_def = graph_util.convert_variables_to_constants(
session, g.as_graph_def(), ["y"])
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 6a2080bcec..d088e74434 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Signal processing operations.
-See the @{$python/contrib.signal} guide.
+See the
+[Contrib Signal](https://tensorflow.org/api_guides/python/contrib.signal)
+guide.
@@frame
@@hamming_window
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index d877831fce..a6ce45c203 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -416,12 +416,17 @@ class Image(ItemHandler):
def decode_image():
"""Decodes a image based on the headers."""
- return image_ops.decode_image(image_buffer, channels=self._channels)
+ return math_ops.cast(
+ image_ops.decode_image(image_buffer, channels=self._channels),
+ self._dtype)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
- return image_ops.decode_jpeg(
- image_buffer, channels=self._channels, dct_method=self._dct_method)
+ return math_ops.cast(
+ image_ops.decode_jpeg(
+ image_buffer,
+ channels=self._channels,
+ dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""Checks if an image is jpeg."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index d783d4fef4..826242c9d7 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -37,12 +37,12 @@ from tensorflow.python.platform import test
class TFExampleDecoderTest(test.TestCase):
def _EncodedFloatFeature(self, ndarray):
- return feature_pb2.Feature(float_list=feature_pb2.FloatList(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=ndarray.flatten().tolist()))
def _EncodedInt64Feature(self, ndarray):
- return feature_pb2.Feature(int64_list=feature_pb2.Int64List(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
with self.test_session():
@@ -74,12 +74,14 @@ class TFExampleDecoderTest(test.TestCase):
if image_format in ['raw', 'RAW']:
return constant_op.constant(image.tostring(), dtype=dtypes.string)
- def GenerateImage(self, image_format, image_shape):
+ def GenerateImage(self, image_format, image_shape, image_dtype=np.uint8):
"""Generates an image and an example containing the encoded image.
Args:
image_format: the encoding format of the image.
image_shape: the shape of the image to generate.
+ image_dtype: the dtype of values in the image. Only 'raw' image can have
+ type different than uint8.
Returns:
image: the generated image.
@@ -87,14 +89,18 @@ class TFExampleDecoderTest(test.TestCase):
serialized image and a feature key 'image/format' set to the image
encoding format ['jpeg', 'JPEG', 'png', 'PNG', 'raw'].
"""
+ assert image_format in ['raw', 'RAW'] or image_dtype == np.uint8
num_pixels = image_shape[0] * image_shape[1] * image_shape[2]
image = np.linspace(
- 0, num_pixels - 1, num=num_pixels).reshape(image_shape).astype(np.uint8)
+ 0, num_pixels - 1,
+ num=num_pixels).reshape(image_shape).astype(image_dtype)
tf_encoded = self._Encoder(image, image_format)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': self._EncodedBytesFeature(tf_encoded),
- 'image/format': self._StringFeature(image_format)
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded': self._EncodedBytesFeature(tf_encoded),
+ 'image/format': self._StringFeature(image_format)
+ }))
return image, example.SerializeToString()
@@ -168,8 +174,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_decoded_image = self.DecodeExample(
serialized_example,
- tfexample_decoder.Image(
- shape=None, channels=channels),
+ tfexample_decoder.Image(shape=None, channels=channels),
image_format='jpeg')
self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
@@ -225,27 +230,38 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
+ def testDecodeExampleWithRawEncodingFloatDtype(self):
image_shape = (2, 3, 3)
- unused_image, serialized_example = self.GenerateImage(
+ image, serialized_example = self.GenerateImage(
+ image_format='raw', image_shape=image_shape, image_dtype=np.float32)
+
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32),
+ image_format='raw')
+
+ self.assertAllClose(image, decoded_image, atol=0)
+
+ def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
+ image_shape = (2, 3, 3)
+ # Image has type uint8 but decoding at uint16 should not cause problems.
+ image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
- # decode_raw support uint16 now so ValueError will be thrown instead.
- with self.assertRaisesRegexp(
- ValueError,
- 'true_fn and false_fn must have the same type: uint16, uint8'):
- unused_decoded_image = self.RunDecodeExample(
- serialized_example,
- tfexample_decoder.Image(dtype=dtypes.uint16),
- image_format='jpeg')
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(dtype=dtypes.uint16),
+ image_format='jpeg')
+ self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithStringTensor(self):
tensor_shape = (2, 3, 1)
np_array = np.array([[['ab'], ['cd'], ['ef']],
[['ghi'], ['jkl'], ['mnop']]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._BytesFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._BytesFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -259,7 +275,9 @@ class TFExampleDecoderTest(test.TestCase):
default_value=constant_op.constant(
'', shape=tensor_shape, dtype=dtypes.string))
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -271,9 +289,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFloatTensor(self):
np_array = np.random.rand(2, 3, 1).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -282,7 +301,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -291,9 +312,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithInt64Tensor(self):
np_array = np.random.randint(1, 10, size=(2, 3, 1))
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -302,7 +324,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -311,9 +335,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensor(self):
np_array = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -322,7 +347,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -332,9 +359,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFixLenTensorWithShape(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -342,12 +370,10 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
- parsing_ops.FixedLenFeature(
- np_array.shape, dtype=dtypes.int64),
+ parsing_ops.FixedLenFeature(np_array.shape, dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -357,9 +383,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensorToDense(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -369,8 +396,7 @@ class TFExampleDecoderTest(test.TestCase):
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -382,12 +408,18 @@ class TFExampleDecoderTest(test.TestCase):
np_image = np.random.rand(2, 3, 1).astype('f')
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/shape': self._EncodedInt64Feature(np.array(np_labels.shape)),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/shape':
+ self._EncodedInt64Feature(np.array(np_labels.shape)),
+ }))
serialized_example = example.SerializeToString()
@@ -401,11 +433,9 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
- tfexample_decoder.Tensor(
- 'labels', shape_keys='labels/shape'),
+ tfexample_decoder.Tensor('labels', shape_keys='labels/shape'),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -419,14 +449,22 @@ class TFExampleDecoderTest(test.TestCase):
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
height, width, depth = np_labels.shape
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/height': self._EncodedInt64Feature(np.array([height])),
- 'labels/width': self._EncodedInt64Feature(np.array([width])),
- 'labels/depth': self._EncodedInt64Feature(np.array([depth])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/height':
+ self._EncodedInt64Feature(np.array([height])),
+ 'labels/width':
+ self._EncodedInt64Feature(np.array([width])),
+ 'labels/depth':
+ self._EncodedInt64Feature(np.array([depth])),
+ }))
serialized_example = example.SerializeToString()
@@ -442,8 +480,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
tfexample_decoder.Tensor(
'labels',
@@ -459,10 +496,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithSparseTensor(self):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -472,7 +511,9 @@ class TFExampleDecoderTest(test.TestCase):
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
}
- items_to_handlers = {'labels': tfexample_decoder.SparseTensor(),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.SparseTensor(),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -485,11 +526,13 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- 'shape': self._EncodedInt64Feature(np_shape),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ 'shape': self._EncodedInt64Feature(np_shape),
+ }))
serialized_example = example.SerializeToString()
@@ -515,10 +558,12 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -544,10 +589,12 @@ class TFExampleDecoderTest(test.TestCase):
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
np_dense = np.array([0.0, 0.1, 0.2, 0.0, 0.0, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -559,8 +606,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'labels':
- tfexample_decoder.SparseTensor(
- shape=np_shape, densify=True),
+ tfexample_decoder.SparseTensor(shape=np_shape, densify=True),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -572,9 +618,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -603,9 +650,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -701,12 +749,14 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -740,26 +790,32 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
- 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
}
items_to_handlers = {
@@ -784,11 +840,16 @@ class TFExampleDecoderTest(test.TestCase):
with self.test_session():
tf_string = tf_encoded.eval()
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
- value=[tf_string, tf_string])),
- 'image/format': self._StringFeature(image_format),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded':
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[tf_string, tf_string])),
+ 'image/format':
+ self._StringFeature(image_format),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -797,8 +858,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features={
'image/encoded':
- parsing_ops.FixedLenFeature(
- (2,), dtypes.string),
+ parsing_ops.FixedLenFeature((2,), dtypes.string),
'image/format':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value=image_format),
@@ -814,10 +874,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithLookup(self):
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/class/text': self._BytesFeature(
- np.array(['cat', 'dog', 'guinea pig'])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ }))
serialized_example = example.SerializeToString()
# 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
table = lookup_ops.index_table_from_tensor(
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 2c97834523..cbfdaeb45d 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -100,7 +100,7 @@ class EvaluationTest(test.TestCase):
# Save initialized variables to a checkpoint directory:
saver = saver_lib.Saver()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_op.run()
saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
@@ -211,7 +211,7 @@ class EvaluationTest(test.TestCase):
# Save initialized variables to a checkpoint directory:
saver = saver_lib.Saver()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_op.run()
saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
@@ -248,7 +248,7 @@ class SingleEvaluationTest(test.TestCase):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
saver.save(sess, checkpoint_path)
diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py
index 831c6e427a..d92a7fbb47 100644
--- a/tensorflow/contrib/slim/python/slim/learning_test.py
+++ b/tensorflow/contrib/slim/python/slim/learning_test.py
@@ -73,7 +73,7 @@ class ClipGradientNormsTest(test.TestCase):
# Ensure the variable passed through.
self.assertEqual(gradients_to_variables[1], variable)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_gradient = sess.run(gradients_to_variables[0])
np_testing.assert_almost_equal(actual_gradient, self._clipped_grad_vec)
@@ -164,7 +164,7 @@ class MultiplyGradientsTest(test.TestCase):
# Ensure the variable passed through.
self.assertEqual(grad_to_var[1], variable)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_gradient = sess.run(grad_to_var[0])
np_testing.assert_almost_equal(actual_gradient, self._multiplied_grad_vec,
5)
@@ -188,7 +188,7 @@ class MultiplyGradientsTest(test.TestCase):
self.assertEqual(grad_to_var[0].indices, indices)
self.assertEqual(grad_to_var[0].dense_shape, dense_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_gradient = sess.run(grad_to_var[0].values)
np_testing.assert_almost_equal(actual_gradient, self._multiplied_grad_vec,
5)
@@ -204,7 +204,7 @@ class MultiplyGradientsTest(test.TestCase):
[grad_to_var] = learning.multiply_gradients([grad_to_var],
gradient_multipliers)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
gradient_true_flag = sess.run(grad_to_var[0])
sess.run(multiplier_flag.assign(False))
diff --git a/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py
index eb93f753ae..b6d1afd27d 100644
--- a/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py
@@ -33,7 +33,7 @@ class AlexnetV2Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = alexnet.alexnet_v2(inputs, num_classes)
self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
@@ -44,7 +44,7 @@ class AlexnetV2Test(test.TestCase):
batch_size = 1
height, width = 300, 400
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
@@ -55,7 +55,7 @@ class AlexnetV2Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
_, end_points = alexnet.alexnet_v2(inputs, num_classes)
expected_names = [
@@ -70,7 +70,7 @@ class AlexnetV2Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
alexnet.alexnet_v2(inputs, num_classes)
expected_names = [
@@ -98,7 +98,7 @@ class AlexnetV2Test(test.TestCase):
batch_size = 2
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
self.assertListEqual(logits.get_shape().as_list(),
@@ -112,7 +112,7 @@ class AlexnetV2Test(test.TestCase):
train_height, train_width = 224, 224
eval_height, eval_width = 300, 400
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
train_inputs = random_ops.random_uniform(
(train_batch_size, train_height, train_width, 3))
logits, _ = alexnet.alexnet_v2(train_inputs)
@@ -132,7 +132,7 @@ class AlexnetV2Test(test.TestCase):
def testForward(self):
batch_size = 1
height, width = 224, 224
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = alexnet.alexnet_v2(inputs)
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py
index 7a3d1c9703..34f12d7591 100644
--- a/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/inception_v1_test.py
@@ -143,7 +143,7 @@ class InceptionV1Test(test.TestCase):
height, width = 224, 224
num_classes = 1000
input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(batch_size, None, None, 3))
logits, end_points = inception_v1.inception_v1(inputs, num_classes)
@@ -167,7 +167,7 @@ class InceptionV1Test(test.TestCase):
self.assertListEqual(logits.get_shape().as_list(), [None, num_classes])
images = random_ops.random_uniform((batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch_size, num_classes))
@@ -182,7 +182,7 @@ class InceptionV1Test(test.TestCase):
eval_inputs, num_classes, is_training=False)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,))
@@ -200,7 +200,7 @@ class InceptionV1Test(test.TestCase):
logits, _ = inception_v1.inception_v1(eval_inputs, num_classes, reuse=True)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,))
@@ -211,7 +211,7 @@ class InceptionV1Test(test.TestCase):
logits, _ = inception_v1.inception_v1(
images, num_classes=num_classes, spatial_squeeze=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py
index 5fbc9e5aa3..66effba944 100644
--- a/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/inception_v2_test.py
@@ -196,7 +196,7 @@ class InceptionV2Test(test.TestCase):
height, width = 224, 224
num_classes = 1000
input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(batch_size, None, None, 3))
logits, end_points = inception_v2.inception_v2(inputs, num_classes)
@@ -220,7 +220,7 @@ class InceptionV2Test(test.TestCase):
self.assertListEqual(logits.get_shape().as_list(), [None, num_classes])
images = random_ops.random_uniform((batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch_size, num_classes))
@@ -235,7 +235,7 @@ class InceptionV2Test(test.TestCase):
eval_inputs, num_classes, is_training=False)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,))
@@ -253,7 +253,7 @@ class InceptionV2Test(test.TestCase):
logits, _ = inception_v2.inception_v2(eval_inputs, num_classes, reuse=True)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,))
@@ -264,7 +264,7 @@ class InceptionV2Test(test.TestCase):
logits, _ = inception_v2.inception_v2(
images, num_classes=num_classes, spatial_squeeze=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
diff --git a/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py b/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py
index 6ba02318ed..0f9cca7bbd 100644
--- a/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/inception_v3_test.py
@@ -226,7 +226,7 @@ class InceptionV3Test(test.TestCase):
height, width = 299, 299
num_classes = 1000
input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(batch_size, None, None, 3))
logits, end_points = inception_v3.inception_v3(inputs, num_classes)
@@ -249,7 +249,7 @@ class InceptionV3Test(test.TestCase):
self.assertListEqual(logits.get_shape().as_list(), [None, num_classes])
images = random_ops.random_uniform((batch_size, height, width, 3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch_size, num_classes))
@@ -264,7 +264,7 @@ class InceptionV3Test(test.TestCase):
eval_inputs, num_classes, is_training=False)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (batch_size,))
@@ -283,7 +283,7 @@ class InceptionV3Test(test.TestCase):
eval_inputs, num_classes, is_training=False, reuse=True)
predictions = math_ops.argmax(logits, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(predictions)
self.assertEquals(output.shape, (eval_batch_size,))
@@ -294,7 +294,7 @@ class InceptionV3Test(test.TestCase):
logits, _ = inception_v3.inception_v3(
images, num_classes=num_classes, spatial_squeeze=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
logits_out = sess.run(logits)
self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
diff --git a/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py
index 317af3cb29..44fa35ad14 100644
--- a/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py
@@ -33,7 +33,7 @@ class OverFeatTest(test.TestCase):
batch_size = 5
height, width = 231, 231
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = overfeat.overfeat(inputs, num_classes)
self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
@@ -44,7 +44,7 @@ class OverFeatTest(test.TestCase):
batch_size = 1
height, width = 281, 281
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
@@ -55,7 +55,7 @@ class OverFeatTest(test.TestCase):
batch_size = 5
height, width = 231, 231
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
_, end_points = overfeat.overfeat(inputs, num_classes)
expected_names = [
@@ -70,7 +70,7 @@ class OverFeatTest(test.TestCase):
batch_size = 5
height, width = 231, 231
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
overfeat.overfeat(inputs, num_classes)
expected_names = [
@@ -98,7 +98,7 @@ class OverFeatTest(test.TestCase):
batch_size = 2
height, width = 231, 231
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
self.assertListEqual(logits.get_shape().as_list(),
@@ -112,7 +112,7 @@ class OverFeatTest(test.TestCase):
train_height, train_width = 231, 231
eval_height, eval_width = 281, 281
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
train_inputs = random_ops.random_uniform(
(train_batch_size, train_height, train_width, 3))
logits, _ = overfeat.overfeat(train_inputs)
@@ -132,7 +132,7 @@ class OverFeatTest(test.TestCase):
def testForward(self):
batch_size = 1
height, width = 231, 231
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = overfeat.overfeat(inputs)
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
index 576444214d..8ff44fe4b5 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
@@ -69,7 +69,7 @@ class ResnetUtilsTest(test.TestCase):
x = resnet_utils.subsample(x, 2)
expected = array_ops.reshape(
constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(x.eval(), expected.eval())
def testSubsampleFourByFour(self):
@@ -77,7 +77,7 @@ class ResnetUtilsTest(test.TestCase):
x = resnet_utils.subsample(x, 2)
expected = array_ops.reshape(
constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(x.eval(), expected.eval())
def testConv2DSameEven(self):
@@ -110,7 +110,7 @@ class ResnetUtilsTest(test.TestCase):
y4_expected = math_ops.to_float([[48, 37], [37, 22]])
y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
@@ -148,7 +148,7 @@ class ResnetUtilsTest(test.TestCase):
y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
y4_expected = y2_expected
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
@@ -223,7 +223,7 @@ class ResnetUtilsTest(test.TestCase):
with arg_scope([layers.batch_norm], is_training=False):
for output_stride in [1, 2, 4, 8, None]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(1, height, width, 3)
# Dense feature extraction followed by subsampling.
@@ -364,7 +364,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
for output_stride in [4, 8, 16, 32, None]:
with arg_scope(resnet_utils.resnet_arg_scope()):
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
@@ -401,7 +401,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
@@ -415,7 +415,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 3, 3, 32))
@@ -431,7 +431,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs, None, global_pool=global_pool, output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 9, 9, 32))
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
index 6bdda18c5b..055ecff1c3 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
@@ -69,7 +69,7 @@ class ResnetUtilsTest(test.TestCase):
x = resnet_utils.subsample(x, 2)
expected = array_ops.reshape(
constant_op.constant([0, 2, 6, 8]), [1, 2, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(x.eval(), expected.eval())
def testSubsampleFourByFour(self):
@@ -77,7 +77,7 @@ class ResnetUtilsTest(test.TestCase):
x = resnet_utils.subsample(x, 2)
expected = array_ops.reshape(
constant_op.constant([0, 2, 8, 10]), [1, 2, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(x.eval(), expected.eval())
def testConv2DSameEven(self):
@@ -110,7 +110,7 @@ class ResnetUtilsTest(test.TestCase):
y4_expected = math_ops.to_float([[48, 37], [37, 22]])
y4_expected = array_ops.reshape(y4_expected, [1, n2, n2, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
@@ -151,7 +151,7 @@ class ResnetUtilsTest(test.TestCase):
y4 = layers.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
y4_expected = y2_expected
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllClose(y1.eval(), y1_expected.eval())
self.assertAllClose(y2.eval(), y2_expected.eval())
@@ -227,7 +227,7 @@ class ResnetUtilsTest(test.TestCase):
with arg_scope([layers.batch_norm], is_training=False):
for output_stride in [1, 2, 4, 8, None]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(1, height, width, 3)
# Dense feature extraction followed by subsampling.
@@ -368,7 +368,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
for output_stride in [4, 8, 16, 32, None]:
with arg_scope(resnet_utils.resnet_arg_scope()):
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
@@ -405,7 +405,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 1, 1, num_classes))
@@ -419,7 +419,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 3, 3, 32))
@@ -435,7 +435,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs, None, global_pool=global_pool, output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.shape, (batch, 9, 9, 32))
diff --git a/tensorflow/contrib/slim/python/slim/nets/vgg_test.py b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py
index 36628b32d1..71ce4b89cd 100644
--- a/tensorflow/contrib/slim/python/slim/nets/vgg_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py
@@ -34,7 +34,7 @@ class VGGATest(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_a(inputs, num_classes)
self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed')
@@ -45,7 +45,7 @@ class VGGATest(test.TestCase):
batch_size = 1
height, width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False)
self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
@@ -73,7 +73,7 @@ class VGGATest(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
vgg.vgg_a(inputs, num_classes)
expected_names = [
@@ -107,7 +107,7 @@ class VGGATest(test.TestCase):
batch_size = 2
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_a(eval_inputs, is_training=False)
self.assertListEqual(logits.get_shape().as_list(),
@@ -121,7 +121,7 @@ class VGGATest(test.TestCase):
train_height, train_width = 224, 224
eval_height, eval_width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
train_inputs = random_ops.random_uniform(
(train_batch_size, train_height, train_width, 3))
logits, _ = vgg.vgg_a(train_inputs)
@@ -141,7 +141,7 @@ class VGGATest(test.TestCase):
def testForward(self):
batch_size = 1
height, width = 224, 224
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_a(inputs)
sess.run(variables.global_variables_initializer())
@@ -155,7 +155,7 @@ class VGG16Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_16(inputs, num_classes)
self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed')
@@ -166,7 +166,7 @@ class VGG16Test(test.TestCase):
batch_size = 1
height, width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False)
self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
@@ -197,7 +197,7 @@ class VGG16Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
vgg.vgg_16(inputs, num_classes)
expected_names = [
@@ -241,7 +241,7 @@ class VGG16Test(test.TestCase):
batch_size = 2
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_16(eval_inputs, is_training=False)
self.assertListEqual(logits.get_shape().as_list(),
@@ -255,7 +255,7 @@ class VGG16Test(test.TestCase):
train_height, train_width = 224, 224
eval_height, eval_width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
train_inputs = random_ops.random_uniform(
(train_batch_size, train_height, train_width, 3))
logits, _ = vgg.vgg_16(train_inputs)
@@ -275,7 +275,7 @@ class VGG16Test(test.TestCase):
def testForward(self):
batch_size = 1
height, width = 224, 224
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_16(inputs)
sess.run(variables.global_variables_initializer())
@@ -289,7 +289,7 @@ class VGG19Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_19(inputs, num_classes)
self.assertEquals(logits.op.name, 'vgg_19/fc8/squeezed')
@@ -300,7 +300,7 @@ class VGG19Test(test.TestCase):
batch_size = 1
height, width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_19(inputs, num_classes, spatial_squeeze=False)
self.assertEquals(logits.op.name, 'vgg_19/fc8/BiasAdd')
@@ -332,7 +332,7 @@ class VGG19Test(test.TestCase):
batch_size = 5
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((batch_size, height, width, 3))
vgg.vgg_19(inputs, num_classes)
expected_names = [
@@ -382,7 +382,7 @@ class VGG19Test(test.TestCase):
batch_size = 2
height, width = 224, 224
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_19(eval_inputs, is_training=False)
self.assertListEqual(logits.get_shape().as_list(),
@@ -396,7 +396,7 @@ class VGG19Test(test.TestCase):
train_height, train_width = 224, 224
eval_height, eval_width = 256, 256
num_classes = 1000
- with self.test_session():
+ with self.cached_session():
train_inputs = random_ops.random_uniform(
(train_batch_size, train_height, train_width, 3))
logits, _ = vgg.vgg_19(train_inputs)
@@ -416,7 +416,7 @@ class VGG19Test(test.TestCase):
def testForward(self):
batch_size = 1
height, width = 224, 224
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((batch_size, height, width, 3))
logits, _ = vgg.vgg_19(inputs)
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/slim/python/slim/summaries_test.py b/tensorflow/contrib/slim/python/slim/summaries_test.py
index 873ee78de2..c6017f073e 100644
--- a/tensorflow/contrib/slim/python/slim/summaries_test.py
+++ b/tensorflow/contrib/slim/python/slim/summaries_test.py
@@ -88,7 +88,7 @@ class SummariesTest(test.TestCase):
summary_op = summary.merge_all()
summary_writer = summary.FileWriter(output_dir)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_summary = sess.run(summary_op)
summary_writer.add_summary(new_summary, 1)
summary_writer.flush()
diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py
index 9a4ad36793..b7ce6aa20a 100644
--- a/tensorflow/contrib/specs/python/specs_test.py
+++ b/tensorflow/contrib/specs/python/specs_test.py
@@ -38,7 +38,7 @@ def _rand(*size):
class SpecsTest(test.TestCase):
def testSimpleConv(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -53,7 +53,7 @@ class SpecsTest(test.TestCase):
def testUnary(self):
# This is just a quick and dirty check that these ops exist
# and work as unary ops.
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax"
outputs = specs.create_net(spec, inputs)
@@ -63,7 +63,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(tuple(result.shape), (17, 55))
def testAdd(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Fs(10) + Fr(10)"
outputs = specs.create_net(spec, inputs)
@@ -77,7 +77,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd relu add")
def testMpPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "M2 = Mp([2, 2]); net = M2**3"
outputs = specs.create_net(spec, inputs)
@@ -90,7 +90,7 @@ class SpecsTest(test.TestCase):
"_ maxpool maxpool maxpool")
def testAbbrevPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3"
outputs = specs.create_net(spec, inputs)
@@ -106,7 +106,7 @@ class SpecsTest(test.TestCase):
" biasadd relu maxpool")
def testAbbrevPower2(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);"
spec += "net = (C3(_0=5) | M2)**3"
@@ -123,7 +123,7 @@ class SpecsTest(test.TestCase):
" maxpool")
def testConc(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "net = Conc(1, Fs(20), Fs(10))"
outputs = specs.create_net(spec, inputs)
@@ -137,7 +137,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd sig _ concatv2")
def testImport(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = ("S = Import('from tensorflow.python.ops" +
" import math_ops; f = math_ops.sigmoid')")
@@ -150,7 +150,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig")
def testKeywordRestriction(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "import re; net = Conc(1, Fs(20), Fs(10))"
self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs))
@@ -179,7 +179,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testVar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with specs.ops:
# pylint: disable=undefined-variable
v = Var("test_var",
@@ -196,7 +196,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testShared(self):
- with self.test_session():
+ with self.cached_session():
with specs.ops:
# pylint: disable=undefined-variable
f = Shared(Fr(100))
diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py
index 34ff4bc8ca..b82ba06d3f 100644
--- a/tensorflow/contrib/specs/python/summaries_test.py
+++ b/tensorflow/contrib/specs/python/summaries_test.py
@@ -34,7 +34,7 @@ def _rand(*size):
class SummariesTest(test.TestCase):
def testStructure(self):
- with self.test_session():
+ with self.cached_session():
inputs_shape = (1, 18, 19, 5)
inputs = constant_op.constant(_rand(*inputs_shape))
spec = "net = Cr(64, [5, 5])"
@@ -48,7 +48,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testStructureFromTensor(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -60,7 +60,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testPrint(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -70,7 +70,7 @@ class SummariesTest(test.TestCase):
summaries.tf_spec_print(spec, inputs)
def testSummary(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 164f3e58e6..00c855daa3 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -462,7 +462,10 @@ py_test(
size = "small",
srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"],
+ tags = [
+ "no_gpu",
+ "no_pip_gpu",
+ ],
deps = [
":tensor_forest_ops_py",
"//tensorflow/python:framework_test_lib",
@@ -515,6 +518,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
+ "//tensorflow/contrib/estimator:head",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
@@ -533,10 +537,11 @@ py_library(
py_test(
name = "random_forest_test",
- size = "medium",
+ size = "large",
srcs = ["client/random_forest_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "noasan",
"nomac", # b/63258195
"notsan",
],
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 8fa0b3ada9..0042d37acd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
+from tensorflow.contrib.estimator.python.estimator import head as core_head_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -25,7 +26,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
@@ -130,17 +130,23 @@ def _get_default_head(params, weights_name, output_type, name=None):
head_name=name)
else:
if params.regression:
- return core_head_lib._regression_head( # pylint:disable=protected-access
+ return core_head_lib.regression_head(
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
- return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
- n_classes=params.num_classes,
- weight_column=weights_name,
- name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ if params.num_classes == 2:
+ return core_head_lib.binary_classification_head(
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ else:
+ return core_head_lib.multi_class_head(
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
def get_model_fn(params,
graph_builder_class,
diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
index 69a0143a4e..1ed3d8ca2e 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
+++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
-#define LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
#include <vector>
#include "tensorflow/core/framework/tensor.h"
@@ -43,4 +43,4 @@ void GetFeatureSet(int32 tree_num, int32 node_num, int32 random_seed,
} // namespace tensorforest
} // namespace tensorflow
-#endif // LEARNING_LIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_HYBRID_CORE_OPS_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/k_feature_routing_function_op_test.py b/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/k_feature_routing_function_op_test.py
index 980f53253d..cc053f3b94 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/k_feature_routing_function_op_test.py
+++ b/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/k_feature_routing_function_op_test.py
@@ -58,7 +58,7 @@ class KFeatureRoutingFunctionTest(test_util.TensorFlowTestCase):
self.assertEquals(self.params.num_features_per_node, 2)
def testRoutingFunction(self):
- with self.test_session():
+ with self.cached_session():
route_tensor = gen_training_ops.k_feature_routing_function(
self.input_data,
self.tree_weights,
diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/routing_function_op_test.py b/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/routing_function_op_test.py
index a27fd49d32..554f7b0d7a 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/routing_function_op_test.py
+++ b/tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests/routing_function_op_test.py
@@ -36,7 +36,7 @@ class RoutingFunctionTest(test_util.TensorFlowTestCase):
self.ops = training_ops.Load()
def testRoutingFunction(self):
- with self.test_session():
+ with self.cached_session():
route_tensor = gen_training_ops.routing_function(
self.input_data, self.tree_weights, self.tree_thresholds, max_nodes=3)
diff --git a/tensorflow/contrib/tensor_forest/kernels/data_spec.h b/tensorflow/contrib/tensor_forest/kernels/data_spec.h
index bb33400214..336a7a3239 100644
--- a/tensorflow/contrib/tensor_forest/kernels/data_spec.h
+++ b/tensorflow/contrib/tensor_forest/kernels/data_spec.h
@@ -15,8 +15,8 @@
// This is a surrogate for using a proto, since it doesn't seem to be possible
// to use protos in a dynamically-loaded/shared-linkage library, which is
// what is used for custom ops in tensorflow/contrib.
-#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
-#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
#include <unordered_map>
#include "tensorflow/core/lib/strings/numbers.h"
@@ -139,4 +139,4 @@ class TensorForestDataSpec {
} // namespace tensorforest
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_DATA_SPEC_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
index 03aab1b61e..e04eb60f9b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
-#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_TREE_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_TREE_UTILS_H_
#include <limits>
@@ -302,4 +302,4 @@ void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
} // namespace tensorforest
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_TREE_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
index 6cb2c881e2..7716536ba4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
@@ -54,17 +54,24 @@ InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
<< "Invalid feature ID: [" << test.feature_id().id().value() << "]";
threshold_ = test.threshold().float_value();
- include_equals_ =
- test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL;
+ _test_type = test.type();
}
int32 InequalityDecisionNodeEvaluator::Decide(
const std::unique_ptr<TensorDataSet>& dataset, int example) const {
const float val = dataset->GetExampleValue(example, feature_num_);
- if (val < threshold_ || (include_equals_ && val == threshold_)) {
- return left_child_id_;
- } else {
- return right_child_id_;
+ switch (_test_type) {
+ case decision_trees::InequalityTest::LESS_OR_EQUAL:
+ return val <= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::LESS_THAN:
+ return val < threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_OR_EQUAL:
+ return val >= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_THAN:
+ return val > threshold_ ? left_child_id_ : right_child_id_;
+ default:
+ LOG(ERROR) << "Unknown split test type: " << _test_type;
+ return -1;
}
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
index 3db351c328..6497787f84 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
@@ -55,9 +55,7 @@ class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
protected:
int32 feature_num_;
float threshold_;
-
- // If decision is '<=' as opposed to '<'.
- bool include_equals_;
+ ::tensorflow::decision_trees::InequalityTest_Type _test_type;
};
// Evaluator for splits with multiple weighted features.
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
index af5cf72a3c..3db1335563 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
@@ -60,6 +60,40 @@ TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) {
ASSERT_EQ(eval->Decide(dataset, 4), 1);
}
+TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_OR_EQUAL);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 0);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
+TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_THAN);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 1);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
TEST(MatchingDecisionNodeEvaluatorTest, Basic) {
MatchingValuesTest test;
test.mutable_feature_id()->mutable_id()->set_value("0");
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index d43884481a..99c5800391 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example,
num_total_features += num_sparse;
}
}
- int rand_feature = rng_->Uniform(num_total_features);
+ int rand_feature = 0;
+ {
+ mutex_lock lock(mu_);
+ rand_feature = rng_->Uniform(num_total_features);
+ }
if (rand_feature < available_features_.size()) { // it's dense.
*feature_id = available_features_[rand_feature];
*type = input_spec_.GetDenseFeatureType(rand_feature);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index 95f75b4d7e..4945b53007 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorforest {
@@ -120,6 +121,8 @@ class TensorDataSet {
int32 split_sampling_random_seed_;
std::unique_ptr<random::PhiloxRandom> single_rand_;
std::unique_ptr<random::SimplePhilox> rng_;
+ // Mutex for using random number generator.
+ mutable mutex mu_;
};
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index fc0d22d112..122a67a407 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -279,7 +279,9 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
+ "//tensorflow/core:gpu_runtime",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -293,6 +295,31 @@ tf_cuda_library(
]) + tf_custom_op_library_additional_deps(),
)
+tf_cuda_cc_test(
+ name = "convert_graph_test",
+ size = "medium",
+ srcs = ["convert/convert_graph_test.cc"],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+ deps = [
+ ":trt_conversion",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_base",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
# Library for the segmenting portion of TensorRT operation creation
cc_library(
name = "segment",
@@ -387,17 +414,19 @@ cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
"test/base_test.py",
- # "test/batch_matmul_test.py",
- # "test/biasadd_matmul_test.py",
- # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation
- # "test/concatenation_test.py", # Blocked by trt4 installation
+ "test/batch_matmul_test.py",
+ "test/biasadd_matmul_test.py",
+ "test/binary_tensor_weight_broadcast_test.py",
+ "test/concatenation_test.py",
"test/const_broadcast_test.py",
+ "test/manual_test.py",
+ "test/memory_alignment_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
- # "test/unary_test.py", # Blocked by trt4 installation
- # "test/vgg_block_nchw_test.py",
- # "test/vgg_block_test.py",
- "test/memory_alignment_test.py",
+ "test/rank_two_test.py",
+ "test/unary_test.py",
+ "test/vgg_block_nchw_test.py",
+ "test/vgg_block_test.py",
],
additional_deps = [
":tf_trt_integration_test_base",
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 21ec8b0b30..b019c99882 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -31,6 +31,9 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorflow/contrib/tensorrt/test/utils.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
@@ -772,33 +775,55 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
const ConversionParams& params, const EngineInfo& engine) {
int cuda_device_id = -1;
tensorflow::Allocator* dev_allocator = nullptr;
- if (params.cluster) {
- std::vector<tensorflow::Device*> devices;
- if (!engine.device.empty() && params.cluster->GetDeviceSet()) {
- DeviceNameUtils::ParsedName parsed_name;
- if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
- parsed_name.has_id) {
- params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name,
- &devices);
+ if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr ||
+ engine.device.empty()) {
+ // If device is not set, use the first found GPU device for the conversion.
+ for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
+ TfGpuId tf_gpu_id(tf_gpu_id_value);
+ CudaGpuId cuda_gpu_id;
+ Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ if (s.ok()) {
+ VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
+ << cuda_gpu_id.value();
+ cuda_device_id = cuda_gpu_id.value();
+ GPUOptions gpu_options;
+ // If the TF to Cuda gpu id mapping exist, the device and corresponding
+ // allocator must have been initialized already, so the
+ // GetGPUAllocator() call won't create a new allocator.
+ dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
+ gpu_options, tf_gpu_id, 1);
+ break;
}
+ LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist "
+ << s;
}
- if (!devices.empty()) {
- if (devices.size() > 1) {
- string msg = "Found multiple matching devices using name '";
- StrAppend(&msg, engine.device, "': ");
- for (auto d : devices) StrAppend(&msg, d->name(), ", ");
- StrAppend(&msg, ". Will get the allocator from first one.");
- LOG(WARNING) << msg;
- }
- tensorflow::AllocatorAttributes alloc_attr;
- cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
- dev_allocator = devices[0]->GetAllocator(alloc_attr);
- VLOG(1) << "Using allocator " << dev_allocator->Name()
- << " and cuda_device_id " << cuda_device_id;
- } else {
- LOG(WARNING) << "Cluster is set but device '" << engine.device
- << "' is not found in the cluster";
+ return std::make_pair(cuda_device_id, dev_allocator);
+ }
+
+ // Use the device requested by the engine.
+ auto device_set = params.cluster->GetDeviceSet();
+ std::vector<tensorflow::Device*> devices;
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
+ parsed_name.has_id) {
+ device_set->FindMatchingDevices(parsed_name, &devices);
+ }
+ if (!devices.empty()) {
+ if (devices.size() > 1) {
+ string msg = "Found multiple matching devices using name '";
+ StrAppend(&msg, engine.device, "': ");
+ for (auto d : devices) StrAppend(&msg, d->name(), ", ");
+ StrAppend(&msg, ". Will get the allocator from first one.");
+ LOG(WARNING) << msg;
}
+ tensorflow::AllocatorAttributes alloc_attr;
+ cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
+ dev_allocator = devices[0]->GetAllocator(alloc_attr);
+ VLOG(1) << "Using allocator " << dev_allocator->Name()
+ << " and cuda_device_id " << cuda_device_id;
+ } else {
+ LOG(WARNING) << "Cluster is set but device '" << engine.device
+ << "' is not found in the cluster";
}
return std::make_pair(cuda_device_id, dev_allocator);
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 9d986e4890..3525202369 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -84,6 +85,11 @@ std::vector<int> GetLinkedTensorRTVersion();
// Return runtime time TensorRT library version information.
std::vector<int> GetLoadedTensorRTVersion();
+
+// Helper method for the conversion, expose for testing.
+std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
+ const ConversionParams& params, const EngineInfo& engine);
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
new file mode 100644
index 0000000000..8146bed4b0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
+#include "tensorflow/core/public/session.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+class FakeCluster : public grappler::Cluster {
+ public:
+ FakeCluster() : Cluster(0) {}
+
+ void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; }
+
+ const DeviceSet* GetDeviceSet() const override { return device_set_; }
+
+ string type() const override { return ""; }
+ Status Provision() override { return Status::OK(); }
+ Status Initialize(const grappler::GrapplerItem& item) override {
+ return Status::OK();
+ }
+ Status Run(const GraphDef& graph_def,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* metadata) override {
+ return Status::OK();
+ }
+
+ private:
+ const DeviceSet* device_set_;
+};
+
+TEST(ConvertGraphTest, GetDeviceAndAllocator) {
+ ConversionParams params;
+ EngineInfo engine_info;
+ {
+ // params.cluster is not set, and no gpu device is available.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(-1, result.first);
+ EXPECT_EQ(nullptr, result.second);
+ }
+
+ // Create a session with two (virtual) gpu device.
+ SessionOptions options;
+ ConfigProto* config = &options.config;
+ GPUOptions* gpu_options = config->mutable_gpu_options();
+ auto virtual_devices =
+ gpu_options->mutable_experimental()->add_virtual_devices();
+ virtual_devices->add_memory_limit_mb(200);
+ virtual_devices->add_memory_limit_mb(200);
+ std::unique_ptr<Session> session(NewSession(options));
+
+ {
+ // params.cluster is not set, should find and return first gpu id and
+ // corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ FakeCluster cluster;
+ params.cluster = &cluster;
+ {
+ // params.cluster->GetDeviceSet() returns null, should find and return first
+ // gpu id and corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ // Build the DeviceSet.
+ DeviceSet device_set;
+ const DeviceMgr* device_mgr = nullptr;
+ TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr));
+ for (auto d : device_mgr->ListDevices()) {
+ device_set.AddDevice(d);
+ }
+ cluster.SetDeviceSet(&device_set);
+ {
+ // engine_info.device is not set, should find and return first gpu id and
+ // corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ engine_info.device = "/GPU:1";
+ {
+ // Set to use second device.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_1_bfc", result.second->Name());
+ }
+
+ engine_info.device = "/GPU:3";
+ {
+ // Set to use nonexistent device.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(-1, result.first);
+ EXPECT_EQ(nullptr, result.second);
+ }
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 35fa590254..c98b07ad8b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -77,6 +78,10 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
+// TODO(aaroey): put these constants into some class.
+const char* const kInputPHName = "TensorRTInputPH_";
+const char* const kOutputPHName = "TensorRTOutputPH_";
+
namespace convert {
using ::tensorflow::str_util::Split;
using ::tensorflow::strings::StrAppend;
@@ -155,12 +160,22 @@ tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
for (int d = 1; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return tensorflow::errors::InvalidArgument(
- "Input tensor has a unknown non-batch dimemension at dim ", d);
+ "Input tensor with shape ", shape.DebugString(),
+ " has an unknown non-batch dimemension at dim ", d);
}
}
return Status::OK();
}
+string DebugString(const nvinfer1::Dims& dims) {
+ string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
+ for (int i = 0; i < nvinfer1::Dims::MAX_DIMS; ++i) {
+ StrAppend(&out, dims.d[i], ",");
+ }
+ StrAppend(&out, ")");
+ return out;
+}
+
// Return whether or not the broadcast is feasible;
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
const bool operand_l_is_tensor,
@@ -353,6 +368,13 @@ class TRT_ShapedWeights {
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+ string DebugString() const {
+ return StrCat(
+ "TRT_ShapedWeights(shape=", convert::DebugString(shape_), ", type=",
+ type_, ", values=", reinterpret_cast<uintptr_t>(values_),
+ ", empty_weight_flag=", empty_weight_flag_, ")");
+ }
+
// TODO(aaroey): make these private.
nvinfer1::Dims shape_;
tensorflow::DataType type_;
@@ -367,11 +389,14 @@ class TRT_TensorOrWeights {
public:
explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
: tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
+
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
: tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+
// TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
: tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
+
~TRT_TensorOrWeights() {}
bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
@@ -381,18 +406,22 @@ class TRT_TensorOrWeights {
CHECK(is_tensor());
return tensor_;
}
+
const nvinfer1::ITensor* tensor() const {
CHECK(is_tensor());
return tensor_;
}
+
TRT_ShapedWeights& weights() {
CHECK(is_weights());
return weights_;
}
+
const TRT_ShapedWeights& weights() const {
CHECK(is_weights());
return weights_;
}
+
nvinfer1::Dims shape() const {
if (is_tensor()) {
return tensor()->getDimensions();
@@ -401,6 +430,18 @@ class TRT_TensorOrWeights {
}
}
+ string DebugString() const {
+ string output = "TRT_TensorOrWeights(type=";
+ if (is_tensor()) {
+ StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor_),
+ ", shape=", convert::DebugString(tensor_->getDimensions()));
+ } else {
+ StrAppend(&output, "weights=", weights_.DebugString());
+ }
+ StrAppend(&output, ")");
+ return output;
+ }
+
private:
nvinfer1::ITensor* tensor_;
TRT_ShapedWeights weights_;
@@ -555,7 +596,7 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
}
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
- TRT_ShapedWeights* oweights, int num_groups) {
+ TRT_ShapedWeights* oweights, const int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
// K indexes over output channels, C over input channels, and R and S over the
@@ -563,13 +604,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
const int r = iweights.shape_.d[0];
const int s = iweights.shape_.d[1];
// TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
- VLOG(2) << "num_groups: " << num_groups;
const int c = iweights.shape_.d[2] / num_groups;
- VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
const int k = iweights.shape_.d[3] * num_groups;
- VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
- VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r;
- VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s;
+ VLOG(2) << "num_groups: " << num_groups
+ << "c" << iweights.shape_.d[2] << " then " << c
+ << "k" << iweights.shape_.d[3] << " then " << k
+ << "r" << iweights.shape_.d[0] << " then " << r
+ << "s" << iweights.shape_.d[1] << " then " << s;
oweights->shape_.d[0] = k / num_groups;
oweights->shape_.d[1] = c * num_groups;
oweights->shape_.d[2] = r;
@@ -607,63 +648,15 @@ using OpConverter =
std::vector<TRT_TensorOrWeights>*)>;
class Converter {
- // TODO(aaroey): fix the order of members.
- std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
- std::unordered_map<string, OpConverter> op_registry_;
- OpConverter plugin_converter_;
- nvinfer1::INetworkDefinition* trt_network_;
- std::list<std::vector<uint8_t>> temp_bufs_;
- // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
- // operate the stored weights instead of operating it directly.
- TRTWeightStore* weight_store_;
- bool fp16_;
- void register_op_converters();
- tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
- std::vector<TRT_TensorOrWeights>* inputs) {
- for (auto const& input_name : node_def.input()) {
- /*************************************************************************
- * TODO(jie): handle case 1) here.
- * Normalizes the inputs and extracts associated metadata:
- * 1) Inputs can contain a colon followed by a suffix of characters.
- * That suffix may be a single number (e.g. inputName:1) or several
- * word characters separated from a number by a colon
- * (e.g. inputName:foo:1). The
- * latter case is used to denote inputs and outputs of functions.
- * 2) Control dependency inputs contain caret at the beginning and we
- * remove this and annotate the edge as a control dependency.
- ************************************************************************/
- // skip control nodes
- if (input_name[0] == '^') continue;
- string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0')
- name.erase(first);
-
- VLOG(2) << "retrieve input: " << name;
- if (trt_tensors_.count(name)) {
- inputs->push_back(trt_tensors_.at(name));
- } else {
- // TODO(aaroey): this should not happen, make it a CHECK.
- // TODO(aaroey): use StrCat for pattern like this.
- string msg("Node ");
- StrAppend(&msg, node_def.name(), " should have an input named '", name,
- "' but it is not available");
- LOG(ERROR) << msg;
- return tensorflow::errors::InvalidArgument(msg);
- }
- }
- return tensorflow::Status::OK();
- }
-
public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters();
}
+
TRTWeightStore* weight_store() { return weight_store_; }
+
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape);
@@ -672,8 +665,10 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
+
// TODO(aaroey): fix all the namings.
bool isFP16() { return fp16_; }
+
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
}
@@ -684,7 +679,6 @@ class Converter {
const string& op = node_def.op();
std::vector<TRT_TensorOrWeights> outputs;
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
- // TODO(aaroey): plugin_converter_ is not set, fix it.
TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
} else {
if (!op_registry_.count(op)) {
@@ -702,7 +696,8 @@ class Converter {
if (output.is_tensor()) {
output.tensor()->setName(output_name.c_str());
}
- VLOG(2) << "Write out tensor: " << output_name;
+ VLOG(2) << "Adding out tensor " << output_name << ": "
+ << output.DebugString();
if (!trt_tensors_.insert({output_name, output}).second) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + op);
@@ -751,6 +746,63 @@ class Converter {
layer->setReshapeDimensions(reshape_dims);
return layer->getOutput(0);
}
+
+ private:
+ std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
+ std::unordered_map<string, OpConverter> op_registry_;
+ OpConverter plugin_converter_;
+ nvinfer1::INetworkDefinition* trt_network_;
+ std::list<std::vector<uint8_t>> temp_bufs_;
+
+ // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
+ // operate the stored weights instead of operating it directly.
+ TRTWeightStore* weight_store_;
+
+ bool fp16_;
+
+ void register_op_converters();
+
+ tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights>* inputs) {
+ for (auto const& input_name : node_def.input()) {
+ /*************************************************************************
+ * TODO(jie): handle case 1) here.
+ * Normalizes the inputs and extracts associated metadata:
+ * 1) Inputs can contain a colon followed by a suffix of characters.
+ * That suffix may be a single number (e.g. inputName:1) or several
+ * word characters separated from a number by a colon
+ * (e.g. inputName:foo:1). The
+ * latter case is used to denote inputs and outputs of functions.
+ * 2) Control dependency inputs contain caret at the beginning and we
+ * remove this and annotate the edge as a control dependency.
+ ************************************************************************/
+ // skip control nodes
+ if (input_name[0] == '^') continue;
+ string name = input_name;
+ auto first = name.find_first_of(':');
+ // TODO(aaroey): why removing the colon but not the zero? A bug?
+ // TODO(aaroey): use TensorId
+ if (first != string::npos && first + 2 == name.size() &&
+ name[first + 1] == '0') {
+ name.erase(first);
+ }
+
+ if (trt_tensors_.count(name)) {
+ TRT_TensorOrWeights& input = trt_tensors_.at(name);
+ inputs->push_back(input);
+ VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
+ } else {
+ // TODO(aaroey): this should not happen, make it a CHECK.
+ // TODO(aaroey): use StrCat for pattern like this.
+ string msg("Node ");
+ StrAppend(&msg, node_def.name(), " should have an input named '", name,
+ "' but it is not available");
+ LOG(ERROR) << msg;
+ return tensorflow::errors::InvalidArgument(msg);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
};
TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
@@ -1187,17 +1239,11 @@ tensorflow::Status ConvertConv2DHelper(
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
-
- VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims;
- for (int i = 0; i < weights_rsck.shape_.nbDims; i++) {
- VLOG(2) << weights_rsck.shape_.d[i];
- }
-
+ VLOG(2) << "weight shape: " << weights_rsck.DebugString();
if (weights_rsck.shape_.nbDims != 4) {
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
-
if (ctx.isFP16()) {
weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
@@ -1209,16 +1255,13 @@ tensorflow::Status ConvertConv2DHelper(
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
- VLOG(2) << "RSCK: ";
- for (int i = 0; i < 4; i++) {
- VLOG(2) << " " << weights.shape_.d[i];
- }
+ VLOG(2) << "RSCK: " << weights.DebugString();
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
// TODO(jie): stride. (NHWC/NCHW)
const auto tf_stride = attrs.get<std::vector<int>>("strides");
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
- VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
+ VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
<< tf_stride[3];
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
@@ -1240,10 +1283,7 @@ tensorflow::Status ConvertConv2DHelper(
// TODO(jie): handle asymmetric padding
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
-
- auto dim_before = tensor->getDimensions();
- VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
- << dim_before.d[2] << ", " << dim_before.d[3];
+ VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions());
auto pad_layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
@@ -1251,9 +1291,7 @@ tensorflow::Status ConvertConv2DHelper(
TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
- auto dim_after = tensor->getDimensions();
- VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
- << dim_after.d[2] << ", " << dim_after.d[3];
+ VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions());
}
nvinfer1::IConvolutionLayer* layer =
@@ -1266,17 +1304,12 @@ tensorflow::Status ConvertConv2DHelper(
layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
-
- auto dim_after = output_tensor->getDimensions();
- VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] << ", "
- << dim_after.d[2] << ", " << dim_after.d[3];
-
+ VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
+ VLOG(2) << "data_format: " << data_format;
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
- } else {
- VLOG(2) << "NCHW !!!!";
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1990,22 +2023,22 @@ tensorflow::Status ConvertReduce(Converter& ctx,
return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
}
- const auto keep_dims = attrs.get<bool>("keep_dims");
- auto index_list_data =
- static_cast<int*>(const_cast<void*>(index_list.GetValues()));
-
int axes = 0;
if (index_list.count() == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot support reduce on all (batch) dimensions, at",
node_def.name());
} else {
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
for (int i = 0; i < index_list.count(); i++) {
- if (index_list_data[i] == 0) {
+ int axis = index_list_data[i];
+ if (axis < 0) axis += tensor->getDimensions().nbDims + 1;
+ if (axis == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot reduce at batch dimension, at", node_def.name());
}
- axes |= (1 << (index_list_data[i] - 1));
+ axes |= (1 << (axis - 1));
}
}
@@ -2025,6 +2058,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
" , at ", node_def.name());
}
+ const auto keep_dims = attrs.get<bool>("keep_dims");
nvinfer1::ILayer* layer =
ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor),
reduce_operation, axes, keep_dims);
@@ -2694,8 +2728,6 @@ tensorflow::Status ConvertGraphDefToEngine(
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
- nvinfer1::DimsCHW input_dim_pseudo_chw;
- for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
@@ -2713,28 +2745,25 @@ tensorflow::Status ConvertGraphDefToEngine(
LOG(WARNING) << error_message;
return Status(status.code(), error_message);
}
- if (VLOG_IS_ON(1)) {
- string dim_str("dims=");
- StrAppend(&dim_str, "[ ", shape.dim_size(0));
- for (int i = 1; i < shape.dims(); i++) {
- StrAppend(&dim_str, ", ", shape.dim_size(i));
- }
- StrAppend(&dim_str, " ]");
- VLOG(1) << dim_str;
- }
+
+#if NV_TENSORRT_MAJOR == 3
+ nvinfer1::DimsCHW input_dim;
+#elif NV_TENSORRT_MAJOR > 3
+ nvinfer1::Dims input_dim;
+#endif
for (int i = 1; i < shape.dims(); i++) {
- input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
+ input_dim.d[i - 1] = shape.dim_size(i);
}
-
- input_dim_pseudo_chw.nbDims = shape.dims() - 1;
- nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- node_name.c_str(), dtype, input_dim_pseudo_chw);
+ input_dim.nbDims = shape.dims() - 1;
+ nvinfer1::ITensor* input_tensor =
+ converter.network()->addInput(node_name.c_str(), dtype, input_dim);
if (!input_tensor) {
return tensorflow::errors::InvalidArgument(
"Failed to create Input layer tensor ", node_name,
" rank=", shape.dims() - 1);
}
- VLOG(1) << "Input tensor name :" << node_name;
+ VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
+ << DebugString(input_dim);
if (!converter.insert_input_tensor(node_name, input_tensor)) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + node_name);
@@ -2937,10 +2966,25 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
<< ": " << status;
return false;
}
- if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+
+
+ if (in_edge->src()->type_string() != "Const" &&
+#if NV_TENSORRT_MAJOR == 3
+ // TRT 3.x only support 4 dimensional input tensor.
+ shape.dims() != 4) {
+#else
+ // Single dimensional input tensor is not supported since the first
+ // dimension is treated as batch dimension.
+ shape.dims() < 2) {
+#endif
VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
- << " which has an input at port " << in_edge->dst_input()
- << " with #dim<3 and is not a const: " << shape;
+ << " which has an input at port " << in_edge->dst_input() << " with"
+#if NV_TENSORRT_MAJOR == 3
+ << " #dim!=4"
+#else
+ << " #dim<2"
+#endif
+ << " and is not a const: " << shape;
return false;
}
return true;
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index a60253740f..9274027e63 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,8 +36,9 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-static const char* kInputPHName = "TensorRTInputPH_";
-static const char* kOutputPHName = "TensorRTOutputPH_";
+extern const char* const kInputPHName;
+extern const char* const kOutputPHName;
+
namespace convert {
struct EngineConnection {
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index f33f2cc4d6..ff4fba58bf 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -14,6 +14,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
@@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
if (config == nullptr) {
- maximum_workspace_size_ = 2 << 30;
return tensorflow::Status::OK();
}
const auto params = config->parameter_map();
@@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init(
if (params.count("max_batch_size")) {
maximum_batch_size_ = params.at("max_batch_size").i();
}
- is_dynamic_op_ = false;
if (params.count("is_dynamic_op")) {
is_dynamic_op_ = params.at("is_dynamic_op").b();
}
@@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init(
batches_.push_back(i);
}
}
- max_cached_batches_ = 1;
if (params.count("maximum_cached_engines")) {
max_cached_batches_ = params.at("maximum_cached_engines").i();
}
if (params.count("max_workspace_size_bytes")) {
- maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
+ max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
}
if (params.count("precision_mode")) {
- string pm = Uppercase(params.at("precision_mode").s());
- if (pm == "FP32") {
- precision_mode_ = 0;
- } else if (pm == "FP16") {
- precision_mode_ = 1;
- } else if (pm == "INT8") {
- precision_mode_ = 2;
- } else {
- LOG(ERROR) << "Unknown precision mode '" << pm << "'";
- return tensorflow::errors::InvalidArgument(
- "Unknown precision mode argument" + pm +
- " Valid values are FP32, FP16, INT8");
- }
+ TF_RETURN_IF_ERROR(GetPrecisionMode(
+ Uppercase(params.at("precision_mode").s()), &precision_mode_));
}
return tensorflow::Status::OK();
}
@@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
cp.input_graph_def = &item.graph;
cp.output_names = &nodes_to_preserve;
cp.max_batch_size = maximum_batch_size_;
- cp.max_workspace_size_bytes = maximum_workspace_size_;
+ cp.max_workspace_size_bytes = max_workspace_size_bytes_;
cp.output_graph_def = optimized_graph;
cp.precision_mode = precision_mode_;
cp.minimum_segment_size = minimum_segment_size_;
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
index 463ed3883e..71b51d1368 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
@@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
minimum_segment_size_(3),
precision_mode_(0),
maximum_batch_size_(-1),
- maximum_workspace_size_(-1) {
+ is_dynamic_op_(false),
+ max_cached_batches_(1),
+ max_workspace_size_bytes_(256LL << 20) {
VLOG(1) << "Constructing " << name_;
}
@@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
const tensorflow::grappler::GrapplerItem& item);
private:
- string name_;
+ const string name_;
int minimum_segment_size_;
int precision_mode_;
int maximum_batch_size_;
bool is_dynamic_op_;
std::vector<int> batches_;
int max_cached_batches_;
- int64_t maximum_workspace_size_;
+ int64_t max_workspace_size_bytes_;
};
} // namespace convert
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h
index bc15b51e05..19f39e6d3d 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h
@@ -42,4 +42,4 @@ class TRTResourceManager {
} // namespace tensorrt
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRT_RESOURCE_MANAGER_H_
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index b43f1b190f..c82d4a0183 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -74,6 +74,7 @@ class SimpleNode {
const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
+
std::vector<SimpleNode*> in_nodes() const {
std::vector<SimpleNode*> res;
res.reserve(in_edges_.size());
@@ -82,6 +83,16 @@ class SimpleNode {
}
return res;
}
+
+ std::vector<SimpleNode*> out_nodes() const {
+ std::vector<SimpleNode*> res;
+ res.reserve(out_edges_.size());
+ for (const auto e : out_edges_) {
+ if (e) res.push_back(e->dst());
+ }
+ return res;
+ }
+
const string& name() const { return node_->name(); }
const tensorflow::Node* tf_node() const { return node_; }
int id() const { return id_; }
@@ -215,45 +226,53 @@ SimpleGraph::~SimpleGraph() {
namespace {
-bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
- const std::vector<SimpleNode*>& start) {
- // Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+// Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+void StableDFS(const SimpleGraph& g, bool reverse,
+ const std::vector<const SimpleNode*>& start,
+ const std::function<bool(const SimpleNode*)>& enter,
+ const std::function<bool(const SimpleNode*)>& leave) {
+ // Stack of work to do.
struct Work {
- SimpleNode* node;
+ const SimpleNode* node;
bool leave; // Are we entering or leaving n?
};
-
std::vector<Work> stack(start.size());
for (int i = 0; i < start.size(); ++i) {
stack[i] = Work{start[i], false};
}
- std::vector<bool> visited(g->num_node_ids(), false);
+ auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
+ : [](const SimpleNode* n) { return n->out_nodes(); };
+ std::vector<bool> visited(g.num_node_ids(), false);
while (!stack.empty()) {
Work w = stack.back();
stack.pop_back();
auto n = w.node;
if (w.leave) {
- if (n == src) {
- return true;
- }
+ if (leave && !leave(n)) return;
continue;
}
if (visited[n->id()]) continue;
visited[n->id()] = true;
- // Arrange to call leave(n) when all done with descendants.
- stack.push_back(Work{n, true});
+ if (enter && !enter(n)) return;
- auto nodes = n->in_nodes();
- for (const auto node : nodes) {
+ // Arrange to call leave(n) when all done with descendants.
+ if (leave) stack.push_back(Work{n, true});
+
+ auto nodes = get_nodes(n);
+ std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
+ std::sort(nodes_sorted.begin(), nodes_sorted.end(),
+ [](const SimpleNode* lhs, const SimpleNode* rhs) {
+ return lhs->name() < rhs->name();
+ });
+ for (const SimpleNode* node : nodes_sorted) {
if (!visited[node->id()]) {
stack.push_back(Work{node, false});
}
}
}
- return false;
}
bool CanContractEdge(const SimpleEdge* edge,
@@ -289,14 +308,21 @@ bool CanContractEdge(const SimpleEdge* edge,
// To achieve this goal, the correct way seems to be:
// 1. remove any direct edge from src->dst;
// 2. detect if src can reach dst, if so they cannot be merged.
- std::vector<SimpleNode*> dfs_start_nodes;
- for (SimpleNode* node : dst->in_nodes()) {
+ std::vector<const SimpleNode*> dfs_start_nodes;
+ for (const SimpleNode* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
-
- const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes);
+ bool has_cycle = false;
+ StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
+ [&has_cycle, src](const SimpleNode* n) {
+ if (n == src) {
+ has_cycle = true;
+ return false;
+ }
+ return true;
+ });
return !has_cycle;
}
} // namespace
@@ -403,15 +429,13 @@ tensorflow::Status SegmentGraph(
// In the future if we have a measure of how beneficial it is to include a
// given node in a TRT subgraph then we can revisit this algorithm to take
// advantage of that information.
- std::vector<tensorflow::Node*> tforder;
- tensorflow::GetPostOrder(*tf_graph, &tforder);
- // use postorder implementation from tensorflow and construct mirror in
- // internal format
- std::vector<SimpleNode*> order;
- order.reserve(tforder.size());
- for (const auto tfnode : tforder) {
- order.push_back(graph->FindNodeId(tfnode->id()));
- }
+ std::vector<const SimpleNode*> order;
+ order.reserve(graph->num_node_ids());
+ StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
+ /*enter=*/nullptr, [&order](const SimpleNode* n) {
+ order.push_back(n);
+ return true;
+ });
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 8ea5a63735..e9ac833d55 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -40,6 +40,7 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -62,19 +63,21 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
identity = array_ops.identity(relu, "identity")
pool = nn_ops.max_pool(
identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(pool, name=self.output_name)
+ array_ops.squeeze(pool, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
- # "relu", "identity", "max_pool"]
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(100, 6, 6, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 6, 6, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ return ["my_trt_op_0"]
class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -85,6 +88,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -115,20 +119,22 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
q = math_ops.mul(q, edge, name="mul1")
s = math_ops.add(p, q, name="add1")
s = math_ops.sub(s, r, name="sub1")
- array_ops.squeeze(s, name=self.output_name)
+ array_ops.squeeze(s, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
- # "add", "sub1"];
- # - my_trt_op_1 should have ["weights","conv", "div"]
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(100, 12, 12, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 12, 12, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
+ # "add", "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ return ["my_trt_op_0", "my_trt_op_1"]
class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
@@ -143,6 +149,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing two segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -161,18 +168,20 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
c = constant_op.constant(1.0, name="c3")
n = math_ops.add(n, c, name="add3")
n = math_ops.mul(n, n, name="mul3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- # Only the first engine is built.
- "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ }
class PartiallyConvertedTestB(PartiallyConvertedTestA):
@@ -184,13 +193,12 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
trt_convert.clear_test_values("")
trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
- def GetParams(self):
- """Create a graph containing two segment."""
- return super(PartiallyConvertedTestB, self).GetParams()._replace(
- expected_engines={
- # Only the second engine is built.
- "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
- })
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ }
class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
@@ -199,6 +207,7 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -221,18 +230,20 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add", "add1", "mul"],
- "my_trt_op_1": ["add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
+ }
class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
@@ -241,6 +252,7 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing single segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -251,15 +263,17 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add")
n = math_ops.mul(n, n, name="mul")
n = math_ops.add(n, n, name="add1")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {"my_trt_op_0": ["c", "add", "add1", "mul"]}
class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -268,6 +282,7 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -282,22 +297,24 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add2", "add3", "mul1"],
- # Why segment ["add", "add1", "mul"] was assigned segment id 1
- # instead of 0: the parent node of this segment is actually const
- # node 'c', but it's removed later since it's const output of the
- # segment which is not allowed.
- "my_trt_op_1": ["add", "add1", "mul"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
+ "my_trt_op_1": ["add", "add1", "mul"]
+ }
class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
@@ -306,6 +323,7 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -328,18 +346,20 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
mul1 = math_ops.mul(add2, add2, name="mul1")
with g.control_dependencies([d1, d2, add, add1]):
add3 = math_ops.add(mul1, mul1, name="add3")
- array_ops.squeeze(add3, name=self.output_name)
+ array_ops.squeeze(add3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["c1", "add", "add1", "mul"],
- "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 2e1107e303..2f153c6f2f 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -37,6 +37,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 12]
+ output_name = "output"
w1_name = "matmul_w1"
w1_dims = [12, 5, 12, 7]
w2_name = "matmul_w2"
@@ -61,15 +62,46 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
x3 = x3 + f
x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7])
out = x1 + x2 + x3
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(12, 5, 8, 7),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 7)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return ["my_trt_op_0", "my_trt_op_1"]
+ return ["my_trt_op_1"]
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return ["my_trt_op_1"]
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt library will fail like:
+ #
+ # ../builder/cudnnBuilder2.cpp:685:
+ # virtual std::vector<nvinfer1::query::Ports<
+ # nvinfer1::query::TensorRequirements>>
+ # nvinfer1::builder::Node::getSupportedFormats(
+ # const nvinfer1::query::Ports<nvinfer1::query::AbstractTensor>&,
+ # const nvinfer1::cudnn::HardwareContext&,
+ # nvinfer1::builder::Format::Type,
+ # const nvinfer1::builder::FormatTypeHack&) const:
+ # Assertion `sf' failed.
+ #
+ # To reproduce, run:
+ # bazel test -c opt --copt=-mavx \
+ # --test_arg=BatchMatMulTest.testTfTrt_ToolConversion_INT8_DynamicEngine \
+ # tensorflow/contrib/tensorrt:batch_matmul_test
+ #
+ # Investigate and fix it.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 8be32f59b4..62f4e525f7 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -38,6 +38,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [48, 12]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -97,18 +98,59 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
out = array_ops.concat(
[x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1)
- out = array_ops.squeeze(out, name=self.output_name)
+ out = array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
- ],
- expected_output_dims=(48, 89),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(48, 89)])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return super(BiasaddMatMulTest,
+ self).GetConversionParams(run_params)._replace(
+ max_batch_size=48, maximum_cached_engines=2)
+
+ def _ValidEngines(self):
+ """Engines expected to build and run."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_6",
+ "my_trt_op_7", "my_trt_op_8", "my_trt_op_9"
+ ]
+
+ def _InvalidEngines(self):
+ """Engines that will cause conversion error at building time."""
+ return ["my_trt_op_3", "my_trt_op_4", "my_trt_op_5"]
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # In dynamic engine mode the engines are built in execution time, not in
+ # conversion time, so build errors occurs later. Here three of the engines
+ # will be failed to built but the corresponding engine op are still created.
+ # TODO(aaroey, jjsjann123): fix this.
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return self._ValidEngines() + self._InvalidEngines()
+ return self._ValidEngines()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self._ValidEngines()
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index 9316b14da0..f126ed4238 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -37,6 +37,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [10, 24, 24, 20]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -104,32 +105,34 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
a = constant_op.constant(np.random.randn(24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
- gen_array_ops.reshape(x, [5, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [5, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0",
- "my_trt_op_1",
- "my_trt_op_2",
- "my_trt_op_3",
- "my_trt_op_4",
- "my_trt_op_5",
- "my_trt_op_6",
- "my_trt_op_7",
- "my_trt_op_8",
- "my_trt_op_9",
- "my_trt_op_10",
- "my_trt_op_11",
- "my_trt_op_12",
- "my_trt_op_13",
- "my_trt_op_14",
- "my_trt_op_15",
- ],
- expected_output_dims=(5, 23040),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 23040)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 1874b9dd45..465cb02296 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -37,6 +37,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 3, 1]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -68,15 +69,17 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
concat1 = array_ops.concat([r1, r2, r3, r4, r5, r6], axis=-1)
concat2 = array_ops.concat([r7, r8, r9, r10, r11, r12], axis=3)
x = array_ops.concat([concat1, concat2], axis=-1)
- gen_array_ops.reshape(x, [2, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [2, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 126),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 126)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 8c59000b70..e32f047866 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -36,6 +36,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = 'input'
input_dims = [5, 12, 12, 2]
+ output_name = 'output'
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -53,15 +54,25 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype=dtype,
name='filt3')
y3 = nn.conv2d(z2, filt3, strides=[1, 1, 1, 1], padding='SAME', name='y3')
- nn.relu(y3, name='output')
+ nn.relu(y3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=['my_trt_op_0'],
- expected_output_dims=(5, 12, 12, 1),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(5, 12, 12, 1)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ['my_trt_op_0']
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
if __name__ == '__main__':
diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py
new file mode 100644
index 0000000000..1187c759b4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/manual_test.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Basic tests for TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import os
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+class ManualTest(trt_test.TfTrtIntegrationTestBase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(ManualTest, self).__init__(methodName)
+ self._params_map = None
+
+ def _GetEnv(self):
+ """Get an environment variable specifying the manual test parameters.
+
+ The value of the environment variable is the string representation of a dict
+ which should contain the following keys:
+ - 'graph_path': the file path to the serialized frozen graphdef
+ - 'input_names': TfTrtIntegrationTestParams.input_names
+ - 'input_dims': TfTrtIntegrationTestParams.input_dims
+ - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims
+ - 'output_name': the name of op to fetch
+ - 'expected_engines_to_run': ExpectedEnginesToRun() will return this
+ - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this
+ - 'max_batch_size': ConversionParams.max_batch_size
+
+ Returns:
+ The value of the environment variable.
+ """
+ return os.getenv('TRT_MANUAL_TEST_PARAMS', '')
+
+ def _GetParamsMap(self):
+ """Parse the environment variable as a dict and return it."""
+ if self._params_map is None:
+ self._params_map = ast.literal_eval(self._GetEnv())
+ return self._params_map
+
+ def GetParams(self):
+ """Testing conversion of manually provided frozen graph."""
+ params_map = self._GetParamsMap()
+ gdef = graph_pb2.GraphDef()
+ with gfile.Open(params_map['graph_path'], 'rb') as f:
+ gdef.ParseFromString(f.read())
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=gdef,
+ input_names=params_map['input_names'],
+ input_dims=params_map['input_dims'],
+ output_names=params_map['output_names'],
+ expected_output_dims=params_map['expected_output_dims'])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ conversion_params = super(ManualTest, self).GetConversionParams(run_params)
+ params_map = self._GetParamsMap()
+ if 'max_batch_size' in params_map:
+ conversion_params = conversion_params._replace(
+ max_batch_size=params_map['max_batch_size'])
+ return conversion_params
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return self._GetParamsMap()['expected_engines_to_build']
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ params_map = self._GetParamsMap()
+ if 'expected_engines_to_run' in params_map:
+ return params_map['expected_engines_to_run']
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'atol' in params_map:
+ return params_map['atol']
+ return 1.e-3
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'rtol' in params_map:
+ return params_map['rtol']
+ return 1.e-3
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return len(self._GetEnv())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 66eb6be757..bc7c90081f 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -36,6 +36,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 15, 15, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -57,15 +58,25 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
strides=[1, 1, 1, 1],
padding="VALID",
name="conv_2")
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 15, 15, 10),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(2, 15, 15, 10)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 0.1
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index fd55b8cd99..11be4feaf7 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -38,6 +38,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -72,15 +73,17 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
t = t + q
t = t + d
t = t - edge3
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0", "my_trt_op_1"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 51c905a50b..eddeafa38b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -37,6 +37,7 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -54,18 +55,20 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
t = math_ops.mul(conv, b, name="mul")
e = self.trt_incompatible_op(conv, name="incompatible")
t = math_ops.sub(t, e, name="sub")
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["bias", "mul", "sub"],
- "my_trt_op_1": ["weights", "conv"]
- },
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py
new file mode 100644
index 0000000000..74a4a05925
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class RankTwoTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Test for rank 2 input in TF-TRT."""
+ input_names = ["input", "input2"]
+ # Two paths: first with rank 2 input, second with rank 4 input.
+ input_dims = [[12, 5], [12, 5, 2, 2]]
+ output_name = "output"
+ g = ops.Graph()
+ with g.as_default():
+ outputs = []
+ for i in range(2):
+ x = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
+ c = constant_op.constant(1.0, name="c%d_1" % i)
+ q = math_ops.add(x, c, name="add%d_1" % i)
+ q = math_ops.abs(q, name="abs%d_1" % i)
+ c = constant_op.constant(2.2, name="c%d_2" % i)
+ q = math_ops.add(q, c, name="add%d_2" % i)
+ q = math_ops.abs(q, name="abs%d_2" % i)
+ c = constant_op.constant(3.0, name="c%d_3" % i)
+ q = math_ops.add(q, c, name="add%d_3" % i)
+ if i == 0:
+ for j in range(2):
+ q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
+ q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
+ outputs.append(q)
+ # Combine both paths
+ q = math_ops.add(outputs[0], outputs[1], name="add")
+ array_ops.squeeze(q, name=output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=input_names,
+ input_dims=input_dims,
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims[1])])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": [
+ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1",
+ "abs0_2"
+ ],
+ "my_trt_op_1": [
+ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3",
+ "abs1_1", "abs1_2", "reciprocal0", "reciprocal1"
+ ],
+ }
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 6f85ada464..65ca21cf37 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
@@ -39,18 +40,23 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "expected_engines",
- "expected_output_dims", "allclose_atol", "allclose_rtol"
+ "gdef", "input_names", "input_dims", "output_names", "expected_output_dims"
])
RunParams = namedtuple(
"RunParams",
["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+ConversionParams = namedtuple("ConversionParams", [
+ "max_batch_size", "max_workspace_size_bytes", "precision_mode",
+ "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
+ "cached_engine_batches"
+])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
-def _IsQuantizationMode(mode):
+def IsQuantizationMode(mode):
return mode == "INT8"
@@ -64,10 +70,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@property
- def output_name(self):
- return "output"
-
- @property
def trt_incompatible_op(self):
return math_ops.sin
@@ -112,6 +114,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
super(TfTrtIntegrationTestBase, cls).setUpClass()
trt_convert.enable_test_value()
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(TfTrtIntegrationTestBase, self).__init__(methodName)
+ self._trt_test_params = None
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
@@ -122,43 +128,97 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _PrepareRun(self, params, graph_state):
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return ConversionParams(
+ max_batch_size=max([
+ dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
+ ]),
+ max_workspace_size_bytes=1 << 25,
+ precision_mode=self._ToBytes(run_params.precision_mode),
+ minimum_segment_size=2,
+ is_dynamic_op=run_params.dynamic_engine,
+ maximum_cached_engines=1,
+ cached_engine_batches=None)
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return True
+
+ def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True):
+ """Verify the state of a particular engine after sess.run()."""
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ if expect_run:
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+ else:
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+
+ def VerifyRun(self, run_params, graph_state):
+ """Verify the state of all engines after sess.run()."""
+ for engine_name in self.ExpectedEnginesToBuild(run_params):
+ expect_run = (engine_name in self.ExpectedEnginesToRun(run_params))
+ self.VerifyRunForEngine(engine_name, graph_state, expect_run)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build, implemented by subclass."""
+ raise NotImplementedError()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def _GetParamsCached(self):
+ if self._trt_test_params is None:
+ self._trt_test_params = self.GetParams()
+ return self._trt_test_params
+
+ def _PrepareRun(self, graph_state):
"""Set up necessary testing environment before calling sess.run()."""
# Clear test values added by TRTEngineOp.
trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
- def _VerifyRun(self, params, graph_state):
- """Verify the state after sess.run()."""
- for engine_name in params.expected_engines:
- if graph_state == GraphState.ORIGINAL:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.CALIBRATE:
- self._ExpectCalibration(engine_name, "done")
- self._ExpectNativeSegment(engine_name, "done")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.INFERENCE:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "done")
-
- def _GetConfigProto(self, params, run_params, graph_state):
+ def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 2
- custom_op.parameter_map["max_batch_size"].i = max(
- [dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- run_params.precision_mode)
+ trt_params = self.GetConversionParams(run_params)
+ custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
+ custom_op.parameter_map["max_workspace_size_bytes"].i = (
+ trt_params.max_workspace_size_bytes)
+ custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
+ custom_op.parameter_map["minimum_segment_size"].i = (
+ trt_params.minimum_segment_size)
+ custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
+ custom_op.parameter_map["maximum_cached_engines"].i = (
+ trt_params.maximum_cached_engines)
+ if trt_params.cached_engine_batches:
+ custom_op.parameter_map["cached_engine_batches"].list.i.extend(
+ trt_params.cached_engine_batches)
+
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
@@ -190,53 +250,67 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _ExpectNativeSegment(self, engine_name, value):
self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
- def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ def _RunGraph(self,
+ run_params,
+ gdef,
+ input_data,
+ config,
+ graph_state,
num_runs=2):
"""Run given graphdef multiple times."""
+ params = self._GetParamsCached()
assert len(params.input_names) == len(input_data)
g = ops.Graph()
with g.as_default():
io_ops = importer.import_graph_def(
graph_def=gdef,
- return_elements=params.input_names + [self.output_name],
+ return_elements=params.input_names + params.output_names,
name="")
- inp = [i.outputs[0] for i in io_ops[:-1]]
- assert len(inp) == len(input_data)
- out = io_ops[-1].outputs[0]
+ inputs = [op.outputs[0] for op in io_ops[:len(params.input_names)]]
+ assert len(inputs) == len(input_data)
+ outputs = [op.outputs[0] for op in io_ops[len(params.input_names):]]
with self.test_session(
graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
- self._PrepareRun(params, graph_state)
- new_val = sess.run(out,
- {inp[i]: input_data[i] for i in range(len(inp))})
- self.assertEqual(params.expected_output_dims, new_val.shape)
+ self._PrepareRun(graph_state)
+ new_val = sess.run(
+ outputs, {inputs[i]: input_data[i] for i in range(len(inputs))})
+ output_len = len(params.expected_output_dims)
+ self.assertEqual(output_len, len(new_val))
+ for i in range(output_len):
+ self.assertEqual(params.expected_output_dims[i], new_val[i].shape)
if val is not None:
- self.assertAllEqual(val, new_val)
+ self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06)
val = new_val
- self._VerifyRun(params, graph_state)
+ self.VerifyRun(run_params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
- def _RunCalibration(self, params, gdef, input_data, config):
+ def _RunCalibration(self, run_params, gdef, input_data, config):
"""Run calibration on given graph."""
return self._RunGraph(
- params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
+ run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, run_params, gdef):
+ def _GetTrtGraphDef(self, run_params, gdef):
"""Return trt converted graphdef."""
+ params = self._GetParamsCached()
+ trt_params = self.GetConversionParams(run_params)
+ logging.info(trt_params)
return trt_convert.create_inference_graph(
input_graph_def=gdef,
- outputs=[self.output_name],
- max_batch_size=max([dims[0] for dims in params.input_dims]),
- max_workspace_size_bytes=1 << 25,
- precision_mode=run_params.precision_mode,
- minimum_segment_size=2,
- is_dynamic_op=run_params.dynamic_engine)
-
- def _WriteGraph(self, params, run_params, gdef, graph_state):
+ outputs=params.input_names + params.output_names,
+ max_batch_size=trt_params.max_batch_size,
+ max_workspace_size_bytes=trt_params.max_workspace_size_bytes,
+ precision_mode=trt_params.precision_mode,
+ minimum_segment_size=trt_params.minimum_segment_size,
+ is_dynamic_op=trt_params.is_dynamic_op,
+ maximum_cached_engines=trt_params.maximum_cached_engines,
+ cached_engine_batches=trt_params.cached_engine_batches)
+
+ def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
label = "Original"
elif graph_state == GraphState.CALIBRATE:
@@ -247,15 +321,17 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
".pbtxt")
temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
- logging.info("Writing graph to %s/%s", temp_dir, graph_name)
- graph_io.write_graph(gdef, temp_dir, graph_name)
+ if temp_dir:
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
- def _VerifyConnections(self, params, converted_gdef):
+ def _VerifyConnections(self, expected_engines, converted_gdef):
+ params = self._GetParamsCached()
old_to_new_node_map = {
self._ToString(node.name): self._ToString(node.name)
for node in params.gdef.node
}
- for engine_name, node_names in params.expected_engines.items():
+ for engine_name, node_names in expected_engines.items():
for node_name in node_names:
old_to_new_node_map[node_name] = engine_name
name_to_node_map = {
@@ -310,97 +386,114 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
msg="expected:\n%s\nvs actual:\n%s" % (sorted(
expected_input_map.items()), sorted(actual_input_map.items())))
- def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
- self._WriteGraph(params, run_params, gdef, graph_state)
+ def _VerifyGraphDef(self, run_params, gdef, graph_state):
+ self._WriteGraph(run_params, gdef, graph_state)
+ expected_engines = self.ExpectedEnginesToBuild(run_params)
num_engines = 0
for node in gdef.node:
if node.op == "TRTEngineOp":
+ logging.info("Found TRTEngineOp: " + node.name)
+ for node in gdef.node:
+ if node.op == "TRTEngineOp":
num_engines += 1
- self.assertTrue(node.name in params.expected_engines)
- self.assertTrue(len(node.attr["serialized_segment"].s))
- self.assertTrue(len(node.attr["segment_funcdef_name"].s))
+ self.assertTrue(node.name in expected_engines, node.name)
+ self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
+ self.assertTrue(len(node.attr["segment_funcdef_name"].s), node.name)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
- node.attr["precision_mode"].s)
+ node.attr["precision_mode"].s, node.name)
is_dynamic_engine = not node.attr["static_engine"].b
- self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
+ node.name)
has_calibration_data = len(node.attr["calibration_data"].s)
- if (_IsQuantizationMode(run_params.precision_mode) and
+ if (IsQuantizationMode(run_params.precision_mode) and
graph_state == GraphState.INFERENCE):
- self.assertTrue(has_calibration_data)
+ self.assertTrue(has_calibration_data, node.name)
else:
- self.assertFalse(has_calibration_data)
+ self.assertFalse(has_calibration_data, node.name)
if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, len(params.expected_engines))
- if isinstance(params.expected_engines, dict):
- self._VerifyConnections(params, gdef)
+ self.assertEqual(num_engines, len(expected_engines))
+ if isinstance(expected_engines, dict):
+ self._VerifyConnections(expected_engines, gdef)
# TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, run_params):
+ def RunTest(self, run_params):
+ if not self.ShouldRunTest(run_params):
+ return
assert run_params.precision_mode in PRECISION_MODES
- input_data = [np.random.random_sample(dims) for dims in params.input_dims]
+
+ params = self._GetParamsCached()
input_gdef = params.gdef
- self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
+ input_dtypes = {}
+ for node in input_gdef.node:
+ if self._ToString(node.name) in params.input_names:
+ assert self._ToString(node.op) == "Placeholder"
+ input_dtypes[self._ToString(node.name)] = (
+ dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype())
+ assert len(params.input_names) == len(input_dtypes)
+
+ input_data = []
+ for i in range(len(params.input_names)):
+ dtype = input_dtypes[params.input_names[i]]
+ # Multiply the input by some constant to avoid all zeros input for integer
+ # types.
+ scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
+ dims = params.input_dims[i]
+ input_data.append((scale * np.random.random_sample(dims)).astype(dtype))
+ self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, run_params,
- GraphState.ORIGINAL)
+ config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
- GraphState.ORIGINAL)
+ ref_result = self._RunGraph(run_params, input_gdef, input_data,
+ config_no_trt, GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(run_params.precision_mode):
+ if IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, run_params,
- GraphState.CALIBRATE)
+ calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
if run_params.use_optimizer:
- result = self._RunCalibration(params, input_gdef, input_data,
+ result = self._RunCalibration(run_params, input_gdef, input_data,
calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
- self._VerifyGraphDef(params, run_params, calib_gdef,
- GraphState.CALIBRATE)
- result = self._RunCalibration(params, calib_gdef, input_data,
+ calib_gdef = self._GetTrtGraphDef(run_params, input_gdef)
+ self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE)
+ result = self._RunCalibration(run_params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(
+ calib_gdef, run_params.dynamic_engine)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
else:
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, run_params,
- GraphState.INFERENCE)
+ infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if run_params.use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
- else:
- trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
- self._VerifyGraphDef(params, run_params, trt_infer_gdef,
- GraphState.INFERENCE)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
+ if not run_params.use_optimizer:
+ infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
+ result = self._RunGraph(run_params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
def testIdempotence(self):
# Test that applying tensorrt optimizer or offline conversion tools multiple
@@ -421,13 +514,12 @@ def _AddTests(test_class):
"""Gets a single test method based on the parameters."""
def _Test(self):
- params = self.GetParams()
logging.info(
"Running test %s with parameters: use_optimizer=%s, "
"precision_mode=%s, dynamic_engine=%s",
"testTfTrt_" + run_params.test_name, run_params.use_optimizer,
run_params.precision_mode, run_params.dynamic_engine)
- self.RunTest(params, run_params)
+ self.RunTest(run_params)
return _Test
@@ -435,7 +527,7 @@ def _AddTests(test_class):
dynamic_engine_options = [False, True]
for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
- if _IsQuantizationMode(precision_mode):
+ if IsQuantizationMode(precision_mode):
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index 500057a36d..8736bfb644 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -38,6 +38,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 1, 1, 12]
+ output_name = "output"
input2_name = "input_2"
input2_dims = [12, 5, 8, 1, 12, 1, 1]
g = ops.Graph()
@@ -95,18 +96,20 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
q = a * b
q = q / c
- array_ops.squeeze(q, name=self.output_name)
+ array_ops.squeeze(q, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4"
- ],
- expected_output_dims=(12, 5, 8, 12),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 12)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index ab4d224db4..b0271a04b3 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -38,15 +38,14 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 2, 8, 8]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
data_format="NCHW",
is_training=False)
e = constant_op.constant(
@@ -67,15 +66,17 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
"VALID",
data_format="NCHW",
name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 6, 2, 2),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 6, 2, 2)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index 56bdf848ea..d7c165784b 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -38,15 +38,14 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 8, 8, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
is_training=False)
e = constant_op.constant(
np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
@@ -58,15 +57,17 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
idty = array_ops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 2, 2, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 2, 2, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 355303acf6..21c0c30c19 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -16,6 +16,7 @@ config_setting(
py_binary(
name = "predict",
srcs = ["predict.py"],
+ data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = select({
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index 71621abc71..1226433625 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -41,7 +41,7 @@ _MODULE_PATH = path.dirname(__file__)
_DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv")
-def state_space_esitmator(exogenous_feature_columns):
+def state_space_estimator(exogenous_feature_columns):
"""Constructs a StructuralEnsembleRegressor."""
def _exogenous_update_condition(times, features):
@@ -68,7 +68,7 @@ def state_space_esitmator(exogenous_feature_columns):
4, 64)
-def autoregressive_esitmator(exogenous_feature_columns):
+def autoregressive_estimator(exogenous_feature_columns):
input_window_size = 8
output_window_size = 2
return (
@@ -169,10 +169,10 @@ def main(unused_argv):
"Please install matplotlib to generate a plot from this example.")
make_plot("Ignoring a known anomaly (state space)",
*train_and_evaluate_exogenous(
- estimator_fn=state_space_esitmator))
+ estimator_fn=state_space_estimator))
make_plot("Ignoring a known anomaly (autoregressive)",
*train_and_evaluate_exogenous(
- estimator_fn=autoregressive_esitmator, train_steps=3000))
+ estimator_fn=autoregressive_estimator, train_steps=3000))
pyplot.show()
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
index 8c64f2e186..57ccf8f260 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
@@ -28,7 +28,7 @@ class KnownAnomalyExampleTest(test.TestCase):
def test_shapes_and_variance_structural_ar(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=1, estimator_fn=known_anomaly.autoregressive_esitmator)
+ train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
@@ -40,7 +40,7 @@ class KnownAnomalyExampleTest(test.TestCase):
def test_shapes_and_variance_structural_ssm(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=50, estimator_fn=known_anomaly.state_space_esitmator)
+ train_steps=50, estimator_fn=known_anomaly.state_space_estimator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
diff --git a/tensorflow/contrib/timeseries/examples/predict.py b/tensorflow/contrib/timeseries/examples/predict.py
index 8147d40caa..b036911314 100644
--- a/tensorflow/contrib/timeseries/examples/predict.py
+++ b/tensorflow/contrib/timeseries/examples/predict.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import os
import sys
import numpy as np
@@ -40,6 +41,10 @@ except ImportError:
FLAGS = None
+_MODULE_PATH = os.path.dirname(__file__)
+_DEFAULT_DATA_FILE = os.path.join(_MODULE_PATH, "data/period_trend.csv")
+
+
def structural_ensemble_train_and_predict(csv_file_name):
# Cycle between 5 latent values over a period of 100. This leads to a very
# smooth periodic component (and a small model), which is a good fit for our
@@ -115,9 +120,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
+ input_filename = FLAGS.input_filename
+ if input_filename is None:
+ input_filename = _DEFAULT_DATA_FILE
make_plot("Structural ensemble",
- *structural_ensemble_train_and_predict(FLAGS.input_filename))
- make_plot("AR", *ar_train_and_predict(FLAGS.input_filename))
+ *structural_ensemble_train_and_predict(input_filename))
+ make_plot("AR", *ar_train_and_predict(input_filename))
pyplot.show()
@@ -126,7 +134,7 @@ if __name__ == "__main__":
parser.add_argument(
"--input_filename",
type=str,
- required=True,
- help="Input csv file.")
+ required=False,
+ help="Input csv file (omit to use the data/period_trend.csv).")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
index 5eb4deefb9..de547f835d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
@@ -195,7 +195,7 @@ class ARModelTest(test.TestCase):
self.train_helper(input_window_size=10,
loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
train_steps=300,
- max_loss=2.5,
+ max_loss=50., # Just make sure there are no exceptions.
anomaly_distribution=None)
def test_autoregression_normal_multiple_periods(self):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 983455f63d..461fe22210 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -69,8 +69,10 @@ class TimeSeriesRegressorTest(test.TestCase):
input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
batch_size=16, window_size=16)
first_estimator.train(input_fn=train_input_fn, steps=1)
- first_loss_before_fit = first_estimator.evaluate(
- input_fn=eval_input_fn, steps=1)["loss"]
+ first_evaluation = first_estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ first_loss_before_fit = first_evaluation["loss"]
+ self.assertAllEqual(first_loss_before_fit, first_evaluation["average_loss"])
self.assertAllEqual([], first_loss_before_fit.shape)
first_estimator.train(input_fn=train_input_fn, steps=1)
first_loss_after_fit = first_estimator.evaluate(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 32194e400e..1f9f9b7aa6 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.summary import summary
@@ -123,6 +124,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
_identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
model_outputs.end_state))
+ metrics[metric_keys.MetricKeys.LOSS_MEAN] = metrics_impl.mean(
+ model_outputs.loss, name="average_loss")
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index bda3b53aca..647455ae42 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -122,7 +122,7 @@ class EvaluationMetricsTests(test.TestCase):
metric[1] for metric in outputs.eval_metric_ops.values()]
loss_mean, loss_update = metrics.mean(outputs.loss)
metric_update_ops.append(loss_update)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(sess, coord=coordinator)
variables.local_variables_initializer().run()
@@ -172,6 +172,7 @@ class EvaluationMetricsTests(test.TestCase):
evaluation = estimator.evaluate(input_fn, steps=1)
self.assertIn("plain_boring_metric386", evaluation)
self.assertIn("fun_metric101", evaluation)
+ self.assertIn("average_loss", evaluation)
# The values are deterministic because of fixed tf_random_seed.
# However if they become flaky, remove such exacts comparisons.
self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380)
@@ -398,6 +399,7 @@ class OneShotTests(parameterized.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
result = estimator.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertIn("average_loss", result)
self.assertNotIn(feature_keys.State.STATE_TUPLE, result)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
export_location = estimator.export_savedmodel(_new_temp_dir(),
diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
index 703537abf0..f92148b788 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
@@ -88,7 +88,7 @@ class RandomWindowInputFnTests(test.TestCase):
window_size=window_size, batch_size=batch_size)
result, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
@@ -261,7 +261,7 @@ class WholeDatasetInputFnTests(test.TestCase):
def _whole_dataset_input_fn_test_template(
self, time_series_reader, num_features, num_samples):
result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables.local_variables_initializer())
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -340,7 +340,7 @@ class AllWindowInputFnTests(test.TestCase):
window_size=window_size)
features, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index b9f8620fd8..c0de42b15b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -55,7 +55,7 @@ class MathUtilsTest(test.TestCase):
running_sum = running_sum + current_contribution
# pylint: enable=g-no-augmented-assignment
transition_power = numpy.dot(transition, transition_power)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.power_sums_tensor(
array_size, transition, addition).eval())
@@ -66,7 +66,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(powers.shape[0]):
result.append(numpy.linalg.matrix_power(matrix, powers[i]))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.matrix_to_powers(matrix, powers).eval(),
rtol=1e-5,
@@ -78,7 +78,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(batch.shape[0]):
result.append(numpy.linalg.matrix_power(batch[i], powers[i]))
- with self.test_session():
+ with self.cached_session():
# TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be
# made slightly more stable?
self.assertAllClose(result,
@@ -91,7 +91,7 @@ class MathUtilsTest(test.TestCase):
left_transpose = numpy.transpose(left, [0, 2, 1])
right = numpy.random.normal(size=[2, 3]).astype(numpy.float32)
expected_result = numpy.dot(left, right)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.batch_times_matrix(
left, right).eval())
@@ -114,7 +114,7 @@ class MathUtilsTest(test.TestCase):
right_transpose = numpy.transpose(right, [0, 2, 1])
expected_result = numpy.transpose(numpy.dot(right_transpose, left.T),
[0, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.matrix_times_batch(
left, right).eval())
@@ -132,7 +132,7 @@ class MathUtilsTest(test.TestCase):
adj_x=True, adj_y=True).eval())
def test_make_diagonal_undefined_shapes(self):
- with self.test_session():
+ with self.cached_session():
completely_undefined = array_ops.placeholder(dtype=dtypes.float32)
partly_undefined = array_ops.placeholder(
shape=[None, None], dtype=dtypes.float32)
@@ -152,7 +152,7 @@ class MathUtilsTest(test.TestCase):
[5., 6.]]}))
def test_make_diagonal_mostly_defined_shapes(self):
- with self.test_session():
+ with self.cached_session():
mostly_defined = array_ops.placeholder(
shape=[None, 2], dtype=dtypes.float32)
blocked = math_utils.block_diagonal([[[2.]],
@@ -192,7 +192,7 @@ class TestMakeToeplitzMatrix(test.TestCase):
def _test_make_toeplitz_matrix(self, inputs, output_expected):
output_tf = math_utils.make_toeplitz_matrix(inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_tf_np = sess.run(output_tf)
self.assertAllClose(output_tf_np, output_expected)
@@ -201,13 +201,13 @@ class TestMakeCovarianceMatrix(test.TestCase):
def test_zero_size_matrix(self):
raw = numpy.zeros([0, 0])
- with self.test_session():
+ with self.cached_session():
constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval()
self.assertEqual((0, 0), constructed.shape)
def test_sign_magnitude_positive_definite(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
matrix_tensor = math_utils.sign_magnitude_positive_definite(
raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype),
off_diagonal_scale=constant_op.constant(-1., dtype=dtype),
@@ -230,7 +230,8 @@ class TestLookupTable(test.TestCase):
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
- with self.test_session() as session:
+
+ with self.cached_session() as session:
((float_output, double_output), int_output) = session.run(
hash_table.lookup([2, 1, 0]))
def expected_output_before_insert(base_tensor):
@@ -290,7 +291,7 @@ class InputStatisticsTests(test.TestCase):
time_series_reader=input_pipeline.NumpyReader(features))
statistics = stat_object.initialize_graph(
features=input_fn()[0])
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index cfd31cc70d..a049dbe773 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -29,7 +29,7 @@ class ModelUtilsTest(test.TestCase):
def test_parameter_switching(self):
parameter = array_ops.constant(5)
overridden_parameter = array_ops.constant(3)
- with self.test_session():
+ with self.cached_session():
getter = model_utils.parameter_switch({overridden_parameter: 4})
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
index 5f7e3da2db..42ba6e1c25 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
@@ -127,7 +127,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -178,7 +178,7 @@ class ChainingStateManagerTest(test.TestCase):
result_model_outputs = chainer.define_loss(
model=stub_model, features=result_input_fn()[0],
mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -221,7 +221,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index 1fb4a3c121..c2eaa78493 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -190,13 +190,13 @@ class StateSpaceEquivalenceTests(test.TestCase):
estimator.build_raw_serving_input_receiver_fn())
with ops.Graph().as_default() as graph:
random_model.initialize_graph()
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
variables.global_variables_initializer().run()
evaled_start_state = session.run(random_model.get_start_state())
evaled_start_state = [
state_element[None, ...] for state_element in evaled_start_state]
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as session:
+ with self.session(graph=graph) as session:
signatures = loader.load(
session, [tag_constants.SERVING], export_location)
first_split_filtering = saved_model_utils.filter_continuation(
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 56e451e2e3..298ffc1ded 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -16,6 +16,7 @@ package(
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
"//learning/deepmind:__subpackages__",
+ "//medical/pathology:__subpackages__",
"//tensorflow:__subpackages__",
],
)
@@ -166,6 +167,7 @@ py_library(
name = "keras_support",
srcs = [
"python/tpu/keras_support.py",
+ "python/tpu/keras_tpu_variables.py",
],
srcs_version = "PY2AND3",
visibility = [
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 537d94b797..3c0456dc2f 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -33,6 +33,7 @@
@@shard
@@batch_parallel
@@rewrite
+@@outside_compilation
@@CrossShardOptimizer
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 06553929dc..ea8e0e00ed 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -18,28 +18,111 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_OP("AllToAll")
+ .Input("input: T")
+ .Input("group_assignment: int32")
+ .Output("output: T")
+ .Attr("T: {bfloat16, float}")
+ .Attr("concat_dimension: int")
+ .Attr("split_dimension: int")
+ .Attr("split_count: int")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ int64 rank;
+ if (c->RankKnown(input)) {
+ rank = c->Rank(input);
+ } else {
+ return errors::InvalidArgument("input's rank is unknown.");
+ }
+ int concat_dimension;
+ int split_dimension;
+
+ TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
+
+ if (concat_dimension < 0 || concat_dimension >= rank) {
+ return errors::InvalidArgument("concat_dimension ", concat_dimension,
+ " is out of range of input rank ", rank);
+ }
+
+ TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension));
+ if (split_dimension < 0 || split_dimension >= rank) {
+ return errors::InvalidArgument("split_dimension ", split_dimension,
+ " is out of range of input rank ", rank);
+ }
+
+ std::vector<DimensionHandle> dims;
+ dims.resize(rank);
+
+ for (int32 i = 0; i < rank; ++i) {
+ int64 in_idx = i;
+ if (i == concat_dimension) {
+ in_idx = split_dimension;
+ } else if (i == split_dimension) {
+ in_idx = concat_dimension;
+ }
+
+ dims[i] = c->Dim(input, in_idx);
+ }
+
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+An Op to exchange data across TPU replicas. On each replica, the input is
+split into `split_count` blocks along `split_dimension` and send to the other
+replicas given group_assignment. After receiving `split_count` - 1 blocks from
+other replicas, we concatenate the blocks along `concat_dimension` as the
+output.
+
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+
+group_assignment=`[[0, 1]]`
+concat_dimension=0
+split_dimension=1
+split_count=2
+
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+
+input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
+concat_dimension: The dimension number to concatenate.
+split_dimension: The dimension number to split.
+split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+output: The exchanged result.
+T: The type of elements to be exchanged.
+)doc");
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
+ .Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {bfloat16, float}")
- .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
-An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input. If group_assignment is empty, the output of
-each is the sum of all the inputs, otherwise the output of each is the sum of
-the inputs belonging to the same group.
+An Op to sum inputs across replicated TPU instances. Each instance supplies its
+own input.
-For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
-group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
-Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
+For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
+Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
+and `B, D, F, H` as group 1. Thus we get the outputs:
+`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
-group_assignment: The list of group ids. `group_assignment[i]` represents the
- group id of replica i.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 8e6e9aa0cd..b498599962 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,7 +156,8 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)));
+ stub->NewSession(&context, new_session_request, &new_session_response)))
+ << new_session_response.error_message();
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index 1f249de314..feb177a7da 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -8,6 +8,8 @@ message Profile {
Node by_category = 1;
// Root of a profile broken down by program structure.
Node by_program_structure = 2;
+ // Per program profile, indexed by hlo module name of the program.
+ map<string, Node> per_program = 3;
}
// An entry in the profile tree. (An instruction, or set of instructions).
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 2b13343efa..f88dc51636 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -79,12 +79,15 @@ message StepInfoResult {
// The step duration in picoseconds.
optional uint64 duration_ps = 2;
// The infeed duration in picoseconds.
- // Can turn into a map if we want a variable number of ops.
optional uint64 infeed_duration_ps = 3;
+ // The outfeed duration in picoseconds.
+ optional uint64 host_outfeed_ps = 8;
// The start time of this step in picoseconds.
optional uint64 begin_ps = 4;
// The waiting time within this step in picoseconds.
optional uint64 wait_duration_ps = 5;
+ // The unit b outfeed duration in picoseconds.
+ optional uint64 unit_b_outfeed_ps = 9;
// The time spent on cross-replica-sum in picoseconds.
optional uint64 crs_duration_ps = 6;
// Percentage of unit b time spent on infeed.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index 2cc17d6d92..fc1320501b 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -9,8 +9,8 @@ message ClippingLimits {
google.protobuf.FloatValue upper = 2; // +inf if not set
}
-// Get the learning rate from a <yet to be determined> source that can change
-// dynamically.
+// Get the learning rate from the parameters of the SendTPUEmbeddingGradients
+// op.
message DynamicLearningRate {
}
@@ -119,7 +119,9 @@ message OptimizationParameters {
// Whether to use gradient accumulation (do two passes over the input
// gradients: one to accumulate them into a temporary array and another to
- // apply them using the actual optimization algorithm).
+ // apply them using the actual optimization algorithm). This feature is
+ // experimental -- it has not been fully verified and may cause training
+ // crashes and/or failures.
bool use_gradient_accumulation = 15;
// Optimization algorithm parameters; which field is selected determines which
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index bf442d9116..d92a0652bb 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -21,8 +21,10 @@ from __future__ import print_function
import platform
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
@@ -36,10 +38,85 @@ if platform.system() != "Windows":
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
+ def _create_default_group_assignment():
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ logging.warning(
+ "cross_replica_sum should be used within a tpu_shard_context, but "
+ "got unset number_of_shards. Assuming 1.")
+ num_shards = 1
+ group_assignment = [list(range(num_shards))]
+ return group_assignment
+
+ def all_to_all(x,
+ concat_dimension,
+ split_dimension,
+ split_count,
+ group_assignment=None,
+ name=None):
+ """Exchange data across TPU replicas.
+
+ Args:
+ x: The local tensor.
+ concat_dimension: The dimension number to concatenate.
+ split_dimension: The dimension number to split.
+ split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is concatenated by data from different replicas.
+ """
+ if group_assignment is None:
+ group_assignment = _create_default_group_assignment()
+ return gen_tpu_ops.all_to_all(
+ x,
+ group_assignment,
+ concat_dimension=concat_dimension,
+ split_dimension=split_dimension,
+ split_count=split_count,
+ name=name)
+
+ @ops.RegisterGradient("AllToAll")
+ def _all_to_all_grad(op, grad):
+ # The gradient of a all-to-all is also a all-to-all but the
+ # split_dimension and concat_dimension is swapped.
+ # The graident with respect to group_assignment is None.
+ return [
+ gen_tpu_ops.all_to_all(
+ grad,
+ op.inputs[1],
+ concat_dimension=op.get_attr("split_dimension"),
+ split_dimension=op.get_attr("concat_dimension"),
+ split_count=op.get_attr("split_count")), None
+ ]
+
+ def cross_replica_sum(x, group_assignment=None, name=None):
+ """Sum the input tensor accorss replicas according to group_assignment.
+
+ Args:
+ x: The local tensor to the sum.
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is summed across replicas.
+ """
+ if group_assignment is None:
+ group_assignment = _create_default_group_assignment()
+
+ return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
+ # The graident with respect to group_assignment is None.
+ return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index ff893a722f..d8c3872363 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -54,26 +54,36 @@ import time
import numpy as np
-from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
+from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
+from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import training_arrays
+from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.layers import embeddings
+from tensorflow.python.keras.utils.generic_utils import make_batches
+from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
@@ -82,10 +92,120 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
- from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- return tpu_strategy.TPUStrategy(*args, **kw)
+_SESSIONS = {}
+
+
+def tpu_session(cluster_resolver):
+ """Construct or return a `tf.Session` connected to the given cluster."""
+ global _SESSIONS
+ master = cluster_resolver.master()
+ if master not in _SESSIONS:
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
+ logging.info('Connecting to: %s', master)
+ graph = ops.Graph()
+ session = tf_session.Session(graph=graph, target=master, config=config)
+ with graph.as_default():
+ session.run(tpu.initialize_system())
+
+ _SESSIONS[master] = session
+ return _SESSIONS[master]
+
+
+def reset_tpu_sessions():
+ _SESSIONS.clear()
+
+try:
+ from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ issparse = None
+
+
+def get_tpu_system_metadata(tpu_cluster_resolver):
+ """Retrieves TPU system metadata given a TPUClusterResolver."""
+ master = tpu_cluster_resolver.master()
+
+ # pylint: disable=protected-access
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
+ tpu_system_metadata = (
+ tpu_system_metadata_lib._query_tpu_system_metadata(
+ master,
+ cluster_def=cluster_def,
+ query_topology=False))
+
+ return tpu_system_metadata
+
+
+class TPUDistributionStrategy(object):
+ """The strategy to run Keras model on TPU."""
+
+ def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
+ """Construct a TPUDistributionStrategy.
+
+ Args:
+ tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
+ create one with '' as master address.
+ using_single_core: Bool. This is the debugging option, which might be
+ removed in future once the model replication functionality is mature
+ enough. If `False` (default behavior), the system automatically finds
+ the best configuration, in terms of number of TPU cores, for the model
+ replication, typically using all avaiable TPU cores. If overwrites as
+ `True`, force the model replication using single core, i.e., no
+ replication.
+ """
+
+ if tpu_cluster_resolver is None:
+ tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
+
+ metadata = get_tpu_system_metadata(tpu_cluster_resolver)
+ self._tpu_metadata = metadata
+ self._tpu_cluster_resolver = tpu_cluster_resolver
+ self._num_cores = 1 if using_single_core else metadata.num_cores
+
+ # Walk device list to identify TPU worker for enqueue/dequeue operations.
+ worker_re = re.compile('/job:([^/]+)')
+ for device in metadata.devices:
+ if 'TPU:0' in device.name:
+ self._worker_name = worker_re.search(device.name).group(1)
+ break
+
+ def _make_assignment_for_model(self, cpu_model):
+ """Makes a `TPUAssignment` for the passed in `cpu_model`."""
+ num_cores = self._num_cores
+ if num_cores > 1 and cpu_model.stateful:
+ logging.warning(
+ 'Model replication does not currently support stateful models. '
+ 'Degrading to a single core.')
+ num_cores = 1
+
+ return TPUAssignment(
+ worker_name=self._worker_name, num_cores=num_cores)
+
+
+class TPUAssignment(object):
+ """This is object holding TPU resources assignment for the concrete model.
+
+ `TPUDistributionStrategy` is responsible to create the instance of
+ `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
+ model and input batch sizes.
+ """
+
+ def __init__(self, worker_name, num_cores):
+ self._worker_name = worker_name
+ self._num_cores = num_cores
+
+ @property
+ def worker_name(self):
+ return self._worker_name
+
+ @property
+ def num_towers(self):
+ # TODO(xiejw): Support automatically assign num_cores based on inputs.
+ return self._num_cores
class TPUEmbedding(embeddings.Embedding):
@@ -138,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def set_weights(self, weights):
+ # TODO(power): Figure out whether we really need this given there is no
+ # caller for this API yet.
self._opt.set_weights()
def get_weights(self):
@@ -162,9 +284,9 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- if tpu_function.get_tpu_context().number_of_shards == 1:
- return opt
-
+ # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
+ # single core. This ensures Keras properly tracks and initializes optimizer
+ # variables.
if isinstance(opt, keras_optimizers.TFOptimizer):
return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
else:
@@ -405,8 +527,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_dict[tensor] = value
return infeed_dict
- def __init__(self, distribution_strategy):
- self._strategy = distribution_strategy
+ def __init__(self, tpu_assignment):
+ self._tpu_assignment = tpu_assignment
def _split_tensors(self, inputs):
"""Split input data across shards.
@@ -419,16 +541,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
Returns:
List of lists containing the input to feed to each TPU shard.
"""
- if self._strategy.num_towers == 1:
+ if self._tpu_assignment.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
- assert batch_size % self._strategy.num_towers == 0, (
- 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
- (batch_size, self._strategy.num_towers))
- shard_size = batch_size // self._strategy.num_towers
+ assert batch_size % self._tpu_assignment.num_towers == 0, (
+ 'batch_size must be divisible by the number of TPU cores in use (%s '
+ 'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
+ shard_size = batch_size // self._tpu_assignment.num_towers
input_list = []
- for index in range(self._strategy.num_towers):
+ for index in range(self._tpu_assignment.num_towers):
shard_inputs = [
x[index * shard_size:(index + 1) * shard_size] for x in inputs
]
@@ -443,8 +565,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_op = []
shard_infeed_tensors = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
@@ -483,30 +606,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# TODO(saeta): Verify tpu_model_op is as expected!
return {}
- def __init__(self, dataset, distribution_strategy, tpu_session):
+ # pylint: disable=redefined-outer-name
+ def __init__(self, dataset, tpu_assignment, tpu_session):
"""Constructs a TPUDatasetInfeedManager.
Must be called within a `KerasTPUModel.tpu_session` context!
Args:
dataset: A `tf.data.Dataset` to infeed.
- distribution_strategy: The `TPUDistributionStrategy` used to configure the
+ tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
tpu_session: The `tf.Session` object used for running the TPU model.
"""
self._verify_dataset_shape(dataset)
self._dataset = dataset
- self._strategy = distribution_strategy
+ self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
- dummy_x_shape[0] *= distribution_strategy.num_towers
+ dummy_x_shape[0] *= tpu_assignment.num_towers
dummy_y_shape = dataset.output_shapes[1].as_list()
- dummy_y_shape[0] *= distribution_strategy.num_towers
+ dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
tpu_session.run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
- for i in range(distribution_strategy.num_towers):
+ for i in range(tpu_assignment.num_towers):
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
# TODO(saeta): Ensure correct placement!
get_next_op = self._iterator.get_next()
@@ -570,7 +694,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
'currently requires static shapes. The provided '
'dataset only has a partially defined shape. '
'(Dimension %d of output tensor %d is not statically known '
- 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint))
+ 'for output shapes: %s.%s)' % (j, i, dataset.output_shapes, hint))
@property
def dummy_x(self):
@@ -586,10 +710,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
def build_infeed_from_input_specs(self, input_specs, execution_mode):
shard_infeed_tensors = self._get_next_ops
- assert len(shard_infeed_tensors) == self._strategy.num_towers
+ assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
infeed_ops = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
@@ -612,10 +737,10 @@ class TPUFunction(object):
instead of being injected as `feed_dict` items or fetches.
"""
- def __init__(self, model, execution_mode, strategy):
+ def __init__(self, model, execution_mode, tpu_assignment):
self.model = model
self.execution_mode = execution_mode
- self._strategy = strategy
+ self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
@@ -666,9 +791,10 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
- # TODO(power): Replicate variables.
- with ops.device('/device:TPU:0'):
- self._cloned_model = models.clone_model(self.model)
+ with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ with keras_tpu_variables.replicated_scope(
+ self._tpu_assignment.num_towers):
+ self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
@@ -737,7 +863,7 @@ class TPUFunction(object):
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
# running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
- _model_fn, inputs=[[]] * self._strategy.num_towers)
+ _model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
@@ -745,8 +871,9 @@ class TPUFunction(object):
input_specs, self.execution_mode)
# Build output ops.
outfeed_op = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
@@ -764,7 +891,7 @@ class TPUFunction(object):
def _test_model_compiles(self, tpu_model_ops):
"""Verifies that the given TPUModelOp can be compiled via XLA."""
logging.info('Started compiling')
- start_time = time.clock()
+ start_time = time.time()
result = K.get_session().run(tpu_model_ops.compile_op)
proto = tpu_compilation_result.CompilationResultProto()
@@ -773,38 +900,52 @@ class TPUFunction(object):
raise RuntimeError('Compilation failed: {}'.format(
proto.status_error_message))
- end_time = time.clock()
+ end_time = time.time()
logging.info('Finished compiling. Time elapsed: %s secs',
end_time - start_time)
- def __call__(self, inputs):
- assert isinstance(inputs, list)
+ def _lookup_infeed_manager(self, inputs):
+ """Return an existing manager, or construct a new InfeedManager for inputs.
+
+ _lookup_infeed_manager will return an existing InfeedManager if one has been
+ previously assigned for this model and input. If not, it will construct a
+ new TPUNumpyInfeedManager.
+
+ Args:
+ inputs: A NumPy input to the model.
+
+ Returns:
+ A `TPUInfeedManager` object to manage infeeds for this input.
+ """
+ if inputs is None:
+ return None
- infeed_manager = None
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
- infeed_manager = mgr
- break
- if infeed_manager is None:
- infeed_manager = TPUNumpyInfeedManager(self.model._strategy)
+ return mgr
+ return TPUNumpyInfeedManager(self.model._tpu_assignment)
- # Strip sample weight from inputs
- if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
- self.execution_mode == model_fn_lib.ModeKeys.EVAL):
- input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- else:
- input_tensors = self.model._feed_inputs
+ def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
+ """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
- infeed_instance = infeed_manager.make_infeed_instance(inputs)
- del inputs # To avoid accident usage.
- input_specs = infeed_instance.make_input_specs(input_tensors)
+ It instantiates a new copy of the model for each unique input shape.
+
+ Args:
+ input_specs: The specification of the inputs to train on.
+ infeed_manager: The infeed manager responsible for feeding in data.
+
+ Returns:
+ A `TPUModelOp` instance that can be used to execute a step of the model.
+ """
+ if input_specs is None or infeed_manager is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None
# XLA requires every operation in the graph has a fixed shape. To
# handle varying batch sizes we recompile a new sub-graph for each
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
-
if shape_key not in self._compilation_cache:
with self.model.tpu_session():
logging.info('New input shapes; (re-)compiling: mode=%s, %s',
@@ -814,24 +955,47 @@ class TPUFunction(object):
self._compilation_cache[shape_key] = new_tpu_model_ops
self._test_model_compiles(new_tpu_model_ops)
- # Initialize our TPU weights on the first compile.
- self.model._initialize_weights(self._cloned_model)
- tpu_model_ops = self._compilation_cache[shape_key]
+ return self._compilation_cache[shape_key]
- infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+ def _construct_input_tensors_and_inputs(self, inputs):
+ """Returns input tensors and numpy array inputs corresponding to `inputs`.
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ Args:
+ inputs: NumPy inputs.
- # TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
+ Returns:
+ A tuple of `input_tensors`, and `inputs`.
+ """
+ if inputs is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None, None
+ # Strip sample weight from inputs
+ if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
+ self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ input_tensors = self.model._feed_inputs + self.model._feed_targets
+ inputs = inputs[:len(input_tensors)]
+ return input_tensors, inputs
+ else:
+ input_tensors = self.model._feed_inputs
+ return input_tensors, inputs
+
+ def _process_outputs(self, outfeed_outputs):
+ """Processes the outputs of a model function execution.
+
+ Args:
+ outfeed_outputs: The sharded outputs of the TPU computation.
+
+ Returns:
+ The aggregated outputs of the TPU computation to be used in the rest of
+ the model execution.
+ """
+ # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
outputs = [[]] * len(self._outfeed_spec)
outputs_per_replica = len(self._outfeed_spec)
- for i in range(self._strategy.num_towers):
+ for i in range(self._tpu_assignment.num_towers):
output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
outputs_per_replica]
for j in range(outputs_per_replica):
@@ -839,13 +1003,145 @@ class TPUFunction(object):
return [np.concatenate(group) for group in outputs]
else:
- return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers]
+ return outfeed_outputs[:len(outfeed_outputs) //
+ self._tpu_assignment.num_towers]
+
+ def __call__(self, inputs):
+ """__call__ executes the function on the computational hardware.
+
+ It handles executing infeed, and preprocessing in addition to executing the
+ model on the TPU hardware.
+
+ Note: `__call__` has a sibling method `pipeline_run` which performs the same
+ operations, but with software pipelining.
+
+ Args:
+ inputs: The inputs to use to train.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ assert isinstance(inputs, list)
+
+ infeed_manager = self._lookup_infeed_manager(inputs)
+ input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs)
+ infeed_instance = infeed_manager.make_infeed_instance(inputs)
+ del inputs # To avoid accident usage.
+ input_specs = infeed_instance.make_input_specs(input_tensors)
+ tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
+ infeed_manager)
+ infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+
+ def pipeline_run(self, cur_step_inputs, next_step_inputs):
+ """pipeline_run executes the function on the computational hardware.
+
+ pipeline_run performs the same computation as __call__, however it runs the
+ infeed in a software pipelined fashion compared to the on-device execution.
+
+ Note: it is the responsibility of the caller to call `pipeline_run` in the
+ following sequence:
+ - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)`
+ - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s
+ - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None`
+ Additionally, it is the responsibility of the caller to pass
+ `next_step_inputs` as `cur_step_inputs` on the next invocation of
+ `pipeline_run`.
+
+ Args:
+ cur_step_inputs: The current step's inputs.
+ next_step_inputs: The next step's inputs.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ # Software pipelined case.
+ next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
+ cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
+
+ if (next_step_infeed_manager is not None
+ and cur_step_infeed_manager is not None):
+ assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
+
+ next_input_tensors, next_step_inputs = (
+ self._construct_input_tensors_and_inputs(next_step_inputs))
+ cur_input_tensors, cur_step_inputs = (
+ self._construct_input_tensors_and_inputs(cur_step_inputs))
+
+ cur_infeed_instance = None
+ if cur_step_infeed_manager:
+ cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance(
+ cur_step_inputs)
+ next_infeed_instance = None
+ if next_step_infeed_manager:
+ next_infeed_instance = next_step_infeed_manager.make_infeed_instance(
+ next_step_inputs)
+
+ del cur_step_inputs # Avoid accidental re-use.
+ del next_step_inputs # Avoid accidental re-use.
+
+ cur_tpu_model_ops = None
+ next_tpu_model_ops = None
+ infeed_dict = None
+
+ if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
+ cur_input_specs = cur_infeed_instance.make_input_specs(
+ cur_input_tensors)
+ cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ cur_input_specs, cur_step_infeed_manager)
+
+ if (next_infeed_instance
+ and next_input_tensors
+ and next_step_infeed_manager):
+ next_input_specs = next_infeed_instance.make_input_specs(
+ next_input_tensors)
+ next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ next_input_specs, next_step_infeed_manager)
+ infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ if next_tpu_model_ops and cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+ if cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, outfeed_outputs = session.run([
+ cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ return self._process_outputs(outfeed_outputs)
+ if next_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ return None
+ raise RuntimeError('Internal error: both current & next tpu_model_ops '
+ 'were None')
+
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
- def __init__(self, cpu_model, tpu_name_or_address, strategy):
+ def __init__(self, cpu_model, strategy):
super(models.Model, self).__init__( # pylint: disable=bad-super-call
inputs=cpu_model.inputs,
outputs=cpu_model.outputs,
@@ -860,29 +1156,15 @@ class KerasTPUModel(models.Model):
self.predict_function = None
self.test_function = None
self.train_function = None
- self._strategy = strategy
- self._tpu_name_or_address = tpu_name_or_address
+ cluster_resolver = strategy._tpu_cluster_resolver
+ self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
+ self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
self._tpu_model = None
self._tpu_weights_initialized = False
- self._graph = ops.Graph()
-
- self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
- tpu_name_or_address)
- master = self._cluster_resolver.master()
- cluster_spec = self._cluster_resolver.cluster_spec()
- self._session = tf_session.Session(
- graph=self._graph,
- target=master,
- config=config_pb2.ConfigProto(isolate_session_state=True))
-
- # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
- if cluster_spec:
- self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- with self._graph.as_default():
- self._session.run(tpu.initialize_system())
+ self._session = tpu_session(cluster_resolver)
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
@@ -901,7 +1183,7 @@ class KerasTPUModel(models.Model):
return {
'cpu_model': self._cpu_model,
'tpu_name_or_address': self._tpu_name_or_address,
- 'strategy': self._strategy,
+ 'tpu_assignment': self._tpu_assignment,
}
def compile(self,
@@ -945,6 +1227,10 @@ class KerasTPUModel(models.Model):
steps_per_epoch=None,
validation_steps=None,
**kwargs):
+ if context.executing_eagerly():
+ raise EnvironmentError('KerasTPUModel currently does not support eager '
+ 'mode.')
+
assert not self._numpy_to_infeed_manager_list # Ensure empty.
infeed_managers = [] # Managers to clean up at the end of the fit call.
@@ -957,7 +1243,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
+ with self.tpu_session() as sess,\
+ ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -965,7 +1252,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -986,7 +1274,8 @@ class KerasTPUModel(models.Model):
if validation_steps is None:
raise ValueError('When using tf.data as validation for a model, you '
'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
val_x = infeed_manager.dummy_x
@@ -996,7 +1285,28 @@ class KerasTPUModel(models.Model):
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).fit(
+ if not kwargs.get('_pipeline', True):
+ logging.info(
+ 'Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs)
+ return self._pipeline_fit(
x,
y,
batch_size,
@@ -1015,23 +1325,479 @@ class KerasTPUModel(models.Model):
finally:
self._numpy_to_infeed_manager_list = []
+ def evaluate(self,
+ x=None,
+ y=None,
+ batch_size=None,
+ verbose=1,
+ sample_weight=None,
+ steps=None):
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with self.tpu_session() as sess:
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(
+ x,
+ y,
+ batch_size,
+ verbose,
+ sample_weight,
+ steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
+
+ def _pipeline_fit(self,
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs):
+ # Similar to super.fit(...), but modified to support software pipelining.
+
+ # Backwards compatibility
+ if batch_size is None and steps_per_epoch is None:
+ batch_size = 32
+ # Legacy support
+ if 'nb_epoch' in kwargs:
+ logging.warning('The `nb_epoch` argument in `fit` has been renamed '
+ '`epochs`.')
+ epochs = kwargs.pop('nb_epoch')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ # Validate and standardize user data
+ x, y, sample_weights = self._standardize_user_data(
+ x,
+ y,
+ sample_weight=sample_weight,
+ class_weight=class_weight,
+ batch_size=batch_size,
+ check_steps=True,
+ steps_name='steps_per_epoch',
+ steps=steps_per_epoch,
+ validation_split=validation_split)
+
+ # Prepare validation data
+ val_x, val_y, val_sample_weights = self._prepare_validation_data(
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size)
+ return self._pipeline_fit_loop(
+ x,
+ y,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_inputs=val_x,
+ val_targets=val_y,
+ val_sample_weights=val_sample_weights,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
+
+ def _pipeline_fit_loop(self,
+ inputs,
+ targets,
+ sample_weights,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ shuffle,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps):
+ self._make_train_function()
+ sample_weights = sample_weights or []
+ val_sample_weights = val_sample_weights or []
+ if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = inputs + targets + sample_weights + [1]
+ else:
+ ins = inputs + targets + sample_weights
+
+ do_validation = False
+ if val_inputs:
+ do_validation = True
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
+ print('Train on %d samples, validate on %d samples' %
+ (inputs[0].shape[0], val_inputs[0].shape[0]))
+
+ if validation_steps:
+ do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` when doing step-wise '
+ 'training, i.e. `steps_per_epoch` must be set.')
+
+ num_training_samples = training_utils.check_num_samples(
+ ins, batch_size, steps_per_epoch, 'steps_per_epoch')
+ count_mode = 'steps' if steps_per_epoch else 'samples'
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ self,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ samples=num_training_samples,
+ validation_steps=validation_steps,
+ verbose=verbose,
+ count_mode=count_mode)
+
+ if num_training_samples is not None:
+ index_array = np.arange(num_training_samples)
+
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
+ callbacks.on_train_begin()
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in self.stateful_metric_functions:
+ m.reset_states()
+ # Update callbacks
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ if steps_per_epoch is not None:
+ # Step-wise fit loop.
+ self._pipeline_fit_loop_step_wise(
+ ins=ins,
+ callbacks=callbacks,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+ else:
+ # Sample-wise fit loop.
+ self._pipeline_fit_loop_sample_wise(
+ ins=ins,
+ callbacks=callbacks,
+ index_array=index_array,
+ shuffle=shuffle,
+ batch_size=batch_size,
+ num_training_samples=num_training_samples,
+ indices_for_conversion_to_dense=indices_for_conversion_to_dense,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+ return self.history
+
+ def _pipeline_fit_loop_sample_wise(self,
+ ins,
+ callbacks,
+ index_array,
+ shuffle,
+ batch_size,
+ num_training_samples,
+ indices_for_conversion_to_dense,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+ if shuffle == 'batch':
+ index_array = training_utils.batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+ batches = make_batches(num_training_samples, batch_size)
+
+ ins_last_batch = None
+ last_batch_logs = None
+ batch_index = 0
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ try:
+ if isinstance(ins[-1], int):
+ # Do not slice the training phase flag.
+ ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = slice_arrays(ins, batch_ids)
+ except TypeError:
+ raise TypeError('TypeError while preparing batch. If using HDF5 '
+ 'input data, pass shuffle="batch".')
+
+ # Pipeline batch logs
+ next_batch_logs = {}
+ next_batch_logs['batch'] = batch_index
+ next_batch_logs['size'] = len(batch_ids)
+ if batch_index > 0:
+ # Callbacks operate one step behind in software pipeline.
+ callbacks.on_batch_begin(batch_index - 1, last_batch_logs)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
+ next_step_inputs=ins_batch)
+ ins_last_batch = ins_batch
+
+ if batch_index == 0:
+ assert outs is None
+ else:
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o # pylint: disable=unsupported-assignment-operation
+ callbacks.on_batch_end(batch_index - 1, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+ last_batch_logs = next_batch_logs
+
+ # Final batch
+ callbacks.on_batch_begin(batch_index, last_batch_logs)
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None)
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o
+ callbacks.on_batch_end(batch_index, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _pipeline_fit_loop_step_wise(self,
+ ins,
+ callbacks,
+ steps_per_epoch,
+ epochs,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+
+ # Loop prologue
+ try:
+ outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins)
+ assert outs is None # Function shouldn't return anything!
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data on the first step '
+ 'of the epoch, preventing further training. Check to '
+ 'make sure your paths are correct and you have '
+ 'permissions to read the files. Skipping validation')
+
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ if step_index < steps_per_epoch - 1:
+ next_step_inputs = ins
+ else:
+ next_step_inputs = None
+ outs = f.pipeline_run(cur_step_inputs=ins,
+ next_step_inputs=next_step_inputs)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your '
+ 'dataset can generate at least `steps_per_batch * '
+ 'epochs` batches (in this case, %d batches). You '
+ 'may need to use the repeat() function when '
+ 'building your dataset.' % steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ batch_logs[l] = o
+
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _prepare_validation_data(self,
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size):
+ """Prepares the validation dataset.
+
+ Args:
+ validation_data: The validation data (if provided)
+ validation_split: The validation split (if provided)
+ validation_steps: The validation steps (if provided)
+ x: The main training data x (if provided)
+ y: The main training data y (if provided)
+ sample_weights: The sample weights (if provided)
+ batch_size: The training batch size (if provided)
+
+ Returns:
+ A 3-tuple of (val_x, val_y, val_sample_weights).
+
+ Raises:
+ ValueError: If the provided arguments are not compatible with
+ `KerasTPUModel`.
+ """
+ # Note: this is similar to a section of $tf/python/keras/engine/training.py
+ # It differns in that tf.data objects are not allowed to be passed directly.
+ # Additionally, it handles validating shapes & types appropriately for use
+ # in TPUs.
+ if validation_data:
+ if (isinstance(validation_data, iterator_ops.Iterator) or
+ isinstance(validation_data, iterator_ops.EagerIterator) or
+ isinstance(validation_data, dataset_ops.Dataset)):
+ raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
+ 'for validation_data. Please instead pass a function '
+ 'that returns a `tf.data.Dataset`.')
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ else:
+ raise ValueError('When passing a `validation_data` argument, it must '
+ 'contain either 2 items (x_val, y_val), or 3 items '
+ '(x_val, y_val, val_sample_weights). However we '
+ 'received `validation_data=%s`' % validation_data)
+ val_x, val_y, val_sample_weights = self._standardize_user_data(
+ val_x,
+ val_y,
+ sample_weight=val_sample_weight,
+ batch_size=batch_size,
+ steps=validation_steps)
+ elif validation_split and 0. < validation_split < 1.:
+ if training_utils.has_symbolic_tensors(x):
+ raise ValueError('If your data is in the form of symbolic tensors, you '
+ 'cannot use `validation_split`.')
+ if hasattr(x[0], 'shape'):
+ split_at = int(x[0].shape[0] * (1. - validation_split))
+ else:
+ split_at = int(len(x[0]) * (1. - validation_split))
+
+ x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
+ y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
+ sample_weights, val_sample_weights = (slice_arrays(
+ sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ elif validation_steps:
+ val_x = []
+ val_y = []
+ val_sample_weights = []
+ else:
+ val_x = None
+ val_y = None
+ val_sample_weights = None
+
+ return val_x, val_y, val_sample_weights
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.TRAIN,
+ tpu_assignment=self._tpu_assignment)
return self.train_function
def _make_test_function(self):
if not self.test_function:
self.test_function = TPUFunction(
- self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
+ self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
return self.test_function
def _make_predict_function(self):
if not self.predict_function:
self.predict_function = TPUFunction(
- self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.PREDICT,
+ tpu_assignment=self._tpu_assignment)
return self.predict_function
def _initialize_weights(self, cloned_model):
@@ -1085,7 +1851,7 @@ class KerasTPUModel(models.Model):
@contextlib.contextmanager
def tpu_session(self):
"""Yields a TPU session and sets it as the default Keras session."""
- with self._graph.as_default():
+ with self._session.graph.as_default():
default_session = K.get_session()
# N.B. We have to call `K.set_session()` AND set our session as the
# TF default. `K.get_session()` surprisingly does not return the value
@@ -1103,6 +1869,7 @@ class KerasTPUModel(models.Model):
self._session.close()
+# pylint: disable=bad-continuation
def _validate_shapes(model):
"""Validate that all layers in `model` have constant shape."""
for layer in model.layers:
@@ -1130,14 +1897,17 @@ Layer: %(layer)s
Input shape: %(input_shape)s
Output shape: %(output_shape)s
""" % {
- 'layer': layer,
- 'input_shape': layer.input_shape,
- 'output_shape': layer.output_shape
- })
+ 'layer': layer,
+ 'input_shape': layer.input_shape,
+ 'output_shape': layer.output_shape
+ })
+
+
+# pylint: enable=bad-continuation
@experimental
-def tpu_model(model, tpu_name_or_address=None, strategy=None):
+def tpu_model(model, strategy=None):
"""Copy `model` along with weights to the TPU. Returns a TPU model.
Usage:
@@ -1148,7 +1918,7 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# If `num_cores_per_host` is greater than one, batch parallelism will be used
# to run on multiple TPU cores.
- strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
+ strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = keras_support.tpu_model(model, strategy)
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
@@ -1158,10 +1928,6 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
Args:
model: A `KerasTPUModel`.
- tpu_name_or_address: A string that is either the name of the Cloud TPU,
- the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
- Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
- examine the environment to determine a potential Cloud TPU to use.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
model across multiple TPU cores.
@@ -1176,9 +1942,13 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None):
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
# TODO(xiejw): Adds reduction option.
+
if strategy is None:
- strategy = TPUDistributionStrategy(num_cores_per_host=1)
- return KerasTPUModel(
- cpu_model=model,
- tpu_name_or_address=tpu_name_or_address,
- strategy=strategy)
+ strategy = TPUDistributionStrategy()
+ else:
+ if not isinstance(strategy, TPUDistributionStrategy):
+ raise TypeError(
+ '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
+ 'Got: {}'.format(type(strategy)))
+
+ return KerasTPUModel(cpu_model=model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
new file mode 100644
index 0000000000..170977d8ab
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -0,0 +1,287 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Distributed variable implementation for TPUs.
+
+N.B. This is an experimental feature that should only be used for Keras support.
+
+It is unsupported and will be removed in favor of Distribution Strategy soon.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import variable_scope
+
+
+@contextlib.contextmanager
+def _handle_graph(handle):
+ with handle.graph.as_default():
+ yield
+
+
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while context is not None and not isinstance(
+ context, control_flow_ops.XLAControlFlowContext):
+ context = context.outer_context
+ return context
+
+
+class ReplicatedVariable(object):
+ """A replicated variable for use on TPUs.
+
+ When accessed inside a tpu.replicate() context, this variable acts as if it
+ is a single variable whose handle is a replicated input to the computation.
+
+ Outside a tpu.replicate() context currently this object has pretty murky
+ semantics, especially with respect to things such as
+ * initialization
+ * colocation.
+ """
+
+ def __init__(self, name, variables):
+ self._name = name
+ self._primary_var = variables[0]
+ self._vars = variables
+ self._cached_value = None
+ self._dtype = variables[0].dtype
+
+ @property
+ def handle(self):
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is None:
+ return self._primary_var.handle
+
+ return tpu_context.get_replicated_var_handle(self)
+
+ @contextlib.contextmanager
+ def _assign_dependencies(self):
+ """Makes assignments depend on the cached value, if any.
+
+ This prevents undefined behavior with reads not ordered wrt writes.
+
+ Yields:
+ None.
+ """
+ if self._cached_value is not None:
+ with ops.control_dependencies([self._cached_value]):
+ yield
+ else:
+ yield
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group([v.initializer for v in self._vars])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def op(self):
+ return self.get().op
+
+ @property
+ def is_tensor_like(self):
+ return True
+
+ def _read_variable_op(self):
+ if _enclosing_tpu_context() is None:
+ return self._primary_var.read_value()
+ v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
+ return v
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def is_initialized(self, name=None):
+ return self._vars[0].is_initialized(name=name)
+
+ def __getitem__(self, *args):
+ return self.read_value().__getitem__(*args)
+
+ def assign(self, value, use_locking=None, name=None, read_value=False):
+ """Assign `value` to all replicas.
+
+ Outside of the tpu.rewrite context, assign explicitly to all replicas.
+ Inside of the tpu.rewrite context, assigns to the local replica.
+
+ Arguments:
+ value: Tensor to assign
+ use_locking: ignored
+ name: ignored
+ read_value: return the value from the assignment
+ Returns:
+ Assignment operation, or new value of the variable if `read_value` is True
+ """
+ del use_locking
+ if _enclosing_tpu_context() is None:
+ assign_ops = []
+ with self._assign_dependencies():
+ for var in self._vars:
+ assign_ops.append(var.assign(value, use_locking=None, name=name))
+
+ if read_value:
+ with ops.control_dependencies(assign_ops):
+ return self.read_value()
+ else:
+ return control_flow_ops.group(assign_ops)
+
+ with _handle_graph(self.handle), self._assign_dependencies():
+ value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
+ assign_op = gen_resource_variable_ops.assign_variable_op(
+ self.handle, value_tensor, name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_op
+
+ def assign_add(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_add_op
+
+ def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_sub_op
+
+ def get(self):
+ return self._primary_var
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._primary_var._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ return NotImplemented
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+def replicated_fetch_function(var):
+ # pylint: disable=protected-access
+ return ([var._dense_var_to_tensor()], lambda v: v[0])
+ # pylint: enable=protected-access
+
+
+ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
+ops.register_dense_tensor_like_type(ReplicatedVariable)
+session_lib.register_session_run_conversion_functions(
+ ReplicatedVariable, replicated_fetch_function)
+
+
+def replicated_scope(num_replicas):
+ """Variable scope for constructing replicated variables."""
+
+ def _replicated_variable_getter(getter, name, *args, **kwargs):
+ """Getter that constructs replicated variables."""
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ variables = []
+ index = {}
+ for i in range(num_replicas):
+ replica_name = "{}/{}".format(name, i)
+ with ops.device("device:TPU:{}".format(i)):
+ v = getter(*args, name=replica_name, **kwargs)
+ variables.append(v)
+ index[i] = v
+ result = ReplicatedVariable(name, variables)
+
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ if v in l:
+ l.remove(v)
+ g.add_to_collections(collections, result)
+
+ return result
+
+ return variable_scope.variable_scope(
+ "", custom_getter=_replicated_variable_getter)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 7fa06d6d56..0f9f7cd91b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -42,9 +42,9 @@ _BLACKLISTED_OPS = set([
"Placeholder",
])
-# These operations will currently fail to compile, but we should be able to
-# support them eventually via CPU offload or extending our operation set.
-_NOT_IMPLEMENTED_OPS = set([
+# XLA doesn't currently support reading of intermediate tensors, thus some ops
+# are not supported.
+_UNSUPPORTED_OPS = set([
"AudioSummary",
"AudioSummaryV2",
"HistogramSummary",
@@ -78,10 +78,10 @@ def initialize_system(embedding_config=None, job=None):
embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
- job: The job (the XXX in TensorFlow device specification /job:XXX)
- that contains the TPU devices that will be initialized. If job=None
- it is assumed there is only one job in the TensorFlow flock, and an
- error will be returned if this assumption does not hold.
+ job: The job (the XXX in TensorFlow device specification /job:XXX) that
+ contains the TPU devices that will be initialized. If job=None it is
+ assumed there is only one job in the TensorFlow flock, and an error will
+ be returned if this assumption does not hold.
Returns:
A serialized `TopologyProto` that describes the TPU system. Note:
the topology must be evaluated using `Session.run` before it can be used.
@@ -118,9 +118,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
is a unique name.
- We use a `ControlFlowContext` to perform the annotation since it
- integrates with Tensorflow constructs like ResourceVariables. For example,
- if a `ResourceVariable` is constructed inside a tpu.replicate() block, the
+ We use a `ControlFlowContext` to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a tpu.replicate() block, the
`ResourceVariable` implementation can use
`with ops.control_dependencies(None)` to build the variable's definition
outside the replicated computation.
@@ -149,6 +149,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._gradient_colocation_stack = []
self._host_compute_core = []
self._name = name
+ self._name_as_bytes = compat.as_bytes(name)
self._unsupported_ops = []
self._pivot = pivot
self._replicated_vars = {}
@@ -156,8 +157,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def get_replicated_var_handle(self, var):
"""Returns a variable handle for replicated TPU variable 'var'.
- This is an method used by an experimental replicated variable
- implementation and is not intended as a public API.
+ This is a method used by an experimental replicated variable implementation
+ and is not intended as a public API.
Args:
var: The replicated TPU variable.
@@ -210,28 +211,24 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if gradient_uid == "__unsupported__":
raise NotImplementedError(
"No gradient_uid calling gradient within outside_compilation")
- # When we take the gradient of an op X in an
- # outside_compilation cluster C in a forward computation we
- # would like to put the ops corresponding to the gradient of
- # X into a new outside_compilation cluster C'. However, if
- # we take the gradient of X twice, the second one should get
- # yet another new outside_compilation cluster C''.
+ # When we take the gradient of an op X in an outside_compilation
+ # cluster C in a forward computation we would like to put the ops
+ # corresponding to the gradient of X into a new outside_compilation
+ # cluster C'. However, if we take the gradient of X twice, the second
+ # one should get yet another new outside_compilation cluster C''.
#
- # The mechanism we adopt is to use a 'root_cluster' which is
- # the cluster that X was in before we took gradients, and a
- # 'gradient_uid' which is different for every invocation of
- # gradients, and put the gradient of X in cluster
- # 'root_cluster.gradient_uid'.
+ # The mechanism we adopt is to use a 'root_cluster' which is the
+ # cluster that X was in before we took gradients, and a 'gradient_uid'
+ # which is different for every invocation of gradients, and put the
+ # gradient of X in cluster 'root_cluster.gradient_uid'.
#
- # When taking a gradient of a gradient, some ops will be
- # colocated with Op in the forward pass (e.g., cluster
- # root_cluster) and some in the backward pass (e.g., cluster
- # root_cluster.initial_gradient_uid). We need all of the
- # grad-of-grad ops to be in the same cluster to avoid cyclic
- # dependencies between clusters. We adopt a heuristic that
- # puts any op clustered with root_cluster.<xxx> in
- # root_cluster.gradient_uid, even if xxx was
- # initial_gradient_uid.
+ # When taking a gradient of a gradient, some ops will be colocated
+ # with Op in the forward pass (e.g., cluster root_cluster) and some in
+ # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
+ # We need all of the grad-of-grad ops to be in the same cluster to
+ # avoid cyclic dependencies between clusters. We adopt a heuristic
+ # that puts any op clustered with root_cluster.<xxx> in
+ # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
self._in_gradient_colocation = op
parts = outside_attr.split(".")
cluster = parts[0] + "." + gradient_uid
@@ -323,16 +320,13 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return self._host_compute_core
def AddOp(self, op):
- self._AddOpInternal(op)
-
- def _AddOpInternal(self, op):
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
logging.error("Operation of type %s (%s) is not supported on the TPU. "
"Execution will fail if this op is used in the graph. " %
(op.type, op.name))
- if op.type in _NOT_IMPLEMENTED_OPS:
+ if op.type in _UNSUPPORTED_OPS:
self._unsupported_ops.append(op)
if any(x.dtype._is_ref_dtype for x in op.inputs):
@@ -342,7 +336,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if _TPU_REPLICATE_ATTR in op.node_def.attr:
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
- attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
+ attr_value_pb2.AttrValue(s=self._name_as_bytes))
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,
@@ -356,11 +350,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# Remove any control edges from outer control flow contexts. These may cause
# mismatched frame errors.
- control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+ (internal_control_inputs,
+ external_control_inputs) = self._RemoveExternalControlEdges(op)
if not op.inputs:
# Add a control edge from the control pivot to this op.
- if not control_inputs:
+ if not internal_control_inputs:
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot())
# pylint: enable=protected-access
@@ -371,19 +366,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
- if external_inputs:
+ if external_control_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(phawkins): fix that.
with ops.control_dependencies(None):
self.Enter()
- external_inputs = [
+ external_control_inputs = [
array_ops.identity(x.outputs[0]).op
- for x in external_inputs
+ for x in external_control_inputs
if x.outputs
]
self.Exit()
# pylint: disable=protected-access
- op._add_control_inputs(external_inputs)
+ op._add_control_inputs(external_control_inputs)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
@@ -399,6 +394,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._outer_context.AddInnerOp(op)
def AddValue(self, val):
+ """Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context.
result = self._external_values.get(val.name)
@@ -415,7 +411,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return result
def AddInnerOp(self, op):
- self._AddOpInternal(op)
+ self.AddOp(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
@@ -656,13 +652,31 @@ def split_compile_and_replicate(computation,
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
+ # Partitioned variables is not supported (b/112311320).
+ def custom_getter(getter, name, *args, **kwargs):
+ """Variables on TPU have a few restrictions."""
+ partitioner = kwargs["partitioner"]
+ if partitioner is not None:
+ kwargs["partitioner"] = None
+ logging.warning(
+ "Partitioned variables are not supported on TPU. Got "
+ "`partitioner` that is {} for variable {}. "
+ "Setting `partitioner` to `None`."
+ .format(partitioner, name))
+ return getter(name, *args, **kwargs)
+
vscope = variable_scope.get_variable_scope()
+
saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
vscope.set_use_resource(True)
+ vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
+ vscope.set_custom_getter(saved_custom_getter)
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
@@ -765,11 +779,10 @@ def shard(computation,
name=None):
"""Shards `computation` for parallel execution.
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list), each of which has a corresponding split axis (from
- `input_shard_axes`). Each input is split into `num_shards` pieces
- along the corresponding axis, and computation is applied to each
- shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list), each
+ of which has a corresponding split axis (from `input_shard_axes`). Each input
+ is split into `num_shards` pieces along the corresponding axis, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -791,10 +804,9 @@ def shard(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). Each input tensor has a corresponding shard axes, given
- by `input_shard_axes`, which must have size divisible by
- `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). Each
+ input tensor has a corresponding shard axes, given by `input_shard_axes`,
+ which must have size divisible by `num_shards`.
num_shards: The number of shards.
input_shard_axes: A list of dimensions along which to shard `inputs`, or
`None`. `None` means "shard all inputs along dimension 0". If not `None`,
@@ -913,9 +925,9 @@ def batch_parallel(computation,
Convenience wrapper around shard().
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list). Each input is split into `num_shards` pieces along the 0-th
- dimension, and computation is applied to each shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list).
+ Each input is split into `num_shards` pieces along the 0-th dimension, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -933,9 +945,8 @@ def batch_parallel(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). The 0-th dimension of each Tensor must have size
- divisible by `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). The
+ 0-th dimension of each Tensor must have size divisible by `num_shards`.
num_shards: The number of shards.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
@@ -968,14 +979,14 @@ def rewrite(computation,
"""Rewrites `computation` for execution on a TPU system.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
- `computation` may return a list of operations and tensors. Tensors must
+ `computation` may return a list of operations and tensors. Tensors must
come before operations in the returned list. The return value of
`rewrite` is a list of tensors corresponding to the tensors from the
- from `computation`.
+ output of `computation`.
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
@@ -1070,12 +1081,12 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
def validate_inference_rewrite_for_variables(graph):
"""Validates whether rewrite_for_inference() 'worked' for variables.
- The rewrite_for_inference() method is supposed to append
- GuaranteeConstOps after ReadVariableOps, but this mechanism works only
- if you are using tf.get_variable() to create and access variables in your
- tpu computation. This validation method can be called immediately after
- calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
- where added to the graph.
+ The rewrite_for_inference() method is supposed to append GuaranteeConstOps
+ after ReadVariableOps, but this mechanism works only if you are using
+ tf.get_variable() to create and access variables in your tpu computation.
+ This validation method can be called immediately after calling
+ tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
+ to the graph.
Typical usages:
tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
@@ -1089,10 +1100,9 @@ def validate_inference_rewrite_for_variables(graph):
"""
if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
raise RuntimeError(
- "No GuaranteeConst ops found in the graph after "
- "running tpu.rewrite_for_inference(...). Please "
- "check that you are using tf.get_variable() to "
- "create and access variables in your tpu "
+ "No GuaranteeConst ops found in the graph after running "
+ "tpu.rewrite_for_inference(...). Please check that you are using "
+ "tf.get_variable() to create and access variables in your tpu "
"computation.")
@@ -1108,16 +1118,16 @@ def rewrite_for_inference(computation,
in your computation, it moves the ReadVariableOps outside the TPU
computation, and adds GuaranteeConst ops just after the ReadVariableOps.
This mechanism works only if you are using tf.get_variable() to create and
- access variables in your tpu computation. You can validate whether
- this worked, by calling validate_inference_rewrite_for_variables() method
+ access variables in your tpu computation. You can validate whether this
+ worked, by calling validate_inference_rewrite_for_variables() method
immediately after this method to check whether GuaranteeConstOps where
added to the graph.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors. If the function returns m outputs, rewrite
- will return a list of m tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors. If the function returns m outputs, rewrite will return a list of
+ m tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 8d05e081a7..18e0abdda2 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -65,7 +65,7 @@ class TPUConfig(
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
- product(computation_shape) * num_shards.
+ num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
@@ -103,7 +103,7 @@ class TPUConfig(
input mode.
Raises:
- ValueError: If `computation_shape` or `computation_shape` are invalid.
+ ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
"""
def __new__(cls,
@@ -137,7 +137,7 @@ class TPUConfig(
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
- # Parse computation_shape
+ # Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8]:
raise ValueError(
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 806ae1c4c9..19359cb612 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -390,12 +390,6 @@ class _InternalTPUContext(object):
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
return True
- if mode != model_fn_lib.ModeKeys.PREDICT:
- return False
-
- # There are actually 2 use cases when running with mode.PREDICT: prediction
- # and saving the model. We run actual predictions on the TPU, but
- # model export is run on the CPU.
if is_export_mode:
return True
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index f221155568..1ff04f5c26 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -762,9 +762,13 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
if not is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
'input pipeline configuration.')
+
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
- # TODO(b/XXX): Add predict support for PER_HOST_V2
- raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
+ inputs = _InputsWithStoppingSignals(
+ dataset=inputs.dataset,
+ batch_size=ctx.batch_size_for_input_fn,
+ add_padding=True,
+ num_invocations_per_step=ctx.num_of_replicas_per_host)
hooks.append(inputs.dataset_initializer_hook())
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
@@ -774,6 +778,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
control_deps = []
per_host_sharded_inputs = []
num_replicas_per_host = ctx.num_of_replicas_per_host
+ cached_signals = None
with ops.device(device):
if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
@@ -781,21 +786,32 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
features, labels = inputs.features_and_labels() # Calls get_next()
+ signals = inputs.signals()
+
+ # All the replicas share the replica 0's stopping singal.
+ # This avoids inconsistent state among different model replcias.
+ if cached_signals:
+ signals['stopping'] = cached_signals['stopping']
+ else:
+ cached_signals = signals
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
- features, labels))
+ features, labels, signals))
control_deps.extend(flattened_inputs)
per_host_sharded_inputs.append(flattened_inputs)
if inputs_structure_recorder.flattened_input_dims:
+ input_partition_dims = inputs_structure_recorder.flattened_input_dims
+ if signals:
+ input_partition_dims += [None] * len(signals)
# pylint: disable=protected-access
infeed_queue = tpu_feed._PartitionedInfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]),
host_id=host_id,
- input_partition_dims=inputs_structure_recorder.flattened_input_dims,
+ input_partition_dims=input_partition_dims,
device_assignment=ctx.device_assignment)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs)
@@ -807,7 +823,13 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
tpu_ordinal_function=tpu_ordinal_function_impl)
captured_infeed_queue.capture(infeed_queue)
- return per_host_enqueue_ops
+ if signals is None:
+ return per_host_enqueue_ops
+ else:
+ return {
+ 'ops': per_host_enqueue_ops,
+ 'signals': signals,
+ }
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
@@ -2124,9 +2146,10 @@ class TPUEstimator(estimator_lib.Estimator):
mode=model_fn_lib.ModeKeys.PREDICT,
export_tags=None,
check_variables=True):
- if mode != model_fn_lib.ModeKeys.PREDICT:
+ if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
raise NotImplementedError(
- 'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
+ 'TPUEstimator only handles mode PREDICT for exporting '
+ 'when `export_to_tpu` is `True`; '
'got {}.'.format(mode))
(super(TPUEstimator, self).
@@ -2424,16 +2447,12 @@ class TPUEstimator(estimator_lib.Estimator):
with self._ctx.with_mode(mode) as ctx:
model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
- if mode != model_fn_lib.ModeKeys.PREDICT:
+ # `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
+ # but not in `export_savedmodel()`.
+ if self._is_input_fn_invoked:
is_export_mode = False
else:
- # For export_savedmodel, input_fn is never passed to Estimator. So, by
- # checking the self._is_input_fn_invoked bit, we can know, given the
- # mode == PREDICT, it is the .predict API, not export_savedmodel API.
- if self._is_input_fn_invoked:
- is_export_mode = False
- else:
- is_export_mode = True
+ is_export_mode = True
# Clear the bit.
self._is_input_fn_invoked = None
@@ -2805,8 +2824,6 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
- num_cores = ctx.num_cores
-
(single_tpu_predict_step, host_calls, captured_scaffold_fn,
captured_predict_hooks
) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
@@ -2825,7 +2842,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
(dummy_predict_op,) = tpu.shard(
multi_tpu_predict_steps_on_single_shard,
inputs=[],
- num_shards=num_cores,
+ num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
@@ -3043,16 +3060,48 @@ class _Inputs(object):
class _InputsWithStoppingSignals(_Inputs):
"""Inputs with `_StopSignals` inserted into the dataset."""
- def __init__(self, dataset, batch_size, add_padding=False):
+ def __init__(self,
+ dataset,
+ batch_size,
+ add_padding=False,
+ num_invocations_per_step=1):
assert dataset is not None
-
user_provided_dataset = dataset.map(
_InputsWithStoppingSignals.insert_stopping_signal(
stop=False, batch_size=batch_size, add_padding=add_padding))
- final_batch_dataset = dataset.take(1).map(
- _InputsWithStoppingSignals.insert_stopping_signal(
- stop=True, batch_size=batch_size, add_padding=add_padding))
+ if num_invocations_per_step == 1:
+ final_batch_dataset = dataset.take(1).map(
+ _InputsWithStoppingSignals.insert_stopping_signal(
+ stop=True, batch_size=batch_size, add_padding=add_padding))
+ else:
+ # We append (2 * num_invocations_per_step - 1) batches for exhausting the
+ # user_provided_dataset and stop properly.
+ # For example, if num_invocations_per_step is 2, we append 3 additional
+ # padding batches: b1, b2, b3.
+ # If user_provided_dataset contains two batches: a1, a2
+ # Step 1: [a1, a2]
+ # Step 2: [b1, b2] -> STOP
+ # If user_provided_dataset contains three batches: a1, a2, a3.
+ # The training loops:
+ # Step 1: [a1, a2]
+ # Step 2: [a3, b1]
+ # Step 3: [b2, b3] -> STOP.
+ final_batch_dataset = dataset.take(1).map(
+ _InputsWithStoppingSignals.insert_stopping_signal(
+ stop=True, batch_size=batch_size, add_padding=add_padding))
+ final_batch_dataset = final_batch_dataset.repeat(
+ 2 * num_invocations_per_step - 1)
+
+ def _set_mask(data_dict):
+ signals = data_dict['signals']
+ signals['padding_mask'] = array_ops.ones_like(signals['padding_mask'])
+ data_dict['signals'] = signals
+ return data_dict
+
+ # Mask out the extra batch.
+ final_batch_dataset = final_batch_dataset.map(_set_mask)
+
dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
index 3e90957e6d..bd530fdc3a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py
@@ -286,6 +286,59 @@ class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(sliced_features)
+ def test_slice_with_multi_invocations_per_step(self):
+ num_samples = 3
+ batch_size = 2
+
+ params = {'batch_size': batch_size}
+ input_fn, (a, b) = make_input_fn(num_samples=num_samples)
+
+ with ops.Graph().as_default():
+ dataset = input_fn(params)
+ inputs = tpu_estimator._InputsWithStoppingSignals(
+ dataset, batch_size, add_padding=True, num_invocations_per_step=2)
+ hook = inputs.dataset_initializer_hook()
+ features, _ = inputs.features_and_labels()
+ signals = inputs.signals()
+
+ sliced_features = (
+ tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))
+
+ with session.Session() as sess:
+ hook.begin()
+ hook.after_create_session(sess, coord=None)
+
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual(a[:batch_size], result['a'])
+ self.assertAllEqual(b[:batch_size], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # This is the final partial batch.
+ result, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertEqual(1, len(result['a']))
+ self.assertAllEqual(a[batch_size:num_samples], result['a'])
+ self.assertAllEqual(b[batch_size:num_samples], result['b'])
+ self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
+
+ # We should see 3 continuous batches with STOP ('1') as signals and all
+ # of them have mask 1.
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+
+ _, evaluated_signals = sess.run([sliced_features, signals])
+ self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
+ self.assertAllEqual([1.] * batch_size,
+ evaluated_signals['padding_mask'])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(sliced_features)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index 53d33f4077..1e11de6421 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
@@ -44,8 +43,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
- group_assignment: Optional list of group ids for applying the optimizer
- to subgroups.
+ group_assignment: Optional 2d int32 lists with shape
+ [num_groups, num_replicas_per_group] which describles how to apply
+ optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -74,11 +74,22 @@ class CrossShardOptimizer(optimizer.Optimizer):
"""
if not group_assignment:
return None
- if len(group_assignment) != num_shards:
- raise ValueError("The size of group_assignment does not equal to "
- "num_shard({0}). Got group_assignment={1}".format(
- num_shards, self._group_assignment))
- subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if not (isinstance(group_assignment, list) and
+ all(isinstance(i, list) for i in group_assignment)):
+ raise ValueError("group_assignment must be a list of list. Got {}".format(
+ group_assignment))
+
+ replica_ids = set()
+ for g in group_assignment:
+ for i in g:
+ replica_ids.add(i)
+
+ if set(range(num_shards)) != replica_ids:
+ raise ValueError("group_assignment must be a permutation of range({0})."
+ " Got group_assignment={1}".format(
+ num_shards, group_assignment))
+
+ subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else:
@@ -186,3 +197,7 @@ class CrossShardOptimizer(optimizer.Optimizer):
A list of strings.
"""
return self._opt.get_slot_names(*args, **kwargs)
+
+ def variables(self):
+ """Forwarding the variables from the underlying optimizer."""
+ return self._opt.variables()
diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py
index edd71fb250..3547e71184 100644
--- a/tensorflow/contrib/training/__init__.py
+++ b/tensorflow/contrib/training/__init__.py
@@ -14,7 +14,9 @@
# ==============================================================================
"""Training and input utilities.
-See @{$python/contrib.training} guide.
+See
+[Contrib Training](https://tensorflow.org/api_guides/python/contrib.training)
+guide.
@@batch_sequences_with_states
@@NextQueuedSequenceBatch
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index df07ff44ee..afeef978f3 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -108,7 +108,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch1, expected_seq4_batch2,
key=None, make_keys_unique=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_batch = sqss.batch_sequences_with_states(
input_key=key if key is not None else self.key,
input_sequences=self.sequences,
@@ -332,7 +332,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
"seq4": self.sequences["seq4"],
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
"value: 4. Consider setting pad=True."):
@@ -508,7 +508,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
sequences = {
"key_1": constant_op.constant([1, 2, 3]), # length 3
"key_2": constant_op.constant([1.5, 2.5]) # length 2
@@ -520,7 +520,7 @@ class PaddingTest(test.TestCase):
padded_seq["key_1"].eval()
def testPadding(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
sequences = {
"key_1": constant_op.constant([1, 2]),
"key_2": constant_op.constant([0.5, -1.0]),
@@ -549,7 +549,7 @@ class PaddingTest(test.TestCase):
val2 = np.array([9, 12])
shape2 = np.array([5])
- with ops.Graph().as_default() as g, self.test_session(graph=g):
+ with ops.Graph().as_default() as g, self.session(graph=g):
sp_tensor1 = sparse_tensor.SparseTensor(
indices=array_ops.constant(ind1, dtypes.int64),
values=array_ops.constant(val1, dtypes.int64),
diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py
index 504f1fcd41..b259e0ee83 100644
--- a/tensorflow/contrib/training/python/training/bucket_ops_test.py
+++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py
@@ -112,7 +112,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(32):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -162,7 +162,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[None], [None, None], [None, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(15):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -204,7 +204,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(64):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -286,7 +286,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(128):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -405,7 +405,7 @@ class BucketBySequenceLengthTest(test.TestCase):
num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
num_pairs_dequeued)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
# Feed the inputs, then close the input thread.
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index 01bac891da..16a647bf66 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -296,6 +296,7 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
def begin(self):
if self._replace_summary_op:
+ # This can still remain None if there are no summaries.
self._summary_op = summary.merge_all()
self._global_step = training_util.get_or_create_global_step()
@@ -304,10 +305,12 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
self._summary_writer = summary.FileWriterCache.get(self._log_dir)
def end(self, session):
- global_step = training_util.global_step(session, self._global_step)
- summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_op is not None:
+ global_step = training_util.global_step(session, self._global_step)
+ summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_writer:
+ self._summary_writer.add_summary(summary_str, global_step)
if self._summary_writer:
- self._summary_writer.add_summary(summary_str, global_step)
self._summary_writer.flush()
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index c36d00e842..ddd135f047 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -67,7 +67,7 @@ class CheckpointIteratorTest(test.TestCase):
global_step = variables.get_or_create_global_step()
saver = saver_lib.Saver() # Saves the global step.
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib.global_variables_initializer())
save_path = os.path.join(checkpoint_dir, 'model.ckpt')
saver.save(session, save_path, global_step=global_step)
@@ -427,9 +427,11 @@ class EvaluateRepeatedlyTest(test.TestCase):
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
return names_to_values, names_to_updates
- def _verify_summaries(self, output_dir, names_to_values):
+ def _verify_events(self, output_dir, names_to_values):
"""Verifies that the given `names_to_values` are found in the summaries.
+ Also checks that a GraphDef was written out to the events file.
+
Args:
output_dir: An existing directory where summaries are found.
names_to_values: A dictionary of strings to values.
@@ -440,7 +442,13 @@ class EvaluateRepeatedlyTest(test.TestCase):
self.assertEqual(len(output_filepath), 1)
events = summary_iterator.summary_iterator(output_filepath[0])
- summaries = [e.summary for e in events if e.summary.value]
+ summaries = []
+ graph_def = None
+ for event in events:
+ if event.summary.value:
+ summaries.append(event.summary)
+ elif event.graph_def:
+ graph_def = event.graph_def
values = []
for summary in summaries:
for value in summary.value:
@@ -448,6 +456,7 @@ class EvaluateRepeatedlyTest(test.TestCase):
saved_results = {v.tag: v.simple_value for v in values}
for name in names_to_values:
self.assertAlmostEqual(names_to_values[name], saved_results[name], 5)
+ self.assertIsNotNone(graph_def)
def testSummariesAreFlushedToDisk(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed')
@@ -475,7 +484,23 @@ class EvaluateRepeatedlyTest(test.TestCase):
],
max_number_of_evaluations=1)
- self._verify_summaries(logdir, names_to_values)
+ self._verify_events(logdir, names_to_values)
+
+ def testSummaryAtEndHookWithoutSummaries(self):
+ logdir = os.path.join(self.get_temp_dir(),
+ 'summary_at_end_hook_without_summaires')
+ if gfile.Exists(logdir):
+ gfile.DeleteRecursively(logdir)
+
+ with ops.Graph().as_default():
+ # Purposefully don't add any summaries. The hook will just dump the
+ # GraphDef event.
+ hook = evaluation.SummaryAtEndHook(log_dir=logdir)
+ hook.begin()
+ with self.cached_session() as session:
+ hook.after_create_session(session, None)
+ hook.end(session)
+ self._verify_events(logdir, {})
if __name__ == '__main__':
diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py
index 774241a816..8665a24883 100644
--- a/tensorflow/contrib/training/python/training/resample_test.py
+++ b/tensorflow/contrib/training/python/training/resample_test.py
@@ -44,7 +44,7 @@ class ResampleTest(test.TestCase):
([3], [0, 0, 0]),
([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]),
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for inputs, expected in cases:
array_inputs = numpy.array(inputs, dtype=numpy.int32)
actual = sess.run(resample._repeat_range(array_inputs))
@@ -65,7 +65,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init) # initialize
# outputs
@@ -112,7 +112,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
expected_sum_op = math_ops.reduce_sum(vals)
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init)
expected_sum = n * s.run(expected_sum_op)
@@ -147,7 +147,7 @@ class ResampleTest(test.TestCase):
resampled = resample.resample_at_rate([vals], rates)
- with self.test_session() as s:
+ with self.cached_session() as s:
rs, = s.run(resampled, {
vals: list(range(count)),
rates: numpy.zeros(
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index bf7fb4fd48..1aeff7dc80 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -146,7 +146,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_label in illegal_labels:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([val_tf, lbl_tf],
feed_dict={label_ph: illegal_label,
@@ -154,7 +154,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_prob in illegal_probs:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([prob_tf],
feed_dict={label_ph: valid_labels,
@@ -172,7 +172,7 @@ class StratifiedSampleTest(test.TestCase):
summary_op = logging_ops.merge_summary(
ops.get_collection(ops.GraphKeys.SUMMARIES))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -197,7 +197,7 @@ class StratifiedSampleTest(test.TestCase):
batch_size,
init_probs=[0, .3, 0, .7, 0],
enqueue_many=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -228,7 +228,7 @@ class StratifiedSampleTest(test.TestCase):
# Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([val_tf, labels_tf],
feed_dict={vals_ph: vals,
labels_ph: labels})
@@ -253,7 +253,7 @@ class StratifiedSampleTest(test.TestCase):
self.assertEqual(len(val_list), len(val_input_batch))
self.assertTrue(isinstance(lbls, ops.Tensor))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -283,7 +283,7 @@ class StratifiedSampleTest(test.TestCase):
# Run session and keep track of how frequently the labels and values appear.
data_l = []
label_l = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
@@ -374,7 +374,7 @@ class RejectionSampleTest(test.TestCase):
'rejection_sample/prob_with_checks:0')
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for illegal_prob in [-0.1, 1.1]:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
@@ -393,7 +393,7 @@ class RejectionSampleTest(test.TestCase):
sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn,
batch_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
index ca78c0029e..73ad859ab3 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
@@ -59,7 +59,7 @@ class SamplingOpsThreadingTest(test.TestCase):
out_tensor = queue.dequeue()
# Run the multi-threaded session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
index 7aebd9d9fe..8932b905c9 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import test
class SequenceQueueingStateSaverTest(test.TestCase):
def testSequenceInputWrapper(self):
- with self.test_session():
+ with self.cached_session():
length = 3
key = "key"
padded_length = 4
@@ -54,7 +54,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertTrue(isinstance(input_wrapper.context["context1"], ops.Tensor))
def testStateSaverWithTwoSimpleSteps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 2
batch_size = constant_op.constant(batch_size_value)
num_unroll = 2
@@ -159,7 +159,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(0, state_saver.barrier.ready_size().eval())
def testStateSaverFailsIfPaddedLengthIsNotMultipleOfNumUnroll(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
bad_padded_length = 3
@@ -194,7 +194,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
})
def _testStateSaverFailsIfCapacityTooSmall(self, batch_size):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_unroll = 2
length = array_ops.placeholder(dtypes.int32)
key = array_ops.placeholder(dtypes.string)
@@ -243,7 +243,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self._testStateSaverFailsIfCapacityTooSmall(batch_size)
def testStateSaverFailsIfInconsistentPaddedLength(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -282,7 +282,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverFailsIfInconsistentWriteState(self):
# TODO(b/26910386): Identify why this infrequently causes timeouts.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(1)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -326,7 +326,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverWithManyInputsReadWriteThread(self):
batch_size_value = 32
num_proc_threads = 100
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -490,7 +490,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertGreater(processed_count[0], 2 * 20 * batch_size_value)
def testStateSaverProcessesExamplesInOrder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 32
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
@@ -563,7 +563,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(get_ready_size.eval(), 0)
def testStateSaverCanHandleVariableBatchsize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = array_ops.placeholder(dtypes.int32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
index 4a46e9a49e..3269d5fef2 100644
--- a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -62,7 +62,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
"""Get an array with learning rate values from the consecutive steps
using current tensorflow implementation."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
@@ -76,7 +76,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
"""Compare values generated by tensorflow implementation to the values
generated by the original implementation
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
- with self.test_session():
+ with self.cached_session():
lr = 10.0
init_steps = 2
t_mul = 3
@@ -92,7 +92,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testMDecay(self):
"""Test m_mul argument. Check values for learning rate at the beginning
of the first, second, third and fourth period. """
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.1
@@ -121,7 +121,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testCos(self):
"""Check learning rate values at the beginning, in the middle
and at the end of the period."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.2
t_e = 1000
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index df0a186f4f..d9b0511a98 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -79,7 +79,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0, 0]], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([[1, 0, 0]], value_1)
@@ -101,7 +101,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertEqual([1], value_1)
@@ -126,7 +126,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]],
array_ops.expand_dims(
value[0], axis=0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([0, 1], value_0)
value_1, _ = sess.run([value, enqueue_zeroth])
@@ -147,7 +147,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i)
for i in range(1000)
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run((value, enqueue_many_more))
self.assertEqual([0], value_0)
rest = []
@@ -174,7 +174,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = 0
while i < 4:
received, _ = sess.run((value, enqueue))
@@ -199,7 +199,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
batch_size=1, padded_shapes=[2]))
iterator = dataset.make_one_shot_iterator()
_, value = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
r"Incompatible input shapes at component 0 between "
r"input dataset this dataset: \[3\] vs. \[2\]"):
@@ -224,7 +224,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
np.array(
[[1]], dtype=np.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
"mismatched number of tensors. Queue expects 1 tensors but "
"tried to insert 2"):
@@ -274,7 +274,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
with ops.control_dependencies([enqueue_rest_op]):
calc = array_ops.identity(value_head)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc))
self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc))
self.assertAllEqual([[6, 6]], sess.run(calc))
@@ -304,7 +304,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
_, (unused_count, padded_value) = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]],
sess.run(padded_value))
self.assertAllEqual([[6] * 6], sess.run(padded_value))
diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py
index 94cf7788b2..3b524ac8c7 100644
--- a/tensorflow/contrib/training/python/training/training_test.py
+++ b/tensorflow/contrib/training/python/training/training_test.py
@@ -62,7 +62,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms(
gradients_to_variables, 3.0)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -75,7 +75,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)(
gradients_to_variables)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -122,7 +122,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -155,7 +155,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -186,7 +186,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -209,7 +209,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -535,7 +535,7 @@ class TrainTest(test.TestCase):
train_biases = training.create_train_op(
total_loss, optimizer, variables_to_train=[biases])
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize the variables.
session.run(variables_lib2.global_variables_initializer())
diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py
index 08741cf8ca..338acef63f 100644
--- a/tensorflow/contrib/util/__init__.py
+++ b/tensorflow/contrib/util/__init__.py
@@ -15,7 +15,7 @@
"""Utilities for dealing with Tensors.
-See @{$python/contrib.util} guide.
+See [Contrib Util](https://tensorflow.org/api_guides/python/contrib.util) guide.
@@constant_value
@@make_tensor_proto
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h
index 2cfaa4986c..e07085502f 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_client.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
-#define TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
+#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
@@ -47,4 +47,4 @@ class GrpcVerbsClient {
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
index abe5e08b07..cfb9b7ddd7 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
-#define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
+#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
#include "grpcpp/impl/codegen/async_stream.h"
#include "grpcpp/impl/codegen/async_unary_call.h"
@@ -86,4 +86,4 @@ class VerbsService GRPC_FINAL {
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ad3dce1784..d4951b156c 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
}
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
- string key(std::move(parsed.FullKey().ToString()));
+ string key(parsed.FullKey());
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
Device* dst_dev;
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
index 5cd0a3533a..6277bc4b41 100644
--- a/tensorflow/contrib/verbs/verbs_util.h
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_RDMA_UTIL_H_
-#define TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
+#define TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
#include <string>
@@ -30,4 +30,4 @@ class VerbsUtil {
};
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 64430a1418..79ad3b8e54 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -375,6 +375,7 @@ cc_library(
":lib_platform",
":platform_base",
"//tensorflow/core/platform/default/build_config:port",
+ "@com_google_absl//absl/base",
"@snappy",
],
)
@@ -611,6 +612,7 @@ cc_library(
copts = tf_copts(),
deps = tf_lib_proto_parsing_deps() + [
":platform_base",
+ "@com_google_absl//absl/strings",
"@double_conversion//:double-conversion",
],
)
@@ -668,8 +670,11 @@ cc_library(
"lib/io/table_builder.h",
"lib/io/table_options.h",
"lib/math/math_util.h",
+ "lib/monitoring/collected_metrics.h",
+ "lib/monitoring/collection_registry.h",
"lib/monitoring/counter.h",
"lib/monitoring/gauge.h",
+ "lib/monitoring/metric_def.h",
"lib/monitoring/sampler.h",
"lib/random/distribution_sampler.h",
"lib/random/philox_random.h",
@@ -690,6 +695,24 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "feature_util",
+ srcs = ["example/feature_util.cc"],
+ hdrs = [
+ "example/feature_util.h",
+ "platform/types.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_stringpiece",
+ ":platform_protobuf",
+ ":protos_all_cc",
],
)
@@ -726,10 +749,12 @@ cc_library(
# required to use tf_cc_test, and that rule will change / into _
cc_library(
name = "core_stringpiece",
- srcs = ["lib/core/stringpiece.cc"],
hdrs = ["lib/core/stringpiece.h"],
copts = tf_copts(),
- deps = [":platform_base"],
+ deps = [
+ ":platform_base",
+ "@com_google_absl//absl/strings",
+ ],
)
# Test support library needed for all tests
@@ -864,7 +889,6 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stat_summarizer_options.h",
- "util/status_util.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -931,15 +955,6 @@ cc_library(
)
cc_library(
- name = "status_util",
- hdrs = ["util/status_util.h"],
- deps = [
- ":graph",
- ":lib",
- ],
-)
-
-cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],
hdrs = ["framework/reader_base.h"],
@@ -1339,6 +1354,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_softmax_op",
+ "//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:mkl_aggregate_ops",
]) + if_cuda([
@@ -1572,6 +1588,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":mobile_additional_lib_deps",
":protos_all_cc_impl",
":stats_calculator_portable",
"//third_party/eigen3",
@@ -1582,6 +1599,13 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "mobile_additional_lib_deps",
+ deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
# Native library support for iOS applications.
#
# bazel build --config=ios_x86_64 \
@@ -1613,6 +1637,7 @@ cc_library(
copts = tf_copts() + ["-Os"] + ["-std=c++11"],
visibility = ["//visibility:public"],
deps = [
+ ":mobile_additional_lib_deps",
":protos_all_cc_impl",
":stats_calculator_portable",
"//third_party/eigen3",
@@ -2009,9 +2034,6 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"lib/io/zlib_compression_options.h",
"lib/io/zlib_inputstream.h",
"lib/io/zlib_outputbuffer.h",
- "lib/monitoring/collected_metrics.h",
- "lib/monitoring/collection_registry.h",
- "lib/monitoring/metric_def.h",
"lib/monitoring/mobile_counter.h",
"lib/monitoring/mobile_gauge.h",
"lib/monitoring/mobile_sampler.h",
@@ -2059,6 +2081,7 @@ cc_library(
],
}),
deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
"//third_party/eigen3",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
@@ -2076,7 +2099,6 @@ cc_library(
exclude = [
"**/*test*",
"framework/variant.cc",
- "lib/core/stringpiece.cc",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
@@ -2091,7 +2113,6 @@ cc_library(
) + tf_additional_lib_srcs(
exclude = [
"**/*test*",
- "lib/core/stringpiece.cc",
"platform/**/cuda.h",
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
@@ -2202,6 +2223,7 @@ cc_library(
":lib",
":lib_internal",
"//tensorflow/core/platform/default/build_config:png",
+ "@com_google_absl//absl/strings",
"@zlib_archive//:zlib",
],
)
@@ -2217,7 +2239,7 @@ cc_library(
"platform/macros.h",
"platform/platform.h",
"platform/types.h",
- ],
+ ] + if_windows(["platform/windows/integral_types.h"]),
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
@@ -2252,6 +2274,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
],
)
@@ -2260,6 +2283,8 @@ cc_library(
srcs = if_android([
"lib/gif/gif_io.cc",
"platform/gif.h",
+ "lib/strings/strcat.h",
+ "lib/strings/numbers.h",
]),
hdrs = [
"lib/bfloat16/bfloat16.h",
@@ -2281,6 +2306,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
],
)
@@ -2308,6 +2334,7 @@ cc_library(
linkopts = ["-ldl"],
deps = [
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
"@png_archive//:png",
],
)
@@ -2350,6 +2377,7 @@ tf_generate_proto_text_sources(
srcs = COMMON_PROTO_SRCS,
protodeps = ERROR_CODES_PROTO_SRCS,
srcs_relative_dir = "tensorflow/core/",
+ visibility = ["//visibility:public"],
deps = [
":error_codes_proto_text",
":lib_internal",
@@ -2451,6 +2479,7 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
+ includes = ["../../external/com_google_absl"],
deps = [
":lib",
":lib_internal",
@@ -2462,6 +2491,7 @@ cc_header_only_library(
cc_header_only_library(
name = "core_cpu_headers_lib",
+ visibility = ["//visibility:public"],
deps = [
":core_cpu_lib",
],
@@ -2534,6 +2564,11 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_headers_lib",
+ extra_deps = [
+ # ABSL headers get dropped, so we add them back here.
+ "@com_google_absl//absl/strings",
+ ],
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2543,6 +2578,7 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
+ includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -2585,6 +2621,7 @@ tf_cuda_library(
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
+ visibility = ["//visibility:public"],
deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
)
@@ -2694,12 +2731,13 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/base_collective_executor.h",
"common_runtime/bfc_allocator.h",
- "common_runtime/broadcaster.h",
+ "common_runtime/hierarchical_tree_broadcaster.h",
"common_runtime/buf_rendezvous.h",
"common_runtime/build_graph_options.h",
"common_runtime/collective_executor_mgr.h",
"common_runtime/collective_param_resolver_local.h",
"common_runtime/collective_rma_local.h",
+ "common_runtime/collective_util.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/costmodel_manager.h",
@@ -2712,6 +2750,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
+ "common_runtime/lower_while_op.h",
"common_runtime/memory_types.h",
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
@@ -2730,6 +2769,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
+ "common_runtime/tracing_device.h",
"common_runtime/visitable_allocator.h",
"common_runtime/process_state.h",
"common_runtime/pool_allocator.h",
@@ -2744,12 +2784,12 @@ tf_cuda_library(
"common_runtime/allocator_retry.cc",
"common_runtime/base_collective_executor.cc",
"common_runtime/bfc_allocator.cc",
- "common_runtime/broadcaster.cc",
"common_runtime/buf_rendezvous.cc",
"common_runtime/build_graph_options.cc",
"common_runtime/collective_executor_mgr.cc",
"common_runtime/collective_param_resolver_local.cc",
"common_runtime/collective_rma_local.cc",
+ "common_runtime/collective_util.cc",
"common_runtime/constant_folding.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/costmodel_manager.cc",
@@ -2764,8 +2804,10 @@ tf_cuda_library(
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
+ "common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/local_device.cc",
"common_runtime/lower_if_op.cc",
+ "common_runtime/lower_while_op.cc",
"common_runtime/memory_types.cc",
"common_runtime/mkl_cpu_allocator.cc",
"common_runtime/optimization_registry.cc",
@@ -3190,18 +3232,15 @@ tf_cc_tests(
"lib/core/status_test.cc",
"lib/core/stringpiece_test.cc",
"lib/core/threadpool_test.cc",
- "lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/compactptrset_test.cc",
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
- "lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
"lib/gtl/manual_constructor_test.cc",
"lib/gtl/map_util_test.cc",
- "lib/gtl/optional_test.cc",
"lib/gtl/top_n_test.cc",
"lib/hash/crc32c_test.cc",
"lib/hash/hash_test.cc",
@@ -3527,7 +3566,6 @@ tf_cc_tests(
"util/semver_test.cc",
"util/sparse/sparse_tensor_test.cc",
"util/stat_summarizer_test.cc",
- "util/status_util_test.cc",
"util/tensor_format_test.cc",
"util/tensor_slice_reader_test.cc",
"util/tensor_slice_set_test.cc",
@@ -3552,7 +3590,6 @@ tf_cc_tests(
":ops",
":protos_all_cc",
":protos_test_cc",
- ":status_util",
":test",
":test_main",
":testlib",
@@ -3650,10 +3687,10 @@ tf_cc_tests_gpu(
)
tf_cc_tests_gpu(
- name = "broadcaster_test",
+ name = "hierarchical_tree_broadcaster_test",
size = "small",
srcs = [
- "common_runtime/broadcaster_test.cc",
+ "common_runtime/hierarchical_tree_broadcaster_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
@@ -3691,6 +3728,7 @@ tf_cc_test_mkl(
":core_cpu_internal",
":framework",
":framework_internal",
+ ":lib",
":test",
":test_main",
":testlib",
@@ -4045,6 +4083,7 @@ tf_cuda_cc_test(
":testlib",
"//third_party/eigen3",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4086,6 +4125,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
# Link with support for TensorFlow Debugger (tfdbg).
"//tensorflow/core/debug",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4570,6 +4610,29 @@ tf_cc_tests(
],
)
+tf_cc_tests(
+ name = "common_runtime_lower_while_op_test",
+ size = "small",
+ srcs = ["common_runtime/lower_while_op_test.cc"],
+ deps = [
+ ":all_kernels",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ ],
+)
+
# Test data
filegroup(
name = "image_testdata",
diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
index b90f5473c8..6341eeda32 100644
--- a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
@@ -82,7 +82,7 @@ END
}
summary: "Update \'*var\' according to the Adam algorithm."
description: <<END
-$$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
diff --git a/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..5604a1a89e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,9 @@
+op {
+ graph_op_name: "DivNoNan"
+ summary: "Returns 0 if the denominator is zero."
+ description: <<END
+
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..1658472209
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "EnsureShape"
+ in_arg {
+ name: "input"
+ description: <<END
+A tensor, whose shape is to be validated.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A tensor with the same shape and contents as the input tensor or value.
+END
+ }
+ attr {
+ name: "shape"
+ description: <<END
+The expected (possibly partially specified) shape of the input tensor.
+END
+ }
+ summary: "Ensures that the tensor's shape matches the expected shape."
+ description: <<END
+Raises an error if the input tensor's shape does not match the specified shape.
+Returns the input tensor otherwise.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index ffd01ba5cc..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,3 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
index 58262a385c..37d1a9dcbf 100644
--- a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
@@ -27,5 +27,15 @@ For example:
fill([2, 3], 9) ==> [[9, 9, 9]
[9, 9, 9]]
```
+
+`tf.fill` differs from `tf.constant` in a few ways:
+
+* `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+ Tensor values.
+* `tf.fill` creates an Op in the computation graph that constructs the actual
+ Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+ the entire Tensor into the graph with a `Const` node.
+* Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+ based on other runtime Tensors, unlike `tf.constant`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
index d7b56aec87..46da1de1c3 100644
--- a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
@@ -1,32 +1,5 @@
op {
graph_op_name: "MatrixExponential"
- in_arg {
- name: "input"
- description: <<END
-Shape is `[..., M, M]`.
-END
- }
- out_arg {
- name: "output"
- description: <<END
-Shape is `[..., M, M]`.
-
-@compatibility(scipy)
-Equivalent to scipy.linalg.expm
-@end_compatibility
-END
- }
- summary: "Computes the matrix exponential of one or more square matrices:"
- description: <<END
-\\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
-
-The exponential is computed using a combination of the scaling and squaring
-method and the Pade approximation. Details can be founds in:
-Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
-
-The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-form square matrices. The output is a tensor of the same shape as the input
-containing the exponential for all input submatrices `[..., :, :]`.
-END
+ visibility: SKIP
+ summary: "Deprecated, use python implementation tf.linalg.matrix_exponential."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000000..27bc4013c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ParallelInterleaveDatasetV2"
+ visibility: HIDDEN
+ attr {
+ name: "f"
+ description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..3de2f18fc2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ in_arg {
+ name: "dense_defaults"
+ description: <<END
+A dict mapping string keys to `Tensor`s.
+The keys of the dict must match the dense_keys of the feature.
+END
+ }
+ attr {
+ name: "sparse_keys"
+ description: <<END
+A list of string keys in the examples features.
+The results for these keys will be returned as `SparseTensor` objects.
+END
+ }
+ attr {
+ name: "dense_keys"
+ description: <<END
+A list of Ndense string Tensors (scalars).
+The keys expected in the Examples features associated with dense values.
+END
+ }
+ attr {
+ name: "sparse_types"
+ description: <<END
+A list of `DTypes` of the same length as `sparse_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+END
+ }
+ attr {
+ name: "Tdense"
+ description: <<END
+A list of DTypes of the same length as `dense_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+
+END
+ }
+ attr {
+ name: "dense_shapes"
+ description: <<END
+List of tuples with the same length as `dense_keys`.
+The shape of the data for each dense feature referenced by `dense_keys`.
+Required for any input tensors identified by `dense_keys`. Must be
+either fully defined, or may contain an unknown first dimension.
+An unknown first dimension means the feature is treated as having
+a variable number of blocks, and the output shape along this dimension
+is considered unknown at graph build time. Padding is applied for
+minibatch elements smaller than the maximum number of blocks for the
+given feature along this dimension.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..b1cb9a696d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,112 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ in_arg {
+ name: "serialized"
+ description: <<END
+A vector containing binary serialized SequenceExample protos.
+END
+ }
+ in_arg {
+ name: "debug_name"
+ description: <<END
+A vector containing the names of the serialized protos.
+May contain, for example, table key (descriptive) name for the
+corresponding serialized proto. This is purely useful for debugging
+purposes, and the presence of values here has no effect on the output.
+May also be an empty vector if no name is available.
+END
+ }
+ in_arg {
+ name: "context_dense_defaults"
+ description: <<END
+A list of Ncontext_dense Tensors (some may be empty).
+context_dense_defaults[j] provides default values
+when the SequenceExample's context map lacks context_dense_key[j].
+If an empty Tensor is provided for context_dense_defaults[j],
+then the Feature context_dense_keys[j] is required.
+The input type is inferred from context_dense_defaults[j], even when it's
+empty. If context_dense_defaults[j] is not empty, its shape must match
+context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ description: <<END
+A vector listing the
+FeatureList keys which may be missing from the SequenceExamples. If the
+associated FeatureList is missing, it is treated as empty. By default,
+any FeatureList not listed in this vector must exist in the SequenceExamples.
+END
+ }
+ attr {
+ name: "context_sparse_keys"
+ description: <<END
+A list of Ncontext_sparse string Tensors (scalars).
+The keys expected in the Examples' features associated with context_sparse
+values.
+END
+ }
+ attr {
+ name: "context_dense_keys"
+ description: <<END
+A list of Ncontext_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' context features associated with
+dense values.
+END
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ description: <<END
+A list of Nfeature_list_sparse string Tensors
+(scalars). The keys expected in the FeatureLists associated with sparse
+values.
+END
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ description: <<END
+A list of Nfeature_list_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' feature_lists associated
+with lists of dense values.
+END
+ }
+ attr {
+ name: "context_sparse_types"
+ description: <<END
+A list of Ncontext_sparse types; the data types of data in
+each context Feature given in context_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "context_dense_shapes"
+ description: <<END
+A list of Ncontext_dense shapes; the shapes of data in
+each context Feature given in context_dense_keys.
+The number of elements in the Feature corresponding to context_dense_key[j]
+must always equal context_dense_shapes[j].NumEntries().
+The shape of context_dense_values[j] will match context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ description: <<END
+A list of Nfeature_list_sparse types; the data types
+of data in each FeatureList given in feature_list_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ description: <<END
+A list of Nfeature_list_dense shapes; the shapes of
+data in each FeatureList given in feature_list_dense_keys.
+The shape of each Feature in the FeatureList corresponding to
+feature_list_dense_key[j] must always equal
+feature_list_dense_shapes[j].NumEntries().
+END
+ }
+ summary: "Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
index 8cef243aee..30fd97a0d7 100644
--- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -9,7 +9,7 @@ END
in_arg {
name: "pattern"
description: <<END
-A 1-D string tensor of the regular expression to match the input.
+A scalar string tensor containing the regular expression to match the input.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
index ad0aeac004..2dcd136ae3 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
@@ -76,7 +76,7 @@ END
}
summary: "Update \'*var\' according to the Adam algorithm."
description: <<END
-$$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
index 1a75e67c0c..e400c7402b 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
@@ -70,5 +70,7 @@ The resulting update to ref would look like this:
See `tf.scatter_nd` for more details about how to make updates to
slices.
+
+See also `tf.scatter_update` and `tf.batch_scatter_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
index 4804908afc..4037dee432 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
@@ -59,5 +59,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
</div>
+
+See also `tf.batch_scatter_update` and `tf.scatter_nd_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 5e2912fcdd..d33a36ce06 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \max_j(data_j)\\) where `max` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index a7d85b3f4e..afdc39da96 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the mean along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index 74fc598218..026b5b3991 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \min_j(data_j)\\) where `min` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 4c4363e524..a168eed87f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \prod_j data_j\\) where the product is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index 583ab3904f..876b860824 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \sum_j data_j\\) where sum is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
index 866e04e97b..138a6366c8 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the mean along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
index af4bc75fa0..b8073d88ac 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
@@ -30,7 +30,8 @@ END
Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
index 194bcea726..945bbdcf62 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
@@ -23,7 +23,8 @@ END
description: <<END
N is the size of the segment being reduced.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
index 8b502928a5..ff328c8a61 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -32,7 +32,8 @@ N is the size of the segment being reduced.
Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
index dfd50bf273..a68e14607f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the sum along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
index 3bc16577ff..aa5c1fc8d0 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
@@ -30,8 +30,9 @@ END
Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
For example:
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000000..6d9d9908ca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "StaticRegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
new file mode 100644
index 0000000000..e382bcec81
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "StaticRegexReplace"
+ in_arg {
+ name: "input"
+ description: "The text to be processed."
+ }
+ out_arg {
+ name: "output"
+ description: "The text after applying pattern and rewrite."
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ attr {
+ name: "rewrite"
+ description: "The rewrite to be applied to the matched expresion."
+ }
+ attr {
+ name: "replace_global"
+ description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match."
+ }
+ summary: "Replaces the match of pattern in input with rewrite."
+ description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
index 8d6fc04847..9a89a4e8e7 100644
--- a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
@@ -32,7 +32,7 @@ END
description: <<END
a bitmask where a bit i being 1 means to ignore the begin
value and instead use the largest interval possible. At runtime
-begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
`[-1, n-1]` if `stride[i] < 0`
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..3022fccb1e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,12 @@
+op {
+ graph_op_name: "TensorListGather"
+ summary: "Creates a Tensor by indexing into the TensorList."
+ description: <<END
+Each row in the produced Tensor corresponds to the element in the TensorList
+specified by the given index (see `tf.gather`).
+
+input_handle: The input tensor list.
+indices: The indices used to index into the list.
+values: The tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..35194b353e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "TensorListScatter"
+ summary: "Creates a TensorList by indexing into a Tensor."
+ description: <<END
+Each member of the TensorList corresponds to one row of the input tensor,
+specified by the given index (see `tf.gather`).
+
+tensor: The input tensor.
+indices: The indices used to index into the list.
+element_shape: The shape of the elements in the list (can be less specified than
+ the shape of the tensor).
+output_handle: The TensorList.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 82c913d15e..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,5 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- summary: "Returns 0 if the denominator is zero."
- description: ""
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 4ca6780c95..7a60e4387a 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -3,33 +3,36 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
-END
+A tensor whose shape is a prefix of `data.shape`.END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the maximum such that:
-\\(output_i = \max_j data_j\\) where max is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the maximum is empty for a given segment ID `i`, it outputs the smallest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::lowest()`.
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
+
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
</div>
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 55ea69b5dd..7e139ddf4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -3,31 +3,35 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the minimum such that:
-\\(output_i = \min_j data_j\\) where min is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the minimum is empty for a given segment ID `i`, it outputs the largest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::max()`.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index 577ff53d60..9c8ea3b620 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -3,30 +3,34 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the product of all
entries belonging to a segment such that:
-\\(output_i = \prod_j data_j\\) where the product is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+`j...` such that `segment_ids[j...] == i`.
If there is no entry for a given segment ID `i`, it outputs 1.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index 9aeabd030d..7e5d9265c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -16,11 +16,12 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
-\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
need not be sorted and need not cover all values in the full
range of valid values.
diff --git a/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..1bf3fba3c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "DivNoNan"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..4414d973ac
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EnsureShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index 7f721f4fb7..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..45826b6fdc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..4a7e75ba0e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt
new file mode 100644
index 0000000000..c1edef8c9d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterNdSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 56caabcf3c..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 637b43c844..5b01f7fa03 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -14,13 +14,28 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/base_collective_executor.h"
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include <algorithm>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
#define VALUE_IN_DEBUG_STRING false
@@ -83,7 +98,7 @@ class CollectiveAdapterImpl : public CollectiveAdapter {
// If necessary, flatten output.
void Flatten() {
- if (old_shape_.dims() > 1) {
+ if (old_shape_.dims() != 1) {
TensorShape new_shape = TensorShape({old_shape_.num_elements()});
DMAHelper::UnsafeSetShape(&output_, new_shape);
}
@@ -211,104 +226,67 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
};
Tensor* output = ctx->mutable_output(0);
- string error;
- switch (col_params.instance.type) {
- case REDUCTION_COLLECTIVE: {
- // TODO(tucker): support other reduction algorithms,
- // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc.
- const Tensor* input = &ctx->input(0);
- RingReducer* reducer =
- CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_,
- input, output, &error);
- if (!reducer) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- // TODO(tucker): Instead of forking every per-device Collective
- // Op off into its own thread, consider queuing them on a
- // fixed-size thread-pool dedicated to running CollectiveOps.
- SchedClosure([reducer, done_safe]() {
- reducer->Run([reducer, done_safe](const Status& s) {
- done_safe(s);
- delete reducer;
- });
- });
- } break;
-
- case BROADCAST_COLLECTIVE: {
- Broadcaster* broadcaster = CreateBroadcaster(
- ctx, CtxParams(ctx), col_params, exec_key, step_id_, output, &error);
- if (!broadcaster) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- SchedClosure([broadcaster, done_safe]() {
- broadcaster->Run([broadcaster, done_safe](const Status& s) {
- done_safe(s);
- delete broadcaster;
- });
- });
- } break;
-
- default:
- done_safe(errors::Internal("Unimplemented CollectiveType ",
- col_params.instance.type));
+ const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
+ (col_params.instance.type == BROADCAST_COLLECTIVE &&
+ col_params.is_source))
+ ? &ctx->input(0)
+ : nullptr;
+ CollectiveImplementationInterface* col_impl = nullptr;
+ Status status = CreateCollective(col_params, &col_impl);
+ if (!status.ok()) {
+ done_safe(status);
+ DCHECK_EQ(nullptr, col_impl);
+ return;
}
-}
-
-RingReducer* BaseCollectiveExecutor::CreateReducer(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output, string* error) {
- switch (col_params.instance.data_type) {
- case DT_INT32:
- if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Reduce does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
- }
- TF_FALLTHROUGH_INTENDED;
- case DT_FLOAT:
- case DT_DOUBLE:
- case DT_INT64:
- return new RingReducer(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, input, output);
- break;
- default:
- *error = strings::StrCat("Collective Reduce does not support datatype ",
- col_params.instance.data_type);
- return nullptr;
+ CollectiveContext* col_ctx =
+ new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
+ exec_key, step_id_, input, output);
+ status = col_impl->InitializeCollectiveContext(col_ctx);
+ if (!status.ok()) {
+ done_safe(status);
+ delete col_ctx;
+ delete col_impl;
+ return;
}
+ // Run in an I/O thread, so as not to starve the executor threads.
+ // TODO(b/80529858): Instead of forking every per-device Collective
+ // Op off into its own thread, consider queuing them on a
+ // fixed-size thread-pool dedicated to running CollectiveOps.
+ SchedClosure([col_impl, col_ctx, done_safe]() {
+ col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
+ done_safe(s);
+ delete col_ctx;
+ delete col_impl;
+ });
+ });
}
-Broadcaster* BaseCollectiveExecutor::CreateBroadcaster(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- Tensor* output, string* error) {
+Status BaseCollectiveExecutor::CreateCollective(
+ const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl) {
+ *col_impl = nullptr;
+ Status status;
switch (col_params.instance.data_type) {
case DT_INT32:
if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Broadcast does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype DT_INT32 on "
+ "DEVICE_GPU");
}
TF_FALLTHROUGH_INTENDED;
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT64: {
- return new Broadcaster(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, output);
- } break;
+ status = CollectiveRegistry::Lookup(
+ col_params.instance.impl_details.collective_name, col_impl);
+ break;
+ }
default:
- *error =
- strings::StrCat("Collective Broadcast does not support datatype ",
- DataTypeString(col_params.instance.data_type));
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype ",
+ col_params.instance.data_type);
}
+ return status;
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 3af9286264..360ce4db7b 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -15,15 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
+#include <memory>
#include <string>
+
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class Broadcaster;
+class CollectiveImplementation;
class DeviceMgr;
-class RingReducer;
+class Device;
// Helper interface that aliases regular subfields of a Tensor as separate
// Tensors for in-place update.
@@ -133,18 +135,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_;
private:
- RingReducer* CreateReducer(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output,
- string* error);
-
- Broadcaster* CreateBroadcaster(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- Tensor* output, string* error);
+ Status CreateCollective(const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3bf0532491..84c6285bbe 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -596,7 +596,7 @@ string BFCAllocator::RenderOccupancy() {
region_offset += region.memory_size();
}
- return std::string(rendered, resolution);
+ return string(rendered, resolution);
}
void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 580e61e2ea..20e1dab1d5 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
#include <array>
#include <memory>
@@ -451,4 +451,4 @@ class BFCAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
deleted file mode 100644
index e1c6b21939..0000000000
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ /dev/null
@@ -1,300 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
-
-#include "tensorflow/core/common_runtime/collective_rma_local.h"
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/platform/env.h"
-
-// Set true for greater intelligibility of debug mode log messages.
-#define READABLE_KEYS false
-
-namespace tensorflow {
-
-namespace {
-// Key to be used for BufRendezvous by Broadcaster.
-string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
- int dst_rank) {
- if (READABLE_KEYS) {
- return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
- "):src(", src_rank, "):dst(", dst_rank, ")");
- } else {
- // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash.
- return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
- }
-}
-} // namespace
-
-Broadcaster::Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- col_params_(col_params),
- exec_key_(exec_key),
- rank_(col_params.subdiv_rank[0]),
- is_source_(col_params.is_source),
- output_(output),
- done_(nullptr),
- device_(nullptr) {}
-
-void Broadcaster::Run(StatusCallback done) {
- // The optimal data transfer choreography is going to very platform dependent.
- // That will be addressed by later improvements here or by platform-specific
- // overrides of collective broadcast. The initial version is simply
- // a binary tree that completely ignores DeviceLocality.
- done_ = std::move(done);
-
- // Get the device for which we're executing and look up its locality.
- status_ = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status_.ok()) {
- done_(status_);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- RunTree();
-}
-
-// Binary tree parent/child relations are trivial to calculate, i.e.
-// device at rank r is the parent of 2r+1 and 2r+2. The one exception
-// is if the source is not rank 0. We treat that case as though the
-// source is appended to the front of the rank ordering as well as
-// continuing to occupy its current position. Hence we calculate as
-// though each device's rank is actually r+1, then subtract 1 again to
-// get the descendent ranks. If the source is not rank 0 then its
-// descendants include both {0,1} and the descendents of its current
-// position. Where a non-0-rank source is a descendent of another
-// device, no send to it is necessary.
-
-/* static*/
-int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return -1;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
- if (my_rank == source_rank) return -1;
- if (source_rank == 0) {
- return (my_rank - 1) / 2;
- } else {
- int predecessor_rank = (my_rank / 2) - 1;
- return (predecessor_rank < 0) ? source_rank : predecessor_rank;
- }
-}
-
-/* static */
-void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv,
- std::vector<int>* targets) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
-
- int group_size = 0;
- for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
- if (impl.subdiv_permutations[subdiv][i] >= 0) {
- group_size++;
- }
- }
-
- targets->clear();
- int successor_rank = 0;
- if (source_rank == 0) {
- successor_rank = (2 * my_rank) + 1;
- } else {
- successor_rank = (2 * (my_rank + 1));
- }
- DCHECK_NE(successor_rank, my_rank);
- if (cp.is_source && source_rank != 0) {
- // The source sends to rank 0,1 in addition to its positional
- // descendants.
- if (group_size > 1) {
- targets->push_back(0);
- }
- if (group_size > 2 && source_rank != 1) {
- targets->push_back(1);
- }
- }
- for (int i = 0; i < 2; ++i) {
- if (successor_rank < group_size && successor_rank != source_rank) {
- targets->push_back(successor_rank);
- }
- ++successor_rank;
- }
-}
-
-// Executes a hierarchical tree broadcast.
-// Each subdiv is a broadcast between a subset of the devices.
-// If there is only one task, there is one subdiv comprising a broadcast between
-// all devices belonging to the task.
-// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
-// subdiv, one device from each task participates in a binary tree broadcast.
-// Each task receives a copy of the tensor on one device via this broadcast.
-// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
-// corresponds to broadcast between all devices on task i. Thus, each task
-// participates in at most 2 subdivs.
-void Broadcaster::RunTree() {
- int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
- // TODO(ayushd): this is easily improved when a node participates in both
- // first and second subdivision. It would first send to its descendents in
- // the first subdiv, then wait until all pending ops are finished before
- // sending to descendents in second subdiv. A better implementation would
- // collapse the two send blocks.
- for (int si = 0; si < num_subdivs; si++) {
- int my_rank = col_params_.subdiv_rank[si];
- // If rank is -1, this device does not participate in this subdiv.
- if (-1 == my_rank) continue;
- int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si];
- if (VLOG_IS_ON(1)) {
- string subdiv_buf;
- for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) {
- strings::StrAppend(&subdiv_buf, r, ",");
- }
- VLOG(1) << "Running Broadcast tree device=" << device_->name()
- << " subdiv=" << si << " perm=" << subdiv_buf
- << " my_rank=" << my_rank << " source_rank=" << source_rank;
- }
-
- mutex mu; // also guards status_ while callbacks are pending
- int pending_count = 0; // GUARDED_BY(mu)
- condition_variable all_done;
-
- if (my_rank >= 0 && my_rank != source_rank) {
- // Begin by receiving the value.
- int recv_from_rank = TreeRecvFrom(col_params_, si);
- Notification note;
- DispatchRecv(si, recv_from_rank, my_rank, output_,
- [this, &mu, &note](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- note.Notify();
- });
- note.WaitForNotification();
- }
-
- // Then forward value to all descendent devices.
- if (my_rank >= 0 && status_.ok()) {
- std::vector<int> send_to_ranks;
- TreeSendTo(col_params_, si, &send_to_ranks);
- for (int i = 0; i < send_to_ranks.size(); ++i) {
- int target_rank = send_to_ranks[i];
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DispatchSend(si, target_rank, my_rank,
- (is_source_ ? &ctx_->input(0) : output_),
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (pending_count == 0) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // For the original source device, we copy input to output if they are
- // different.
- // If there is only 1 subdiv, we do this in that subdiv. If there is more
- // than 1 subdiv, then the original source device will participate in 2
- // subdivs - the global inter-task broadcast and one local intra-task
- // broadcast. In this case, we perform the copy in the second subdiv for
- // this device.
- if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
- VLOG(2) << "copying input to output for device=" << device_->name()
- << " subdiv=" << si;
- const Tensor* input = &ctx_->input(0);
- if (input != output_ &&
- (DMAHelper::base(input) != DMAHelper::base(output_))) {
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- CollectiveRemoteAccessLocal::MemCpyAsync(
- op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
- ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (0 == pending_count) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // Then wait for all pending actions to complete.
- {
- mutex_lock l(mu);
- if (pending_count > 0) {
- all_done.wait(l);
- }
- }
- }
- VLOG(2) << "device=" << device_->name() << " return status " << status_;
- done_(status_);
-}
-
-void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank,
- const Tensor* src_tensor,
- const StatusCallback& done) {
- string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int dst_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank];
- VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
- << device_->name() << " to_device "
- << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv
- << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
- col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx],
- col_params_.instance.task_names[dst_idx], send_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), src_tensor,
- device_locality_, done);
-}
-
-void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank,
- Tensor* dst_tensor, const StatusCallback& done) {
- string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int src_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank];
- VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device "
- << col_params_.instance.device_names[src_idx] << " to_device "
- << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank
- << " src_idx=" << src_idx;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx],
- col_params_.instance.task_names[src_idx],
- col_params_.task.is_local[src_idx], recv_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, 0 /*stream_index*/, done);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h
index 9eb9f060f6..065bbd008b 100644
--- a/tensorflow/core/common_runtime/buf_rendezvous.h
+++ b/tensorflow/core/common_runtime/buf_rendezvous.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
-#define TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
#include <functional>
#include <string>
@@ -100,4 +100,4 @@ class BufRendezvous {
void PurgeTable(const Status& s, HookTable* table);
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 9de6ab8968..d53aca85b9 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -72,4 +72,4 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 2a14493a67..3b2dc6a050 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -14,7 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include <stddef.h>
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -179,7 +192,9 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
int next_rank = 0;
while (true) {
selected.insert(next_device);
- DevRec* dr = &(*tdm)[next_device];
+ auto next_dev_it = tdm->find(next_device);
+ CHECK(next_dev_it != tdm->end());
+ DevRec* dr = &next_dev_it->second;
dr->local_rank = next_rank;
++next_rank;
if (selected.size() == tdm->size()) {
@@ -193,9 +208,15 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
parsed_name.id = il.device_id();
string endpoint_device =
DeviceNameUtils::ParsedNameToString(parsed_name);
+ // Skip the device if we've already seen it.
if (selected.find(endpoint_device) != selected.end()) {
continue;
}
+ // Skip the device if it is not participating in this collective
+ // instance.
+ if (tdm->find(endpoint_device) == tdm->end()) {
+ continue;
+ }
if (best_link == nullptr || il.strength() > best_link->strength()) {
best_link = &il;
}
@@ -319,206 +340,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
}
} // namespace
-int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) {
- int num_tasks = static_cast<int>(dev_per_task.size());
- int task_lo = 0;
- int task_hi;
- for (int ti = 0; ti < num_tasks; ti++) {
- task_hi = task_lo + dev_per_task[ti];
- if (task_lo <= device_rank && device_rank < task_hi) return ti;
- task_lo += dev_per_task[ti];
- }
- LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
- << " devices";
- return -1;
-}
-
-void CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- const string& device, int source_rank, const std::vector<int>& dev_per_task,
- CollectiveParams* cp) {
- if (VLOG_IS_ON(1)) {
- string dpt_buf;
- for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
- VLOG(1) << "GenerateBcastSubdivPerms device=" << device
- << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf;
- }
- int num_tasks = cp->group.num_tasks;
- // If there is just 1 task, then execute binary tree broadcast over all
- // devices. Otherwise, the first subdiv is inter-task broadcast, and then
- // there are N more subdivs, where N is #task.
- int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
- int total_num_devices = 0;
- for (int num_dev : dev_per_task) total_num_devices += num_dev;
-
- cp->instance.impl_details.subdiv_permutations.resize(num_subdivs);
- cp->subdiv_rank.reserve(num_subdivs);
- cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
-
- // Inter-task subdiv. Pick one device from each task - this is the source
- // device if it belongs to that task, or device 0 for that task. If a device
- // does not participate in the subdiv, set subdiv_rank to -1.
- if (num_tasks > 1) {
- const int sdi = 0;
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int device_count = 0;
- int source_task = GetDeviceTask(source_rank, dev_per_task);
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- bool participate = false;
- if (source_task == ti) {
- // Source device belongs to this task.
- perm.push_back(source_rank);
- participate = cp->instance.device_names[source_rank] == device;
- } else {
- // Source does not belong to this task, choose dev 0.
- perm.push_back(device_count);
- participate = cp->instance.device_names[device_count] == device;
- }
- if (participate) cp->subdiv_rank.push_back(ti);
- device_count += dev_per_task[ti];
- }
- if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(source_task);
- }
-
- // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
- // source to dev 0 for that task if it does not contain original source, else
- // set to rank of original source. If a device does not participate in the
- // subdiv, set subdiv_rank to -1;
- int abs_di = 0;
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- const int sdi = ti + (num_tasks > 1 ? 1 : 0);
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- bool participate = false;
- int subdiv_source = 0;
- for (int di = 0; di < dev_per_task[ti]; di++) {
- perm.push_back(abs_di);
- if (cp->instance.device_names[abs_di] == device) {
- participate = true;
- cp->subdiv_rank.push_back(di);
- }
- if (abs_di == source_rank) subdiv_source = di;
- abs_di++;
- }
- if (!participate) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source);
- }
-
- for (int sri = 0; sri < num_subdivs; sri++) {
- CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0);
- }
-}
-
-// Establish the requested number of subdivision permutations based on the
-// ring order implicit in the device order.
-/*static*/
-void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
- int source_rank,
- CollectiveParams* cp) {
- // Each subdiv permutation is a ring formed by rotating each
- // single-task subsequence of devices by an offset. This makes most
- // sense when each task has the same number of devices but we can't
- // depend on that being the case so we'll compute something that
- // works in any case.
-
- // Start by counting the devices in each task.
- // Precondition: device_names must be sorted so that all devices in
- // the same task are adjacent.
- VLOG(2) << "Sorted task names: "
- << str_util::Join(cp->instance.task_names, ", ");
- std::vector<int> dev_per_task;
- const string* prior_task_name = &cp->instance.task_names[0];
- int dev_count = 1;
- for (int di = 1; di < cp->group.group_size; ++di) {
- if (cp->instance.task_names[di] != *prior_task_name) {
- dev_per_task.push_back(dev_count);
- dev_count = 1;
- prior_task_name = &cp->instance.task_names[di];
- } else {
- ++dev_count;
- }
- }
- dev_per_task.push_back(dev_count);
- CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
-
- CHECK(cp->instance.type == REDUCTION_COLLECTIVE ||
- cp->instance.type == BROADCAST_COLLECTIVE);
- if (cp->instance.type == REDUCTION_COLLECTIVE) {
- // Generate a ring permutation for each requested offset.
- CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
- VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
- << &cp->instance.impl_details.subdiv_permutations;
- cp->instance.impl_details.subdiv_permutations.resize(
- cp->instance.impl_details.subdiv_offsets.size());
- cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
- ++sdi) {
- std::vector<int>& perm =
- cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- // A negative subdivision offset is interpreted as follows:
- // 1. Reverse the local device ordering.
- // 2. Begin the subdivision at abs(offset) in the reversed ordering.
- bool reverse = false;
- if (offset < 0) {
- offset = abs(offset);
- reverse = true;
- }
- int prior_dev_count = 0; // sum over prior worker device counts
- for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
- for (int di = 0; di < dev_per_task[ti]; ++di) {
- int di_offset = (di + offset) % dev_per_task[ti];
- int offset_di =
- reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
- // Device index in global subdivision permutation.
- int permuted_di = prior_dev_count + offset_di;
- int rank = static_cast<int>(perm.size());
- perm.push_back(permuted_di);
- if (cp->instance.device_names[permuted_di] == device) {
- CHECK_EQ(permuted_di, cp->default_rank);
- cp->subdiv_rank[sdi] = rank;
- }
- }
- prior_dev_count += dev_per_task[ti];
- }
- CHECK_EQ(cp->group.group_size, perm.size());
- }
- } else if (cp->instance.type == BROADCAST_COLLECTIVE) {
- GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp);
- }
-
- if (VLOG_IS_ON(1)) {
- // Log the computed ring order for each subdiv.
- string buf;
- for (int sdi = 0;
- sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
- buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
- for (int di = 0;
- di < cp->instance.impl_details.subdiv_permutations[sdi].size();
- ++di) {
- int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
- if (idx >= 0) {
- CHECK_GT(cp->instance.device_names.size(), idx);
- strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
- }
- }
- strings::StrAppend(&buf, " subdiv_offsets: ");
- for (auto o : cp->instance.impl_details.subdiv_offsets)
- strings::StrAppend(&buf, o, " ");
- strings::StrAppend(&buf, " SubdivRank: ");
- for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
- if (cp->instance.type == BROADCAST_COLLECTIVE) {
- strings::StrAppend(&buf, " subdiv_source_rank: ");
- for (auto src : cp->instance.impl_details.subdiv_source_rank)
- strings::StrAppend(&buf, src, " ");
- }
- VLOG(1) << buf;
- }
- }
-}
-
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
@@ -594,6 +415,10 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
});
}
+// NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
+// to all devices that they are physically connected to and visible to the
+// TensorFlow runtime. This set of devices may be a superset of the devices
+// participating in this instance of collectives.
void CollectiveParamResolverLocal::CompleteDefaultRanking(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const std::vector<DeviceLocality>& localities) {
@@ -785,29 +610,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// Populate the fields common across task, also default_rank.
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
+ // TODO(b/113171733): we need a better way to pick the collective
+ // implementation. The ideal way would depend upon the topology and link
+ // strength before picking a particular implementation.
+ cp->instance.impl_details.collective_name =
+ (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast"
+ : "RingReduce";
+ CollectiveImplementationInterface* col_impl;
+ Status lookup_status = CollectiveRegistry::LookupParamResolverInstance(
+ cp->instance.impl_details.collective_name, &col_impl);
+ if (!lookup_status.ok()) {
+ done(lookup_status);
+ return;
+ }
// If broadcast, may need to wait for source discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
CompleteInstanceSource(ir, cp, is_source,
- [this, ir, device, cp, done](InstanceRec* irec) {
+ [col_impl, ir, device, cp, done](InstanceRec* irec) {
CHECK_EQ(ir, irec);
Status s;
- int source_rank;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
- source_rank = irec->source_rank;
+ cp->source_rank = irec->source_rank;
}
if (s.ok()) {
- GenerateSubdivPerms(device, source_rank, cp);
+ s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
- return;
} else {
- GenerateSubdivPerms(device, 0, cp);
+ done(col_impl->InitializeCollectiveParams(cp));
}
- done(Status::OK());
}
void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 2e2aa801d9..c5c3497e28 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#include <functional>
+#include <memory>
+#include <set>
#include <string>
+#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// Used to complete/verify CollInstance.
struct InstanceRec;
+
typedef std::function<void(InstanceRec*)> IRConsumer;
struct InstanceRec {
// This structure has two mutexes so that a possibly long
@@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);
- friend class CollectiveParamResolverLocalTest;
- // Establishes the requested number of subdivision permutations based on the
- // ring order implicit in the device order.
- static void GenerateSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp);
- // Establishes the subdivisions for broadcast op. The first subdiv executes
- // binary tree bcast with one device per task. Each subsequent subdiv
- // executes intra-task binary tree broadcast.
- static void GenerateBcastSubdivPerms(const string& device, int source_rank,
- const std::vector<int>& dev_per_task,
- CollectiveParams* cp);
-
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
@@ -237,4 +230,4 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 9ea23b72d2..9e1e2e8d5b 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -44,31 +44,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
task_name));
}
- void GenSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp) {
- CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp);
- }
-
- // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the
- // generated subdiv perms, ranks, and source ranks match the expected values.
- void BcastSubdivPerms(
- CollectiveParams* cp, const std::vector<int>& dev_per_task,
- int device_rank, int source_rank,
- const std::vector<std::vector<int>>& expected_subdiv_perms,
- const std::vector<int>& expected_subdiv_rank,
- const std::vector<int>& expected_subdiv_source_rank) {
- cp->subdiv_rank.clear();
- cp->instance.impl_details.subdiv_permutations.clear();
- cp->instance.impl_details.subdiv_source_rank.clear();
- CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- cp->instance.device_names[device_rank], source_rank, dev_per_task, cp);
- EXPECT_EQ(expected_subdiv_perms,
- cp->instance.impl_details.subdiv_permutations);
- EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
- EXPECT_EQ(expected_subdiv_source_rank,
- cp->instance.impl_details.subdiv_source_rank);
- }
-
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@@ -114,7 +89,6 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
EXPECT_FALSE(cps[i].is_source);
EXPECT_EQ(cps[i].default_rank, i);
@@ -161,188 +135,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- ASSERT_GT(cps[i].subdiv_rank.size(), 0);
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
- ASSERT_GT(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
- EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank[0], 1);
EXPECT_EQ(cps[i].is_source, (i == 1));
EXPECT_EQ(cps[i].default_rank, i);
EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
}
}
-TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
- CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
- cp.group.group_key = 1;
- cp.group.group_size = kNumDevs;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
- cp.instance.instance_key = 3;
- cp.instance.type = REDUCTION_COLLECTIVE;
- cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
- cp.instance.impl_details.subdiv_offsets.push_back(0);
- cp.is_source = false;
- for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
- string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
- string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
- cp.instance.task_names.push_back(task_name);
- cp.instance.device_names.push_back(device_name);
- }
-
- int test_rank = 0;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {0, 4};
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- std::vector<int> expected_0 = {0, 1, 2, 3, 4, 5, 6, 7,
- 8, 9, 10, 11, 12, 13, 14, 15,
- 16, 17, 18, 19, 20, 21, 22, 23};
- std::vector<int> expected_1 = {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
- 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(4, cp.subdiv_rank[1]);
-
- test_rank = 3;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {3, -3};
- cp.instance.impl_details.subdiv_permutations.clear();
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- expected_0 = {3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
- 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18};
- expected_1 = {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
- 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(1, cp.subdiv_rank[1]);
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 1;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int i = 0; i < 8; i++) {
- string dev_name =
- strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i);
- cp.instance.device_names.push_back(dev_name);
- }
- std::vector<int> dev_per_task = {8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {0});
-
- // source 2 device 2
- BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2},
- {2});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {2});
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < 8; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
- std::vector<int> dev_per_task = {8, 8, 8, 8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 9
- BcastSubdivPerms(&cp, dev_per_task, 9, 9,
- {{0, 9, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
-}
-
-TEST_F(CollectiveParamResolverLocalTest,
- GenerateBcastSubdivPerms4TasksVariableGPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- std::vector<int> dev_per_task = {4, 4, 6, 8};
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < dev_per_task[ti]; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 5
- BcastSubdivPerms(&cp, dev_per_task, 5, 9,
- {{0, 4, 9, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 44408438b9..2188087957 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/collective.h"
@@ -89,4 +89,4 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc
new file mode 100644
index 0000000000..195521a078
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_util.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+/*static*/
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality) {
+ if (!dev_mgr) {
+ return errors::Internal("Required non-null dev_mgr ", dev_mgr,
+ " for InitializeDeviceAndLocality");
+ }
+
+ Status status = dev_mgr->LookupDevice(device_name, device);
+ if (status.ok()) {
+ CHECK(*device);
+ *device_locality = (*device)->attributes().locality();
+ } else {
+ LOG(ERROR) << "Failed to find device " << device_name;
+ for (auto d : dev_mgr->ListDevices()) {
+ LOG(ERROR) << "Available devices " << d->name();
+ }
+ }
+ return status;
+}
+
+/*static*/
+string SubdivPermDebugString(const CollectiveParams& col_params) {
+ const auto& subdiv_perms =
+ col_params.instance.impl_details.subdiv_permutations;
+ string buf;
+ for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) {
+ strings::StrAppend(&buf, "Subdiv ", sdi, " device order:\n");
+ for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
+ int idx = subdiv_perms[sdi][di];
+ if (idx >= 0) {
+ CHECK_GT(col_params.instance.device_names.size(), idx);
+ strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n");
+ }
+ }
+ strings::StrAppend(&buf, " subdiv_offsets: ");
+ for (auto o : col_params.instance.impl_details.subdiv_offsets)
+ strings::StrAppend(&buf, o, " ");
+ strings::StrAppend(&buf, " SubdivRank: ");
+ for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " ");
+ if (col_params.instance.type == BROADCAST_COLLECTIVE) {
+ strings::StrAppend(&buf, " subdiv_source_rank: ");
+ for (auto src : col_params.instance.impl_details.subdiv_source_rank)
+ strings::StrAppend(&buf, src, " ");
+ }
+ strings::StrAppend(&buf, "\n");
+ }
+ return buf;
+}
+
+} // namespace collective_util
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h
new file mode 100644
index 0000000000..ebb5731bec
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality);
+string SubdivPermDebugString(const CollectiveParams& col_params);
+
+} // namespace collective_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index b5a51d2526..97b6971c5b 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -37,6 +37,8 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/denormal.h"
+#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -553,6 +555,11 @@ bool ReplaceTensorWithConstant(
Status ConstantFold(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph, bool* was_mutated) {
+ // TensorFlow flushes denormals to zero and rounds to nearest, so we do
+ // the same here.
+ port::ScopedFlushDenormal flush;
+ port::ScopedSetRound round(FE_TONEAREST);
+
DumpGraph("Before", graph);
ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
if (generate_new_name == nullptr) {
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index 84598880bb..a9a84f761b 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
-#define TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/function.h"
@@ -66,4 +66,4 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h
index e0fa983373..797a0ade53 100644
--- a/tensorflow/core/common_runtime/debugger_state_interface.h
+++ b/tensorflow/core/common_runtime/debugger_state_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
#include <memory>
@@ -117,4 +117,4 @@ class DebugGraphDecoratorRegistry {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index b537666492..81d68e3be4 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -26,8 +26,8 @@ limitations under the License.
// * Task numbers are within the specified replica, so there are as
// many "task zeros" as replicas.
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
#include <memory>
#include <string>
@@ -183,4 +183,4 @@ class Device : public DeviceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
index 10eb62afa8..db50226fe8 100644
--- a/tensorflow/core/common_runtime/device_factory.h
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
#include <string>
#include <vector>
@@ -126,4 +126,4 @@ class Registrar {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index cd93f76324..c1ff10d9b5 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#include <string>
#include <unordered_map>
@@ -77,4 +77,4 @@ class DeviceMgr {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
index 098eccdf84..bb6ff2efa0 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.h
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
#include <string>
@@ -45,4 +45,4 @@ class DeviceResolverLocal : public DeviceResolverInterface {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
index 4cd56e583c..c384d46e97 100644
--- a/tensorflow/core/common_runtime/device_set.h
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
#include <memory>
#include <unordered_map>
@@ -86,4 +86,4 @@ class DeviceSet {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index bf1d78ec65..eb388202fa 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -451,8 +451,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
#ifndef __ANDROID__
- // Set up for collectives if the RunOption declares a key.
- if (run_options.experimental().collective_graph_key() > 0) {
+ // Set up for collectives if ExecutorsAndKeys declares a key.
+ if (executors_and_keys->collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ if (run_options.experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ // If a collective_graph_key was specified in run_options, ensure that it
+ // matches what came out of GraphExecutionState::BuildGraph().
+ if (run_options.experimental().collective_graph_key() !=
+ executors_and_keys->collective_graph_key) {
+ return errors::Internal(
+ "collective_graph_key in RunOptions ",
+ run_options.experimental().collective_graph_key(),
+ " should match collective_graph_key from optimized graph ",
+ executors_and_keys->collective_graph_key);
+ }
+ }
if (!collective_executor_mgr_) {
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
@@ -678,10 +692,16 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
+ run_state_args.collective_graph_key =
+ run_options.experimental().collective_graph_key();
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
+ {
+ mutex_lock l(collective_graph_key_lock_);
+ collective_graph_key_ = executors_and_keys->collective_graph_key;
+ }
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
@@ -1116,6 +1136,8 @@ Status DirectSession::CreateExecutors(
BuildGraphOptions options;
options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run;
+ options.collective_graph_key =
+ callable_options.run_options().experimental().collective_graph_key();
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1123,9 +1145,9 @@ Status DirectSession::CreateExecutors(
ek->callable_options = callable_options;
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
- TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
- run_state_args, &ek->input_types,
- &ek->output_types));
+ TF_RETURN_IF_ERROR(CreateGraphs(
+ options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
+ &ek->output_types, &ek->collective_graph_key));
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
@@ -1353,6 +1375,9 @@ Status DirectSession::GetOrCreateExecutors(
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
+ callable_options.mutable_run_options()
+ ->mutable_experimental()
+ ->set_collective_graph_key(run_state_args->collective_graph_key);
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
@@ -1379,7 +1404,7 @@ Status DirectSession::CreateGraphs(
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types) {
+ DataTypeVector* output_types, int64* collective_graph_key) {
mutex_lock l(graph_def_lock_);
std::unique_ptr<ClientGraph> client_graph;
@@ -1403,6 +1428,7 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
+ *collective_graph_key = client_graph->collective_graph_key;
if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 72a2be4816..c2cf3c7fd7 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
-#define TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
#include <atomic>
#include <memory>
@@ -117,6 +117,9 @@ class DirectSession : public Session {
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private:
+ // For access to collective_graph_key_.
+ friend class DirectSessionCollectiveTest;
+
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
@@ -150,6 +153,8 @@ class DirectSession : public Session {
DataTypeVector output_types;
CallableOptions callable_options;
+
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// A FunctionInfo object is created for every unique set of feeds/fetches.
@@ -203,6 +208,7 @@ class DirectSession : public Session {
string handle;
std::unique_ptr<Graph> graph;
const DebugOptions& debug_options;
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// Initializes the base execution state given the 'graph',
@@ -234,7 +240,7 @@ class DirectSession : public Session {
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types);
+ DataTypeVector* output_types, int64* collective_graph_key);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
@@ -391,6 +397,10 @@ class DirectSession : public Session {
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
+ // For testing collective graph key generation.
+ mutex collective_graph_key_lock_;
+ int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
+
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdbg) related
@@ -399,4 +409,4 @@ class DirectSession : public Session {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4b51b20bb1..3f2355e530 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -2218,4 +2218,121 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
+
+class DirectSessionCollectiveTest : public ::testing::Test {
+ public:
+ // Creates a graph with CollectiveOps inside functions and runs it. Returns
+ // the generated collective_graph_key.
+ Status RunGraphWithCollectiveFunctions(bool add_unused_function,
+ int64* collective_graph_key) {
+ GraphDef g = CreateGraph(add_unused_function);
+ const Tensor t1 =
+ test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
+ const Tensor t2 =
+ test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
+ auto session = CreateSession();
+ TF_RETURN_IF_ERROR(session->Create(g));
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(
+ session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
+ {"collective_call1:0", "collective_call2:0"}, &outputs));
+ DirectSession* direct_session = static_cast<DirectSession*>(session.get());
+ {
+ mutex_lock l(direct_session->collective_graph_key_lock_);
+ *collective_graph_key = direct_session->collective_graph_key_;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Creates a function with name `function_name` and a single CollectiveReduce
+ // node with instance key set as `instance_key`.
+ FunctionDef CollectiveFunction(const string& function_name,
+ int instance_key) {
+ return FunctionDefHelper::Define(
+ // Function name
+ function_name,
+ // In def
+ {"arg:float"},
+ // Out def
+ {"reduce:float"},
+ // Attr def
+ {},
+ // Node def
+ {{
+ {"reduce"},
+ "CollectiveReduce",
+ {"arg"},
+ {{"group_size", 2},
+ {"group_key", 1},
+ {"instance_key", instance_key},
+ {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
+ {"merge_op", "Add"},
+ {"final_op", "Div"},
+ {"T", DT_FLOAT}},
+ }});
+ }
+
+ // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
+ // CPU1, with instance_key 1, and appropriate placeholder inputs. If
+ // `add_unused_function` is true, adds another CollectiveFunction with
+ // instance_key 2 that is not invoked in the graph.
+ GraphDef CreateGraph(bool add_unused_function) {
+ GraphDef g;
+ FunctionDef collective_function =
+ CollectiveFunction("CollectiveFunction1", 1);
+ FunctionDefLibrary* lib = g.mutable_library();
+ *lib->add_function() = collective_function;
+ if (add_unused_function) {
+ FunctionDef unused_function =
+ CollectiveFunction("CollectiveFunction2", 2);
+ *lib->add_function() = unused_function;
+ }
+
+ // Inputs.
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input1;
+ input1.set_name("input1");
+ input1.set_op("Placeholder");
+ input1.mutable_attr()->insert({"dtype", dtype_attr});
+ NodeDef input2;
+ input2.set_name("input2");
+ input2.set_op("Placeholder");
+ input2.mutable_attr()->insert({"dtype", dtype_attr});
+
+ // CollectiveReduce on CPU0 with instance_key 1.
+ NodeDef collective_call1;
+ collective_call1.set_name("collective_call1");
+ collective_call1.set_op("CollectiveFunction1");
+ collective_call1.add_input("input1");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
+ // CollectiveReduce on CPU1 with instance_key 1.
+ NodeDef collective_call2;
+ collective_call2.set_name("collective_call2");
+ collective_call2.set_op("CollectiveFunction1");
+ collective_call2.add_input("input2");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
+
+ *g.add_node() = input1;
+ *g.add_node() = input2;
+ *g.add_node() = collective_call1;
+ *g.add_node() = collective_call2;
+
+ return g;
+ }
+};
+
+#ifndef GOOGLE_CUDA
+// TODO(ayushd): enable this test for GPU builds.
+TEST_F(DirectSessionCollectiveTest,
+ TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
+ int64 key1;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
+ int64 key2;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
+ ASSERT_EQ(key1, key2);
+}
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/dma_helper.h b/tensorflow/core/common_runtime/dma_helper.h
index cdfce1f366..4a76cff1e3 100644
--- a/tensorflow/core/common_runtime/dma_helper.h
+++ b/tensorflow/core/common_runtime/dma_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
-#define TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
#include "tensorflow/core/framework/tensor.h"
@@ -35,4 +35,4 @@ class DMAHelper {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 7f28f3b793..be5f3bae3a 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -219,7 +219,9 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":kernel_and_device",
- "//tensorflow/c:c_api",
+ # Only the TF_AttrType enum is required, so pull in just the C headers.
+ # TODO(b/113535673): Break this dependency and avoid the C header completely.
+ "//tensorflow/c:c_api_headers",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
@@ -249,6 +251,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index 92307d78f2..cf1cd4134e 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -103,7 +103,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
return *this; \
}
-DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
@@ -119,9 +118,6 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
bool include_those_in_node_def) const {
- for (const auto& p : string_attrs_) {
- SetInAttrValueMap(m, p.first, p.second);
- }
for (const auto& p : int_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
@@ -211,10 +207,6 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
// not been called.
if (node_def_finalized_) return f;
}
- for (const auto& p : string_attrs_) {
- CombineUnordered(
- CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
- }
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index 929b1b8296..cbe6a1cb50 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
-#define TENSORFLOW_C_EAGER_RUNTIME_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
// Support for eager execution of TensorFlow kernels.
@@ -122,16 +122,15 @@ class AttrBuilder {
AttrValue attr_value;
if (found == nullptr) {
SetAttrValue(value, &attr_value);
- m->insert(AttrValueMap::value_type(attr_name.ToString(), attr_value));
+ m->insert(AttrValueMap::value_type(string(attr_name), attr_value));
} else {
// TODO(ashankar): Do what is done in
// NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
SetAttrValue(std::forward<T>(value), &attr_value);
- (*m)[attr_name.ToString()] = attr_value;
+ (*m)[string(attr_name)] = attr_value;
}
}
- AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
@@ -143,8 +142,6 @@ class AttrBuilder {
}; // namespace tensorflow
template <>
-AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
-template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
@@ -157,4 +154,4 @@ AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
} // namespace tensorflow
-#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 5bdd547c7f..37fc031985 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@@ -25,7 +26,7 @@ namespace {
bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
bool val;
- if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
+ if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
return val;
}
return default_val;
@@ -35,22 +36,35 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
- bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
+ : EagerContext(opts, default_policy, async, device_mgr.release(),
+ /*device_mgr_owned*/ true, rendezvous) {}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, const DeviceMgr* device_mgr,
+ bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
- local_device_manager_(std::move(device_mgr)),
- local_unowned_device_manager_(nullptr),
- devices_(local_device_manager_->ListDevices()),
+ devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
- local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
+ device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
+ thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
+ log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
use_send_tensor_rpc_(false) {
+ if (device_mgr_owned) {
+ local_device_manager_.reset(device_mgr);
+ local_unowned_device_manager_ = nullptr;
+ } else {
+ local_unowned_device_manager_ = device_mgr;
+ }
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
runner_ = [this](std::function<void()> closure) {
@@ -78,6 +92,12 @@ void EagerContext::InitDeviceMapAndAsync() {
}
}
}
+
+ DeviceSet ds;
+ for (Device* d : devices_) {
+ ds.AddDevice(d);
+ }
+ prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
bool EagerContext::Async() const {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 9835b19511..5ed6057ec6 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#endif
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
- explicit EagerContext(const SessionOptions& opts,
- ContextDevicePlacementPolicy default_policy, bool async,
- std::unique_ptr<DeviceMgr> device_mgr,
- Rendezvous* rendezvous);
+ // TODO: remove this constructor once we migrate all callers to the next one.
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ const DeviceMgr* device_mgr, bool device_mgr_owned,
+ Rendezvous* rendezvous);
+
~EagerContext();
// Returns the function library runtime for the given device.
@@ -93,6 +101,9 @@ class EagerContext {
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
+ const std::vector<DeviceType>& prioritized_device_type_list() {
+ return prioritized_device_type_list_;
+ }
// Clears the kernel caches.
void ClearCaches();
@@ -131,6 +142,7 @@ class EagerContext {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() { return log_device_placement_; }
+ bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; }
@@ -204,12 +216,13 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
- std::unique_ptr<DeviceMgr> local_device_manager_;
- DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<const DeviceMgr> local_device_manager_;
+ const DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
+ std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_;
@@ -250,6 +263,8 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const bool log_memory_;
+
Env* const env_;
#ifndef __ANDROID__
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 46065f399c..1da1326a9a 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
}
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
- DeviceSet ds;
- for (Device* d : *ctx->devices()) {
- ds.AddDevice(d);
- }
DeviceTypeVector final_devices;
- auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
- ndef, &final_devices);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ ctx->prioritized_device_type_list(), ndef, &final_devices));
if (final_devices.empty()) {
- return errors::Internal("Could not find valid device for node ",
- ndef.DebugString());
+ return errors::Internal(
+ "Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
+ "\nAll kernels registered for op ", ndef.op(), " :\n",
+ KernelsRegisteredForOp(ndef.op()));
}
for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) {
@@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
}
}
return errors::Unknown("Could not find a device for node ",
- ndef.DebugString());
+ SummarizeNodeDef(ndef));
}
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
@@ -299,7 +296,7 @@ Status EagerLocalExecute(EagerOperation* op,
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new KernelAndDevice(ctx->GetRendezvous());
+ kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
auto* flr = ctx->func_lib(device);
if (flr == nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3d61ff4dc2..59f94506b7 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -95,6 +95,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
params.cancellation_manager = &cm_;
+ params.log_memory = log_memory_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 0ef419cbaa..ed76c4f601 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -56,8 +56,11 @@ class KernelAndDevice {
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
- KernelAndDevice(tensorflow::Rendezvous* rendez)
- : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
+ KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
+ : device_(nullptr),
+ flib_(nullptr),
+ rendez_(rendez),
+ log_memory_(log_memory) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@@ -87,6 +90,7 @@ class KernelAndDevice {
DataTypeVector output_dtypes_;
std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_;
+ const bool log_memory_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 6abe98f53c..da280b2317 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) {
.NumInputs(2)
.BuildNodeDef());
TestEnv env;
- KernelAndDevice k(nullptr);
+ KernelAndDevice k(nullptr, false);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
@@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) {
.NumInputs(inputs.size())
.BuildNodeDef());
TestEnv env;
- KernelAndDevice kernel(nullptr);
+ KernelAndDevice kernel(nullptr, false);
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
nullptr, &kernel));
tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 85b0b79bce..b912f7d37b 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -193,7 +193,6 @@ Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
// has device type XLA_CPU, and the other CPU.
const bool both_on_cpu = src_cpu && dst_cpu;
if (is_same_device || both_on_cpu) {
- dstd = dst_cpu ? nullptr : dstd;
*output = new tensorflow::TensorHandle(*src, dstd, dstd, ctx);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 951bc4197e..84865397bc 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -72,141 +72,58 @@ bool IsInitializationOp(const Node* node) {
return node->op_def().allows_uninitialized_input();
}
-// Sets the timeline_label field of *node_stats, using data from *node.
-// Returns true iff the node is a transfer node.
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- bool is_transfer_node = false;
- if (!stats) {
- return is_transfer_node;
- }
- string memory;
- for (auto& all : stats->stats()->memory()) {
- int64 tot = all.total_bytes();
- if (tot >= 0.1 * 1048576.0) {
- int64 peak = all.peak_bytes();
- if (peak > 0) {
- memory =
- strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
- peak / 1048576.0));
- } else {
- memory = strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB] ", tot / 1048576.0));
- }
- }
- }
- const AttrSlice attrs = node->attrs();
- string text;
- if (IsSend(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string recv_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string send_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", send_device);
- is_transfer_node = true;
- } else {
- text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
- }
- stats->stats()->set_timeline_label(text);
- return is_transfer_node;
-}
-
// Helper routines for collecting step stats.
namespace nodestats {
-inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 nanos) {
+void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
if (!stats) return;
- stats->stats()->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_scheduled_nanos(nanos);
+ stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
void SetAllStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
- int64 now_nanos = NowInNsec();
- stats->stats()->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_all_start_nanos(now_nanos);
+ stats->RecordExecutorStarted();
}
void SetOpStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_start_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeStarted();
}
void SetOpEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeEnded();
}
void SetAllEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_all_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordExecutorEnded();
}
void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
if (!stats) return;
- DCHECK(v);
- NodeOutput* no = stats->stats()->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
+ stats->SetOutput(slot, v);
}
void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
if (!stats) return;
-
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- stats->AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats->stats()->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+ stats->SetMemory(ctx);
}
void SetReferencedTensors(NodeExecStatsWrapper* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description =
- stats->stats()->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
+ stats->SetReferencedTensors(tensors);
+}
+
+// Sets the timeline_label field of *stats, using data from *node.
+// Returns true iff the node is a transfer node.
+bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
+ if (!stats) {
+ return false;
}
+ return stats->SetTimelineLabel(node);
}
} // namespace nodestats
@@ -1565,6 +1482,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
+ delete this;
done(fill_status);
return;
}
@@ -1575,6 +1493,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
ready.push_back(TaggedNode{n, root_frame_, 0, false});
}
if (ready.empty()) {
+ delete this;
done(Status::OK());
} else {
num_outstanding_ops_ = ready.size();
@@ -1694,15 +1613,14 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper;
- stats->stats()->set_node_name(node->name());
+ stats = new NodeExecStatsWrapper(node->name());
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
if (vlog_) {
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
- << SummarizeNode(*node) << " is dead: " << tagged_node.is_dead
+ << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
@@ -1764,7 +1682,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
VLOG(2) << "Async kernel done: " << state->item->node->id()
<< " step " << step_id_ << " "
<< SummarizeNode(*state->item->node)
- << " is dead: " << state->tagged_node.is_dead
+ << (state->tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
@@ -1818,7 +1736,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (vlog_) {
VLOG(2) << "Synchronous kernel done: " << id << " step "
<< params.step_id << " " << SummarizeNode(*node)
- << " is dead: " << tagged_node.is_dead
+ << (tagged_node.is_dead ? " is dead: " : "")
<< " device: " << device->name();
}
@@ -2165,7 +2083,8 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
NodeExecStatsWrapper* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
+ if (stats_collector_ != nullptr &&
+ !nodestats::SetTimelineLabel(node, stats)) {
// Only record non-transfer nodes.
// Transfers 'stats' ownership to 'stats_collector_'.
stats_collector_->Save(impl_->params_.device->name(), stats);
@@ -2502,8 +2421,7 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
}
if (dst_ready) {
if (IsControlTrigger(dst_node)) dst_dead = false;
- ready->push_back(
- TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
parent_iter_state->outstanding_ops++;
}
}
@@ -2627,7 +2545,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
// Add dst to the ready queue if it's ready
if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false;
- ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
+ ready->emplace_back(dst_item->node, this, iter, dst_dead);
iter_state->outstanding_ops++;
}
}
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index a238a6763a..6cd4fd22ea 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -235,4 +235,4 @@ void DeleteNonCachedKernel(OpKernel* kernel);
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 54bbe84b57..1c9b69721d 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -310,6 +310,7 @@ class CallOp : public AsyncOpKernel {
opts.step_container = ctx->step_container();
opts.stats_collector = ctx->stats_collector();
opts.runner = ctx->runner();
+ opts.collective_executor = ctx->collective_executor();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
@@ -346,9 +347,10 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
return nullptr;
}
- mutex_lock l(mu_);
- CHECK_EQ(1, items_.count(local_handle));
- return items_[local_handle]->func_graph;
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ CHECK(iter != items_.end());
+ return iter->second->func_graph;
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
@@ -555,6 +557,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
next_handle_++;
}
}
+
+ if (options.create_kernels_eagerly) {
+ Item* item;
+ TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
+ }
+
return Status::OK();
}
@@ -607,11 +615,14 @@ void PruneFunctionBody(Graph* g) {
std::unordered_set<const Node*> nodes;
for (auto n : g->nodes()) {
// NOTE(mrry): "_Retval" nodes are stateful, and so will be added
- // to the seed set of `nodes`.
+ // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
+ // specifically exclude them as seeds, to avoid unconditionally executing
+ // unused argument nodes (e.g. in a function like `lambda x, y: y`).
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
// still needed. It would be preferable to prune entire loops and/or
// conditionals if they are not used in the graph.
- if (n->IsControlFlow() || n->op_def().is_stateful()) {
+ if (n->IsControlFlow() ||
+ (n->op_def().is_stateful() && n->type_string() != kArgOp)) {
nodes.insert(n);
}
}
@@ -627,7 +638,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionLibraryDefinition* lib_def;
string executor_type;
{
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
executor_type = (*item)->executor_type;
@@ -676,12 +687,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
{
- mutex_lock l(mu_);
- if (items_.count(local_handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ if (iter == items_.end()) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle].get();
+ *item = iter->second.get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
@@ -916,29 +928,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
DCHECK(run_opts.runner != nullptr);
- Executor::Args* exec_args = new Executor::Args;
+ Executor::Args exec_args;
// Inherit the step_id from the caller.
- exec_args->step_id = run_opts.step_id;
- exec_args->rendezvous = run_opts.rendezvous;
- exec_args->stats_collector = run_opts.stats_collector;
- exec_args->cancellation_manager = run_opts.cancellation_manager;
- exec_args->collective_executor = run_opts.collective_executor;
- exec_args->step_container = run_opts.step_container;
- exec_args->runner = *run_opts.runner;
- exec_args->call_frame = frame;
-
- item->exec->RunAsync(
- // Executor args
- *exec_args,
- // Done callback.
- std::bind(
- [item, frame, exec_args](DoneCallback done,
- // Start unbound arguments.
- const Status& status) {
- delete exec_args;
- done(status);
- },
- std::move(done), std::placeholders::_1));
+ exec_args.step_id = run_opts.step_id;
+ exec_args.rendezvous = run_opts.rendezvous;
+ exec_args.stats_collector = run_opts.stats_collector;
+ exec_args.cancellation_manager = run_opts.cancellation_manager;
+ exec_args.collective_executor = run_opts.collective_executor;
+ exec_args.step_container = run_opts.step_container;
+ exec_args.runner = *run_opts.runner;
+ exec_args.call_frame = frame;
+
+ item->exec->RunAsync(exec_args, std::move(done));
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index a274f1ef51..eeca66f5d0 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
-#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
#include <functional>
#include <memory>
@@ -170,4 +170,4 @@ Status FunctionDefToBodyHelper(
FunctionBody** fbody);
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480198..7bab9be9a6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
- {"x: int32"},
+ {"x: int32", "y: float32"},
// Return values
- {"y: int32"},
+ {"z: int32"},
// Attrs
{},
// Nodes
@@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
- // y = Add<T>(a, o)
- {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ // z = Add<T>(a, o)
+ {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- Tensor y;
+ auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+ Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
@@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
- TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
- {x}, {&y}));
- test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+ {x, y}, {&z}));
+ test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
- // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+ // execute.
std::set<string> expected_node_names(
- {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index a3e0d0734f..f1cc2eace1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
#include <memory>
#include <string>
@@ -89,4 +89,4 @@ class GPUMemAllocator : public SubAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 5043fac797..856fdc34b4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
#include <memory>
@@ -51,4 +51,4 @@ class GPUcudaMallocAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index c49ec2a566..0f9b72040c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
#include <memory>
#include <string>
@@ -88,4 +88,4 @@ class GPUNanResetAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3292ef2f62..2763ac0d4a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -917,16 +917,21 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
const auto& gpu_options = options.config.gpu_options();
std::vector<CudaGpuId> visible_gpu_order;
- TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
- &visible_gpu_order));
-
std::vector<CudaGpuId> valid_cuda_gpu_ids;
- TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ // If we aren't going to use any GPUs, don't initialize them.
+ // We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
+ // because it treats an empty gpu_options.visible_device_list as 'all GPUs are
+ // visible'.
+ if (num_gpus_to_use > 0) {
+ TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
+ &visible_gpu_order));
+ TF_RETURN_IF_ERROR(
+ GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ }
if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
num_gpus_to_use = valid_cuda_gpu_ids.size();
}
- // If we aren't going to use any GPUs, don't initialize them.
- if (num_gpus_to_use > 0 && !valid_cuda_gpu_ids.empty()) {
+ if (!valid_cuda_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index f0a109cc10..2d406b676e 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
#include <deque>
#include <vector>
@@ -203,4 +203,4 @@ class EventMgr {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h
index bfd7a77f83..4e1f06ac83 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_init.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
#include "tensorflow/core/lib/core/status.h"
@@ -36,4 +36,4 @@ stream_executor::Platform* GPUMachineManager();
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
index 771c158267..c61ada96ef 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
#include <unordered_map>
@@ -42,4 +42,4 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
} // namespace gpu_stream_util
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
index 57687a8364..8ac3febb01 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -108,4 +108,4 @@ class GPUUtil {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
index ea1b04feeb..4bc88ffc8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
@@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
}
+Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) {
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ device->tensorflow_gpu_device_info();
+ gpu_info->event_mgr->ThenExecute(stream, func);
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
index d697d878dc..3603808152 100644
--- a/tensorflow/core/common_runtime/gpu_device_context.h
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h"
@@ -60,6 +60,9 @@ class GPUDeviceContext : public DeviceContext {
void MaintainLifetimeOnStream(const Tensor* t,
se::Stream* stream) const override {}
+ Status ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) override;
+
private:
int stream_id_;
// The default primary stream to use for this context.
@@ -75,4 +78,4 @@ class GPUDeviceContext : public DeviceContext {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index c23b7d3699..4475fa979e 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory>
+#include <set>
#include <string>
#include <unordered_set>
#include <utility>
@@ -560,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item;
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
+ // TODO(b/114748242): Add a unit test to test this bug fix.
+ if (flib_def_) {
+ *item.graph.mutable_library() = flib_def_->ToProto();
+ }
item.fetch.insert(item.fetch.end(),
options.callable_options.fetch().begin(),
@@ -581,7 +586,7 @@ Status GraphExecutionState::OptimizeGraph(
if (id.second != 0) {
return errors::InvalidArgument("Unsupported feed: ", feed);
}
- feeds.insert(id.first.ToString());
+ feeds.emplace(id.first);
}
for (const TensorConnection& tensor_connection :
options.callable_options.tensor_connection()) {
@@ -590,7 +595,7 @@ Status GraphExecutionState::OptimizeGraph(
return errors::InvalidArgument("Unsupported feed: ",
tensor_connection.to_tensor());
}
- feeds.insert(id.first.ToString());
+ feeds.emplace(id.first);
}
for (const NodeDef& node : original_graph_def_.node()) {
if (feeds.find(node.name()) == feeds.end()) {
@@ -727,12 +732,50 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
+ int64 collective_graph_key = options.collective_graph_key;
+ if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // BuildGraphOptions does not specify a collective_graph_key. Check all
+ // nodes in the Graph and FunctionLibraryDefinition for collective ops and
+ // if found, initialize a collective_graph_key as a hash of the ordered set
+ // of instance keys.
+ std::set<int32> instance_key_set;
+ for (Node* node : optimized_graph->nodes()) {
+ if (node->IsCollective()) {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ } else {
+ const FunctionDef* fdef = optimized_flib->Find(node->def().op());
+ if (fdef != nullptr) {
+ for (const NodeDef& ndef : fdef->node_def()) {
+ if (ndef.op() == "CollectiveReduce" ||
+ ndef.op() == "CollectiveBcastSend" ||
+ ndef.op() == "CollectiveBcastRecv") {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ndef, "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ }
+ }
+ }
+ }
+ }
+ if (!instance_key_set.empty()) {
+ uint64 hash = 0x8774aa605c729c72ULL;
+ for (int32 instance_key : instance_key_set) {
+ hash = Hash64Combine(instance_key, hash);
+ }
+ collective_graph_key = hash;
+ }
+ }
+
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
- rewrite_metadata.fetch_types));
+ rewrite_metadata.fetch_types, collective_graph_key));
CopyGraph(*optimized_graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h
index d44a24c87b..9cabe478a6 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.h
+++ b/tensorflow/core/common_runtime/graph_execution_state.h
@@ -50,17 +50,20 @@ struct GraphExecutionStateOptions {
// BuildGraphOptions.
struct ClientGraph {
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
- DataTypeVector feed_types, DataTypeVector fetch_types)
+ DataTypeVector feed_types, DataTypeVector fetch_types,
+ int64 collective_graph_key)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
- fetch_types(std::move(fetch_types)) {}
+ fetch_types(std::move(fetch_types)),
+ collective_graph_key(collective_graph_key) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
DataTypeVector feed_types;
DataTypeVector fetch_types;
+ int64 collective_graph_key;
};
// GraphExecutionState is responsible for generating an
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 0a1797fa19..f9aef3af70 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous {
}
mutex_lock l(mu_);
- string edge_name = std::string(parsed.edge_name);
+ string edge_name(parsed.edge_name);
if (table_.count(edge_name) > 0) {
return errors::Internal("Send of an already sent tensor");
}
@@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous {
Tensor tensor;
Status status = Status::OK();
{
- string key = std::string(parsed.edge_name);
+ string key(parsed.edge_name);
mutex_lock l(mu_);
if (table_.count(key) <= 0) {
status = errors::Internal("Did not find key ", key);
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
new file mode 100644
index 0000000000..eae34997d9
--- /dev/null
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -0,0 +1,440 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+// Set true for greater intelligibility of debug mode log messages.
+#define READABLE_KEYS false
+
+namespace tensorflow {
+
+namespace {
+// Key to be used for BufRendezvous by Broadcaster.
+string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
+ int dst_rank) {
+ if (READABLE_KEYS) {
+ return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
+ "):src(", src_rank, "):dst(", dst_rank, ")");
+ } else {
+ // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
+ return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
+ }
+}
+} // namespace
+
+HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
+ done_(nullptr),
+ is_source_(false) {}
+
+int HierarchicalTreeBroadcaster::GetDeviceTask(
+ int device_rank, const std::vector<int>& dev_per_task) {
+ int num_tasks = static_cast<int>(dev_per_task.size());
+ int task_lo = 0;
+ int task_hi;
+ for (int ti = 0; ti < num_tasks; ti++) {
+ task_hi = task_lo + dev_per_task[ti];
+ if (task_lo <= device_rank && device_rank < task_hi) return ti;
+ task_lo = task_hi;
+ }
+ LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
+ << " devices";
+ return -1;
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
+ CollectiveParams* col_params) {
+ CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name,
+ "HierarchicalTreeBroadcast");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ if (VLOG_IS_ON(2)) {
+ string dpt_buf;
+ for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
+ VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
+ << device_name << " source_rank=" << col_params->source_rank
+ << " dev_per_task=" << dpt_buf;
+ }
+ int num_tasks = col_params->group.num_tasks;
+ // If there is just 1 task, then execute binary tree broadcast over all
+ // devices. Otherwise, the first subdiv is inter-task broadcast, and then
+ // there are N more subdivs, where N is #task.
+ int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
+ int total_num_devices = 0;
+ for (int num_dev : dev_per_task) total_num_devices += num_dev;
+
+ col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+ col_params->subdiv_rank.reserve(num_subdivs);
+ col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
+
+ // Inter-task subdiv. Pick one device from each task - this is the source
+ // device if it belongs to that task, or device 0 for that task. If a device
+ // does not participate in the subdiv, set subdiv_rank to -1.
+ if (num_tasks > 1) {
+ const int sdi = 0;
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int device_count = 0;
+ int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ bool participate = false;
+ if (source_task == ti) {
+ // Source device belongs to this task.
+ perm.push_back(col_params->source_rank);
+ participate =
+ col_params->instance.device_names[col_params->source_rank] ==
+ device_name;
+ } else {
+ // Source does not belong to this task, choose dev 0.
+ perm.push_back(device_count);
+ participate =
+ col_params->instance.device_names[device_count] == device_name;
+ }
+ if (participate) col_params->subdiv_rank.push_back(ti);
+ device_count += dev_per_task[ti];
+ }
+ if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
+ }
+
+ // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
+ // source to dev 0 for that task if it does not contain original source, else
+ // set to rank of original source. If a device does not participate in
+ // the subdiv, set subdiv_rank to -1;
+ int abs_di = 0;
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ const int sdi = ti + (num_tasks > 1 ? 1 : 0);
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ bool participate = false;
+ int subdiv_source = 0;
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ perm.push_back(abs_di);
+ if (col_params->instance.device_names[abs_di] == device_name) {
+ participate = true;
+ col_params->subdiv_rank.push_back(di);
+ }
+ if (abs_di == col_params->source_rank) subdiv_source = di;
+ abs_di++;
+ }
+ if (!participate) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(
+ subdiv_source);
+ }
+
+ for (int sri = 0; sri < num_subdivs; sri++) {
+ CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
+ }
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
+ CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
+}
+
+void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
+ done_ = std::move(done);
+ is_source_ = col_params_->is_source;
+ RunTree();
+}
+
+// Binary tree parent/child relations are trivial to calculate, i.e.
+// device at rank r is the parent of 2r+1 and 2r+2. The one exception
+// is if the source is not rank 0. We treat that case as though the
+// source is appended to the front of the rank ordering as well as
+// continuing to occupy its current position. Hence we calculate as
+// though each device's rank is actually r+1, then subtract 1 again to
+// get the descendent ranks. If the source is not rank 0 then its
+// descendants include both {0,1} and the descendents of its current
+// position. Where a non-0-rank source is a descendent of another
+// device, no send to it is necessary.
+
+/* static*/
+int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
+ int subdiv) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return -1;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+ if (my_rank == source_rank) return -1;
+ if (source_rank == 0) {
+ return (my_rank - 1) / 2;
+ } else {
+ int predecessor_rank = (my_rank / 2) - 1;
+ return (predecessor_rank < 0) ? source_rank : predecessor_rank;
+ }
+}
+
+/* static */
+void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
+ int subdiv,
+ std::vector<int>* targets) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+
+ int group_size = 0;
+ for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
+ if (impl.subdiv_permutations[subdiv][i] >= 0) {
+ group_size++;
+ }
+ }
+
+ targets->clear();
+ int successor_rank = 0;
+ if (source_rank == 0) {
+ successor_rank = (2 * my_rank) + 1;
+ } else {
+ successor_rank = (2 * (my_rank + 1));
+ }
+ DCHECK_NE(successor_rank, my_rank);
+ if (cp.is_source && source_rank != 0) {
+ // The source sends to rank 0,1 in addition to its positional
+ // descendants.
+ if (group_size > 1) {
+ targets->push_back(0);
+ }
+ if (group_size > 2 && source_rank != 1) {
+ targets->push_back(1);
+ }
+ }
+ for (int i = 0; i < 2; ++i) {
+ if (successor_rank < group_size && successor_rank != source_rank) {
+ targets->push_back(successor_rank);
+ }
+ ++successor_rank;
+ }
+}
+
+// Executes a hierarchical tree broadcast.
+// Each subdiv is a broadcast between a subset of the devices.
+// If there is only one task, there is one subdiv comprising a broadcast between
+// all devices belonging to the task.
+// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
+// subdiv, one device from each task participates in a binary tree broadcast.
+// Each task receives a copy of the tensor on one device via this broadcast.
+// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
+// corresponds to broadcast between all devices on task i. Thus, each task
+// participates in at most 2 subdivs.
+void HierarchicalTreeBroadcaster::RunTree() {
+ int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
+ // TODO(b/78352018): this is easily improved when a node participates in both
+ // first and second subdivision. It would first send to its descendents in
+ // the first subdiv, then wait until all pending ops are finished before
+ // sending to descendents in second subdiv. A better implementation would
+ // collapse the two send blocks.
+ for (int si = 0; si < num_subdivs; si++) {
+ int my_rank = col_params_->subdiv_rank[si];
+ // If rank is -1, this device does not participate in this subdiv.
+ if (-1 == my_rank) continue;
+ int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
+ if (VLOG_IS_ON(1)) {
+ string subdiv_buf;
+ for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
+ strings::StrAppend(&subdiv_buf, r, ",");
+ }
+ VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
+ << " subdiv=" << si << " perm=" << subdiv_buf
+ << " my_rank=" << my_rank << " source_rank=" << source_rank;
+ }
+
+ mutex mu; // also guards status_ while callbacks are pending
+ int pending_count = 0; // GUARDED_BY(mu)
+ condition_variable all_done;
+
+ if (my_rank >= 0 && my_rank != source_rank) {
+ // Begin by receiving the value.
+ int recv_from_rank = TreeRecvFrom(*col_params_, si);
+ Notification note;
+ DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
+ [this, &mu, &note](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ note.Notify();
+ });
+ note.WaitForNotification();
+ }
+
+ // Then forward value to all descendent devices.
+ if (my_rank >= 0 && status_.ok()) {
+ std::vector<int> send_to_ranks;
+ TreeSendTo(*col_params_, si, &send_to_ranks);
+ for (int i = 0; i < send_to_ranks.size(); ++i) {
+ int target_rank = send_to_ranks[i];
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DispatchSend(si, target_rank, my_rank,
+ (is_source_ ? col_ctx_->input : col_ctx_->output),
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (pending_count == 0) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // For the original source device, we copy input to output if they are
+ // different.
+ // If there is only 1 subdiv, we do this in that subdiv. If there is more
+ // than 1 subdiv, then the original source device will participate in 2
+ // subdivs - the global inter-task broadcast and one local intra-task
+ // broadcast. In this case, we perform the copy in the second subdiv for
+ // this device.
+ if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
+ VLOG(2) << "copying input to output for device=" << col_ctx_->device_name
+ << " subdiv=" << si;
+ if (col_ctx_->input != col_ctx_->output &&
+ (DMAHelper::base(col_ctx_->input) !=
+ DMAHelper::base(col_ctx_->output))) {
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ CollectiveRemoteAccessLocal::MemCpyAsync(
+ op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
+ col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0, /*stream_index*/
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (0 == pending_count) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // Then wait for all pending actions to complete.
+ {
+ mutex_lock l(mu);
+ if (pending_count > 0) {
+ all_done.wait(l);
+ }
+ }
+ }
+ VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
+ done_(status_);
+}
+
+void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
+ int src_rank,
+ const Tensor* src_tensor,
+ const StatusCallback& done) {
+ string send_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int dst_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
+ VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
+ << col_ctx_->device_name << " to_device "
+ << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
+ << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
+ col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx],
+ col_params_->instance.task_names[dst_idx],
+ send_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0),
+ src_tensor, col_ctx_->device_locality, done);
+}
+
+void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
+ int dst_rank, Tensor* dst_tensor,
+ const StatusCallback& done) {
+ string recv_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int src_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
+ VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
+ << col_params_->instance.device_names[src_idx] << " to_device "
+ << col_ctx_->device_name << " subdiv=" << subdiv
+ << " src_rank=" << src_rank << " src_idx=" << src_idx;
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[src_idx],
+ col_params_->instance.task_names[src_idx],
+ col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, 0 /*stream_index*/, done);
+}
+
+REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
index 799228b161..ceb9baad30 100644
--- a/tensorflow/core/common_runtime/broadcaster.h
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
@@ -12,25 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
#include <vector>
+
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-// Tree-algorithm implementation of collective broadcast.
-class Broadcaster {
+// Hierarchical tree-algorithm implementation of collective broadcast.
+class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
public:
- Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, Tensor* output);
+ HierarchicalTreeBroadcaster();
+ ~HierarchicalTreeBroadcaster() override = default;
+
+ // Establishes the subdiv permutations needed for a hierarchical broadcast.
+ // If all devices are local, establishes a single subdiv comprising all
+ // devices. If any devices are on a different task, establishes n+1 subdivs
+ // for n tasks.
+ // The first subdiv comprises one device per task which gets the tensor on
+ // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task
+ // i.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the hierarchical tree broadcast.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
// Returns the rank of the device from which this device should receive
// its value, -1 if no value should be received.
@@ -42,32 +57,29 @@ class Broadcaster {
std::vector<int>* targets);
private:
+ // Get the task to which the device at `device_rank` belongs.
+ int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task);
+
// Sends `src_tensor` asynchronously from this device to device at `dst_rank`
// in `subdiv`. Calls `done` upon completion.
void DispatchSend(int subdiv, int dst_rank, int src_rank,
const Tensor* src_tensor, const StatusCallback& done);
+
// Receives a tensor into the memory buffer owned by `dst_tensor` at this
// device from device at `src_rank` in `subdiv`. Calls `done` upon
// completion.
void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
const StatusCallback& done);
+
// Executes the hierarchical broadcast defined by this op.
void RunTree();
- Status status_;
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const int rank_;
- const bool is_source_;
- Tensor* output_; // Not owned
- std::unique_ptr<CollectiveAdapter> ca_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
- Device* device_; // The device for which this instance labors
- DeviceLocality device_locality_;
+ Status status_;
+ bool is_source_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 3960fc6c97..da0e359cf8 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
@@ -41,7 +41,7 @@ static int64 kStepId = 123;
// The test harness won't allow a mixture of fixture and non-fixture
// tests in one file, so this is a trival fixture for tests that don't
-// need the heavy-weight BroadcasterTest fixture.
+// need the heavy-weight HierarchicalTreeBroadcasterTest fixture.
class TrivialTest : public ::testing::Test {
protected:
TrivialTest() {}
@@ -53,23 +53,23 @@ class TrivialTest : public ::testing::Test {
// R = tested rank
// RF = receive-from rank
// ST = send_to rank vector
-#define DEF_TL_TEST(D, S, R, RF, ST) \
- TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
- CollectiveParams cp; \
- cp.group.group_size = D; \
- cp.instance.impl_details.subdiv_source_rank = {S}; \
- cp.instance.impl_details.subdiv_permutations.push_back( \
- std::vector<int>(D, 0)); \
- cp.subdiv_rank = {R}; \
- cp.is_source = (S == R); \
- EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp, 0)); \
- std::vector<int> expected = ST; \
- std::vector<int> send_to; \
- Broadcaster::TreeSendTo(cp, 0, &send_to); \
- ASSERT_EQ(expected.size(), send_to.size()); \
- for (int i = 0; i < expected.size(); ++i) { \
- EXPECT_EQ(expected[i], send_to[i]); \
- } \
+#define DEF_TL_TEST(D, S, R, RF, ST) \
+ TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
+ CollectiveParams cp; \
+ cp.group.group_size = D; \
+ cp.instance.impl_details.subdiv_source_rank = {S}; \
+ cp.instance.impl_details.subdiv_permutations.push_back( \
+ std::vector<int>(D, 0)); \
+ cp.subdiv_rank = {R}; \
+ cp.is_source = (S == R); \
+ EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
+ std::vector<int> expected = ST; \
+ std::vector<int> send_to; \
+ HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \
+ ASSERT_EQ(expected.size(), send_to.size()); \
+ for (int i = 0; i < expected.size(); ++i) { \
+ EXPECT_EQ(expected[i], send_to[i]); \
+ } \
}
#define V(...) std::vector<int>({__VA_ARGS__})
@@ -130,7 +130,7 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
-// TODO(tucker): factor out of this file and ring_reducer_test.cc
+// TODO(b/113171733): factor out of this file and ring_reducer_test.cc
// into a single common source.
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
@@ -187,31 +187,32 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
int fail_after_ GUARDED_BY(mu_);
};
-class BroadcasterTest : public ::testing::Test {
+class HierarchicalTreeBroadcasterTest : public ::testing::Test {
protected:
- BroadcasterTest() : device_type_(DEVICE_CPU) {}
+ HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
- ~BroadcasterTest() override {
+ ~HierarchicalTreeBroadcasterTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
void Init(int num_workers, int num_devices_per_worker, DataType dtype,
const DeviceType& device_type, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
VLOG(2) << "num_workers=" << num_workers
<< " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker;
@@ -400,8 +401,6 @@ class BroadcasterTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
- void BuildColParams() {}
-
template <typename T>
void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
int num_devices, int tensor_len, int fail_after,
@@ -511,10 +510,47 @@ class BroadcasterTest : public ::testing::Test {
}
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank,
+ const std::vector<int>& expected_subdiv_source_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ cp->instance.impl_details.subdiv_source_rank.clear();
+ // Create a stub broadcaster only for testing param initialization.
+ HierarchicalTreeBroadcaster broadcaster;
+ TF_CHECK_OK(broadcaster.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ EXPECT_EQ(expected_subdiv_source_rank,
+ cp->instance.impl_details.subdiv_source_rank);
+ }
+
+ void PrepColParamsForSubdivPermsTest(CollectiveParams* cp, int num_tasks,
+ int num_gpus) {
+ cp->group.device_type = DeviceType("GPU");
+ cp->group.num_tasks = num_tasks;
+ cp->group.group_size = num_tasks * num_gpus;
+ cp->instance.type = BROADCAST_COLLECTIVE;
+ cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ for (int ti = 0; ti < num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < num_gpus; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp->instance.task_names.push_back(task_name);
+ cp->instance.device_names.push_back(dev_name);
+ }
+ }
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
- const DeviceType& device_type, BroadcasterTest* parent)
+ const DeviceType& device_type,
+ HierarchicalTreeBroadcasterTest* parent)
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
@@ -636,21 +672,20 @@ class BroadcasterTest : public ::testing::Test {
ctx.allocate_output(0, tensor_.shape(), &output_tensor_ptr));
}
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
+ const Tensor* input_tensor_ptr =
+ col_params_.is_source ? &tensor_ : nullptr;
// Prepare a Broadcaster instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- Broadcaster broadcaster(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId,
- output_tensor_ptr);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- broadcaster.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- notification.WaitForNotification();
+ HierarchicalTreeBroadcaster broadcaster;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, input_tensor_ptr, output_tensor_ptr);
+ TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+
+ // Run the broadcast.
+ broadcaster.Run([this](Status s) { status_ = s; });
if (status_.ok()) {
CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
}
@@ -658,15 +693,13 @@ class BroadcasterTest : public ::testing::Test {
dev_ctx->Unref();
}
- BroadcasterTest* parent_;
+ HierarchicalTreeBroadcasterTest* parent_;
string dev_name_;
DeviceType device_type_ = DEVICE_CPU;
int rank_;
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
- std::unique_ptr<CollectiveAdapter> ca_;
- std::unique_ptr<OpKernelContext> ctx_;
Status status_;
}; // class DeviceInstance
@@ -688,6 +721,118 @@ class BroadcasterTest : public ::testing::Test {
int failure_count_ GUARDED_BY(mu_) = 0;
};
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 1, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
+
+ // source 2 device 2
+ cp.source_rank = 2;
+ cp.default_rank = 2;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 4, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 9
+ cp.source_rank = 9;
+ cp.default_rank = 9;
+ RunSubdivPermsTest(&cp,
+ {{0, 9, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
+ CollectiveParams cp;
+ int num_tasks = 4;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = num_tasks;
+ cp.group.group_size = 0;
+ cp.instance.type = BROADCAST_COLLECTIVE;
+ cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ std::vector<int> dev_per_task = {4, 4, 6, 8};
+ for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(dev_name);
+ cp.group.group_size++;
+ }
+ }
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 5
+ cp.source_rank = 9;
+ cp.default_rank = 5;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 9, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
+}
+
+// TODO(b/113171733): change to use TEST_P.
// Tests of full broadcast algorithm, with different device and
// data types.
// B = data element type
@@ -697,7 +842,7 @@ class BroadcasterTest : public ::testing::Test {
// L = tensor length
// A = abort after count
#define DEF_TEST(B, T, W, D, L, A, F) \
- TEST_F(BroadcasterTest, \
+ TEST_F(HierarchicalTreeBroadcasterTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A##_Fw##F) { \
DataType dtype = DT_##B; \
switch (dtype) { \
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 995a15a299..555b43f655 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
-#define TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
#include <string>
#include <vector>
@@ -65,4 +65,4 @@ class Benchmark {
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index 873182371e..db5022d56e 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -62,7 +62,7 @@ struct LocalDevice::EigenThreadPoolInfo {
LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes)
- : Device(options.env, attributes), owned_tp_info_(nullptr) {
+ : TracingDevice(options.env, attributes), owned_tp_info_(nullptr) {
// Log info messages if TensorFlow is not compiled with instructions that
// could speed up performance and are available on the current CPU.
port::InfoAboutUnusedCPUFeatures();
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index 84a4f66db4..9a82fb7204 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/tracing_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -31,7 +32,7 @@ struct SessionOptions;
// initializes a shared Eigen compute device used by both. This
// should eventually be removed once we refactor ThreadPoolDevice and
// GPUDevice into more 'process-wide' abstractions.
-class LocalDevice : public Device {
+class LocalDevice : public TracingDevice {
public:
LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes);
@@ -54,4 +55,4 @@ class LocalDevice : public Device {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc
new file mode 100644
index 0000000000..1f5da133e9
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op.cc
@@ -0,0 +1,427 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+
+namespace {
+
+using NodeOut = NodeBuilder::NodeOut;
+
+// Helper to convert a functional While op to its lowered form.
+//
+// Example:
+//
+// Input graph:
+//
+// loop_var -> WhileOp<cond_func, body_func> -> consumer
+//
+// Output graph(top to down flow):
+//
+// loop_var
+// |
+// Enter
+// |
+// inlined_cond_func ---<--- Merge -----<----- NextIteration
+// | | |
+// V V ^
+// | | |
+// LoopCond ------>-------- Switch ---->---- inlined_body_func
+// |
+// Exit
+// |
+// consumer
+class LowerWhileHelper {
+ public:
+ static Status Run(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph) {
+ LowerWhileHelper helper(while_op, cond_fn_name, body_fn_name, graph);
+ return helper.RunInternal();
+ }
+
+ private:
+ // Create a LowerWhileHelper to create the lowering of While op that has cond
+ // and body functions named `cond_fn_name` and `body_fn_name` respectively in
+ // the given graph.
+ LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph);
+
+ Status RunInternal();
+
+ // Creates an Enter node for each `while_op_` input and adds them to
+ // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
+ // `src` node we add a control edge from `src` to each Enter node.
+ Status CreateEnterNodes();
+
+ // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
+ // Initially now both inputs of a Merge node are the Enter node. Input at
+ // index 1 is later updated to the output of NextIteration node in
+ // `UpdateMergeNodes`.
+ Status CreateMergeNodes();
+
+ // Creates the call node for cond func and stores in `cond_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateCondFuncCallNode();
+
+ // Creates a Switch node for each loop var and adds to `switch_nodes_`.
+ // Output at index 1(true) of a Switch node is fed into the loop body.
+ // Output at index 0(false) of a Switch node is fed into the Exit nodes.
+ Status CreateSwitchNodes();
+
+ // Creates the call node for body func and stores in `body_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateBodyFuncCallNode();
+
+ // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
+ // are fed into the consumers of the `while_op_`.
+ Status CreateExitNodes();
+
+ // Creates an NextIteration node for each loop var and adds to
+ // `next_iteration_nodes_`.
+ Status CreateNextIterationNodes();
+
+ // Updates input at index 1 of each merge node created in `CreateMergeNodes`
+ // to use the output of NextIteration node created in
+ // `CreateNextIterationNodes` instead.
+ Status UpdateMergeNodes();
+
+ // Updates consumers of the original `while_op_` to instead use the outputs
+ // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
+ // edges to depend on `lowered_while_output_` instead.
+ Status UpdateConsumers();
+
+ // Inlines the cond and body functions.
+ Status InlineCallNodes();
+
+ // Returns unique name containing the name of the While op being rewritten
+ // (name_), infix and a suffix to ensure it is unique within the graph.
+ string NewName(const string& infix);
+
+ // The original While op.
+ Node* while_op_;
+ // The call node for the cond branch. This gets inlined.
+ Node* cond_call_node_;
+ // The LoopCond node specifying the loop termination condition.
+ Node* loop_cond_node_;
+ // The call node for the body branch. This gets inlined.
+ Node* body_call_node_;
+ // The IdentityN node with the same outputs as the original While op.
+ Node* lowered_while_output_;
+ Graph* graph_;
+ // Name of the `while_op_`.
+ string name_;
+
+ NodeBuilder cond_call_builder_;
+ NodeBuilder body_call_builder_;
+
+ std::vector<Node*> enter_nodes_;
+ std::vector<Node*> merge_nodes_;
+ std::vector<Node*> switch_nodes_;
+ std::vector<Node*> exit_nodes_;
+ std::vector<Node*> next_iterations_nodes_;
+
+ size_t num_loop_inputs_;
+};
+
+LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph)
+ : while_op_(while_op),
+ graph_(graph),
+ name_(while_op->name()),
+ cond_call_builder_(NewName("cond"), cond_fn_name, graph->op_registry()),
+ body_call_builder_(NewName("body"), body_fn_name, graph->op_registry()),
+ num_loop_inputs_(while_op_->num_inputs()) {
+ // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
+ // because we need to set it's elements out of order in `CreateEnterNodes`.
+ enter_nodes_.resize(num_loop_inputs_);
+ merge_nodes_.reserve(num_loop_inputs_);
+ switch_nodes_.reserve(num_loop_inputs_);
+ exit_nodes_.reserve(num_loop_inputs_);
+ next_iterations_nodes_.reserve(num_loop_inputs_);
+}
+
+Status LowerWhileHelper::RunInternal() {
+ TF_RETURN_IF_ERROR(CreateEnterNodes());
+ TF_RETURN_IF_ERROR(CreateMergeNodes());
+ TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateSwitchNodes());
+ TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateExitNodes());
+ TF_RETURN_IF_ERROR(CreateNextIterationNodes());
+ TF_RETURN_IF_ERROR(UpdateMergeNodes());
+ TF_RETURN_IF_ERROR(UpdateConsumers());
+ TF_RETURN_IF_ERROR(InlineCallNodes());
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateEnterNodes() {
+ // Note: `Node::input_edge` runs in O(num_inputs) so we use
+ // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
+ // and not O(num_inputs^2).
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
+ for (const Edge* edge : edges) {
+ Node* enter_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("enter"), "Enter", graph_->op_registry())
+ .Input(NodeOut(edge->src(), edge->src_output()))
+ .Attr("frame_name", name_)
+ .Finalize(graph_, &enter_node));
+ enter_nodes_[edge->dst_input()] = enter_node;
+ }
+ // Create a NoOp node that takes incoming control inputs of the original While
+ // op as control inputs and use it as a control input for all Enter nodes.
+ std::vector<Node*> control_inputs;
+ for (const Edge* e : while_op_->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_inputs.push_back(e->src());
+ }
+ }
+ if (!control_inputs.empty()) {
+ Node* incoming_control_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopControlInputs"), "NoOp", graph_->op_registry())
+ .ControlInputs(control_inputs)
+ .Finalize(graph_, &incoming_control_node));
+ for (Node* n : enter_nodes_) {
+ graph_->AddControlEdge(incoming_control_node, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateMergeNodes() {
+ for (Node* enter_node : enter_nodes_) {
+ Node* merge_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("merge"), "Merge", graph_->op_registry())
+ .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
+ .Finalize(graph_, &merge_node));
+ merge_nodes_.emplace_back(merge_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateCondFuncCallNode() {
+ for (Node* merge_node : merge_nodes_) {
+ cond_call_builder_.Input(NodeOut(merge_node, 0));
+ }
+ TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
+ // Add a control edge to make sure the Const nodes in the cond function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopCond"), "LoopCond", graph_->op_registry())
+ .Input(NodeOut(cond_call_node_, 0))
+ .Finalize(graph_, &loop_cond_node_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateSwitchNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ string op_name;
+ {
+ const Node* input_node;
+ TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
+ op_name = strings::StrCat(input_node->name(), "_switch");
+ }
+ Node* switch_node;
+ string op_type = "Switch";
+ if (IsRefType(merge_nodes_[i]->output_type(0))) {
+ op_type = "RefSwitch";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName(op_name), op_type, graph_->op_registry())
+ .Input(NodeOut(merge_nodes_[i], 0))
+ .Input(NodeOut(loop_cond_node_, 0))
+ .Finalize(graph_, &switch_node));
+ switch_nodes_.emplace_back(switch_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateBodyFuncCallNode() {
+ for (Node* switch_node : switch_nodes_) {
+ body_call_builder_.Input(NodeOut(switch_node, 1));
+ }
+ TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
+ // Add a control edge to make sure the Const nodes in the body function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
+ // this is how tf.while_loop does it. Can this affect performance if the 0th
+ // node is not the first one to be ready? Can we speed that case up using some
+ // sort of multi-input Merge?
+ Node* body_control_node_;
+ string op_type = "Identity";
+ if (IsRefType(switch_nodes_[0]->output_type(1))) {
+ op_type = "RefIdentity";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("loop_body_control"), op_type, graph_->op_registry())
+ .Input(NodeOut(switch_nodes_[0], 1))
+ .Finalize(graph_, &body_control_node_));
+ graph_->AddControlEdge(body_control_node_, body_call_node_);
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateExitNodes() {
+ std::vector<NodeOut> outputs;
+ outputs.reserve(num_loop_inputs_);
+ for (Node* switch_node : switch_nodes_) {
+ Node* exit_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("exit"), "Exit", graph_->op_registry())
+ .Input(NodeOut(switch_node, 0))
+ .Finalize(graph_, &exit_node));
+ exit_nodes_.emplace_back(exit_node);
+ outputs.emplace_back(NodeOut(exit_node, 0));
+ }
+
+ // Add an IdentityN node that has the same outputs and same name as the
+ // original functional While op. This is used for
+ // 1. Rewiring the control edges with the original while op as src.
+ // 2. Fetching the output of the While node by name in calls to sess.run.
+ NodeBuilder ib(name_, "IdentityN");
+ ib.Input(outputs);
+ TF_RETURN_IF_ERROR(ib.Finalize(graph_, &lowered_while_output_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateNextIterationNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ Node* next_iteration;
+ TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
+ graph_->op_registry())
+ .Input(NodeOut(body_call_node_, i))
+ .Finalize(graph_, &next_iteration));
+ next_iterations_nodes_.emplace_back(next_iteration);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateMergeNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ TF_RETURN_IF_ERROR(
+ graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateConsumers() {
+ for (const Edge* e : while_op_->out_edges()) {
+ if (e->IsControlEdge()) {
+ graph_->AddControlEdge(lowered_while_output_, e->dst());
+ } else {
+ // Feed the outputs directly from the exit nodes so that downstream ops
+ // can start before all the outputs have been computed.
+ graph_->AddEdge(exit_nodes_[e->src_output()], 0, e->dst(),
+ e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+string LowerWhileHelper::NewName(const string& infix) {
+ return graph_->NewName(strings::StrCat(name_, "/", infix));
+}
+
+Status InlineCallInGraph(Node* n, Graph* g) {
+ const auto& lib = g->flib_def();
+ const FunctionDef* fdef = lib.Find(n->type_string());
+ CHECK(fdef != nullptr);
+ FunctionBody* fbody;
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
+ [&lib](const string& op, const OpDef** sig) {
+ return lib.LookUpOpDef(op, sig);
+ },
+ &fbody));
+ // TODO(jpienaar): Improve this interface to make the need to delete it
+ // explicit.
+ InlineFunctionBody(g->flib_def(), g, n, fbody, false);
+ delete fbody;
+ return Status::OK();
+}
+
+Status LowerWhileHelper::InlineCallNodes() {
+ TF_RETURN_IF_ERROR(InlineCallInGraph(cond_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(body_call_node_, graph_));
+ return Status::OK();
+}
+
+} // namespace
+
+Status LowerWhileOpPass::Run(const GraphOptimizationPassOptions& options) {
+ if (options.partition_graphs != nullptr) {
+ return errors::Internal(
+ "Lowering While op should happen before partitioning.");
+ }
+ if (options.graph == nullptr) {
+ return Status::OK();
+ }
+
+ Graph* g = options.graph->get();
+ if (g == nullptr) {
+ return errors::Internal(
+ "Lowering While op requires a graph to be available.");
+ }
+
+ // Match all the nodes that need to be rewritten.
+ gtl::InlinedVector<Node*, 2> matches;
+ for (Node* n : g->op_nodes()) {
+ if (n->type_string() == "While") {
+ // Only rewrite if the While op is marked as needing to be lowered.
+ bool match;
+ Status s = GetNodeAttr(n->attrs(),
+ LowerIfOpPass::kLowerUsingSwitchMergeAttr, &match);
+ if (s.ok() && match) matches.push_back(n);
+ }
+ }
+ for (Node* n : matches) {
+ TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileOpPass::RewriteNode(Node* n, Graph* g) {
+ const AttrValue* cond_attr = n->attrs().Find("cond");
+ if (cond_attr == nullptr) {
+ return errors::InvalidArgument("While cond function missing");
+ }
+ const AttrValue* body_attr = n->attrs().Find("body");
+ if (body_attr == nullptr) {
+ return errors::InvalidArgument("While body function missing");
+ }
+
+ TF_RETURN_IF_ERROR(LowerWhileHelper::Run(n, cond_attr->func().name(),
+ body_attr->func().name(), g));
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
+ LowerWhileOpPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/common_runtime/lower_while_op.h
index ea92f61dce..eadafbeb91 100644
--- a/tensorflow/core/util/status_util.h
+++ b/tensorflow/core/common_runtime/lower_while_op.h
@@ -13,24 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-// Creates a tag to be used in an exception error message. This can be parsed by
-// the Python layer and replaced with information about the node.
-//
-// For example, error_format_tag(node, "${file}") returns
-// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as
-// e.g. "file/where/node/was/created.py".
-inline string error_format_tag(const Node& node, const string& format) {
- return strings::StrCat("^^node:", node.name(), ":", format, "^^");
-}
+// Rewrite While ops to use lower level control flow primitives instead.
+class LowerWhileOpPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ private:
+ // Rewrite the given While node `n` in graph `g` to use the lower level
+ // primitives Enter, Exit, Switch, Merge and NextIteration.
+ Status RewriteNode(Node* n, Graph* g);
+};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
new file mode 100644
index 0000000000..27cbada004
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status Rewrite(std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = &flib_def;
+ LowerWhileOpPass pass;
+ return pass.Run(opt_options);
+}
+
+TEST(LowerWhileOpTest, Simple) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *f_lib_proto.add_function() = test::function::XTimesTwo();
+ *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("LessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XTimesTwo");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ int node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ }
+ ASSERT_EQ(node_called_while_count, 1);
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // One node per loop input.
+ ASSERT_EQ(enter_count, 1);
+ ASSERT_EQ(exit_count, 1);
+ ASSERT_EQ(switch_count, 1);
+ ASSERT_EQ(merge_count, 1);
+ ASSERT_EQ(next_iteration_count, 1);
+ ASSERT_EQ(node_called_while_count, 1);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
+ }
+}
+
+TEST(LowerWhileOpTest, MultipleInputs) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *(f_lib_proto.add_function()) = test::function::XPlusOneXTimesY();
+ *(f_lib_proto.add_function()) = test::function::XYXLessThanOrEqualToN(4);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(root.WithOpName("B"), DT_INT32, 1);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs(
+ {NodeBuilder::NodeOut(a.node()), NodeBuilder::NodeOut(b.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("XYXLessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XPlusOneXTimesY");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32, DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ }
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // Two nodes per loop input.
+ ASSERT_EQ(enter_count, 2);
+ ASSERT_EQ(exit_count, 2);
+ ASSERT_EQ(switch_count, 2);
+ ASSERT_EQ(merge_count, 2);
+ ASSERT_EQ(next_iteration_count, 2);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ feeds.emplace(Output(b.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 24);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ feeds.emplace(Output(b.node()), Input::Initializer(5));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 60);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 99bd43e090..df9c3a686c 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -24,9 +24,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -48,6 +50,125 @@ class MklSubAllocator : public SubAllocator {
void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
};
+// CPU allocator that handles small-size allocations by calling
+// suballocator directly. Mostly, it is just a wrapper around a suballocator
+// (that calls malloc and free directly) with support for bookkeeping.
+class MklSmallSizeAllocator : public VisitableAllocator {
+ public:
+ MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ const string& name)
+ : sub_allocator_(sub_allocator), name_(name) {
+ stats_.bytes_limit = total_memory;
+ }
+ ~MklSmallSizeAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator);
+
+ inline string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* ptr = sub_allocator_->Alloc(alignment, num_bytes);
+ if (ptr != nullptr) {
+ std::pair<void*, size_t> map_val(ptr, num_bytes);
+ mutex_lock l(mutex_);
+ // Check that insertion in the hash map was successful.
+ CHECK(map_.insert(map_val).second);
+ // Increment statistics for small-size allocations.
+ IncrementStats(num_bytes);
+ // Call alloc visitors.
+ for (const auto& visitor : alloc_visitors_) {
+ visitor(ptr, num_bytes);
+ }
+ }
+ return ptr;
+ }
+
+ void DeallocateRaw(void* ptr) override {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(mutex_);
+ auto map_iter = map_.find(ptr);
+ if (map_iter != map_.end()) {
+ // Call free visitors.
+ size_t dealloc_bytes = map_iter->second;
+ for (const auto& visitor : free_visitors_) {
+ visitor(ptr, dealloc_bytes);
+ }
+ sub_allocator_->Free(ptr, dealloc_bytes);
+ DecrementStats(dealloc_bytes);
+ map_.erase(map_iter);
+ } else {
+ LOG(ERROR) << "tried to deallocate invalid pointer";
+ return;
+ }
+ }
+
+ inline bool IsSmallSizeAllocation(const void* ptr) const {
+ mutex_lock l(mutex_);
+ return map_.find(ptr) != map_.end();
+ }
+
+ void GetStats(AllocatorStats* stats) override {
+ mutex_lock l(mutex_);
+ *stats = stats_;
+ }
+
+ void ClearStats() override {
+ mutex_lock l(mutex_);
+ stats_.Clear();
+ }
+
+ void AddAllocVisitor(Visitor visitor) override {
+ mutex_lock l(mutex_);
+ alloc_visitors_.push_back(visitor);
+ }
+
+ void AddFreeVisitor(Visitor visitor) override {
+ mutex_lock l(mutex_);
+ free_visitors_.push_back(visitor);
+ }
+
+ private:
+ // Increment statistics for the allocator handling small allocations.
+ inline void IncrementStats(size_t alloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ ++stats_.num_allocs;
+ stats_.bytes_in_use += alloc_size;
+ stats_.max_bytes_in_use =
+ std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+ stats_.max_alloc_size =
+ std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size));
+ }
+
+ // Decrement statistics for the allocator handling small allocations.
+ inline void DecrementStats(size_t dealloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ stats_.bytes_in_use -= dealloc_size;
+ }
+
+ SubAllocator* sub_allocator_; // Not owned by this class.
+
+ // Mutex for protecting updates to map of allocations.
+ mutable mutex mutex_;
+
+ // Allocator name
+ string name_;
+
+ // Hash map to keep track of "small" allocations
+ // We do not use BFC allocator for small allocations.
+ std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_);
+
+ // Allocator stats for small allocs
+ AllocatorStats stats_ GUARDED_BY(mutex_);
+
+ // Visitors
+ std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
+ std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
+};
+
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
class MklCPUAllocator : public VisitableAllocator {
@@ -62,7 +183,10 @@ class MklCPUAllocator : public VisitableAllocator {
MklCPUAllocator() { TF_CHECK_OK(Initialize()); }
- ~MklCPUAllocator() override { delete allocator_; }
+ ~MklCPUAllocator() override {
+ delete small_size_allocator_;
+ delete large_size_allocator_;
+ }
Status Initialize() {
VLOG(2) << "MklCPUAllocator: In MklCPUAllocator";
@@ -96,8 +220,15 @@ class MklCPUAllocator : public VisitableAllocator {
}
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
- allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
- kAllowGrowth, kName);
+
+ sub_allocator_ = new MklSubAllocator();
+
+ // SubAllocator is owned by BFCAllocator, so we do not need to deallocate
+ // it in MklSmallSizeAllocator.
+ small_size_allocator_ =
+ new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName);
+ large_size_allocator_ =
+ new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName);
#ifndef INTEL_MKL_DNN_ONLY
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
@@ -112,23 +243,55 @@ class MklCPUAllocator : public VisitableAllocator {
inline string Name() override { return kName; }
inline void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return allocator_->AllocateRaw(alignment, num_bytes);
+ // If the allocation size is less than threshold, call small allocator,
+ // otherwise call large-size allocator (BFC). We found that BFC allocator
+ // does not deliver good performance for small allocations when
+ // inter_op_parallelism_threads is high.
+ return (num_bytes < kSmallAllocationsThreshold)
+ ? small_size_allocator_->AllocateRaw(alignment, num_bytes)
+ : large_size_allocator_->AllocateRaw(alignment, num_bytes);
}
inline void DeallocateRaw(void* ptr) override {
- allocator_->DeallocateRaw(ptr);
+ // Check if ptr is for "small" allocation. If it is, then call Free
+ // directly. Otherwise, call BFC to handle free.
+ if (small_size_allocator_->IsSmallSizeAllocation(ptr)) {
+ small_size_allocator_->DeallocateRaw(ptr);
+ } else {
+ large_size_allocator_->DeallocateRaw(ptr);
+ }
}
- void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); }
+ void GetStats(AllocatorStats* stats) override {
+ AllocatorStats l_stats, s_stats;
+ small_size_allocator_->GetStats(&s_stats);
+ large_size_allocator_->GetStats(&l_stats);
+
+ // Combine statistics from small-size and large-size allocator.
+ stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs;
+ stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use;
+ stats->max_bytes_in_use =
+ l_stats.max_bytes_in_use + s_stats.max_bytes_in_use;
+
+ // Since small-size allocations go to MklSmallSizeAllocator,
+ // max_alloc_size from large_size_allocator would be the maximum
+ // size allocated by MklCPUAllocator.
+ stats->max_alloc_size = l_stats.max_alloc_size;
+ }
- void ClearStats() override { allocator_->ClearStats(); }
+ void ClearStats() override {
+ small_size_allocator_->ClearStats();
+ large_size_allocator_->ClearStats();
+ }
void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
+ small_size_allocator_->AddAllocVisitor(visitor);
+ large_size_allocator_->AddAllocVisitor(visitor);
}
void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
+ small_size_allocator_->AddFreeVisitor(visitor);
+ large_size_allocator_->AddFreeVisitor(visitor);
}
private:
@@ -148,24 +311,36 @@ class MklCPUAllocator : public VisitableAllocator {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
static inline void* ReallocHook(void* ptr, size_t size) {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
- /// Do we allow growth in BFC Allocator
+ // Do we allow growth in BFC Allocator
static const bool kAllowGrowth = true;
- /// Name
+ // Name
static constexpr const char* kName = "mklcpu";
- /// The alignment that we need for the allocations
+ // The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* allocator_; // owned by this class
+ VisitableAllocator* large_size_allocator_; // owned by this class
+ MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
+
+ SubAllocator* sub_allocator_; // not owned by this class
+
+ // Size in bytes that defines the upper-bound for "small" allocations.
+ // Any allocation below this threshold is "small" allocation.
+ static constexpr const size_t kSmallAllocationsThreshold = 4096;
+
+ // Prevent copying and assignment
+ TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h
index f5d265aa24..6fcd2afd27 100644
--- a/tensorflow/core/common_runtime/optimization_registry.h
+++ b/tensorflow/core/common_runtime/optimization_registry.h
@@ -132,11 +132,12 @@ class OptimizationPassRegistration {
#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
-#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
- static optimization_registration::OptimizationPassRegistration \
- register_optimization_##ctr( \
- grouping, phase, \
- std::unique_ptr<GraphOptimizationPass>(new optimization()), \
+#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
+ static ::tensorflow::optimization_registration::OptimizationPassRegistration \
+ register_optimization_##ctr( \
+ grouping, phase, \
+ ::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \
+ new optimization()), \
#optimization)
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index d581f45a90..3b59995433 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/status_util.h"
namespace tensorflow {
@@ -255,9 +254,11 @@ class ColocationGraph {
old_root_member.device_name,
allow_soft_placement_);
if (!s.ok()) {
- return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
- "' and '", y.name(), ": ",
- s.error_message());
+ return errors::InvalidArgument(
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()), ": ",
+ s.error_message());
}
// Ensure that the common root has at least one supported device
@@ -268,8 +269,10 @@ class ColocationGraph {
old_root_member.supported_device_types);
if (new_root_member.supported_device_types.empty()) {
return errors::InvalidArgument(
- "Cannot colocate nodes '", x.name(), "' and '", y.name(),
- "' because no device type supports both of those nodes and the "
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()),
+ " because no device type supports both of those nodes and the "
"other nodes colocated with them.",
DebugInfo(x_root), DebugInfo(y_root));
}
@@ -377,8 +380,9 @@ class ColocationGraph {
// merged set device is different, so print both.
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
- node->requested_device(),
- "' because the node was colocated with a group of nodes that "
+ node->requested_device(), "' because the node ",
+ errors::FormatColocationNodeForError(node->name()),
+ " was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
members_[node_root].device_name),
@@ -810,10 +814,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
// Returns the first device in sorted devices list so we will always
@@ -857,10 +861,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
int assigned_device = -1;
@@ -926,22 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const {
}
}
-bool Placer::ClientHandlesErrorFormatting() const {
- return options_ != nullptr &&
- options_->config.experimental().client_handles_error_formatting();
-}
-
-// Returns the node name in single quotes. If the client handles formatted
-// errors, appends a formatting tag which the client will reformat into, for
-// example, " (defined at filename:123)".
-string Placer::RichNodeName(const Node* node) const {
- string quoted_name = strings::StrCat("'", node->name(), "'");
- if (ClientHandlesErrorFormatting()) {
- string file_and_line = error_format_tag(*node, "${defined_at}");
- return strings::StrCat(quoted_name, file_and_line);
- } else {
- return quoted_name;
- }
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index fce87269c5..f97ffe7372 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_PLACER_H_
-#define TENSORFLOW_COMMON_RUNTIME_PLACER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
#include <string>
#include <unordered_map>
@@ -87,8 +87,6 @@ class Placer {
// placement if the SessionOptions entry in 'options_' requests it.
void AssignAndLog(int assigned_device, Node* node) const;
void LogDeviceAssignment(const Node* node) const;
- bool ClientHandlesErrorFormatting() const;
- string RichNodeName(const Node* node) const;
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
@@ -100,4 +98,4 @@ class Placer {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_PLACER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 87f2f2ceb9..9b8a95e3b6 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
}
Status s = Place(&g);
- EXPECT_TRUE(
- str_util::StrContains(s.error_message(),
- "Cannot colocate nodes 'foo' and 'in' because no "
- "device type supports both of those nodes and the "
- "other nodes colocated with them"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot colocate nodes {{colocation_node foo}} and "
+ "{{colocation_node in}} because no device type supports both of those "
+ "nodes and the other nodes colocated with them"));
}
TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
@@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
Status s = Place(&g);
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
- "Cannot colocate nodes 'var3' and 'assign3' because no "
- "device type supports both of those nodes and the other "
- "nodes colocated with them."));
+ "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node "
+ "assign3}} because no device type supports both of those nodes and the "
+ "other nodes colocated with them."));
}
TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -1154,36 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
}
SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
EXPECT_TRUE(str_util::StrContains(s.error_message(),
- "Cannot assign a device for operation 'in'"
- "^^node:in:${defined_at}^^"));
-}
-
-// Test that the "Cannot assign a device" error message does not contain a
-// format tag when not it shouldn't
-TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestDevice",
- b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- false);
- Status s = Place(&g, &options);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot assign a device for operation 'in'"));
- EXPECT_FALSE(str_util::StrContains(
- s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+ "Cannot assign a device for operation in"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1289,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot colocate nodes {{colocation_node "
+ "var}} and {{colocation_node assign}}"));
}
// Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 10a24ed14c..fdad8de8d6 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 6dac4c3acf..c43a9d7dc2 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -113,7 +113,7 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
- std::vector<Tensor>* received_tensors, const StatusCallback& done) {
+ std::vector<Tensor>* received_tensors, StatusCallback done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
@@ -121,9 +121,8 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
- RecvOutputsFromRendezvousAsync(
- rendezvous, device_context, alloc_attrs, keys, received_tensors,
- [done](const Status& status) { done(status); });
+ RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
+ received_tensors, std::move(done));
}
Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
@@ -192,7 +191,7 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
}
@@ -204,11 +203,12 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
return kInvalidLocalHandle;
}
- FunctionData* function_data = function_data_[handle].get();
+ FunctionData* function_data = iter->second.get();
if (function_data->target_device() != device_name) {
return kInvalidLocalHandle;
}
@@ -217,9 +217,10 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- CHECK_EQ(1, function_data_.count(handle));
- FunctionData* function_data = function_data_[handle].get();
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ CHECK(iter != function_data_.end());
+ FunctionData* function_data = iter->second.get();
return function_data->target_device();
}
@@ -302,13 +303,15 @@ void ProcessFunctionLibraryRuntime::Run(
string target_device;
FunctionLibraryRuntime::LocalHandle local_handle;
{
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
done(errors::NotFound("Handle: ", handle, " not found."));
return;
}
- target_device = function_data_[handle]->target_device();
- local_handle = function_data_[handle]->local_handle();
+ FunctionData* function_data = iter->second.get();
+ target_device = function_data->target_device();
+ local_handle = function_data->local_handle();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
@@ -339,26 +342,29 @@ void ProcessFunctionLibraryRuntime::Run(
opts.rets_alloc_attrs;
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
- [source_device, target_device, target_incarnation, rendezvous,
- device_context, rets_alloc_attrs, remote_rets, rets,
- done](const Status& status) {
- if (!status.ok()) {
- delete remote_rets;
- done(status);
- return;
- }
- int64 num_returns = remote_rets->size();
- delete remote_rets;
- // Now receive the return values from the target.
- ReceiveTensorsAsync(target_device, source_device, "ret_",
- target_incarnation, num_returns,
- device_context, rets_alloc_attrs, rendezvous,
- rets, done);
- });
+ std::bind(
+ [source_device, target_device, target_incarnation, rendezvous,
+ device_context, rets_alloc_attrs, remote_rets,
+ rets](const Status& status,
+ FunctionLibraryRuntime::DoneCallback& done) {
+ if (!status.ok()) {
+ delete remote_rets;
+ done(status);
+ return;
+ }
+ int64 num_returns = remote_rets->size();
+ delete remote_rets;
+ // Now receive the return values from the target.
+ ReceiveTensorsAsync(target_device, source_device, "ret_",
+ target_incarnation, num_returns,
+ device_context, rets_alloc_attrs,
+ rendezvous, rets, std::move(done));
+ },
+ std::placeholders::_1, std::move(done)));
return;
}
if (parent_ != nullptr) {
- parent_->Run(opts, local_handle, args, rets, done);
+ parent_->Run(opts, local_handle, args, rets, std::move(done));
return;
}
done(errors::Internal("Could not find device"));
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 69381dd34d..53815715d8 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -59,8 +60,6 @@ class ProcessFunctionLibraryRuntime {
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous);
- typedef std::function<void(const Status&)> StatusCallback;
-
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. `device_context` should be for the device receiving
@@ -73,7 +72,7 @@ class ProcessFunctionLibraryRuntime {
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
index cb5848ede3..b4d8ab4eb2 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.h
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
#include <string>
#include <unordered_map>
@@ -87,4 +87,4 @@ class IntraProcessRendezvous : public Rendezvous {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 92dc03812e..1e3fed0d6f 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
namespace tensorflow {
Status SendTensorsToRendezvous(
@@ -54,7 +56,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done) {
+ StatusCallback done) {
if (keys.empty()) {
done(Status::OK());
return;
@@ -85,13 +87,7 @@ void RecvOutputsFromRendezvousAsync(
alloc_attr);
}
- typedef struct {
- mutex mu;
- int64 done_counter;
- Status shared_status = Status::OK();
- } CallState;
- CallState* call_state = new CallState;
- call_state->done_counter = keys.size();
+ auto status_cb = new ReffedStatusCallback(std::move(done));
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
@@ -99,13 +95,13 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous::Args rendez_args;
rendez_args.device_context = device_context;
rendez_args.alloc_attrs = std::get<3>(p);
-
+ status_cb->Ref();
rendezvous->RecvAsync(
parsed, rendez_args,
- [val, done, key, call_state](const Status& s,
- const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args,
- const Tensor& v, const bool is_dead) {
+ [val, key, status_cb](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
@@ -114,20 +110,11 @@ void RecvOutputsFromRendezvousAsync(
" was not valid.");
}
}
- call_state->mu.lock();
- call_state->shared_status.Update(status);
- call_state->done_counter--;
- // If we are the last async call to return, call the done callback.
- if (call_state->done_counter == 0) {
- const Status& final_status = call_state->shared_status;
- call_state->mu.unlock();
- done(final_status);
- delete call_state;
- return;
- }
- call_state->mu.unlock();
+ status_cb->UpdateStatus(status);
+ status_cb->Unref();
});
}
+ status_cb->Unref();
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h
index aad910f6d8..deb9a7c822 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.h
+++ b/tensorflow/core/common_runtime/rendezvous_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -42,7 +43,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index e26761703b..a81f8650bf 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -14,13 +14,30 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include <stdlib.h>
+#include <atomic>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
@@ -36,7 +53,8 @@ string RingReduceBufKey(const string& exec_key, int pass, int section,
return strings::StrCat("rred(", exec_key, "):pass(", pass, "):section(",
section, "):srcrank(", source_rank, ")");
} else {
- // TODO(tucker): Try out some kind of denser encoding, e.g. 128 bit hash.
+ // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
+ // hash.
return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
}
}
@@ -65,105 +83,149 @@ RingReducer::RingField* RingReducer::PCQueue::Dequeue() {
return rf;
}
-RingReducer::RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx,
- OpKernelContext::Params* op_params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- op_params_(op_params),
- col_params_(col_params),
- exec_key_(exec_key),
- input_(input),
- output_(output),
- rank_(col_params.subdiv_rank[0]),
- step_id_(step_id),
- group_size_(col_params.group.group_size),
- num_subdivs_(static_cast<int>(
- col_params.instance.impl_details.subdiv_permutations.size())),
+RingReducer::RingReducer()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
done_(nullptr),
- device_(nullptr),
- device_name_(
- col_params_.instance.device_names[col_params_.default_rank]) {
- CHECK_GT(group_size_, 0);
- CHECK_GT(num_subdivs_, 0);
-}
+ group_size_(-1),
+ num_subdivs_(-1) {}
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
-string RingReducer::TensorDebugString(Tensor tensor) {
- const DeviceBase::GpuDeviceInfo* gpu_device_info =
- ctx_->device()->tensorflow_gpu_device_info();
- if (gpu_device_info) {
- Tensor cpu_tensor(tensor.dtype(), tensor.shape());
- Notification note;
- gpu_device_info->default_context->CopyDeviceTensorToCPU(
- &tensor, "" /*tensor_name*/, device_, &cpu_tensor,
- [&note](const Status& s) {
- CHECK(s.ok());
- note.Notify();
- });
- note.WaitForNotification();
- return cpu_tensor.SummarizeValue(64);
- } else {
- return tensor.SummarizeValue(64);
+Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Each subdiv permutation is a ring formed by rotating each
+ // single-task subsequence of devices by an offset. This makes most
+ // sense when each task has the same number of devices but we can't
+ // depend on that being the case so we'll compute something that
+ // works in any case.
+
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ // Generate a ring permutation for each requested offset.
+ if (col_params->instance.impl_details.subdiv_offsets.empty()) {
+ return errors::Internal(
+ "Subdiv offsets should be non-empty for ring reducer, size=",
+ col_params->instance.impl_details.subdiv_offsets.size());
+ }
+ VLOG(2) << "Setting up perms for col_params " << col_params
+ << " subdiv_permutations "
+ << &col_params->instance.impl_details.subdiv_permutations;
+ col_params->instance.impl_details.subdiv_permutations.resize(
+ col_params->instance.impl_details.subdiv_offsets.size());
+ col_params->subdiv_rank.resize(
+ col_params->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0;
+ sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
+ // A negative subdivision offset is interpreted as follows:
+ // 1. Reverse the local device ordering.
+ // 2. Begin the subdivision at abs(offset) in the reversed ordering.
+ bool reverse = false;
+ if (offset < 0) {
+ offset = abs(offset);
+ reverse = true;
+ }
+ int prior_dev_count = 0; // sum over prior worker device counts
+ for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int di_offset = (di + offset) % dev_per_task[ti];
+ int offset_di =
+ reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
+ // Device index in global subdivision permutation.
+ int permuted_di = prior_dev_count + offset_di;
+ int rank = static_cast<int>(perm.size());
+ perm.push_back(permuted_di);
+ if (col_params->instance.device_names[permuted_di] == device_name) {
+ CHECK_EQ(permuted_di, col_params->default_rank);
+ col_params->subdiv_rank[sdi] = rank;
+ }
+ }
+ prior_dev_count += dev_per_task[ti];
+ }
+ CHECK_EQ(col_params->group.group_size, perm.size());
}
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status RingReducer::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
}
void RingReducer::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
done_ = std::move(done);
+ group_size_ = col_params_->group.group_size;
+ num_subdivs_ = static_cast<int>(
+ col_params_->instance.impl_details.subdiv_permutations.size());
+ CHECK_GT(num_subdivs_, 0);
- // Get local execution device.
if (VLOG_IS_ON(1)) {
string buf;
- for (int r = 0; r < col_params_.instance.device_names.size(); ++r) {
+ for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
strings::StrAppend(&buf, "dev ", r, " : ",
- col_params_.instance.device_names[r], "\n");
+ col_params_->instance.device_names[r], "\n");
}
for (int sd = 0;
- sd < col_params_.instance.impl_details.subdiv_permutations.size();
+ sd < col_params_->instance.impl_details.subdiv_permutations.size();
++sd) {
strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
- for (auto x : col_params_.instance.impl_details.subdiv_permutations[sd]) {
+ for (auto x :
+ col_params_->instance.impl_details.subdiv_permutations[sd]) {
strings::StrAppend(&buf, x, ", ");
}
}
- VLOG(1) << "RingReducer::Run for device " << device_name_
- << " default_rank " << col_params_.default_rank << "\n"
+ VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
+ << " default_rank " << col_params_->default_rank << "\n"
<< buf;
}
- CHECK(dev_mgr_);
- Status status = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status.ok()) {
- LOG(ERROR) << "Failed to find device "
- << col_params_.instance.device_names[col_params_.default_rank];
- for (auto d : dev_mgr_->ListDevices()) {
- LOG(ERROR) << "Available device " << d->name();
- }
- done_(status);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- VLOG(1) << this << " default_rank " << col_params_.default_rank << " cp "
- << &col_params_ << ": " << col_params_.ToString();
// Start by copying input to output if they're not already the same, i.e. if
// we're not computing in-place on the input tensor.
- if ((input_ != output_) &&
- (DMAHelper::base(input_) != DMAHelper::base(output_))) {
+ if ((col_ctx_->input != col_ctx_->output) &&
+ (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
// We are running in a blockable thread and the callback can't block so
// just wait here on the copy.
Notification note;
+ Status status;
CollectiveRemoteAccessLocal::MemCpyAsync(
- ctx_->input_device_context(0), ctx_->op_device_context(), device_,
- device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_,
- output_, 0 /*dev_to_dev_stream_index*/,
+ col_ctx_->op_ctx->input_device_context(0),
+ col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
+ col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
[this, &note, &status](const Status& s) {
status.Update(s);
note.Notify();
@@ -177,24 +239,43 @@ void RingReducer::Run(StatusCallback done) {
ContinueAfterInputCopy();
}
+string RingReducer::TensorDebugString(const Tensor& tensor) {
+ const DeviceBase::GpuDeviceInfo* gpu_device_info =
+ col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
+ if (gpu_device_info) {
+ Tensor cpu_tensor(tensor.dtype(), tensor.shape());
+ Notification note;
+ gpu_device_info->default_context->CopyDeviceTensorToCPU(
+ &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
+ [&note](const Status& s) {
+ CHECK(s.ok());
+ note.Notify();
+ });
+ note.WaitForNotification();
+ return cpu_tensor.SummarizeValue(64);
+ } else {
+ return tensor.SummarizeValue(64);
+ }
+}
+
// Note that this function is blocking and must not run in any thread
// which cannot be blocked.
void RingReducer::ContinueAfterInputCopy() {
- AllocatorAttributes attr = ctx_->output_alloc_attr(0);
- ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_,
- device_->GetAllocator(attr)));
+ AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
+ ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
+ col_ctx_->device->GetAllocator(attr)));
- if (col_params_.final_op) {
+ if (col_params_->final_op) {
// Create an on-device scalar value from group_size_ that may be needed
// later.
// TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
// can be provided to the kernel in host memory?
Tensor group_size_val = ca_->Scalar(group_size_);
- if (col_params_.group.device_type != "CPU") {
- group_size_tensor_ =
- ca_->Scalar(device_->GetAllocator(ctx_->input_alloc_attr(0)));
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, device_,
+ if (col_params_->group.device_type != "CPU") {
+ group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
+ col_ctx_->op_ctx->input_alloc_attr(0)));
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
&group_size_tensor_,
[this](const Status& s) {
if (!s.ok()) {
@@ -231,14 +312,14 @@ void RingReducer::StartAbort(const Status& s) {
// cancellation on all of the outstanding CollectiveRemoteAccess
// actions.
if (abort_started) {
- col_exec_->StartAbort(s);
+ col_ctx_->col_exec->StartAbort(s);
}
}
void RingReducer::Finish(bool ok) {
if (ok) {
// Recover the output from the adaptor.
- ca_->ConsumeFinalValue(output_);
+ ca_->ConsumeFinalValue(col_ctx_->output);
}
Status s;
{
@@ -275,7 +356,7 @@ Status RingReducer::ComputeBinOp(Device* device, OpKernel* op, Tensor* output,
// TODO(tucker): Is it possible to cache and reuse these objects? They're
// mostly identical inside one device execution.
std::unique_ptr<SubContext> sub_ctx(
- new SubContext(ctx_, op_params_, op, output, input));
+ new SubContext(col_ctx_->op_ctx, col_ctx_->op_params, op, output, input));
device->Compute(op, sub_ctx->sub_ctx_);
return sub_ctx->sub_ctx_->status();
}
@@ -295,18 +376,18 @@ void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
rf->chunk_idx = chunk_idx;
rf->subdiv_idx = subdiv_idx;
rf->sc_idx = field_idx;
- rf->rank = col_params_.subdiv_rank[subdiv_idx];
+ rf->rank = col_params_->subdiv_rank[subdiv_idx];
rf->second_pass = false;
rf->action = RF_INIT;
// Recv from the device with preceding rank within the subdivision.
int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
int send_to_rank = (rf->rank + 1) % group_size_;
- rf->recv_dev_idx = col_params_.instance.impl_details
+ rf->recv_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][recv_from_rank];
- int send_dev_idx = col_params_.instance.impl_details
+ int send_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][send_to_rank];
- rf->recv_is_remote = !col_params_.task.is_local[rf->recv_dev_idx];
- rf->send_is_remote = !col_params_.task.is_local[send_dev_idx];
+ rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
+ rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
if (ca_->ChunkBytes(rf->sc_idx) > 0) {
// In pass 0 we skip Recv when rank = chunk_idx
rf->do_recv = (rf->chunk_idx != rf->rank);
@@ -360,45 +441,47 @@ string RingReducer::RingField::DebugString() const {
void RingReducer::DispatchSend(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_send);
- string send_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, rf->rank);
- VLOG(3) << "DispatchSend rank=" << col_params_.default_rank << " send key "
+ string send_buf_key = RingReduceBufKey(col_ctx_->exec_key, rf->second_pass,
+ rf->sc_idx, rf->rank);
+ VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
<< send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
<< rf->sc_idx;
int send_to_rank = (rf->rank + 1) % group_size_;
- int send_to_dev_idx = col_params_.instance.impl_details
+ int send_to_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[rf->subdiv_idx][send_to_rank];
- col_exec_->PostToPeer(col_params_.instance.device_names[send_to_dev_idx],
- col_params_.instance.task_names[send_to_dev_idx],
- send_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), &rf->chunk,
- device_locality_, done);
+ col_ctx_->col_exec->PostToPeer(
+ col_params_->instance.device_names[send_to_dev_idx],
+ col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
+ col_ctx_->device_locality, done);
}
void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_recv);
string recv_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx,
+ RingReduceBufKey(col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
(rf->rank + (group_size_ - 1)) % group_size_);
- VLOG(3) << "DispatchRecv rank=" << col_params_.default_rank << " recv key "
+ VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
<< recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
- << ((col_params_.merge_op != nullptr) ? "tmp_chunk" : "chunk");
- Tensor* dst_tensor = (!rf->second_pass && (col_params_.merge_op != nullptr))
+ << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
+ Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
? &rf->tmp_chunk
: &rf->chunk;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[rf->recv_dev_idx],
- col_params_.instance.task_names[rf->recv_dev_idx],
- col_params_.task.is_local[rf->recv_dev_idx],
- recv_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, rf->subdiv_idx, done);
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[rf->recv_dev_idx],
+ col_params_->instance.task_names[rf->recv_dev_idx],
+ col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, rf->subdiv_idx, done);
}
string RingReducer::FieldState() {
- string s = strings::StrCat("RingReducer ",
- strings::Hex(reinterpret_cast<uint64>(this)),
- " exec ", exec_key_, " step_id=", step_id_,
- " state of all ", rfv_.size(), " fields:");
+ string s = strings::StrCat(
+ "RingReducer ", strings::Hex(reinterpret_cast<uint64>(this)), " exec ",
+ col_ctx_->exec_key, " step_id=", col_ctx_->step_id, " state of all ",
+ rfv_.size(), " fields:");
for (int i = 0; i < rfv_.size(); ++i) {
s.append("\n");
s.append(rfv_[i].DebugString());
@@ -415,13 +498,6 @@ bool RingReducer::RunAsyncParts() {
rfv_.clear();
rfv_.resize(group_size_ * num_subdivs_);
PCQueue ready_queue;
- int field_done_count = 0;
- int send_pending_count = 0;
- int recv_pending_count = 0;
- std::atomic<bool> aborted(false);
- field_done_count = 0;
- send_pending_count = 0;
- recv_pending_count = 0;
for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
@@ -429,6 +505,30 @@ bool RingReducer::RunAsyncParts() {
ready_queue.Enqueue(&rfv_[rf_index]);
}
}
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ col_ctx_->device->tensorflow_gpu_device_info();
+ if (gpu_info) {
+ // Wait for all currently queued events on the CPU compute stream to
+ // complete before proceeding. The previous InitRingField calls allocated
+ // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
+ // write) unless we do.
+ Notification note;
+ Status s = gpu_info->default_context->ThenExecute(
+ col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
+ if (s.ok()) {
+ note.WaitForNotification();
+ } else {
+ mutex_lock l(status_mu_);
+ status_ =
+ errors::Internal("Failed to dispatch ThenExecute in RingReducer");
+ return false;
+ }
+ }
+
+ int field_done_count = 0;
+ int send_pending_count = 0;
+ int recv_pending_count = 0;
+ std::atomic<bool> aborted(false);
// Loop until all RingFields have advanced to completion.
while (field_done_count < rfv_.size()) {
@@ -468,8 +568,9 @@ bool RingReducer::RunAsyncParts() {
--recv_pending_count;
if (!rf->second_pass) {
rf->action = RF_REDUCE;
- Status s = ComputeBinOp(device_, col_params_.merge_op.get(),
- &rf->chunk, &rf->tmp_chunk);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->merge_op.get(),
+ &rf->chunk, &rf->tmp_chunk);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -479,11 +580,12 @@ bool RingReducer::RunAsyncParts() {
}
break;
case RF_REDUCE:
- if (!rf->second_pass && col_params_.final_op.get() && rf->is_final) {
+ if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) {
rf->action = RF_FINALIZE;
group_size_tensor_ready_.WaitForNotification();
- Status s = ComputeBinOp(device_, col_params_.final_op.get(),
- &rf->chunk, &group_size_tensor_);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->final_op.get(),
+ &rf->chunk, &group_size_tensor_);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -552,9 +654,11 @@ bool RingReducer::RunAsyncParts() {
CHECK_EQ(send_pending_count, 0);
CHECK_EQ(recv_pending_count, 0);
- VLOG(2) << this << " rank=" << rank_ << " finish;"
+ VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
<< " final value " << TensorDebugString(ca_->Value());
return !aborted;
}
+REGISTER_COLLECTIVE(RingReduce, RingReducer);
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h
index 3e1988e787..0848e37b52 100644
--- a/tensorflow/core/common_runtime/ring_reducer.h
+++ b/tensorflow/core/common_runtime/ring_reducer.h
@@ -16,25 +16,35 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_
#include <deque>
+#include <memory>
+#include <string>
+#include <vector>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class DeviceMgr;
+class Device;
// Ring-algorithm implementation of collective all-reduce.
-class RingReducer {
+class RingReducer : public CollectiveImplementationInterface {
public:
- RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* op_params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, const Tensor* input, Tensor* output);
+ RingReducer();
+ ~RingReducer() override;
- virtual ~RingReducer();
+ // Establishes the requested number of subdivision permutations based on the
+ // ring order implicit in the device order.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the ring reduce algorithm.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
private:
// Called when a bad status is received that implies we should terminate
@@ -101,7 +111,7 @@ class RingReducer {
// For constructing log messages for debugging.
string FieldState();
- string TensorDebugString(Tensor tensor);
+ string TensorDebugString(const Tensor& tensor);
// Producer/Consumer Queue of RingField structs.
class PCQueue {
@@ -116,30 +126,19 @@ class RingReducer {
std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_);
};
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- OpKernelContext::Params* op_params_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const Tensor* input_; // Not owned
- Tensor* output_; // Not owned
- const int rank_;
- const int64 step_id_;
- const int group_size_;
- const int num_subdivs_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
+ StatusCallback done_;
+ int group_size_;
+ int num_subdivs_;
Tensor group_size_tensor_;
Notification group_size_tensor_ready_;
std::unique_ptr<CollectiveAdapter> ca_;
- StatusCallback done_;
- Device* device_; // The device for which this instance labors
- const string device_name_;
- DeviceLocality device_locality_;
-
mutex status_mu_;
Status status_ GUARDED_BY(status_mu_);
-
std::vector<RingField> rfv_;
+
+ friend class RingReducerTest;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index fcdf9deff8..28df85399e 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-namespace {
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
@@ -135,27 +134,28 @@ class RingReducerTest : public ::testing::Test {
protected:
RingReducerTest() : device_type_(DEVICE_CPU) {}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
~RingReducerTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
void Init(int num_workers, int num_devices, DataType dtype,
const DeviceType& device_type, int num_subdivs, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
device_type_ = device_type;
std::vector<Device*> local_devices;
SessionOptions sess_opts;
@@ -201,6 +201,7 @@ class RingReducerTest : public ::testing::Test {
col_params_.instance.instance_key = kInstanceKey;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = REDUCTION_COLLECTIVE;
+ col_params_.instance.impl_details.collective_name = "RingReduce";
col_params_.instance.data_type = dtype;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
@@ -259,13 +260,17 @@ class RingReducerTest : public ::testing::Test {
}
}
- void Reduce() {
+ void Reduce(int fail_after) {
std::atomic<int> done(0);
for (auto di : instances_) {
SchedClosure([di, &done] {
di->DoReduce();
++done;
});
+ if (fail_after > 0) {
+ // Stagger the op execution starts.
+ Env::Default()->SleepForMicroseconds(100);
+ }
}
while (done < static_cast<int>(instances_.size())) {
if (stop_) break;
@@ -295,7 +300,7 @@ class RingReducerTest : public ::testing::Test {
}
});
}
- Reduce();
+ Reduce(fail_after);
if (fail_after > 0) {
// Confirm that every device terminated with the expected error status.
for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
@@ -373,6 +378,22 @@ class RingReducerTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ // Create a stub ring reducer only for testing param initialization.
+ RingReducer reducer;
+ TF_CHECK_OK(reducer.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ reducer.group_size_tensor_ready_.Notify(); // To unblock destructor.
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
@@ -475,8 +496,8 @@ class RingReducerTest : public ::testing::Test {
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
- // We never actually execute the kernel, so we need to do the
- // output allocation that it would do, ourselves.
+ // We never actually execute the kernel, so we need to do the output
+ // allocation it would do, ourselves.
Tensor* output_tensor_ptr = nullptr;
TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
&output_tensor_ptr));
@@ -485,20 +506,17 @@ class RingReducerTest : public ::testing::Test {
// Prepare a RingReducer instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- RingReducer rr(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId, &tensor_,
- &tensor_);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- SchedClosure([this, &notification, &rr]() {
- rr.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- });
- notification.WaitForNotification();
- CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ RingReducer reducer;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, &tensor_, &tensor_);
+ TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+
+ // Run the all-reduce.
+ reducer.Run([this](Status s) { status_ = s; });
+ if (status_.ok()) {
+ CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ }
dev_ctx->Unref();
}
@@ -531,6 +549,57 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
+TEST_F(RingReducerTest, InitializeParams) {
+ static const int kNumDevsPerTask = 8;
+ static const int kNumTasks = 3;
+ static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp;
+ std::vector<string> device_names;
+ std::vector<string> task_names;
+ cp.group.group_key = 1;
+ cp.group.group_size = kNumDevs;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = kNumTasks;
+ cp.instance.instance_key = 3;
+ cp.instance.type = REDUCTION_COLLECTIVE;
+ cp.instance.data_type = DataType(DT_FLOAT);
+ cp.instance.shape = TensorShape({5});
+ cp.instance.impl_details.collective_name = "RingReduce";
+ cp.instance.impl_details.subdiv_offsets.push_back(0);
+ cp.is_source = false;
+ for (int i = 0; i < kNumDevs; ++i) {
+ int task_id = i / kNumDevsPerTask;
+ int dev_id = i % kNumDevsPerTask;
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
+ task_names.push_back(task_name);
+ string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
+ device_names.push_back(device_name);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(device_name);
+ }
+
+ int test_rank = 0;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, 4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
+ 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
+ {0, 4});
+
+ test_rank = 3;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {3, -3};
+ RunSubdivPermsTest(&cp,
+ {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
+ 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
+ {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
+ 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}},
+ {0, 1});
+}
+
+// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
@@ -575,6 +644,7 @@ DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0)
DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0)
// Failure tests
+DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 1)
DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
#endif
@@ -604,5 +674,4 @@ DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2)
DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5)
#endif
-} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
index 81c172c6ae..8565088afc 100644
--- a/tensorflow/core/common_runtime/session_factory.h
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
-#define TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
#include <string>
@@ -73,4 +73,4 @@ class SessionFactory {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 65ff356e73..5b1915755d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names,
// Save only the tensors in output_names in the session.
for (const string& name : output_names) {
TensorId id(ParseTensorName(name));
- const string& op_name = std::string(id.first);
+ const string op_name(id.first);
auto it = tensors_.find(op_name);
if (it != tensors_.end()) {
// Save the tensor to the session state.
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index af6880c6b3..836cb8ed14 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -16,12 +16,16 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -36,11 +40,89 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper()
- : NodeExecStatsWrapper(new NodeExecStats) {}
+NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
+ : NodeExecStatsWrapper(new NodeExecStats) {
+ stats_->set_node_name(node_name);
+}
NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
: stats_(stats) {}
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
+ DCHECK(v);
+ NodeOutput* no = stats_->add_output();
+ no->set_slot(slot);
+ v->FillDescription(no->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
+}
+
+// TODO(tucker): merge with the DetailText function in session.cc
+// in a common location.
+bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
+ bool is_transfer_node = false;
+ string memory;
+ for (auto& all : stats_->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const AttrSlice attrs = node->attrs();
+ string text;
+ if (IsSend(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
+ text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ "(", tensor_name, " @", recv_device);
+ is_transfer_node = true;
+ } else if (IsRecv(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
+ text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ "(", tensor_name, " @", send_device);
+ is_transfer_node = true;
+ } else {
+ text =
+ strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
+ str_util::Join(node->requested_inputs(), ", "), ")");
+ }
+ stats_->set_timeline_label(text);
+ return is_transfer_node;
+}
+
void NodeExecStatsWrapper::AddAllocation(
Allocator* allocator, TrackingAllocator* tracking_allocator) {
AllocatorMemoryUsed* memory = stats_->add_memory();
@@ -94,7 +176,7 @@ static int ExtractGpuWithStreamAll(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -123,7 +205,7 @@ static int ExtractGpuWithoutStream(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -170,7 +252,7 @@ void StepStatsCollector::BuildCostModel(
for (auto& itr : per_device_stats) {
const StringPiece device_name = itr.first;
- const int gpu_id = ExtractGpuWithoutStream(std::string(device_name));
+ const int gpu_id = ExtractGpuWithoutStream(string(device_name));
if (gpu_id >= 0) {
// Reference the gpu hardware stats in addition to the regular stats
// for this gpu device if they're available.
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 0394f25839..7206fbf427 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -19,7 +19,9 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@@ -30,33 +32,99 @@ class Allocator;
class AllocatorMemoryUsed;
class CostModelManager;
class Graph;
+class Node;
class NodeExecStats;
+class OpKernelContext;
class StepStats;
+class Tensor;
class TrackingAllocator;
// Wraps NodeExecStats and adds allocation to it.
class NodeExecStatsWrapper {
public:
- NodeExecStatsWrapper();
+ NodeExecStatsWrapper(const string& node_name);
// Owns 'stats'.
NodeExecStatsWrapper(NodeExecStats* stats);
// Destructor calls Finalize() to release the TrackingAllocators.
~NodeExecStatsWrapper() { Finalize(); }
- NodeExecStats* stats() { return stats_.get(); }
-
- // "Does not take ownership of the 'allocator'.
- // Transfers ownership of the 'tracking_allocator' to *this."
- void AddAllocation(Allocator* allocator,
- TrackingAllocator* tracking_allocator);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ void SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+ }
+
+ // Called immediately after this node starts being processed by the executor.
+ void RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+ }
+
+ // Called immediately before this node's `Compute()` or `ComputeAsync()`
+ // method is called.
+ void RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Called immediately after this node's `Compute()` method returned (or, for
+ // asynchronous operations, the callback passed to its `ComputeAsync()` method
+ // was called).
+ void RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Called immediately after this executor finishes processing this node.
+ void RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+ }
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ void SetOutput(int slot, const Tensor* v);
+
+ // Records information about the memory allocated during the execution of this
+ // node.
+ void SetMemory(OpKernelContext* ctx);
+
+ // Records information about the tensors that were accessed during the
+ // execution of this node.
+ void SetReferencedTensors(const TensorReferenceVector& tensors);
+
+ // Sets the timeline_label field of the wrapped NodeExecStats, using data
+ // from *node. Returns true iff the node is a transfer node.
+ bool SetTimelineLabel(const Node* node);
private:
friend class StepStatsCollector;
+ NodeExecStats* stats() { return stats_.get(); }
+
// Populates stats_ and releases TrackingAllocator.
void Finalize();
+ // Does not take ownership of the `allocator`.
+ // Takes ownership of `tracking_allocator`.
+ void AddAllocation(Allocator* allocator,
+ TrackingAllocator* tracking_allocator);
+
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
index 550f193332..cc5909de17 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building TensorFlow with SYCL support
#endif
-#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h"
@@ -72,4 +72,4 @@ class SYCLAllocator : public Allocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 7406ecf4f8..0fbc20b34b 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -70,17 +70,6 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
ThreadPoolDevice::~ThreadPoolDevice() {}
-void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
- // When Xprof/ThreadScape profiling is off (which is the default), the
- // following code is simple enough that its overhead is negligible.
- tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
- op_kernel->IsExpensive());
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name());
-
- op_kernel->Compute(context);
-}
-
Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_;
}
diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h
index afc5d15ebc..51bd038a1c 100644
--- a/tensorflow/core/common_runtime/threadpool_device.h
+++ b/tensorflow/core/common_runtime/threadpool_device.h
@@ -29,7 +29,6 @@ class ThreadPoolDevice : public LocalDevice {
Allocator* allocator);
~ThreadPoolDevice() override;
- void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
Allocator* GetScopedAllocator(AllocatorAttributes attr,
int64 step_id) override;
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
new file mode 100644
index 0000000000..e1b163074f
--- /dev/null
+++ b/tensorflow/core/common_runtime/tracing_device.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+
+namespace test {
+class Benchmark;
+}
+struct SessionOptions;
+
+// This class implements tracing functionality that is shared by its subclasses
+// (including ThreadPoolDevice and XlaDevice).
+class TracingDevice : public Device {
+ public:
+ TracingDevice(Env* env, const DeviceAttributes& attributes)
+ : Device(env, attributes) {}
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+ const tracing::TraceCollector* trace_collector =
+ tracing::GetTraceCollector();
+ if (TF_PREDICT_FALSE(
+ (trace_collector &&
+ trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
+ tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
+ const string& op_name = op_kernel->name();
+ tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
+ op_kernel->IsExpensive());
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute, op_name);
+ op_kernel->Compute(context);
+ } else {
+ op_kernel->Compute(context);
+ }
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TracingDevice);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
index 8edf922d11..ae0563a96a 100644
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ b/tensorflow/core/common_runtime/visitable_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
#include <functional>
#include "tensorflow/core/framework/allocator.h"
@@ -76,4 +76,4 @@ class TrackingVisitableAllocator : public TrackingAllocator,
VisitableAllocator* allocator_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/debug/debug_callback_registry.h b/tensorflow/core/debug/debug_callback_registry.h
index 8f08c656c2..bcd4ddc50c 100644
--- a/tensorflow/core/debug/debug_callback_registry.h
+++ b/tensorflow/core/debug/debug_callback_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
-#define TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
#include <functional>
#include <map>
@@ -68,4 +68,4 @@ class DebugCallbackRegistry {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 7641edea52..5fc95a8f20 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -356,8 +356,8 @@ Status DebugNodeInserter::ParseDebugOpName(
"Malformed attributes in debug op name \"", debug_op_name, "\"");
}
- const string key = std::string(seg.substr(0, eq_index));
- const string value = std::string(
+ const string key(seg.substr(0, eq_index));
+ const string value(
seg.substr(eq_index + 1, attribute_seg.size() - eq_index - 1));
if (key.empty() || value.empty()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h
index 64deff1f00..86dc90a134 100644
--- a/tensorflow/core/debug/debug_graph_utils.h
+++ b/tensorflow/core/debug/debug_graph_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_NODE_INSERTER_H_
-#define TENSORFLOW_DEBUG_NODE_INSERTER_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
#include <unordered_map>
#include <vector>
@@ -123,4 +123,4 @@ class DebugNodeInserter {
};
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_NODE_INSERTER_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 8d3c9ff575..93376613b6 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
-#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
#include <atomic>
#include <unordered_set>
@@ -84,4 +84,4 @@ bool PollTillFirstRequestSucceeds(const string& server_url,
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 9e8002d490..6994dec3b5 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <stddef.h>
#include <string.h>
#include <cmath>
+#include <cstdlib>
+#include <cstring>
#include <limits>
#include <utility>
#include <vector>
@@ -399,8 +401,8 @@ Status DebugIO::PublishDebugMetadata(
strings::Printf("%.14lld", session_run_index))),
Env::Default()->NowMicros());
status.Update(DebugFileIO::DumpEventProtoToFile(
- event, std::string(io::Dirname(core_metadata_path)),
- std::string(io::Basename(core_metadata_path))));
+ event, string(io::Dirname(core_metadata_path)),
+ string(io::Basename(core_metadata_path))));
}
}
@@ -418,6 +420,19 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
+ const int64 tensorBytes =
+ tensor.IsInitialized() ? tensor.TotalBytes() : 0;
+ if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
+ return errors::ResourceExhausted(
+ "TensorFlow Debugger has exhausted file-system byte-size "
+ "allowance (",
+ DebugFileIO::globalDiskBytesLimit, "), therefore it cannot ",
+ "dump an additional ", tensorBytes, " byte(s) of tensor data ",
+ "for the debug tensor ", debug_node_key.node_name, ":",
+ debug_node_key.output_slot, ". You may use the environment ",
+ "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
+ }
+
Status s = DebugFileIO::DumpTensorToDir(
debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
if (!s.ok()) {
@@ -632,8 +647,8 @@ Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
std::vector<Event> events;
TF_RETURN_IF_ERROR(
WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
- return DumpEventProtoToFile(events[0], std::string(io::Dirname(file_path)),
- std::string(io::Basename(file_path)));
+ return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
+ string(io::Basename(file_path)));
}
Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
@@ -642,7 +657,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
return Status::OK();
}
- string parent_dir = std::string(io::Dirname(dir));
+ string parent_dir(io::Dirname(dir));
if (!env->FileExists(parent_dir).ok()) {
// The parent path does not exist yet, create it first.
Status s = RecursiveCreateDir(env, parent_dir); // Recursive call
@@ -670,6 +685,42 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
+// Default total disk usage limit: 100 GBytes
+const uint64 DebugFileIO::defaultGlobalDiskBytesLimit = 107374182400L;
+uint64 DebugFileIO::globalDiskBytesLimit = 0;
+uint64 DebugFileIO::diskBytesUsed = 0;
+
+mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
+
+bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+ mutex_lock l(bytes_mu);
+ if (globalDiskBytesLimit == 0) {
+ const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
+ if (env_tfdbg_disk_bytes_limit == nullptr ||
+ strlen(env_tfdbg_disk_bytes_limit) == 0) {
+ globalDiskBytesLimit = defaultGlobalDiskBytesLimit;
+ } else {
+ strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
+ &globalDiskBytesLimit);
+ }
+ }
+
+ if (bytes == 0) {
+ return true;
+ }
+ if (diskBytesUsed + bytes < globalDiskBytesLimit) {
+ diskBytesUsed += bytes;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void DebugFileIO::resetDiskByteUsage() {
+ mutex_lock l(bytes_mu);
+ diskBytesUsed = 0;
+}
+
#ifndef PLATFORM_WINDOWS
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: server_stream_addr_(server_stream_addr),
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index c974a47051..5390ce408a 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
-#define TENSORFLOW_DEBUG_IO_UTILS_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
#include <cstddef>
#include <functional>
@@ -193,6 +193,26 @@ class DebugFileIO {
const string& dir_name,
const string& file_name);
+ // Request additional bytes to be dumped to the file system.
+ //
+ // Does not actually dump the bytes, but instead just performs the
+ // bookkeeping necessary to prevent the total dumped amount of data from
+ // exceeding the limit (default 100 GBytes or set customly through the
+ // environment variable TFDBG_DISK_BYTES_LIMIT).
+ //
+ // Args:
+ // bytes: Number of bytes to request.
+ //
+ // Returns:
+ // Whether the request is approved given the total dumping
+ // limit.
+ static bool requestDiskByteUsage(uint64 bytes);
+
+ // Reset the disk byte usage to zero.
+ static void resetDiskByteUsage();
+
+ static uint64 globalDiskBytesLimit;
+
private:
// Encapsulates the Tensor in an Event protobuf and write it to file.
static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
@@ -204,6 +224,15 @@ class DebugFileIO {
// TODO(cais): Replace with shared implementation once http://b/30497715 is
// fixed.
static Status RecursiveCreateDir(Env* env, const string& dir);
+
+ // Tracks how much disk has been used so far.
+ static uint64 diskBytesUsed;
+ // Mutex for thread-safe access to diskBytesUsed.
+ static mutex bytes_mu;
+ // Default limit for the disk space.
+ static const uint64 defaultGlobalDiskBytesLimit;
+
+ friend class DiskUsageLimitTest;
};
} // namespace tensorflow
@@ -398,4 +427,4 @@ class DebugGrpcIO {
} // namespace tensorflow
#endif // #ifndef(PLATFORM_WINDOWS)
-#endif // TENSORFLOW_DEBUG_IO_UTILS_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index 0807a85b8b..82e0ae5edb 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstdlib>
#include <unordered_set>
#include "tensorflow/core/debug/debug_io_utils.h"
@@ -454,5 +455,50 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
}
}
+class DiskUsageLimitTest : public ::testing::Test {
+ public:
+ void Initialize() {
+ setenv("TFDBG_DISK_BYTES_LIMIT", "", 1);
+ DebugFileIO::resetDiskByteUsage();
+ DebugFileIO::globalDiskBytesLimit = 0;
+ }
+};
+
+TEST_F(DiskUsageLimitTest, RequestWithZeroByteIsOkay) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(0L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterOneCall) {
+ Initialize();
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(100L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterTwoCalls) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ResetDiskByteUsageWorks) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, CustomEnvVarIsObeyed) {
+ Initialize();
+ setenv("TFDBG_DISK_BYTES_LIMIT", "1024", 1);
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1000L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(23L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1023L));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_node_key.h b/tensorflow/core/debug/debug_node_key.h
index b46054c013..eaeb369790 100644
--- a/tensorflow/core/debug/debug_node_key.h
+++ b/tensorflow/core/debug/debug_node_key.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_NODE_KEY_H_
-#define TENSORFLOW_DEBUG_NODE_KEY_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
#include <string>
@@ -48,4 +48,4 @@ struct DebugNodeKey {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_NODE_KEY_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc
index 2f5aaf93fa..79798f9392 100644
--- a/tensorflow/core/debug/debugger_state_impl.cc
+++ b/tensorflow/core/debug/debugger_state_impl.cc
@@ -27,6 +27,9 @@ DebuggerState::DebuggerState(const DebugOptions& debug_options) {
debug_urls_.insert(url);
}
}
+ if (debug_options.reset_disk_byte_usage()) {
+ DebugFileIO::resetDiskByteUsage();
+ }
}
DebuggerState::~DebuggerState() {
diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h
index 52e2663d08..8f6e53fafe 100644
--- a/tensorflow/core/debug/debugger_state_impl.h
+++ b/tensorflow/core/debug/debugger_state_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUGGER_STATE_IMPL_H_
-#define TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
@@ -58,4 +58,4 @@ class DebugGraphDecorator : public DebugGraphDecoratorInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b2192c5a80..37029f3f1a 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -562,6 +562,7 @@ cc_library(
deps = [
":worker_cache",
":worker_interface",
+ "//tensorflow/core:framework",
],
)
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index a48f734d3e..269f620e42 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -167,13 +168,55 @@ class DeviceFinder {
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
- CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
- const string& local_device_name = env_->local_devices[0]->name();
- std::vector<string> workers;
- worker_cache->ListWorkers(&workers);
if (filters_.empty()) {
+ // If no filters were specified, we list all known workers in
+ // `worker_cache`.
+ std::vector<string> workers;
+ worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
+ // When applying filters, we must include the local worker, even if it
+ // does not match any of the filters.
+ CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
+ const string& local_device_name = env_->local_devices[0]->name();
+ DeviceNameUtils::ParsedName local_parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(local_device_name,
+ &local_parsed_name));
+ bool all_filters_have_job = true;
+ std::unordered_set<string> filter_job_names({local_parsed_name.job});
+ for (const DeviceNameUtils::ParsedName& filter : filters_) {
+ all_filters_have_job = all_filters_have_job && filter.has_job;
+ if (filter.has_job) {
+ filter_job_names.insert(filter.job);
+ }
+ }
+
+ std::vector<string> workers;
+ if (all_filters_have_job) {
+ // If all of the device filters have a job specified, then we only need
+ // to list the workers in the jobs named in the filter, because a worker
+ // in any other job would not match any filter.
+ for (const string& job_name : filter_job_names) {
+ VLOG(2) << "Selectively listing workers in job: " << job_name;
+ std::vector<string> workers_in_job;
+ worker_cache->ListWorkersInJob(job_name, &workers_in_job);
+ workers.insert(workers.end(), workers_in_job.begin(),
+ workers_in_job.end());
+ }
+ } else {
+ // If any of the device filters does not have a job specified, then we
+ // must list the workers from all jobs.
+ VLOG(2) << "Listing workers in all jobs because some device "
+ << "filter has no job specified. Filters were:";
+ if (device_filters.empty()) {
+ VLOG(2) << "- <NO FILTERS>";
+ } else {
+ for (const string& filter : device_filters) {
+ VLOG(2) << "- " << filter;
+ }
+ }
+ worker_cache->ListWorkers(&workers);
+ }
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index da26c42aca..837ccd1dd4 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -99,4 +99,4 @@ struct MasterEnv {
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index d34ca53f73..8e9eec1ed9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
- c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
+ c->req.set_collective_graph_key(client_graph()->collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -615,7 +615,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
// inadvertently slowing down the normal run path.
if (is_partial_) {
for (const auto& name_index : feeds) {
- const auto iter = part.feed_key.find(std::string(name_index.first));
+ const auto iter = part.feed_key.find(string(name_index.first));
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
@@ -959,7 +959,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
- const Node* n = execution_state->get_node_by_name(std::string(id.first));
+ const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
@@ -975,7 +975,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
- const Node* n = execution_state->get_node_by_name(std::string(id.first));
+ const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
@@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
- if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
- h = Hash64Combine(opts.collective_graph_key, h);
- }
-
return h;
}
@@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
- if (rcg->build_graph_options().collective_graph_key !=
+ if (rcg->client_graph()->collective_graph_key !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(
- rcg->build_graph_options().collective_graph_key, step_id);
+ rcg->client_graph()->collective_graph_key, step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
@@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = NewStepId(bgopts.collective_graph_key);
+ uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- const uint64 step_id =
- NewStepId(rcg->build_graph_options().collective_graph_key);
+ const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 72a0c7edd8..474ac0e186 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -721,4 +721,4 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
} // namespace tensorflow
-#endif // TENSORFLOW
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
index 15e5919c54..a043c5dee6 100644
--- a/tensorflow/core/distributed_runtime/remote_device.cc
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -37,7 +37,7 @@ string GetLocalDeviceName(StringPiece fullname) {
auto pos = fullname.rfind('/');
CHECK_NE(pos, StringPiece::npos);
fullname.remove_prefix(pos + 1);
- return std::string(fullname);
+ return string(fullname);
}
class RemoteDevice : public Device {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index b7eb3c9015..456c30ecf4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ for (GrpcChannelCache* cache : caches_) {
+ cache->ListWorkersInJob(job_name, workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
@@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ if (job_name == job_id_) {
+ ListWorkers(workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 4861cdb691..6fa99d7b14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -66,6 +66,8 @@ class GrpcChannelCache {
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index f07a5a0974..a814ef85e2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- EXPECT_EQ(std::vector<string>(
- {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
- "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, SparseHostPorts) {
@@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- std::sort(workers.begin(), workers.end());
- EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
- "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ std::sort(workers.begin(), workers.end());
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
index 709c3833e7..b85c1dc5b4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
-#define TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
#include <memory>
@@ -35,4 +35,4 @@ WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
} // namespace tensorflow
-#endif // TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index bcd46a4c06..c4f2247145 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -190,6 +190,8 @@ Status GrpcServer::Init(
builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
+ // Allow subclasses to specify more args to pass to the gRPC server.
+ MaybeMutateBuilder(&builder);
master_impl_ = CreateMaster(&master_env_);
master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
worker_impl_ =
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 3366246afb..7979e96d3e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -59,6 +59,9 @@ typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*)>
class GrpcServer : public ServerInterface {
protected:
GrpcServer(const ServerDef& server_def, Env* env);
+ // Allow children classes to override this and provide custom args to the
+ // server before it is constructed. Default behavior is to do nothing.
+ virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
public:
static Status Create(const ServerDef& server_def, Env* env,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
index d5baaae353..98164e750b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
#include <memory>
#include <string>
@@ -71,4 +71,4 @@ class TestCluster {
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index b9f21ea211..e1541db69b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
channel_cache_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ channel_cache_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 25ff6512a0..b070dd13dd 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -50,6 +50,8 @@ namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {}
WorkerInterface* CreateWorker(const string& target) override {
return nullptr;
}
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
index bae4ec794c..4c34297990 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.h
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -87,6 +87,9 @@ class TensorResponse {
// modified.
const RecvTensorResponse& metadata() const { return meta_; }
+ // Return pointer to the device hosting the tensor.
+ DeviceBase* device() const { return device_; }
+
private:
bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
TensorProto* tensor_meta);
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 48d83845dd..88a97da34d 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ DeviceNameUtils::ParsedName device_name;
+ CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
+ CHECK(device_name.has_job);
+ if (job_name == device_name.job) {
+ workers->push_back(it.first);
+ }
+ }
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
index 8521f8956b..0c8575b4d5 100644
--- a/tensorflow/core/distributed_runtime/worker_cache.h
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -36,6 +36,8 @@ class WorkerCacheInterface {
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
index 43c3b6285b..1f309b4361 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
virtual void ListWorkers(std::vector<string>* workers) const {
return wrapped_->ListWorkers(workers);
}
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const {
+ return wrapped_->ListWorkersInJob(job_name, workers);
+ }
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index ca6dc1b1de..c7d0c6b7f3 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
wrapped_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ wrapped_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h
index 3d06bd55e2..8bbed28471 100644
--- a/tensorflow/core/example/example_parser_configuration.h
+++ b/tensorflow/core/example/example_parser_configuration.h
@@ -53,4 +53,4 @@ Status ExampleParserConfigurationProtoToFeatureVectors(
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_
+#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index 2265498b5e..ec93b9aad9 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -97,8 +97,8 @@ limitations under the License.
// GetFeatureValues<FeatureType>(feature) -> RepeatedField<FeatureType>
// Returns values of the feature for the FeatureType.
-#ifndef TENSORFLOW_EXAMPLE_FEATURE_H_
-#define TENSORFLOW_EXAMPLE_FEATURE_H_
+#ifndef TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
+#define TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
#include <iterator>
#include <type_traits>
@@ -322,4 +322,4 @@ bool ExampleHasFeature(const string& key, const Example& example) {
}
} // namespace tensorflow
-#endif // TENSORFLOW_EXAMPLE_FEATURE_H_
+#endif // TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
index 0da9b1081b..9fce488793 100644
--- a/tensorflow/core/framework/attr_value_util.h
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
#include <functional>
#include <string>
@@ -126,4 +126,4 @@ bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
index 2f79d0fa70..e9e94024f5 100644
--- a/tensorflow/core/framework/bfloat16.h
+++ b/tensorflow/core/framework/bfloat16.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
-#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
+#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/byte_order.h"
@@ -60,4 +60,4 @@ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index 90074c87b2..acdaaf6a90 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
-#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
#include <atomic>
#include <functional>
@@ -134,4 +134,4 @@ class CancellationManager {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index d4ac50cbbe..4cb277d5a8 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -21,6 +21,31 @@ limitations under the License.
namespace tensorflow {
+namespace {
+// A RegistrationInfo object stores a collective implementation registration
+// details. `factory` is used to create instances of the collective
+// implementation.
+struct RegistrationInfo {
+ // This constructor also creates, and stores in `param_resolver_instance`,
+ // what is effectively a static instance of the collective implementation.
+ // During param resolution of collective ops we return this static instance.
+ // The actual op execution gets a fresh instance using `factory`.
+ RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
+ : name(n),
+ factory(std::move(f)),
+ param_resolver_instance(this->factory()) {}
+ string name;
+ CollectiveRegistry::Factory factory;
+ CollectiveImplementationInterface* param_resolver_instance;
+};
+
+std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
+ static std::vector<RegistrationInfo>* registry =
+ new std::vector<RegistrationInfo>;
+ return registry;
+}
+} // namespace
+
string CollGroupParams::ToString() const {
return strings::StrCat("CollGroupParams {group_key=", group_key,
" group_size=", group_size,
@@ -102,7 +127,8 @@ string CollectiveParams::ToString() const {
strings::StrAppend(&v, " ", instance.ToString());
strings::StrAppend(&v, " ", task.ToString());
strings::StrAppend(&v, " default_rank=", default_rank,
- " is_source=", is_source, " subdiv_rank={");
+ " is_source=", is_source, " source_rank=", source_rank,
+ " subdiv_rank={");
for (const auto& r : subdiv_rank) {
strings::StrAppend(&v, r, ",");
}
@@ -115,7 +141,81 @@ string CollectiveParams::ToString() const {
return ctx->params_;
}
+CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
+ const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx,
+ OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params,
+ const string& exec_key, int64 step_id,
+ const Tensor* input, Tensor* output)
+ : col_exec(col_exec),
+ dev_mgr(dev_mgr),
+ op_ctx(ctx),
+ op_params(op_params),
+ col_params(col_params),
+ exec_key(exec_key),
+ step_id(step_id),
+ input(input),
+ output(output),
+ device(nullptr),
+ device_name(col_params.instance.device_names[col_params.default_rank]) {}
+
/*static*/
int64 CollectiveExecutor::kInvalidId = -1;
+/*static*/
+Status CollectiveRegistry::Lookup(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, false);
+}
+
+/*static*/
+Status CollectiveRegistry::LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, true);
+}
+
+/*static*/
+void CollectiveRegistry::GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry)
+ implementations->emplace_back(reg_info.factory());
+}
+
+/*static*/
+Status CollectiveRegistry::Register(const string& collective_name,
+ Factory factory) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name)
+ return errors::Internal("Already registered collective ",
+ collective_name);
+ }
+ registry->emplace_back(collective_name, std::move(factory));
+ return Status::OK();
+}
+
+/*static*/
+Status CollectiveRegistry::LookupHelper(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation, bool param_resolver) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name) {
+ if (param_resolver) {
+ *implementation = reg_info.param_resolver_instance;
+ } else {
+ *implementation = reg_info.factory();
+ }
+ return Status::OK();
+ }
+ }
+ return errors::Internal(
+ "CollectiveRegistry::Lookup did not find collective implementation ",
+ collective_name);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index c3e6388e28..e35edb09d0 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
-#define TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
#include <string>
#include <vector>
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -30,7 +31,8 @@ class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
-class DeviceLocality;
+class Device;
+class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class Op;
@@ -64,10 +66,10 @@ struct CollGroupParams {
// interpretation. On first execution the runtime will update this
// structure with decisions that will guide all subsequent executions.
struct CollImplDetails {
+ string collective_name;
std::vector<std::vector<int>> subdiv_permutations;
std::vector<int> subdiv_offsets;
- // broadcast only: rank of source in each subdiv
- std::vector<int> subdiv_source_rank;
+ std::vector<int> subdiv_source_rank; // rank of source in each subdiv
};
// Data common to all members of a collective instance.
@@ -104,6 +106,7 @@ struct CollectiveParams {
string name = ""; // node name used only for log or error messages
int default_rank = -1; // index of this op within device_names
bool is_source = false; // broadcast only
+ int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
@@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
virtual void StartAbort(const Status& s) = 0;
};
+class CollectiveContext {
+ public:
+ CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx, OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params, const string& exec_key,
+ int64 step_id, const Tensor* input, Tensor* output);
+
+ virtual ~CollectiveContext() = default;
+
+ CollectiveExecutor* col_exec; // Not owned
+ const DeviceMgr* dev_mgr; // Not owned
+ OpKernelContext* op_ctx; // Not owned
+ OpKernelContext::Params* op_params; // Not owned
+ const CollectiveParams& col_params;
+ const string exec_key;
+ const int64 step_id;
+ const Tensor* input; // Not owned
+ Tensor* output; // Not owned
+ Device* device; // The device for which this instance labors
+ const string device_name;
+ DeviceLocality device_locality;
+};
+
+// Interface of a Collective Op implementation. Each specific CollectiveOp will
+// implement this interface and register the implementation via the
+// CollectiveRegistry detailed below. See common_runtime/ring_reducer and
+// common_runtime/hierarchical_tree_broadcaster for examples.
+class CollectiveImplementationInterface {
+ public:
+ virtual ~CollectiveImplementationInterface() = default;
+
+ // Initializes the portions of `col_params` specific to this
+ // implementation. Called exactly once for every Collective instance during
+ // the CollectiveParams resolution process when the graph is first executed.
+ // NOTE(ayushd): This is effectively a static function because it modifies the
+ // `col_params` passed in and should not manipulate any data members. However
+ // because it is virtual and needs to be implemented by every derived class we
+ // do not mark it as static.
+ virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
+
+ // Prepares the CollectiveContext for executing this CollectiveImplementation.
+ // Called from CollectiveExecutor right before calling Run(). The
+ // CollectiveContext passed in must outlive the CollectiveImplementation
+ // object.
+ virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
+
+ // Processes and moves data according to the logic of this Collective
+ // implementation. Relies on appropriate initialization of op-specific
+ // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
+ // context initialization in InitializeCollectiveContext().
+ virtual void Run(StatusCallback done) = 0;
+};
+
+// Static-methods only class for registering and looking up collective
+// implementations.
+class CollectiveRegistry {
+ public:
+ using Factory = std::function<CollectiveImplementationInterface*()>;
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, creates an instance of the implementation and
+ // assign to `implementation`.
+ static Status Lookup(const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, returns the static instance of this
+ // implementation via `implementation`. This instance should only be used to
+ // call InitializateCollectiveParams.
+ static Status LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Returns all registered collective implementations.
+ static void GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations);
+
+ private:
+ friend class CollectiveRegistration;
+ // Registers a CollectiveImplementation with name `collective_name` and
+ // factory `factory`. The latter is a function used to create instances of
+ // the CollectiveImplementation. Also creates a static instance of the
+ // implementation - this instance is used during param resolution and should
+ // only be used to call InitializeCollectiveParams.
+ static Status Register(const string& collective_name, Factory factory);
+
+ static Status LookupHelper(const string& collective_name,
+ CollectiveImplementationInterface** implementation,
+ bool param_resolver);
+};
+
+// Class used to call CollectiveRegistry::Register. This should only be used to
+// create a global static object.
+class CollectiveRegistration {
+ public:
+ CollectiveRegistration(const string& collective_name,
+ CollectiveRegistry::Factory factory) {
+ TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
+ }
+};
+
+#define REGISTER_COLLECTIVE(name, implementation) \
+ static CollectiveRegistration register_##name##_collective( \
+ #name, []() { return new implementation; });
+
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 21c6940b62..20a07d86a2 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -432,9 +432,9 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
- TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
- &batch_size_dim, &input_spatial_dims,
- &input_depth_dim, c));
+ TF_RETURN_IF_ERROR(DimensionsFromShape(
+ conv_input_shape, data_format, &batch_size_dim,
+ absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
DimensionHandle output_depth_dim = c->Dim(
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 2bedce1d6a..e6f9f935f9 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
-#define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
#include <array>
@@ -311,4 +311,4 @@ Status ExplicitShapes(InferenceContext* c);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h
index 4dad0b4fef..4839e02e22 100644
--- a/tensorflow/core/framework/control_flow.h
+++ b/tensorflow/core/framework/control_flow.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
-#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
+#define TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,4 +55,4 @@ struct FrameAndIterHash {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index f3c7189292..5281c56f04 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
+namespace data {
namespace {
@@ -133,22 +134,25 @@ Status GraphDefBuilderWrapper::AddDataset(
return Status::OK();
}
-Status GraphDefBuilderWrapper::AddFunction(
- const FunctionLibraryDefinition& flib_def, const string& function_name) {
+Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
+ const string& function_name) {
if (b_->HasFunction(function_name)) {
VLOG(1) << "Function with name " << function_name << "already exists in"
<< " the graph. It will not be added again.";
return Status::OK();
}
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name));
- const FunctionDef* f_def = flib_def.Find(function_name);
+ if (!ctx->allow_stateful_functions()) {
+ TF_RETURN_IF_ERROR(
+ EnsureFunctionIsStateless(ctx->flib_def(), function_name));
+ }
+ const FunctionDef* f_def = ctx->flib_def().Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in the registry.");
}
FunctionDefLibrary def;
*def.add_function() = *f_def;
- const string gradient_func = flib_def.FindGradient(function_name);
+ const string gradient_func = ctx->flib_def().FindGradient(function_name);
if (!gradient_func.empty()) {
GradientDef* g_def = def.add_gradient();
g_def->set_function_name(function_name);
@@ -159,23 +163,30 @@ Status GraphDefBuilderWrapper::AddFunction(
// Recursively add functions in inputs of function_name.
for (const NodeDef& node_def : f_def->node_def()) {
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
}
// Recursively add functions in attrs of this NodeDef.
for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
}
}
// Recursively add functions in attrs of function_name.
for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
}
return Status::OK();
}
+void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Placeholder",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
+}
+
void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
Node** output) {
*output = ops::SourceOp(
@@ -319,4 +330,5 @@ void BackgroundWorker::WorkerLoop() {
}
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e0c26d9286..4e51fba048 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -40,7 +40,15 @@ limitations under the License.
namespace tensorflow {
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class Node;
+
+namespace data {
+
class DatasetBase;
+class SerializationContext;
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
@@ -65,11 +73,6 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
-// Forward declarations to avoid introducing a dependency on headers in
-// "tensorflow/core/graph/...".
-class GraphDefBuilder;
-class Node;
-
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -109,10 +112,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- // Adds a Const node with Tensor value to the Graph.
+ // Adds a `Const` node for the given tensor value to the graph.
+ //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
- // non-null if the method returns with an OK status.
- // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
@@ -121,6 +125,20 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
+ // Adds a `Placeholder` node for the given tensor value to the graph.
+ //
+ // `*output` contains a pointer to the output `Node`. It is guaranteed to be
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
+ Status AddPlaceholder(const Tensor& val, Node** output) {
+ AddPlaceholderInternal(val, output);
+ if (*output == nullptr) {
+ return errors::Internal(
+ "AddPlaceholder: Failed to build Placeholder op.");
+ }
+ return Status::OK();
+ }
+
Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
@@ -155,11 +173,11 @@ class GraphDefBuilderWrapper {
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
// name has already been added, returns with OK status. If a user-defined with
- // name `function_name` is not found in the FunctionLibraryDefinition, returns
- // an InvalidArgumentError. If the function with name `function_name` or any
- // of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(const FunctionLibraryDefinition& flib_def,
- const string& function_name);
+ // name `function_name` is not found in the context's function library,
+ // returns an InvalidArgumentError. If the function with name `function_name`
+ // or any of its dependent functions are stateful, and the context does not
+ // explicitly permit stateful functions, returns an InvalidArgument error.
+ Status AddFunction(SerializationContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -167,6 +185,7 @@ class GraphDefBuilderWrapper {
}
private:
+ void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
@@ -205,8 +224,7 @@ class GraphDefBuilderWrapper {
return (str_util::EndsWith(op_def->name(), "Dataset") &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
- dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
- op_def->name());
+ WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
bool HasAttr(const string& op_type_name, const string& attr_name) const;
@@ -220,13 +238,13 @@ class GraphDefBuilderWrapper {
return false;
}
- Status AddAttrFunctions(const AttrValue& attr_value,
- const FunctionLibraryDefinition& flib_def) {
+ Status AddAttrFunctions(SerializationContext* ctx,
+ const AttrValue& attr_value) {
if (attr_value.has_func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
} else if (attr_value.has_list()) {
for (const NameAttrList& name_attr_list : attr_value.list().func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
}
}
return Status::OK();
@@ -332,13 +350,21 @@ class IteratorContext {
class SerializationContext {
public:
struct Params {
- const FunctionLibraryDefinition* flib_def; // Not owned.
+ bool allow_stateful_functions = false;
+ const FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+ std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
+ bool allow_stateful_functions() { return params_.allow_stateful_functions; }
+
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+ std::vector<std::pair<string, Tensor>>* input_list() {
+ return params_.input_list;
+ }
+
private:
Params params_;
@@ -726,6 +752,21 @@ class BackgroundWorker {
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
};
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::DatasetBase;
+using data::DatasetContext;
+using data::DatasetIterator;
+using data::DatasetOpKernel;
+using data::IteratorBase;
+using data::IteratorContext;
+using data::IteratorStateReader;
+using data::IteratorStateWriter;
+using data::SerializationContext;
+using data::UnaryDatasetOpKernel;
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
index 3b48999edb..74bd39cb61 100644
--- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h
+++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
@@ -16,38 +16,38 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
+#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-namespace dataset {
+namespace data {
// Registry for stateful ops that need to be used in dataset functions.
// See below macro for usage details.
class WhitelistedStatefulOpRegistry {
public:
- Status Add(StringPiece op_name) {
- op_names_.insert(op_name);
+ Status Add(string op_name) {
+ op_names_.insert(std::move(op_name));
return Status::OK();
}
- bool Contains(StringPiece op_name) {
- return op_names_.find(op_name) != op_names_.end();
- }
+ bool Contains(const string& op_name) { return op_names_.count(op_name); }
static WhitelistedStatefulOpRegistry* Global() {
- static WhitelistedStatefulOpRegistry* reg =
- new WhitelistedStatefulOpRegistry;
+ static auto* reg = new WhitelistedStatefulOpRegistry;
return reg;
}
private:
- WhitelistedStatefulOpRegistry() {}
- WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+ WhitelistedStatefulOpRegistry() = default;
+ WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
+ delete;
WhitelistedStatefulOpRegistry operator=(
- WhitelistedStatefulOpRegistry const& copy);
- std::set<StringPiece> op_names_;
+ WhitelistedStatefulOpRegistry const& copy) = delete;
+
+ std::unordered_set<string> op_names_;
};
-} // namespace dataset
+} // namespace data
// Use this macro to whitelist an op that is marked stateful but needs to be
// used inside a map_fn in an input pipeline. This is only needed if you wish
@@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry {
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
-#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
- static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
- ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
- name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
+ static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index b184fd91e1..794250a2c1 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -89,6 +89,15 @@ class DeviceContext : public core::RefCounted {
Tensor* cpu_tensor, StatusCallback done) {
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
}
+
+ // If possible, wait for all events on *stream to complete then execute func.
+ // A non-OK Status is returned otherwise. The stream argument should be the
+ // one provided by GpuDeviceInfo. This function is not applicable to devices
+ // that don't provide such a value.
+ virtual Status ThenExecute(Device* device, stream_executor::Stream* stream,
+ std::function<void()> func) {
+ return errors::Internal("ThenExecute not supported by device");
+ }
};
// map[i] is the DeviceContext* for the node with id i, if i < map.size().
diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h
index 103db47a99..c3062762ff 100644
--- a/tensorflow/core/framework/fake_input.h
+++ b/tensorflow/core/framework/fake_input.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
-#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
+#define TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -37,4 +37,4 @@ inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 6b92e10d76..26f32677af 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -504,7 +504,7 @@ string Print(const NodeDef& n) {
std::vector<string> dep;
for (StringPiece s : n.input()) {
if (str_util::ConsumePrefix(&s, "^")) {
- dep.push_back(std::string(s));
+ dep.emplace_back(s);
} else {
dat.push_back(s);
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index edb7ed01e9..03296a7761 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
-#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -490,6 +490,11 @@ class FunctionLibraryRuntime {
// Instantiates the function using an executor of the given type. If empty,
// the default TensorFlow executor will be used.
string executor_type;
+
+ // If true, the runtime will attempt to create kernels for the function at
+ // instantiation time, rather than on the first run. This can be used to
+ // surface errors earlier.
+ bool create_kernels_eagerly = false;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
@@ -705,9 +710,10 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
-#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
- static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
- ::tensorflow::gradient::RegisterOp(name, fn)
+#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
+ static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
+ SHOULD_REGISTER_OP_GRADIENT && \
+ ::tensorflow::gradient::RegisterOp(name, fn)
namespace gradient {
// Register a gradient creator for the "op".
@@ -731,4 +737,4 @@ GET_ATTR(bool)
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 41270b8e5e..c5a4f661d2 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -49,8 +49,8 @@ NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
const string& device) {
NodeDef n;
- n.set_name(name.ToString());
- n.set_op(op.ToString());
+ n.set_name(string(name));
+ n.set_op(string(op));
for (const auto& in : inputs) n.add_input(in);
n.set_device(device);
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
@@ -110,6 +110,22 @@ FunctionDef XTimesTwo() {
});
}
+FunctionDef XAddX() {
+ return FDH::Define(
+ // Name
+ "XAddX",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
+ });
+}
+
FunctionDef XTimesTwoInt32() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
@@ -219,6 +235,62 @@ FunctionDef InvalidControlFlow() {
{{"o", "add:z"}});
}
+FunctionDef LessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "LessThanOrEqualToN",
+ // Args
+ {"x: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XPlusOneXTimesY() {
+ const Tensor kOne = test::AsScalar<int64>(1);
+ return FDH::Define(
+ // Name
+ "XPlusOneXTimesY",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"s: T", "t: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
+ {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
+ {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
+}
+
+FunctionDef XYXLessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "XYXLessThanOrEqualToN",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
+ });
+}
+
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index af08d296b2..ad61a76f16 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
// x:T -> x * 2.
FunctionDef XTimesTwo();
+// x:T -> x + x.
+FunctionDef XAddX();
+
// x:T -> x * 2, where x is int32.
FunctionDef XTimesTwoInt32();
@@ -87,6 +90,15 @@ FunctionDef Swap();
// Contains malformed control flow which can't be run by the executor.
FunctionDef InvalidControlFlow();
+// x:T -> x <= N.
+FunctionDef LessThanOrEqualToN(int64 N);
+
+// x:T, y:T -> x+1, x*y
+FunctionDef XPlusOneXTimesY();
+
+// x:T, y:T -> x <= N
+FunctionDef XYXLessThanOrEqualToN(int64 N);
+
void FunctionTestSchedClosure(std::function<void()> fn);
} // end namespace function
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
index 525e84a989..2f8d5e8f51 100644
--- a/tensorflow/core/framework/graph_def_util.h
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
#include <set>
#include "tensorflow/core/framework/op.h"
@@ -118,4 +118,4 @@ Status StrippedOpListForGraph(const GraphDef& graph_def,
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
index 2966aa58de..32dd21f94e 100644
--- a/tensorflow/core/framework/kernel_def_builder.h
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -84,4 +84,4 @@ KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/log_memory.h b/tensorflow/core/framework/log_memory.h
index faef7b8e98..1b926ddaa3 100644
--- a/tensorflow/core/framework/log_memory.h
+++ b/tensorflow/core/framework/log_memory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
-#define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
+#define TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -108,4 +108,4 @@ class LogMemory {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 1381dd66a5..0622dd06cb 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@@ -142,4 +142,4 @@ class LookupInterface : public ResourceBase {
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h
index d3918513d3..f719131bcb 100644
--- a/tensorflow/core/framework/memory_types.h
+++ b/tensorflow/core/framework/memory_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
@@ -35,4 +35,4 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index 8e00bfe4f8..348a825af9 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -24,23 +24,22 @@ limitations under the License.
namespace tensorflow {
NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
- : node(std::string(n)), index(i), data_type(dt) {}
+ : node(n), index(i), data_type(dt) {}
NodeDefBuilder::NodeOut::NodeOut() {
// uninitialized, call Reset() before use.
}
void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
- node = std::string(n);
+ node = string(n);
index = i;
data_type = dt;
}
NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(std::string(name));
- const Status status =
- op_registry->LookUpOpDef(std::string(op_name), &op_def_);
+ node_def_.set_name(string(name));
+ const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
if (status.ok()) {
Initialize();
} else {
@@ -51,7 +50,7 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(std::string(name));
+ node_def_.set_name(string(name));
Initialize();
}
@@ -171,7 +170,7 @@ void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(std::string(src_node));
+ node_def_.add_input(string(src_node));
}
}
@@ -194,12 +193,12 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
- control_inputs_.push_back(std::string(src_node));
+ control_inputs_.emplace_back(src_node);
return *this;
}
NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
- node_def_.set_device(std::string(device_spec));
+ node_def_.set_device(string(device_spec));
return *this;
}
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
index c138332beb..ad07ec5480 100644
--- a/tensorflow/core/framework/node_def_builder.h
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
#include <functional>
#include <vector>
@@ -175,4 +175,4 @@ class NodeDefBuilder {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 0bd79366eb..bacc1d72c4 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -254,7 +254,7 @@ DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
#undef DEFINE_GET_ATTR
bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
- return node_def.attr().find(std::string(attr_name)) != node_def.attr().end();
+ return node_def.attr().find(string(attr_name)) != node_def.attr().end();
}
static const string& kEmptyString = *new string();
@@ -653,7 +653,7 @@ Status AttachDef(const Status& status, const Node& node) {
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
node_def->mutable_attr()->insert(
- AttrValueMap::value_type(std::string(name), value));
+ AttrValueMap::value_type(string(name), value));
}
#define ADD_NODE_ATTR(T) \
@@ -691,7 +691,7 @@ ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
#undef ADD_NODE_ATTR
void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
- map->insert(AttrValueMap::value_type(std::string(name), value));
+ map->insert(AttrValueMap::value_type(string(name), value));
}
#define ADD_ATTR(T) \
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index c012b7c3d3..499034cab2 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
#include <unordered_map>
@@ -312,4 +312,4 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
NodeDef* node_def);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
index 4538ff053c..0167e21f11 100644
--- a/tensorflow/core/framework/numeric_op.h
+++ b/tensorflow/core/framework/numeric_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
-#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -110,4 +110,4 @@ class BinaryElementWiseOp : public BinaryOp<T> {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index b1d0127809..3236d1897c 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -122,4 +122,4 @@ struct hash<Eigen::half> {
} // namespace std
#endif // _MSC_VER
-#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 3ccca4090d..25f8de8dcc 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_H_
-#define TENSORFLOW_FRAMEWORK_OP_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_H_
#include <functional>
#include <unordered_map>
@@ -309,4 +309,4 @@ struct OpDefBuilderReceiver {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 91eb6c0672..34a7a43d38 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -527,7 +527,7 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(std::string(op_name)); // NOLINT
+ op_def()->set_name(string(op_name)); // NOLINT
}
OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
@@ -584,7 +584,7 @@ OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(std::string(explanation));
+ deprecation->set_explanation(string(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index fbfb4018aa..0b39d6e848 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -16,8 +16,8 @@ limitations under the License.
// Class and associated machinery for specifying an Op's OpDef and shape
// inference function for Op registration.
-#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
#include <string>
#include <vector>
@@ -162,4 +162,4 @@ class OpDefBuilder {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 9be0dc69d2..3597f43d51 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -172,6 +172,15 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
return nullptr;
}
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
#define VALIDATE(EXPR, ...) \
do { \
if (!(EXPR)) { \
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index 0ba1325a03..85afe2bdea 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -16,10 +16,11 @@ limitations under the License.
// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't
// need to be as publicly accessible as other files in framework/.
-#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
#include <string>
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -47,6 +48,10 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
// Returns nullptr if no such attr is found.
const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def);
+// Searches api_def for input argument with the indicated name.
+// Returns nullptr if no such attr is found.
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def);
+
// Produce a human-readable version of an op_def that is more concise
// than a text-format proto. Excludes descriptions.
string SummarizeOpDef(const OpDef& op_def);
@@ -98,4 +103,4 @@ uint64 OpDefHash(const OpDef& o);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 4b56d807df..505ab54775 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -186,7 +186,7 @@ static bool FindMultiline(StringPiece line, size_t colon, string* end) {
while (str_util::ConsumePrefix(&line, " ")) {
}
if (str_util::ConsumePrefix(&line, "<<")) {
- *end = std::string(line);
+ *end = string(line);
return true;
}
return false;
diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h
index 533dd64805..c269e2df04 100644
--- a/tensorflow/core/framework/op_gen_lib.h
+++ b/tensorflow/core/framework/op_gen_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
-#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
#include <string>
#include <unordered_map>
@@ -97,4 +97,4 @@ class ApiDefMap {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b285accce7..c694e10193 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -913,7 +913,7 @@ void OpKernelContext::clear_recorded_memory() {
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
- : def(d), kernel_class_name(std::string(c)), factory(f) {}
+ : def(d), kernel_class_name(c), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h
index 4aeaab3d9b..4ca4416c5a 100644
--- a/tensorflow/core/framework/queue_interface.h
+++ b/tensorflow/core/framework/queue_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
#include <string>
#include <vector>
@@ -99,4 +99,4 @@ class QueueInterface : public ResourceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h
index cb44be4dee..5b82e9181f 100644
--- a/tensorflow/core/framework/reader_base.h
+++ b/tensorflow/core/framework/reader_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_BASE_H_
-#define TENSORFLOW_FRAMEWORK_READER_BASE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
#include <memory>
#include <string>
@@ -135,4 +135,4 @@ class ReaderBase : public ReaderInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_READER_BASE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h
index dac6056b5a..f894acbe1d 100644
--- a/tensorflow/core/framework/reader_interface.h
+++ b/tensorflow/core/framework/reader_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
#include <memory>
#include <string>
@@ -84,4 +84,4 @@ class ReaderInterface : public ResourceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h
index ffd6a1a184..e65a8695be 100644
--- a/tensorflow/core/framework/reader_op_kernel.h
+++ b/tensorflow/core/framework/reader_op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
#include <functional>
#include <string>
@@ -85,4 +85,4 @@ class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index f1cd37ecda..ddb5b10c18 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
// This file is used by cuda code and must remain compilable by nvcc.
#include "tensorflow/core/framework/numeric_types.h"
@@ -161,9 +161,12 @@ limitations under the License.
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
TF_CALL_uint8(m) TF_CALL_int8(m)
+#define TF_CALL_FLOAT_TYPES(m) \
+ TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+
#define TF_CALL_REAL_NUMBER_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \
- TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+ TF_CALL_FLOAT_TYPES(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
@@ -225,4 +228,4 @@ limitations under the License.
#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
#endif // __ANDROID_TYPES_SLIM__
-#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h
index ab35c2f095..d475a1972d 100644
--- a/tensorflow/core/framework/register_types_traits.h
+++ b/tensorflow/core/framework/register_types_traits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
-#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
// This file is used by cuda code and must remain compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -102,4 +102,4 @@ struct proxy_type {
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 33d4cb77ff..f8a587c9b5 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
-#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
#include <string>
#include <typeindex>
@@ -61,8 +61,8 @@ namespace tensorflow {
//
// // Create a var.
// MyVar* my_var = new MyVar;
-// my_var.val = Tensor(DT_FLOAT, my_shape);
-// my_var.val.flat<float>().setZeros(); // 0 initialized.
+// my_var->val = Tensor(DT_FLOAT, my_shape);
+// my_var->val.flat<float>().setZeros(); // 0 initialized.
// ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
//
// // += a variable.
@@ -555,4 +555,4 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 0a8da8b3bf..fbcd439dea 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
#include <string>
@@ -136,4 +136,4 @@ class ResourceOpKernel : public OpKernel {
};
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h
index 503947969d..4b281a04bf 100644
--- a/tensorflow/core/framework/selective_registration.h
+++ b/tensorflow/core/framework/selective_registration.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
-#define TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
#include <string.h>
@@ -55,4 +55,4 @@ static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros");
#define SHOULD_REGISTER_OP_KERNEL(clz) true
#endif
-#endif // TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h
index 653a661dd2..63568685f2 100644
--- a/tensorflow/core/framework/session_state.h
+++ b/tensorflow/core/framework/session_state.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
-#define TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
#include <string>
#include <unordered_map>
@@ -90,4 +90,4 @@ class TensorStore {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index f6656b3b45..bb4dc25da4 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -32,7 +32,7 @@ class Tensor;
struct ShapeInferenceTestOp {
typedef std::pair<string, DataType> ShapeAndType;
- explicit ShapeInferenceTestOp(StringPiece name) : name(std::string(name)) {}
+ explicit ShapeInferenceTestOp(StringPiece name) : name(string(name)) {}
string name;
NodeDef node_def;
std::vector<const Tensor*> input_tensors;
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 4a18efc940..af53ed0a3c 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -25,6 +25,8 @@ namespace tensorflow {
class Summary;
+namespace data {
+
// A `StatsAggregator` accumulates statistics incrementally. A
// `StatsAggregator` can accumulate multiple different statistics, distinguished
// by a string name.
@@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase {
const std::shared_ptr<StatsAggregator> stats_aggregator_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index a82beb7e8f..516afa517d 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -617,13 +617,13 @@ bool Tensor::IsInitialized() const {
}
void Tensor::CheckType(DataType expected_dtype) const {
- CHECK_EQ(dtype(), expected_dtype)
+ CHECK_EQ(dtype(), expected_dtype) << " "
<< DataTypeString(expected_dtype) << " expected, got "
<< DataTypeString(dtype());
}
void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
- CHECK_EQ(dtype(), expected_dtype)
+ CHECK_EQ(dtype(), expected_dtype) << " "
<< DataTypeString(expected_dtype) << " expected, got "
<< DataTypeString(dtype());
CHECK(IsAligned()) << "ptr = " << base<void>();
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
index 6019737342..82f21fb17e 100644
--- a/tensorflow/core/framework/tensor_slice.h
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -221,4 +221,4 @@ void TensorSlice::FillIndicesAndSizes(
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
index a5c1a56bfc..6f981db189 100644
--- a/tensorflow/core/framework/tensor_types.h
+++ b/tensorflow/core/framework/tensor_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -123,4 +123,4 @@ To32Bit(TensorType in) {
}
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 43d2d95311..4bda8f9eb8 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -160,4 +160,4 @@ CreateTensorProto(const std::vector<Type>& values,
} // namespace tensor
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
index 661c28969e..5eafce662e 100644
--- a/tensorflow/core/framework/tracking_allocator.h
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
-#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
#include <unordered_map>
#include "tensorflow/core/framework/allocator.h"
@@ -130,4 +130,4 @@ class TrackingAllocator : public Allocator {
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/type_index.h b/tensorflow/core/framework/type_index.h
index b978d90fa8..989fc42e26 100644
--- a/tensorflow/core/framework/type_index.h
+++ b/tensorflow/core/framework/type_index.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
-#define TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
#include <string>
#if defined(__GXX_RTTI) || defined(_CPPRTTI)
@@ -84,4 +84,4 @@ inline TypeIndex MakeTypeIndex() {
#endif // __GXX_RTTI
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h
index e8351e494f..96fbf92938 100644
--- a/tensorflow/core/framework/type_traits.h
+++ b/tensorflow/core/framework/type_traits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
-#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
#include <limits>
#include <utility>
@@ -106,4 +106,4 @@ struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {};
} // namespace std
-#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index ff7c9855d6..15b1add2c1 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
#include <map>
#include <set>
@@ -481,4 +481,4 @@ bool DataTypeAlwaysOnHost(DataType dt);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index c02391dae3..52732801a0 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
#include <functional>
#include <iostream>
@@ -351,4 +351,4 @@ const void* Variant::get() const;
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index ded04b2a30..f155aa4892 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
#include <iostream>
#include <type_traits>
@@ -271,4 +271,4 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index c9e8dd2217..e6a2665a56 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
#include <string>
#include <unordered_set>
@@ -580,4 +580,4 @@ class UnaryVariantBinaryOpRegistration {
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 1d87bc341a..7500e77d43 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
-#define TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
#include <algorithm>
#include <vector>
@@ -112,4 +112,4 @@ string ProtoDebugString(const VariantTensorData& object);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index 5bbbc6f6dc..45f8a29a92 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_
-#define TENSORFLOW_GRAPH_ALGORITHM_H_
+#ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
+#define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
#include <functional>
#include <unordered_set>
@@ -117,4 +117,4 @@ bool FixupSourceAndSinkEdges(Graph* g);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_ALGORITHM_H_
+#endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h
index c1e1940cac..43d2225571 100644
--- a/tensorflow/core/graph/colors.h
+++ b/tensorflow/core/graph/colors.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_COLORS_H_
-#define TENSORFLOW_GRAPH_COLORS_H_
+#ifndef TENSORFLOW_CORE_GRAPH_COLORS_H_
+#define TENSORFLOW_CORE_GRAPH_COLORS_H_
namespace tensorflow {
@@ -26,4 +26,4 @@ const char* ColorFor(int dindex);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_COLORS_H_
+#endif // TENSORFLOW_CORE_GRAPH_COLORS_H_
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 548820720b..5abe77f5a1 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_CONTROL_FLOW_H_
-#define TENSORFLOW_GRAPH_CONTROL_FLOW_H_
+#ifndef TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
+#define TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
#include <vector>
@@ -48,4 +48,4 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_CONTROL_FLOW_H_
+#endif // TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index 9b703e4693..2d94dd5cdc 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_
-#define TENSORFLOW_GRAPH_COSTMODEL_H_
+#ifndef TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
+#define TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
#include <unordered_map>
#include <vector>
@@ -229,4 +229,4 @@ class CostModel {
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_COSTMODEL_H_
+#endif // TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h
index 68d7c8e553..f0f53c91f4 100644
--- a/tensorflow/core/graph/default_device.h
+++ b/tensorflow/core/graph/default_device.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
-#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#ifndef TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
+#define TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
#include <string>
@@ -38,4 +38,4 @@ inline void SetDefaultDevice(const string& device, GraphDef* graph_def) {
} // namespace graph
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#endif // TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 568f0870c0..1630ab7a15 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -483,7 +483,7 @@ const Edge* Graph::AddControlEdge(Node* source, Node* dest,
void Graph::RemoveControlEdge(const Edge* e) {
if (!e->src_->IsSource() && !e->dst_->IsSink()) {
e->dst_->MaybeCopyOnWrite();
- std::string e_src_name = strings::StrCat("^", e->src_->name());
+ string e_src_name = strings::StrCat("^", e->src_->name());
auto* inputs = e->dst_->props_->node_def.mutable_input();
for (auto it = inputs->begin(); it != inputs->end(); ++it) {
if (*it == e_src_name) {
@@ -495,6 +495,15 @@ void Graph::RemoveControlEdge(const Edge* e) {
RemoveEdge(e);
}
+namespace {
+const Edge* FindEdge(const Node* dst, int index) {
+ for (const Edge* e : dst->in_edges()) {
+ if (e->dst_input() == index) return e;
+ }
+ return nullptr;
+}
+} // namespace
+
Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
int dst_index) {
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
@@ -512,17 +521,6 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
return Status::OK();
}
-const Edge* Graph::FindEdge(const Node* dst, int index) {
- for (const Edge* e : edges_) {
- // edges_ will contain null edges if RemoveEdge() was called.
- if (e == nullptr) continue;
- if (e->dst() == dst && e->dst_input() == index) {
- return e;
- }
- }
- return nullptr;
-}
-
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
// Need a new-enough consumer to support the functions we add to the graph.
if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
@@ -721,7 +719,7 @@ Status Graph::AddWhileContext(StringPiece frame_name,
std::vector<OutputTensor> body_outputs,
WhileContext** result) {
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
- std::string(frame_name),
+ string(frame_name),
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
cond_output, std::move(body_inputs),
std::move(body_outputs))));
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index a147c94689..52e9f23a76 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -680,10 +680,6 @@ class Graph {
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
- // Searches through edges_ for the Edge whose destination node and index
- // matches dst. An edge with destination `dst` must exist in the graph.
- const Edge* FindEdge(const Node* dst, int index);
-
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 8c73f8f712..ee10194142 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -513,7 +513,7 @@ Status GraphConstructor::InitFromEdges() {
num_control_edges++;
} else {
TensorId id(ParseTensorName(input_name));
- if (next_iteration_nodes_.find(std::string(id.first)) !=
+ if (next_iteration_nodes_.find(string(id.first)) !=
next_iteration_nodes_.end()) {
has_loop_back_edge = true;
}
@@ -835,7 +835,7 @@ void GraphConstructor::UniquifyNames(
// We require that UniquifyNames() is called on all NodeDefs in topological
// order. This guarantees that node_def's inputs will already be uniquified
// if necessary.
- auto iter = uniquified_names_.find(std::string(id.first));
+ auto iter = uniquified_names_.find(string(id.first));
if (iter == uniquified_names_.end()) continue;
id.first = iter->second;
node_def->set_input(i, id.ToString());
@@ -854,7 +854,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
for (int i = 0; i < coloc_values.size(); ++i) {
StringPiece val(coloc_values[i]);
if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
- const auto& name_pair = uniquified_names_.find(std::string(val));
+ const auto& name_pair = uniquified_names_.find(string(val));
if (name_pair == uniquified_names_.end()) continue;
updated = true;
coloc_values[i] =
@@ -880,7 +880,7 @@ bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
}
string GraphConstructor::FindUniqueName(StringPiece original_name) {
- string name = std::string(original_name);
+ string name(original_name);
int count = 0;
// Check that any generated names don't collide with imported NodeDefs (as
// well as nodes in g_).
@@ -997,7 +997,7 @@ Status GraphConstructor::Convert() {
src_node->num_outputs(), " outputs");
}
- inputs.emplace_back(std::string(id.first), src_node, src_index);
+ inputs.emplace_back(string(id.first), src_node, src_index);
}
if (has_data_back_edge && !IsMerge(*node_def)) {
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index 889359a68a..f6e41faf9c 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
-#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
@@ -186,4 +186,4 @@ extern void CopyGraph(const Graph& src, Graph* dest);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e338840eeb..73142ebde7 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -156,9 +156,8 @@ class GraphConstructorTest : public ::testing::Test {
return "";
}
StringPiece loc(value[0]);
- return str_util::ConsumePrefix(&loc, kColocationGroupPrefix)
- ? std::string(loc)
- : "";
+ return str_util::ConsumePrefix(&loc, kColocationGroupPrefix) ? string(loc)
+ : "";
}
string GraphDebugString() const {
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
index dd84c4f7c7..6d5df7efba 100644
--- a/tensorflow/core/graph/graph_def_builder.cc
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -44,12 +44,12 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
StringPiece name) {
- name_ = std::string(name);
+ name_ = string(name);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
StringPiece device) {
- device_ = std::string(device);
+ device_ = string(device);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
index 0d6aae4355..400d8b6c84 100644
--- a/tensorflow/core/graph/graph_def_builder.h
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
-#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
@@ -128,7 +128,7 @@ class GraphDefBuilder {
Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
template <class T>
Options WithAttrImpl(StringPiece name, T&& value) {
- attrs_.emplace_back(std::string(name), AttrValue());
+ attrs_.emplace_back(string(name), AttrValue());
SetAttrValue(std::forward<T>(value), &attrs_.back().second);
return *this;
}
@@ -203,4 +203,4 @@ Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
} // namespace ops
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index ea0a814ab8..1dbcebab59 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -793,7 +793,7 @@ Status TopologicalSortNodesWithTimePriority(
for (int n = 0; n < gdef->node_size(); ++n) {
const NodeDef* ndef = &gdef->node(n);
for (int i = 0; i < ndef->input_size(); ++i) {
- node_to_output_nodes[std::string(ParseTensorName(ndef->input(i)).first)]
+ node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
.push_back(ndef);
}
int64 start_time;
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
index 67fafddd51..8020c2d247 100644
--- a/tensorflow/core/graph/graph_partition.h
+++ b/tensorflow/core/graph/graph_partition.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
-#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
#include <functional>
#include <string>
@@ -95,4 +95,4 @@ Status AddControlEdges(const PartitionOptions& opts,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 333bf761b0..bab1df87a4 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -41,7 +41,7 @@ namespace tensorflow {
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+static const MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 5683944e46..2e644fe987 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -334,6 +334,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
+
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
@@ -546,14 +547,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -1042,6 +1043,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -1335,6 +1337,7 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -2408,6 +2411,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.avg_pool3d = "AvgPool3D";
+ csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -2418,6 +2423,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.conv2d_grad_filter_with_bias =
"__MklDummyConv2DBackpropFilterWithBias";
+ csinfo_.conv3d = "Conv3D";
+ csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
+ csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.identity = "Identity";
@@ -2426,6 +2434,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.max_pool3d = "MaxPool3D";
+ csinfo_.max_pool3d_grad = "MaxPool3DGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
@@ -2460,6 +2470,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite});
@@ -2468,18 +2484,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
- csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
+ csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_filter,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_input,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
@@ -2501,7 +2526,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, MaxpoolGradRewrite});
-
+ rinfo_.push_back({csinfo_.max_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
+ rinfo_.push_back({csinfo_.max_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
@@ -2538,6 +2568,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
+ wsinfo_.push_back
+ ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
@@ -2605,6 +2637,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string add;
string avg_pool;
string avg_pool_grad;
+ string avg_pool3d;
+ string avg_pool3d_grad;
string bias_add;
string bias_add_grad;
string concat;
@@ -2614,6 +2648,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv2d_grad_input;
string conv2d_grad_filter;
string conv2d_grad_filter_with_bias;
+ string conv3d;
+ string conv3d_grad_input;
+ string conv3d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
string identity;
@@ -2622,6 +2659,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string max_pool3d;
+ string max_pool3d_grad;
string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
@@ -3086,7 +3125,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
@@ -3177,6 +3216,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -3571,14 +3611,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
// Op-specific functions to copy attributes from old node to new node
//////////////////////////////////////////////////////////////////////////
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
string data_format;
string padding;
std::vector<int32> strides;
std::vector<int32> dilations;
- bool use_cudnn_on_gpu;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
@@ -3586,8 +3625,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@@ -3595,7 +3632,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("dilations", dilations);
nb->Attr("padding", padding);
nb->Attr("data_format", data_format);
- nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
@@ -3896,7 +3932,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
// Copy attributes from Conv2D to Conv2DWithBias.
- CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
+ CopyAttrsConv(const_cast<const Node*>(pred), &nb);
// Copy the device assigned to old node to new node.
nb.Device(succ->def().device());
@@ -4007,7 +4043,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
}
// Copy attributes from Conv2DBackpropFilter.
- CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb);
+ CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
// Copy the device assigned to old node to new node.
nb.Device(fltr->def().device());
diff --git a/tensorflow/core/graph/mkl_layout_pass.h b/tensorflow/core/graph/mkl_layout_pass.h
index ffe5c1ecfc..e7175149df 100644
--- a/tensorflow/core/graph/mkl_layout_pass.h
+++ b/tensorflow/core/graph/mkl_layout_pass.h
@@ -15,8 +15,8 @@ limitations under the License.
// A graph pass that rewrites graph for propagating MKL layout as a tensor
-#ifndef TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
-#define TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
+#ifndef TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
+#define TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
#ifdef INTEL_MKL
@@ -33,4 +33,4 @@ extern bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g);
#endif
-#endif // TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
+#endif // TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index aa39af637f..b67a321fc1 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -175,7 +175,11 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
- if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
+ // using data_format. This code might be redundant.
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -254,9 +258,13 @@ Status MklToTfConversionPass::InsertInputConversionNode(
}
}
+ // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
+ // seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
- Status::OK()) {
+ Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index 03f3bbd663..a446e0d136 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -30,7 +30,7 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit)
dt(SafeGetOutput(node, i, &error)) {}
NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
- : node(nullptr), error(false), name(std::string(n)), index(i), dt(t) {}
+ : node(nullptr), error(false), name(n), index(i), dt(t) {}
NodeBuilder::NodeOut::NodeOut()
: node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index f6b7b5674b..4727ee7b56 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_
-#define TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#ifndef TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
+#define TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
#include <vector>
#include "tensorflow/core/framework/node_def_builder.h"
@@ -160,4 +160,4 @@ NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#endif // TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h
index b8f3230c70..ef466fb788 100644
--- a/tensorflow/core/graph/optimizer_cse.h
+++ b/tensorflow/core/graph/optimizer_cse.h
@@ -15,8 +15,8 @@ limitations under the License.
// An optimization pass that performs common subexpression elimination.
-#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
-#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#ifndef TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
+#define TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
#include <sys/types.h>
#include "tensorflow/core/graph/graph.h"
@@ -34,4 +34,4 @@ extern bool OptimizeCSE(Graph* g,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#endif // TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h
index 2bb4ee1cf0..dc3d7e3b1f 100644
--- a/tensorflow/core/graph/quantize_training.h
+++ b/tensorflow/core/graph/quantize_training.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
-#define TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+#ifndef TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
+#define TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
#include "tensorflow/core/graph/graph.h"
@@ -53,4 +53,4 @@ Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+#endif // TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
index ba35846d93..3e99ff0c8c 100644
--- a/tensorflow/core/graph/subgraph.h
+++ b/tensorflow/core/graph/subgraph.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_
-#define TENSORFLOW_GRAPH_SUBGRAPH_H_
+#ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
+#define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
#include <string>
@@ -162,4 +162,4 @@ class SendFetchRewrite : public PruneRewrite {
} // namespace subgraph
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 80c76df255..5a5b85e727 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -25,7 +25,7 @@ namespace tensorflow {
TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
SafeTensorId::SafeTensorId(const TensorId& id)
- : SafeTensorId(id.first.ToString(), id.second) {}
+ : SafeTensorId(string(id.first), id.second) {}
TensorId ParseTensorName(const string& name) {
return ParseTensorName(StringPiece(name.data(), name.size()));
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f654..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index eb9038d619..bd0284d43a 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -15,8 +15,8 @@ limitations under the License.
// DEPRECATED: Use the C++ API defined in tensorflow/cc instead.
-#ifndef TENSORFLOW_GRAPH_TESTLIB_H_
-#define TENSORFLOW_GRAPH_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_GRAPH_TESTLIB_H_
+#define TENSORFLOW_CORE_GRAPH_TESTLIB_H_
#include <string>
#include <vector>
@@ -209,8 +209,17 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_GRAPH_TESTLIB_H_
+#endif // TENSORFLOW_CORE_GRAPH_TESTLIB_H_
diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h
index c707809927..ac5a7f8229 100644
--- a/tensorflow/core/graph/types.h
+++ b/tensorflow/core/graph/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_TYPES_H_
-#define TENSORFLOW_GRAPH_TYPES_H_
+#ifndef TENSORFLOW_CORE_GRAPH_TYPES_H_
+#define TENSORFLOW_CORE_GRAPH_TYPES_H_
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/types.h"
@@ -32,4 +32,4 @@ TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_TYPES_H_
+#endif // TENSORFLOW_CORE_GRAPH_TYPES_H_
diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc
index 1b38aac35d..8e89bc4c75 100644
--- a/tensorflow/core/graph/while_context.cc
+++ b/tensorflow/core/graph/while_context.cc
@@ -23,7 +23,7 @@ WhileContext::WhileContext(StringPiece frame_name,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs)
- : frame_name_(std::string(frame_name)),
+ : frame_name_(frame_name),
enter_nodes_(std::move(enter_nodes)),
exit_nodes_(std::move(exit_nodes)),
cond_output_(cond_output),
diff --git a/tensorflow/core/graph/while_context.h b/tensorflow/core/graph/while_context.h
index 2a83eb7bd8..5405e62be2 100644
--- a/tensorflow/core/graph/while_context.h
+++ b/tensorflow/core/graph/while_context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
-#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
+#define TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
#include "tensorflow/core/graph/graph.h"
@@ -73,4 +73,4 @@ class WhileContext {
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 6ca379323e..7171ae059b 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -81,6 +81,8 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_dependency_optimization(RewriterConfig::OFF);
rewriter_config->set_constant_folding(RewriterConfig::OFF);
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
+ rewriter_config->set_shape_optimization(RewriterConfig::OFF);
+ rewriter_config->set_remapping(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 12e3e46f65..f543dca49e 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -45,6 +45,8 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set)
for (const auto& device : device_set_->devices()) {
DeviceProperties props = GetDeviceInfo(device->parsed_name());
if (props.type() == "UNKNOWN") continue;
+ auto attrs = device->attributes();
+ props.set_memory_size(attrs.memory_limit());
devices_[device->name()] = props;
}
}
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index a60e3c7a9f..0690640ffa 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <unordered_map>
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index a5736d40b1..b01aca610a 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 231c7c63be..d24e7e8ee4 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -428,18 +429,22 @@ class SymbolicShapeRefiner {
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
- // to function node `node`.
+ // to function node `function_node`.
//
- // In the event of an error, UpdateNode will simply set `node`'s
+ // In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
- Status UpdateFunction(const NodeDef* node) {
- auto it = fun_to_grappler_function_item_.find(node->op());
+ Status UpdateFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
- node->op(), " was not previously added to SymbolicShapeRefiner.");
+ function_node->op(),
+ " was not previously added to SymbolicShapeRefiner.");
}
- GrapplerFunctionItem& grappler_function_item = it->second;
+ // Copy (not reference) so that changes we make here (e.g., replacing
+ // Placeholder with Const) don't affect one in
+ // fun_to_grappler_function_item_.
+ GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
@@ -452,7 +457,7 @@ class SymbolicShapeRefiner {
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
- const string& input = node->input(i);
+ const string& input = function_node->input(i);
const string& node_name = NodeName(input);
if (IsControlInput(input)) {
@@ -477,16 +482,35 @@ class SymbolicShapeRefiner {
TensorShapeProto proto;
const auto& handle = input_inference_context->output(output_port_num);
input_inference_context->ShapeHandleToProto(handle, &proto);
+ // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+ for (int i = 0; i < proto.dim_size(); i++) {
+ if (proto.dim(i).size() < -1) {
+ proto.mutable_dim(i)->set_size(-1);
+ }
+ }
*attr_output_shape.mutable_shape() = proto;
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
+ // Replace input Placeholders with Consts, if values are known. Note that
+ // we don't check exceptions here as it's done in the above loop.
+ for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+ const string& input = function_node->input(i);
+ const string& node_name = NodeName(input);
+ NodeDef* input_node = graph_.GetNode(node_name);
+ // TODO(dyoon): also use Const when output_tensors_as_shape is available.
+ if (IsConstant(*input_node)) {
+ TF_CHECK_OK(
+ ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ }
+ }
+
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(node);
+ auto ic = GetContext(function_node);
int output = 0;
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
@@ -504,8 +528,9 @@ class SymbolicShapeRefiner {
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
- return errors::FailedPrecondition("Unable to find return node ",
- node_name, " for ", node->name());
+ return errors::FailedPrecondition(
+ "Unable to find return function_node ", node_name, " for ",
+ function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -670,11 +695,13 @@ class SymbolicShapeRefiner {
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
- if (UpdateFunction(node).ok()) {
+ auto s = UpdateFunction(node);
+ if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
- << ". Defaulting to ShapeUnknown.";
+ << ". Defaulting to ShapeUnknown.\n"
+ << s.ToString();
}
}
@@ -804,8 +831,9 @@ class SymbolicShapeRefiner {
CHECK_NOTNULL(function_library_.Find(function_node->op()));
GrapplerFunctionItem grappler_function_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
- *function_def, function_library_, &grappler_function_item));
+ TF_RETURN_IF_ERROR(
+ MakeGrapplerFunctionItem(*function_def, function_library_,
+ graph_def_version_, &grappler_function_item));
if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 5acfb56b05..3ec68a4e59 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -783,6 +785,97 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
+ FunctionDefLibrary library;
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape; hence, func
+ *library.add_function() = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+ EXPECT_EQ(4, out_prop0.shape().dim_size());
+ EXPECT_EQ(1, out_prop0.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop0.shape().dim(1).size());
+ EXPECT_EQ(3, out_prop0.shape().dim(2).size());
+ EXPECT_EQ(4, out_prop0.shape().dim(3).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
+ // Create graph with a function that takes a scalar value so that we use
+ // Placeholder with scalar as for input to the function shape inference.
+ // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
+ // the input; all tensors are scalars.
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: float"}, // Inputs
+ {"out: float"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ Output placeholder =
+ ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({})));
+ Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
+ auto _identity = tensorflow::ops::AsNodeOut(s, identity);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Tensorflow version < 21 infers output shape of Placeholder with empty shape
+ // as unknown, instead of scalar.
+ EXPECT_GT(item.graph.versions().producer(), 21);
+
+ // MyFunc output shouldn't be unknown rank.
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+}
+
TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
// Test graph produced in python using:
/*
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 0341d7f8e1..71f4d9fd05 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 9e579098ef..998bd59dce 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index be54d98534..aad00ce039 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -99,7 +99,7 @@ static void ExtractExtraProperties(
continue;
}
TensorId input_tensor_id = ParseTensorName(input_name);
- const string input_node_name = input_tensor_id.first.ToString();
+ const string input_node_name(input_tensor_id.first);
auto iter = name_to_node.find(input_node_name);
if (iter == name_to_node.end()) continue;
@@ -172,7 +172,7 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
for (const auto& input_name : node.input()) {
CHECK(!input_name.empty());
TensorId input_tensor_id = ParseTensorName(input_name);
- const string input_node_name = input_tensor_id.first.ToString();
+ const string input_node_name(input_tensor_id.first);
const int output_index = input_tensor_id.second;
// Skip control inputs.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6e3ebdee12..037a823096 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -880,10 +880,15 @@ Costs VirtualScheduler::Summary() const {
// Print per device summary
VLOG(1) << "Devices:";
Costs critical_path_costs = Costs::ZeroCosts();
+ std::vector<string> device_names;
+ device_names.reserve(device_.size());
+ for (auto& it : device_) {
+ device_names.push_back(it.first);
+ }
+ std::sort(device_names.begin(), device_names.end());
- for (const auto& device : device_) {
- const auto& name = device.first;
- const auto& state = device.second;
+ for (const auto& name : device_names) {
+ const auto& state = device_.at(name);
std::map<string, int64> op_to_memory;
// First profile only persistent memory usage.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index b1373d8317..02a379fca8 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
diff --git a/tensorflow/core/grappler/graph_analyzer/BUILD b/tensorflow/core/grappler/graph_analyzer/BUILD
new file mode 100644
index 0000000000..d56a08d3c8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/BUILD
@@ -0,0 +1,139 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "graph_analyzer_lib",
+ srcs = [
+ "gen_node.cc",
+ "graph_analyzer.cc",
+ "sig_node.cc",
+ "subgraph.cc",
+ ],
+ hdrs = [
+ "gen_node.h",
+ "graph_analyzer.h",
+ "hash_tools.h",
+ "map_tools.h",
+ "sig_node.h",
+ "subgraph.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_library(
+ name = "graph_analyzer_tool",
+ srcs = ["graph_analyzer_tool.cc"],
+ hdrs = ["graph_analyzer_tool.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:grappler_item",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "test_tools_lib",
+ testonly = 1,
+ srcs = [
+ "test_tools.cc",
+ ],
+ hdrs = [
+ "test_tools.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+tf_cc_test(
+ name = "hash_tools_test",
+ testonly = 1,
+ srcs = [
+ "hash_tools_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "gen_node_test",
+ testonly = 1,
+ srcs = [
+ "gen_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "sig_node_test",
+ testonly = 1,
+ srcs = [
+ "sig_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "graph_analyzer_test",
+ testonly = 1,
+ srcs = [
+ "graph_analyzer_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "subgraph_test",
+ testonly = 1,
+ srcs = [
+ "subgraph_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.cc b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
new file mode 100644
index 0000000000..f8c15fd50e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
@@ -0,0 +1,148 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {}
+
+Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) {
+ for (const auto& n : source.node()) {
+ const string& name = n.name();
+ if (map->find(name) != map->end()) {
+ // This error code looks more meaningful than ALREADY_EXISTS.
+ return Status(error::INVALID_ARGUMENT,
+ "Duplicate node name '" + name + "'.");
+ }
+ (*map)[name] = absl::make_unique<GenNode>(&n);
+ }
+ // Now parse the links.
+ for (const auto& mapit : *map) {
+ Status st = mapit.second->ParseInputs(map);
+ if (!st.ok()) {
+ return st;
+ }
+ }
+ return Status::OK();
+}
+
+Status GenNode::ParseInputs(const GenNodeMap* map) {
+ all_inputs_or_none_ = false;
+ Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_);
+ if (!st.ok()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat("Node '%s' contains an undefined operation '%s': %s",
+ name(), opcode(), st.error_message()));
+ }
+
+ int n_inputs = node_->input_size();
+
+ int n_named_inputs = op_->input_arg_size();
+
+ int n_multi_inputs = 0;
+ for (const auto& inarg : op_->input_arg()) {
+ if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) {
+ ++n_multi_inputs;
+ }
+ }
+ bool is_commutative = grappler::IsCommutative(*node_);
+
+ if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) {
+ // Can't handle more than one multi-input at a time.
+ // And can't handle the commutativeness of only some arguments
+ // rather than all of them.
+ is_commutative = false;
+ }
+
+ if (is_commutative) {
+ // If truly commutative, can treat all the inputs as one multi-input.
+ // It's possible to just treat the commutative nodes as AllInputsOrNone
+ // but (1) this way is a bit more efficient and (2) I want to preserve this
+ // more efficient code path that does all-or-none by a single input and
+ // perhaps extend its use in the future.
+ n_named_inputs = 1;
+ all_inputs_or_none_ = false;
+ } else if (n_multi_inputs > 0) {
+ all_inputs_or_none_ = true;
+ }
+
+ for (int i = 0; i < n_inputs; ++i) {
+ int other_position;
+ string other_name = ParseNodeName(node_->input(i), &other_position);
+ auto other_it = map->find(other_name);
+ if (other_it == map->end()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' input %d refers to a non-existing node '%s'.", name(),
+ i, other_name));
+ }
+ GenNode* other_node = other_it->second.get();
+
+ int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i);
+
+ if (this_position >= 0 && n_multi_inputs == 0 &&
+ this_position >= n_named_inputs) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' has a non-control input from '%s' at index %d but its "
+ "operation '%s' defines only %d inputs.",
+ name(), other_name, this_position, op_->name(), n_named_inputs));
+ }
+
+ Port this_port(/*inbound=*/true, this_position);
+ Port other_port(/*inbound=*/false, other_position);
+
+ links_[this_port].emplace_back(LinkTarget(other_node, other_port));
+ other_node->links_[other_port].emplace_back(LinkTarget(this, this_port));
+ }
+ return Status::OK();
+}
+
+bool GenNode::IsMultiInput(Port port) const {
+ if (!port.IsInbound()) {
+ return false;
+ }
+ auto it = links_.find(port);
+ if (it == links_.end()) {
+ return false; // Shouldn't happen.
+ }
+ return (it->second.size() > 1);
+}
+
+GenNode::Port::operator string() const {
+ string result = this->IsInbound() ? "i" : "o";
+ if (this->IsControl()) {
+ result.append("C");
+ } else {
+ result.append(absl::StrFormat("%d", this->Id()));
+ }
+ return result;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.h b/tensorflow/core/grappler/graph_analyzer/gen_node.h
new file mode 100644
index 0000000000..faec9ecad8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.h
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+class GenNode;
+
+// To find nodes by name.
+using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>;
+
+// One node in the graph, in the form convenient for traversal and generation of
+// subgraphs. It refers to the original NodeDef protobuf for most information
+// and adds the extra enrichment.
+//
+// The graph building is 2-stage: first match a GenNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the GenNodes together.
+class GenNode {
+ public:
+ // Will keep the pointer, so the underlying object must not be deleted while
+ // GenNode is alive.
+ explicit GenNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // Parse the inputs of this node and update the map accordingly, creating the
+ // links (i.e. edges, connections between nodes) in itself and in the nodes
+ // it's linked to (the map itself is unchanged, only the nodes in it are
+ // updated).
+ Status ParseInputs(const GenNodeMap* map);
+
+ // Does the full 2-stage build of the graph. The map should be initially
+ // empty. The map keeps pointers to the nodes in source, so the source must
+ // not be destroyed before the map.
+ static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map);
+
+ // The enrichment that constitutes the point of this class.
+
+ // Representation of a connection on a node.
+ class Port {
+ public:
+ // A port may be inbound or outbound.
+ // Negative ids (canonically -1) mean a control port.
+ Port(bool inbound, int32_t id) : value_(id << 1) {
+ if (inbound) {
+ value_ |= 1;
+ }
+ }
+ Port(const Port&) = default;
+ Port& operator=(const Port&) = default;
+
+ bool IsInbound() const { return (value_ & 0x1); }
+
+ bool IsControl() const { return (value_ < 0); }
+
+ int32_t Id() const {
+ // Arithmetic shift preserves the sign.
+ return (value_ >> 1);
+ }
+
+ // Integer type used to represent the encoded port value.
+ using IntPort = int32_t;
+
+ // Returns the encoded form of this port, so that it can be used
+ // as various map indexes.
+ IntPort Encoded() const { return value_; }
+
+ static Port Decode(IntPort encoded) { return Port(encoded); }
+
+ bool operator==(const Port& other) const { return value_ == other.value_; }
+ bool operator<(const Port& other) const { return value_ < other.value_; }
+
+ struct Hasher {
+ size_t operator()(const Port& port) const noexcept {
+ return hasher(port.Encoded());
+ }
+ std::hash<int32_t> hasher;
+ };
+
+ // Convenient for printing. I've really wanted it to be implicit but
+ // ClangTidy insists on making it explicit.
+ explicit operator string() const;
+
+ private:
+ explicit Port(IntPort value) : value_(value) {}
+
+ IntPort value_;
+ };
+
+ struct LinkTarget {
+ GenNode* node; // Node where this link points.
+ Port port; // Port on the remote side of this link.
+
+ LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {}
+ };
+ // All the links that are connected to the same port of this node
+ // are collected in one vector. A link is an edge of the graph that connects
+ // 2 nodes. Each of the connected nodes has its own perspective on the link,
+ // seeing its local port, remote port and the remote node. The direction of
+ // the link is encoded in the ports, one port is always incoming and another
+ // one outgoing.
+ using LinkTargetVector = std::vector<LinkTarget>;
+ // Both inputs and outputs are stored in the same map.
+ using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>;
+
+ // Access to the link map.
+ const LinkMap& links() const { return links_; }
+
+ // Check whether the port is an input (including the controls) with multiple
+ // connections. Such inputs get handled in a special way when building the
+ // subgraphs, in an "all or nothing" fashion.
+ bool IsMultiInput(Port port) const;
+
+ // When building the subgraphs, must include either all non-control inputs of
+ // this node into the subgraph or none of them. This happens when at least one
+ // of the inputs is a multi-input (or if the opcode is commutative, thus
+ // treating all the inputs as one multi-input).
+ bool AllInputsOrNone() const { return all_inputs_or_none_; }
+
+ private:
+ const NodeDef* node_;
+ // Becomes valid only after ParseInputs().
+ const OpDef* op_;
+
+ // The opcode has a complicated structure of input args, with multi-input args
+ // that are not commutative. This means that to make sense, the subgraphs that
+ // include this node must also include either all its inputs or none of them.
+ bool all_inputs_or_none_ = false;
+
+ LinkMap links_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
new file mode 100644
index 0000000000..d77daf7849
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
@@ -0,0 +1,491 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(GenNodeTest, Port) {
+ {
+ GenNode::Port p(true, 100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(100));
+ }
+ {
+ GenNode::Port p(false, 0);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(0));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(0));
+ }
+ {
+ GenNode::Port p(true, -100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-100));
+ }
+ {
+ GenNode::Port p(false, -1);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-1));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-1));
+ }
+}
+
+TEST(GenNodeTest, ParseNodeNoInputs) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre());
+}
+
+// A general operation, and a control link.
+TEST(GenNodeTest, ParseNodeWithControl) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ node3.add_input("^node1"); // The control link.
+ node3.add_input("^node2"); // The control link.
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "iC: node1[oC], node2[oC]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+
+ // This is a multi-control-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, -1)), Eq(true));
+
+ EXPECT_FALSE(gn1->AllInputsOrNone());
+ EXPECT_FALSE(gn2->AllInputsOrNone());
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+// Commutative nodes are treated as having a single input,
+// because their inputs are equivalent.
+TEST(GenNodeTest, ParseNodeCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ // TODO(babkin): grappler::IsCommutative() should return true for Add but
+ // apparently doesn't. So use Mul in the meantime.
+ NodeDef node3 = MakeNodeMul("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeAddN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ // This is a multi-output.
+ EXPECT_THAT(gn2->IsMultiInput(GenNode::Port(false, 0)), Eq(false));
+ // This is a multi-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputNotCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputList) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeIdentityN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiMultiInput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeConst("node3");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeConst("node4");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ NodeDef node5 =
+ MakeNodeQuantizedConcat("node5", "node1", "node2", "node3", "node4");
+ map["node5"] = absl::make_unique<GenNode>(&node5);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ auto gn4 = map["node4"].get();
+ auto gn5 = map["node5"].get();
+ ASSERT_THAT(gn5->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node5[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node5[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "o0: node5[i2]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "o0: node5[i3]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn5->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "i2: node3[o0]",
+ "i3: node4[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 1)), Eq(false));
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 2)), Eq(false));
+ EXPECT_TRUE(gn5->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiOutput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeSub("node4", "node3:1", "node3:0");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ auto gn4 = map["node4"].get();
+ ASSERT_THAT(gn4->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeUndefinedOp) {
+ GenNodeMap map;
+ NodeDef node1;
+ node1.set_name("node1");
+ node1.set_op("Zzzx");
+
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ const OpDef* opdef;
+ Status nested_error = OpRegistry::Global()->LookUpOpDef("Zzzx", &opdef);
+
+ auto gn = map["node1"].get();
+ ASSERT_THAT(
+ gn->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' contains an undefined operation 'Zzzx': " +
+ nested_error.error_message())));
+}
+
+TEST(GenNodeTest, ParseNodeUnexpectedInputs) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+
+ auto gn1 = map["node1"].get();
+ EXPECT_THAT(gn1->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' has a non-control "
+ "input from 'node1' at index 0 but its operation "
+ "'Const' defines only 0 inputs.")));
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+ node3.add_input("node1");
+
+ auto gn3 = map["node3"].get();
+ EXPECT_THAT(gn3->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node3' has a non-control "
+ "input from 'node1' at index 2 but its operation "
+ "'Sub' defines only 2 inputs.")));
+}
+
+// Even if an opcode defines no inputs, the node may still accept the control
+// inputs.
+TEST(GenNodeTest, ParseNodeControlInputsAlwaysOk) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("^node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "iC: node1[oC]",
+ "oC: node1[iC]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeInvalidInput) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeAddN("node1", "node2", "node3");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(
+ gn1->ParseInputs(&map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node1' input 0 refers to a non-existing node 'node2'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMap) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ (*graph.add_node()) =
+ MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ EXPECT_THAT(map["node1"]->name(), Eq("node1"));
+ EXPECT_THAT(map["node2"]->name(), Eq("node2"));
+ EXPECT_THAT(map["node3"]->name(), Eq("node3"));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(map["node1"]->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node2"]->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]",
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node3"]->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "o0: node2[i1]",
+ "o1: node2[i0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, BuildGraphInMapDuplicateNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node1");
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMapParseError) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node2' input 0 refers to a non-existing node 'node3'.")));
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
new file mode 100644
index 0000000000..f3796fcf86
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+#include <iostream>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
+ : graph_(graph), subgraph_size_(subgraph_size) {}
+
+GraphAnalyzer::~GraphAnalyzer() {}
+
+Status GraphAnalyzer::Run() {
+ // The signature computation code would detect this too, but better
+ // to report it up front than spend time computing all the graphs first.
+ if (subgraph_size_ > Signature::kMaxGraphSize) {
+ return Status(error::INVALID_ARGUMENT,
+ absl::StrFormat("Subgraphs of %d nodes are not supported, "
+ "the maximal supported node count is %d.",
+ subgraph_size_, Signature::kMaxGraphSize));
+ }
+
+ Status st = BuildMap();
+ if (!st.ok()) {
+ return st;
+ }
+
+ FindSubgraphs();
+ DropInvalidSubgraphs();
+ st = CollateResult();
+ if (!st.ok()) {
+ return st;
+ }
+
+ return Status::OK();
+}
+
+Status GraphAnalyzer::BuildMap() {
+ nodes_.clear();
+ return GenNode::BuildGraphInMap(graph_, &nodes_);
+}
+
+void GraphAnalyzer::FindSubgraphs() {
+ result_.clear();
+
+ if (subgraph_size_ < 1) {
+ return;
+ }
+
+ partial_.clear();
+ todo_.clear(); // Just in case.
+
+ // Start with all subgraphs of size 1.
+ const Subgraph::Identity empty_parent;
+ for (const auto& node : nodes_) {
+ if (subgraph_size_ == 1) {
+ result_.ExtendParent(empty_parent, node.second.get());
+ } else {
+ // At this point ExtendParent() is guaranteed to not return nullptr.
+ todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
+ }
+ }
+
+ // Then extend the subgraphs until no more extensions are possible.
+ while (!todo_.empty()) {
+ ExtendSubgraph(todo_.front());
+ todo_.pop_front();
+ }
+
+ partial_.clear();
+}
+
+void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
+ bool will_complete = (parent->id().size() + 1 == subgraph_size_);
+ SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
+
+ const GenNode* last_all_or_none_node = nullptr;
+ for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
+ const GenNode* node = sit.GetNode();
+ GenNode::Port port = sit.GetPort();
+ const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
+
+ if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
+ if (node != last_all_or_none_node) {
+ ExtendSubgraphAllOrNone(parent, node);
+ last_all_or_none_node = node;
+ }
+ sit.SkipPort();
+ } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
+ !port.IsControl()) {
+ if (parent->id().find(neighbor.node) == parent->id().end()) {
+ // Not added yet.
+ ExtendSubgraphAllOrNone(parent, neighbor.node);
+ }
+ } else if (node->IsMultiInput(port)) {
+ ExtendSubgraphPortAllOrNone(parent, node, port);
+ sit.SkipPort();
+ } else if (neighbor.node->IsMultiInput(neighbor.port)) {
+ // Would need to add all inputs of the neighbor node at this port at
+ // once.
+ if (parent->id().find(neighbor.node) != parent->id().end()) {
+ continue; // Already added.
+ }
+ ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
+ } else {
+ Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
+ if (!will_complete && sg != nullptr) {
+ todo_.push_back(sg);
+ }
+ }
+ }
+}
+
+void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
+ const GenNode* node) {
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ auto range_end = node->links().end();
+
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
+ const GenNode* node,
+ GenNode::Port port) {
+ auto nbit = node->links().find(port);
+ if (nbit == node->links().end()) {
+ return; // Should never happen.
+ }
+
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
+ const Subgraph::Identity& id) {
+ if (id.size() == parent->id().size()) {
+ return; // Nothing new was added.
+ }
+
+ auto sg = absl::make_unique<Subgraph>(id);
+ SubgraphPtrSet& spec_sg_set =
+ (id.size() == subgraph_size_) ? result_ : partial_;
+ if (spec_sg_set.find(sg) != spec_sg_set.end()) {
+ // This subgraph was already found by extending from a different path.
+ return;
+ }
+
+ if (id.size() != subgraph_size_) {
+ todo_.push_back(sg.get());
+ }
+ spec_sg_set.insert(std::move(sg));
+}
+
+void GraphAnalyzer::DropInvalidSubgraphs() {
+ auto resit = result_.begin();
+ while (resit != result_.end()) {
+ if (HasInvalidMultiInputs(resit->get())) {
+ auto delit = resit;
+ ++resit;
+ result_.erase(delit);
+ } else {
+ ++resit;
+ }
+ }
+}
+
+bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
+ // Do the all-or-none-input nodes.
+ for (auto const& node : sg->id()) {
+ if (!node->AllInputsOrNone()) {
+ continue;
+ }
+
+ bool anyIn = false;
+ bool anyOut = false;
+
+ auto range_end = node->links().end();
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ if (sg->id().find(link.node) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ }
+ }
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+
+ // Do the multi-input ports.
+ for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
+ if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
+ bool anyIn = false;
+ bool anyOut = false;
+ do {
+ GenNode* peer = sit.GetNeighbor().node;
+ if (sg->id().find(peer) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ } while (sit.NextIfSamePort());
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Status GraphAnalyzer::CollateResult() {
+ ordered_collation_.clear();
+ collation_map_.clear();
+
+ // Collate by the signatures of the graphs.
+ for (const auto& it : result_) {
+ auto sig = absl::make_unique<Signature>();
+ it->ExtractForSignature(&sig->map);
+ Status status = sig->Compute();
+ if (!status.ok()) {
+ return status;
+ }
+
+ auto& coll_entry = collation_map_[sig.get()];
+ if (coll_entry.sig == nullptr) {
+ coll_entry.sig = std::move(sig);
+ }
+ ++coll_entry.count;
+ }
+
+ // Then order them by the count.
+ for (auto& entry : collation_map_) {
+ ordered_collation_.insert(&entry.second);
+ }
+
+ result_.clear(); // Not needed after collation.
+
+ return Status::OK();
+}
+
+std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
+ std::vector<string> result;
+ for (const auto& it : result_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+}
+
+std::vector<string> GraphAnalyzer::DumpSubgraphs() {
+ std::vector<string> result;
+ for (auto ptr : ordered_collation_) {
+ result.emplace_back(
+ absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
+ }
+ return result;
+}
+
+Status GraphAnalyzer::OutputSubgraphs() {
+ size_t total = 0;
+ for (auto ptr : ordered_collation_) {
+ std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
+ total += ptr->count;
+ }
+ std::cout << "Total: " << total << '\n';
+ if (std::cout.fail()) {
+ return Status(error::DATA_LOSS, "Failed to write to stdout");
+ } else {
+ return Status::OK();
+ }
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
new file mode 100644
index 0000000000..97626346c7
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
@@ -0,0 +1,154 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class GraphAnalyzerTest;
+} // end namespace test
+
+// Finds all the subgraphs of a given size and groups them by equivalence.
+class GraphAnalyzer {
+ public:
+ // Makes a copy of the graph.
+ GraphAnalyzer(const GraphDef& graph, int subgraph_size);
+
+ virtual ~GraphAnalyzer();
+
+ // Performs the analysis and collects the subgraphs.
+ Status Run();
+
+ // Returns the subgraphs found in Run() printed to text.
+ std::vector<string> DumpSubgraphs();
+
+ // Prints the subgraphs found in Run() to stdout.
+ Status OutputSubgraphs();
+
+ // TODO(babkin): add a way to extract the subgraphs as direct data
+ // structures and as protobufs, and to write protobufs to a RecordIO.
+
+ private:
+ GraphAnalyzer() = delete;
+ GraphAnalyzer(const GraphAnalyzer&) = delete;
+ void operator=(const GraphAnalyzer&) = delete;
+
+ friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest;
+
+ // Builds the map of nodes from the original graph definition.
+ Status BuildMap();
+
+ // Using nodes_, finds all the subgraphs of size subgraph_size_ and places
+ // them into result_.
+ void FindSubgraphs();
+
+ // Deletes from result_ the unacceptable subgraphs. Those include the
+ // subgraphs where not all the inputs at a multi-input port are included (this
+ // could happen if some of these inputs were reached and included through
+ // different paths).
+ void DropInvalidSubgraphs();
+
+ // Deletes from result_ duplicate entries of equivalent topology.
+ Status CollateResult();
+
+ // Returns the raw subgraphs found in FindSubgraphs() printed to text.
+ std::vector<string> DumpRawSubgraphs();
+
+ // Finds and adds appropriately to either partial_ or result_ all the
+ // subgraphs that can be created by extending the parent subgraph by one node.
+ // Ignores the duplicates.
+ void ExtendSubgraph(Subgraph* parent);
+
+ // Extends the parent subgraph by adding another node (if it wasn't already
+ // added) and all its non-control inputs in the link map range at once.
+ // If the subgraph would grow over subgraph_size_, it gets ignored.
+ void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node);
+ // Same but adds one specific inbound port (even control) all-or-none.
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node,
+ GenNode::Port port);
+ // The common final step called by ExtendSubgraph*AllOrNone() methods.
+ void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id);
+
+ // Returns true if this subgraph has any multi-inputs that aren't all-in or
+ // all-out.
+ bool HasInvalidMultiInputs(Subgraph* sg);
+
+ // Graph to run the analysis on.
+ GraphDef graph_;
+ int subgraph_size_;
+
+ // The enriched graph of parsed nodes and connections.
+ GenNodeMap nodes_;
+ // The resulting set of subgraphs.
+ SubgraphPtrSet result_;
+ // The subgraphs of partial size, stored while finding the result.
+ SubgraphPtrSet partial_;
+ // The subgraphs of partial size (stored in partial_) that are still waiting
+ // to be extended.
+ //
+ // TODO(babkin): This is rather simple-minded, each subgraph is examined from
+ // scratch, which means that all its internal links get iterated too. But it's
+ // OK for the small subgraphs. This can be improved by keeping not just
+ // subgraphs but iterators on the list, each of them having the list not-yet
+ // examined nodes (and the link position of the next link to be examined for
+ // the first node). This would add extra constant overhead, so the break-even
+ // subgraph size is not clear yet.
+ std::deque<Subgraph*> todo_;
+
+ // The collation map by signature is designed to allow the removal of entries
+ // and moving of the signature references from the keys of this map to the
+ // outside world. Must be careful at inserting and removal: make sure that
+ // when a new entry is inserted, its signature reference gets populated with
+ // the same data as the key of the map, and that if a reference is moved out,
+ // the map entry gets removed before that reference gets destroyed.
+ struct CollationEntry {
+ std::shared_ptr<Signature> sig;
+ size_t count = 0;
+ };
+ using CollationMap =
+ std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>,
+ EqAtPtr<Signature*> >;
+ CollationMap collation_map_;
+
+ // The entries are owned by collation_map_, so must be removed from
+ // ordered_collation_ before removing them from collation_map_.
+ struct ReverseLessByCount {
+ bool operator()(CollationEntry* left, CollationEntry* right) const {
+ return left->count > right->count; // Reverse order.
+ }
+ };
+ using CollationOrderByCount =
+ std::multiset<CollationEntry*, ReverseLessByCount>;
+ CollationOrderByCount ordered_collation_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
new file mode 100644
index 0000000000..e94c472056
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
@@ -0,0 +1,569 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+
+#include <algorithm>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+class GraphAnalyzerTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ Status BuildMap() { return gran_->BuildMap(); }
+
+ void FindSubgraphs() { gran_->FindSubgraphs(); }
+
+ void DropInvalidSubgraphs() { gran_->DropInvalidSubgraphs(); }
+
+ Status CollateResult() { return gran_->CollateResult(); }
+
+ void ExtendSubgraph(Subgraph* parent) { gran_->ExtendSubgraph(parent); }
+
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, GenNode* node,
+ GenNode::Port port) {
+ gran_->ExtendSubgraphPortAllOrNone(parent, node, port);
+ }
+
+ void ExtendSubgraphAllOrNone(Subgraph* parent, GenNode* node) {
+ gran_->ExtendSubgraphAllOrNone(parent, node);
+ }
+
+ std::vector<string> DumpRawSubgraphs() { return gran_->DumpRawSubgraphs(); }
+
+ std::vector<string> DumpPartials() {
+ std::vector<string> result;
+ for (const auto& it : gran_->partial_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+ }
+
+ const GenNodeMap& GetNodes() { return gran_->nodes_; }
+
+ GenNode* GetNode(const string& name) { return gran_->nodes_.at(name).get(); }
+
+ SubgraphPtrSet& GetResult() { return gran_->result_; }
+ SubgraphPtrSet& GetPartial() { return gran_->partial_; }
+ std::deque<Subgraph*>& GetTodo() { return gran_->todo_; }
+
+ // Gets initialized by a particular test from a suitable GraphDef.
+ std::unique_ptr<GraphAnalyzer> gran_;
+};
+
+TEST_F(GraphAnalyzerTest, BuildMap) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ EXPECT_THAT(st, Eq(Status::OK()));
+
+ auto& map = GetNodes();
+ EXPECT_THAT(map.find("node1"), Ne(map.end()));
+ EXPECT_THAT(map.find("node2"), Ne(map.end()));
+ EXPECT_THAT(map.find("node3"), Ne(map.end()));
+}
+
+TEST_F(GraphAnalyzerTest, BuildMapError) {
+ // A duplicate node.
+ (*graph_3n_self_control_.add_node()) = MakeNodeConst("node1");
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(
+ st, Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs0) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 0);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(0));
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs1) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(3));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: BroadcastGradientArgs(node3)",
+ "1: Const(node1)",
+ "1: Sub(node2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// The required subgraphs are larger than the graph.
+TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node not in the graph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto parent = absl::make_unique<Subgraph>(Subgraph::Identity());
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(parent.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ // clang-format off
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards through a multi-input link, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards through a multi-input link, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("add2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards through a multi-input link,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, multi-input is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("add1"),
+ })));
+ // A good one, multi-input is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add1"),
+ GetNode("add2"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("add1"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add1), AddN(add2), Sub(sub)",
+ "1: AddN(add1), Const(const1_1), Const(const1_2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// but no control links propagate. It also tests the situation
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass1"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// The control links propagate separately as all-or-none, even on the nodes
+// that are all-or-none for the normal inputs.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("pass1"),
+ GenNode::Port(true, -1));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards from all-or-none-input node, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards from all-or-none-input node, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("pass2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards to all-or-none-input node,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards from all-or-none-input node.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: IdentityN(pass2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards to all-or-none-input node. This includes
+// both all-or-none-input for the normal inputs, and multi-input by the
+// control path.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)",
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsAllOrNone) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, all-or-none is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("pass1"),
+ })));
+ // A good one, all-or-none is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass1"),
+ GetNode("pass2"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("pass1"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: IdentityN(pass1), IdentityN(pass2), Sub(sub)",
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
new file mode 100644
index 0000000000..924ca11e61
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Dies on failure.
+static void LoadModel(const string& filename,
+ tensorflow::MetaGraphDef* metagraph) {
+ LOG(INFO) << "Loading model from " << filename;
+ Status st;
+ st = ReadBinaryProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(WARNING) << "Failed to read a binary metagraph: " << st;
+ st = ReadTextProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to read a text metagraph: " << st;
+ }
+ }
+}
+
+// Prune the graph to only keep the transitive fanin part with respect to a set
+// of train ops (if provided).
+void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph,
+ tensorflow::GraphDef* graph) {
+ std::vector<string> fetch_nodes;
+ for (const auto& fetch :
+ metagraph.collection_def().at("train_op").node_list().value()) {
+ LOG(INFO) << "Fetch node: " << fetch;
+ fetch_nodes.push_back(fetch);
+ }
+ if (fetch_nodes.empty()) {
+ *graph = metagraph.graph_def();
+ } else {
+ std::vector<const tensorflow::NodeDef*> fanin_nodes =
+ tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(),
+ fetch_nodes);
+ for (const tensorflow::NodeDef* node : fanin_nodes) {
+ *(graph->add_node()) = *node;
+ }
+ LOG(INFO) << "Pruned "
+ << metagraph.graph_def().node_size() - graph->node_size()
+ << " nodes. Original graph size: "
+ << metagraph.graph_def().node_size()
+ << ". New graph size: " << graph->node_size() << ".";
+ }
+}
+
+void GraphAnalyzerTool(const string& file_name, int n) {
+ if (n < 1) {
+ LOG(FATAL) << "Invalid subgraph size " << n << ", must be at least 1";
+ }
+
+ tensorflow::MetaGraphDef metagraph;
+ LoadModel(file_name, &metagraph);
+ tensorflow::GraphDef graph;
+ MaybePruneGraph(metagraph, &graph);
+ tensorflow::grappler::graph_analyzer::GraphAnalyzer analyzer(graph, n);
+ LOG(INFO) << "Running the analysis";
+ tensorflow::Status st = analyzer.Run();
+ if (!st.ok()) {
+ LOG(FATAL) << "Analysis failed: " << st;
+ }
+
+ LOG(INFO) << "Printing the result";
+ st = analyzer.OutputSubgraphs();
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to print the result: " << st;
+ }
+
+ LOG(INFO) << "Completed";
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/parallel_check_op.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
index db5c195578..5a91fe7dc8 100644
--- a/tensorflow/compiler/jit/ops/parallel_check_op.cc
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/framework/op.h"
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+void GraphAnalyzerTool(const string& file_name, int n);
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
-REGISTER_OP("ParallelCheck")
- .Attr("T: list(type) >= 0")
- .Input("expected: T")
- .Input("actual: T")
- .Output("result: T")
- .Doc(R"doc(
-Op that compares two sets of inputs for near-identity, and propagates the first.
-Inequality is logged to ERROR log.
-)doc");
-
-} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools.h b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
new file mode 100644
index 0000000000..b0e79f9a68
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+
+#include <cstddef>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Unfortunately, std::hash provides no way to combine hashes, so everyone
+// is copying boost::hash_combine. This is a version that follows Google's
+// guidelines on the arguments, and contains only the combination, without
+// hashing.
+inline void CombineHash(size_t from, size_t* to) {
+ *to ^= from + 0x9e3779b9 + (*to << 6) + (*to >> 2);
+}
+
+// Combine two hashes in such a way that the order of combination doesn't matter
+// (so it's really both commutative and associative). The result is not a very
+// high-quality hash but can be used in case if the order of sub-elements must
+// not matter in the following comparison. An alternative would be to sort the
+// hashes of the sub-elements and then combine them normally in the sorted
+// order.
+inline void CombineHashCommutative(size_t from, size_t* to) {
+ *to = *to + from + 0x9e3779b9;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
index 1f06004db2..b5e9ce6b8e 100644
--- a/tensorflow/core/util/status_util_test.cc
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
@@ -13,24 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/status_util.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
namespace {
-TEST(TestStatusUtil, ErrorFormatTagForNode) {
- Graph graph(OpRegistry::Global());
- Node* node;
- TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node));
- EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^");
- EXPECT_EQ(error_format_tag(*node, "${file}:${line}"),
- "^^node:Foo:${file}:${line}^^");
+using ::testing::Eq;
+
+TEST(HashToolsTest, CombineHashCommutative) {
+ size_t a = 0;
+ size_t b = 999;
+
+ size_t c = a;
+ CombineHashCommutative(b, &c);
+
+ size_t d = b;
+ CombineHashCommutative(a, &d);
+
+ EXPECT_THAT(c, Eq(d));
}
} // namespace
-} // namespace tensorflow
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/map_tools.h b/tensorflow/core/grappler/graph_analyzer/map_tools.h
new file mode 100644
index 0000000000..584062c5f2
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/map_tools.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+
+#include <functional>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Helpers for building maps of pointers.
+
+template <typename Ptr>
+struct LessAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x < *y; }
+};
+
+template <typename Ptr>
+struct EqAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x == *y; }
+};
+
+template <typename Ptr>
+struct HashAtPtr : std::unary_function<Ptr, size_t> {
+ size_t operator()(const Ptr& x) const { return x->Hash(); }
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.cc b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
new file mode 100644
index 0000000000..b5cca6a512
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
@@ -0,0 +1,453 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <algorithm>
+
+#include "absl/strings/str_format.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+static constexpr bool debug = false;
+
+//=== SigNode
+
+SigNode::SigNode(const NodeDef* node) : node_(node) {}
+
+void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
+ hash_to_link_.clear();
+ hashed_peers_.clear();
+
+ std::map<LinkTag, Link> link_map;
+ CopyLinksPass1(from, map, &link_map);
+ CopyLinksPass2(&link_map);
+}
+
+void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map) {
+ LinkTag::Hasher link_hasher;
+
+ for (const auto& entry : from.links()) {
+ for (const auto& target : entry.second) {
+ auto nodeit = map.find(target.node);
+ if (nodeit == map.end()) {
+ // Node is not in the subgraph, ignore.
+ continue;
+ }
+
+ LinkTag tag(entry.first, target.port);
+ size_t hval = link_hasher(tag);
+
+ // This instantiates the entry if it was not present.
+ Link& map_entry = (*link_map)[tag];
+ if (map_entry.peers.empty()) {
+ map_entry.tag = tag;
+ map_entry.unique_hash = hval;
+ }
+ map_entry.peers.push_back(nodeit->second);
+ }
+ }
+}
+
+void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
+ for (auto& entry : *link_map) {
+ Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ // In case of a conflict, rehash. This should almost never happen.
+ // Because the order of iteration is predictable, the rehashed values
+ // will also be predictable.
+ while (!hl_entry_ptr->peers.empty()) {
+ CombineHash(1, &entry.second.unique_hash);
+ hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ }
+
+ for (const auto& peer : entry.second.peers) {
+ hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
+ }
+
+ hl_entry_ptr->tag = entry.second.tag;
+ hl_entry_ptr->unique_hash = entry.second.unique_hash;
+ hl_entry_ptr->peers.swap(entry.second.peers);
+ }
+}
+
+void SigNode::ComputeTopoHash0() {
+ topo_hash_.clear();
+ last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
+
+ // TODO(babkin): include the attrbutes too, as an option.
+ size_t hval = std::hash<string>()(opcode());
+
+ // Getting the topology of the links in to the hash early should get more
+ // conflicts resolved early.
+ for (const auto& entry : hashed_peers_) {
+ CombineHash(entry.link_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+void SigNode::ComputeTopoHash(int distance) {
+ // The new starting point.
+ next_hashed_nodes_ = last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex
+ << next_hashed_nodes_;
+ }
+
+ if (hash_is_final_) {
+ return;
+ }
+
+ CHECK(topo_hash_.size() == distance);
+
+ int prev = distance - 1;
+
+ // Start with own's local topology hash. This value is stable, so
+ // if the hashes of the surrounding nodes don't change on the following
+ // distances, the hash of this node won't change either.
+ size_t hval = topo_hash_[0];
+
+ if (!hashed_peers_.empty()) {
+ size_t last_link_hash = hashed_peers_[0].link_hash;
+ size_t comm_hash = 0;
+
+ for (const auto& entry : hashed_peers_) {
+ if (entry.link_hash != last_link_hash) {
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ comm_hash = 0;
+ last_link_hash = entry.link_hash;
+ }
+
+ // The links in the same vector are commutative, so combine their
+ // hashes in a commutative way.
+ CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
+ next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name()
+ << " mask=" << std::hex << next_hashed_nodes_;
+ }
+ }
+
+ // The last commutative group.
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+size_t SigNode::GetTopoHash(int distance) const {
+ CHECK(!topo_hash_.empty());
+ if (distance >= topo_hash_.size()) {
+ CHECK(hash_is_final_);
+ return topo_hash_.back();
+ } else {
+ return topo_hash_[distance];
+ }
+}
+
+bool SigNode::operator==(const SigNode& other) const {
+ // TODO(babkin): add attributes too.
+ if (opcode() != other.opcode()) {
+ return false;
+ }
+
+ // Normally the caller is expected to compare the nodes
+ // at the same rank in different graphs, but just in case...
+ if (unique_rank_ != other.unique_rank_) {
+ return false;
+ }
+
+ if (hashed_peers_.size() != other.hashed_peers_.size()) {
+ return false;
+ }
+
+ for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
+ it1 != hashed_peers_.end(); ++it1, ++it2) {
+ // TODO(babkin): might compare the actual values too
+ // but the hash is probably just as good.
+ if (it1->link_hash != it2->link_hash) {
+ return false;
+ }
+ if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+//=== Signature
+
+constexpr int Signature::kMaxGraphSize;
+
+string Signature::ToString() const {
+ string result;
+ for (size_t n = 0; n < nodes.size(); ++n) {
+ // TODO(babkin): add attributes too.
+ result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
+ for (const auto& entry : nodes[n]->hashed_peers_) {
+ const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
+
+ // The link entries are already sorted, by tags and then by the
+ // node ranks.
+ if (link.tag.local.IsInbound()) {
+ result +=
+ absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
+ string(link.tag.remote), entry.peer->unique_rank_);
+ }
+ }
+ result.push_back(',');
+ }
+ return result;
+}
+
+Status Signature::Compute() {
+ if (map.size() > kMaxGraphSize) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "A graph of %d nodes is too big for signature computation, "
+ "the maximal supported node count is %d.",
+ map.size(), kMaxGraphSize));
+ }
+
+ // The value that will be assigned next as the unique node id.
+ // This also means that all the entries in nodes at indexes less than this
+ // have been finalized and don't need to be touched any more.
+ size_t next_node_id = 0;
+
+ sig_short = 0;
+ sig_full.resize(0); // Keep the storage.
+
+ // The main signature generation.
+ PrepareNodes();
+ FindUniqueHashes(&next_node_id);
+ while (next_node_id < map.size()) {
+ ComputeOneRound(next_node_id);
+ FindUniqueHashes(&next_node_id);
+ }
+
+ OrderLinks();
+
+ return Status::OK();
+}
+
+void Signature::PrepareNodes() {
+ nodes.resize(0); // Keep the storage.
+
+ // Initialize the nodes.
+ int64_t mask = 1;
+ for (const auto& entry : map) {
+ SigNode* node = entry.second.get();
+ node->last_hashed_nodes_ = node->node_mask_ = mask;
+ mask <<= 1;
+ node->unique_rank_ = ~0;
+ node->hash_is_final_ = false;
+ node->ComputeTopoHash0();
+ if (node->GetHighTopoHash() <= map.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+
+ // The initial order is random.
+ nodes.emplace_back(node);
+ }
+}
+
+void Signature::FindUniqueHashes(size_t* next_node_id_p) {
+ // Start by sorting by the hash value.
+ std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
+ SigNode::NodeOrderLess());
+
+ // At each call, if no nodes have unique hashes, one node that has a
+ // non-unique (shared) hash can be made unique by assigning a unique id.
+ // This node gets picked predictably by taking the last node.
+ // TODO(babkin): Technically, more than one node can be unshared,
+ // as long as their last_hashed_nodes_ overlap only by the nodes that
+ // already had the assigned ids before the current round. But it's not clear
+ // yet, how often would this beneficial, because it looks like for many
+ // subgraphs unsharing one node should be enough to untangle them. This
+ // would need more measurement before implementing.
+ bool found_unique = false;
+ for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
+ size_t cur_hash = nodes[n]->GetHighTopoHash();
+ if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
+ // A sequence of nodes sharing the same hash. Skip over it.
+ // TODO(babkin): check here for the arbitrary hash conflicts and resolve
+ // them.
+ for (++n;
+ n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
+ ++n) {
+ }
+ if (found_unique || n != nodes.size() - 1) {
+ // Either some unique nodes have already been found, or this is
+ // not the last chance, keep trying to find the unique nodes.
+ continue;
+ }
+ // Here we're at the last node and haven't found any unique ones.
+ // So fall through and make this last node unique.
+ }
+
+ found_unique = true;
+ size_t id = (*next_node_id_p)++;
+ nodes[n]->unique_rank_ = id;
+
+ size_t last_hash = nodes[n]->GetHighTopoHash();
+ CombineHash(last_hash, &sig_short);
+ sig_full.push_back(last_hash);
+
+ // Take the hash at 0 and mix the unique rank into it. After that it will
+ // stay fixed.
+ nodes[n]->topo_hash_.resize(1);
+ nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0.
+
+ nodes[n]->hash_is_final_ = true;
+ nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
+ if (n != id) {
+ std::swap(nodes[id], nodes[n]);
+ }
+ }
+}
+
+void Signature::ComputeOneRound(size_t next_node_id) {
+ // Reset the state of the nodes.
+ int debug_i = 0;
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ // The hash at distance 0 never changes, so preserve it.
+ node->topo_hash_.resize(1);
+ node->last_hashed_nodes_ = node->node_mask_;
+ node->hash_is_final_ = false;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
+ << node->name() << " mask=" << std::hex
+ << node->last_hashed_nodes_;
+ }
+ }
+
+ bool stop = false;
+ // The distance can reach up to nodes.size()+1, to include not only all the
+ // nodes but also all the redundant paths.
+ for (int distance = 1; !stop; ++distance) {
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (node->hash_is_final_) {
+ continue;
+ }
+ node->ComputeTopoHash(distance);
+ if (node->GetHighTopoHash() <= nodes.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+ }
+
+ // Will be looking for the indications to not stop.
+ stop = true;
+
+ debug_i = 0;
+ // The bitmasks get moved after all the hash computations are done.
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
+ << " " << node->name() << " oldmask=" << std::hex
+ << node->last_hashed_nodes_ << " mask=" << std::hex
+ << node->next_hashed_nodes_;
+ }
+ if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
+ // Stopped growing, this part of the graph must be fully
+ // surrounded by nodes that already have the unique ids.
+ node->hash_is_final_ = true;
+ } else {
+ node->last_hashed_nodes_ = node->next_hashed_nodes_;
+ stop = false;
+ }
+ }
+ }
+}
+
+void Signature::OrderLinks() {
+ for (const auto& node : nodes) {
+ if (node->hashed_peers_.empty()) {
+ continue;
+ }
+
+ size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
+ int first_idx = -1;
+
+ int idx;
+ for (idx = 0; idx < node->hashed_peers_.size(); ++idx) {
+ auto& entry = node->hashed_peers_[idx];
+ if (entry.link_hash == cur_link_hash) {
+ continue;
+ }
+ if (idx - first_idx > 1) {
+ // Need to sort.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+
+ cur_link_hash = entry.link_hash;
+ first_idx = idx;
+ }
+ if (idx - first_idx > 1) {
+ // Sort the last bunch.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+ }
+}
+
+bool Signature::operator==(const Signature& other) const {
+ // Tries to find the differences as early as possible by
+ // comparing the hashes first.
+
+ if (sig_short != other.sig_short) {
+ return false;
+ }
+ if (sig_full.size() != other.sig_full.size()) {
+ return false;
+ }
+
+ for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
+ it1 != sig_full.end(); ++it1, ++it2) {
+ if (*it1 != *it2) {
+ return false;
+ }
+ }
+
+ if (nodes.size() != other.nodes.size()) {
+ return false;
+ }
+ for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
+ ++it1, ++it2) {
+ if (**it1 != **it2) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.h b/tensorflow/core/grappler/graph_analyzer/sig_node.h
new file mode 100644
index 0000000000..45c0ed3162
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class SigBaseTest;
+} // end namespace test
+
+class SigNode;
+
+// To find nodes by name. Having the map ordered makes the tests easier,
+// and it isn't used in production code often enough to get any win from
+// using an unordered map.
+using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>;
+
+// One node in the graph, in the form convenient for generation of the signature
+// of the graph, and comparison of two (sub)graphs for equivalence. It refers to
+// the original NodeDef protobuf for most information and adds the extra
+// enrichment.
+//
+// The graph building is 2-stage: first match a SigNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the SigNodes together.
+class SigNode {
+ public:
+ friend struct Signature;
+
+ // Will keep the pointer to the underlying NodeDef, so that
+ // underlying object must not be deleted while SigNode is alive.
+ explicit SigNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // For extraction of subgraphs into a separate SigNodeMap, copies the links
+ // that point inside the subgraph from a full-graph SigNode to a subgraph
+ // SigNode. The translation map defines the subgraph and gives the mapping
+ // from the nodes in the full graph to the matching nodes in subgraph.
+ using TranslationMap =
+ std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>;
+ void CopyLinks(const GenNode& from, const TranslationMap& map);
+
+ // A link is an edge of the graph that connects 2 nodes. Each of the connected
+ // nodes has its own perspective on the link, seeing its local port, remote
+ // port and the remote node. The direction of the link is encoded in the
+ // ports, one port is always incoming and another one outgoing.
+ //
+ // The link tag here contains both ports of the link viewed from the
+ // perspective of this node; consisting of both the local port (i.e. at this
+ // node) and remote port (i.e. on the other node), the local one going first.
+ struct LinkTag {
+ struct Hasher {
+ size_t operator()(const LinkTag& tag) const noexcept {
+ size_t hval = port_hasher(tag.local);
+ CombineHash(port_hasher(tag.remote), &hval);
+ return hval;
+ }
+ GenNode::Port::Hasher port_hasher;
+ };
+
+ LinkTag(GenNode::Port a_local, GenNode::Port a_remote)
+ : local(a_local), remote(a_remote) {}
+
+ // The default constructor is used for the default values in maps.
+ // (false, 99) is an arbitrary value that makes the uninitialized
+ // links easy to tell when debugging (they should never happen).
+ LinkTag() : local(false, 99), remote(false, 99) {}
+
+ // Port of the link on the local node.
+ GenNode::Port local;
+ // Port of the link on the remote node.
+ GenNode::Port remote;
+
+ bool operator==(const LinkTag& other) const {
+ return local == other.local && remote == other.remote;
+ }
+ bool operator<(const LinkTag& other) const {
+ return local < other.local ||
+ (local == other.local && remote < other.remote);
+ }
+ };
+
+ // Since the signature logic doesn't differentiate between the links
+ // with the same tag (other than by the "peer" nodes on their other ends),
+ // all the links with the same tag are grouped into a single structure.
+ struct Link {
+ LinkTag tag;
+ size_t unique_hash; // Hash of the tag after conflict resolution.
+ // The remote node(s) on the other side on the link(s).
+ using PeerVector = std::vector<SigNode*>;
+ PeerVector peers;
+ };
+
+ // A way to look up the link description by its hash.
+ using LinkHashMap = std::map<size_t, Link>;
+ const LinkHashMap& hash_to_link() const { return hash_to_link_; }
+
+ // The enumeration of all the peer nodes in a predictable order.
+ // Before the signature generation, only the link values determine the
+ // order, after the signature generation the entries at the same
+ // links get further sorted by their peer node ranks.
+ struct HashedPeer {
+ HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {}
+
+ struct LessByRank {
+ bool operator()(const SigNode::HashedPeer& left,
+ const SigNode::HashedPeer& right) {
+ return left.peer->unique_rank_ < right.peer->unique_rank_;
+ }
+ };
+
+ size_t link_hash;
+ SigNode* peer;
+ };
+ using HashedPeerVector = std::vector<HashedPeer>;
+ const HashedPeerVector& hashed_peers() const { return hashed_peers_; }
+
+ // Compares two nodes in two different graphs for equivalence (two nodes in
+ // the same graph would never be equivalent). Expects that the signatures of
+ // the graphs have already been computed, so unique_rank_ is filled in and
+ // the hashed_peers_ properly ordered.
+ bool operator==(const SigNode& other) const;
+
+ bool operator!=(const SigNode& other) const { return !(*this == other); }
+
+ private:
+ friend class test::SigBaseTest;
+
+ // The CopyLinks code is split into 2 parts for testability.
+ // The first pass builds a map ordered by LinkTag for predictability.
+ void CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map);
+ // The second pass converts to the map by hash value,
+ // resolves any hash conflicts, and builds the hashed peer vector.
+ void CopyLinksPass2(std::map<LinkTag, Link>* link_map);
+
+ // Computes the topological hash at distance 0. Resets the topo_hash_ vector
+ // and hashed_nodes_;
+ void ComputeTopoHash0();
+
+ // Compute the topological has at the given distance. The hashes for all the
+ // lower distances must be already computed for all the nodes in the graph.
+ // Also computes next_hashed_nodes_ from last_hashed_nodes_.
+ void ComputeTopoHash(int distance);
+
+ // Get the hash value for a particular distance. It must be previously
+ // computed.
+ size_t GetTopoHash(int distance) const;
+
+ // The the hash value for the highest computed distance. It must be previously
+ // computed.
+ size_t GetHighTopoHash() const {
+ CHECK(!topo_hash_.empty());
+ return topo_hash_.back();
+ }
+
+ // Rehash the topmost hash, to avoid conflicts.
+ void ReHighTopoHash() {
+ CHECK(!topo_hash_.empty());
+ CombineHash(1, &topo_hash_.back());
+ }
+
+ // Ordering by node order and highest available hash (it must be
+ // previously computed).
+ struct NodeOrderLess {
+ bool operator()(const SigNode* left, const SigNode* right) {
+ return left->topo_hash_.back() < right->topo_hash_.back();
+ }
+ };
+
+ private:
+ const NodeDef* node_;
+
+ // The bitmap mask with 1 bit set that represents this node in the set
+ // during the computation of the signature.
+ uint64_t node_mask_ = 0;
+
+ // The code that populates this map makes sure that there are no hash
+ // conflicts, rehashing if necessary.
+ LinkHashMap hash_to_link_;
+
+ // The enumeration of all the direct peers in the predictable order (which
+ // happens to be the order ot their link tags, but the order of the hashes
+ // would do too). It is used for the quick enumeration during the signature
+ // computation. After the signature building is completed, the entries that
+ // have the same link tag get further sorted in the order of the ranks of
+ // their nodes.
+ HashedPeerVector hashed_peers_;
+
+ // The unique rank represents the order in which the node will be included
+ // into the signature. It gets assigned in order either when the topo_hash_ of
+ // this node becomes unique in the graph, or when the nodes are completely
+ // equivalent, one of them is picked at random to assign the next rank, and
+ // then the rest of the nodes attempt to disambiguate based on that
+ // information.
+ size_t unique_rank_ = ~0;
+ // When hash_is_final_ is set, the topo_has_ vector stops growing, and the
+ // last value from it is used for all the further hashes.
+ bool hash_is_final_ = false;
+ // The hashes that include the topology of the nodes up to the distance N. The
+ // hash for distance 0 is produced from the attributes of this node itself and
+ // its general connectivity properties but no information about the
+ // neighboring nodes. The hash for distance D+1 is build from hashes at level
+ // D of this node and of all its immediate neighbors. The neighbors that are
+ // connected by equivalent links are included in a commutative way.
+ std::vector<size_t> topo_hash_;
+ // The set of nodes that got included into the computation of the
+ // last topo_hash_ entry.
+ uint64_t last_hashed_nodes_ = 0;
+ // The next set of nodes that gets used for the current topo_hash entry.
+ uint64_t next_hashed_nodes_ = 0;
+};
+
+// Signature of a graph. The computation is intertwined with the private methods
+// of SigNode, so keeping both in the same file looks more convenient.
+struct Signature {
+ friend class test::SigBaseTest;
+
+ // Maximal size of the graphs for which the signature can be computed.
+ // Changing this constant won't magically add the support for a larger size,
+ // the rest of implementation would have to be extended. The value of 64 is
+ // driven by the size of a bitset in an uint64_t, and should be enough for our
+ // purposes, while having a high efficiency of implementation.
+ static constexpr int kMaxGraphSize = 64;
+
+ // Using the map, computes the rest of the fields of a signature.
+ // Returns an error is the graph is too big.
+ Status Compute();
+
+ // Convert the computed signature to a string representation.
+ string ToString() const;
+
+ SigNodeMap map; // The nodes in the graph, accessible by name.
+ size_t sig_short = 0; // Hash of the signature, for the quick equality check.
+ // The full signature: hashes of the nodes in a predictable order.
+ std::vector<size_t> sig_full;
+ // The nodes in the same order as they go in the signature.
+ std::vector<SigNode*> nodes;
+
+ // For building the unordered maps.
+ size_t Hash() const { return sig_short; }
+
+ // Returns true if the graphs are equivalent. The signature must be already
+ // computed.
+ bool operator==(const Signature& other) const;
+
+ private:
+ // Populates the nodes vector from the map and initializes the state of the
+ // nodes for the signature computation.
+ void PrepareNodes();
+
+ // Finds the nodes with the hashes that are unique and assigns the unique ids
+ // to them. If there are nodes with non-unique hashes, exactly one node from
+ // the first such sequence (in the order of hash values) will be picked and
+ // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have
+ // been already assigned the unique ids. Advances next_node_id by at least 1.
+ void FindUniqueHashes(size_t* next_node_id_p);
+
+ // One round of the signature computation. Assumes that the
+ // nodes[0...(next_node_id-1)] have been already assigned the fixed
+ // positions, and thus computes the hashes only for the remaining nodes.
+ void ComputeOneRound(size_t next_node_id);
+
+ // Additional ordering of the hashed_peers_ links in the nodes, so that they
+ // can be compared and printed in a predictable order.
+ void OrderLinks();
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
new file mode 100644
index 0000000000..4c6a9ba9e0
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
@@ -0,0 +1,1235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Gt;
+using ::testing::Ne;
+using ::testing::SizeIs;
+
+//===
+
+TEST(SigNodeLinkTag, Compare) {
+ SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1));
+ SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3));
+ SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2));
+
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a == c);
+ EXPECT_FALSE(a == e);
+
+ EXPECT_FALSE(a < b);
+ EXPECT_FALSE(b < a);
+
+ EXPECT_TRUE(a < c);
+ EXPECT_FALSE(c < a);
+
+ EXPECT_TRUE(a < d);
+ EXPECT_FALSE(d < a);
+}
+
+//===
+
+class SigBaseTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ void BuildSigMap(const GraphDef& graph) {
+ gen_map_.clear();
+ sig_.map.clear();
+ CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok());
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+ }
+
+ static void CopyLinksPass2(
+ std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) {
+ node->CopyLinksPass2(link_map);
+ }
+
+ static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); }
+
+ static void ComputeTopoHash(int distance, SigNode* node) {
+ node->ComputeTopoHash(distance);
+ }
+
+ static size_t GetTopoHash(int distance, SigNode* node) {
+ return node->GetTopoHash(distance);
+ }
+
+ static size_t GetHighTopoHash(SigNode* node) {
+ return node->GetHighTopoHash();
+ }
+
+ static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); }
+
+ static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) {
+ return node->hashed_peers_;
+ }
+ static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; }
+ static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; }
+ static std::vector<size_t>& RefTopoHash(SigNode* node) {
+ return node->topo_hash_;
+ }
+ static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; }
+ static uint64_t& RefLastHashedNodes(SigNode* node) {
+ return node->last_hashed_nodes_;
+ }
+ static uint64_t& RefNextHashedNodes(SigNode* node) {
+ return node->next_hashed_nodes_;
+ }
+
+ static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); }
+
+ static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) {
+ signature->FindUniqueHashes(next_node_id_p);
+ }
+
+ static void ComputeOneRound(size_t next_node_id, Signature* signature) {
+ signature->ComputeOneRound(next_node_id);
+ }
+
+ static void OrderLinks(Signature* signature) { signature->OrderLinks(); }
+
+ // These get initialized in BuildSigMap().
+ GenNodeMap gen_map_;
+ Signature sig_;
+};
+
+//===
+
+class SigNodeTest : public SigBaseTest {};
+
+// Tests that the duplicate hashes get resolved by rehashing.
+TEST_F(SigNodeTest, DuplicateHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ NodeDef node2 = MakeNodeConst("node2");
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+
+ SigNode sn1(&node1);
+ SigNode sn2(&node2);
+ SigNode sn3(&node3);
+
+ constexpr size_t kSameHash = 999;
+
+ SigNode::Link link1;
+ link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0));
+ link1.unique_hash = kSameHash;
+ link1.peers.emplace_back(&sn1);
+
+ SigNode::Link link2;
+ link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0));
+ link2.unique_hash = kSameHash;
+ link2.peers.emplace_back(&sn2);
+
+ SigNode::Link link3;
+ link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0));
+ link3.unique_hash = kSameHash;
+ link3.peers.emplace_back(&sn3);
+
+ std::map<SigNode::LinkTag, SigNode::Link> link_map;
+ link_map[link1.tag] = link1;
+ link_map[link2.tag] = link2;
+ link_map[link3.tag] = link3;
+
+ CopyLinksPass2(&link_map, &sn3);
+ auto& hl = sn3.hash_to_link();
+ EXPECT_THAT(hl, SizeIs(3));
+
+ // Check that the hashes are self_consistent, and put the entries into
+ // another map with a known order.
+ std::map<SigNode::LinkTag, SigNode::Link> rehashed;
+ auto hlit = hl.begin();
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+
+ // Just in case.
+ ASSERT_THAT(rehashed, SizeIs(3));
+
+ auto rhit = rehashed.begin();
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link1.tag);
+ EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash));
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link2.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ size_t hash2 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link3.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ EXPECT_THAT(rhit->second.unique_hash, Ne(hash2));
+ size_t hash3 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3));
+
+ auto& peers = sn3.hashed_peers();
+ EXPECT_THAT(peers, SizeIs(3));
+
+ auto peerit = peers.begin();
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(kSameHash));
+ EXPECT_THAT(peerit->peer, Eq(&sn1));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash2));
+ EXPECT_THAT(peerit->peer, Eq(&sn2));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash3));
+ EXPECT_THAT(peerit->peer, Eq(&sn3));
+}
+
+// The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature).
+
+TEST_F(SigNodeTest, GetTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ RefHashIsFinal(&sn1) = true;
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456));
+
+ EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456));
+}
+
+TEST_F(SigNodeTest, ReTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ ReHighTopoHash(&sn1);
+
+ size_t expected_hash = 456;
+ CombineHash(1, &expected_hash);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHash0) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 10;
+ RefNodeMask(&sn1) = 0x02;
+
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ // Fake a state.
+ RefLastHashedNodes(&sn1) = 0xFF;
+ RefNextHashedNodes(&sn1) = 0xFF;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+
+ // Run the test.
+ ComputeTopoHash0(&sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1));
+
+ size_t exp_hval = std::hash<string>()(sn1.opcode());
+ CombineHash(1, &exp_hval);
+ CombineHash(1, &exp_hval);
+ CombineHash(2, &exp_hval);
+ CombineHash(3, &exp_hval);
+ CombineHash(3, &exp_hval);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashNotFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38));
+
+ // This computes the hash form the explicit numbers above.
+ size_t exp_hash = 123; // The 0th hash is the starting point.
+ size_t comm_hash;
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(10, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+
+ CombineHash(20, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(30, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology - same as for ComputeTopoHashNotFinal.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state - mostly same as for ComputeTopoHashNotFinal.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ // This is the difference in configuration.
+ RefHashIsFinal(&sn1) = true;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321));
+}
+
+TEST_F(SigNodeTest, EqualsOpcode) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ node2.set_op("Mul");
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+TEST_F(SigNodeTest, EqualsRank) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ RefUniqueRank(&sn1) = 1;
+ RefUniqueRank(&sn2) = 2;
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+// Checks that if the nodes have a different number of links,
+// they will be considered unequal.
+TEST_F(SigNodeTest, EqualsLinkSize) {
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GraphDef graph2;
+ (*graph2.add_node()) = MakeNodeConst("node1");
+ // The difference between graph1 and graph2: one more input.
+ auto node22 = graph2.add_node();
+ *node22 = MakeNodeMul("node2", "node1", "node1");
+ node22->add_input("node2");
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+ EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]);
+}
+
+TEST_F(SigNodeTest, EqualsLinks) {
+ // Start with 2 copies of the same graph.
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the link hash of one of the nodes.
+ SigNode* sn2 = sig_map2["node2"].get();
+ ++RefHashedPeers(sn2)[0].link_hash;
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Restore back.
+ --RefHashedPeers(sn2)[0].link_hash;
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the unique rank of a referenced node.
+ ++RefUniqueRank(sig_map2["node1"].get());
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+}
+
+//===
+
+class SignatureTest : public SigBaseTest {
+ protected:
+ // Initializeds the state used to generate the permutations of a given size.
+ static void InitPermutation(size_t size,
+ std::vector<size_t>* plain_permutation,
+ std::vector<size_t>* countdown) {
+ plain_permutation->clear();
+ countdown->clear();
+ for (size_t i = 0; i < size; ++i) {
+ plain_permutation->emplace_back(i);
+ countdown->emplace_back(size - 1 - i);
+ }
+ }
+
+ // Builds a permutation guided by the count-down value.
+ static void BuildPermutation(const std::vector<size_t>& plain_permutation,
+ const std::vector<size_t>& countdown,
+ std::vector<size_t>* result) {
+ *result = plain_permutation;
+ for (int i = 0; i < result->size(); ++i) {
+ std::swap((*result)[i], (*result)[i + countdown[i]]);
+ }
+ }
+
+ // Returns false when the count-down is finished.
+ static bool CountDown(std::vector<size_t>* countdown) {
+ // The last position always contains 0, so skip it.
+ int pos;
+ for (pos = countdown->size() - 2; pos >= 0; --pos) {
+ if ((*countdown)[pos] > 0) {
+ --(*countdown)[pos];
+ break;
+ }
+ (*countdown)[pos] = (countdown->size() - 1 - pos);
+ }
+
+ return pos >= 0;
+ }
+
+ // Permutes the nodes every which way and checks that all the signatures
+ // produced are the same. This is reasonable for the graphs up to the
+ // size 5, maybe 6 at the stretch. After that the number of permutation grows
+ // huge and the test becomes very slow.
+ void TestGraphEveryWay(const GraphDef& graph) {
+ size_t graph_size = graph.node_size();
+
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(graph_size, &plain_permutation, &countdown);
+
+ std::set<string> signatures;
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+
+ constexpr bool kDebugPermutation = false;
+ if (kDebugPermutation) {
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ }
+
+ std::vector<std::unique_ptr<SigNode>> hold(graph_size);
+ int idx;
+
+ // Permute the nodes.
+ sig_.nodes.clear();
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes before permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ hold[idx++] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[permutation[idx++]]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ // This is used to order the links per permutation.
+ sig_.nodes.emplace_back(entry.second.get());
+ RefUniqueRank(entry.second.get()) = idx;
+ }
+ // Order the links with the same tags per permutation.
+ OrderLinks(&sig_);
+
+ // The test as such.
+ ASSERT_THAT(sig_.Compute(), Eq(Status::OK()));
+
+ signatures.insert(sig_.ToString());
+
+ EXPECT_THAT(sig_.sig_full, SizeIs(graph_size));
+ size_t hval = 0;
+ for (size_t ih : sig_.sig_full) {
+ // The space 1..graph_size is reserved.
+ EXPECT_THAT(ih, Gt(graph_size));
+ CombineHash(ih, &hval);
+ }
+ EXPECT_THAT(sig_.sig_short, Eq(hval));
+
+ // Un-permute the nodes for the next iteration.
+ idx = 0;
+ for (auto& entry : sig_.map) {
+ hold[permutation[idx++]] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after un-permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[idx++]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ }
+ } while (CountDown(&countdown));
+
+ for (const auto& s : signatures) {
+ LOG(INFO) << "Signature: " << s;
+ }
+
+ // All the permutations should produce the same signature.
+ EXPECT_THAT(signatures, SizeIs(1));
+ }
+};
+
+TEST_F(SignatureTest, PrepareNodes) {
+ NodeDef node1 = MakeNodeConst("node1");
+ sig_.map["node1"] = absl::make_unique<SigNode>(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ sig_.map["node2"] = absl::make_unique<SigNode>(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ sig_.map["node3"] = absl::make_unique<SigNode>(&node3);
+
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(3));
+
+ int idx = 0;
+ for (const auto& entry : sig_.map) {
+ EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx))
+ << " at index " << idx;
+ EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0)))
+ << " at index " << idx;
+ EXPECT_THAT(RefHashIsFinal(entry.second.get()), false)
+ << " at index " << idx;
+ EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1))
+ << " at index " << idx;
+ ++idx;
+ }
+}
+
+TEST_F(SignatureTest, FindUniqueHashesAllDifferent) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+
+ // The last values in the arrays values go in the backwards order.
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(900);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(800);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(600);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+
+ size_t next = 1; // Skips over sn1.
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(4));
+
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn1));
+ // The nodes after first one get sorted by the high hash.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn2));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ // Nodes that get finalized are marked as such.
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4));
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3));
+ EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2));
+
+ EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800));
+
+ size_t exp_short_hash = 0;
+ CombineHash(600, &exp_short_hash);
+ CombineHash(700, &exp_short_hash);
+ CombineHash(800, &exp_short_hash);
+ EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(800);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(800);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The unique node goes first.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn3));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ // Node 1 gets swapped with node 3.
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn1));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn5));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2));
+
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicates) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(700);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(700);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The last copy of the last duplicate wins.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn5));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ // Node 1 gets swapped.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn1));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1));
+}
+
+// On a circular topology.
+TEST_F(SignatureTest, ComputeOneRoundCircular) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ // All the nodes are the same, so the computed hashes will also be the same.
+ size_t hval = GetHighTopoHash(sig_.nodes[0]);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i;
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ // The sets of hashed nodes go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i;
+ }
+}
+
+// On a linear topology.
+TEST_F(SignatureTest, ComputeOneRoundLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ std::vector<size_t> hash_size;
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ // The sets of hashed nodes for the central node go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ //
+ // The nodes one step closer to the ends require one more step. The end nodes
+ // require one more step yet.
+ std::sort(hash_size.begin(), hash_size.end());
+ EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6));
+}
+
+// On a linear topology where the cental node has been already marked as unique
+// (yeah, not a very realistic case but tests the situations when the
+// disconnected subgraphs get created).
+TEST_F(SignatureTest, ComputeOneRoundSplitLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This test relies on the order of SigNodeMap imposed on sig_.nodes.
+
+ // The middle node gets separated by moving it to the front.
+ std::swap(sig_.nodes[0], sig_.nodes[2]);
+ ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+ RefHashIsFinal(sig_.nodes[0]) = true;
+
+ ComputeOneRound(1, &sig_);
+
+ // These should stay unchanged.
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+
+ std::vector<size_t> hash_size;
+ for (int i = 1; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ std::sort(hash_size.begin(), hash_size.end());
+ // The end nodes take 4 steps, closer to the center 3 steps.
+ EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C));
+}
+
+TEST_F(SignatureTest, OrderLinks) {
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ // Populate the fake signature and assign the ranks in the backwards order.
+ for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) {
+ auto& entry = *it;
+ RefUniqueRank(entry.second.get()) = sig_.nodes.size();
+ sig_.nodes.emplace_back(entry.second.get());
+ }
+
+ // How it was ordered in the original graph.
+ string before = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(before, Eq(
+ "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2],"
+ "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+
+ OrderLinks(&sig_);
+
+ string after = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(after, Eq(
+ "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2],"
+ "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+}
+
+TEST_F(SignatureTest, GraphTooBig) {
+ GraphDef graph;
+ for (int i = 0; i <= Signature::kMaxGraphSize; ++i) {
+ (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i));
+ }
+
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK()));
+
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ ASSERT_THAT(sig_.Compute(),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "A graph of 65 nodes is too big for signature "
+ "computation, the maximal supported node count is "
+ "64.")));
+}
+
+TEST_F(SignatureTest, ToString) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // Fake the works by assigning unique ranks as they go in the initial order.
+ for (int i = 0; i < 5; ++i) {
+ RefUniqueRank(sig_.nodes[i]) = i;
+ RefHashIsFinal(sig_.nodes[i]) = true;
+ }
+
+ string result = sig_.ToString();
+
+ // clang-format off
+ ASSERT_THAT(result, Eq(
+ "0:Mul[i0:o0:4][i0:o0:4],"
+ "1:Mul[i0:o0:0][i0:o0:0],"
+ "2:Mul[i0:o0:1][i0:o0:1],"
+ "3:Mul[i0:o0:2][i0:o0:2],"
+ "4:Mul[i0:o0:3][i0:o0:3],"
+ ));
+ // clang-format on
+}
+
+// This is a test of the permutation logic itself.
+TEST_F(SignatureTest, Permutation) {
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(5, &plain_permutation, &countdown);
+
+ std::set<string> results;
+
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+ EXPECT_THAT(permutation, SizeIs(5));
+
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ results.insert(p);
+ } while (CountDown(&countdown));
+
+ EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1));
+}
+
+TEST_F(SignatureTest, ComputeCircularOneDir) {
+ TestGraphEveryWay(graph_circular_onedir_);
+}
+
+TEST_F(SignatureTest, ComputeCircularBiDir) {
+ TestGraphEveryWay(graph_circular_bidir_);
+}
+
+TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); }
+
+TEST_F(SignatureTest, ComputeMultiInput) {
+ TestGraphEveryWay(graph_multi_input_);
+}
+
+TEST_F(SignatureTest, ComputeAllOrNone) {
+ TestGraphEveryWay(graph_all_or_none_);
+}
+
+TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); }
+
+TEST_F(SignatureTest, Equals) {
+ // Start with 2 copies of the same graph.
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ Signature sig1;
+ sg1.ExtractForSignature(&sig1.map);
+ ASSERT_THAT(sig1.Compute(), Eq(Status::OK()));
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ Signature sig2;
+ sg2.ExtractForSignature(&sig2.map);
+ ASSERT_THAT(sig2.Compute(), Eq(Status::OK()));
+
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the short hash.
+ ++sig2.sig_short;
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_short;
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the full hash.
+ ++sig2.sig_full[0];
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_full[0];
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Make the nodes different.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Different number of nodes.
+ sig2.nodes.emplace_back(sig2.nodes[0]);
+ EXPECT_FALSE(sig1 == sig2);
+ EXPECT_FALSE(sig2 == sig1);
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.cc b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
new file mode 100644
index 0000000000..28a91e0f84
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <functional>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+//=== Subgraph::Identity
+
+Subgraph::Identity::Identity(InitializerList init) {
+ for (auto element : init) {
+ insert(element);
+ }
+}
+
+bool Subgraph::Identity::operator<(const Identity& other) const {
+ // Shorter sets go first.
+ if (this->size() < other.size()) {
+ return true;
+ }
+ if (this->size() > other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit < *rit) {
+ return true;
+ }
+ if (*lit > *rit) {
+ return false;
+ }
+ }
+ return false; // Equal.
+}
+
+bool Subgraph::Identity::operator==(const Identity& other) const {
+ if (this->size() != other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit != *rit) {
+ return false;
+ }
+ }
+ return true; // Equal.
+}
+
+size_t Subgraph::Identity::Hash() const {
+ std::hash<const GenNode*> hasher;
+ size_t result = 0;
+ for (auto ptr : *this) {
+ CombineHash(hasher(ptr), &result);
+ }
+ return result;
+}
+
+string Subgraph::Dump() {
+ // TODO(babkin): this is simplified for now.
+ std::vector<string> nodes;
+ for (const auto& n : id_) {
+ if (specific_) {
+ nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name()));
+ } else {
+ nodes.emplace_back(n->opcode());
+ }
+ }
+ std::sort(nodes.begin(), nodes.end());
+
+ return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", ");
+}
+
+void Subgraph::ExtractForSignature(SigNodeMap* result) {
+ // Mapping of nodes from the original graph to the new one.
+ SigNode::TranslationMap full_to_new;
+
+ for (auto node : id_) {
+ auto newnode_ref = absl::make_unique<SigNode>(node->node_def());
+ auto newnode = newnode_ref.get();
+ (*result)[node->name()] = std::move(newnode_ref);
+ full_to_new[node] = newnode;
+ }
+
+ for (const auto& mapping : full_to_new) {
+ mapping.second->CopyLinks(*mapping.first, full_to_new);
+ }
+}
+
+//=== Subgraph
+
+Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node)
+ : id_(parent_id) {
+ id_.insert(add_node);
+ hash_ = id_.Hash();
+}
+
+//=== SubgraphIterator
+
+SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id)
+ : id_(id), id_it_(id_->begin()) {
+ if (!id_->empty()) {
+ link_map_it_ = (*id_it_)->links().begin();
+ // In case if the node has no links.
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ // The LinkTargetVector should never be empty but just in case safeguard
+ // against that too.
+ PropagateNext();
+ }
+}
+
+bool SubgraphIterator::Next() {
+ if (AtEnd()) {
+ return false;
+ }
+ ++link_idx_;
+ return PropagateNext();
+}
+
+bool SubgraphIterator::NextIfSamePort() {
+ if (AtEnd()) {
+ return false;
+ }
+ if (link_idx_ + 1 < link_map_it_->second.size()) {
+ ++link_idx_;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void SubgraphIterator::SkipPort() {
+ if (AtEnd()) {
+ return;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+void SubgraphIterator::SkipNode() {
+ if (AtEnd()) {
+ return;
+ }
+ for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) {
+ link_map_it_ = next;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+bool SubgraphIterator::PropagateNext() {
+ // Loops are used to skip over the empty entries.
+ while (link_idx_ >= link_map_it_->second.size()) {
+ ++link_map_it_;
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return false;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ }
+ return true;
+}
+
+bool SubgraphIterator::operator==(const SubgraphIterator& other) const {
+ if (id_ != other.id_) {
+ return false;
+ }
+ if (id_it_ != other.id_it_) {
+ return false;
+ }
+ // When AtEnd(), the rest of the fields are not valid.
+ if (AtEnd()) {
+ return true;
+ }
+ if (link_map_it_ != other.link_map_it_) {
+ return false;
+ }
+ if (link_idx_ != other.link_idx_) {
+ return false;
+ }
+ return true;
+}
+
+//=== SubgraphPtrSet
+
+Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id,
+ GenNode* node) {
+ if (parent_id.find(node) != parent_id.end()) {
+ // This was another link to the node that is already in the parent.
+ return nullptr;
+ }
+
+ // Constructing an object just to check that an equivalent one is already
+ // present is kind of ugly but storing the references rather than the objects
+ // in the set avoids the need to make the object copyable.
+ auto sg = absl::make_unique<Subgraph>(parent_id, node);
+ if (find(sg) != end()) {
+ // This subgraph was already found by extending from a different path.
+ return nullptr;
+ }
+
+ Subgraph* ptr = sg.get();
+ insert(std::move(sg));
+ return ptr;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.h b/tensorflow/core/grappler/graph_analyzer/subgraph.h
new file mode 100644
index 0000000000..4de31d5dfa
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.h
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+
+#include <initializer_list>
+#include <set>
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// The description of a single subgraph for processing.
+class Subgraph {
+ public:
+ // Identity of a single subgraph as a set of nodes.
+ class Identity : public gtl::FlatSet<const GenNode*> {
+ public:
+ using InitializerList = std::initializer_list<GenNode*>;
+
+ Identity() = default;
+ Identity(InitializerList init);
+ bool operator<(const Identity& other) const;
+ bool operator==(const Identity& other) const;
+
+ // Compute the hash.
+ size_t Hash() const;
+ };
+
+ explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {}
+
+ // Construct by extending the parent identity with an extra node.
+ Subgraph(const Identity& parent_id, GenNode* add_node);
+
+ Subgraph() = delete;
+ Subgraph(const Subgraph& other) = delete;
+ void operator=(const Subgraph& other) = delete;
+
+ // Order for building sets of subgraphs.
+ bool operator<(const Subgraph& other) const { return this->id_ < other.id_; }
+ // Support for hashed sets.
+ bool operator==(const Subgraph& other) const {
+ return this->id_ == other.id_;
+ }
+ size_t Hash() const { return hash_; }
+
+ // Dump the subgraph information to a string.
+ string Dump();
+
+ // Extract this subgraph into a separate graph representation for signature
+ // building, that includes only the links between the nodes in the subgraph
+ // and drops all the external links. The result map should be clear before the
+ // call.
+ void ExtractForSignature(SigNodeMap* result);
+
+ const Identity& id() const { return id_; }
+ bool specific() const { return specific_; }
+ void SetSpecific(bool value) { specific_ = value; }
+ int32_t collation_count() const { return collation_count_; }
+ void AddCollation(int32_t n = 1) { collation_count_ += n; }
+ void ResetCollation() { collation_count_ = 1; }
+ void MergeCollation(const Subgraph& other) {
+ collation_count_ += other.collation_count_;
+ }
+
+ private:
+ // Identity also serves as the list of nodes. It never changes throughout the
+ // life of subgraph.
+ Identity id_;
+ size_t hash_; // Cached from the identity.
+ // Whether the dump should include the specific names of the nodes. The
+ // non-specific (i.e. generic) subgraphs represent a collation of multiple
+ // subgraphs.
+ bool specific_ = true;
+ // How many collated subgraphs are represented by this subgraph.
+ int32_t collation_count_ = 1;
+};
+
+// Iteration of all links in a subgraph. This is more like Java iterators than
+// the normal C++ iterators. It's simpler this way and there seems to be no
+// major reason to make it a proper C++ iterator.
+class SubgraphIterator {
+ public:
+ // Obviously an iterator is valid only until the original object
+ // gets destroyed.
+ explicit SubgraphIterator(const Subgraph::Identity* id);
+ explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {}
+
+ // Check whether the built-in iterator is at the end.
+ bool AtEnd() const { return id_it_ == id_->end(); }
+
+ // Get the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode::LinkTarget& GetNeighbor() const {
+ return link_map_it_->second[link_idx_];
+ }
+
+ // Get the node at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode* GetNode() const { return *id_it_; }
+
+ // Get the port leading to the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ GenNode::Port GetPort() const { return link_map_it_->first; }
+
+ // Increases the iterator.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // Safe to call if already AtEnd().
+ bool Next();
+
+ // If there are more links at the same port, increases the iterator and
+ // returns true. Otherwise leaves the iterator unchanged and returns false.
+ bool NextIfSamePort();
+
+ // Increases the iterator directly to the last position on the current port
+ // (or if already there then doesn't increase). Equivalent to calling
+ // NextIfSamePort() while it returns true, but faster.
+ // Safe to call if already AtEnd().
+ void SkipPort();
+
+ // Increases the iterator directly to the last position on the current node.
+ // Safe to call if already AtEnd().
+ void SkipNode();
+
+ // Returns true if the iterators are exactly the same.
+ bool operator==(const SubgraphIterator& other) const;
+ bool operator!=(const SubgraphIterator& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ // After link_idx_ has been increased, make sure that it points to the
+ // next valid element (or end) by increasing the higher levels of iteration if
+ // needed.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // NOT safe to call if already AtEnd().
+ bool PropagateNext();
+
+ // Identity of the subgraph being iterated over.
+ const Subgraph::Identity* id_;
+
+ // The current position, allowing to iterate through the links (see the
+ // reasoning for it in the public section).
+ //
+ // (1) Iterator of the nodes in the subgraph.
+ Subgraph::Identity::const_iterator id_it_;
+ // (2) Iterator in the link map of the node.
+ GenNode::LinkMap::const_iterator link_map_it_;
+ // (3) Index in the vector of the links.
+ int32_t link_idx_;
+};
+
+// A convenient way to store subgraphs: in a set of unique_ptrs. This way the
+// addresses of subgraph objects will stay stable, and the objects themselves
+// won't be copied.
+class SubgraphPtrSet
+ : public std::unordered_set<std::unique_ptr<Subgraph>,
+ HashAtPtr<std::unique_ptr<Subgraph>>,
+ EqAtPtr<std::unique_ptr<Subgraph>>> {
+ public:
+ // Attempts to extend the set by adding a new subgraph that gets created by
+ // adding one node to the parent subgraph. If such a subgraph already exists,
+ // returns nullptr, otherwise returns the pointer to the new subgraph.
+ Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node);
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
new file mode 100644
index 0000000000..0f90dc8f0d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
@@ -0,0 +1,348 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(SubgraphTest, Comparison) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ ASSERT_THAT(gn1, Ne(nullptr));
+ ASSERT_THAT(gn2, Ne(nullptr));
+
+ Subgraph::Identity id1;
+ Subgraph::Identity id2;
+
+ id1.insert(gn1);
+ id2.insert(gn2);
+
+ Subgraph sg1(id1);
+ Subgraph sg2(id2);
+
+ EXPECT_TRUE(id1 == sg1.id());
+ EXPECT_TRUE(id2 == sg2.id());
+
+ EXPECT_THAT(sg1 < sg2, Eq(id1 < id2));
+}
+
+TEST(SubgraphTest, EmptyIteration) {
+ NodeDef node1 = MakeNodeConst("node1");
+ auto gn1 = absl::make_unique<GenNode>(&node1);
+ Subgraph::Identity id1;
+ id1.insert(gn1.get());
+ Subgraph sg1(id1);
+ SubgraphIterator sit(&sg1);
+
+ EXPECT_TRUE(sit.AtEnd());
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+
+ SubgraphIterator sit2(&sg1);
+ EXPECT_TRUE(sit == sit2);
+}
+
+TEST(SubgraphTest, Iteration) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ // node3 has 2 incoming data links, 2 outgoing data , 1 control incoming, 1
+ // control outgoing = total of 6
+ {
+ SubgraphIterator sit(&sg);
+ EXPECT_FALSE(sit.AtEnd()); // 1
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 2
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 3
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 4
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 5
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 6
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+ }
+
+ // Now get the values out. And more equality testing along the way.
+ {
+ SubgraphIterator sit(&sg);
+ SubgraphIterator sit2(&sg);
+ std::vector<string> links;
+ for (; !sit.AtEnd(); sit.Next()) {
+ EXPECT_TRUE(sit == sit2);
+ sit2.Next();
+ EXPECT_FALSE(sit == sit2);
+
+ links.push_back(absl::StrFormat("[%s,%s,%s]", string(sit.GetPort()),
+ sit.GetNeighbor().node->name(),
+ string(sit.GetNeighbor().port)));
+ }
+ EXPECT_TRUE(sit == sit2);
+
+ std::sort(links.begin(), links.end());
+ // clang-format off
+ EXPECT_THAT(links, ElementsAre(
+ "[i0,node1,o0]",
+ "[i1,node2,o0]",
+ "[iC,node3,oC]",
+ "[o0,node2,i1]",
+ "[o1,node2,i0]",
+ "[oC,node3,iC]"
+ ));
+ // clang-format on
+ }
+}
+
+TEST(SubgraphTest, IterationSamePort) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ int total_links = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ ++total_links;
+ }
+
+ // Initialize the port as control, which doesn't occur in this graph.
+ GenNode::Port last_port(false, -1);
+ int steps_total_same_port = 0;
+ int steps_with_same_port = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ GenNode::Port new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Ne(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ last_port = new_port;
+
+ ++steps_total_same_port;
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipPort();
+
+ while (sit.NextIfSamePort()) {
+ new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Eq(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ ++steps_total_same_port;
+ ++steps_with_same_port;
+ }
+
+ EXPECT_TRUE(sit == sit2);
+ }
+
+ EXPECT_THAT(steps_total_same_port, Eq(total_links));
+ // There is one 2-way input and one 2-way output.
+ EXPECT_THAT(steps_with_same_port, Eq(2));
+}
+
+TEST(SubgraphTest, IterationSameNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ const GenNode* last_node = nullptr;
+ SubgraphIterator sit(&sg);
+ while (!sit.AtEnd()) {
+ const GenNode* new_node = sit.GetNode();
+
+ EXPECT_THAT(new_node, Ne(last_node)) << "At node " << new_node->name();
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipNode();
+
+ ASSERT_FALSE(sit2.AtEnd());
+ EXPECT_THAT(sit2.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ while (sit != sit2 && !sit.AtEnd()) {
+ sit.Next();
+ }
+
+ ASSERT_FALSE(sit.AtEnd());
+ EXPECT_THAT(sit.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ sit.Next();
+
+ last_node = new_node;
+ }
+
+ // Check that it doesn't fail if already at end.
+ sit.SkipNode();
+ EXPECT_TRUE(sit.AtEnd());
+}
+
+TEST(SubgraphTest, ExtendSet) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id_empty;
+
+ Subgraph::Identity id3;
+ id3.insert(map["node3"].get());
+
+ Subgraph::Identity id23 = id3;
+ id23.insert(map["node2"].get());
+
+ Subgraph* sg;
+ SubgraphPtrSet set;
+
+ // Extend an empty identity.
+ sg = set.ExtendParent(id_empty, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id3);
+
+ // Extend with a node that is already in the parent.
+ sg = set.ExtendParent(id3, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ EXPECT_THAT(sg, Eq(nullptr));
+
+ // Extend to a 2-node subgraph.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id23);
+
+ // The second insert of the same node gets ignored.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ EXPECT_THAT(sg, Eq(nullptr));
+}
+
+TEST(SubgraphTest, ExtractForSignature) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node1");
+ node3->add_input("^node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node1"].get());
+ id.insert(map["node3"].get());
+
+ Subgraph sg(id);
+
+ SigNodeMap map2;
+ sg.ExtractForSignature(&map2);
+ ASSERT_THAT(map2.find("node1"), Ne(map2.end()));
+ ASSERT_THAT(map2.find("node2"), Eq(map2.end()));
+ ASSERT_THAT(map2.find("node3"), Ne(map2.end()));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkHashMap(map2["node1"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "o0:i0: node3"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node1"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node3"
+ ));
+ EXPECT_THAT(DumpLinkHashMap(map2["node3"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "iC:oC: node1, node3",
+ "i0:o0: node1"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node3"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node1",
+ "node3",
+ "node1"
+ ));
+ // clang-format on
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.cc b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
new file mode 100644
index 0000000000..fc9495bc7d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op("Const");
+ return n;
+}
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ return n;
+}
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ n.add_input(arg3);
+ n.add_input(arg4);
+ return n;
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2) {
+ // This opcode is multi-input but not commutative.
+ return MakeNode2Arg(name, "ShapeN", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2) {
+ // The argument is of a list type.
+ return MakeNode2Arg(name, "IdentityN", arg1, arg2);
+}
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4) {
+ // This opcode has multiple multi-inputs.
+ return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4);
+}
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) {
+ // This will order the entries first.
+ std::map<string, string> ordered;
+ for (const auto& link : link_map) {
+ string key = string(link.first);
+
+ // Order the other sides too. They may be repeating, so store them
+ // in a multiset.
+ std::multiset<string> others;
+ for (const auto& other : link.second) {
+ others.emplace(
+ absl::StrFormat("%s[%s]", other.node->name(), string(other.port)));
+ }
+ ordered[key] = absl::StrJoin(others, ", ");
+ }
+ // Now dump the result in a predictable order.
+ std::vector<string> result;
+ result.reserve(ordered.size());
+ for (const auto& link : ordered) {
+ result.emplace_back(link.first + ": " + link.second);
+ }
+ return result;
+}
+
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) {
+ // The entries in this map are ordered by hash value which might change
+ // at any point. Re-order them by the link tag.
+ std::map<SigNode::LinkTag, size_t> tags;
+ for (const auto& entry : link_hash_map) {
+ tags[entry.second.tag] = entry.first;
+ }
+
+ std::vector<string> result;
+ for (const auto& id : tags) {
+ // For predictability, the nodes need to be sorted.
+ std::vector<string> nodes;
+ for (const auto& peer : link_hash_map.at(id.second).peers) {
+ nodes.emplace_back(peer->name());
+ }
+ std::sort(nodes.begin(), nodes.end());
+ result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) +
+ ": " + absl::StrJoin(nodes, ", "));
+ }
+ return result;
+}
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers) {
+ std::vector<string> result;
+
+ // Each subset of nodes with the same hash has to be sorted by name.
+ // Other than that, the vector is already ordered by full tags.
+ size_t last_hash = 0;
+ // Index, since iterators may get invalidated on append.
+ size_t subset_start = 0;
+
+ for (const auto& entry : hashed_peers) {
+ if (entry.link_hash != last_hash) {
+ std::sort(result.begin() + subset_start, result.end());
+ subset_start = result.size();
+ }
+ result.emplace_back(entry.peer->name());
+ }
+ std::sort(result.begin() + subset_start, result.end());
+
+ return result;
+}
+
+TestGraphs::TestGraphs() {
+ {
+ GraphDef& graph = graph_3n_self_control_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+ }
+ {
+ GraphDef& graph = graph_multi_input_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto add2 = graph.add_node();
+ *add2 = MakeNodeAddN("add2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ add2->add_input("const2_3");
+ add2->add_input("const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2");
+ }
+ {
+ GraphDef& graph = graph_all_or_none_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ auto pass1 = graph.add_node();
+ *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto pass2 = graph.add_node();
+ *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ pass2->add_input("const2_3");
+ pass2->add_input("const2_3");
+
+ // Add the control links, they get handled separately than the normal
+ // links.
+ pass1->add_input("^const2_1");
+ pass1->add_input("^const2_2");
+ pass1->add_input("^const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2");
+ }
+ {
+ GraphDef& graph = graph_circular_onedir_;
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_circular_bidir_;
+ // The left and right links are intentionally mixed up.
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2");
+ (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4");
+ (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1");
+ }
+ {
+ GraphDef& graph = graph_linear_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5");
+ (*graph.add_node()) = MakeNodeConst("node7");
+ (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node9", "node2", "node4");
+ center->add_input("node6");
+ center->add_input("node8");
+ }
+ {
+ GraphDef& graph = graph_small_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+ }
+ {
+ GraphDef& graph = graph_for_link_order_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ // One group of equivalent links.
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+
+ // Multiple groups, separated by unique links.
+ auto center2 = graph.add_node();
+ *center2 = MakeNodeMul("node6", "node1", "node2");
+ center2->add_input("node2:1");
+ center2->add_input("node3:2");
+ center2->add_input("node4:2");
+ center2->add_input("node4:3");
+ }
+ {
+ GraphDef& graph = graph_sun_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10");
+ (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6");
+ (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7");
+ (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8");
+ (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9");
+ }
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.h b/tensorflow/core/grappler/graph_analyzer/test_tools.h
new file mode 100644
index 0000000000..98e269d57e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.h
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/op_types.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name);
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2);
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4);
+
+inline NodeDef MakeNodeMul(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Mul", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+inline NodeDef MakeNodeAddN(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "AddN", arg1, arg2);
+}
+
+inline NodeDef MakeNodeSub(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Sub", arg1, arg2);
+}
+
+// Has 2 honest outputs.
+inline NodeDef MakeNodeBroadcastGradientArgs(const string& name,
+ const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2);
+}
+
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4);
+
+//=== A container of pre-constructed graphs.
+
+class TestGraphs {
+ public:
+ TestGraphs();
+
+ // Graph with 3 nodes and a control link to self (which is not valid in
+ // reality but adds excitement to the tests).
+ GraphDef graph_3n_self_control_;
+ // Graph that has the multi-input links.
+ GraphDef graph_multi_input_;
+ // Graph that has the all-or-none nodes.
+ GraphDef graph_all_or_none_;
+ // All the nodes are connected in a circle that goes in one direction.
+ GraphDef graph_circular_onedir_;
+ // All the nodes are connected in a circle that goes in both directions.
+ GraphDef graph_circular_bidir_;
+ // The nodes are connected in a line.
+ GraphDef graph_linear_;
+ // The nodes are connected in a cross shape.
+ GraphDef graph_cross_;
+ GraphDef graph_small_cross_;
+ // For testing the ordering of links at the end of signature generation,
+ // a variation of a cross.
+ GraphDef graph_for_link_order_;
+ // Sun-shaped, a ring with "rays".
+ GraphDef graph_sun_;
+};
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map);
+
+// Also checks for the consistency of hash values.
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map);
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers);
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 288587ce9b..029515ad3c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..e78239bd43 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
-bool IsElementWiseMonotonic(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_monotonic_ops =
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+ static const std::unordered_set<string>* monotonic_non_decreasing_ops =
CHECK_NOTNULL((new std::unordered_set<string>{
- "Relu",
- "Relu6",
- "Sigmoid",
- "Sqrt",
- "Tanh",
+ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
+ "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
+ "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ }));
+ static const std::unordered_set<string>* monotonic_non_increasing_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Inv",
+ "Reciprocal",
+ "Erfc",
+ "Rsqrt",
+ "Neg",
}));
- return element_wise_monotonic_ops->count(node.op()) > 0;
+ if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = true;
+ }
+ return true;
+ } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = false;
+ }
+ return true;
+ }
+ return false;
}
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..25ab6b65ac 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index caaa5ac8db..f094c151e6 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -116,6 +116,7 @@ tf_cc_test(
shard_count = 5,
deps = [
":constant_folding",
+ ":dependency_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:all_kernels",
@@ -827,11 +828,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:op_types",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/grappler/costs:graph_properties",
],
)
@@ -850,3 +846,68 @@ tf_cc_test(
"//third_party/eigen3",
],
)
+
+cc_library(
+ name = "function_api_info",
+ srcs = ["function_api_info.cc"],
+ hdrs = ["function_api_info.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "function_api_info_test",
+ size = "small",
+ srcs = ["function_api_info_test.cc"],
+ deps = [
+ ":function_api_info",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "experimental_implementation_selector",
+ srcs = ["experimental_implementation_selector.cc"],
+ hdrs = ["experimental_implementation_selector.h"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ ],
+)
+
+tf_cc_test(
+ name = "experimental_implementation_selector_test",
+ size = "small",
+ srcs = ["experimental_implementation_selector_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":experimental_implementation_selector",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 889445bbd6..11ce121cba 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -1120,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
- if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
- }
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -2702,22 +2700,37 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ bool is_non_decreasing = false;
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..d457eb6d21 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -61,7 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fold_multiply_into_conv = true;
bool fold_transpose_into_matmul = true;
bool hoist_common_factor_out_of_aggregation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool minimize_broadcasts = true;
bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 685b5379af..39517edc06 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3224,6 +3224,72 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"sqrt", "final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // Should be a NoOp since we are not allowed to change the output of fetch
+ // nodes.
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+}
+
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "neg") {
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Min", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index f2ac3a44c0..99737a71eb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -136,6 +136,27 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input;
}
+bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
+ int* axis) {
+ if (node->op() != "ConcatV2" ||
+ properties.GetInputProperties(node->name()).empty()) {
+ return false;
+ }
+ const auto& axis_input = properties.GetInputProperties(node->name()).back();
+ if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
+ return false;
+ }
+
+ Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
+ if (!axis_tensor.FromProto(axis_input.value())) {
+ return false;
+ }
+ *axis = axis_input.dtype() == DT_INT64
+ ? static_cast<int>(axis_tensor.scalar<int64>()())
+ : axis_tensor.scalar<int32>()();
+ return true;
+}
+
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@@ -1719,6 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
+ if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
return Status::OK();
}
@@ -2910,6 +2936,55 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
return false;
}
+bool ConstantFolding::MergeConcat(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node) {
+ // We only optimize for ConcatV2.
+ int axis;
+ if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
+ nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
+ node_map_->GetOutputs(node->name()).size() != 1) {
+ return false;
+ }
+
+ NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
+ int parent_axis;
+ if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
+ return false;
+ }
+
+ const int index = NumNonControlInputs(*node) - 1;
+ auto inputs = parent->input();
+ parent->clear_input();
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (IsSameInput(inputs.Get(i), node->name())) {
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (j < index) {
+ // Input tensors (non axis), add to input list of parent.
+ parent->add_input(node->input(j));
+ node_map_->RemoveOutput(node->input(j), node->name());
+ node_map_->AddOutput(node->input(j), parent->name());
+ }
+ // Skip j == index, which means axis tensor.
+ if (j > index) {
+ // Control Dependencies, push back to inputs so they can be forwarded
+ // to parent.
+ *inputs.Add() = node->input(j);
+ }
+ }
+ } else {
+ parent->add_input(inputs.Get(i));
+ }
+ }
+ node->clear_input();
+ node->set_op("NoOp");
+ node->clear_attr();
+ node_map_->RemoveNode(node->name());
+ (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
+
+ return true;
+}
+
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index b42d5f201e..8593b3e0b8 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -209,6 +209,10 @@ class ConstantFolding : public GraphOptimizer {
// Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
+
+ bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b9765b9292..2a19b3f95a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2030,6 +2030,130 @@ TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
CompareGraphs(want, got);
}
+TEST_F(ConstantFoldingTest, MergeConcat) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
+ &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
+ Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis1", "Const", {}, {}, &want);
+ AddNode("axis2", "Const", {}, {}, &want);
+ AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
+ AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
@@ -3047,6 +3171,39 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
}
+TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
+ // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
+ // Make sure constant folding behaves the same way as TensorFlow.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a =
+ ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
+ Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
+ Output c = ops::Mul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch.push_back("c");
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(1, output.node_size());
+
+ const NodeDef& node_d = output.node(0);
+ EXPECT_EQ("c", node_d.name());
+ EXPECT_EQ("Const", node_d.op());
+
+ std::vector<string> fetch = {"c"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index b8e69787e3..530c957068 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -4,36 +4,41 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
- name = "function_rename",
- srcs = ["function_rename.cc"],
+ name = "filter_fusion",
+ srcs = ["filter_fusion.cc"],
hdrs = [
- "function_rename.h",
+ "filter_fusion.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":fusion_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:lib",
- "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
] + tf_protos_all(),
)
tf_cc_test(
- name = "function_rename_test",
- srcs = ["function_rename_test.cc"],
+ name = "filter_fusion_test",
+ srcs = ["filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- ] + tf_protos_all(),
+ ],
)
cc_library(
@@ -46,11 +51,13 @@ cc_library(
deps = [
":graph_utils",
"//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:functional_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -125,6 +132,43 @@ cc_library(
)
cc_library(
+ name = "map_vectorization",
+ srcs = ["map_vectorization.cc"],
+ hdrs = [
+ "map_vectorization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_vectorization_test",
+ srcs = ["map_vectorization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_vectorization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
+ ],
+)
+
+cc_library(
name = "map_and_batch_fusion",
srcs = ["map_and_batch_fusion.cc"],
hdrs = [
@@ -306,11 +350,12 @@ cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
+ ":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
],
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
new file mode 100644
index 0000000000..c71aa6e804
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -0,0 +1,141 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
+ const NodeDef& second_filter_node,
+ const FunctionDef& fused_function,
+ MutableGraphView* graph) {
+ NodeDef fused_node;
+ graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
+ &fused_node);
+
+ fused_node.set_op("FilterDataset");
+ fused_node.add_input(first_filter_node.input(0));
+
+ auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
+ NodeDef* to) {
+ (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+ };
+
+ auto attr = first_filter_node.attr().at("predicate");
+ *attr.mutable_func()->mutable_name() = fused_function.signature().name();
+ (*fused_node.mutable_attr())["predicate"] = std::move(attr);
+
+ copy_attribute("Targuments", first_filter_node, &fused_node);
+
+ for (auto key : {"output_shapes", "output_types"})
+ copy_attribute(key, second_filter_node, &fused_node);
+
+ return fused_node;
+}
+
+} // namespace
+
+Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ GraphDef sorted_old_graph = item.graph;
+ TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
+ *output = sorted_old_graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ output->library());
+
+ auto get_filter_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "FilterDataset") return &node;
+ return nullptr;
+ };
+
+ auto get_fused_predicate =
+ [&](const NodeDef* first_filter_node,
+ const NodeDef* second_filter_node) -> FunctionDef* {
+ const auto& parent_fun = first_filter_node->attr().at("predicate");
+ const FunctionDef* first_func =
+ function_library.Find(parent_fun.func().name());
+ const auto& fun = second_filter_node->attr().at("predicate");
+ const FunctionDef* second_func = function_library.Find(fun.func().name());
+
+ if (!fusion_utils::HasSameSignature(first_func->signature(),
+ second_func->signature())) {
+ VLOG(1) << "Can't fuse Filters because they have different signature\n";
+ return nullptr;
+ }
+
+ return fusion_utils::FuseFunctions(
+ *first_func, *second_func, "fused_predicate",
+ fusion_utils::SameSignature, fusion_utils::SameInput,
+ fusion_utils::LazyConjunctionOutput, fusion_utils::LazyConjunctionNodes,
+ output->mutable_library());
+ };
+
+ for (const NodeDef& node : sorted_old_graph.node()) {
+ const NodeDef* second_filter_node = get_filter_node(node);
+ if (!second_filter_node) continue;
+
+ const NodeDef* first_filter_node =
+ get_filter_node(*graph_utils::GetInputNode(*second_filter_node, graph));
+ if (!first_filter_node) continue;
+
+ const auto* fused_predicate =
+ get_fused_predicate(first_filter_node, second_filter_node);
+ if (!fused_predicate) continue;
+ const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode(
+ *first_filter_node, *second_filter_node, *fused_predicate, &graph));
+
+ graph.ReplaceInput(*second_filter_node, *fused_filter_node);
+
+ // TODO(prazek): we should run some optimizations on the fused filter
+ // functions, or make sure that optimization passes run after filter
+ // fusion.
+ TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
+ // TODO(prazek): we could also remove map functions from library if they
+ // are not used anymore.
+ nodes_to_delete.insert(first_filter_node->name());
+ nodes_to_delete.insert(second_filter_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void FilterFusion::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(FilterFusion, "filter_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.h b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
new file mode 100644
index 0000000000..91a0364a46
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization fuses filter transformations.
+class FilterFusion : public CustomGraphOptimizer {
+ public:
+ FilterFusion() = default;
+ ~FilterFusion() override = default;
+
+ string name() const override { return "filter_fusion"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
new file mode 100644
index 0000000000..12b1924efd
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
+ {"Targuments", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+}
+
+TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"),
+ MakeFilterNode("filter2", "filter1")},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+}
+
+TEST(FilterFusionTest, FuseThreeNodesIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"), MakeFilterNode("filter2", "filter1"),
+ MakeFilterNode("filter3", "filter2"),
+ NDef("cache", "CacheDataset", {"filter3", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter3", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc
deleted file mode 100644
index 8cf044d1bd..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-
-Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* output) {
- *output = item.graph;
- GraphView graph(output);
- int n = output->mutable_library()->function_size();
- for (int i = 0; i < n; ++i) {
- FunctionDef* fn = output->mutable_library()->mutable_function(i);
- fn->mutable_signature()->set_name(fn->signature().name() + "world");
- }
-
- return Status::OK();
-}
-
-void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item,
- const GraphDef& optimize_output, double result) {
- // no-op
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename");
-
-} // end namespace grappler
-} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
deleted file mode 100644
index 56b8a960a7..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-TEST(FunctionRenameTest, RenameFunction) {
- GrapplerItem item;
- GraphDef *graph = &item.graph;
- FunctionDef *fn = graph->mutable_library()->add_function();
- fn->mutable_signature()->set_name("hello");
-
- FunctionRename optimizer;
- GraphDef output;
- TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_EQ(output.library().function(0).signature().name(), "helloworld");
-}
-
-} // namespace
-} // namespace grappler
-} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index f84f109af6..01a78c04b0 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def.pb.h"
-
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
@@ -52,6 +52,12 @@ string GetOutputNode(const FunctionDef& function, int output_idx) {
return function.ret().at(ret_output_name);
}
+string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
+ const auto& ret_output_name =
+ function->signature().output_arg(output_idx).name();
+ return function->mutable_ret()->at(ret_output_name);
+}
+
template <typename Iterable>
StringCollection GetNames(const Iterable& iterable, int allocate_size) {
StringCollection names;
@@ -106,7 +112,6 @@ gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
// Nodes that will be added to the function can have the same name as the nodes
// from parent function.
void RenameFunctionNodes(const FunctionDef& first_function,
- FunctionDef* fused_function,
protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
protobuf::Map<string, string>* rets_to_fuse) {
const gtl::FlatMap<string, string> changed_node_names =
@@ -149,6 +154,7 @@ OpDef GetUniqueSignature(const OpDef& first_signature,
const gtl::FlatMap<string, string> changed_input_names =
GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
OpDef signature;
+ signature.set_name(second_signature.name());
for (const auto& input_arg : second_signature.input_arg()) {
auto& input = *signature.add_input_arg();
@@ -221,12 +227,13 @@ void FuseFunctionNodes(const StringCollection& first_inputs,
}
// This function looks for direct edges from input to return and rewrites
-// them to the coresponding input of the return of `first_function`.
+// them to the corresponding input of the return of `first_function`.
void FuseReturns(const StringCollection& first_inputs,
const StringCollection& second_inputs,
const StringCollection& first_outputs,
- const SetInputFn& set_input, FunctionDef* fused_function) {
- for (auto& ret : *fused_function->mutable_ret()) {
+ const SetInputFn& set_input,
+ protobuf::Map<string, string>* fused_ret) {
+ for (auto& ret : *fused_ret) {
auto return_input = ParseNodeConnection(ret.second);
auto input_it =
std::find(second_inputs.begin(), second_inputs.end(), return_input);
@@ -249,6 +256,33 @@ StringCollection GetFunctionOutputs(const FunctionDef& function) {
return outputs;
}
+FunctionDef* CreateFalsePredicate(
+ const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
+ FunctionDefLibrary* library) {
+ GraphDef graph;
+ MutableGraphView graph_view(&graph);
+ auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
+ auto* false_predicate = library->add_function();
+ graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
+ false_predicate);
+
+ int num = 0;
+ for (const auto& fake_arg : fake_args) {
+ auto* arg = false_predicate->mutable_signature()->add_input_arg();
+ arg->set_type(fake_arg.type());
+ arg->set_name(strings::StrCat("fake_arg", num));
+ num++;
+ }
+
+ auto* output = false_predicate->mutable_signature()->add_output_arg();
+ output->set_name("false_out");
+ output->set_type(DT_BOOL);
+
+ (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
+ *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
+ return false_predicate;
+}
+
void CheckIfCanCompose(const OpDef& first_signature,
const OpDef& second_signature) {
CHECK(CanCompose(first_signature, second_signature))
@@ -259,6 +293,15 @@ void CheckIfCanCompose(const OpDef& first_signature,
} // namespace
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ // Copy all nodes from first_function.
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+ // Copy transformed nodes from the second function.
+ fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
+}
+
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
// TODO(prazek): Functions can have additional inputs being placeholders
// for a values used in function. We should be able to also fuse these
@@ -285,8 +328,8 @@ void ComposeSignature(const OpDef& first_signature,
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = second_ret;
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = second_ret;
}
void CombineSignature(const OpDef& first_signature,
@@ -302,41 +345,110 @@ void CombineSignature(const OpDef& first_signature,
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = first_ret;
- fused_function->mutable_ret()->insert(second_ret.begin(), second_ret.end());
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = first_ret;
+ fused_ret->insert(second_ret.begin(), second_ret.end());
+}
+
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num) {
+ return first_inputs.at(arg_num);
+}
+
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature) {
+ return first_signature.input_arg_size() ==
+ second_signature.input_arg_size() &&
+ first_signature.output_arg_size() ==
+ second_signature.output_arg_size();
+}
+
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature) {
+ CHECK(HasSameSignature(first_signature, second_signature))
+ << "Functions do not have the same signature";
+ // Copy signature from first function.
+ *fused_signature = first_signature;
+}
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
+ NodeDefBuilder if_builder("", "If");
+ if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
+ DataTypeVector in_arg_types;
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ for (const auto& input_arg : first_function.signature().input_arg()) {
+ inputs.push_back({input_arg.name(), 0, input_arg.type()});
+ in_arg_types.push_back(input_arg.type());
+ }
+ if_builder.Attr("Tin", in_arg_types);
+
+ if_builder.Attr("Tcond", DT_BOOL);
+ if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
+ if_builder.Attr("_lower_using_switch_merge", true);
+
+ NameAttrList then_branch;
+ then_branch.set_name(second_function.signature().name());
+ if_builder.Attr("then_branch", then_branch);
+
+ auto* false_predicate =
+ CreateFalsePredicate(first_function.signature().input_arg(), library);
+
+ NameAttrList else_branch;
+ else_branch.set_name(false_predicate->signature().name());
+ if_builder.Attr("else_branch", else_branch);
+ if_builder.Input(inputs);
+
+ auto* if_node = fused_function->add_node_def();
+ // This is guaranteed to succeed.
+ TF_CHECK_OK(if_builder.Finalize(if_node));
+ graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+
+ GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
+}
+
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret) {
+ CHECK_EQ(first_ret.size(), 1);
+ CHECK_EQ(second_ret.size(), 1);
+ // Temporarily copy returns from first_ret. We are going to change the
+ // output node after creating it.
+ *fused_ret = first_ret;
}
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library) {
- if (first_function.attr_size() != 0 || function.attr_size() != 0)
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
+ if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
return nullptr; // Functions with attributes are currently not supported
// This function will be used as a clone of second function, having unique
// names.
- FunctionDef setup_function = function;
+ FunctionDef setup_function = second_function;
*setup_function.mutable_signature() = GetUniqueSignature(
first_function.signature(), setup_function.signature(),
setup_function.mutable_ret(), setup_function.mutable_node_def());
FunctionDef* fused_function = library->add_function();
- // Copy all nodes from first_function.
- fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
set_signature(first_function.signature(), setup_function.signature(),
fused_function->mutable_signature());
graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
fused_function);
- RenameFunctionNodes(first_function, fused_function,
- setup_function.mutable_node_def(),
+ RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
setup_function.mutable_ret());
- set_output(first_function.ret(), setup_function.ret(), fused_function);
+ set_output(first_function.ret(), setup_function.ret(),
+ fused_function->mutable_ret());
CHECK(fused_function->signature().output_arg_size() ==
fused_function->ret_size())
@@ -351,10 +463,10 @@ FunctionDef* FuseFunctions(const FunctionDef& first_function,
FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
setup_function.mutable_node_def());
FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
- fused_function);
+ fused_function->mutable_ret());
+
+ set_nodes(first_function, setup_function, fused_function, library);
- // Copy transformed nodes from the second function.
- fused_function->mutable_node_def()->MergeFrom(setup_function.node_def());
return fused_function;
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
index 41f13f6cb8..19b7002dcd 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
@@ -48,14 +48,20 @@ using SetInputFn =
const StringCollection& second_function_inputs,
const StringCollection& parent_outputs, int arg_num)>;
-// This function is invoked with first function ret. It is used to set up
-// returns of fused function. If you need to combine outputs
-// of first and second function, then this is a right place to create a new
-// nodes.
+// This function is invoked with first and second function ret. It is used to
+// set up returns of fused function.
using SetOutputFn =
std::function<void(const protobuf::Map<string, string>& parent_ret,
const protobuf::Map<string, string>& second_function_ret,
- FunctionDef* fused_function)>;
+ protobuf::Map<string, string>* fused_ret)>;
+
+using SetNodesFn = std::function<void(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ FunctionDef* fused_function, FunctionDefLibrary* library)>;
+
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Returns true if functions can be composed.
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature);
@@ -71,7 +77,7 @@ string ComposeInput(const StringCollection& first_inputs,
// second_function(first_function(args...)).
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
// Set input signature to `first_function_signature` and output signature
// to `first_function_signature` + `second_function_signature`
@@ -83,7 +89,32 @@ void CombineSignature(const OpDef& first_signature,
// return *first_function(...), *second_function(...)
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
+
+// Returns true if both signatures have the same number of input and output
+// args.
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature);
+
+// Check if both signatures are same and copy it from `first_signature`.
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature);
+
+// Take the same input as first function.
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num);
+
+// Create a fused function that computes the short-circuit logical AND of the
+// result of the first function and the result of the second function.
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret);
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Fuse `first_function` with `second_function`, setting `fused_name_prefix` as
// a name prefix. The nodes from `first_function` are copied unmodified. All
@@ -91,13 +122,11 @@ void CombineOutput(const protobuf::Map<string, string>& first_ret,
// that are not conflicting with first function. This means that copied nodes
// from second function can end up having different names. For explanation of
// set up functions see the documentation of the functions types.
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& second_function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library);
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library);
} // namespace fusion_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index 7ad5d63bf6..d5c6466080 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -57,10 +57,10 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::ComposeSignature, fusion_utils::ComposeInput,
- fusion_utils::ComposeOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature,
+ fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(), "fused_maps");
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
@@ -98,7 +98,8 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
auto *fused_function =
FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(),
"fused_map_and_filter_function");
@@ -134,10 +135,10 @@ TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::CombineSignature,
+ fusion_utils::ComposeInput, fusion_utils::CombineOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
@@ -169,7 +170,8 @@ TEST(FusionUtilsTest, ZipFusion) {
auto *fused_function =
FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 2);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 0eceaf4017..5a7fe19265 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -94,11 +94,11 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
MutableGraphView* graph) {
NodeDef node;
if (!name.empty()) {
- node.set_name(name.ToString());
+ node.set_name(string(name));
} else {
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
}
- node.set_op(op.ToString());
+ node.set_op(string(op));
for (const string& input : inputs) {
node.add_input(input);
}
@@ -108,6 +108,26 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -181,7 +201,7 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
}
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
- return FindNodeWithOp(op, graph) != -1;
+ return FindGraphNodeWithOp(op, graph) != -1;
}
bool ContainsGraphFunctionWithName(StringPiece name,
@@ -205,7 +225,7 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return indices.empty() ? -1 : indices.front();
}
-int FindNodeWithOp(StringPiece op, const GraphDef& graph) {
+int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
std::vector<int> indices = GetElementIndicesWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
return indices.empty() ? -1 : indices.front();
@@ -242,9 +262,15 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
return indices.empty() ? -1 : indices.front();
}
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
+ if (node.input_size() == 0) return nullptr;
+ GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
+ return graph.GetRegularFanin(input_port).node;
+}
+
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
if (name.rfind("_generated") != std::string::npos &&
@@ -260,7 +286,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
NodeDef* node) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = function->node_def_size();
while (ContainsFunctionNodeWithName(name, *function)) {
name = strings::StrCat(prefix, "/_", id);
@@ -271,7 +297,7 @@ void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = library->function_size();
while (ContainsGraphFunctionWithName(name, *library)) {
name = strings::StrCat(prefix, "/_", id);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 28a1aff877..6f431c232d 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,6 +37,12 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -99,7 +105,10 @@ int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
// Returns the index of the first node with the given op or -1 if no such node
// exists.
-int FindNodeWithOp(StringPiece op, const GraphDef& graph);
+int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
+
+// Gets the 0th input to a node in the graph.
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 0a3af1a914..c19ac7b880 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -176,25 +176,25 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
}
-TEST(GraphUtilsTest, FindNodeWithOp) {
+TEST(GraphUtilsTest, FindGraphNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
graph.DeleteNodes({"B"});
- EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
}
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
@@ -251,6 +251,54 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
+TEST(GraphUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+TEST(GraphUtilsTest, GetInputNode) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+
+ NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
+ NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
+
+ EXPECT_EQ(GetInputNode(*node2, graph), node1);
+ EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
index 0b25b1ea9d..9e382aeef9 100644
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
@@ -33,7 +33,7 @@ namespace {
constexpr char kInsertOpName[] = "LatencyStatsDataset";
-NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) {
+NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kInsertOpName);
graph_utils::SetUniqueGraphNodeName(
@@ -96,7 +96,7 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- graph.InsertNode(node, make_latency_node(node, &graph));
+ graph.InsertNode(node, MakeLatencyNode(node, &graph));
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 3ce238a30a..63945b8b9e 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -32,9 +32,8 @@ namespace {
constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
-NodeDef make_map_and_batch_node(const NodeDef& map_node,
- const NodeDef& batch_node,
- MutableGraphView* graph) {
+NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
+ MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
@@ -104,8 +103,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& batch_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
@@ -113,7 +112,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* map_node = node2;
auto* new_node =
- graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph));
+ graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
graph.ReplaceInput(batch_node, *new_node);
// Mark the `Map` and `Batch` nodes for removal.
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index a46c504ac4..b676246b31 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -85,8 +85,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@@ -170,8 +170,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@@ -253,8 +253,8 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index 5e76c9f819..f1844a141c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -116,22 +116,25 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = filter_node->attr().at("predicate");
const FunctionDef* filter_func = function_library.Find(fun.func().name());
if (!fusion_utils::CanCompose(map_func->signature(),
- filter_func->signature()))
+ filter_func->signature())) {
+ VLOG(1) << "Can't fuse map and filter because the output signature of "
+ "the map function does not match the input signature of the "
+ "filter function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*map_func, *filter_func, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, output->mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* filter_node = get_filter_node(node);
if (!filter_node) continue;
- GraphView::InputPort input_port =
- graph.GetInputPort(filter_node->name(), 0);
const NodeDef* map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*filter_node, graph));
if (!map_node) continue;
const auto* fused_function = make_fused_function(map_node, filter_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index 027e0c1590..f029a093fa 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -30,7 +30,7 @@ namespace {
NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
+ name, "MapDataset", {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
{"Targuments", {}},
{"output_shapes", {}},
@@ -39,7 +39,7 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "FilterDataset", {input_node_name.ToString()},
+ name, "FilterDataset", {string(input_node_name)},
{{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
{"Targuments", {}},
{"output_shapes", {}},
@@ -101,18 +101,18 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) {
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
- int map_id = graph_utils::FindNodeWithOp("MapDataset", output);
+ int map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
auto& map_node = output.node(map_id);
ASSERT_EQ(map_node.input_size(), 1);
EXPECT_EQ(map_node.input(0), "range");
int filter_by_component_id =
- graph_utils::FindNodeWithOp("FilterByLastComponentDataset", output);
+ graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output);
auto& filter_by_component = output.node(filter_by_component_id);
ASSERT_EQ(filter_by_component.input_size(), 1);
EXPECT_EQ(filter_by_component.input(0), map_node.name());
- int cache_id = graph_utils::FindNodeWithOp("CacheDataset", output);
+ int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
auto& cache_node = output.node(cache_id);
ASSERT_EQ(cache_node.input_size(), 2);
EXPECT_EQ(cache_node.input(0), filter_by_component.name());
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index feb370eb9d..a78ecb09f7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -90,21 +90,25 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = map_node->attr().at("f");
const FunctionDef* func = function_library.Find(fun.func().name());
- if (!fusion_utils::CanCompose(parent_func->signature(), func->signature()))
+ if (!fusion_utils::CanCompose(parent_func->signature(),
+ func->signature())) {
+ VLOG(1) << "Can't fuse two maps because the output signature of the "
+ "first map function does not match the input signature of the "
+ "second function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*parent_func, *func, "fused_map", fusion_utils::ComposeSignature,
fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
- output->mutable_library());
+ fusion_utils::MergeNodes, output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* map_node = get_map_node(node);
if (!map_node) continue;
- GraphView::InputPort input_port = graph.GetInputPort(map_node->name(), 0);
const NodeDef* parent_map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*map_node, graph));
if (!parent_map_node) continue;
const auto* fused_function = get_fused_function(parent_map_node, map_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index df6c19dc7c..b25dfbd0b8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -30,7 +30,7 @@ namespace {
NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
+ name, "MapDataset", {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
{"Targuments", {}},
{"output_shapes", {}},
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
new file mode 100644
index 0000000000..a019b77eb7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -0,0 +1,258 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
+ (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
+}
+
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // If we decide to use a different method of vectorization, we can just
+ // swap out this part.
+ FunctionDef* vectorized_func = library->add_function();
+ // Function inputs and outputs are the same as original, just
+ // with different shapes.
+ *vectorized_func->mutable_signature() = orig_func.signature();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
+ vectorized_func);
+
+ // Add MapDefun node
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
+ map_defun_node->set_op("MapDefun");
+ graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
+ map_defun_node);
+
+ // Set attrs and inputs
+ for (const string& k : {"f", "output_types", "output_shapes"}) {
+ // Function, output types and (unbatched) shapes are the same as the
+ // original map node.
+ CopyAttribute(k, map_node, map_defun_node);
+ }
+
+ // Get types of input arguments from original map function
+ AttrValue t_args;
+ for (const auto& input : vectorized_func->signature().input_arg()) {
+ t_args.mutable_list()->add_type(input.type());
+ map_defun_node->add_input(input.name());
+ }
+ (*map_defun_node->mutable_attr())["Targuments"] = t_args;
+
+ // Set return values to match output names
+ string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
+ for (size_t i = 0; i < vectorized_func->signature().output_arg_size(); ++i) {
+ const auto& output_arg = vectorized_func->signature().output_arg(i);
+ (*vectorized_func->mutable_ret())[output_arg.name()] =
+ strings::StrCat(output_prefix, i);
+ }
+
+ return vectorized_func;
+}
+
+bool IsOutputShapesFullyDefined(const NodeDef& node) {
+ auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
+ if (shapes_attr == nullptr) return false;
+ const auto& shapes = shapes_attr->list().shape();
+
+ for (const TensorShapeProto& shape : shapes) {
+ for (const auto& dim : shape.dim()) {
+ if (dim.size() == -1) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool IsStatefulFn(const FunctionLibraryDefinition& library,
+ const FunctionDef& function_def) {
+ for (const NodeDef& node_def : function_def.node_def()) {
+ const OpDef* op_def;
+ Status s = library.LookUpOpDef(node_def.op(), &op_def);
+ if (!s.ok() || op_def->is_stateful()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasCapturedInputs(const NodeDef& map_node) {
+ return map_node.attr().at("Targuments").list().type_size() > 0;
+}
+
+NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
+ const NodeDef& input_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
+ NodeDef batch_node;
+ batch_node.set_op(old_batch_node.op());
+ graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
+ &batch_node);
+
+ // Set the `input_dataset` input argument
+ batch_node.add_input(input_node.name());
+ // Set the `batch_size` input_argument
+ batch_node.add_input(old_batch_node.input(1));
+ if (batch_node.op() == "BatchDatasetV2") {
+ // Set the `drop_remainder` input argument
+ batch_node.add_input(old_batch_node.input(2));
+ }
+
+ // Set attrs
+ AttrValue output_types;
+ for (const auto& input : vectorized_func.signature().input_arg()) {
+ output_types.mutable_list()->add_type(input.type());
+ }
+ (*batch_node.mutable_attr())["output_types"] = output_types;
+
+ auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
+ const auto& input_shapes =
+ input_node.attr().at("output_shapes").list().shape();
+ int64 batch_size =
+ old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
+ for (size_t i = 0; i < input_shapes.size(); ++i) {
+ TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
+ TensorShapeProto_Dim* dim = shape->add_dim();
+ dim->set_size(batch_size);
+ shape->MergeFrom(input_shapes.Get(i));
+ }
+ return batch_node;
+}
+
+NodeDef MakeNewMapNode(const NodeDef& old_map_node,
+ const NodeDef& old_batch_node,
+ const NodeDef& new_batch_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
+ NodeDef map_node;
+ map_node.set_op(old_map_node.op());
+ graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
+ &map_node);
+
+ // Set the `input_dataset` input argument
+ map_node.add_input(new_batch_node.name());
+ for (int i = 1; i < old_map_node.input_size(); i++) {
+ // Set the `other_arguments` and `num_parallel_calls` input arguments
+ map_node.add_input(old_map_node.input(i));
+ }
+
+ // Set attrs
+ CopyAttribute("Targuments", old_map_node, &map_node);
+ auto& func_attr = (*map_node.mutable_attr())["f"];
+ func_attr.mutable_func()->set_name(vectorized_func.signature().name());
+
+ for (auto key : {"output_shapes", "output_types"}) {
+ CopyAttribute(key, old_batch_node, &map_node);
+ }
+ return map_node;
+}
+
+} // namespace
+
+Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+
+ for (const NodeDef& node : item.graph.node()) {
+ // Find Map->Batch nodes.
+ // TODO(rachelim): Optimize MapAndBatchDataset[V2] as well.
+ if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
+ continue;
+ }
+
+ const NodeDef& batch_node(node);
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+ if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* map_node = node2;
+ // Input to the map node
+ NodeDef* input_node = graph_utils::GetInputNode(*map_node, graph);
+ CHECK_NOTNULL(input_node);
+
+ FunctionDefLibrary* library = output->mutable_library();
+
+ FunctionLibraryDefinition function_library(OpRegistry::Global(), *library);
+ const FunctionDef* orig_func =
+ function_library.Find(map_node->attr().at("f").func().name());
+
+ // Check that this is a valid optimization.
+ if (!IsOutputShapesFullyDefined(*input_node) ||
+ !IsOutputShapesFullyDefined(*map_node) ||
+ IsStatefulFn(function_library, *orig_func) ||
+ HasCapturedInputs(*map_node)) {
+ // 1. If any of the inputs have an unknown shape, don't optimize, since
+ // inputs might not be batchable.
+ // 2. If any of the map func outputs have an unknown shape, don't
+ // optimize, so that batching errors surface as before.
+ // 3. If the function is stateful, don't vectorize it.
+ // 4. TODO(rachelim): Make this work for MapDataset with captured inputs
+ // by tiling inputs or modifying the signature of MapDefun.
+ continue;
+ }
+
+ FunctionDef* vectorized_func =
+ AddVectorizedFunction(*map_node, *orig_func, library);
+ CHECK_NOTNULL(vectorized_func);
+
+ auto* new_batch_node = graph.AddNode(
+ MakeNewBatchNode(batch_node, *input_node, *vectorized_func, &graph));
+
+ auto* new_map_node = graph.AddNode(MakeNewMapNode(
+ *map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
+ graph.ReplaceInput(batch_node, *new_map_node);
+
+ // Mark the `Map` and `Batch` nodes for removal.
+ nodes_to_delete.insert(map_node->name());
+ nodes_to_delete.insert(batch_node.name());
+ }
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapVectorization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapVectorization, "map_vectorization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/map_vectorization.h
index 23ad9470ff..cc56a8ee5e 100644
--- a/tensorflow/core/grappler/optimizers/data/function_rename.h
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.h
@@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
-class FunctionRename : public CustomGraphOptimizer {
+class MapVectorization : public CustomGraphOptimizer {
public:
- FunctionRename() = default;
- ~FunctionRename() override = default;
+ MapVectorization() = default;
+ ~MapVectorization() override = default;
- string name() const override { return "_test_only_function_rename"; };
+ string name() const override { return "map_vectorization"; };
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
@@ -43,4 +43,4 @@ class FunctionRename : public CustomGraphOptimizer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
new file mode 100644
index 0000000000..ed1bd6bc97
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -0,0 +1,201 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+using test::function::GDef;
+using test::function::NDef;
+
+void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
+ TensorShapeProto* t) {
+ for (size_t i = 0; i < dims.size(); ++i) {
+ auto* d = t->add_dim();
+ d->set_size(dims[i]);
+ }
+}
+
+AttrValue MakeShapeListAttr(
+ const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
+ AttrValue shapes_attr;
+ for (size_t i = 0; i < shapes.size(); ++i) {
+ MakeTensorShapeProtoHelper(shapes[i],
+ shapes_attr.mutable_list()->add_shape());
+ }
+
+ return shapes_attr;
+}
+
+NodeDef MakeMapNodeHelper(
+ StringPiece name, StringPiece input_node_name, StringPiece function_name,
+ StringPiece map_op_name,
+ const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
+ const gtl::ArraySlice<DataType>& output_types) {
+ return test::function::NDef(
+ name, map_op_name, {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_types", output_types}});
+}
+
+NodeDef MakeMapNode(
+ StringPiece name, StringPiece input_node_name, StringPiece function_name,
+ const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
+ const gtl::ArraySlice<DataType>& output_types) {
+ return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
+ output_shapes, output_types);
+}
+
+NodeDef MakeBatchNode(
+ StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
+ const gtl::ArraySlice<DataType>& output_types) {
+ return NDef(name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types},
+ {"output_shapes", MakeShapeListAttr(output_shapes)}});
+}
+
+NodeDef MakeBatchV2Node(
+ StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
+ const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
+ const gtl::ArraySlice<DataType>& output_types) {
+ return NDef(name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types},
+ {"output_shapes", MakeShapeListAttr(output_shapes)}});
+}
+
+NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+ return NDef(name, "RangeDataset", inputs,
+ {{"output_shapes", MakeShapeListAttr({{}})},
+ {"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
+}
+
+TEST(MapVectorizationTest, VectorizeMapWithBatch) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
+TEST(MapVectorizationTest, VectorizeMapWithBatchV2) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", false}, {"dtype", DT_BOOL}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchV2Node("batch", "map", "batch_size", "drop_remainder", {{-1}},
+ {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(
+ graph_utils::FindAllGraphNodesWithOp("BatchDatasetV2", output).size(), 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
+TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShape) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("input", "InputDataset", {},
+ {{"output_types", gtl::ArraySlice<DataType>({DT_INT32})}}),
+ MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+}
+
+TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("input", "InputDataset", {},
+ {{"output_shapes", MakeShapeListAttr({{}})}}),
+ MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index 55d57b3b97..a26f1000a3 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -69,8 +69,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
for (const NodeDef& node : item.graph.node()) {
if (!IsNoOp(node, graph)) continue;
- GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
- NodeDef* const parent = graph.GetRegularFanin(input_port).node;
+ NodeDef* const parent = graph_utils::GetInputNode(node, graph);
graph.ReplaceInput(node, *parent);
nodes_to_delete.insert(node.name());
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index 7c7161c5b2..cb0ff670e8 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -76,8 +76,8 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& repeat_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
+
if (node2->op() != "ShuffleDataset") {
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
index a2e470e511..f0696eb76d 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
@@ -78,7 +78,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
EXPECT_TRUE(
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
NodeDef shuffle_and_repeat_node = output.node(
- graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
+ graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDataset", output));
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
index 00ad7494f4..79d9ea1608 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/denormal.h"
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h
index 8414b5b8ca..c9dfb6dc0b 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.h
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace Eigen {
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
new file mode 100644
index 0000000000..eeea269fb0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
+
+Status ExperimentalImplementationSelector::LoadFunctions(
+ const GraphDef& graph) {
+ lib_info_.reset(new FunctionLibraryApiInfo);
+ TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
+ NodeDef* node_def) const {
+ const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
+ if (info == nullptr) {
+ // A regular op, or a function which has no interface.
+ return Status::OK();
+ }
+
+ string task, device;
+ if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
+ return errors::Internal("Could not split device name:", node_def->device());
+ }
+ VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
+ << " = (" << task << ", " << device << ")";
+ DeviceNameUtils::ParsedName parsed_name;
+ DeviceNameUtils::ParseLocalName(device, &parsed_name);
+
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ // The current implementation is not the best, swap the op to the best one.
+ // There will be duplicates in the graph and they will be pruned by other
+ // grappler plugin since no other node is using their output as inputs.
+ // TODO(scottzhu): Update the tf.eager.defun to register functions without
+ // having to call them with input data. That will reduce the graph size and
+ // save the work for prune them.
+ node_def->set_op(best_function_name);
+ }
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::SelectImplementation(
+ GraphDef* graph) const {
+ for (int k = 0; k < graph->node_size(); ++k)
+ TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
+
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
+ return SelectImplementation(optimized_graph);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
new file mode 100644
index 0000000000..82f7473a14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// -- EXPERIMENTAL --
+// This transformation replaces function calls by the appropriate function
+// definition based on properties of the runtime system. For instance,
+// we may choose one implementation over another if we have a GPU with
+// enough memory available.
+//
+// It is a way for the programmer to specify alternative implementations
+// of the same functionality in the graph, and let TensorFlow pick the
+// most appropriate one at runtime.
+//
+// For instance, the python code might specify:
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one',
+// experimental_api_preferred_device='GPU')
+// def plus_one_gpu(x): return x + 1.0
+//
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one')
+// def plus_one_reference_implementation(x): return x + 1.0
+// input = tf.constant(2.0, dtype=tf.float32)
+//
+// z = plus_one_reference_implementation(input)
+// z = plus_one_gpu(input)
+// print(sess.run(z))
+//
+// At runtime, we will trim either `plus_one_gpu` or
+// `plus_one_reference_implementation` based on the availability of the GPU.
+//
+// Available annotations:
+// - experimental_api_implements(string): all functions mapping to the same
+// string can be interchanged. For now, all functions must have the same
+// signature and overloads are not allowed. Defuns within defuns are
+// allowed.
+// - experimental_api_preferred_device(string): sets which device is preferred.
+class ExperimentalImplementationSelector : public CustomGraphOptimizer {
+ public:
+ ExperimentalImplementationSelector() = default;
+ ~ExperimentalImplementationSelector() override = default;
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+ string name() const override {
+ return "experimental_implementation_selector";
+ }
+
+ // This call is not thread-safe.
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ // Does not take any feedback.
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ Status LoadFunctions(const GraphDef& graph);
+ Status MaybeOptimizeFunctionCall(NodeDef* node_def) const;
+
+ // Finds all call sites for functions, then replace with the appropriate
+ // implementation.
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, and
+ // 2. Via the functional interface, where the function name appears as an
+ // Attr.
+ //
+ // There may be multiple call sites for a given function. The function body
+ // may call into another function, so a function might have to be duplicated.
+ // For simplicity, we do not change function bodies. Also, we do not change
+ // gradients.
+ Status SelectImplementation(GraphDef* graph) const;
+
+ std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector);
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000000..2368e577c2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ std::unique_ptr<CustomGraphOptimizer> optimizer =
+ CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ "ExperimentalImplementationSelector");
+ ASSERT_NE(nullptr, optimizer);
+ TF_ASSERT_OK(optimizer->Init());
+
+ GraphDef output;
+ const Status status = optimizer->Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // This is a trivial graph so there is nothing to update.
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("times_two");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XAddX();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("times_two");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+ NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), 5);
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y1") {
+ // Make sure the implementation has been swapped to use the GPU version.
+ EXPECT_EQ("XAddX", node.op());
+ } else if (node.name() == "y2") {
+ // Make sure the implementation is not changed.
+ EXPECT_EQ("XTimesTwo", node.op());
+ }
+ }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XTimesFour();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+ NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ const Tensor input = test::AsScalar<float>(1.0f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", input);
+
+ const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+ test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+ test::AsScalar<float>(4.0f));
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ GrapplerItem optimized(item, std::move(output));
+ const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+ test::AsScalar<float>(2.0f));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc
new file mode 100644
index 0000000000..798e0f6fd5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+FunctionApiInfo::FunctionApiInfo() {}
+FunctionApiInfo::~FunctionApiInfo() {}
+
+Status FunctionApiInfo::Init(const FunctionDef& function_def) {
+ for (const auto& attr : function_def.attr()) {
+ if (attr.first == "experimental_api_preferred_device") {
+ preferred_device_ = attr.second.s();
+ }
+ if (attr.first == "experimental_api_implements") {
+ interface_name_ = attr.second.s();
+ }
+ }
+ if (interface_name_.empty() && !preferred_device_.empty()) {
+ return errors::InvalidArgument(
+ "Function '", function_def.signature().name(),
+ "' has a preferred device, but does not implement an interface");
+ }
+ return Status::OK();
+}
+
+const string& FunctionApiInfo::preferred_device() const {
+ return preferred_device_;
+}
+
+const string& FunctionApiInfo::interface_name() const {
+ return interface_name_;
+}
+
+FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
+FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
+
+namespace {
+bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) {
+ if (f1.ret().size() != f2.ret().size()) return false;
+ const auto& sig1 = f1.signature();
+ const auto& sig2 = f2.signature();
+ // Functions have positional semantics, so we don't check for names.
+ if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
+ for (int k = 0; k < sig1.input_arg_size(); ++k) {
+ const OpDef::ArgDef& arg1 = sig1.input_arg(k);
+ const OpDef::ArgDef& arg2 = sig2.input_arg(k);
+ if (arg1.type() != arg2.type()) return false;
+ if (arg1.type_attr() != arg2.type_attr()) return false;
+ if (arg1.number_attr() != arg2.number_attr()) return false;
+ if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
+ if (arg1.is_ref() != arg2.is_ref()) return false;
+ }
+ return true;
+}
+
+Status ValidateSignature(const string& interface_name,
+ const std::vector<const FunctionDef*>& equiv_funcs) {
+ if (equiv_funcs.size() < 2) return Status::OK();
+ for (size_t k = 1; k < equiv_funcs.size(); ++k) {
+ if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k]))
+ return errors::InvalidArgument(
+ "Functions '", equiv_funcs[0]->signature().name(), "' and '",
+ equiv_funcs[k]->signature().name(), "' both implement '",
+ interface_name, "' but their signatures do not match.");
+ }
+ return Status::OK();
+}
+
+Status ValidateSignatures(
+ const std::unordered_map<string, std::vector<const FunctionDef*>>&
+ intf_to_func) {
+ for (const auto& item : intf_to_func)
+ TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second));
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionLibraryApiInfo::Init(
+ const FunctionDefLibrary& function_library) {
+ std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func;
+ for (const auto& function : function_library.function()) {
+ std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
+ TF_RETURN_IF_ERROR(func_info->Init(function));
+ // Ignore the function if it does not implement any interface.
+ if (func_info->interface_name().empty()) continue;
+
+ const string& function_name = function.signature().name();
+ const string& interface_name = func_info->interface_name();
+ func_to_intf_[function_name] = interface_name;
+ intf_to_funcs_[interface_name].emplace_back(function_name);
+ intf_to_func[interface_name].emplace_back(&function);
+ func_info_[function_name] = std::move(func_info);
+ }
+ TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func));
+ return Status::OK();
+}
+
+void FunctionLibraryApiInfo::GetEquivalentImplementations(
+ const string& function_name, std::vector<string>* other_names) const {
+ const auto intf_it = func_to_intf_.find(function_name);
+ // The function does not implement any interface.
+ if (intf_it == func_to_intf_.end()) return;
+ CHECK(!intf_it->second.empty()) << "Function " << function_name
+ << "should at least implement 1 interface.";
+ const auto it = intf_to_funcs_.find(intf_it->second);
+ CHECK(it != intf_to_funcs_.end())
+ << "Function " << function_name << " maps to " << intf_it->second
+ << " but no reverse mapping was found";
+ CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty";
+ other_names->reserve(it->second.size() - 1);
+ for (const auto& other_name : it->second) {
+ if (other_name == function_name) continue;
+ other_names->emplace_back(other_name);
+ }
+}
+
+void FunctionLibraryApiInfo::GetBestImplementation(
+ const string& function_name, const string& device,
+ string* best_func_name) const {
+ CHECK(best_func_name != nullptr);
+ const auto func_it = func_to_intf_.find(function_name);
+ if (func_it == func_to_intf_.end()) return;
+
+ const auto it = intf_to_funcs_.find(func_it->second);
+ // No function found for the given interface.
+ if (it == intf_to_funcs_.end()) return;
+ for (const auto& func_name : it->second) {
+ const auto func_api_info = func_info_.find(func_name)->second.get();
+ if (func_api_info->preferred_device() == device) {
+ best_func_name->assign(func_name);
+ return;
+ }
+ }
+ // Didn't find a function with the match device name, choose the first one
+ // among all the available functions.
+ best_func_name->assign(it->second.front());
+}
+
+const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
+ const string& function_name) const {
+ const auto it = func_info_.find(function_name);
+ if (it == func_info_.end()) return nullptr;
+ return it->second.get();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h
new file mode 100644
index 0000000000..412687c58c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+class FunctionApiInfo {
+ public:
+ FunctionApiInfo();
+ virtual ~FunctionApiInfo();
+
+ Status Init(const FunctionDef& function_def);
+
+ const string& interface_name() const;
+ const string& preferred_device() const;
+
+ private:
+ string interface_name_;
+ string preferred_device_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
+};
+
+// A collection of information for function and the interface it implements.
+// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
+// functions could implement the same interface with different behavior based on
+// different hardware condition and limits,
+// eg F1 = math_ops.add(math_ops.add(x, x), y), or
+// F2 = math_ops.add(math_ops.matmul(x, 2), y).
+class FunctionLibraryApiInfo {
+ public:
+ FunctionLibraryApiInfo();
+ virtual ~FunctionLibraryApiInfo();
+ // Populate the internal field for the functions within the function_library.
+ Status Init(const FunctionDefLibrary& function_library);
+
+ void GetEquivalentImplementations(const string& function_name,
+ std::vector<string>* other_names) const;
+
+ void GetBestImplementation(const string& function_name, const string& device,
+ string* best_func_name) const;
+
+ const FunctionApiInfo* GetApiInfo(const string& function_name) const;
+
+ private:
+ // Map between function name to function details.
+ std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
+ // Map between function name to interface name.
+ std::unordered_map<string, string> func_to_intf_;
+ // Map between interface name to function names.
+ std::unordered_map<string, std::vector<string>> intf_to_funcs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
new file mode 100644
index 0000000000..582890d3e3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+void SetArg(const string& name, const string& type_name,
+ OpDef::ArgDef* arg_def) {
+ arg_def->set_name(name);
+ arg_def->set_type_attr(type_name);
+}
+
+typedef std::pair<string, string> ArgSpec; // name, type.
+
+void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) {
+ for (const auto& arg_spec : args_spec)
+ SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
+ SetArg("output", "float32", sig->add_output_arg());
+}
+
+void PopulateFunction(const string& name, const string& api_interface_name,
+ const string& preferred_device,
+ const std::vector<ArgSpec>& input_args,
+ FunctionDef* func_def) {
+ OpDef* sig = func_def->mutable_signature();
+ sig->set_name(name);
+
+ SetArgs(input_args, sig);
+
+ if (!api_interface_name.empty() || !preferred_device.empty()) {
+ auto* func_attr = func_def->mutable_attr();
+ if (!api_interface_name.empty())
+ (*func_attr)["experimental_api_implements"].set_s(api_interface_name);
+ if (!preferred_device.empty())
+ (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device);
+ }
+}
+
+void PopulateSampleLibrary(const bool mismatch_args,
+ FunctionDefLibrary* func_lib) {
+ const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
+ const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
+ {"in2", "int32"}};
+ PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args,
+ func_lib->add_function());
+ PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
+ mismatch_args ? func_wrong_args : func_args,
+ func_lib->add_function());
+ PopulateFunction("DoThings", "DoThings", "", func_args,
+ func_lib->add_function());
+ PopulateFunction("OneOff", "", "", func_args, func_lib->add_function());
+ PopulateFunction("AnotherOneOff", "", "", func_args,
+ func_lib->add_function());
+}
+
+bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name,
+ const std::vector<string>& expected_other) {
+ std::vector<string> other_impl;
+ lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
+ const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
+ const std::unordered_set<string> expected(expected_other.begin(),
+ expected_other.end());
+ return actual == expected;
+}
+
+bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& function_name, const string& device,
+ const string& expected_function_name) {
+ string best_function_name;
+ lib_api_info.GetBestImplementation(function_name, device,
+ &best_function_name);
+
+ return best_function_name == expected_function_name;
+}
+
+string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->interface_name();
+}
+
+string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->preferred_device();
+}
+
+TEST(FunctionApiInfoTest, ParseTags) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ TF_ASSERT_OK(lib_api_info.Init(func_lib));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
+
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
+
+ EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
+
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu"));
+
+ EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings"));
+ // TPU impl is not available, choose the first one available which is the CPU.
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu"));
+}
+
+TEST(FunctionApiInfoTest, MismatchedArguments) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ const Status ret = lib_api_info.Init(func_lib);
+ EXPECT_FALSE(ret.ok());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 645e4c2087..56364f0095 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature(
}
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
+ const int graph_def_version,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
VLOG(2) << "Specialize function instantiation: "
@@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
// pushing all constant inputs into the function body.
GrapplerFunctionItem item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib,
+ graph_def_version, &item));
// Push const inputs into the function body, and keep track of their control
// dependencies.
@@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node,
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
const FunctionOptimizerContext& ctx,
- GraphDef* optimized_graph) {
+ const int graph_def_version, GraphDef* optimized_graph) {
VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
GrapplerFunctionItem item;
- Status item_status =
- MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item);
+ Status item_status = MakeGrapplerFunctionItem(
+ func, func_attr, ctx.function_library(), graph_def_version, &item);
if (!item_status.ok()) {
return errors::InvalidArgument("Failed to inline function ", func_node.op(),
@@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
if (func_body_node_func != nullptr) {
// Recursively inline function calls.
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
- ctx, optimized_graph));
+ ctx, graph_def_version,
+ optimized_graph));
} else {
// Annotate the node with the function attributes.
for (const auto& attr : func.attr()) {
@@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (inline_func && ctx.IsInlinedFunction(func_name)) {
// Inline function body into the optimized graph}
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineFunction(node, *func, ctx, optimized_graph));
+ InlineFunction(node, *func, ctx, item.graph.versions().producer(),
+ optimized_graph));
continue;
}
@@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- SpecializeFunction(node, *func, &ctx, optimized_graph));
+ SpecializeFunction(node, *func, item.graph.versions().producer(),
+ &ctx, optimized_graph));
continue;
}
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1be5f8dcc2..c775a26914 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
@@ -1070,11 +1071,13 @@ static bool IdentifySwappingCandidates(
// ensure that swapping the tensor back in won't recreate the memory
// bottleneck. Last but not least, we want the tensor to have as few
// remaining uses as possible.
+ //
+ // Note that we must perform the arithmetic inexactly as "double", since
+ // the values do not fit into any integral type.
mem_info.fitness =
- MathUtil::IPow((earliest_use - peak_time).count(), 2);
- mem_info.fitness /= MathUtil::IPow(mem_info.uses_left.size(), 2);
- mem_info.fitness +=
- MathUtil::IPow((allocation_time - peak_time).count(), 2);
+ MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
+ MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
+ MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info);
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e778b7879d..b75d6303b4 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -72,6 +72,16 @@ bool IsRunOnceOptimizer(const string& name) {
name == "loop_optimizer";
}
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (auto node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
#define MK_OPT(NAME, VALUE) \
@@ -156,7 +166,7 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return Status::OK();
+ return InitializeCustomGraphOptimizers(optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
@@ -180,6 +190,11 @@ Status MetaOptimizer::InitializeOptimizersByName(
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
+ return InitializeCustomGraphOptimizers(optimizers);
+}
+
+Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
optimizer_config.name());
@@ -208,7 +223,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
+ if (cfg_.optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
} else {
TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
@@ -331,6 +346,19 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+ // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+ // passes do not handle TPU functions correctly in a variety of ways (Note
+ // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+ // ops are encapsulated away into functions). For example, TPU graphs contain
+ // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+ // passes could prune that away. Grappler passes could also cause issues
+ // around shape inference. Since the desired and existing behavior is to not
+ // optimize TPU functions with Grappler, this check preserves that.
+ if (IsTPUGraphDef(*optimized_graph)) {
+ VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+ return Status::OK();
+ }
+
// 2. Optimize function library
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
@@ -361,7 +389,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Make a GrapplerItem from a FunctionDef.
GrapplerFunctionItem func_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ func, flib, item.graph.versions().producer(), &func_item));
// Optimize function body graph.
GraphDef optimized_func_graph;
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 151a54cbdf..831c5e37c0 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer {
// Initialize active optimizers from RewriterConfig optimizer names.
Status InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Initialize active optimizers from RewriterConfig.custom_optimizers.
+ Status InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 9a03c7dfef..e74e0f7501 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -64,6 +64,13 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+class TestGraphOptimizer : public TestOptimizer {
+ public:
+ string name() const override { return "test_graph_optimizer"; }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ TestGraphOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizer");
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
+TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
index 275568e464..0d4aaf6462 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -203,7 +203,7 @@ void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
NodeDef* node_def) {
if (HasNodeAttr(*node_def, name)) {
VLOG(2) << "extending";
- AttrValue* existing = &(*node_def->mutable_attr())[name.ToString()];
+ AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
for (int32 i : values) {
existing->mutable_list()->add_i(i);
}
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
index 89847f83d4..b033cff8e6 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/testlib.h"
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 26c54df56b..caa0b7b0cb 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index a9c34b6d08..20dbeea2cf 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -139,7 +139,7 @@ inline StringPiece ParseNodeNameAsStringPiece(const string& name,
// Returns the node name and position in a single call.
inline string ParseNodeName(const string& name, int* position) {
- return std::string(ParseNodeNameAsStringPiece(name, position));
+ return string(ParseNodeNameAsStringPiece(name, position));
}
// Add a prefix to a node name with a custom delimiter.
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 462b752316..a428aea7f5 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -303,21 +304,22 @@ Status GrapplerFunctionItemInstantiation::GetArgType(
}
GrapplerFunctionItem::GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body)
- : description_(description),
- func_attr_(func_attr),
- input_arg_expansions_(input_arg_expansions),
- output_arg_expansions_(output_arg_expansions),
+ string func_name, string description, AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, const int graph_def_version,
+ const bool is_stateful, GraphDef&& function_body)
+ : description_(std::move(description)),
+ func_attr_(std::move(func_attr)),
+ input_arg_expansions_(std::move(input_arg_expansions)),
+ output_arg_expansions_(std::move(output_arg_expansions)),
is_stateful_(is_stateful) {
- id = func_name;
- keep_ops = keep_nodes;
- // Swap the graph body.
- graph.Swap(&function_body);
+ // Move assign GrapplerItem members.
+ keep_ops = std::move(keep_nodes);
+ id = std::move(func_name);
+ graph = std::move(function_body);
+
+ graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
@@ -472,6 +474,7 @@ Status InstantiationBodyParameters(
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
const OpDef& signature = func.signature();
@@ -595,14 +598,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(), /*description=*/signature.description(),
/*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
- inputs, outputs, keep_nodes, is_stateful, std::move(function_body));
+ std::move(inputs), std::move(outputs), std::move(keep_nodes),
+ graph_def_version, is_stateful, std::move(function_body));
return Status::OK();
}
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
- return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item);
+ return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, graph_def_version,
+ item);
}
// Register GrapplerFunctionItem input arg expansion and function body outputs
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 9f607dc2ee..733caf325f 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation {
class GrapplerFunctionItem : public GrapplerItem {
public:
GrapplerFunctionItem() = default;
- GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body);
+ GrapplerFunctionItem(string func_name, string description,
+ AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, int graph_def_version,
+ bool is_stateful, GraphDef&& function_body);
const string& description() const;
@@ -222,6 +221,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a GrapplerFunction item from the function definition. Function must be
@@ -231,6 +231,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// without specializing it to it's instantiation attributes (at least types)?
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a FunctionDef from the GrapplerFunctionItem. Use function library
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index b2d059e0ac..b51f2781b8 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace grappler {
@@ -239,7 +240,8 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("XTimesTwo", item.id);
EXPECT_EQ(4, item.function_body().node_size());
@@ -314,7 +316,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("SubGrad", item.id);
EXPECT_EQ(12, item.function_body().node_size());
@@ -395,7 +398,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
func_attr["T"].set_type(DT_FLOAT);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
@@ -456,7 +460,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("Exp", item.output(0).output_tensors[0]);
@@ -499,7 +504,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("ForwardInputs", item.id);
EXPECT_EQ(5, item.function_body().node_size());
@@ -545,7 +551,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(0, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -584,7 +591,8 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef specialized;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
@@ -622,7 +630,8 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(2, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -713,7 +722,8 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
// Replace function body with identity function
item.SwapFunctionBody(std::move(id_func_body));
@@ -754,7 +764,8 @@ TEST_F(FunctionsTest, FunctionDefGrapplerFunctionItemRoundTrip) {
GrapplerFunctionItem item;
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_INT32);
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef func2;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &func2));
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c8523d3201..c3c6013d83 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -495,16 +495,6 @@ cc_library(
],
)
-cc_library(
- name = "warn_about_ints",
- srcs = ["warn_about_ints.cc"],
- hdrs = ["warn_about_ints.h"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
# Private support libraries ---------------------------------------------------
cc_header_only_library(
@@ -653,14 +643,7 @@ cc_library(
":split_v_op",
":strided_slice_op",
":tile_ops",
- ] + if_mkl(
- [
- ":mkl_transpose_op",
- ],
- [
- ":transpose_op",
- ],
- ) + [
+ ":transpose_op",
":unique_op",
":unpack_op",
":unravel_index_op",
@@ -903,24 +886,13 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
-if_mkl(
- [tf_mkl_kernel_library(
- name = "mkl_transpose_op",
- srcs = [
- "mkl_transpose_op.cc",
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS + mkl_deps(),
- )],
- [tf_kernel_library(
- name = "transpose_op",
- srcs = [
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS,
- )],
+tf_kernel_library(
+ name = "transpose_op",
+ srcs = [
+ "transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]),
)
tf_kernel_library(
@@ -1290,6 +1262,7 @@ tf_cuda_cc_test(
srcs = ["gather_op_test.cc"],
deps = [
":gather_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -1308,6 +1281,7 @@ tf_cuda_cc_test(
srcs = ["gather_nd_op_test.cc"],
deps = [
":gather_nd_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -2304,6 +2278,31 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "eigen_benchmark",
+ testonly = 1,
+ hdrs = [
+ "eigen_benchmark.h",
+ ":eigen_helpers",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
+ name = "eigen_benchmark_cpu_test",
+ srcs = ["eigen_benchmark_cpu_test.cc"],
+ deps = [
+ ":eigen_benchmark",
+ ":eigen_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_tests(
name = "basic_ops_benchmark_test",
size = "small",
@@ -3176,6 +3175,7 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":reduction_ops",
@@ -3311,6 +3311,7 @@ tf_cuda_cc_test(
srcs = ["diag_op_test.cc"],
deps = [
":diag_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3531,13 +3532,13 @@ tf_kernel_library(
tf_kernel_library(
name = "softplus_op",
prefix = "softplus_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
name = "softsign_op",
prefix = "softsign_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
@@ -3638,6 +3639,7 @@ tf_cuda_cc_test(
name = "nn_ops_test",
srcs = ["nn_ops_test.cc"],
deps = [
+ ":host_constant_op",
":nn",
":ops_testutil",
":ops_util",
@@ -3771,7 +3773,7 @@ tf_kernel_library(
"spacetobatch_functor.h",
"spacetobatch_functor_gpu.cu.cc",
],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":bounds_check",
"//tensorflow/core:framework",
@@ -3785,6 +3787,7 @@ tf_cuda_cc_test(
srcs = ["spacetobatch_benchmark_test.cc"],
deps = [
":batch_space_ops",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3924,6 +3927,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["random_op_test.cc"],
deps = [
+ ":host_constant_op",
":random_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -4178,6 +4182,7 @@ tf_cuda_cc_tests(
"sparse_xent_op_test.cc",
],
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":sparse",
@@ -4198,6 +4203,7 @@ cc_library(
"hinge-loss.h",
"logistic-loss.h",
"loss.h",
+ "poisson-loss.h",
"smooth-hinge-loss.h",
"squared-loss.h",
],
@@ -4444,12 +4450,48 @@ tf_kernel_library(
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
)
+tf_cc_test(
+ name = "regex_replace_op_test",
+ size = "small",
+ srcs = ["regex_replace_op_test.cc"],
+ deps = [
+ ":regex_replace_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_split_op",
prefix = "string_split_op",
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "string_split_op_test",
+ size = "small",
+ srcs = ["string_split_op_test.cc"],
+ deps = [
+ ":string_split_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_strip_op",
prefix = "string_strip_op",
@@ -4523,6 +4565,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["multinomial_op_test.cc"],
deps = [
+ ":host_constant_op",
":multinomial_op",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -4550,6 +4593,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["parameterized_truncated_normal_op_test.cc"],
deps = [
+ ":host_constant_op",
":ops_util",
":parameterized_truncated_normal_op",
"//tensorflow/core:core_cpu",
@@ -5059,7 +5103,6 @@ filegroup(
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
- "warn_about_ints.h",
"where_op.h",
"xent_op.h",
],
@@ -5141,6 +5184,7 @@ filegroup(
"fifo_queue.cc",
"fifo_queue_op.cc",
"fused_batch_norm_op.cc",
+ "listdiff_op.cc",
"population_count_op.cc",
"population_count_op.h",
"winograd_transform.h",
@@ -5236,7 +5280,6 @@ filegroup(
"transpose_functor_cpu.cc",
"transpose_op.cc",
"unique_op.cc",
- "warn_about_ints.cc",
"where_op.cc",
"xent_op.cc",
":android_extended_ops_headers",
@@ -6291,6 +6334,15 @@ tf_mkl_kernel_library(
deps = NN_DEPS + mkl_deps() + [":cwise_op"],
)
+tf_mkl_kernel_library(
+ name = "mkl_transpose_op",
+ srcs = [
+ "mkl_transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
# NOTE(lespeholt): This rule is deprecated, please use:
# tensorflow/core/util/batch_util.h
cc_library(
diff --git a/tensorflow/core/kernels/adjust_contrast_op.h b/tensorflow/core/kernels/adjust_contrast_op.h
index 7689c04214..f4a53c2ef9 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.h
+++ b/tensorflow/core/kernels/adjust_contrast_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
-#define TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -153,4 +153,4 @@ struct AdjustContrastv2 {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
diff --git a/tensorflow/core/kernels/adjust_hue_op.h b/tensorflow/core/kernels/adjust_hue_op.h
index 03d52a9e77..983a4072bf 100644
--- a/tensorflow/core/kernels/adjust_hue_op.h
+++ b/tensorflow/core/kernels/adjust_hue_op.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
-#define _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -37,4 +37,4 @@ struct AdjustHueGPU {
} // namespace tensorflow
#endif // GOOGLE_CUDA
-#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
diff --git a/tensorflow/core/kernels/adjust_saturation_op.h b/tensorflow/core/kernels/adjust_saturation_op.h
index 05c45c07c3..fd28ba536f 100644
--- a/tensorflow/core/kernels/adjust_saturation_op.h
+++ b/tensorflow/core/kernels/adjust_saturation_op.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
-#define _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -37,4 +37,4 @@ struct AdjustSaturationGPU {
} // namespace tensorflow
#endif // GOOGLE_CUDA
-#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
diff --git a/tensorflow/core/kernels/aggregate_ops.h b/tensorflow/core/kernels/aggregate_ops.h
index 9ea49fc34b..e074d0c2d9 100644
--- a/tensorflow/core/kernels/aggregate_ops.h
+++ b/tensorflow/core/kernels/aggregate_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
-#define TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
// Functor definitions for Aggregate ops, must be compilable by nvcc.
@@ -223,4 +223,4 @@ struct Add9EigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h
index aa1cead928..3e87917b64 100644
--- a/tensorflow/core/kernels/aggregate_ops_cpu.h
+++ b/tensorflow/core/kernels/aggregate_ops_cpu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
-#define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
+#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -250,4 +250,4 @@ struct Add9Functor<SYCLDevice, T> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
+#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
diff --git a/tensorflow/core/kernels/argmax_op.h b/tensorflow/core/kernels/argmax_op.h
index b8bc41e089..224aa4654d 100644
--- a/tensorflow/core/kernels/argmax_op.h
+++ b/tensorflow/core/kernels/argmax_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ARGMAX_OP_H_
-#define TENSORFLOW_KERNELS_ARGMAX_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
// Generator definition for ArgMaxOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -65,4 +65,4 @@ struct ArgMin {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ARGMAX_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h
index a450b1d1ee..74f926bdc8 100644
--- a/tensorflow/core/kernels/assign_op.h
+++ b/tensorflow/core/kernels/assign_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_
-#define TENSORFLOW_KERNELS_ASSIGN_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
#define EIGEN_USE_THREADS
@@ -143,4 +143,4 @@ class AssignOp : public OpKernel {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ASSIGN_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h
index f5e81dbc09..1e49a66af9 100644
--- a/tensorflow/core/kernels/avgpooling_op.h
+++ b/tensorflow/core/kernels/avgpooling_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
-#define TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
+#define TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
// Functor definition for AvgPoolingOp, must be compilable by nvcc.
#include "tensorflow/core/framework/tensor_types.h"
@@ -76,4 +76,4 @@ bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 475bda848d..766713a338 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -15,6 +15,9 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -613,3 +616,5 @@ class BatchMatMul : public OpKernel {
BatchMatMul<SYCLDevice, TYPE>)
#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h
index 48e73c8757..76b156f8fd 100644
--- a/tensorflow/core/kernels/batch_norm_op.h
+++ b/tensorflow/core/kernels/batch_norm_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
-#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
// Functor definition for BatchNormOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -153,4 +153,4 @@ struct BatchNormGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
diff --git a/tensorflow/core/kernels/betainc_op.h b/tensorflow/core/kernels/betainc_op.h
index c4aa9543ab..b941b27ad3 100644
--- a/tensorflow/core/kernels/betainc_op.h
+++ b/tensorflow/core/kernels/betainc_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BETAINC_OP_H_
-#define TENSORFLOW_KERNELS_BETAINC_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
// Functor definition for BetaincOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -48,4 +48,4 @@ struct Betainc {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BETAINC_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
diff --git a/tensorflow/core/kernels/bias_op.h b/tensorflow/core/kernels/bias_op.h
index 065934c709..77f683455d 100644
--- a/tensorflow/core/kernels/bias_op.h
+++ b/tensorflow/core/kernels/bias_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BIAS_OP_H_
-#define TENSORFLOW_KERNELS_BIAS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
// Functor definition for BiasOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -52,4 +52,4 @@ struct Bias {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BIAS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
diff --git a/tensorflow/core/kernels/bincount_op.h b/tensorflow/core/kernels/bincount_op.h
index cd3d560cd1..54cfb79de7 100644
--- a/tensorflow/core/kernels/bincount_op.h
+++ b/tensorflow/core/kernels/bincount_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_BINCOUNT_OP_H_
-#define TENSORFLOW_BINCOUNT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -38,4 +38,4 @@ struct BincountFunctor {
} // end namespace tensorflow
-#endif // TENSORFLOW_BINCOUNT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
new file mode 100644
index 0000000000..3163c63949
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -0,0 +1,63 @@
+# Description:
+# This directory contains common utilities used in boosted_trees.
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# Quantiles
+
+cc_library(
+ name = "weighted_quantiles",
+ srcs = [],
+ hdrs = [
+ "weighted_quantiles_buffer.h",
+ "weighted_quantiles_stream.h",
+ "weighted_quantiles_summary.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_buffer_test",
+ size = "small",
+ srcs = ["weighted_quantiles_buffer_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_summary_test",
+ size = "small",
+ srcs = ["weighted_quantiles_summary_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_stream_test",
+ size = "small",
+ srcs = ["weighted_quantiles_stream_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
new file mode 100644
index 0000000000..07aa9831c4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
@@ -0,0 +1,132 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Buffering container ideally suited for scenarios where we need
+// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesBuffer {
+ public:
+ struct BufferEntry {
+ BufferEntry(ValueType v, WeightType w)
+ : value(std::move(v)), weight(std::move(w)) {}
+ BufferEntry() : value(), weight(0) {}
+
+ bool operator<(const BufferEntry& other) const {
+ return kCompFn(value, other.value);
+ }
+ bool operator==(const BufferEntry& other) const {
+ return value == other.value && weight == other.weight;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const BufferEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << "}";
+ }
+ ValueType value;
+ WeightType weight;
+ };
+
+ explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
+ : max_size_(std::min(block_size << 1, max_elements)) {
+ QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
+ << ", " << max_elements << ")";
+ vec_.reserve(max_size_);
+ }
+
+ // Disallow copying as it's semantically non-sensical in the Squawd algorithm
+ // but enable move semantics.
+ WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
+ WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
+ WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
+ WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
+
+ // Push entry to buffer and maintain a compact representation within
+ // pre-defined size limit.
+ void PushEntry(ValueType value, WeightType weight) {
+ // Callers are expected to act on a full compacted buffer after the
+ // PushEntry call returns.
+ QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
+
+ // Ignore zero and negative weight entries.
+ if (weight <= 0) {
+ return;
+ }
+
+ // Push back the entry to the buffer.
+ vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
+ }
+
+ // Returns a sorted vector view of the base buffer and clears the buffer.
+ // Callers should minimize how often this is called, ideally only right after
+ // the buffer becomes full.
+ std::vector<BufferEntry> GenerateEntryList() {
+ std::vector<BufferEntry> ret;
+ if (vec_.size() == 0) {
+ return ret;
+ }
+ ret.swap(vec_);
+ vec_.reserve(max_size_);
+ std::sort(ret.begin(), ret.end());
+ size_t num_entries = 0;
+ for (size_t i = 1; i < ret.size(); ++i) {
+ if (ret[i].value != ret[i - 1].value) {
+ BufferEntry tmp = ret[i];
+ ++num_entries;
+ ret[num_entries] = tmp;
+ } else {
+ ret[num_entries].weight += ret[i].weight;
+ }
+ }
+ ret.resize(num_entries + 1);
+ return ret;
+ }
+
+ int64 Size() const { return vec_.size(); }
+ bool IsFull() const { return vec_.size() >= max_size_; }
+ void Clear() { vec_.clear(); }
+
+ private:
+ using BufferVector = typename std::vector<BufferEntry>;
+
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Base buffer.
+ size_t max_size_;
+ BufferVector vec_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
new file mode 100644
index 0000000000..75f05d64f3
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
@@ -0,0 +1,99 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double,
+ double>::BufferEntry;
+
+class WeightedQuantilesBufferTest : public ::testing::Test {};
+
+TEST_F(WeightedQuantilesBufferTest, Invalid) {
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(2, 0);
+ }),
+ "Invalid buffer specification");
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(0, 2);
+ }),
+ "Invalid buffer specification");
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) {
+ Buffer buffer(20, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(3, 0); // This entry will be ignored.
+
+ EXPECT_FALSE(buffer.IsFull());
+ EXPECT_EQ(buffer.Size(), 3);
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ EXPECT_EQ(buffer.GenerateEntryList(), expected);
+ EXPECT_FALSE(buffer.IsFull());
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ // Can't push any more entries before clearing.
+ EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
new file mode 100644
index 0000000000..525e2a6a64
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
@@ -0,0 +1,330 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Class to compute approximate quantiles with error bound guarantees for
+// weighted data sets.
+// This implementation is an adaptation of techniques from the following papers:
+// * (2001) Space-efficient online computation of quantile summaries.
+// * (2004) Power-conserving computation of order-statistics over
+// sensor networks.
+// * (2007) A fast algorithm for approximate quantiles in high speed
+// data streams.
+// * (2016) XGBoost: A Scalable Tree Boosting System.
+//
+// The key ideas at play are the following:
+// - Maintain an in-memory multi-level quantile summary in a way to guarantee
+// a maximum approximation error of eps * W per bucket where W is the total
+// weight across all points in the input dataset.
+// - Two base operations are defined: MERGE and COMPRESS. MERGE combines two
+// summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses
+// a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b.
+// - b * sizeof(summary entry) must ideally be small enough to fit in an
+// average CPU L2 cache.
+// - To distribute this algorithm with maintaining error bounds, we need
+// the worker-computed summaries to have no more than eps / h error
+// where h is the height of the distributed computation graph which
+// is 2 for an MR with no combiner.
+//
+// We mainly want to max out IO bw by ensuring we're not compute-bound and
+// using a reasonable amount of RAM.
+//
+// Complexity:
+// Compute: O(n * log(1/eps * log(eps * n))).
+// Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the
+// entire dataset.
+// An epsilon value of zero would make the algorithm extremely inefficent and
+// therefore, is disallowed.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesStream {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+ using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>;
+ using SummaryEntry = typename Summary::SummaryEntry;
+
+ explicit WeightedQuantilesStream(double eps, int64 max_elements)
+ : eps_(eps), buffer_(1LL, 2LL), finalized_(false) {
+ // See the class documentation. An epsilon value of zero could cause
+ // perfoamance issues.
+ QCHECK(eps > 0) << "An epsilon value of zero is not allowed.";
+ std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements);
+ buffer_ = Buffer(block_size_, max_elements);
+ summary_levels_.reserve(max_levels_);
+ }
+
+ // Disallow copy and assign but enable move semantics for the stream.
+ WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete;
+ WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete;
+ WeightedQuantilesStream(WeightedQuantilesStream&& other) = default;
+ WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default;
+
+ // Pushes one entry while maintaining approximation error invariants.
+ void PushEntry(const ValueType& value, const WeightType& weight) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Push element to base buffer.
+ buffer_.PushEntry(value, weight);
+
+ // When compacted buffer is full we need to compress
+ // and push weighted quantile summary up the level chain.
+ if (buffer_.IsFull()) {
+ PushBuffer(buffer_);
+ }
+ }
+
+ // Pushes full buffer while maintaining approximation error invariants.
+ void PushBuffer(Buffer& buffer) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList());
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Pushes full summary while maintaining approximation error invariants.
+ void PushSummary(const std::vector<SummaryEntry>& summary) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromSummaryEntries(summary);
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Flushes approximator and finalizes state.
+ void Finalize() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() may only be called once.";
+
+ // Flush any remaining buffer elements.
+ PushBuffer(buffer_);
+
+ // Create final merged summary.
+ local_summary_.Clear();
+ for (auto& summary : summary_levels_) {
+ local_summary_.Merge(summary);
+ summary.Clear();
+ }
+ summary_levels_.clear();
+ summary_levels_.shrink_to_fit();
+ finalized_ = true;
+ }
+
+ // Generates requested number of quantiles after finalizing stream.
+ // The returned quantiles can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating quantiles.";
+ return local_summary_.GenerateQuantiles(num_quantiles);
+ }
+
+ // Generates requested number of boundaries after finalizing stream.
+ // The returned boundaries can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ // The boundaries, while still guaranteeing approximation bounds, don't
+ // necessarily represent the actual quantiles of the distribution.
+ // Boundaries are preferable over quantiles when the caller is less
+ // interested in the actual quantiles distribution and more interested in
+ // getting a representative sample of boundary values.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating boundaries.";
+ return local_summary_.GenerateBoundaries(num_boundaries);
+ }
+
+ // Calculates approximation error for the specified level.
+ // If the passed level is negative, the approximation error for the entire
+ // summary is returned. Note that after Finalize is called, only the overall
+ // error is available.
+ WeightType ApproximationError(int64 level = -1) const {
+ if (finalized_) {
+ QCHECK(level <= 0) << "Only overall error is available after Finalize()";
+ return local_summary_.ApproximationError();
+ }
+
+ if (summary_levels_.empty()) {
+ // No error even if base buffer isn't empty.
+ return 0;
+ }
+
+ // If level is negative, we get the approximation error
+ // for the top-most level which is the max approximation error
+ // in all summaries by construction.
+ if (level < 0) {
+ level = summary_levels_.size() - 1;
+ }
+ QCHECK(level < summary_levels_.size()) << "Invalid level.";
+ return summary_levels_[level].ApproximationError();
+ }
+
+ size_t MaxDepth() const { return summary_levels_.size(); }
+
+ // Generates requested number of quantiles after finalizing stream.
+ const Summary& GetFinalSummary() const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before requesting final summary.";
+ return local_summary_;
+ }
+
+ // Helper method which, given the desired approximation error
+ // and an upper bound on the number of elements, computes the optimal
+ // number of levels and block size and returns them in the tuple.
+ static std::tuple<int64, int64> GetQuantileSpecs(double eps,
+ int64 max_elements);
+
+ // Serializes the internal state of the stream.
+ std::vector<Summary> SerializeInternalSummaries() const {
+ // The buffer should be empty for serialize to work.
+ QCHECK_EQ(buffer_.Size(), 0);
+ std::vector<Summary> result;
+ result.reserve(summary_levels_.size() + 1);
+ for (const Summary& summary : summary_levels_) {
+ result.push_back(summary);
+ }
+ result.push_back(local_summary_);
+ return result;
+ }
+
+ // Resets the state of the stream with a serialized state.
+ void DeserializeInternalSummaries(const std::vector<Summary>& summaries) {
+ // Clear the state before deserializing.
+ buffer_.Clear();
+ summary_levels_.clear();
+ local_summary_.Clear();
+ QCHECK_GT(max_levels_, summaries.size() - 1);
+ for (int i = 0; i < summaries.size() - 1; ++i) {
+ summary_levels_.push_back(summaries[i]);
+ }
+ local_summary_ = summaries[summaries.size() - 1];
+ }
+
+ private:
+ // Propagates local summary through summary levels while maintaining
+ // approximation error invariants.
+ void PropagateLocalSummary() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // No-op if there's nothing to add.
+ if (local_summary_.Size() <= 0) {
+ return;
+ }
+
+ // Propagate summary through levels.
+ size_t level = 0;
+ for (bool settled = false; !settled; ++level) {
+ // Ensure we have enough depth.
+ if (summary_levels_.size() <= level) {
+ summary_levels_.emplace_back();
+ }
+
+ // Merge summaries.
+ Summary& current_summary = summary_levels_[level];
+ local_summary_.Merge(current_summary);
+
+ // Check if we need to compress and propagate summary higher.
+ if (current_summary.Size() == 0 ||
+ local_summary_.Size() <= block_size_ + 1) {
+ current_summary = std::move(local_summary_);
+ settled = true;
+ } else {
+ // Compress, empty current level and propagate.
+ local_summary_.Compress(block_size_, eps_);
+ current_summary.Clear();
+ }
+ }
+ }
+
+ // Desired approximation precision.
+ double eps_;
+ // Maximum number of levels.
+ int64 max_levels_;
+ // Max block size per level.
+ int64 block_size_;
+ // Base buffer.
+ Buffer buffer_;
+ // Local summary used to minimize memory allocation and cache misses.
+ // After the stream is finalized, this summary holds the final quantile
+ // estimates.
+ Summary local_summary_;
+ // Summary levels;
+ std::vector<Summary> summary_levels_;
+ // Flag indicating whether the stream is finalized.
+ bool finalized_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+inline std::tuple<int64, int64>
+WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
+ double eps, int64 max_elements) {
+ int64 max_level = 1LL;
+ int64 block_size = 2LL;
+ QCHECK(eps >= 0 && eps < 1);
+ QCHECK_GT(max_elements, 0);
+
+ if (eps <= std::numeric_limits<double>::epsilon()) {
+ // Exact quantile computation at the expense of RAM.
+ max_level = 1;
+ block_size = std::max(max_elements, int64{2});
+ } else {
+ // The bottom-most level will become full at most
+ // (max_elements / block_size) times, the level above will become full
+ // (max_elements / 2 * block_size) times and generally level l becomes
+ // full (max_elements / 2^l * block_size) times until the last
+ // level max_level becomes full at most once meaning when the inequality
+ // (2^max_level * block_size >= max_elements) is satisfied.
+ // In what follows, we jointly solve for max_level and block_size by
+ // gradually increasing the level until the inequality above is satisfied.
+ // We could alternatively set max_level = ceil(log2(eps * max_elements));
+ // and block_size = ceil(max_level / eps) + 1 but that tends to give more
+ // pessimistic bounds and wastes RAM needlessly.
+ for (max_level = 1, block_size = 2;
+ (1LL << max_level) * block_size < max_elements; ++max_level) {
+ // Update upper bound on block size at current level, we always
+ // increase the estimate by 2 to hold the min/max elements seen so far.
+ block_size = static_cast<size_t>(ceil(max_level / eps)) + 1;
+ }
+ }
+ return std::make_tuple(max_level, std::max(block_size, int64{2}));
+}
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
new file mode 100644
index 0000000000..6c5b9fd23b
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
@@ -0,0 +1,276 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+using Tuple = std::tuple<int64, int64>;
+
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double, double>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double,
+ double>::SummaryEntry;
+using Stream =
+ boosted_trees::quantiles::WeightedQuantilesStream<double, double>;
+
+TEST(GetQuantileSpecs, InvalidEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(-0.01, 0L); }, "eps >= 0");
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(1.01, 0L); }, "eps < 1");
+}
+
+TEST(GetQuantileSpecs, ZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.0, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 1LL), Tuple(1LL, 2LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 20LL), Tuple(1LL, 20LL));
+}
+
+TEST(GetQuantileSpecs, NonZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.01, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 320LL), Tuple(4LL, 31LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 25600LL), Tuple(6LL, 501LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 104857600LL), Tuple(17LL, 1601LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 104857600LL), Tuple(20LL, 191LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 1LL << 40), Tuple(29LL, 2801LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.001, 1LL << 40), Tuple(26LL, 25001LL));
+}
+
+class WeightedQuantilesStreamTest : public ::testing::Test {};
+
+// Stream generators.
+void GenerateFixedUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, 1.0);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateFixedNonUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, x);
+ (*total_weight) += x;
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformFixedWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ stream->PushEntry(x, 1);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformRandWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ const double w = rand.RandDouble();
+ stream->PushEntry(x, w);
+ (*total_weight) += w;
+ }
+ stream->Finalize();
+}
+
+// Single worker tests.
+void TestSingleWorkerStreams(
+ double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Generate single stream.
+ double total_weight = 0;
+ Stream stream(eps, max_elements);
+ worker_summary_generator(0, max_elements, &total_weight, &stream);
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(stream.ApproximationError(), eps);
+ EXPECT_NEAR(stream.GetFinalSummary().TotalWeight(), total_weight, 1e-6);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals = stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+// Stream generators.
+void GenerateOneValue(int32 worker_id, int64 max_elements, double *total_weight,
+ Stream *stream) {
+ stream->PushEntry(10, 1);
+ ++(*total_weight);
+ stream->Finalize();
+}
+
+void GenerateOneZeroWeightedValue(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ stream->PushEntry(10, 0);
+ stream->Finalize();
+}
+
+TEST(WeightedQuantilesStreamTest, OneValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneValue,
+ {10.0, 10.0, 10.0, 10.0, 10.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, OneZeroWeightValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneZeroWeightedValue, {},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+// Distributed tests.
+void TestDistributedStreams(
+ int32 num_workers, double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Simulate streams on each worker running independently
+ double total_weight = 0;
+ std::vector<std::vector<SummaryEntry>> worker_summaries;
+ for (int32 i = 0; i < num_workers; ++i) {
+ Stream stream(eps / 2, max_elements);
+ worker_summary_generator(i, max_elements / num_workers, &total_weight,
+ &stream);
+ worker_summaries.push_back(stream.GetFinalSummary().GetEntryList());
+ }
+
+ // In the accumulation phase, we aggregate the summaries from each worker
+ // and build an overall summary while maintaining error bounds by ensuring we
+ // don't increase the error by more than eps / 2.
+ Stream reducer_stream(eps, max_elements);
+ for (const auto &summary : worker_summaries) {
+ reducer_stream.PushSummary(summary);
+ }
+ reducer_stream.Finalize();
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(reducer_stream.ApproximationError(), eps);
+ EXPECT_NEAR(reducer_stream.GetFinalSummary().TotalWeight(), total_weight,
+ total_weight);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals =
+ reducer_stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(num_workers, eps, max_elements,
+ GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
new file mode 100644
index 0000000000..31d7fe25a4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
@@ -0,0 +1,344 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+
+#include <cstring>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Summary holding a sorted block of entries with upper bound guarantees
+// over the approximation error.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesSummary {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+
+ struct SummaryEntry {
+ SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
+ const WeightType& max) {
+ // Explicitly initialize all of memory (including padding from memory
+ // alignment) to allow the struct to be msan-resistant "plain old data".
+ //
+ // POD = http://en.cppreference.com/w/cpp/concept/PODType
+ memset(this, 0, sizeof(*this));
+
+ value = v;
+ weight = w;
+ min_rank = min;
+ max_rank = max;
+ }
+
+ SummaryEntry() {
+ memset(this, 0, sizeof(*this));
+
+ value = ValueType();
+ weight = 0;
+ min_rank = 0;
+ max_rank = 0;
+ }
+
+ bool operator==(const SummaryEntry& other) const {
+ return value == other.value && weight == other.weight &&
+ min_rank == other.min_rank && max_rank == other.max_rank;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const SummaryEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << ", "
+ << entry.min_rank << ", " << entry.max_rank << "}";
+ }
+
+ // Max rank estimate for previous smaller value.
+ WeightType PrevMaxRank() const { return max_rank - weight; }
+
+ // Min rank estimate for next larger value.
+ WeightType NextMinRank() const { return min_rank + weight; }
+
+ ValueType value;
+ WeightType weight;
+ WeightType min_rank;
+ WeightType max_rank;
+ };
+
+ // Re-construct summary from the specified buffer.
+ void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
+ entries_.clear();
+ entries_.reserve(buffer_entries.size());
+ WeightType cumulative_weight = 0;
+ for (const auto& entry : buffer_entries) {
+ WeightType current_weight = entry.weight;
+ entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
+ cumulative_weight + current_weight);
+ cumulative_weight += current_weight;
+ }
+ }
+
+ // Re-construct summary from the specified summary entries.
+ void BuildFromSummaryEntries(
+ const std::vector<SummaryEntry>& summary_entries) {
+ entries_.clear();
+ entries_.reserve(summary_entries.size());
+ entries_.insert(entries_.begin(), summary_entries.begin(),
+ summary_entries.end());
+ }
+
+ // Merges two summaries through an algorithm that's derived from MergeSort
+ // for summary entries while guaranteeing that the max approximation error
+ // of the final merged summary is no greater than the approximation errors
+ // of each individual summary.
+ // For example consider summaries where each entry is of the form
+ // (element, weight, min rank, max rank):
+ // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
+ // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
+ // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
+ void Merge(const WeightedQuantilesSummary& other_summary) {
+ // Make sure we have something to merge.
+ const auto& other_entries = other_summary.entries_;
+ if (other_entries.empty()) {
+ return;
+ }
+ if (entries_.empty()) {
+ BuildFromSummaryEntries(other_summary.entries_);
+ return;
+ }
+
+ // Move current entries to make room for a new buffer.
+ std::vector<SummaryEntry> base_entries(std::move(entries_));
+ entries_.clear();
+ entries_.reserve(base_entries.size() + other_entries.size());
+
+ // Merge entries maintaining ranks. The idea is to stack values
+ // in order which we can do in linear time as the two summaries are
+ // already sorted. We keep track of the next lower rank from either
+ // summary and update it as we pop elements from the summaries.
+ // We handle the special case when the next two elements from either
+ // summary are equal, in which case we just merge the two elements
+ // and simultaneously update both ranks.
+ auto it1 = base_entries.cbegin();
+ auto it2 = other_entries.cbegin();
+ WeightType next_min_rank1 = 0;
+ WeightType next_min_rank2 = 0;
+ while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
+ if (kCompFn(it1->value, it2->value)) { // value1 < value2
+ // Take value1 and use the last added value2 to compute
+ // the min rank and the current value2 to compute the max rank.
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + it2->PrevMaxRank());
+ // Update next min rank 1.
+ next_min_rank1 = it1->NextMinRank();
+ ++it1;
+ } else if (kCompFn(it2->value, it1->value)) { // value1 > value2
+ // Take value2 and use the last added value1 to compute
+ // the min rank and the current value1 to compute the max rank.
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + it1->PrevMaxRank());
+ // Update next min rank 2.
+ next_min_rank2 = it2->NextMinRank();
+ ++it2;
+ } else { // value1 == value2
+ // Straight additive merger of the two entries into one.
+ entries_.emplace_back(it1->value, it1->weight + it2->weight,
+ it1->min_rank + it2->min_rank,
+ it1->max_rank + it2->max_rank);
+ // Update next min ranks for both.
+ next_min_rank1 = it1->NextMinRank();
+ next_min_rank2 = it2->NextMinRank();
+ ++it1;
+ ++it2;
+ }
+ }
+
+ // Fill in any residual.
+ while (it1 != base_entries.cend()) {
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + other_entries.back().max_rank);
+ ++it1;
+ }
+ while (it2 != other_entries.cend()) {
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + base_entries.back().max_rank);
+ ++it2;
+ }
+ }
+
+ // Compresses buffer into desired size. The size specification is
+ // considered a hint as we always keep the first and last elements and
+ // maintain strict approximation error bounds.
+ // The approximation error delta is taken as the max of either the requested
+ // min error or 1 / size_hint.
+ // After compression, the approximation error is guaranteed to increase
+ // by no more than that error delta.
+ // This algorithm is linear in the original size of the summary and is
+ // designed to be cache-friendly.
+ void Compress(int64 size_hint, double min_eps = 0) {
+ // No-op if we're already within the size requirement.
+ size_hint = std::max(size_hint, int64{2});
+ if (entries_.size() <= size_hint) {
+ return;
+ }
+
+ // First compute the max error bound delta resulting from this compression.
+ double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
+
+ // Compress elements ensuring approximation bounds and elements diversity
+ // are both maintained.
+ int64 add_accumulator = 0, add_step = entries_.size();
+ auto write_it = entries_.begin() + 1, last_it = write_it;
+ for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
+ auto next_it = read_it + 1;
+ while (next_it != entries_.end() && add_accumulator < add_step &&
+ next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
+ add_accumulator += size_hint;
+ ++next_it;
+ }
+ if (read_it == next_it - 1) {
+ ++read_it;
+ } else {
+ read_it = next_it - 1;
+ }
+ (*write_it++) = (*read_it);
+ last_it = read_it;
+ add_accumulator -= add_step;
+ }
+ // Write last element and resize.
+ if (last_it + 1 != entries_.end()) {
+ (*write_it++) = entries_.back();
+ }
+ entries_.resize(write_it - entries_.begin());
+ }
+
+ // To construct the boundaries we first run a soft compress over a copy
+ // of the summary and retrieve the values.
+ // The resulting boundaries are guaranteed to both contain at least
+ // num_boundaries unique elements and maintain approximation bounds.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+
+ // Generate soft compressed summary.
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
+ compressed_summary;
+ compressed_summary.BuildFromSummaryEntries(entries_);
+ // Set an epsilon for compression that's at most 1.0 / num_boundaries
+ // more than epsilon of original our summary since the compression operation
+ // adds ~1.0/num_boundaries to final approximation error.
+ float compression_eps = ApproximationError() + (1.0 / num_boundaries);
+ compressed_summary.Compress(num_boundaries, compression_eps);
+
+ // Return boundaries.
+ output.reserve(compressed_summary.entries_.size());
+ for (const auto& entry : compressed_summary.entries_) {
+ output.push_back(entry.value);
+ }
+ return output;
+ }
+
+ // To construct the desired n-quantiles we repetitively query n ranks from the
+ // original summary. The following algorithm is an efficient cache-friendly
+ // O(n) implementation of that idea which avoids the cost of the repetitive
+ // full rank queries O(nlogn).
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+ num_quantiles = std::max(num_quantiles, int64{2});
+ output.reserve(num_quantiles + 1);
+
+ // Make successive rank queries to get boundaries.
+ // We always keep the first (min) and last (max) entries.
+ for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
+ // This step boils down to finding the next element sub-range defined by
+ // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
+ WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
+ size_t next_idx = cur_idx + 1;
+ while (next_idx < entries_.size() &&
+ d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
+ ++next_idx;
+ }
+ cur_idx = next_idx - 1;
+
+ // Determine insertion order.
+ if (next_idx == entries_.size() ||
+ d_2 < entries_[cur_idx].NextMinRank() +
+ entries_[next_idx].PrevMaxRank()) {
+ output.push_back(entries_[cur_idx].value);
+ } else {
+ output.push_back(entries_[next_idx].value);
+ }
+ }
+ return output;
+ }
+
+ // Calculates current approximation error which should always be <= eps.
+ double ApproximationError() const {
+ if (entries_.empty()) {
+ return 0;
+ }
+
+ WeightType max_gap = 0;
+ for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
+ max_gap = std::max(max_gap,
+ std::max(it->max_rank - it->min_rank - it->weight,
+ it->PrevMaxRank() - (it - 1)->NextMinRank()));
+ }
+ return static_cast<double>(max_gap) / TotalWeight();
+ }
+
+ ValueType MinValue() const {
+ return !entries_.empty() ? entries_.front().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ ValueType MaxValue() const {
+ return !entries_.empty() ? entries_.back().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ WeightType TotalWeight() const {
+ return !entries_.empty() ? entries_.back().max_rank : 0;
+ }
+ int64 Size() const { return entries_.size(); }
+ void Clear() { entries_.clear(); }
+ const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
+
+ private:
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Summary entries.
+ std::vector<SummaryEntry> entries_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
new file mode 100644
index 0000000000..ccd1215cf4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
@@ -0,0 +1,223 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer = boosted_trees::quantiles::WeightedQuantilesBuffer<float, float>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<float,
+ float>::BufferEntry;
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+class WeightedQuantilesSummaryTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Constructs a buffer of 10 weighted unique entries.
+ buffer1_.reset(new Buffer(10, 1000));
+ buffer1_->PushEntry(5, 9);
+ buffer1_->PushEntry(2, 3);
+ buffer1_->PushEntry(-1, 7);
+ buffer1_->PushEntry(-7, 1);
+ buffer1_->PushEntry(3, 2);
+ buffer1_->PushEntry(-2, 3);
+ buffer1_->PushEntry(21, 8);
+ buffer1_->PushEntry(-13, 4);
+ buffer1_->PushEntry(8, 2);
+ buffer1_->PushEntry(-5, 6);
+
+ // Constructs a buffer of 7 weighted unique entries.
+ buffer2_.reset(new Buffer(7, 1000));
+ buffer2_->PushEntry(9, 2);
+ buffer2_->PushEntry(-7, 3);
+ buffer2_->PushEntry(2, 1);
+ buffer2_->PushEntry(4, 13);
+ buffer2_->PushEntry(0, 5);
+ buffer2_->PushEntry(-5, 3);
+ buffer2_->PushEntry(11, 3);
+ }
+
+ void TearDown() override { buffer1_->Clear(); }
+
+ std::unique_ptr<Buffer> buffer1_;
+ std::unique_ptr<Buffer> buffer2_;
+ const double buffer1_min_value_ = -13;
+ const double buffer1_max_value_ = 21;
+ const double buffer1_total_weight_ = 45;
+ const double buffer2_min_value_ = -7;
+ const double buffer2_max_value_ = 11;
+ const double buffer2_total_weight_ = 30;
+};
+
+TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+
+ // We expect no approximation error because no compress operation occurred.
+ EXPECT_EQ(summary.ApproximationError(), 0);
+
+ // Check first and last elements in the summary.
+ const auto& entries = summary.GetEntryList();
+ // First element's rmin should be zero.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(entries.front(), SummaryEntry(-13, 4, 0, 4));
+ // Last element's rmax should be cumulative weight.
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(entries.back(), SummaryEntry(21, 8, 37, 45));
+ // Check total weight.
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) {
+ const auto entry_list = buffer1_->GenerateEntryList();
+ for (int new_size = 9; new_size >= 2; --new_size) {
+ Summary summary;
+ summary.BuildFromBufferEntries(entry_list);
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of 1 / n
+ // ie. eps0 + 1/n but eps0 = 0.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSequentially) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ for (int new_size = 9; new_size >= 2; new_size -= 2) {
+ double prev_eps = summary.ApproximationError();
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of prev_eps + 1 / n.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), prev_eps + 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) {
+ // Check multiple size compressions and ensure approximation bounds
+ // are always respected.
+ int prev_size = 1;
+ int size = 2;
+ float max_value = 1 << 20;
+ while (size < (1 << 16)) {
+ // Create buffer of size from uniform random elements.
+ Buffer buffer(size, size << 4);
+ random::PhiloxRandom philox(13);
+ random::SimplePhilox rand(&philox);
+ for (int i = 0; i < size; ++i) {
+ buffer.PushEntry(rand.RandFloat() * max_value,
+ rand.RandFloat() * max_value);
+ }
+
+ // Create summary and compress.
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer.GenerateEntryList());
+ int new_size = std::max(rand.Uniform(size), 2u);
+ summary.Compress(new_size);
+
+ // Ensure approximation error is acceptable.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Update size to next fib number.
+ size_t last_size = size;
+ size += prev_size;
+ prev_size = last_size;
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
+ // Create two separate summaries and merge.
+ const auto list_1 = buffer1_->GenerateEntryList();
+ const auto list_2 = buffer2_->GenerateEntryList();
+ Summary summary1;
+ summary1.BuildFromBufferEntries(list_1);
+ Summary summary2;
+ summary2.BuildFromBufferEntries(list_2);
+
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_EQ(summary1.ApproximationError(), 0.0);
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary1.Size(), 14); // 14 unique values.
+
+ // Merge summary 1 into 2 and verify same result.
+ summary1.BuildFromBufferEntries(list_1);
+ summary2.Merge(summary1);
+ EXPECT_EQ(summary2.ApproximationError(), 0.0);
+ EXPECT_EQ(summary2.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary2.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary2.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary2.Size(), 14); // 14 unique values.
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressThenMerge) {
+ // Create two separate summaries and merge.
+ Summary summary1;
+ summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ Summary summary2;
+ summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList());
+
+ // Compress summaries.
+ summary1.Compress(5); // max error is 1/5.
+ const auto eps1 = 1.0 / 5;
+ EXPECT_LE(summary1.ApproximationError(), eps1);
+ summary2.Compress(3); // max error is 1/3.
+ const auto eps2 = 1.0 / 3;
+ EXPECT_LE(summary2.ApproximationError(), eps2);
+
+ // Merge guarantees an approximation error of max(eps1, eps2).
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_LE(summary1.ApproximationError(), std::max(eps1, eps2));
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h
index c8c60c5524..18727c0db3 100644
--- a/tensorflow/core/kernels/bounds_check.h
+++ b/tensorflow/core/kernels/bounds_check.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
-#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
+#define TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
#include <type_traits>
@@ -51,4 +51,4 @@ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC const T SubtleMustCopy(const T &x) {
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
+#endif // TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h
index 73fdd5d28e..a2327a7272 100644
--- a/tensorflow/core/kernels/broadcast_to_op.h
+++ b/tensorflow/core/kernels/broadcast_to_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
-#define TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -239,4 +239,4 @@ struct BroadcastTo {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
diff --git a/tensorflow/core/kernels/bucketize_op.h b/tensorflow/core/kernels/bucketize_op.h
index c8e461beb9..32be475f86 100644
--- a/tensorflow/core/kernels/bucketize_op.h
+++ b/tensorflow/core/kernels/bucketize_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_BUCKETIZE_OP_H_
-#define TENSORFLOW_BUCKETIZE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -38,4 +38,4 @@ struct BucketizeFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_BUCKETIZE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 654d99301a..663bff3657 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -89,9 +89,9 @@ class BaseCandidateSamplerOp : public OpKernel {
// Pick sampled candidates.
auto local_gen = generator_.ReserveSamples32(samples32);
random::SimplePhilox random(&local_gen);
- sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate,
- &sampled_expected_count,
- true_candidate, &true_expected_count);
+ sampler_->SampleBatchGetExpectedCount(&random, unique_, sampled_candidate,
+ sampled_expected_count,
+ true_candidate, true_expected_count);
if (sampler_->NeedsUpdates()) {
sampler_->Update(true_candidate);
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 0478c93280..3a72567655 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
ctx->set_output(0, inp);
} else {
Tensor in;
- in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ if (external_src_dtype_ != src_dtype_) {
+ // If the type is a quantized type we need to do an UnsafeCopyFromInternal
+ // since the src_dtype_ is different from external_src_type_.
+ in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ } else {
+ in = inp;
+ }
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index 527ab528c9..84c44f6b5e 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CAST_OP_H_
-#define TENSORFLOW_KERNELS_CAST_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bfloat16.h"
@@ -323,4 +323,4 @@ struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
} // namespace internal
} // namespace Eigen
-#endif // TENSORFLOW_KERNELS_CAST_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 5de41bac72..e0da91125b 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,14 +132,19 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // Allocate the output tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c,
- c->forward_input_or_allocate_output(
- {0}, 0, c->input(0).shape(), &output),
- done);
-
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
@@ -183,16 +188,23 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->forward_input_or_allocate_output({0}, 0, shape_, &output),
+ done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
c, shape_.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
- // Allocate the output Tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
@@ -239,10 +251,16 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // No input, so must allocate output.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
diff --git a/tensorflow/core/kernels/colorspace_op.h b/tensorflow/core/kernels/colorspace_op.h
index 90bfce1419..4de14bc339 100644
--- a/tensorflow/core/kernels/colorspace_op.h
+++ b/tensorflow/core/kernels/colorspace_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -91,4 +91,4 @@ struct HSVToRGB {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h
index 16784c4770..8b53ecf121 100644
--- a/tensorflow/core/kernels/concat_lib.h
+++ b/tensorflow/core/kernels/concat_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONCAT_LIB_H_
-#define TENSORFLOW_KERNELS_CONCAT_LIB_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
+#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
#include <vector>
@@ -66,4 +66,4 @@ void ConcatSYCL(
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONCAT_LIB_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h
index 720b506537..29f3a427fe 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.h
+++ b/tensorflow/core/kernels/concat_lib_cpu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -162,3 +165,5 @@ void ConcatSYCLImpl(
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index 414891b142..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected:
@@ -133,4 +135,4 @@ class ConditionalAccumulator
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c56b8..292cf0cd64 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
- const DataType& dtype, const PartialTensorShape& shape, const string& name)
- : dtype_(dtype), shape_(shape), name_(name) {
+ const DataType& dtype, const PartialTensorShape& shape, const string& name,
+ const string& reduction_type)
+ : dtype_(dtype),
+ shape_(shape),
+ name_(name),
+ reduction_type_(reduction_type) {
counter_ = 0;
current_global_step_ = 0;
}
@@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++;
// Average the accumulated gradient
- DivideAccumGradByCounter(ctx);
+ if (reduction_type_ == "MEAN") {
+ DivideAccumGradByCounter(ctx);
+ }
// Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index c7c7c98369..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
#include <deque>
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);
@@ -199,4 +200,4 @@ class TypeConverter<Eigen::half, U> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 33c2d596c8..ca24d690f8 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
#define EIGEN_USE_THREADS
@@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("reduction_type", &reduction_type_));
}
void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_;
PartialTensorShape shape_;
ContainerInfo cinfo_;
+ string reduction_type_;
private:
Status SetAccumulatorHandle(OpKernelContext* ctx)
@@ -234,4 +237,4 @@ class ConditionalAccumulatorBaseTakeGradientOp
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a4c6..52ac51a9b6 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator =
- new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+ new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+ reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 375819a8a2..426c404f43 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -259,8 +259,9 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
- TensorShape({}));
+ // DT_VARIANT tensors must be allocated on CPU since they wrap C++
+ // objects which can not be efficiently represented in GPU memory.
+ Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
index 8edbcc9077..c607fcf298 100644
--- a/tensorflow/core/kernels/control_flow_ops.h
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
-#define TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -115,4 +115,4 @@ class LoopCondOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 6b7544fd4c..de9b69828e 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONV_2D_H_
-#define TENSORFLOW_KERNELS_CONV_2D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_2D_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -298,4 +298,4 @@ template <>
class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_2D_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_H_
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 083dec63cc..02e3655ad1 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -15,8 +15,8 @@ limitations under the License.
// Functors for 3d convolution.
-#ifndef TENSORFLOW_KERNELS_CONV_3D_H_
-#define TENSORFLOW_KERNELS_CONV_3D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_3D_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
@@ -45,4 +45,4 @@ struct CuboidConvolution<CPUDevice, T> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONV_3D_H_
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index 09a3b78776..adf4601b43 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONV_OPS_H_
-#define TENSORFLOW_KERNELS_CONV_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -68,4 +68,4 @@ struct Im2ColBufferResource : public ResourceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_OPS_H
+#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/cross_op.h b/tensorflow/core/kernels/cross_op.h
index ca6beba52b..45bc46a921 100644
--- a/tensorflow/core/kernels/cross_op.h
+++ b/tensorflow/core/kernels/cross_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -51,4 +51,4 @@ struct Cross {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index b2e8ee23a9..2c30d036df 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================
*/
+#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+#define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+
// This header declares the class CudaSolver, which contains wrappers of linear
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels.
@@ -433,3 +436,5 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h
index 280d697fc2..738e928246 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.h
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h
@@ -15,8 +15,8 @@ limitations under the License.
// Helper functions to run 3d pooling on GPU using CuDNN.
-#ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
-#define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
#include <array>
@@ -67,4 +67,4 @@ class DnnPooling3dGradOp {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index d6a2403816..313d976e2c 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,8 +24,7 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
-REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
- int32, int64);
+REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
@@ -34,6 +33,7 @@ REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
complex64, complex128);
+REGISTER2(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, float, double);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
index 0b05416274..25ccdcfb00 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -21,6 +21,7 @@ namespace tensorflow {
namespace functor {
DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32,
int64, complex64, complex128);
+DEFINE_BINARY2(div_no_nan, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc
index 2c5538534c..dc064eec5f 100644
--- a/tensorflow/core/kernels/cwise_op_zeta.cc
+++ b/tensorflow/core/kernels/cwise_op_zeta.cc
@@ -18,4 +18,9 @@ limitations under the License.
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double);
REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
+
+#if GOOGLE_CUDA
+REGISTER2(BinaryOp, GPU, "Zeta", functor::zeta, float, double);
+REGISTER2(BinaryOp, GPU, "Polygamma", functor::polygamma, float, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 1014519059..22eb66e979 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
#include <cmath>
#include <functional>
@@ -154,8 +154,8 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};
template <typename T>
-struct unsafe_div_op {
- EIGEN_EMPTY_STRUCT_CTOR(unsafe_div_op)
+struct div_no_nan_op {
+ EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
const T& b) const {
if (b != 0) {
@@ -167,7 +167,7 @@ struct unsafe_div_op {
};
template <typename T>
-struct functor_traits<unsafe_div_op<T>> {
+struct functor_traits<div_no_nan_op<T>> {
enum {
Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
PacketAccess = false,
@@ -742,7 +742,7 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
-struct unsafe_div : base<T, Eigen::internal::unsafe_div_op<T>> {};
+struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
@@ -1036,4 +1036,4 @@ struct BatchSelectFunctor {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index e32eccf547..f77d7238af 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
// See docs in ../ops/math_ops.cc.
@@ -602,4 +602,4 @@ struct ApproximateEqual<CPUDevice, T> {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index 965e42dcce..cfae273bf4 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
#define EIGEN_USE_GPU
@@ -188,4 +188,4 @@ struct ApproximateEqual<GPUDevice, T> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
index e81b840a50..15e5de0f72 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
#define EIGEN_USE_GPU
@@ -68,4 +68,4 @@ struct SimpleBinaryFunctor<GPUDevice, Functor> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 7a6f14babc..53b53cc277 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -208,4 +208,4 @@ struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 607a694dba..3a1ac73f64 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -51,6 +51,7 @@ cc_library(
hdrs = ["captured_function.h"],
deps = [
":dataset",
+ ":single_threaded_executor",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -61,6 +62,42 @@ cc_library(
)
cc_library(
+ name = "single_threaded_executor",
+ srcs = ["single_threaded_executor.cc"],
+ hdrs = ["single_threaded_executor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "single_threaded_executor_test",
+ srcs = ["single_threaded_executor_test.cc"],
+ deps = [
+ ":single_threaded_executor",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:array",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:function_ops",
+ "//tensorflow/core/kernels:math",
+ "//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:state",
+ ],
+)
+
+cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
hdrs = ["window_dataset.h"],
@@ -233,6 +270,16 @@ cc_library(
)
tf_kernel_library(
+ name = "parse_example_dataset_op",
+ srcs = ["parse_example_dataset_op.cc"],
+ deps = [
+ ":parallel_map_iterator",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_kernel_library(
name = "parallel_map_dataset_op",
srcs = ["parallel_map_dataset_op.cc"],
deps = [
@@ -471,8 +518,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -495,8 +541,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -668,6 +713,7 @@ tf_kernel_library(
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
+ ":parse_example_dataset_op",
":prefetch_dataset_op",
":random_dataset_op",
":range_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index f9b5353724..a25f78c6f1 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -241,5 +241,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
BatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 6ca0bcd37d..221b5ad835 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level description of
@@ -891,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
CacheDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 82da385405..ad2365b25b 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -23,12 +23,22 @@ limitations under the License.
#include "tensorflow/core/platform/notification.h"
namespace tensorflow {
+namespace data {
/* static */
Status CapturedFunction::Create(
const NameAttrList& func, std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs)));
+ return Create(func, std::move(captured_inputs), true, out_function);
+}
+
+/* static */
+Status CapturedFunction::Create(
+ const NameAttrList& func, std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism,
+ std::unique_ptr<CapturedFunction>* out_function) {
+ out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
+ use_inter_op_parallelism));
return Status::OK();
}
@@ -172,31 +182,17 @@ class BorrowedArgsCallFrame : public CallFrameBase {
} // namespace
-Status CapturedFunction::MaybeInstantiate(
- IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
- mutex_lock l(mu_);
+Status CapturedFunction::GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle) {
+ tf_shared_lock l(mu_);
if (lib_ == nullptr) {
- // The context's runtime will be used for all subsequent calls.
- lib_ = ctx->lib();
- DCHECK(f_handle_ == kInvalidHandle);
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
- inst_opts.state_handle = std::to_string(random::New64());
- TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
- inst_opts, &f_handle_));
- const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
- if (fbody == nullptr) {
- return errors::Internal("Failed to instantiate function body.");
- }
- ret_types_ = fbody->ret_types;
- } else {
- // TODO(mrry): Consider moving this under a shared lock, as it is
- // the common case.
- if (ctx->lib() != lib_) {
- return errors::Internal(
- "Captured function was called with a different "
- "FunctionLibraryRuntime*, which is not permitted.");
- }
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called before it was instantiated.");
+ }
+ if (ctx->lib() != lib_) {
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
}
*out_handle = f_handle_;
return Status::OK();
@@ -205,7 +201,7 @@ Status CapturedFunction::MaybeInstantiate(
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -242,7 +238,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -277,9 +273,33 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
}
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
- FunctionLibraryRuntime::Handle unused_handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
mutex_lock l(mu_);
+ if (lib_ == nullptr) {
+ // The context's runtime will be used for all subsequent calls.
+ lib_ = ctx->lib();
+ DCHECK(f_handle_ == kInvalidHandle);
+ FunctionLibraryRuntime::InstantiateOptions inst_opts;
+ inst_opts.overlay_lib = ctx->function_library().get();
+ inst_opts.state_handle = std::to_string(random::New64());
+ inst_opts.create_kernels_eagerly = true;
+ if (!use_inter_op_parallelism_) {
+ inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
+ }
+ Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
+ inst_opts, &f_handle_));
+ TF_RETURN_IF_ERROR(s);
+ const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
+ if (fbody == nullptr) {
+ return errors::Internal("Failed to instantiate function body.");
+ }
+ ret_types_ = fbody->ret_types;
+ } else {
+ if (ctx->lib() != lib_) {
+ return errors::Internal(
+ "Captured function was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
+ }
+ }
if (captured_runner_ == nullptr) {
captured_runner_ = *ctx->runner();
}
@@ -343,7 +363,7 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
FunctionLibraryRuntime::Handle handle;
- Status s = MaybeInstantiate(ctx, &handle);
+ Status s = GetHandle(ctx, &handle);
if (!s.ok()) {
done(s);
return;
@@ -391,10 +411,13 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs)
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism)
: func_(func),
lib_(nullptr),
f_handle_(kInvalidHandle),
- captured_inputs_(std::move(captured_inputs)) {}
+ captured_inputs_(std::move(captured_inputs)),
+ use_inter_op_parallelism_(use_inter_op_parallelism) {}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e9ad3e381d..e44bc78b1c 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -32,6 +32,8 @@ class Device;
class OpKernelContext;
class ResourceMgr;
+namespace data {
+
// A `CapturedFunction` encapsulates a TensorFlow function and all of
// the runtime support required to execute it.
//
@@ -48,6 +50,15 @@ class CapturedFunction {
std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
+ // Creates a new instance from a list of named attributes and captured inputs.
+ //
+ // If `use_inter_op_parallelism` is false, the runtime may use an executor
+ // that is optimized for small functions.
+ static Status Create(const NameAttrList& func,
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism,
+ std::unique_ptr<CapturedFunction>* out_function);
+
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
@@ -114,10 +125,11 @@ class CapturedFunction {
private:
CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs);
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism);
- Status MaybeInstantiate(IteratorContext* ctx,
- FunctionLibraryRuntime::Handle* out_handle);
+ Status GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle);
mutex mu_;
const NameAttrList func_;
@@ -126,10 +138,17 @@ class CapturedFunction {
const std::vector<Tensor> captured_inputs_;
DataTypeSlice ret_types_;
std::function<void(std::function<void()>)> captured_runner_ = nullptr;
+ const bool use_inter_op_parallelism_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::CapturedFunction;
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index c361a9adcb..a04f150e71 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
ConcatenateDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index c71d027f23..bd1ccd5b5d 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
DatasetToGraphOp);
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index d85ef1cbab..e7ac368ae3 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
@@ -45,6 +44,5 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
-} // namespace dataset
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6c4191c2be..234856ea39 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -20,16 +20,14 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
-} // namespace dataset
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 9770bc025d..237511a07d 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
DenseToSparseBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index ce577397c5..a7e3a56727 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU),
FilterByLastComponentDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index a80e102ccf..bf0aecaf3c 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -112,7 +112,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -149,7 +149,9 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<FilterDatasetBase>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -278,5 +280,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU),
FilterDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 07bcb9d414..e3c45ef86c 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -94,7 +94,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -129,7 +129,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -243,7 +245,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
private:
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return dataset::MakeIteratorFromInputElement(
+ return MakeIteratorFromInputElement(
ctx, captured_func_inputs_, element_index_++,
dataset()->captured_func_.get(), prefix(),
&current_element_iterator_);
@@ -283,5 +285,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU),
FlatMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 3c3d78b724..ac5cc1b2c1 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -19,9 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -80,20 +82,20 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ return Status::OK();
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (!initialized_) {
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
- // Explicitly instantiate the finalize function here so that
- // we can invoke it in the destructor.
- TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- initialized_ = true;
- }
-
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -121,7 +123,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
- bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
@@ -188,10 +189,13 @@ void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
std::move(finalize_func), output_types_, output_shapes_);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
index 3f84fa9c2e..d23ed97ec3 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.h
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
namespace tensorflow {
+namespace data {
class GeneratorDatasetOp : public DatasetOpKernel {
public:
@@ -37,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel {
NameAttrList finalize_func_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index be4132a064..d6ee42a7c6 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -109,11 +110,10 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), finalize_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -190,7 +190,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_finalize_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
@@ -427,4 +434,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
GroupByReducerDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 288695f3cd..e4fa557598 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -139,10 +140,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), window_size_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -205,7 +205,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_window_size_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
@@ -544,4 +550,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 58b79d6026..0768f46665 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -1,4 +1,3 @@
-
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -117,7 +116,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
@@ -156,7 +155,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@@ -200,7 +201,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(input_impl_->GetNext(
ctx, &args_list_[cycle_index_], &end_of_input_));
if (!end_of_input_) {
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[cycle_index_], cycle_index_,
dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]));
@@ -287,7 +288,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
prefix(), &current_elements_[idx]));
TF_RETURN_IF_ERROR(
@@ -329,5 +330,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU),
InterleaveDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 61a6c06135..fe6d705eab 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -104,9 +104,8 @@ class IteratorResource : public ResourceBase {
bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
+ CHECK_NOTNULL(lib_);
+ ctx->set_lib(lib_);
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
} else {
return errors::FailedPrecondition(
@@ -162,8 +161,10 @@ class IteratorResource : public ResourceBase {
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
- dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@@ -198,6 +199,8 @@ class IteratorResource : public ResourceBase {
return lib_def_;
}
+ FunctionLibraryRuntime* function_library_runtime() { return lib_; }
+
// Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) {
@@ -233,6 +236,8 @@ class IteratorResource : public ResourceBase {
const std::vector<PartialTensorShape> output_shapes_;
};
+namespace {
+
// Helper class for reading data from a VariantTensorData object.
class VariantTensorDataReader : public IteratorStateReader {
public:
@@ -258,7 +263,7 @@ class VariantTensorDataReader : public IteratorStateReader {
}
bool Contains(StringPiece key) override {
- return map_.find(key.ToString()) != map_.end();
+ return map_.find(string(key)) != map_.end();
}
private:
@@ -279,18 +284,18 @@ class VariantTensorDataReader : public IteratorStateReader {
template <typename T>
Status ReadScalarInternal(StringPiece key, T* val) {
- if (map_.find(key.ToString()) == map_.end()) {
+ if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
- *val = data_->tensors(map_[key.ToString()]).scalar<T>()();
+ *val = data_->tensors(map_[string(key)]).scalar<T>()();
return Status::OK();
}
Status ReadTensorInternal(StringPiece key, Tensor* val) {
- if (map_.find(key.ToString()) == map_.end()) {
+ if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
- *val = data_->tensors(map_[key.ToString()]);
+ *val = data_->tensors(map_[string(key)]);
return Status::OK();
}
@@ -339,7 +344,7 @@ class VariantTensorDataWriter : public IteratorStateWriter {
// Write key to the metadata proto. This gets written to `data_`
// when `Flush()` is called. We do this lazily to avoid multiple
// serialization calls.
- metadata_proto_.add_keys(key.ToString());
+ metadata_proto_.add_keys(string(key));
// Update tensors.
*(data_->add_tensors()) = val;
@@ -440,6 +445,8 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
+} // namespace
+
// Note that IteratorHandleOp holds a reference to the resource it creates. If
// cleaning up resources with DestroyResourceOp is important, consider creating
// resource containers with AnonymousIteratorHandleOp instead.
@@ -612,11 +619,15 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
core::ScopedUnref unref(iterator_resource);
std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(iterator_resource->function_library_runtime());
OP_REQUIRES_OK(
- ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
+namespace {
+
class ToSingleElementOp : public AsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
@@ -837,8 +848,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
std::unique_ptr<IteratorBase> iter;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
TF_RETURN_IF_ERROR(
- dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
@@ -880,6 +893,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
const int graph_def_version_;
};
+} // namespace
+
void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
@@ -922,39 +937,35 @@ void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
std::move(done)));
}
-class IteratorGetNextSyncOp : public OpKernel {
- public:
- explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
- core::ScopedUnref unref_iterator(iterator);
-
- std::vector<Tensor> components;
- bool end_of_sequence = false;
-
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.function_library = iterator->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- IteratorContext iter_ctx(std::move(params));
+void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
+ core::ScopedUnref unref_iterator(iterator);
+
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
- OP_REQUIRES_OK(ctx,
- iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
- OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
+ OP_REQUIRES_OK(ctx,
+ iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
+ OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
- for (int i = 0; i < components.size(); ++i) {
- // TODO(mrry): Check that the shapes match the shape attrs.
- ctx->set_output(i, components[i]);
- }
+ for (int i = 0; i < components.size(); ++i) {
+ // TODO(mrry): Check that the shapes match the shape attrs.
+ ctx->set_output(i, components[i]);
}
-};
+}
+
+namespace {
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
@@ -1036,6 +1047,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace
+
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
@@ -1107,6 +1120,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
+namespace {
+
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1201,4 +1216,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index e426febcce..8a2b2639a7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
+namespace data {
class IteratorResource;
@@ -116,6 +117,13 @@ class IteratorGetNextOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class IteratorGetNextSyncOp : public OpKernel {
+ public:
+ explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
class IteratorToStringHandleOp : public OpKernel {
public:
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
@@ -135,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 0e17011b05..27c89b3661 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -147,7 +147,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -204,7 +204,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -673,5 +675,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 294fb1c49a..af301e2b42 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,12 +28,12 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit MapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
@@ -48,7 +48,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ func_, std::move(other_arguments),
+ use_inter_op_parallelism_, &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
@@ -92,7 +93,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -127,7 +128,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -181,14 +184,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
+ bool use_inter_op_parallelism_;
};
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index d66716ef66..b87d61ee44 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,18 +18,20 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
+namespace data {
namespace {
void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
bool always_collect_stats) {
opts->step_id = ctx->step_id();
opts->rendezvous = ctx->rendezvous();
- opts->cancellation_manager = ctx->cancellation_manager();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
@@ -60,22 +62,43 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
+ Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
+ // Validates inputs and gets the size of their leading dimension.
+ *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != *batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+ return Status::OK();
+ }
+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size = ctx->input(0).dim_size(0);
+ int64 batch_size;
+ OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+
// Inputs
auto* args = new std::vector<Tensor>;
auto* arg_shapes = new std::vector<TensorShape>;
+
+ // Create a copy because every `Compute` may have different output shapes.
+ auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
arg_shapes->reserve(ctx->num_inputs());
args->reserve(ctx->num_inputs());
+ auto* mu = new mutex;
+
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
args->push_back(ctx->input(i));
arg_shapes->push_back(ctx->input(i).shape());
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- OP_REQUIRES_ASYNC(
- ctx, batch_size == ctx->input(i).dim_size(0),
- errors::InvalidArgument("All inputs must have the same dimension 0."),
- done);
}
// Outputs
@@ -83,10 +106,14 @@ class MapDefunOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
for (size_t i = 0; i < output_types().size(); ++i) {
- Tensor* out = nullptr;
- TensorShape output_shape = output_shapes_.at(i);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, batch_size);
+ OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
+ done);
+ }
}
SetRunOptions(ctx, &opts_, false);
@@ -94,15 +121,19 @@ class MapDefunOp : public AsyncOpKernel {
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes, OpOutputList* output,
- DoneCallback& done, const Status& status) {
+ std::vector<TensorShape>* arg_shapes,
+ std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
+ mutex* mu, DoneCallback& done, const Status& status) {
delete args;
delete arg_shapes;
delete output;
+ delete output_shapes;
+ delete mu;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+ ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
+ std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
@@ -110,13 +141,18 @@ class MapDefunOp : public AsyncOpKernel {
// Start from i = 1 because refcounted is initialized with refcount = 1
refcounted->Ref();
}
+
for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame =
- new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
+ auto* call_frame = new MapFunctionCallFrame(
+ *args, *arg_shapes, output_shapes, mu, output, this, i,
+ static_cast<size_t>(batch_size));
+ CancellationManager* c_mgr = new CancellationManager;
+ opts_.cancellation_manager = c_mgr;
ctx->function_library()->Run(
opts_, func_handle_, call_frame,
- [call_frame, refcounted](const Status& func_status) {
+ [call_frame, refcounted, c_mgr](const Status& func_status) {
delete call_frame;
+ delete c_mgr;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
@@ -126,18 +162,23 @@ class MapDefunOp : public AsyncOpKernel {
private:
FunctionLibraryRuntime::Handle func_handle_;
FunctionLibraryRuntime::Options opts_;
- std::vector<TensorShape> output_shapes_;
+ std::vector<PartialTensorShape> output_shapes_;
class MapFunctionCallFrame : public CallFrameInterface {
public:
MapFunctionCallFrame(const std::vector<Tensor>& args,
const std::vector<TensorShape>& arg_shapes,
- OpOutputList* output, OpKernel* kernel, size_t iter)
+ std::vector<PartialTensorShape>* output_shapes,
+ mutex* output_shapes_mutex, OpOutputList* output,
+ OpKernel* kernel, size_t iter, size_t batch_size)
: args_(args),
arg_shapes_(arg_shapes),
+ output_shapes_(output_shapes),
+ output_shapes_mutex_(output_shapes_mutex),
output_(output),
kernel_(kernel),
- iter_(iter) {}
+ iter_(iter),
+ batch_size_(batch_size) {}
~MapFunctionCallFrame() override {}
@@ -175,18 +216,41 @@ class MapDefunOp : public AsyncOpKernel {
"output: ",
index);
}
+ { // Locking scope
+ mutex_lock l(*output_shapes_mutex_);
+ if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ return errors::InvalidArgument(
+ "Mismatch in function retval shape, ", val.shape(),
+ ", and expected output shape,",
+ output_shapes_->at(index).DebugString(), ".");
+ }
+ if (!output_shapes_->at(index).IsFullyDefined()) {
+ // Given val, we have new information about the output shape at
+ // this index. Store the shape and allocate the output accordingly.
+ output_shapes_->at(index) = val.shape();
+
+ Tensor* out = nullptr;
+ TensorShape actual_shape = val.shape();
+ actual_shape.InsertDim(0, batch_size_);
+ TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ }
+ }
return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
const std::vector<Tensor>& args_;
const std::vector<TensorShape>& arg_shapes_;
+ std::vector<PartialTensorShape>* output_shapes_;
+ mutex* output_shapes_mutex_;
OpOutputList* output_;
const OpKernel* kernel_;
const size_t iter_;
+ const size_t batch_size_;
};
-}; // namespace
+};
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index b097598cd9..d5b725eac9 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -92,24 +93,35 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
+ std::vector<std::pair<string, Tensor>> input_list;
+ params.allow_stateful_functions = true;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ params.input_list = &input_list;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
+
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
VLOG(3) << "Before optimization: " << graph_def.DebugString();
+
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
VLOG(3) << "After optimization: " << graph_def.DebugString();
- flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
- graph_def.library()));
+
+ // Instantiate the optimized input pipeline by running the optimized graph
+ // using the optimized function library.
+ TF_RETURN_IF_ERROR(
+ ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
+ TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
+
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->function_library()->device());
- TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
- {output_node}, &outputs));
+
+ TF_RETURN_IF_ERROR(
+ graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
@@ -142,8 +154,14 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
- &input_impl_);
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = dataset()->lib_;
+ params.allocator_getter = ctx->allocator_getter();
+ return dataset()->optimized_input_->MakeIterator(
+ IteratorContext(params), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -153,8 +171,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
params.env = ctx->env();
params.runner = *(ctx->runner());
params.stats_aggregator_getter = ctx->stats_aggregator_getter();
- params.lib = ctx->lib();
- params.function_library = dataset()->flib_def_;
+ params.lib = dataset()->lib_;
params.allocator_getter = ctx->allocator_getter();
IteratorContext iter_ctx(params);
return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
@@ -236,7 +253,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
}
DatasetBase* optimized_input_;
- std::shared_ptr<FunctionLibraryDefinition> flib_def_;
+ FunctionLibraryRuntime* lib_ = nullptr;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;
@@ -252,4 +271,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
OptimizeDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index cfac45dbc7..b372d31a93 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
namespace tensorflow {
+namespace data {
namespace {
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
@@ -267,4 +268,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
index 6f25567678..2cbf2933f5 100644
--- a/tensorflow/core/kernels/data/optional_ops.h
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
namespace tensorflow {
+namespace data {
// Stores a DT_VARIANT value representing an Optional with the given value
// in the `output_index`^th output of the given kernel execution context.
@@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index be45eac46e..fd0e6c4cd0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -382,5 +382,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
PaddedBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index e492a8215a..640f1565b7 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
+#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,11 +22,12 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -34,8 +36,7 @@ namespace {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -125,6 +126,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
@@ -137,8 +139,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), interleave_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
@@ -251,7 +252,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
// It is implemented so that it matches the deterministic interleave
@@ -279,7 +282,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
- if (i == 0) {
+ const bool element_acquired_sloppily =
+ dataset()->sloppy_ && i > 1;
+ if (!element_acquired_sloppily) {
+ // If the element was acquired in the regular (non-sloppy)
+ // order, then advance the current block and cycle pointers to
+ // the next element in the regular order.
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % interleave_indices_.size();
@@ -678,7 +686,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
{
tf_shared_lock l(ckpt_mu_);
worker_thread_states_[thread_index].iterator_creation_status =
- dataset::MakeIteratorFromInputElement(
+ MakeIteratorFromInputElement(
ctx.get(), worker_thread_states_[thread_index].input,
thread_index, dataset()->captured_func_.get(), prefix(),
&worker_thread_states_[thread_index].iterator);
@@ -908,7 +916,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_thread_states_[index].iterator.reset();
} else {
std::unique_ptr<IteratorBase> iterator;
- Status s = dataset::MakeIteratorFromInputElement(
+ Status s = MakeIteratorFromInputElement(
ctx, worker_thread_states_[index].input, index,
dataset()->captured_func_.get(), prefix(), &iterator);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
@@ -1052,7 +1060,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
@@ -1061,6 +1068,593 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
-} // namespace
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+ explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OpInputList inputs;
+ OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
+
+ int64 cycle_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ OP_REQUIRES(
+ ctx, num_parallel_calls <= cycle_length,
+ errors::InvalidArgument(
+ "num_parallel_calls must less than or equal to cycle_length."));
+
+ // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
+ std::vector<Tensor> other_arguments;
+ other_arguments.reserve(inputs.size());
+ for (const Tensor& t : inputs) {
+ other_arguments.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(
+ interleave_func_, std::move(other_arguments), &captured_func));
+
+ *output = new Dataset(ctx, input, interleave_func_,
+ std::move(captured_func), cycle_length, block_length,
+ num_parallel_calls, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, int64 num_parallel_calls,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ interleave_func_(func),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ num_parallel_calls_(num_parallel_calls),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParallelInterleaveDatasetV2Op::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, num_parallel_calls_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ args_list_(params.dataset->cycle_length_),
+ current_elements_(params.dataset->cycle_length_),
+ element_in_use_(params.dataset->cycle_length_, false),
+ thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "parallel_interleave",
+ dataset()->cycle_length_ /* num_threads */,
+ false /* low_latency_hint */)) {}
+
+ ~Iterator() override {
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ do {
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty() &&
+ (!end_of_input_ || num_open_ > 0)) {
+ cond_var_.wait(l);
+ }
+ if (!invocation_results_.empty()) {
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ } while (result->skip);
+
+ if (result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ }
+ *end_of_sequence = false;
+ return result->status;
+ }
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->skip) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")),
+ ""));
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_open"), num_open_));
+ TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->skip = reader->Contains(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")));
+ result->notification.Notify();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_open"), &num_open_));
+ TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification; // used for coordination with the consumer
+ Status status; // the invocation status
+ std::vector<Tensor> return_values; // the invocation result values
+ bool skip; // if set the result should be skipped
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
+ }
+ }
+
+ // Fetches up to `results.size()` outputs from the cycle element at
+ // position `cycle_index`.
+ //
+ // If end of input is encountered, the `skip` field of the invocation
+ // result is used to identify results that should be skipped.
+ void FetchOutputs(
+ const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+ const std::vector<std::shared_ptr<InvocationResult>>& results)
+ LOCKS_EXCLUDED(mu_) {
+ bool end_of_input = false;
+ for (auto& result : results) {
+ if (!end_of_input) {
+ result->status = current_elements_[cycle_index]->GetNext(
+ ctx.get(), &result->return_values, &end_of_input);
+ }
+ if (end_of_input) {
+ result->skip = true;
+ }
+ result->notification.Notify();
+ if (!result->status.ok()) {
+ break;
+ }
+ }
+
+ // Release the ownership of the cycle element iterator, closing the
+ // iterator if end of input was encountered.
+ {
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
+ }
+ }
+ cond_var_.notify_all();
+ }
+
+ int64 MaxInvocationResults() {
+ return dataset()->cycle_length_ * dataset()->block_length_;
+ }
+
+ // Method responsible for 1) creating iterators out of input elements, 2)
+ // determining the order in which elements are fetched from the iterators,
+ // and 3) scheduling the fetching of the elements to a threadpool.
+ //
+ // This method runs in the `runner_thread` background thread.
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+ (element_in_use_[cycle_index_] ||
+ num_calls_ >= dataset()->num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
+
+ while (!element_in_use_[cycle_index_] &&
+ (!end_of_input_ || num_open_ > 0) &&
+ num_calls_ < dataset()->num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ ++num_open_;
+ }
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ }
+ }
+ cond_var_.notify_all();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ Status WriteCurrentElements(IteratorStateWriter* writer)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (current_elements_[idx]) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ args_list_[idx].size()));
+ for (int i = 0; i < args_list_[idx].size(); i++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ args_list_[idx][i]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ReadCurrentElements(IteratorContext* ctx,
+ IteratorStateReader* reader)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (reader->Contains(
+ full_name(strings::StrCat("args_size[", idx, "]")))) {
+ int64 args_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ &args_size));
+ args_list_[idx].resize(args_size);
+ for (int i = 0; i < args_size; i++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ &args_list_[idx][i]));
+ }
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+ ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[idx]));
+ TF_RETURN_IF_ERROR(
+ RestoreInput(ctx, reader, current_elements_[idx]));
+ } else {
+ current_elements_[idx].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads.
+ mutex mu_;
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism, there are slots available in the
+ // `invocation_results_` buffer, the current cycle element is not in use,
+ // and there are elements left to be fetched.
+ condition_variable cond_var_;
+
+ // Iterator for input elements.
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+
+ // Identifies current cycle element.
+ int64 cycle_index_ = 0;
+
+ // Arguments for creating an iterator for cycle elements.
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+
+ // Iterators for the current cycle elements. Concurrent access is
+ // protected by `element_in_use_`.
+ std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+ // Identifies cycle elements that are in use by worker threads.
+ std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+
+ // Identifies whether end of input has been reached.
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+
+ // Identifies the number of open iterators.
+ int64 num_open_ GUARDED_BY(mu_) = 0;
+
+ // Identifies the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+
+ // Identifies whether background activity should be cancelled.
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ };
+
+ const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+ ParallelInterleaveDatasetV2Op);
+
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a407abfce4..a0cb179eb8 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,11 +33,12 @@ namespace {
class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelMapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
protected:
@@ -60,10 +61,12 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ func_, std::move(other_arguments),
+ use_inter_op_parallelism_, &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
- output_shapes_, std::move(captured_func));
+ output_shapes_, use_inter_op_parallelism_,
+ std::move(captured_func));
}
private:
@@ -73,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func, int32 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
@@ -80,6 +84,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
num_parallel_calls_(num_parallel_calls),
output_types_(output_types),
output_shapes_(output_shapes),
+ use_inter_op_parallelism_(use_inter_op_parallelism),
captured_func_(std::move(captured_func)) {
input_->Ref();
}
@@ -88,16 +93,35 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
+ auto init_func = [this](IteratorContext* ctx) {
+ return captured_func_->Instantiate(ctx);
};
+ ParallelMapIteratorFunction map_func;
+ if (use_inter_op_parallelism_) {
+ map_func = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done));
+ };
+ } else {
+ map_func = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(
+ [this, ctx, result](std::vector<Tensor>& input_element,
+ StatusCallback& done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done));
+ },
+ std::move(input_element), std::move(done)));
+ };
+ }
+
return NewParallelMapIterator(
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(map_func), num_parallel_calls_);
+ std::move(init_func), std::move(map_func), num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -138,7 +162,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
b->AddScalar(num_parallel_calls_, &num_parallel_calls));
// Attr: f
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
AttrValue f;
b->BuildAttrValue(func_, &f);
@@ -163,12 +187,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const int32 num_parallel_calls_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
NameAttrList func_;
};
@@ -176,5 +201,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
ParallelMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4d32b719a4..4ae742aaaf 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -20,16 +20,19 @@ limitations under the License.
#include <vector>
namespace tensorflow {
+namespace data {
namespace {
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params,
- const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
- int32 num_parallel_calls)
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params),
input_dataset_(input_dataset),
+ init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
num_parallel_calls_(num_parallel_calls) {}
@@ -50,7 +53,12 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
- return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
+ if (init_func_) {
+ TF_RETURN_IF_ERROR(init_func_(ctx));
+ }
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
@@ -285,6 +293,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
const DatasetBase* const input_dataset_; // Not owned.
+ const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
@@ -311,8 +320,19 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
- params, input_dataset, std::move(map_func), num_parallel_calls));
+ return NewParallelMapIterator(params, input_dataset, nullptr,
+ std::move(map_func), num_parallel_calls);
+}
+
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
+ return std::unique_ptr<IteratorBase>(
+ new ParallelMapIterator(params, input_dataset, std::move(init_func),
+ std::move(map_func), num_parallel_calls));
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 2ce36c3869..dc26c5cf25 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
+namespace data {
// A function that transforms elements of one dataset into another
// asynchronously. The arguments are:
@@ -33,12 +34,21 @@ using ParallelMapIteratorFunction =
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
-// `input_dataset` using the given degree of parallelism.
+// `input_dataset` using the given degree of parallelism. `init_func` (if
+// specified) will be executed when the iterator is initialized (see
+// `IteratorBase::Initialize()`) and enables the user to specify error checking
+// logic that can fail early.
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
new file mode 100644
index 0000000000..0cf5db017b
--- /dev/null
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -0,0 +1,372 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
+#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tdense", &dense_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ for (int i = 0; i < dense_shapes_.size(); ++i) {
+ bool shape_ok = true;
+ if (dense_shapes_[i].dims() == -1) {
+ shape_ok = false;
+ } else {
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ if (dense_shapes_[i].dim_size(d) == -1) {
+ shape_ok = false;
+ }
+ }
+ }
+ OP_REQUIRES(ctx, shape_ok,
+ errors::InvalidArgument(
+ "dense_shapes[", i,
+ "] has unknown rank or unknown inner dimensions: ",
+ dense_shapes_[i].DebugString()));
+ TensorShape dense_shape;
+ if (dense_shapes_[i].dims() > 0 && dense_shapes_[i].dim_size(0) == -1) {
+ variable_length_.push_back(true);
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ dense_shape.AddDim(dense_shapes_[i].dim_size(d));
+ }
+ } else {
+ variable_length_.push_back(false);
+ dense_shapes_[i].AsTensorShape(&dense_shape);
+ }
+ elements_per_stride_.push_back(dense_shape.num_elements());
+ }
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+
+ OpInputList dense_default_tensors;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("dense_defaults", &dense_default_tensors));
+
+ OP_REQUIRES(ctx, dense_default_tensors.size() == dense_keys_.size(),
+ errors::InvalidArgument(
+ "Expected len(dense_defaults) == len(dense_keys) but got: ",
+ dense_default_tensors.size(), " vs. ", dense_keys_.size()));
+
+ std::vector<Tensor> dense_defaults;
+ dense_defaults.reserve(dense_default_tensors.size());
+ for (const Tensor& dense_default_t : dense_default_tensors) {
+ dense_defaults.push_back(dense_default_t);
+ }
+
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ const Tensor& def_value = dense_defaults[d];
+ if (variable_length_[d]) {
+ OP_REQUIRES(ctx, def_value.NumElements() == 1,
+ errors::InvalidArgument(
+ "dense_shape[", d, "] is a variable length shape: ",
+ dense_shapes_[d].DebugString(),
+ ", therefore "
+ "def_value[",
+ d,
+ "] must contain a single element ("
+ "the padding element). But its shape is: ",
+ def_value.shape().DebugString()));
+ } else if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, dense_shapes_[d].IsCompatibleWith(def_value.shape()),
+ errors::InvalidArgument(
+ "def_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " is not compatible with dense_shapes_[", d,
+ "] == ", dense_shapes_[d].DebugString()));
+ }
+ OP_REQUIRES(ctx, def_value.dtype() == dense_types_[d],
+ errors::InvalidArgument(
+ "dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != dense_types_[", d,
+ "] == ", DataTypeString(dense_types_[d])));
+ }
+
+ example::FastParseExampleConfig config;
+ std::map<string, int> key_to_output_index;
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ config.dense.push_back({dense_keys_[d], dense_types_[d], dense_shapes_[d],
+ dense_default_tensors[d], variable_length_[d],
+ elements_per_stride_[d]});
+ auto result = key_to_output_index.insert({dense_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ dense_keys_[d]));
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ config.sparse.push_back({sparse_keys_[d], sparse_types_[d]});
+ auto result = key_to_output_index.insert({sparse_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ sparse_keys_[d]));
+ }
+ int i = 0;
+ for (auto it = key_to_output_index.begin(); it != key_to_output_index.end();
+ it++) {
+ it->second = i++;
+ }
+
+ *output = new Dataset(ctx, input, std::move(dense_defaults),
+ std::move(sparse_keys_), std::move(dense_keys_),
+ std::move(key_to_output_index), std::move(config),
+ num_parallel_calls, sparse_types_, dense_types_,
+ dense_shapes_, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ std::vector<Tensor> dense_defaults, std::vector<string> sparse_keys,
+ std::vector<string> dense_keys,
+ std::map<string, int> key_to_output_index,
+ example::FastParseExampleConfig config, int32 num_parallel_calls,
+ const DataTypeVector& sparse_types,
+ const DataTypeVector& dense_types,
+ const std::vector<PartialTensorShape>& dense_shapes,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ dense_defaults_(std::move(dense_defaults)),
+ sparse_keys_(std::move(sparse_keys)),
+ dense_keys_(std::move(dense_keys)),
+ key_to_output_index_(std::move(key_to_output_index)),
+ config_(std::move(config)),
+ num_parallel_calls_(num_parallel_calls),
+ sparse_types_(sparse_types),
+ dense_types_(dense_types),
+ dense_shapes_(dense_shapes),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ auto map_fn = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())([this, ctx, input_element, result, done]() {
+ thread::ThreadPool* device_threadpool =
+ ctx->lib()->device()->tensorflow_cpu_worker_threads()->workers;
+ std::vector<string> slice_vec;
+ for (Tensor t : input_element) {
+ auto serialized_t = t.flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(),
+ serialized_t.size());
+ for (auto it = slice.begin(); it != slice.end(); it++)
+ slice_vec.push_back(*it);
+ }
+ example::FastParseExampleConfig config = config_;
+ // local copy of config_ for modification.
+ auto stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator) {
+ config.collect_feature_stats = true;
+ }
+ example::Result example_result;
+ Status s = FastParseExample(config, slice_vec, {}, device_threadpool,
+ &example_result);
+ if (s.ok()) {
+ (*result).resize(key_to_output_index_.size());
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ int output_index = key_to_output_index_.at(dense_keys_[d]);
+ CHECK(example_result.dense_values[d].dtype() ==
+ output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(example_result.dense_values[d].dtype())
+ << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ example_result.dense_values[d].shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << example_result.dense_values[d].shape().DebugString()
+ << ").";
+ (*result)[output_index] = example_result.dense_values[d];
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ Tensor serialized_sparse = Tensor(DT_VARIANT, TensorShape({3}));
+ auto serialized_sparse_t = serialized_sparse.vec<Variant>();
+ serialized_sparse_t(0) = example_result.sparse_indices[d];
+ serialized_sparse_t(1) = example_result.sparse_values[d];
+ serialized_sparse_t(2) = example_result.sparse_shapes[d];
+ int output_index = key_to_output_index_.at(sparse_keys_[d]);
+ CHECK(serialized_sparse.dtype() == output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(serialized_sparse.dtype()) << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ serialized_sparse.shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << serialized_sparse.shape().DebugString() << ").";
+ (*result)[output_index] = serialized_sparse;
+ }
+ // TODO(b/111553342): User provided tags instead of fixed tag.
+ if (stats_aggregator) {
+ stats_aggregator->IncrementCounter(
+ "examples_count", "trainer",
+ example_result.feature_stats.size());
+ for (example::PerExampleFeatureStats feature_stats :
+ example_result.feature_stats) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat("record_stats", ":features"),
+ {static_cast<double>(feature_stats.features_count)});
+ stats_aggregator->IncrementCounter(
+ "features_count", "trainer", feature_stats.features_count);
+ stats_aggregator->IncrementCounter(
+ "feature_values_count", "trainer",
+ feature_stats.feature_values_count);
+ stats_aggregator->AddToHistogram(
+ strings::StrCat("record_stats", ":feature-values"),
+ {static_cast<double>(feature_stats.feature_values_count)});
+ }
+ }
+ }
+ done(s);
+ });
+ };
+
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParseExample")}, input_,
+ std::move(map_fn), num_parallel_calls_);
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParseExampleDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+
+ Node* num_parallle_calls_node;
+ std::vector<Node*> dense_defaults_nodes;
+ dense_defaults_nodes.reserve(dense_defaults_.size());
+
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallle_calls_node));
+
+ for (const Tensor& dense_default : dense_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(dense_default, &node));
+ dense_defaults_nodes.emplace_back(node);
+ }
+
+ AttrValue sparse_keys_attr;
+ AttrValue dense_keys_attr;
+ AttrValue sparse_types_attr;
+ AttrValue dense_attr;
+ AttrValue dense_shapes_attr;
+
+ b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
+ b->BuildAttrValue(dense_keys_, &dense_keys_attr);
+ b->BuildAttrValue(sparse_types_, &sparse_types_attr);
+ b->BuildAttrValue(dense_types_, &dense_attr);
+ b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(this,
+ {
+ {0, input_graph_node},
+ {1, num_parallle_calls_node},
+ },
+ {{2, dense_defaults_nodes}},
+ {{"sparse_keys", sparse_keys_attr},
+ {"dense_keys", dense_keys_attr},
+ {"sparse_types", sparse_types_attr},
+ {"Tdense", dense_attr},
+ {"dense_shapes", dense_shapes_attr}},
+ output));
+ return Status::OK();
+ }
+
+ private:
+ const DatasetBase* const input_;
+ const std::vector<Tensor> dense_defaults_;
+ const std::vector<string> sparse_keys_;
+ const std::vector<string> dense_keys_;
+ const std::map<string, int> key_to_output_index_;
+ const example::FastParseExampleConfig config_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector sparse_types_;
+ const DataTypeVector dense_types_;
+ const std::vector<PartialTensorShape> dense_shapes_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ std::vector<string> sparse_keys_;
+ std::vector<string> dense_keys_;
+ DataTypeVector sparse_types_;
+ DataTypeVector dense_types_;
+ std::vector<PartialTensorShape> dense_shapes_;
+ std::vector<bool> variable_length_;
+ std::vector<std::size_t> elements_per_stride_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
+ ParseExampleDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index b3272f6bcd..533d0bd5d2 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
: buffer_limit_(initial_buffer_size) {
@@ -43,4 +44,5 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
}
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h
index fa8a184072..8693205512 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.h
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+namespace data {
// PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator.
//
@@ -66,6 +67,7 @@ class PrefetchAutotuner {
Mode mode_ = Mode::kDisabled;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
index 29a8cc50cd..cfc324fc7e 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace data {
namespace {
TEST(PrefetchAutotuner, Disabled) {
@@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) {
}
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 50efbcbe2a..ad7d5eb3ff 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -12,15 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <deque>
-
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
+#include <deque>
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -70,7 +73,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- auto_tuner_(params.dataset->buffer_size_) {}
+ auto_tuner_(params.dataset->buffer_size_) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
~Iterator() override {
// Signal the prefetch thread to terminate it. We will then
@@ -97,6 +104,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
{
mutex_lock l(mu_);
+ auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
@@ -112,7 +120,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
}
if (!buffer_.empty()) {
- return Consume(out_tensors, end_of_sequence);
+ return Consume(out_tensors, end_of_sequence, stats_aggregator);
}
if (prefetch_thread_finished_) {
@@ -200,14 +208,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> value;
};
- Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(prefix_end_, "::buffer_utilization"),
+ {static_cast<float>(buffer_.size()) /
+ static_cast<float>(auto_tuner_.buffer_limit())});
+ }
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
Status s = buffer_.front().status;
if (s.ok()) {
*out_tensors = std::move(buffer_.front().value);
}
+ auto_tuner_.RecordConsumption(buffer_.size());
buffer_.pop_front();
*end_of_sequence = false;
@@ -324,6 +340,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
+ string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
@@ -346,6 +363,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
*output = new Dataset(ctx, input, buffer_size);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
@@ -354,4 +372,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
.HostMemory("input_dataset")
.HostMemory("handle"),
PrefetchDatasetOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
index c40c4b00da..588fb25a06 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.h
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
class PrefetchDatasetOp : public UnaryDatasetOpKernel {
public:
@@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
class Dataset;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 7817170e73..044a791a3f 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random_distributions.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index aa38775125..89fbaae369 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -142,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
RangeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 086b552936..c474cb4773 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 5e9ace3486..94e96635ab 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class ForeverIterator : public DatasetIterator<Dataset> {
public:
explicit ForeverIterator(const Params& params)
- : DatasetIterator<Dataset>(params), input_impl_(nullptr) {}
+ : DatasetIterator<Dataset>(params),
+ input_impl_(nullptr),
+ first_call_(true) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
do {
- bool first_call = false;
if (!input_impl_) {
- first_call = true;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (!*end_of_sequence) {
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ if (first_call_ && *end_of_sequence) {
+ // If the first call to GetNext() fails because the end
+ // of sequence has been reached, we terminate the
+ // iteration immediately. (Otherwise, this iterator
+ // would loop infinitely and never produce a value.)
+ input_impl_.reset();
return Status::OK();
+ }
+ first_call_ = false;
+ if (!*end_of_sequence) {
+ return s;
} else {
input_impl_.reset();
- if (first_call) {
- // If the first call to GetNext() fails because the end
- // of sequence has been reached, we terminate the
- // iteration immediately. (Otherwise, this iterator
- // would loop infinitely and never produce a value.)
- return Status::OK();
- }
+ first_call_ = true;
}
} while (true);
}
@@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (input_impl_)
+ if (!first_call_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
else
TF_RETURN_IF_ERROR(
@@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
+ first_call_ = true;
} else {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ first_call_ = false;
}
return Status::OK();
}
@@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool first_call_ GUARDED_BY(mu_);
};
const int64 count_;
@@ -240,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
RepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index e4cb31e2b2..6e515d6cc8 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -109,7 +109,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;
@@ -153,7 +153,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_(params.dataset->initial_state_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -277,5 +279,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 93a4376836..66466d6a36 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
@@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
ShuffleAndRepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
new file mode 100644
index 0000000000..5b084a16f0
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -0,0 +1,380 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class SingleThreadedExecutorImpl : public Executor {
+ public:
+ explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
+ : params_(params) {}
+
+ ~SingleThreadedExecutorImpl() override {
+ for (const KernelState& kernel_state : kernels_) {
+ params_.delete_kernel(kernel_state.kernel);
+ }
+ }
+
+ Status Initialize(const Graph& graph) {
+ // Topologicially sort `graph` to get a sequence of OpKernels.
+ std::vector<Node*> ordered_nodes;
+ ordered_nodes.reserve(graph.num_nodes());
+ GetReversePostOrder(graph, &ordered_nodes);
+
+ if (ordered_nodes.size() != graph.num_nodes()) {
+ return errors::InvalidArgument("Graph had ", graph.num_nodes(),
+ " but reverse post-order had ",
+ ordered_nodes.size());
+ }
+
+ kernels_.resize(ordered_nodes.size());
+
+ std::unordered_map<Node*, size_t> node_to_index_map;
+
+ // Create the kernel and input-related structures for each node in `graph`.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ node_to_index_map[n] = i;
+
+ for (DataType dt : n->output_types()) {
+ if (IsRefType(dt)) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support reference-typed "
+ "edges.");
+ }
+ }
+
+ if (n->IsControlFlow()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support control flow.");
+ }
+ if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support partitioned graphs.");
+ }
+ if (n->IsCollective()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support collective ops.");
+ }
+
+ KernelState& kernel_state = kernels_[i];
+ TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
+ kernel_state.num_inputs = n->num_inputs();
+ kernel_state.num_outputs = n->num_outputs();
+
+ if (i == 0) {
+ kernel_state.input_start_index = 0;
+ } else {
+ const KernelState& previous_kernel_state = kernels_[i - 1];
+ kernel_state.input_start_index =
+ previous_kernel_state.input_start_index +
+ previous_kernel_state.num_inputs;
+ }
+ }
+
+ // Build the mapping from each node output to the input slot for the
+ // corresponding destination node.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ KernelState& kernel_state = kernels_[i];
+ kernel_state.output_locations.resize(kernel_state.num_outputs);
+ for (const Edge* e : n->out_edges()) {
+ if (!e->IsControlEdge()) {
+ kernel_state.output_locations[e->src_output()].push_back(
+ kernels_[node_to_index_map[e->dst()]].input_start_index +
+ e->dst_input());
+ }
+ }
+
+ // Compute allocator attributes for each node output, and corresponding
+ // node input.
+ kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
+ AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
+
+ OpKernel* op_kernel = kernel_state.kernel;
+ for (int out = 0; out < n->num_outputs(); out++) {
+ DCHECK_LT(out, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
+ if (on_host) {
+ AllocatorAttributes h;
+ h.set_on_host(on_host);
+ attrs[out].Merge(h);
+ }
+ }
+ }
+
+ if (!kernels_.empty()) {
+ const KernelState& last_kernel_state = kernels_.back();
+ total_num_inputs_ =
+ last_kernel_state.input_start_index + last_kernel_state.num_inputs;
+ input_alloc_attrs_.resize(total_num_inputs_);
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
+ for (size_t output_location : kernels_[i].output_locations[j]) {
+ input_alloc_attrs_[output_location] =
+ kernels_[i].output_alloc_attrs[j];
+ }
+ }
+ }
+ } else {
+ total_num_inputs_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // TODO(mrry): Consider specializing the implementation of Executor::Run()
+ // instead, to avoid unnecessary atomic operations in the callback when
+ // running synchronously.
+ void RunAsync(const Args& args, DoneCallback done) override {
+ // The inputs to each kernel are stored contiguously in `inputs`.
+ //
+ // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
+ // determine the range of elements in this vector that correspond to
+ // the inputs of `kernels_[i]`.
+ //
+ // This vector has the following layout:
+ //
+ // * Kernel 0, input 0.
+ // * Kernel 0, input 1.
+ // * ...
+ // * Kernel 0, input `kernels_[0].num_inputs - 1`.
+ // * Kernel 1, input 0.
+ // * ...
+ // * Kernel 1, input `kernels_[1].num_inputs - 1`.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input 0.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
+ //
+ // Note that kernels with zero inputs do not correspond to any elements in
+ // this vector.
+ //
+ // We use `ManualConstructor<Tensor>` to avoid the overhead of
+ // default-constructing an invalid `Tensor` for each slot at the beginning
+ // of execution:
+ // * Elements are initialized when the outputs of a kernel execution are
+ // propagated to the inputs of kernels that depend on them.
+ // * The elements corresponding to the inputs for kernel `i` are destroyed
+ // after kernel `i` executes.
+ // * In an error case (see below), we use the connectivity information in
+ // `KernelState::output_locations` to determine which locations have been
+ // initialized, and manually destroy them.
+ std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_);
+
+ // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
+ // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
+ TensorValueVec node_inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ // Prepare the parameters that will be the same for all kernels.
+ OpKernelContext::Params params;
+ params.step_id = args.step_id;
+ Device* device = params_.device;
+ params.device = device;
+ params.log_memory = false; // TODO(mrry): Too severe?
+ params.record_tensor_accesses = false; // TODO(mrry): Too severe?
+ params.rendezvous = args.rendezvous;
+ params.session_state = args.session_state;
+ params.tensor_store = args.tensor_store;
+ params.cancellation_manager = args.cancellation_manager;
+ // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor
+ // allocations that it performs. Consider specializing its handling in the
+ // executor.
+ params.call_frame = args.call_frame;
+ params.function_library = params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_container = args.step_container;
+ params.slice_reader_cache = nullptr; // TODO(mrry): Too severe?
+ params.inputs = &node_inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Args::Runner runner_copy = args.runner;
+ params.runner = &runner_copy;
+ params.stats_collector = args.stats_collector;
+
+ // NOTE(mrry): We are assuming that the graph is loopless and condless.
+ params.frame_iter = FrameAndIter(0, 0);
+ params.is_input_dead = false;
+
+ // TODO(mrry): Add non-default device context inference.
+ params.op_device_context = nullptr;
+ // TODO(mrry): Consider implementing forwarding.
+ params.forward_from_array = nullptr;
+
+ // Execute the kernels one-at-a-time in topological order.
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ const KernelState& kernel_state = kernels_[i];
+
+ // Prepare the per-kernel parameters.
+ const size_t input_start_index = kernel_state.input_start_index;
+ const size_t num_inputs = kernel_state.num_inputs;
+ const size_t num_outputs = kernel_state.num_outputs;
+
+ node_inputs.clear();
+ node_inputs.resize(num_inputs);
+ input_alloc_attrs.clear();
+ input_alloc_attrs.resize(num_inputs);
+ for (size_t j = 0; j < num_inputs; ++j) {
+ auto t = inputs[input_start_index + j].get();
+ node_inputs[j].tensor = t;
+ input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
+ }
+ params.op_kernel = kernel_state.kernel;
+ input_device_contexts.clear();
+ input_device_contexts.resize(num_inputs);
+ params.output_attr_array = kernel_state.output_alloc_attrs.data();
+ OpKernelContext ctx(&params, num_outputs);
+
+ // Actually execute the kernel.
+ device->Compute(kernel_state.kernel, &ctx);
+
+ if (!ctx.status().ok()) {
+ // On failure, we must manually free all intermediate tensors. We have
+ // already freed all the inputs for kernels up to (but not including)
+ // the `i`th kernel. We scan through the previously executed kernels and
+ // destroy any tensors that were destined to be the input for a kernel
+ // that has not yet executed.
+ for (size_t j = 0; j < i; ++j) {
+ const KernelState& executed_kernel_state = kernels_[j];
+ for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) {
+ for (size_t output_location :
+ executed_kernel_state.output_locations[k]) {
+ if (output_location >= input_start_index) {
+ // Only destroy an output location if it is an input to an
+ // operation that has not yet executed.
+ inputs[output_location].Destroy();
+ }
+ }
+ }
+ }
+ done(ctx.status());
+ return;
+ }
+
+ // Free the inputs to the current kernel.
+ for (size_t j = 0; j < num_inputs; ++j) {
+ inputs[input_start_index + j].Destroy();
+ }
+
+ // Forward the outputs of the kernel to the inputs of subsequent kernels.
+ for (size_t j = 0; j < num_outputs; ++j) {
+ TensorValue val = ctx.release_output(j);
+ // TODO(mrry): Consider flattening the `output_locations` vector
+ // to improve the cache-friendliness of this loop.
+ for (size_t output_location : kernel_state.output_locations[j]) {
+ // TODO(mrry): Validate that the types match the expected values or
+ // ensure that the necessary validation has already happened.
+ inputs[output_location].Init(*val.tensor);
+ }
+ delete val.tensor;
+ }
+ }
+ done(Status::OK());
+ }
+
+ private:
+ const LocalExecutorParams params_;
+
+ // All following members are read-only after Initialize().
+
+ // The sum of the number of inputs for each node in the graph. This determines
+ // the length of the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ size_t total_num_inputs_;
+
+ // Represents cached graph structure state for each kernel.
+ struct KernelState {
+ // The kernel object. Not owned.
+ //
+ // This pointer is managed by `params_.create_kernel()` and
+ // `params_.delete_kernel()`.
+ OpKernel* kernel;
+
+ // These fields determine the range of elements in `inputs` that corresponds
+ // to the inputs of `kernel`.
+ size_t input_start_index;
+ size_t num_inputs;
+
+ size_t num_outputs;
+
+ // For the `j`th output of `kernel`, `output_locations[j]` contains the
+ // locations in the flat `inputs` vector to which that output must be
+ // copied. See comment at the beginning of `RunAsync()` for details.
+ std::vector<std::vector<size_t>>
+ output_locations; // Length = `num_outputs`.
+
+ // Memory space information for each output of `kernel`.
+ std::vector<AllocatorAttributes>
+ output_alloc_attrs; // Length = `num_outputs`.
+ };
+ std::vector<KernelState> kernels_;
+
+ // Memory space information for each input. This information is stored in the
+ // same order as the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ std::vector<AllocatorAttributes>
+ input_alloc_attrs_; // Length = `total_num_inputs_`.
+};
+
+class SingleThreadedExecutorRegistrar {
+ public:
+ SingleThreadedExecutorRegistrar() {
+ ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret;
+ TF_RETURN_IF_ERROR(
+ NewSingleThreadedExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static SingleThreadedExecutorRegistrar registrar;
+
+} // namespace
+
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor) {
+ std::unique_ptr<SingleThreadedExecutorImpl> impl(
+ new SingleThreadedExecutorImpl(params));
+ TF_RETURN_IF_ERROR(impl->Initialize(*graph));
+ *executor = impl.release();
+ return Status::OK();
+}
+
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h
new file mode 100644
index 0000000000..e934352a1d
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/executor.h"
+
+namespace tensorflow {
+namespace data {
+
+// Creates a new `Executor` for executing `graph` synchronously on the caller
+// thread.
+//
+// NOTE(mrry): The returned executor is optimized to impose low overhead on
+// graphs that perform a small amount of work (e.g. <15us of work per graph on
+// present architectures). It eschews concurrency, because issuing work to
+// multiple threads can dominate the cost of executing small ops synchronously,
+// and because contention in the executor data structures can reduce throughput
+// (in terms of ops executed per unit time).
+//
+// However, the current implementation has the following limitations:
+//
+// 1. Reference-typed tensors are not supported and will not be supported in
+// future.
+// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not
+// currently supported. The current plan is to extend support to "functional"
+// control flow after the TensorFlow APIs transition to building graphs in
+// that form (e.g. `tf.cond_v2()`).
+// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported.
+// The present implementation executes kernels one at a time in topological
+// order, and cannot currently distinguish between disconnected subgraphs
+// that are logically connected by subgraphs on a different device.
+// 4. Memory logging is not currently supported.
+// 5. Allocation forwarding is not currently supported.
+// 6. Non-default device contexts are not currently supported. In effect, this
+// limits the executor to CPU devices.
+// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null
+// are not currently supported.
+//
+// The single-threaded executor is primarily suitable for executing simple
+// TensorFlow functions, such as one might find in a `tf.data` pipeline.
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
new file mode 100644
index 0000000000..6244e287bb
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+ ExecutorTest()
+ : device_(DeviceFactory::NewDevice("CPU", {},
+ "/job:localhost/replica:0/task:0")) {}
+
+ ~ExecutorTest() override {
+ // There should always be exactly one Ref left on the Rendezvous
+ // when the test completes.
+ CHECK(rendez_->Unref());
+ delete exec_;
+ delete device_;
+ }
+
+ // Resets executor_ with a new executor based on a graph 'gdef'.
+ void Create(std::unique_ptr<const Graph> graph) {
+ const int version = graph->versions().producer();
+ LocalExecutorParams params;
+ params.device = device_;
+ params.create_kernel = [this, version](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ };
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ delete exec_;
+ TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
+ runner_ = [](std::function<void()> fn) { fn(); };
+ rendez_ = NewLocalRendezvous();
+ }
+
+ Status Run(Rendezvous* rendez) {
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Status Run(CallFrameInterface* call_frame) {
+ Executor::Args args;
+ args.call_frame = call_frame;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Device* device_ = nullptr;
+ Executor* exec_ = nullptr;
+ Executor::Args::Runner runner_;
+ Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = val;
+ return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int32>()() = val;
+ return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+ Tensor tensor(DT_BOOL, TensorShape({}));
+ tensor.scalar<bool>()() = val;
+ return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+ Tensor tensor(DT_DOUBLE, TensorShape({}));
+ tensor.scalar<double>()() = val;
+ return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_FLOAT);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<float>()();
+}
+
+Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
+ const string& receiver, const string& name) {
+ Rendezvous::ParsedKey result;
+ TF_CHECK_OK(
+ Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
+ name, FrameAndIter(0, 0)),
+ &result));
+ return result;
+}
+
+TEST_F(ExecutorTest, SimpleAdd) {
+ // c = a + b
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto tmp = test::graph::Add(g.get(), in0, in1);
+ test::graph::Retval(g.get(), 0, tmp);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+ // v0 <- a
+ // v1 = v0 + v0
+ // v2 = v1 + v1
+ // ... ...
+ // v10 = v9 + v9
+ //
+ // b <- v10
+ // All nodes are executed by one thread.
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ const int N = 10;
+ for (int i = 1; i <= N; ++i) {
+ v = test::graph::Add(g.get(), v, v);
+ }
+ // out <- v10
+ test::graph::Retval(g.get(), 0, v);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ // a = 1.0
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+// a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+// a + ((a + a) + a)
+// (a + a) + (a + a)
+// ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+ CHECK_GT(N, 1);
+ // A single input node "in".
+ auto in = test::graph::Arg(g, 0, DT_FLOAT);
+ std::vector<Node*> nodes;
+ int i = 0;
+ // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+ for (; i < N; ++i) {
+ nodes.push_back(test::graph::Identity(g, in, 0));
+ }
+ random::PhiloxRandom philox(0, 17);
+ random::SimplePhilox rnd(&philox);
+ while (nodes.size() > 1) {
+ // Randomly pick two from nodes and add them. The resulting node
+ // is named lik n10, n11, .... and is put back into "nodes".
+ int x = rnd.Uniform(nodes.size());
+ auto in0 = nodes[x];
+ nodes[x] = nodes.back();
+ nodes.resize(nodes.size() - 1);
+ x = rnd.Uniform(nodes.size());
+ auto in1 = nodes[x];
+ // node = in0 + in1.
+ nodes[x] = test::graph::Add(g, in0, in1);
+ }
+ // The final output node "out".
+ test::graph::Retval(g, 0, nodes.back());
+ FixupSourceAndSinkEdges(g);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ BuildTree(4096, g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(4096.0, V(retvals[0]));
+}
+
+TEST_F(ExecutorTest, OpError) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto zero = test::graph::Constant(g.get(), V(0.0));
+ auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
+ auto check = test::graph::CheckNumerics(g.get(), inf, "message");
+ auto two = test::graph::Constant(g.get(), V(2.0));
+ test::graph::Binary(g.get(), "Mul", check, two);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({}, {});
+ // Fails due to invalid dtype.
+ EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
+}
+
+static void BM_executor(int iters, int width, int depth) {
+#ifdef PLATFORM_GOOGLE
+ BenchmarkUseRealTime();
+#endif // PLATFORM_GOOGLE
+ Graph* g = new Graph(OpRegistry::Global());
+ random::PhiloxRandom philox(1729, 17);
+ random::SimplePhilox rand(&philox);
+ uint64 cur = 0;
+ uint32 r = 1 + rand.Rand32() % width;
+ std::vector<Node*> ready_nodes;
+ for (int i = 0; i < r; ++i) {
+ ready_nodes.push_back(test::graph::NoOp(g, {}));
+ ++cur;
+ }
+ for (int i = 0; i < depth; ++i) {
+ std::random_shuffle(ready_nodes.begin(), ready_nodes.end());
+ r = 1 + rand.Rand32() % (ready_nodes.size());
+ std::vector<Node*> control_inputs;
+ for (int j = 0; j < r; ++j) {
+ control_inputs.push_back(ready_nodes.back());
+ ready_nodes.pop_back();
+ }
+ Node* n = test::graph::NoOp(g, control_inputs);
+ ++cur;
+ r = 1 + rand.Rand32() % width;
+ for (int j = 0; j < r; ++j) {
+ ready_nodes.push_back(test::graph::NoOp(g, {n}));
+ ++cur;
+ }
+ }
+ FixupSourceAndSinkEdges(g);
+#ifdef PLATFORM_GOOGLE
+ SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
+ SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
+#endif // PLATFORM_GOOGLE
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .Run(iters);
+}
+
+// Tall skinny graphs
+BENCHMARK(BM_executor)->ArgPair(16, 1024);
+BENCHMARK(BM_executor)->ArgPair(32, 8192);
+
+// Short fat graphs
+BENCHMARK(BM_executor)->ArgPair(1024, 16);
+BENCHMARK(BM_executor)->ArgPair(8192, 32);
+
+// Tall fat graph
+BENCHMARK(BM_executor)->ArgPair(1024, 1024);
+
+// TODO(mrry): This benchmark currently crashes with a use-after free, because
+// test::Benchmark::RunWithArgs() assumes that the executor will take ownership
+// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
+// duration of the benchmark. Since the single threaded executor does not retain
+// a copy of the graph, this fails.
+//
+// TODO(mrry): Add support for Arg/Retval "function call convention" in
+// `test::Benchmark::RunWithArgs()`.
+#if 0
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+static void BM_FeedInputFetchOutput(int iters) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // z = x + y: x and y are provided as benchmark inputs. z is the
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
+ Node* sum = test::graph::Add(g, x, y);
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
+ FixupSourceAndSinkEdges(g);
+ Tensor val(DT_FLOAT, TensorShape({}));
+ val.scalar<float>()() = 3.14;
+ SetBenchmarkItemsProcessed(static_cast<int64>(iters));
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .RunWithArgs({{x, val}, {y, val}}, {z}, iters);
+}
+BENCHMARK(BM_FeedInputFetchOutput);
+#endif
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index fe7ef38d5f..b8c7fb15f4 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 14df3a6801..1e73cfc753 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU),
SlideDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index e526578701..85b1e50695 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL);
#undef REGISTER_DATASET_KERNEL
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc
index ffabda1a8a..783d1e6cb2 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.cc
+++ b/tensorflow/core/kernels/data/sql/driver_manager.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
@@ -30,5 +30,5 @@ std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h
index a34691b5a2..c5428f396b 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/sql/driver_manager.h
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
// A factory class for creating `QueryConnection` instances.
@@ -35,7 +35,7 @@ class DriverManager {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h
index e9ffca202f..2fd229a9bf 100644
--- a/tensorflow/core/kernels/data/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/sql/query_connection.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
+namespace data {
class IteratorContext;
@@ -63,7 +64,7 @@ class QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index 7cd07bd8ec..5108e83976 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
SqliteQueryConnection::SqliteQueryConnection() {}
@@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
index 81b19530b7..175492c49d 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
class SqliteQueryConnection : public QueryConnection {
@@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 2aa153fcfa..6bbe459332 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -24,8 +24,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace {
+
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following ops.
@@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 75af73df54..f5314f7a75 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
@@ -135,4 +136,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index b133cfab54..a7ded67876 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+namespace data {
namespace {
static mutex* get_counters_map_lock() {
@@ -145,4 +146,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 52753a3ccd..e9e42f05a1 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// This op defines a `Dataset` that passes through its input elements and
@@ -242,206 +243,11 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
-class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit FeatureStatsDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- string tag;
- OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
- OP_REQUIRES(ctx, input->output_dtypes()[0] == DT_STRING,
- errors::InvalidArgument("FeatureStatsDataset only supports "
- "input with a single `tf.string` "
- "component."));
- *output = new Dataset(ctx, input, std::move(tag));
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : DatasetBase(DatasetContext(ctx)),
- input_(input),
- tag_(std::move(tag)) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::FeatureStatsDataset")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return "FeatureStatsDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_node;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
- Node* tag_node;
- TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- tf_shared_lock l(mu_);
- Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- auto stats_aggregator = ctx->stats_aggregator();
- if (stats_aggregator && s.ok() && !*end_of_sequence) {
- for (const Tensor& t : *out_tensors) {
- auto record_t = t.flat<string>();
- Example example;
- // TODO(b/111553342): redundant parsing here, potential solutions
- // to improve performance is to a) have a potential
- // ParseExampleDataset and collect stats from there and b) make
- // changes to parse_example() where it returns stats as well.
- for (int i = 0; i < record_t.size(); ++i) {
- if (example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("examples_count", "trainer",
- 1);
- AddStatsFeatures(example, stats_aggregator);
- } else {
- SequenceExample sequence_example;
- if (sequence_example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("sequence_examples_count",
- "trainer", 1);
- AddStatsFeatures(sequence_example, stats_aggregator);
- }
- }
- }
- }
- }
- return s;
- }
-
- int AddStatsFeatureValues(const Feature& feature) {
- int feature_values_list_size = 0;
- switch (feature.kind_case()) {
- case Feature::kBytesList: {
- feature_values_list_size = feature.bytes_list().value().size();
- break;
- }
- case Feature::kFloatList: {
- feature_values_list_size = feature.float_list().value().size();
- break;
- }
- case Feature::kInt64List: {
- feature_values_list_size = feature.int64_list().value().size();
- break;
- }
- case Feature::KIND_NOT_SET:
- break;
- }
- return feature_values_list_size;
- }
-
- void AddStatsFeatures(
- const Example& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(example.features().feature().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.features().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- void AddStatsFeatures(
- const SequenceExample& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(
- example.context().feature().size() +
- example.feature_lists().feature_list().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.context().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
-
- for (const auto& feature_list :
- example.feature_lists().feature_list()) {
- stats_aggregator->IncrementCounter("feature_lists_count", "trainer",
- 1);
- for (const auto& feature : feature_list.second.feature()) {
- feature_values_list_size_sum += AddStatsFeatureValues(feature);
- }
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- };
-
- const DatasetBase* const input_;
- const string tag_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("FeatureStatsDataset").Device(DEVICE_CPU),
- FeatureStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
BytesProducedStatsDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index e5c237dfaa..e5cdfdd732 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index fc21c3235a..e1cefd23d8 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,8 +29,6 @@ class TensorDatasetOp : public DatasetOpKernel {
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
@@ -74,7 +73,13 @@ class TensorDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -135,5 +140,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),
TensorDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index ccd5e60acc..2ed636a400 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a,
@@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU),
EnqueueInQueueDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 5b051e0e08..7dc64b0a75 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -30,8 +31,6 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
: DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
std::vector<Tensor> components;
@@ -93,7 +92,13 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -163,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
TensorSliceDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 1a79f72b28..81c432b938 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -204,5 +204,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
UnbatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 0ab6beabfc..2ad4711aab 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace data {
namespace {
class WindowDataset : public DatasetBase {
@@ -107,4 +108,5 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 7bd31a0bc7..84cb3c7860 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// Creates a dataset representing an eagerly-collected window of elements.
//
@@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
std::vector<PartialTensorShape> output_shapes,
DatasetBase** out_dataset);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 41bf9d43fe..3975086841 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/window_dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 1c49874a6a..3f76695bb1 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
-
+namespace data {
namespace {
class ToTFRecordOp : public AsyncOpKernel {
@@ -104,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
ToTFRecordOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index e4306579ed..61a2078f46 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index 1ca144cb40..bc416fa78b 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
-#define TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
// Functor definition for data format dim mapping ops, must be compilable
// by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -83,4 +83,4 @@ struct DataFormatVecPermute {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 53a23b1306..d705e82b0d 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DEBUG_OP_H_
-#define TENSORFLOW_KERNELS_DEBUG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
@@ -177,8 +177,10 @@ class BaseDebugOp : public OpKernel {
// Publish a tensor to all debug URLs of the debug op.
// Log an error if the publishing failed.
- void PublishTensor(const Tensor& tensor) {
- if (!debug_urls_.empty()) {
+ Status PublishTensor(const Tensor& tensor) {
+ if (debug_urls_.empty()) {
+ return Status::OK();
+ } else {
Status status = DebugIO::PublishDebugTensor(*debug_watch_key_, tensor,
Env::Default()->NowMicros(),
debug_urls_, gated_grpc_);
@@ -189,6 +191,7 @@ class BaseDebugOp : public OpKernel {
<< str_util::Join(debug_urls_, ", ")
<< ", due to: " << status.error_message();
}
+ return status;
}
}
@@ -213,7 +216,7 @@ class DebugIdentityOp : public BaseDebugOp {
return;
}
- PublishTensor(context->input(0));
+ OP_REQUIRES_OK(context, PublishTensor(context->input(0)));
context->set_output(0, context->input(0));
}
};
@@ -252,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp {
TensorShape shape({1});
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
output_tensor->vec<int64>()(0) = nan_count;
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
};
@@ -377,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp {
bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
positive_inf_count == 0;
if (!mute) {
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
}
@@ -389,4 +392,4 @@ class DebugNumericSummaryOp : public BaseDebugOp {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DEBUG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 240c13261e..61b5731250 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
#define EIGEN_USE_THREADS
@@ -105,4 +105,4 @@ Status VariantCopyFn<GPUDevice>(OpKernelContext* context, const Tensor& from,
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index b01db91720..fb2a4cc8ef 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- T* merged_base = &merged_flat(0, 0);
- const T* data_base = &data_flat(0, 0);
+ T* merged_base = merged_flat.data();
+ const T* data_base = data_flat.data();
for (int i = 0; i < indices_vec.size(); i++) {
int32 index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index e13e548f86..27918b410b 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -51,14 +51,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
- const TensorReshapingOp<
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>, const Kernel> >,
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel> > > >,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
const OutputBackward> > > >,
TensorReshapingOp<
@@ -66,24 +70,27 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
const OutputBackward> >,
- const TensorReshapingOp<
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>,
- const Kernel> > > > >::type
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel> > > > > > >::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
typename internal::traits<OutputBackward>::Index inputRows,
typename internal::traits<OutputBackward>::Index inputCols,
- const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
- const DenseIndex strideCols = 1) {
+ const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1,
+ const DenseIndex col_stride = 1) {
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
@@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput(
const TensorIndex outputCols =
isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
- } else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
- }
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ // Computing the forward padding.
+ const TensorIndex forward_pad_top_z = numext::maxi<Index>(
+ 0,
+ ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2);
+ const TensorIndex forward_pad_top = numext::maxi<Index>(
+ 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
+ const TensorIndex forward_pad_left = numext::maxi<Index>(
+ 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
+
+ const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z;
+ const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+ const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
+
+ const TensorIndex padding_bottom_z = inputPlanes -
+ (outputPlanes - 1) * plane_stride - 2 -
+ padding_top_z + kernelPlanesEff;
+ const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
+ 2 - padding_top + kernelRowsEff;
+ const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
+ 2 - padding_left + kernelColsEff;
+
+ eigen_assert(padding_top_z >= 0);
eigen_assert(padding_top >= 0);
eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom_z >= 0);
eigen_assert(padding_bottom >= 0);
eigen_assert(padding_right >= 0);
- // The kernel has dimensions filters X channels X patch_planes X patch_rows X
- // patch_cols.
+ // The kernel has dimensions :
+ // filters x channels x patch_planes x patch_rows x patch_cols.
// We need to reverse the kernel along the spatial dimensions.
- array<bool, 5> kernel_reverse;
+ Eigen::array<bool, 5> kernel_reverse;
if (isColMajor) {
kernel_reverse[0] = false;
kernel_reverse[1] = false;
@@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput(
kernel_reverse[4] = false;
}
- DSizes<TensorIndex, 3> kernel_dims;
+ // Reorder the dimensions to:
+ // filters x patch_planes x patch_rows x patch_cols x channels
+ array<TensorIndex, 5> kernel_shuffle;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
+ // From: filters x channels x planes x rows x cols
+ // To: filters x planes x rows x cols x channels
+ kernel_shuffle[0] = 0;
+ kernel_shuffle[1] = 2;
+ kernel_shuffle[2] = 3;
+ kernel_shuffle[3] = 4;
+ kernel_shuffle[4] = 1;
} else {
- kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
+ // From: cols x rows x planes x channels x filters
+ // To: channels x cols x rows x planes x filters
+ kernel_shuffle[0] = 3;
+ kernel_shuffle[1] = 0;
+ kernel_shuffle[2] = 1;
+ kernel_shuffle[3] = 2;
+ kernel_shuffle[4] = 4;
+ }
+
+ // Collapse the dims
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelFilters;
+ } else {
+ kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels;
}
// The output_backward has dimensions out_depth X out_planes X out_rows X
@@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput(
// dimensions:
// out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes *
// input_rows * input_cols * OTHERS)
- DSizes<TensorIndex, 3> pre_contract_dims;
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[0] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = inputPlanes * inputRows * inputCols;
for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[2] *= out.dimension(i);
+ pre_contract_dims[1] *= out.dimension(i);
}
} else {
- pre_contract_dims[2] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[1] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = inputPlanes * inputRows * inputCols;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= out.dimension(i);
}
}
- // We will contract along dimensions (0, 2) in kernel and (0, 1) in
- // output_backward, if this is col-major, and
- // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this
- // row-major.
- array<IndexPair<TensorIndex>, 2> contract_dims;
+ // We will contract along the fused dimension that contains the kernelFilters,
+ // kernelPlanes, kernelRows and kernelCols.
+ array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
} else {
// row-major: output.patches.contract(kernel)
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
}
// Post contraction, the dimensions of the input_backprop is
@@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput(
}
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
-
return choose(
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
.reshape(kernel_dims)
+ .eval()
.contract(output_backward
.extract_volume_patches(
kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols, padding_ztop,
- padding_zbottom, padding_top, padding_bottom,
+ plane_stride, row_stride, col_stride, padding_top_z,
+ padding_bottom_z, padding_top, padding_bottom,
padding_left, padding_right)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
output_backward
.extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
+ plane_stride, row_stride, col_stride,
+ padding_top_z, padding_bottom_z, padding_top,
padding_bottom, padding_left, padding_right)
.reshape(pre_contract_dims)
- .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
+ .contract(kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
+ .reshape(kernel_dims)
+ .eval(),
contract_dims)
.reshape(post_contract_dims));
}
@@ -323,47 +324,34 @@ CuboidConvolutionBackwardInput(
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 5>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
- const TensorContractionOp<
- const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<
- Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > > > >,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const OutputBackward>,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index,
+ 2>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > > > >,
+ TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 5>,
+ const TensorContractionOp<
+ const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorShufflingOp<
+ const array<typename internal::traits<OutputBackward>::Index,
+ 2>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
+ const Input> > >,
const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
- const TensorContractionOp<
- const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input> > > > > >::type
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const OutputBackward> > > >::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -406,213 +394,114 @@ CuboidConvolutionBackwardKernel(
const TensorIndex outputCols =
isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
+ // Number of filters. This is the same as the output depth.
const TensorIndex kernelFilters =
isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+ // Number of channels. This is the same as the input depth.
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
- } else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
- }
-
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
- eigen_assert(padding_top >= 0);
- eigen_assert(padding_left >= 0);
- eigen_assert(padding_bottom >= 0);
- eigen_assert(padding_right >= 0);
-
- // The output_backward has dimensions out_depth X out_plaens X out_rows X
- // out_cols X OTHERS
- // When we extract the image patches from output_backward (with input as the
- // kernel), it will have dimensions
- // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
- // kernel_rows * kernel_cols) X OTHERS
- DSizes<TensorIndex, 4> pre_contract_dims;
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ const TensorIndex padPlanes = numext::maxi<Index>(
+ 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
+ const TensorIndex padRows = numext::maxi<Index>(
+ 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows);
+ const TensorIndex padCols = numext::maxi<Index>(
+ 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
+
+ const TensorIndex padding_top_z = padPlanes / 2;
+ const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
+ const TensorIndex padding_top = padRows / 2;
+ const TensorIndex padding_bottom = padRows - padding_top;
+ const TensorIndex padding_left = padCols / 2;
+ const TensorIndex padding_right = padCols - padding_left;
+
+ // Reshaped output_backward before contraction.
+ DSizes<TensorIndex, 2> output_dims;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[3] = 1;
+ output_dims[0] = kernelFilters;
+ output_dims[1] = outputPlanes * outputRows * outputCols;
for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[3] *= out.dimension(i);
+ output_dims[1] *= out.dimension(i);
}
} else {
- pre_contract_dims[3] = kernelFilters;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = 1;
+ output_dims[1] = kernelFilters;
+ output_dims[0] = outputCols * outputRows * outputPlanes;
for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= out.dimension(i);
+ output_dims[0] *= out.dimension(i);
}
}
- // The input has dimensions in_depth X (input_planes * input_rows *
- // input_cols) X OTHERS
- DSizes<TensorIndex, 3> input_dims;
+ // Reshaped extract_volume_patches(in)
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- input_dims[0] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[2] = 1;
+ pre_contract_dims[0] =
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
for (int i = 4; i < NumDims; ++i) {
- input_dims[2] *= in.dimension(i);
+ pre_contract_dims[1] *= in.dimension(i);
}
- eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ eigen_assert(output_dims[1] == pre_contract_dims[1]);
} else {
- input_dims[2] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[0] = 1;
+ pre_contract_dims[1] =
+ kernelCols * kernelRows * kernelPlanes * kernelChannels;
+ pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
for (int i = 0; i < NumDims - 4; ++i) {
- input_dims[0] *= in.dimension(i);
+ pre_contract_dims[0] *= in.dimension(i);
}
- eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- // We will contract along dimensions (1, 2) in and (1, 3) in out, if
- // this is col-major.
- // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
- array<IndexPair<TensorIndex>, 2> contract_dims;
- if (isColMajor) {
- // col-major: in.contract(output.patches)
- contract_dims[0] = IndexPair<TensorIndex>(1, 1);
- contract_dims[1] = IndexPair<TensorIndex>(2, 3);
- } else {
- // row-major: output.patches.contract(in)
- contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
- }
+ array<TensorIndex, 2> shuffle_dims;
+ shuffle_dims[0] = 1;
+ shuffle_dims[1] = 0;
- // After the contraction, the kernel will have dimension
- // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
- // We will need to shuffle the first two dimensions and reverse the spatial
- // dimensions.
- // The end shape is:
- // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- // This is the shape of the kernel *before* the shuffling.
DSizes<TensorIndex, 5> kernel_dims;
if (isColMajor) {
- kernel_dims[0] = kernelChannels;
- kernel_dims[1] = kernelFilters;
+ kernel_dims[0] = kernelFilters;
+ kernel_dims[1] = kernelChannels;
kernel_dims[2] = kernelPlanes;
kernel_dims[3] = kernelRows;
kernel_dims[4] = kernelCols;
} else {
- kernel_dims[0] = kernelCols;
- kernel_dims[1] = kernelRows;
+ kernel_dims[4] = kernelFilters;
+ kernel_dims[3] = kernelChannels;
kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelFilters;
- kernel_dims[4] = kernelChannels;
- }
-
- // Flip filters and channels.
- array<TensorIndex, 5> kernel_shuffle;
- if (isColMajor) {
- kernel_shuffle[0] = 1;
- kernel_shuffle[1] = 0;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 3;
- kernel_shuffle[4] = 4;
- } else {
- kernel_shuffle[0] = 0;
- kernel_shuffle[1] = 1;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 4;
- kernel_shuffle[4] = 3;
- }
-
- // Reverse the spatial dimensions.
- array<bool, 5> kernel_reverse;
- if (isColMajor) {
- kernel_reverse[0] = false;
- kernel_reverse[1] = false;
- kernel_reverse[2] = true;
- kernel_reverse[3] = true;
- kernel_reverse[4] = true;
- } else {
- kernel_reverse[0] = true;
- kernel_reverse[1] = true;
- kernel_reverse[2] = true;
- kernel_reverse[3] = false;
- kernel_reverse[4] = false;
+ kernel_dims[1] = kernelRows;
+ kernel_dims[0] = kernelCols;
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
return choose(
Cond<internal::traits<Input>::Layout == ColMajor>(),
- input.reshape(input_dims)
- .contract(output_backward
+ output_backward.reshape(output_dims)
+ .contract(input
.extract_volume_patches(
- inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
-
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims),
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
+ strideRows, strideCols, 1, 1, 1, padding_top_z,
+ padding_bottom_z, padding_top, padding_bottom,
+ padding_left, padding_right)
+ .reshape(pre_contract_dims)
+ .shuffle(shuffle_dims),
contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle),
- output_backward
- .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
+ .reshape(kernel_dims),
+ input
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
+ stridePlanes, strideRows, strideCols, 1, 1, 1,
+ padding_top_z, padding_bottom_z, padding_top,
padding_bottom, padding_left, padding_right)
.reshape(pre_contract_dims)
- .contract(input.reshape(input_dims), contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle));
+ .shuffle(shuffle_dims)
+ .contract(output_backward.reshape(output_dims), contract_dims)
+ .reshape(kernel_dims));
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 099696105b..8d06107553 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput(
}
#endif
- // Reorder the dimensions to filters X patch_rows X patch_cols X channels
+ // Reorder the dimensions to:
+ // filters x patch_rows x patch_cols x channels
array<TensorIndex, 4> kernel_shuffle;
if (isColMajor) {
+ // From: filters x channels x rows x cols
+ // To: filters x rows x cols x channels
kernel_shuffle[0] = 0;
kernel_shuffle[1] = 2;
kernel_shuffle[2] = 3;
kernel_shuffle[3] = 1;
} else {
+ // From: cols x rows x channels x filters
+ // To: channels x cols x rows x filters
kernel_shuffle[0] = 2;
kernel_shuffle[1] = 0;
kernel_shuffle[2] = 1;
@@ -499,4 +504,4 @@ SpatialConvolutionBackwardKernel(
} // end namespace Eigen
-#endif // EIGEN_CXX11_NEURAL_NETWORKS_BACKWARD_SPATIAL_CONVOLUTIONS_H
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
new file mode 100644
index 0000000000..87e41b89b3
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using ::tensorflow::TTypes;
+
+template <typename Scalar, typename Device>
+class SpatialConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 4>::ConstTensor;
+ using Filter = TTypes<float, 4>::ConstTensor;
+ using Output = TTypes<float, 4>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
+
+ SpatialConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::SpatialConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void SpatialConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using InputBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+ filter, output_backward, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void SpatialConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using FilterBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, input_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+ input, output_backward, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+template <typename Scalar, typename Device>
+class CuboidConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 5>::ConstTensor;
+ using Filter = TTypes<float, 5>::ConstTensor;
+ using Output = TTypes<float, 5>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
+
+ CuboidConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using InputBackward = TTypes<float, 5>::Tensor;
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+ Eigen::Index input_planes = input_dims[3];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void CuboidConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using FilterBackward = TTypes<float, 5>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+ Eigen::Index filter_planes = filter_dims[2];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward, filter_planes, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
new file mode 100644
index 0000000000..ec949ddc84
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENTE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_benchmark.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+#define CREATE_THREAD_POOL(threads) \
+ Eigen::ThreadPool tp(threads); \
+ Eigen::ThreadPoolDevice device(&tp, threads)
+
+// -------------------------------------------------------------------------- //
+// Spatial Convolutions //
+// -------------------------------------------------------------------------- //
+
+void SpatialConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_batches * input_height * input_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
+ BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW
+
+#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK( \
+ BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW))
+
+#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdInput(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdInput(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdKernel(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdKernel(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(16, N, H, W, C, FC, FH, FW, LABEL);
+
+// ImageNet Forward Convolutions -------------------------------------------- //
+
+BM_SpatialConvolutions(32, // batch size
+ 56, 56, 64, // input: height, width, depth
+ 192, 3, 3, // filter: count, height, width
+ "conv2_00");
+
+BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutions(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutions(32, 7, 7, 48, 128, 5, 5, "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_SpatialConvolutions(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
+BM_SpatialConvolutions(128, 64, 64, 64, 128, 9, 9, "convnet-layer2");
+BM_SpatialConvolutions(128, 32, 32, 128, 128, 9, 9, "convnet-layer3");
+BM_SpatialConvolutions(128, 16, 16, 128, 128, 7, 7, "convnet-layer4");
+BM_SpatialConvolutions(128, 13, 13, 384, 384, 3, 3, "convnet-layer5");
+
+// ImageNet BackwardInput Convolutions -------------------------------------- //
+
+BM_SpatialConvolutionsBwdInput(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// ImageNet BackwardKernel Convolutions ------------------------------------- //
+
+BM_SpatialConvolutionsBwdKernel(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// -------------------------------------------------------------------------- //
+// Cuboid Convolutions //
+// -------------------------------------------------------------------------- //
+
+void CuboidConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_planes, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width,
+ int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_batches * input_height * input_width * input_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// P: panes
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+// FP: filter panes
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \
+ BM_CONCAT(BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##C, \
+ _f_##FC##_##FH##_##FW##_##FP)
+
+#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
+ FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK( \
+ BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \
+ LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
+ FC, FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdKernel(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdKernel(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+// Random Cuboid Convolutions ----------------------------------------------- //
+// TODO(ezhulenev): find representative dims for cuboid convolutions (find
+// models using Conv3D ops).
+
+BM_CuboidConvolutions(8, // batch size
+ 25, 25, 25, 4, // input: height, width, panes, depth
+ 16, 5, 5, 5, // filter: count, height, width, panes
+ "conv3d_depth4");
+BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index 83cd0e9b47..528b3c6bf0 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -264,9 +264,168 @@ class ParseSingleExampleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
ParseSingleExampleOp);
-class SingleSequenceExampleParserOp : public OpKernel {
+class ParseSequenceExampleOp : public OpKernel {
public:
- explicit SingleSequenceExampleParserOp(OpKernelConstruction* ctx)
+ explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* debug_name;
+ const Tensor* serialized;
+ OpInputList context_dense_defaults;
+
+ OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
+ OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
+ OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
+ &context_dense_defaults));
+
+ bool has_debug_name = (debug_name->NumElements() > 0);
+ if (has_debug_name) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(debug_name->shape()),
+ errors::InvalidArgument(
+ "Expected debug_name to be a vector, got shape: ",
+ debug_name->shape().DebugString()));
+ }
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
+ errors::InvalidArgument(
+ "Expected serialized to be a vector, got shape: ",
+ serialized->shape().DebugString()));
+
+ OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
+ errors::InvalidArgument("Expected len(context_dense_defaults) "
+ "== len(context_dense_keys) but got: ",
+ context_dense_defaults.size(), " vs. ",
+ attrs_.num_context_dense));
+
+ std::vector<bool> required(attrs_.num_context_dense);
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ const Tensor& def_value = context_dense_defaults[d];
+ required[d] = (def_value.NumElements() == 0); // No default provided.
+
+ if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
+ errors::InvalidArgument(
+ "default_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " != context_dense_shapes[", d,
+ "] == ", attrs_.context_dense_shapes[d].DebugString()));
+ OP_REQUIRES(
+ ctx, def_value.dtype() == attrs_.context_dense_types[d],
+ errors::InvalidArgument(
+ "context_dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != context_dense_types[",
+ d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
+ }
+ }
+
+ example::Result context_result, feature_list_result;
+ std::vector<Tensor> dense_feature_lengths;
+
+ example::FastParseExampleConfig context_config;
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_config.dense.push_back(
+ {attrs_.context_dense_keys[d], attrs_.context_dense_types[d],
+ attrs_.context_dense_shapes[d], context_dense_defaults[d],
+ false /* attrs_.context_variable_length[d] */,
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_config.sparse.push_back(
+ {attrs_.context_sparse_keys[d], attrs_.context_sparse_types[d]});
+ }
+ example::FastParseExampleConfig feature_list_config;
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ DataType dtype = attrs_.feature_list_dense_types[d];
+ Tensor default_value = Tensor(dtype, TensorShape({}));
+ feature_list_config.dense.push_back(
+ {attrs_.feature_list_dense_keys[d], dtype,
+ attrs_.feature_list_dense_shapes[d], default_value,
+ (attrs_.feature_list_dense_missing_assumed_empty.count(
+ attrs_.feature_list_dense_keys[d]) > 0),
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_config.sparse.push_back(
+ {attrs_.feature_list_sparse_keys[d],
+ attrs_.feature_list_sparse_types[d]});
+ }
+
+ auto serialized_t = serialized->flat<string>();
+ auto debug_name_t = debug_name->flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
+ gtl::ArraySlice<string> names_slice(debug_name_t.data(),
+ debug_name_t.size());
+
+ OP_REQUIRES_OK(
+ ctx,
+ FastParseSequenceExample(
+ context_config, feature_list_config, slice, names_slice,
+ ctx->device()->tensorflow_cpu_worker_threads()->workers,
+ &context_result, &feature_list_result, &dense_feature_lengths));
+
+ OpOutputList context_sparse_indices;
+ OpOutputList context_sparse_values;
+ OpOutputList context_sparse_shapes;
+ OpOutputList context_dense_values;
+ OpOutputList feature_list_sparse_indices;
+ OpOutputList feature_list_sparse_values;
+ OpOutputList feature_list_sparse_shapes;
+ OpOutputList feature_list_dense_values;
+ OpOutputList feature_list_dense_lengths;
+
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_dense_values", &context_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
+ &feature_list_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
+ &feature_list_sparse_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
+ &feature_list_sparse_shapes));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
+ &feature_list_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_lengths",
+ &feature_list_dense_lengths));
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_dense_values.set(d, context_result.dense_values[d]);
+ }
+ TensorShape lengths_shape;
+ lengths_shape.AddDim(serialized_t.size());
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
+ feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_sparse_indices.set(d, context_result.sparse_indices[d]);
+ context_sparse_values.set(d, context_result.sparse_values[d]);
+ context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
+ feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
+ feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
+ }
+ }
+
+ protected:
+ ParseSequenceExampleAttrs attrs_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
+ ParseSequenceExampleOp);
+
+class ParseSingleSequenceExampleOp : public OpKernel {
+ public:
+ explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
}
@@ -658,7 +817,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
- SingleSequenceExampleParserOp);
+ ParseSingleSequenceExampleOp);
#ifndef IS_MOBILE_PLATFORM
// when using lite protos on mobile, decoding JSON is not available.
diff --git a/tensorflow/core/kernels/extract_image_patches_op.h b/tensorflow/core/kernels/extract_image_patches_op.h
index e430a23d20..64b8c0338b 100644
--- a/tensorflow/core/kernels/extract_image_patches_op.h
+++ b/tensorflow/core/kernels/extract_image_patches_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
-#define TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#define TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -53,4 +53,4 @@ struct ExtractImagePatchesForward {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index d51acc38ef..045a96ac1e 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
-#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
#include <tuple>
@@ -277,4 +277,4 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fill_functor.h b/tensorflow/core/kernels/fill_functor.h
index 4c8b3f01a7..46bffa5173 100644
--- a/tensorflow/core/kernels/fill_functor.h
+++ b/tensorflow/core/kernels/fill_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
#define EIGEN_USE_THREADS
@@ -89,4 +89,4 @@ struct SetOneFunctor<Eigen::ThreadPoolDevice, string> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h
index 2d7a230fc0..55a959f3c3 100644
--- a/tensorflow/core/kernels/fractional_pool_common.h
+++ b/tensorflow/core/kernels/fractional_pool_common.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
-#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
#include <algorithm>
#include <vector>
@@ -75,4 +75,4 @@ std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
bool pseudo_random);
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h
index d6c68df986..c45b6f79e3 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.h
+++ b/tensorflow/core/kernels/fused_batch_norm_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
-#define TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -128,4 +128,4 @@ struct FusedBatchNormFreezeGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index 2c6e8bf3bc..cd2873bdca 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -176,4 +176,4 @@ struct GatherFunctor<GPUDevice, Variant, Index> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h
index 60780fb50c..003badb74d 100644
--- a/tensorflow/core/kernels/gather_nd_op.h
+++ b/tensorflow/core/kernels/gather_nd_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_H_
-#define TENSORFLOW_KERNELS_GATHER_ND_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
// Functor definition for GatherOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -47,4 +47,4 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index dc028c2f1e..277ee2be02 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
-#define TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
// Specialization of GatherNdSlice to CPU
@@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
#endif
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
+
+#ifdef INTEL_MKL
+// Eigen implementation below is not highly performant. gather_nd_generator
+// does not seem to be called in parallel, leading to very poor performance.
+// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
+// needs to go through redundant operations like 'reshape', 'broadcast' and
+// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
+// is considerably more efficient.
+#pragma omp parallel for
+ for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
+ const Eigen::array<Eigen::DenseIndex, 1> loc{i};
+ gather_nd_generator(loc);
+ }
+#else // INTEL_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
+#endif
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
@@ -142,4 +157,4 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h
index 4b30c1f17f..1c80844085 100644
--- a/tensorflow/core/kernels/gemm_functors.h
+++ b/tensorflow/core/kernels/gemm_functors.h
@@ -24,6 +24,9 @@ limitations under the License.
#error "EIGEN_USE_THREADS must be enabled by all .cc files including this."
#endif // EIGEN_USE_THREADS
+#ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+#define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+
#include <string.h>
#include <map>
#include <vector>
@@ -116,3 +119,5 @@ class FastGemmFunctor<float, float, float> {
}
};
#endif // USE_CBLAS_GEMM
+
+#endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index c7dbefa0b4..86146f75f4 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -123,8 +123,7 @@ class AutoTuneMap {
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
- std::string(action).c_str(),
- params.ToString().c_str(),
+ string(action).c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
index ada96ae4ea..d0d5c3e018 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
-#define TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
#include <queue>
#include <utility>
@@ -56,4 +56,4 @@ class GraphTransferUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index e05de3fe8e..477e729dcb 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -161,7 +161,7 @@ Status GraphTransferer::LoadGraphFromProto(
for (const string& output_node_name : output_node_names) {
const TensorId tid = ParseTensorName(output_node_name);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const int node_id = node_name_to_id_cache_map_.at(node_name);
const Node* node = node_name_cache_list_.at(node_id);
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 86c1c5625f..4328d51916 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -228,4 +228,4 @@ class GraphTransferer {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index 1580b72605..cc469f6dba 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -168,7 +168,7 @@ bool HexagonControlWrapper::SetupGraph() {
new_output_node_info.set_output_count(0);
const TensorId tid = ParseTensorName(graph_output.name());
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
// Register node input for the new output node
const GraphTransferNodeInfo* node_info =
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
index 132cfde2db..1b382996f8 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
-#define TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#include <unordered_map>
#include <vector>
@@ -88,4 +88,4 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
index b9328c8e0e..270d697e96 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
@@ -55,4 +55,4 @@ class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
diff --git a/tensorflow/core/kernels/hexagon/soc_interface.h b/tensorflow/core/kernels/hexagon/soc_interface.h
index 062103ed98..d1a41d47c8 100644
--- a/tensorflow/core/kernels/hexagon/soc_interface.h
+++ b/tensorflow/core/kernels/hexagon/soc_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
-#define TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
// Declaration of APIs provided by hexagon shared library. This header is shared
// with both hexagon library built with qualcomm SDK and tensorflow.
@@ -111,4 +111,4 @@ void soc_interface_SetDebugFlag(uint64_t flag);
}
#endif // __cplusplus
-#endif // TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
diff --git a/tensorflow/core/kernels/hinge-loss.h b/tensorflow/core/kernels/hinge-loss.h
index d303e9c877..b12910d27d 100644
--- a/tensorflow/core/kernels/hinge-loss.h
+++ b/tensorflow/core/kernels/hinge-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_HINGE_LOSS_H_
-#define TENSORFLOW_KERNELS_HINGE_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
#include <algorithm>
#include <limits>
@@ -123,4 +123,4 @@ class HingeLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_HINGE_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
diff --git a/tensorflow/core/kernels/histogram_op.h b/tensorflow/core/kernels/histogram_op.h
index 1b253f7fed..b14fc2bee3 100644
--- a/tensorflow/core/kernels/histogram_op.h
+++ b/tensorflow/core/kernels/histogram_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_HISTOGRAM_OP_H_
-#define TENSORFLOW_HISTOGRAM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -35,4 +35,4 @@ struct HistogramFixedWidthFunctor {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_HISTOGRAM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
index 6072412689..b2329f4b61 100644
--- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h
+++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
-#define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -74,4 +74,4 @@ class IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/identity_n_op.h b/tensorflow/core/kernels/identity_n_op.h
index 490bbf456c..7339cbbe29 100644
--- a/tensorflow/core/kernels/identity_n_op.h
+++ b/tensorflow/core/kernels/identity_n_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
-#define TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -41,4 +41,4 @@ class IdentityNOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
diff --git a/tensorflow/core/kernels/identity_op.h b/tensorflow/core/kernels/identity_op.h
index f8856a1b9b..6b74868ad4 100644
--- a/tensorflow/core/kernels/identity_op.h
+++ b/tensorflow/core/kernels/identity_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IDENTITY_OP_H_
-#define TENSORFLOW_KERNELS_IDENTITY_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -37,4 +37,4 @@ class IdentityOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IDENTITY_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h
index 8dcb5977c6..1d4fa1a7db 100644
--- a/tensorflow/core/kernels/image_resizer_state.h
+++ b/tensorflow/core/kernels/image_resizer_state.h
@@ -18,8 +18,8 @@ limitations under the License.
// reduce code duplication and ensure consistency across the different
// resizers, it performs the input validation.
-#ifndef TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
-#define TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
+#define TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
#define EIGEN_USE_THREADS
@@ -191,4 +191,4 @@ struct ImageResizerGradientState {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+#endif // TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
diff --git a/tensorflow/core/kernels/immutable_constant_op.h b/tensorflow/core/kernels/immutable_constant_op.h
index 795331b4b2..97af8c7dc5 100644
--- a/tensorflow/core/kernels/immutable_constant_op.h
+++ b/tensorflow/core/kernels/immutable_constant_op.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
-#define TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
#include <memory>
@@ -46,4 +46,4 @@ class ImmutableConstantOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc
index 06d53eba30..fcf468f5a8 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.cc
+++ b/tensorflow/core/kernels/initializable_lookup_table.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/initializable_lookup_table.h"
-
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -32,6 +31,13 @@ Status InitializableLookupTable::Find(OpKernelContext* ctx, const Tensor& keys,
return DoFind(keys, values, default_value);
}
+Status InitializableLookupTable::ImportValues(OpKernelContext* ctx,
+ const Tensor& keys,
+ const Tensor& values) {
+ lookup::KeyValueTensorIterator iter(&keys, &values);
+ return Initialize(iter);
+}
+
Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
if (!iter.Valid()) {
return iter.status();
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index b4f81d9a70..424fe5df3c 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
-#define TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/platform/macros.h"
@@ -58,11 +58,7 @@ class InitializableLookupTable : public LookupInterface {
}
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
- const Tensor& values) final {
- return errors::Unimplemented(
- "ImportValues not supported by InitializableLookupTable "
- "implementations");
- }
+ const Tensor& values) final;
TensorShape key_shape() const final { return TensorShape(); }
@@ -155,7 +151,58 @@ class InitializableLookupTable : public LookupInterface {
bool is_initialized_ = false;
};
+// Iterator to initialize tables given 'keys' and 'values' tensors.
+//
+// The two tensors are returned in the first iteration. It doesn't loop
+// over each element of the tensor since insertions in the lookup table can
+// process batches.
+class KeyValueTensorIterator
+ : public InitializableLookupTable::InitTableIterator {
+ public:
+ // keys and values are not owned by the iterator.
+ explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
+ : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
+ TensorShape key_shape = keys_->shape();
+ if (!key_shape.IsSameSize(values_->shape())) {
+ valid_ = false;
+ status_ = errors::InvalidArgument(
+ "keys and values should have the same dimension.",
+ key_shape.DebugString(), " vs ", values_->shape().DebugString());
+ }
+ if (key_shape.num_elements() == 0) {
+ valid_ = false;
+ status_ =
+ errors::InvalidArgument("keys and values cannot be empty tensors.");
+ }
+ }
+
+ bool Valid() const override { return valid_; }
+
+ void Next() override {
+ valid_ = false;
+ status_ = errors::OutOfRange("No more data.");
+ }
+
+ const Tensor& keys() const override { return *keys_; }
+
+ const Tensor& values() const override { return *values_; }
+
+ Status status() const override { return status_; }
+
+ int64 total_size() const override {
+ return keys_ == nullptr ? -1 : keys_->NumElements();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
+
+ const Tensor* keys_; // Doesn't own it.
+ const Tensor* values_; // Doesn't own it.
+ bool valid_; // true if the iterator points to an existing range.
+ Status status_;
+};
+
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
diff --git a/tensorflow/core/kernels/inplace_ops_functor.h b/tensorflow/core/kernels/inplace_ops_functor.h
index b806787e91..2023869f49 100644
--- a/tensorflow/core/kernels/inplace_ops_functor.h
+++ b/tensorflow/core/kernels/inplace_ops_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -46,4 +46,4 @@ Status DoCopy(const Device& device, const Tensor& x, Tensor* y);
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h
index 4953aa237c..465ef96a51 100644
--- a/tensorflow/core/kernels/l2loss_op.h
+++ b/tensorflow/core/kernels/l2loss_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_
-#define TENSORFLOW_KERNELS_L2LOSS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -30,4 +30,4 @@ struct L2LossOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h
index f7c3f1950b..692f916439 100644
--- a/tensorflow/core/kernels/linalg_ops_common.h
+++ b/tensorflow/core/kernels/linalg_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
// Classes to support linear algebra functionality, similar to the numpy.linalg
// module. Supports batch computation on several matrices at once, sharding the
@@ -194,4 +194,4 @@ extern template class LinearAlgebraOp<complex128>;
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
-#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 84fa63fc00..bca1cff41c 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -588,7 +588,11 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListStack<CPUDevice, T>)
+ TensorListStack<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListGather<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU);
REGISTER_TENSOR_LIST_STACK_CPU(quint8);
@@ -604,7 +608,11 @@ REGISTER_TENSOR_LIST_STACK_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListFromTensor<CPUDevice, T>)
+ TensorListFromTensor<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListScatter<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8);
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index 0ea9362cbe..c591226b76 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -40,7 +40,12 @@ typedef Eigen::GpuDevice GPUDevice;
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU), \
- TensorListStack<GPUDevice, T>)
+ TensorListStack<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("indices"), \
+ TensorListGather<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_STACK_GPU);
REGISTER_TENSOR_LIST_STACK_GPU(bfloat16);
@@ -71,7 +76,13 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU) \
.HostMemory("element_shape"), \
- TensorListFromTensor<GPUDevice, T>)
+ TensorListFromTensor<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("element_shape") \
+ .HostMemory("indices"), \
+ TensorListScatter<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_GPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bfloat16);
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 42871c6113..72581c9293 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -134,6 +134,74 @@ class TensorListStack : public OpKernel {
};
template <typename Device, typename T>
+class TensorListGather : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+ explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
+ Tensor indices = c->input(1);
+ TensorShape resulting_shape;
+ resulting_shape.AddDim(indices.NumElements());
+ for (TensorShapeDim s : l->element_shape) {
+ resulting_shape.AddDim(s.size);
+ }
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
+
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(l->tensors.size());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(
+ c, i < l->tensors.size(),
+ errors::InvalidArgument("Index ", i, " out o range; list only has ",
+ l->tensors.size(), " elements."));
+ const Tensor& t = l->tensors[i];
+ OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ t.shaped<T, 2>({1, t.NumElements()})));
+ }
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
+
+#if GOOGLE_CUDA
+ if (std::is_same<Device, Eigen::GpuDevice>::value) {
+ ConcatGPU<T>(c, inputs_flat, output, &output_flat);
+ return;
+ }
+#endif // GOOGLE_CUDA
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+template <typename Device, typename T>
class TensorListFromTensor : public OpKernel {
public:
TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
@@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel {
}
};
+template <typename Device, typename T>
+class TensorListScatter : public OpKernel {
+ public:
+ TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ Tensor* output_tensor;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
+ Tensor indices = c->input(1);
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
+ TensorList output_list;
+ const Tensor& t = c->input(0);
+ output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
+ TensorShape output_shape(t.shape());
+ output_shape.RemoveDim(0);
+ OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
+ errors::InvalidArgument(
+ "Specified a list with shape ", element_shape.DebugString(),
+ " from a tensor with shape ", output_shape.DebugString()));
+ output_list.element_shape = element_shape;
+ output_list.tensors.reserve(indices.NumElements());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(c, i < t.shape().dim_size(0),
+ errors::InvalidArgument("Trying to scatter index ", i,
+ " from tensor with ",
+ t.shape().dim_size(0), " rows."));
+ Tensor tmp = t.Slice(i, i + 1);
+ TensorShape tmp_shape = tmp.shape();
+ tmp_shape.RemoveDim(0);
+ OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
+ errors::Unknown("Unexpected shape error."));
+ // TODO(apassos) maybe not always align; but weird compiler bugs seem to
+ // prevent this.
+ Tensor aligned;
+ OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
+ // TODO(apassos) do all slices in a single kernel invocation instead of
+ // many small ondes.
+ aligned.flat<T>().device(c->eigen_device<Device>()) =
+ tmp.unaligned_flat<T>();
+ output_list.tensors.push_back(aligned);
+ }
+ output_tensor->scalar<Variant>()() = std::move(output_list);
+ }
+};
+
template <typename Device>
Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
const TensorList& b, TensorList* out) {
@@ -253,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
- TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
+ AllocatorAttributes attr;
+ if (t.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(
+ c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
@@ -261,14 +387,29 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
out_tensor.flat<dtype>().constant(dtype(0)); \
break;
- TF_CALL_NUMBER_TYPES(DTYPE_CASE)
+ TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
+
+ case DataTypeToEnum<Variant>::value: {
+ const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
+ if (inner_x == nullptr) {
+ return errors::InvalidArgument("Input handle is not a list. Saw: '",
+ t.scalar<Variant>()().DebugString(),
+ "'");
+ }
+ TensorList inner_y;
+ TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
+ out_tensor.scalar<Variant>()() = std::move(inner_y);
+ break;
+ }
+
default:
return errors::InvalidArgument(
- "Trying to compute zeros_like for unsupported dtype",
- out_tensor.dtype());
+ "Trying to compute zeros_like for unsupported dtype ",
+ DataTypeString(out_tensor.dtype()));
}
+ y->tensors.emplace_back(out_tensor);
}
return Status::OK();
}
diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h
index 6479e6f5dc..9198a98e47 100644
--- a/tensorflow/core/kernels/logistic-loss.h
+++ b/tensorflow/core/kernels/logistic-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
-#define TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
#include <cmath>
@@ -86,7 +86,7 @@ class LogisticLossUpdater : public DualLossUpdater {
} else {
inverse_exp_term = 1 / (1 + exp(label * wx));
}
- return inverse_exp_term * label * example_weight;
+ return -inverse_exp_term * label * example_weight;
}
// The smoothness constant is 4 since the derivative of logistic loss, which
@@ -131,4 +131,4 @@ class LogisticLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index b352dd257c..6e77e1ee01 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -74,13 +74,11 @@ class InitializeTableOp : public OpKernel {
"Keys and values must have the same size ",
keys.NumElements(), " vs ", values.NumElements()));
- lookup::KeyValueTensorIterator iter(&keys, &values);
-
int memory_used_before = 0;
if (ctx->track_allocations()) {
memory_used_before = table->MemoryUsed();
}
- OP_REQUIRES_OK(ctx, table->Initialize(iter));
+ OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
diff --git a/tensorflow/core/kernels/lookup_table_init_op.h b/tensorflow/core/kernels/lookup_table_init_op.h
index 177a26daa8..101e528659 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.h
+++ b/tensorflow/core/kernels/lookup_table_init_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
-#define TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
#include "tensorflow/core/kernels/initializable_lookup_table.h"
@@ -30,4 +30,4 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 2e8d9c623c..a495758861 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
Tensor* keys;
@@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
private:
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
std::unordered_map<K, V> table_ GUARDED_BY(mu_);
};
@@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
int64 value_dim = value_shape_.dim_size(0);
@@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface {
private:
TensorShape value_shape_;
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
size_t size() const override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return num_entries_;
}
@@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface {
auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
const auto default_flat = default_value.flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
const auto key_buckets_matrix =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
const auto value_buckets_matrix =
@@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 3977f16299..9451247f26 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
-#define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -102,9 +102,12 @@ class LookupTableOp : public OpKernel {
~LookupTableOp() override {
// If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(
- cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
- cinfo_.container(), cinfo_.name()));
+ if (!cinfo_.resource_manager()
+ ->template Delete<lookup::LookupInterface>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
}
}
@@ -272,4 +275,4 @@ class HashTable : public InitializableLookupTable {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h
index 894769960a..ec28cf9fa7 100644
--- a/tensorflow/core/kernels/lookup_util.h
+++ b/tensorflow/core/kernels/lookup_util.h
@@ -46,57 +46,6 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
int32 value_index, Env* env,
InitializableLookupTable* table);
-// Iterator to initialize tables given 'keys' and 'values' tensors.
-//
-// The two tensors are returned in the first iteration. It doesn't loop
-// over each element of the tensor since insertions in the lookup table can
-// process batches.
-class KeyValueTensorIterator
- : public InitializableLookupTable::InitTableIterator {
- public:
- // keys and values are not owned by the iterator.
- explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
- : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
- TensorShape key_shape = keys_->shape();
- if (!key_shape.IsSameSize(values_->shape())) {
- valid_ = false;
- status_ = errors::InvalidArgument(
- "keys and values should have the same dimension.",
- key_shape.DebugString(), " vs ", values_->shape().DebugString());
- }
- if (key_shape.num_elements() == 0) {
- valid_ = false;
- status_ =
- errors::InvalidArgument("keys and values cannot be empty tensors.");
- }
- }
-
- bool Valid() const override { return valid_; }
-
- void Next() override {
- valid_ = false;
- status_ = errors::OutOfRange("No more data.");
- }
-
- const Tensor& keys() const override { return *keys_; }
-
- const Tensor& values() const override { return *values_; }
-
- Status status() const override { return status_; }
-
- int64 total_size() const override {
- return keys_ == nullptr ? -1 : keys_->NumElements();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
-
- const Tensor* keys_; // Doesn't own it.
- const Tensor* values_; // Doesn't own it.
- bool valid_; // true if the iterator points to an existing range.
- Status status_;
-};
-
} // namespace lookup
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/loss.h b/tensorflow/core/kernels/loss.h
index a77aa7587b..7db348800e 100644
--- a/tensorflow/core/kernels/loss.h
+++ b/tensorflow/core/kernels/loss.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOSS_H_
-#define TENSORFLOW_KERNELS_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_LOSS_H_
#include "tensorflow/core/lib/core/status.h"
@@ -56,4 +56,4 @@ class DualLossUpdater {
};
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOSS_H_
diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc
index 460d65c5c2..9209ed2ab7 100644
--- a/tensorflow/core/kernels/loss_test.cc
+++ b/tensorflow/core/kernels/loss_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,6 +31,24 @@ namespace {
// TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton
// modification detailed in readme.md
+// This test checks that the dual value after update is optimal.
+// At the optimum the dual value should be the opposite of the primal gradient.
+// This does not hold at a point where the primal is not differentiable.
+void TestComputeUpdatedDual(const DualLossUpdater &loss_updater,
+ const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) {
+ double new_dual = loss_updater.ComputeUpdatedDual(
+ num_loss_partitions, label, example_weight, current_dual, wx,
+ weighted_example_norm);
+ // The primal gradient needs to be computed after the weight update.
+ double new_wx = wx + (new_dual - current_dual) * num_loss_partitions *
+ weighted_example_norm * example_weight;
+ EXPECT_NEAR(new_dual, -loss_updater.PrimalLossDerivative(new_wx, label, 1.0),
+ 1e-5);
+}
+
TEST(LogisticLoss, ComputePrimalLoss) {
LogisticLossUpdater loss_updater;
EXPECT_NEAR(0.693147,
@@ -65,19 +84,12 @@ TEST(LogisticLoss, ComputeDualLoss) {
TEST(LogisticLoss, ComputeUpdatedDual) {
LogisticLossUpdater loss_updater;
- EXPECT_NEAR(0.479,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.031,
- loss_updater.ComputeUpdatedDual(
- 2 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, 0.1 /* current_dual */,
- -0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, 0.1 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SquaredLoss, ComputePrimalLoss) {
@@ -126,19 +138,12 @@ TEST(SquaredLoss, ComputeDualLoss) {
TEST(SquaredLoss, ComputeUpdatedDual) {
SquaredLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(HingeLoss, ComputePrimalLoss) {
@@ -207,48 +212,27 @@ TEST(HingeLoss, ConvertLabel) {
TEST(HingeLoss, ComputeUpdatedDual) {
HingeLossUpdater loss_updater;
- // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and
- // weighted_example_norm=100.0, it turns out that the optimal value to update
- // the dual to is 0.507 which is within the permitted range and thus should be
- // the value returned.
+ // For the two tests belows, y*wx=1 after the update which is a
+ // non-differetiable point of the hinge loss and TestComputeUpdatedDual
+ // cannot be used. Check value of the dual variable instead.
EXPECT_NEAR(0.507,
loss_updater.ComputeUpdatedDual(
1 /* num partitions */, 1.0 /* label */,
1.0 /* example weight */, 0.5 /* current_dual */,
0.3 /* wx */, 100.0 /* weighted_example_norm */),
1e-3);
- // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6,
- // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that
- // the optimal value to update the dual to is 0.384 which is within the
- // permitted range and thus should be the value returned.
EXPECT_NEAR(-0.416,
loss_updater.ComputeUpdatedDual(
10 /* num partitions */, -1.0 /* label */,
1.0 /* example weight */, -0.4 /* current_dual */,
0.6 /* wx */, 10.0 /* weighted_example_norm */),
1e-3);
- // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range
- // and hence the closest permitted value (0.0) should be returned instead.
- EXPECT_NEAR(0.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, -0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0]
- // range and hence the closest permitted value (-1.0) should be returned
- // instead.
- EXPECT_NEAR(-1.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, -1.0 /* label */,
- 2.0 /* example weight */, -1.0 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, -0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, -1.0 /* label */,
+ 2.0 /* example weight */, -1.0 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SmoothHingeLoss, ComputePrimalLoss) {
@@ -297,19 +281,75 @@ TEST(SmoothHingeLoss, ComputeDualLoss) {
TEST(SmoothHingeLoss, ComputeUpdatedDual) {
SmoothHingeLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
+}
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
+TEST(PoissonLoss, ComputePrimalLoss) {
+ PoissonLossUpdater loss_updater;
+ EXPECT_NEAR(1.0,
+ loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
1e-3);
+ EXPECT_NEAR(21996.0,
+ loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
+ 1.0);
+ EXPECT_NEAR(0.606,
+ loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(6.64,
+ loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */,
+ 2.0 /* example weight */),
+ 1e-2);
+}
+
+TEST(PoissonLoss, ComputeDualLoss) {
+ PoissonLossUpdater loss_updater;
+ // Dual is undefined.
+ EXPECT_NEAR(
+ std::numeric_limits<double>::max(),
+ loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ 0.0,
+ loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -0.847,
+ loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -2.675,
+ loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+}
+
+TEST(PoissonLoss, ConvertLabel) {
+ PoissonLossUpdater loss_updater;
+ float example_label = -1.0;
+ // Negative label should throw an error.
+ Status status = loss_updater.ConvertLabel(&example_label);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(PoissonLoss, ComputeUpdatedDual) {
+ PoissonLossUpdater loss_updater;
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */,
+ 1.0 /* example weight */, 0.0 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
} // namespace
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index bdc3b5778f..dd89597369 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -410,8 +410,9 @@ class StagingMap : public ResourceBase {
copy_or_move_tensors(&it->second, *key, *indices, tuple));
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
@@ -444,8 +445,9 @@ class StagingMap : public ResourceBase {
*key = it->first;
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
index 628895ca86..4b74a64025 100644
--- a/tensorflow/core/kernels/matmul_op.h
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -117,4 +117,4 @@ typedef Eigen::GpuDevice GPUDevice;
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index 97cc950793..b04e36db8e 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct MatrixBandPartFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
diff --git a/tensorflow/core/kernels/matrix_diag_op.h b/tensorflow/core/kernels/matrix_diag_op.h
index 14095845b8..108ba0f56b 100644
--- a/tensorflow/core/kernels/matrix_diag_op.h
+++ b/tensorflow/core/kernels/matrix_diag_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
@@ -91,4 +91,4 @@ struct MatrixDiag {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_exponential_op.cc b/tensorflow/core/kernels/matrix_exponential_op.cc
index 99db898301..01d4894438 100644
--- a/tensorflow/core/kernels/matrix_exponential_op.cc
+++ b/tensorflow/core/kernels/matrix_exponential_op.cc
@@ -49,6 +49,7 @@ class MatrixExponentialOp : public LinearAlgebraOp<Scalar> {
TF_DISALLOW_COPY_AND_ASSIGN(MatrixExponentialOp);
};
+// Deprecated kernels (2018/08/21).
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<float>), float);
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<double>), double);
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<complex64>),
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h
index aeb144559f..341ef12e97 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.h
+++ b/tensorflow/core/kernels/matrix_set_diag_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct MatrixSetDiag {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
index 0e09078365..00a05a87a3 100644
--- a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
+++ b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Cholesky"
@@ -159,3 +162,5 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h
index f82e57d44c..2adb8081ce 100644
--- a/tensorflow/core/kernels/maxpooling_op.h
+++ b/tensorflow/core/kernels/maxpooling_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
-#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
#include "tensorflow/core/framework/numeric_types.h"
@@ -51,4 +51,4 @@ struct SpatialMaxPooling<Device, qint8> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
index 10e468ce46..693ed8a8f0 100644
--- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
+++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
@@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase {
// Exercises "delete_old_dirs".
for (int i = 0; i < 2; ++i) {
int directory_found =
- Env::Default()
- ->IsDirectory(std::string(io::Dirname(prefixes[i])))
- .code();
+ Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code();
if (delete_old_dirs) {
EXPECT_EQ(error::NOT_FOUND, directory_found);
} else {
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index 81150a9e79..cc4b6941b9 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
-#define TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -437,4 +437,4 @@ struct MirrorPadGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
diff --git a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
index f27ca139c9..98e3be082d 100644
--- a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
+++ b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
-#define TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
#define EIGEN_USE_THREADS
@@ -42,4 +42,4 @@ TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 28edf51546..20aa1f7ea1 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -392,16 +392,28 @@ class MklAddNOp : public OpKernel {
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
auto src1_tf_data_format =
MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
- auto src2_dims =
- TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ memory::dims src2_dims;
+ if (src2_tensor.dims() == 4) {
+ src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ } else {
+ src2_dims = TFShapeToMklDnnDimsInNCDHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ }
md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
auto src2_tf_data_format =
MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
- auto src1_dims =
- TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ memory::dims src1_dims;
+ if (src1_tensor.dims() == 4) {
+ src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ } else {
+ src1_dims = TFShapeToMklDnnDimsInNCDHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ }
md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index 969baecc51..2409f7e9dc 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -453,6 +453,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -473,23 +475,22 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
}
memory::dims filter_dims, strides, padding_left, padding_right;
+ // Get src/filter/stride/padding information
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get the input memory descriptor
- memory::desc input_md =
- dnn_shape_input.IsMklTensor()
- ? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
-
- // Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
+ memory::desc input_md = dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(),
+ this->data_format_mkldnn_);
// Get an average pooling primitive from the op pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -562,24 +563,30 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
orig_input_shape.AddDim(shape_vec(i));
}
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
memory::dims diff_dst_dims =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -664,6 +671,18 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklAvgPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 50c25e1da7..52157ed5fb 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -82,11 +82,11 @@ struct MklConvBwdFilterParams {
};
template <typename T>
-class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdFilterPrimitive(
- const MklConvBwdFilterParams& convBwdFilterDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_bwd_filter == nullptr) {
@@ -94,7 +94,7 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
}
}
- ~MklConv2DBwdFilterPrimitive() {}
+ ~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
@@ -297,38 +297,41 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DBwdFilterPrimitive<T>* Get(
- const MklConvBwdFilterParams& convBwdFilterDims) {
- MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
-
- // look into the pool for reusable primitive
- conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
- convBwdFilterDims));
-
- if (conv2d_bwd_filter == nullptr) {
- conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
- convBwdFilterDims);
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
- convBwdFilterDims, conv2d_bwd_filter);
+ static MklConvBwdFilterPrimitive<T>* Get(
+ const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
+
+ if (do_not_cache) { /* Create new primitive always */
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*> (
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
+ convBwdFilterDims));
+
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
+ }
}
- return conv2d_bwd_filter;
- }
+ return conv_bwd_filter;
+ }
private:
- MklConv2DBwdFilterPrimitiveFactory() {}
- ~MklConv2DBwdFilterPrimitiveFactory() {}
+ MklConvBwdFilterPrimitiveFactory() {}
+ ~MklConvBwdFilterPrimitiveFactory() {}
- static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
- static MklConv2DBwdFilterPrimitiveFactory instance_;
+ static MklConvBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConvBwdFilterPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
- string prefix = "conv2d_bwd_filter";
+ string prefix = "conv_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -342,14 +345,14 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdFilter(
+ MklPrimitive* GetConvBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
- void SetConv2dBwdFilter(
- const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
@@ -738,14 +741,13 @@ TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#else
template <typename Device, class T, bool biasEnabled>
-class MklConv2DCustomBackpropFilterOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropFilterOp
+ : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropFilterOp() {}
+ ~MklConvCustomBackpropFilterOp() {}
void Compute(OpKernelContext* context) {
try {
@@ -753,6 +755,9 @@ class MklConv2DCustomBackpropFilterOp
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
+ // This flag indicates Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -813,7 +818,10 @@ class MklConv2DCustomBackpropFilterOp
&fwd_dst_dims, &padding_left, &padding_right);
if (!context->status().ok()) return;
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
+
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
@@ -832,21 +840,24 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
TensorShape obp_tf_shape = GetTfShape(context, 2);
depth = (this->data_format_ == FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(isConv2D ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
- convBwdFilterDims);
- auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+
+ // MKL DNN allocates large buffers when a conv gradient filter primtive is
+ // created. So we don't cache conv backward primitives when the env
+ // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
+ conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims, do_not_cache);
+ auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
@@ -854,14 +865,26 @@ class MklConv2DCustomBackpropFilterOp
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
- // output_dims_mkl_order is in OIHW format.
- TensorShape diff_filter_tf_shape(
- {bwd_output_dims[MklDnnDims::Dim_H],
- bwd_output_dims[MklDnnDims::Dim_W],
- bwd_output_dims[MklDnnDims::Dim_I],
- bwd_output_dims[MklDnnDims::Dim_O]});
- AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
- diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ if (isConv2D) {
+ // Conv2D: output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape({bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ } else {
+ // Conv3D: output_dims_mkl_order is in OIDHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims3D::Dim3d_D],
+ bwd_output_dims[MklDnnDims3D::Dim3d_H],
+ bwd_output_dims[MklDnnDims3D::Dim3d_W],
+ bwd_output_dims[MklDnnDims3D::Dim3d_I],
+ bwd_output_dims[MklDnnDims3D::Dim3d_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ }
Tensor* diff_bias_tensor = nullptr;
if (biasEnabled) {
@@ -871,7 +894,7 @@ class MklConv2DCustomBackpropFilterOp
// check if src and diff_dst need reorder
T *src_data = nullptr;
- if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -882,7 +905,7 @@ class MklConv2DCustomBackpropFilterOp
T *diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
- conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ conv_bwd_filter->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -897,7 +920,7 @@ class MklConv2DCustomBackpropFilterOp
bool diff_filter_reorder_required = false;
T *diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
- conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
diff_filter_tensor);
@@ -915,16 +938,19 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
T* diff_bias_data = static_cast<T*>(const_cast<T*>(
diff_bias_tensor->flat<T>().data()));
- conv2d_bwd_filter->Execute(src_data, diff_filter_data,
- diff_bias_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
+ diff_dst_data);
} else {
- conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -947,7 +973,7 @@ class MklConv2DCustomBackpropFilterOp
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
CHECK(!filter_mkl_shape.IsMklTensor())
- << "Conv2DBackpropFilter: filter should not be in MKL Layout";
+ << "ConvBackpropFilter: filter should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -983,9 +1009,11 @@ class MklConv2DCustomBackpropFilterOp
return fwd_filter_dims;
}
- // Output layout is Tensorflow's filter layout (HWIO).
+ // Output layout is Tensorflow's filter layout
+ // Conv2D: HWIO; Conv3D: DHWIO
memory::format GetOutputFormat(const memory::format data_format) {
- return memory::format::hwio;
+ return (this->strides_.size() == 4) ? memory::format::hwio
+ : memory::format::dhwio;
}
// Allocate output tensor.
@@ -1027,24 +1055,27 @@ class MklConv2DCustomBackpropFilterOp
}
};
-#define REGISTER_MKL_FILTER_KERNELS(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilter") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \
- REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklDummyOp<CPUDevice, T>);
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, true>); \
+ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDummyOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 38e014d68e..c38c9cc27c 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_ML_ONLY
-/// utility classes enabling primitive reuse for backward conv2d ops.
+/// utility classes enabling primitive reuse for backward conv ops.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@@ -83,11 +83,11 @@ struct MklConvBwdInputParams {
};
template <typename T>
-class MklConv2DBwdInputPrimitive : public MklPrimitive {
+class MklConvBwdInputPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdInputPrimitive(
- const MklConvBwdInputParams& convBwdInputDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
Setup(convBwdInputDims);
}
}
- ~MklConv2DBwdInputPrimitive() {}
+ ~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
@@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
private:
- // Primitive reuse context for Conv2D Bwd Input op
+ // Primitive reuse context for Conv Bwd Input op
struct ConvBwdInputContext {
// expected memory format for this primitive instance
memory::format filter_fmt;
@@ -174,7 +174,6 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
};
-
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(new memory::desc(
@@ -235,38 +234,41 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
- MklConv2DBwdInputPrimitiveFactory() {}
- ~MklConv2DBwdInputPrimitiveFactory() {}
+ MklConvBwdInputPrimitiveFactory() {}
+ ~MklConvBwdInputPrimitiveFactory() {}
public:
- static MklConv2DBwdInputPrimitive<T>* Get(
- const MklConvBwdInputParams& convBwdInputDims) {
- MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
-
- // look into the pool for reusable primitive
- conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
- convBwdInputDims));
-
- if (conv2d_bwd_input == nullptr) {
- conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
- convBwdInputDims);
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
- convBwdInputDims, conv2d_bwd_input);
+ static MklConvBwdInputPrimitive<T>* Get(
+ const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
+
+ if (do_not_cache) { /* Always allocate primitive */
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
+ convBwdInputDims));
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
+ }
}
- return conv2d_bwd_input;
+
+ return conv_bwd_input;
}
private:
- static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
- static MklConv2DBwdInputPrimitiveFactory instance_;
+ static MklConvBwdInputPrimitiveFactory& GetInstance() {
+ static MklConvBwdInputPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
- string prefix = "conv2d_bwd_input";
+ string prefix = "conv_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -279,14 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims) {
+ MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) {
string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
- void SetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
@@ -594,23 +595,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
+#undef REGISTER_MKL_CPU_KERNELS
+
#else
template <typename Device, class T>
-class MklConv2DCustomBackpropInputOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropInputOp() {}
+ ~MklConvCustomBackpropInputOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
+ // This flag indicate Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -626,7 +638,7 @@ class MklConv2DCustomBackpropInputOp
diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
@@ -655,6 +667,7 @@ class MklConv2DCustomBackpropInputOp
}
return;
}
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
@@ -673,15 +686,18 @@ class MklConv2DCustomBackpropInputOp
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
@@ -689,18 +705,25 @@ class MklConv2DCustomBackpropInputOp
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims,
MklDnnType<T>(), tf_fmt);
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
- conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
- convBwdInputDims);
- auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+
+ // We don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // includes potentialy large buffers. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(fwd_filter_dims, strides));
+ conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
+ do_not_cache);
+ auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
@@ -723,7 +746,7 @@ class MklConv2DCustomBackpropInputOp
// check if filter and diff_dst need reorder
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
- conv2d_bwd_input->GetFilterMemoryFormat()) {
+ conv_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
@@ -733,8 +756,7 @@ class MklConv2DCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
- if (diff_dst_md.data.format !=
- conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -745,7 +767,12 @@ class MklConv2DCustomBackpropInputOp
}
// execute convolution input bwd
- conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) {
+ delete conv_bwd_input;
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -770,7 +797,7 @@ class MklConv2DCustomBackpropInputOp
// of the Tensor and never an actual tensor. So it will never be in MKL
// layout.
CHECK(!input_mkl_shape.IsMklTensor())
- << "Conv2DBackpropInput: input should not be in MKL Layout";
+ << "ConvBackpropInput: input should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -778,10 +805,10 @@ class MklConv2DCustomBackpropInputOp
const Tensor& input_tensor) {
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
- CHECK_EQ(
- TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
- .ok(),
- true);
+ // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
+ // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
+ // input_tensor.
+ CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}
@@ -792,7 +819,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
@@ -800,7 +827,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
@@ -839,17 +866,22 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_ML_ONLY
-
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
+#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index bca1aa21a8..184e0cb003 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -85,9 +85,9 @@ struct MklConvFwdParams {
};
template <typename T>
-class MklConv2DFwdPrimitive : public MklPrimitive {
+class MklConvFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -96,7 +96,7 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
}
}
- ~MklConv2DFwdPrimitive() {}
+ ~MklConvFwdPrimitive() {}
// Convolution forward execute with bias
// src_data: input data buffer of src
@@ -269,37 +269,41 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
-
- // try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
- conv2d_fwd);
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims,
+ bool do_not_cache) {
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
+
+ if (do_not_cache) { /* Always create new primitive */
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ } else {
+ // try to find a suitable one in pool
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
+ }
}
- return conv2d_fwd;
+
+ return conv_fwd;
}
private:
- MklConv2DFwdPrimitiveFactory() {}
- ~MklConv2DFwdPrimitiveFactory() {}
+ MklConvFwdPrimitiveFactory() {}
+ ~MklConvFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1;
- static MklConv2DFwdPrimitiveFactory& GetInstance() {
- static MklConv2DFwdPrimitiveFactory instance_;
+ static MklConvFwdPrimitiveFactory& GetInstance() {
+ static MklConvFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvFwdParams& convFwdDims) {
- string prefix = "conv2d_fwd_";
+ string prefix = "conv_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -313,12 +317,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
+ MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
+ void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -331,11 +335,11 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
#ifdef INTEL_MKL_ML_ONLY
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -755,21 +759,22 @@ class MklConv2DOp : public OpKernel {
#else
+// Base class for convolution forward operations
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES(context, strides_.size() == 4,
+ OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
@@ -778,20 +783,39 @@ class MklConv2DOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ } else if (strides_.size() == 5) {
+ OP_REQUIRES(context, dilations_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
+ GetTensorDim(dilations_, data_format_, 'C') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations rates in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(dilations_, data_format_, '0') > 0 &&
+ GetTensorDim(dilations_, data_format_, '1') > 0 &&
+ GetTensorDim(dilations_, data_format_, '2') > 0),
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -837,7 +861,8 @@ class MklConv2DOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst,
&dst_tensor, src_tf_shape, dst_mkl_shape);
- // MklConv2D also outputs converted filter as 2nd output of Conv2D.
+ // MklConv2D/3D also outputs converted filter
+ // as 2nd output of Conv2D/3D.
filter_mkl_shape.SetMklTensor(false);
Tensor* output_filter_tensor = nullptr;
AllocateOutputSetMklShape(context, kOutputIndex_Filter,
@@ -846,15 +871,20 @@ class MklConv2DOp : public OpKernel {
return;
}
+ bool isConv2D = (strides_.size() == 4);
+
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto tf_fmt = isConv2D ? TFDataFormatToMklDnnDataFormat(data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(data_format_);
// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
- // layout (NHWC or NCHW depending on data format).
+ // layout depending on data format:
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
@@ -864,31 +894,43 @@ class MklConv2DOp : public OpKernel {
auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(),
- memory::format::hwio);
-
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
// MKLDNN dilation starts from 0.
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
+
+ // In some cases, primitve descriptor includes potentialy large buffers,
+ // we don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(filter_dims, strides));
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
}
// allocate output tensors output_tensor and filter_out_tensor
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
- conv2d_fwd->GetPrimitiveDesc();
+ conv_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -900,7 +942,7 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder
T *src_data = nullptr;
- if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
+ if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -908,7 +950,7 @@ class MklConv2DOp : public OpKernel {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* filter_data = nullptr;
- if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
+ if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor));
@@ -918,17 +960,19 @@ class MklConv2DOp : public OpKernel {
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
-
// execute convolution
if (biasEnabled) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
T* bias_data = static_cast<T*>(const_cast<T*>(
bias_tensor.flat<T>().data()));
- conv2d_fwd->Execute(src_data, filter_data, bias_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
} else {
- conv2d_fwd->Execute(src_data, filter_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, dst_data);
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
@@ -1038,24 +1082,34 @@ class MklConv2DOp : public OpKernel {
#endif
-#define REGISTER_MKL_CPU(T) \
+// Register 2D operations
+#define REGISTER_MKL_CPU_2D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, false>); \
+ MklConvOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, true>); \
+ MklConvOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklDummyOp<CPUDevice, T>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_2D);
+
+// Register 3D operations
+#define REGISTER_MKL_CPU_3D(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvOp<CPUDevice, T, false>);
+TF_CALL_float(REGISTER_MKL_CPU_3D);
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 838c06f49d..01cc606f41 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -79,9 +79,16 @@ class MklDnnConvUtil {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
- int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- *strides = {stride_rows, stride_cols};
+ if (strides_.size() == 4) {
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ } else if (strides_.size() == 5) {
+ int stride_planes = GetTensorDim(strides_, data_format_, '0');
+ int stride_rows = GetTensorDim(strides_, data_format_, '1');
+ int stride_cols = GetTensorDim(strides_, data_format_, '2');
+ *strides = {stride_planes, stride_rows, stride_cols};
+ }
}
// Calculate Convolution dilations
@@ -89,13 +96,20 @@ class MklDnnConvUtil {
// For now we take the dilation from the second and third dimensions only
// (we do not support dilation on the batch or depth dimension).
CHECK_NOTNULL(dilations);
- int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
- int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
- *dilations = {dilations_rows, dilations_cols};
+ if (dilations_.size() == 4) {
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ } else if (dilations_.size() == 5) {
+ int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
+ int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
+ *dilations = {dilations_planes, dilations_rows, dilations_cols};
+ }
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
- // requires input in NCHW format. Function does not return anything.
+ // requires input in NCHW/NCDHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
@@ -113,40 +127,62 @@ class MklDnnConvUtil {
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
- // Input rows/height
- int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
- CHECK_BOUNDS(input_rows_raw, "Input rows too large");
- int input_rows = static_cast<int>(input_rows_raw);
-
- // Input columns/width
- int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
- CHECK_BOUNDS(input_cols_raw, "Input cols too large");
- int input_cols = static_cast<int>(input_cols_raw);
-
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
+ if (strides_.size() == 4) { // NCHW format for Conv2D
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCHW format Conv2D.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ } else if (strides_.size() == 5) { // NCDHW format for Conv3D
+ // Input planes/third-dimension
+ int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
+ CHECK_BOUNDS(input_planes_raw, "Input depth too large");
+ int input_planes = static_cast<int>(input_planes_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCDHW format for Conv3D.
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ }
#undef CHECK_BOUNDS
-
- // MKL-DNN always requires input in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
-
- *input_dims = mkldnn_sizes;
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
- //
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
+ // Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
@@ -159,11 +195,13 @@ class MklDnnConvUtil {
memory::dims* filter_dims) {
CHECK_NOTNULL(filter_dims);
- OP_REQUIRES(context_, filter_shape.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
+ OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
+ errors::InvalidArgument((strides_.size() == 4)
+ ? "filter must be 4-dimensional: "
+ : "filter must be 5-dimensional: ",
filter_shape.DebugString()));
- for (int i = 0; i < 3; i++) {
+ for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
OP_REQUIRES(context_,
FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
@@ -172,32 +210,57 @@ class MklDnnConvUtil {
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
-
- // TF filter is always in (rows, cols, in_depth, out_depth) order.
- int filter_rows = static_cast<int>(filter_shape.dim_size(0));
- int filter_cols = static_cast<int>(filter_shape.dim_size(1));
- int in_depth = static_cast<int>(filter_shape.dim_size(2));
- int out_depth = static_cast<int>(filter_shape.dim_size(3));
-
- // MKL-DNN always needs filter in OIHW format.
- // OIHW = (out_depth, in_depth, rows, cols)
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
-
- *filter_dims = mkldnn_sizes;
+ if (strides_.size() == 4) { // Conv2D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ } else { // Conv3D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(3)));
+
+ // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
+ int filter_planes = static_cast<int>(filter_shape.dim_size(0));
+ int filter_rows = static_cast<int>(filter_shape.dim_size(1));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(2));
+ int in_depth = static_cast<int>(filter_shape.dim_size(3));
+ int out_depth = static_cast<int>(filter_shape.dim_size(4));
+
+ // MKL-DNN always needs filter in OIDHW format.
+ // OIDHW = (out_depth, in_depth, planes, rows, cols)
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ }
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
+ // Function does not return anything. But errors arising from sanity
+ // checks are returned in context's status.
virtual inline void GetFilterSizeInMklOrder(size_t src_index,
size_t filter_index,
memory::dims* filter_dims) {
@@ -206,8 +269,8 @@ class MklDnnConvUtil {
GetTfShape(context_, filter_index), filter_dims);
}
- // Calculate Bias size for 2D Convolution. Function does not return
- // anything, but sets error in context status.
+ // Calculate Bias size for 2D or 3D Convolution. Function does not
+ // return anything, but may set an error in context status.
virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
memory::dims* bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
@@ -218,73 +281,142 @@ class MklDnnConvUtil {
*bias_dims = {static_cast<int>(bias.dim_size(0))};
}
- // Function to calculate output and padding size for 2D convolution.
+ // Function to calculate output and padding size for 2D/3D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function also calculates
- // left, right, top and bottom pads. Function does not return any status -
- // status is returned via context status.
+ // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
+ // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
+ // NDHWC||NCDHW(Conv3D) format depending on data format.
+ // Function also calculates left, right, top and bottom pads.
+ // Function does not return any status which is set with context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order,
- memory::dims* output_dims_mkl_order, memory::dims* pad_l,
- memory::dims* pad_r) {
+ memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
+ memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- int input_rows = GetTensorDim(input_shape, data_format_, 'H');
- int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ bool isConv2D = (strides_.size() == 4);
+ int input_planes, input_rows, input_cols;
+ if (isConv2D) {
+ input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ } else {
+ input_planes = GetTensorDim(input_shape, data_format_, '0');
+ input_rows = GetTensorDim(input_shape, data_format_, '1');
+ input_cols = GetTensorDim(input_shape, data_format_, '2');
+ }
- // The first dimension for filter is rows/height.
- int filter_rows = filter_shape.dim_size(0);
- // The second dimension for filter is cols/width.
- int filter_cols = filter_shape.dim_size(1);
+ // Filter dimension
+ // Conv2D:
+ // First dimension: rows/height.
+ // Second dimension: cols/width.
+ // Conv3D:
+ // First dimension: planes/depth.
+ // Second dimension: rows/height.
+ // Third dimension: cols/width.
+
+ int filter_planes, filter_rows, filter_cols;
+ if (isConv2D) {
+ filter_rows = filter_shape.dim_size(0);
+ filter_cols = filter_shape.dim_size(1);
+ } else {
+ filter_planes = filter_shape.dim_size(0);
+ filter_rows = filter_shape.dim_size(1);
+ filter_cols = filter_shape.dim_size(2);
+ }
- // Stride is vector of 2 elements: {s_r, s_c}
- int stride_rows = strides[0];
- int stride_cols = strides[1];
- int dilation_rows = dilations[0];
- int dilation_cols = dilations[1];
+ int stride_planes, stride_rows, stride_cols;
+ int dilation_planes, dilation_rows, dilation_cols;
+ if (isConv2D) {
+ // Conv2D stride is a vector of 2 elements: {s_r, s_c}
+ stride_rows = strides[0];
+ stride_cols = strides[1];
+ dilation_rows = dilations[0];
+ dilation_cols = dilations[1];
+ } else {
+ // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
+ stride_planes = strides[0];
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ dilation_planes = dilations[0];
+ dilation_rows = dilations[1];
+ dilation_cols = dilations[2];
+ }
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+
// Output depth is same as last dimension for filter.
- int out_depth = filter_shape.dim_size(3);
+ int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4);
- int64 out_rows = 0, out_cols = 0;
+ int64 out_rows = 0, out_cols = 0, out_planes = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+ int64 pad_D1, pad_D2;
+
+ if (isConv2D) {
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_rows, filter_rows, dilation_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_cols, filter_cols, dilation_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ } else {
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_planes, filter_planes, stride_planes,
+ padding_, &out_planes, &pad_D1, &pad_D2));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ }
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
- dilation_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
- dilation_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
-
- // Tensorflow output is in data_format order. (NHWC or NCHW)
+ // Tensorflow output is in data_format order.
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
+ // MKL-DNN uses asymetric padding.
TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ isConv2D
+ ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
+ out_depth)
+ : ShapeFromFormat(data_format_, out_batch,
+ {{out_planes, out_rows, out_cols}}, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
- // MKL-DNN always needs output in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
- mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
- *output_dims_mkl_order = mkldnn_sizes;
-
- // Now handle padding. MKL-DNN uses asymetric padding.
- *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
- *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ if (isConv2D) {
+ // For Conv2D, MKL-DNN always needs output in NCHW format.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ } else {
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
+ static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
+ static_cast<int>(pad_right)};
+ }
}
// Calculate output and pad size of forward Convolution operator.
@@ -292,10 +424,10 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index,
- const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
- memory::dims* pad_l, memory::dims* pad_r) {
+ size_t src_index, size_t filter_index, const memory::dims& strides,
+ const memory::dims& dilations, memory::dims* output_dims_tf_order,
+ memory::dims* output_dims_mkl_order, memory::dims* pad_l,
+ memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -304,9 +436,17 @@ class MklDnnConvUtil {
auto input_tf_shape = GetTfShape(context_, src_index);
auto filter_tf_shape = GetTfShape(context_, filter_index);
- OP_REQUIRES(context_, input_tf_shape.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ if (strides_.size() == 4) {
+ // Conv2D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_tf_shape.DebugString()));
+ } else {
+ // Conv3D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input_tf_shape.DebugString()));
+ }
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, dilations, output_dims_tf_order,
@@ -314,9 +454,11 @@ class MklDnnConvUtil {
}
// Wrapper function to calculate input, filter, and output sizes of
- // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
- // Function also calculates output shape in Tensorflow order. Additionally, it
- // also calculates strides and paddings for 2D Convolution.
+ // Conv2D/Conv3D in MKL order:
+ // Conv2D: NCHW for input and output; OIHW for filter.
+ // Conv3D: NCDHW for input and output; OIDHW for filter.
+ // Function also calculates output shape in Tensorflow order.
+ // Additionally, it also calculates strides and paddings.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(
@@ -349,16 +491,15 @@ class MklDnnConvUtil {
}
};
-
/////////////////////////////////////////////////////////////////////
-/// Common class that implements Conv2DBackpropFilter and Input
+/// Common class that implements ConvBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
template <typename Device, class T>
-class MklConv2DBackpropCommonOp : public OpKernel {
+class MklConvBackpropCommonOp : public OpKernel {
public:
- ~MklConv2DBackpropCommonOp() {}
- explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
+ ~MklConvBackpropCommonOp() {}
+ explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
@@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ // Check Conv2D dilations
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 06ce820ae9..84ee241b8e 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -296,7 +296,9 @@ class MklInputConversionOp : public OpKernel {
// implementation.
TensorShape tf_shape0 = input_shape_0.GetTfShape();
TensorShape tf_shape1 = input_shape_1.GetTfShape();
- if (tf_shape0 == tf_shape1) {
+ TensorShape tensor_shape0 = input_tensor_0.shape();
+ TensorShape tensor_shape1 = input_tensor_1.shape();
+ if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
auto input0_md = input_shape_0.GetMklLayout();
auto input1_md = input_shape_1.GetMklLayout();
@@ -350,7 +352,8 @@ class MklInputConversionOp : public OpKernel {
}
// Sanity check
- bool mkl_shapes_are_same = input_shape_0 == input_shape_1;
+ bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
+ (tensor_shape0 == tensor_shape1));
if (mkl_shapes_are_same) {
CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
"different but MKL shapes are same";
@@ -403,7 +406,8 @@ class MklInputConversionOp : public OpKernel {
}
// Broadcast is needed if the shapes are not the same
- if (mkl_shape->GetTfShape().num_elements() == tf_tensor->shape().num_elements() ) {
+ if (mkl_shape->GetTfShape().num_elements() ==
+ tf_tensor->shape().num_elements()) {
// Both shapes are same, convert the TF input to MKL
VLOG(1) << "MklInputConversionOp: No broadcast needed.";
VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
@@ -437,16 +441,17 @@ class MklInputConversionOp : public OpKernel {
bool reordered = tf_input.CheckReorderToOpMem(
memory::primitive_desc(output_mkl_md, cpu_engine),
tensor_out, &net);
- if(!reordered) {
+
+ if (!reordered) {
// This is the case that the TF tensor has the same shape and format of
// mkl tensor. However, tf_tensor can not be simply forwarded to the
// output tensor since mkl data tensor is always one dimensional tensor.
// Tensor::CopyFrom shares the buffer of the other tensor while set its
// shape to the other tensor.
CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
- }
- else
+ } else {
stream(stream::kind::eager).submit(net).wait();
+ }
// -- The tensor in MKL format passes through --
ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index e149f003e5..256d48f4d5 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -524,6 +524,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -547,20 +549,26 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
memory::desc input_md =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ : is_pool2d ? memory::desc(
+ TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_)
+ : memory::desc(
+ TFShapeToMklDnnDimsInNCDHW(
+ input_tensor_shape, this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_);
// Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
-
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get a pooling op from the cached pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -663,23 +671,30 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolParameters pool_params;
TensorShape orig_input_shape = orig_input_tensor.shape();
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
- memory::dims diff_dst_dims =
- grad_mkl_shape.IsMklTensor()
- ? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
+
+ memory::dims diff_dst_dims =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetSizesAsMklDnnDims()
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -715,7 +730,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
void* ws_data = static_cast<void*>(
const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
- ;
+
auto ws_md =
pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
@@ -817,6 +832,18 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklMaxPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index d7ad3f9dcd..5398e6113f 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::pooling_avg;
using mkldnn::pooling_avg_exclude_padding;
@@ -34,11 +34,11 @@ using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
- if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
- fwdParams.alg_kind != pooling_avg_include_padding &&
- fwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(fwdParams.alg_kind == pooling_max ||
+ fwdParams.alg_kind == pooling_avg ||
+ fwdParams.alg_kind == pooling_avg_include_padding ||
+ fwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
// create memory desc
@@ -46,9 +46,10 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
+ bool is_2d = (fwdParams.src_dims.size() == 4);
context_.src_md.reset(
new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
- get_desired_format(fwdParams.src_dims[1])));
+ get_desired_format(fwdParams.src_dims[1], is_2d)));
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
memory::format::any));
@@ -61,7 +62,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
// store expected primitive format
- context_.src_fmt = get_desired_format(fwdParams.src_dims[1]);
+ context_.src_fmt = get_desired_format(fwdParams.src_dims[1], is_2d);
context_.dst_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->dst_primitive_desc().desc().data.format);
@@ -101,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -110,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -119,19 +120,21 @@ template class MklPoolingFwdPrimitive<float>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
- if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
- bwdParams.alg_kind != pooling_avg_include_padding &&
- bwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(bwdParams.alg_kind == pooling_max ||
+ bwdParams.alg_kind == pooling_avg ||
+ bwdParams.alg_kind == pooling_avg_include_padding ||
+ bwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
+ // check whether it is 2d or 3d
+ bool is_2d = (bwdParams.dst_dims.size() == 4);
// Create memory desc
context_.diff_src_md.reset(new memory::desc(
{bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(
new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
- get_desired_format(bwdParams.dst_dims[1])));
+ get_desired_format(bwdParams.dst_dims[1], is_2d)));
context_.bwd_desc.reset(new pooling_backward::desc(
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
@@ -151,7 +154,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
// store expected primitive format
context_.diff_src_fmt = static_cast<mkldnn::memory::format>(
context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format);
- context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]);
+ context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1], is_2d);
// create MKL-DNN internal memory object with dummy data
context_.diff_src_mem.reset(
@@ -165,7 +168,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
if (bwdParams.alg_kind == pooling_max) {
auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
- context_.ws_fmt = get_desired_format(context_.ws_dims[1]);
+ context_.ws_fmt = get_desired_format(context_.ws_dims[1], is_2d);
context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
context_.ws_mem.reset(new memory(
{{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine},
@@ -187,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
@@ -196,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -211,13 +214,22 @@ void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format,
const TensorShape& tensor_in_shape) {
- // For maxpooling, tensor_in should have 4 dimensions.
- OP_REQUIRES(context, tensor_in_shape.dims() == 4,
- errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ // For maxpooling, tensor_in should have 4 or 5 dimensions.
+ OP_REQUIRES(context,
+ tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 4 or 5-dimensional"));
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
- tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
- tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ if (tensor_in_shape.dims() == 4) {
+ // Pool2D
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ } else {
+ // Pool3D
+ tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
+ }
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
Init(context, ksize, stride, padding, data_format);
@@ -246,10 +258,20 @@ void MklPoolParameters::Init(OpKernelContext* context,
TensorFormat data_format,
const MklDnnShape* mklInputShape) {
// Get the input sizes
- depth = mklInputShape->GetDimension('C');
- tensor_in_cols = mklInputShape->GetDimension('W');
- tensor_in_rows = mklInputShape->GetDimension('H');
- tensor_in_batch = mklInputShape->GetDimension('N');
+ if (ksize.size() == 4) {
+ // Pool2D
+ depth = mklInputShape->GetDimension('C');
+ tensor_in_cols = mklInputShape->GetDimension('W');
+ tensor_in_rows = mklInputShape->GetDimension('H');
+ tensor_in_batch = mklInputShape->GetDimension('N');
+ } else {
+ // Pool3D
+ depth = mklInputShape->GetDimension3D('C');
+ tensor_in_cols = mklInputShape->GetDimension3D('W');
+ tensor_in_rows = mklInputShape->GetDimension3D('H');
+ tensor_in_planes = mklInputShape->GetDimension3D('D');
+ tensor_in_batch = mklInputShape->GetDimension3D('N');
+ }
Init(context, ksize, stride, padding, data_format);
}
@@ -262,25 +284,58 @@ void MklPoolParameters::Init(OpKernelContext* context,
// Get the data format
this->data_format = data_format;
- // Get the output sizes
- window_rows = GetTensorDim(ksize, data_format, 'H');
- window_cols = GetTensorDim(ksize, data_format, 'W');
- depth_window = GetTensorDim(ksize, data_format, 'C');
-
- // Get the strides
- row_stride = GetTensorDim(stride, data_format, 'H');
- col_stride = GetTensorDim(stride, data_format, 'W');
- depth_stride = GetTensorDim(stride, data_format, 'C');
+ bool is_pool2d = (ksize.size() == 4);
+ if (is_pool2d) {
+ // Pool2D
+ // Get the output sizes
+ window_rows = GetTensorDim(ksize, data_format, 'H');
+ window_cols = GetTensorDim(ksize, data_format, 'W');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ row_stride = GetTensorDim(stride, data_format, 'H');
+ col_stride = GetTensorDim(stride, data_format, 'W');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 2D pooling across width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
+ errors::Unimplemented(
+ "MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height."));
+ } else {
+ // Pool3D
+ // Get the output sizes
+ window_planes = GetTensorDim(ksize, data_format, '0');
+ window_rows = GetTensorDim(ksize, data_format, '1');
+ window_cols = GetTensorDim(ksize, data_format, '2');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ planes_stride = GetTensorDim(stride, data_format, '0');
+ row_stride = GetTensorDim(stride, data_format, '1');
+ col_stride = GetTensorDim(stride, data_format, '2');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 3D pooling across depth/width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 ||
+ (window_rows == 1 && window_cols == 1 && window_planes == 1)),
+ errors::Unimplemented(
+ "AvgPooling3D supports exactly one of pooling across depth "
+ "or pooling across depth/width/height."));
+ }
- // We only support 2D pooling across width/height and depthwise
- // pooling, not a combination.
- OP_REQUIRES(context,
- (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
- errors::Unimplemented(
- "MaxPooling supports exactly one of pooling across depth "
- "or pooling across width/height."));
+ if (depth_window == 1) { // we are pooling in the D (Pool3D only), H and W
+ if (!is_pool2d) {
+ OP_REQUIRES_OK(
+ context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes,
+ planes_stride, padding,
+ &out_planes, &pad_P1, &pad_P2));
+ }
- if (depth_window == 1) { // we are pooling in the H and W
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_top, &pad_bottom));
@@ -290,7 +345,14 @@ void MklPoolParameters::Init(OpKernelContext* context,
padding, &out_width, &pad_left, &pad_right));
#ifndef INTEL_MKL_ML_ONLY
// TF can work with int64, but mkldnn only supports int32
- // Fail if the height or width are greater than MAX_INT
+ // Fail if the depth, height or width are greater than MAX_INT
+ // We check depth only for 3D pooling case
+
+ if (!is_pool2d) {
+ OP_REQUIRES(context,
+ FastBoundsCheck(out_planes, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("output depth/planes is too large"));
+ }
OP_REQUIRES(context,
FastBoundsCheck(out_height, std::numeric_limits<int>::max()),
@@ -299,7 +361,6 @@ void MklPoolParameters::Init(OpKernelContext* context,
OP_REQUIRES(context,
FastBoundsCheck(out_width, std::numeric_limits<int>::max()),
errors::InvalidArgument("output width is too large"));
-
#endif
out_depth = depth; // output will have the same depth as the input
} else { // we are pooling in the depth dimension
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index ec7af5092d..49f799d7ba 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -19,6 +19,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include <memory>
#include <vector>
+#include <string>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ using mkldnn::stream;
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::memory;
using mkldnn::pooling_avg;
@@ -357,22 +358,28 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
int depth;
+ int tensor_in_planes; // Pool3D
int tensor_in_cols;
int tensor_in_rows;
int tensor_in_batch;
+ int window_planes; // Pool3D
int window_rows;
int window_cols;
int depth_window;
+ int planes_stride; // Pool3D
int row_stride;
int col_stride;
int depth_stride;
+ int64 out_planes; // Pool3D
int64 out_height;
int64 out_width;
int out_depth;
+ int64 pad_P1; // Pool3D
+ int64 pad_P2; // Pool3D
int64 pad_left;
int64 pad_right;
int64 pad_top;
@@ -382,18 +389,24 @@ struct MklPoolParameters {
TensorFormat data_format;
MklPoolParameters()
: depth(0),
+ tensor_in_planes(0),
tensor_in_cols(0),
tensor_in_rows(0),
tensor_in_batch(0),
+ window_planes(0),
window_rows(0),
window_cols(0),
depth_window(0),
+ planes_stride(0),
row_stride(0),
col_stride(0),
depth_stride(0),
+ out_planes(0),
out_height(0),
out_width(0),
out_depth(0),
+ pad_P1(0),
+ pad_P2(0),
pad_left(0),
pad_right(0),
pad_top(0),
@@ -433,20 +446,22 @@ class MklPoolingOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
errors::InvalidArgument("Invalid data format"));
- this->data_format_mkldnn_ =
- TFDataFormatToMklDnnDataFormat(this->data_format_tf_);
OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
- OP_REQUIRES(context, this->ksize_.size() == 4,
+ OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
- OP_REQUIRES(context, this->stride_.size() == 4,
+ OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
errors::Unimplemented("Pooling is not yet supported on the "
"batch dimension."));
+ bool is_pool2d = (this->ksize_.size() == 4);
+ this->data_format_mkldnn_ =
+ is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
// We may not get this attribute for this node if it does not go through
// graph rewrite pass. So we do not check for error while retrieving this
@@ -457,17 +472,26 @@ class MklPoolingOpBase : public OpKernel {
protected:
// Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function expects
- // output height and output width to have already been int32
- // bounds-checked
+ // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
+ // But TensorFlow output will be in NHWC/NCHW(Pool2D) or
+ // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
+ // output height and width to have already been int32 bounds-checked.
void GetOutputDims(const MklPoolParameters& mkl_pool_params,
memory::dims* output_dims_mkl_order) {
- // MKL-DNN always needs output in NCHW format.
- *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
- mkl_pool_params.out_depth,
- static_cast<int>(mkl_pool_params.out_height),
- static_cast<int>(mkl_pool_params.out_width)};
+ if (this->ksize_.size() == 4) {
+ // Pooling2D: MKL-DNN always needs output in NCHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ } else {
+ // Pooling3D: MKL-DNN always needs output in NCDHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_planes),
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ }
}
void InitMklPoolParameters(OpKernelContext* context,
@@ -485,14 +509,34 @@ class MklPoolingOpBase : public OpKernel {
void PoolParamsToDims(const MklPoolParameters* pool_params,
memory::dims* filter_dims, memory::dims* strides,
- memory::dims* padding_left,
- memory::dims* padding_right) {
- *filter_dims = {pool_params->window_rows, pool_params->window_cols};
- *strides = {pool_params->row_stride, pool_params->col_stride};
- *padding_left = {static_cast<int>(pool_params->pad_top),
- static_cast<int>(pool_params->pad_left)};
- *padding_right = {static_cast<int>(pool_params->pad_bottom),
- static_cast<int>(pool_params->pad_right)};
+ memory::dims* padding_left, memory::dims* padding_right,
+ bool is_pool2d) {
+ if (is_pool2d) {
+ // Pool2D
+ *filter_dims =
+ memory::dims({pool_params->window_rows, pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->row_stride, pool_params->col_stride});
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ } else {
+ // Pool3D
+ *filter_dims =
+ memory::dims({pool_params->window_planes, pool_params->window_rows,
+ pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->planes_stride, pool_params->row_stride,
+ pool_params->col_stride});
+
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
+ static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
+ static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ }
}
void AllocateEmptyOutputTensor(OpKernelContext* context,
@@ -556,12 +600,27 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
TensorShape input_tensor_shape = input_tensor.shape();
if (input_tensor.NumElements() != 0) {
memory::desc input_md =
- input_mkl_shape.IsMklTensor()
- ? input_mkl_shape.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetMklLayout()
+ : memory::desc(
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ MklDnnType<T>(), this->data_format_mkldnn_);
dnn_data_input->SetUsrMem(input_md, &input_tensor);
+
+ if (this->ksize_.size() == 5) {
+ // Pool3D
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
+ dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_);
+ }
}
this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
input_tensor_shape);
@@ -593,12 +652,13 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
const MklDnnShape& input_mkl_shape) {
if (!input_mkl_shape.IsMklTensor()) {
- OP_REQUIRES(context, input_tensor.dims() == 4,
- errors::InvalidArgument("Input must be 4-dimensional"));
+ OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
+ errors::InvalidArgument("Input must be 4 or 5-dimensional"));
} else {
- OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4,
+ OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4 ||
+ input_mkl_shape.GetDimension() == 5,
errors::InvalidArgument("Input shape must be "
- "4-dimensional"));
+ "4 or 5-dimensional"));
}
}
// .Input("value: T")
@@ -649,8 +709,12 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
input_gradient_mkl_shape.IsMklTensor()
? input_gradient_mkl_shape.GetMklLayout()
: memory::desc(
- TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
- this->data_format_tf_),
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(
+ input_gradient_tensor.shape(),
+ this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 05034894e5..84385356e1 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -30,10 +30,12 @@ using mkldnn::algorithm;
using mkldnn::eltwise_elu;
using mkldnn::eltwise_relu;
using mkldnn::eltwise_tanh;
+using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+using mkldnn::memory;
#else
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
@@ -42,6 +44,415 @@ using mkldnn::stream;
namespace tensorflow {
+#ifndef INTEL_MKL_ML_ONLY
+
+template <typename T>
+class MklEltwiseFwdParams {
+ public:
+ memory::dims src_dims; // check if this is needed
+ memory::desc src_md;
+ algorithm alg_kind;
+ T alpha;
+ T beta;
+
+ MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md,
+ algorithm alg_kind, T alpha, T beta)
+ : src_dims(src_dims),
+ src_md(src_md),
+ alg_kind(alg_kind),
+ alpha(alpha),
+ beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ // store expected format
+ context_.src_fmt =
+ static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
+ context_.fwd_stream.reset(new stream(stream::kind::eager));
+
+ // create eltwise primitive
+ if (context_.eltwise_fwd == nullptr) {
+ Setup(fwdParams);
+ }
+ }
+
+ ~MklEltwiseFwdPrimitive() {}
+
+ // Eltwise forward execute
+ // src_data: input data buffer of src
+ // dst_data: output data buffer of dst
+ void Execute(const T* src_data, T* dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+ context_.fwd_stream->submit(context_.fwd_primitives);
+
+ // after execution, set data handle back
+ context_.src_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
+ }
+
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
+ return context_.fwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
+
+ private:
+ // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
+ struct EltwiseFwdContext {
+ // expected memory format for this primitive instance
+ mkldnn::memory::format src_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<memory> src_mem;
+ std::shared_ptr<memory> dst_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
+
+ // memory desc
+ std::shared_ptr<memory::desc> src_md;
+ std::shared_ptr<memory::desc> dst_md;
+
+ // memory primitive desc
+ std::shared_ptr<memory::primitive_desc> src_mpd;
+
+ // Eltwise primitive
+ std::shared_ptr<mkldnn::primitive> eltwise_fwd;
+
+ std::shared_ptr<stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ EltwiseFwdContext()
+ : src_fmt(memory::format::any),
+ src_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ src_md(nullptr),
+ dst_md(nullptr),
+ src_mpd(nullptr),
+ eltwise_fwd(nullptr),
+ fwd_stream(nullptr) {}
+ };
+
+ // Eltwise forward primitive setup
+ void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
+ // create memory descriptors for eltwise data with specified format
+ context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+ context_.src_mpd.reset(
+ new memory::primitive_desc(*context_.src_md, cpu_engine_));
+
+ // create a eltwise
+ context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
+ prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
+ fwdParams.alpha, fwdParams.beta));
+ context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+
+ // create eltwise primitive and add it to net
+ context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
+ *context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
+
+ context_.fwd_primitives.push_back(*context_.eltwise_fwd);
+ }
+
+ struct EltwiseFwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklEltwiseFwdPrimitive<T>* Get(
+ const MklEltwiseFwdParams<T>& fwdParams) {
+ MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
+
+ auto src_fmt =
+ static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
+
+ // Get a eltwise fwd primitive from the cached pool
+ eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
+ src_fmt));
+ if (eltwise_forward == nullptr) {
+ eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
+ fwdParams, src_fmt, eltwise_forward);
+ }
+ return eltwise_forward;
+ }
+
+ static MklEltwiseFwdPrimitiveFactory& GetInstance() {
+ static MklEltwiseFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklEltwiseFwdPrimitiveFactory() {}
+ ~MklEltwiseFwdPrimitiveFactory() {}
+
+ static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt) {
+ string prefix = "eltwise_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
+ key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt) {
+ string key = CreateKey(fwdParams, src_fmt);
+ return this->GetOp(key);
+ }
+
+ void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt, MklPrimitive* op) {
+ string key = CreateKey(fwdParams, src_fmt);
+ this->SetOp(key, op);
+ }
+};
+
+template <typename T>
+class MklEltwiseBwdParams {
+ public:
+ memory::dims src_dims;
+ memory::desc common_md;
+ algorithm alg_kind;
+ T alpha;
+ T beta;
+
+ MklEltwiseBwdParams(const memory::dims& src_dims,
+ const memory::desc& common_md, algorithm alg_kind,
+ T alpha, T beta)
+ : src_dims(src_dims),
+ common_md(common_md),
+ alg_kind(alg_kind),
+ alpha(alpha),
+ beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseBwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.src_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ context_.diff_dst_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ context_.bwd_stream.reset(new stream(stream::kind::eager));
+ // create eltwise primitive
+ if (context_.eltwise_bwd == nullptr) {
+ Setup(bwdParams);
+ }
+ }
+
+ ~MklEltwiseBwdPrimitive() {}
+
+ // Eltwise backward execute
+ // src_data: input data buffer of src
+ // diff_dst_data: input data buffer of diff_dst
+ // diff_src_data: output data buffer of diff_src
+ void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+ context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
+ context_.bwd_stream->submit(context_.bwd_primitives);
+
+ // after execution, set data handle back
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ context_.diff_src_mem->set_data_handle(DummyData);
+ }
+
+ std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
+ return context_.bwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
+
+ memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
+
+ private:
+ // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
+ struct EltwiseBwdContext {
+ // expected memory format for this primitive instance
+ memory::format src_fmt;
+ memory::format diff_dst_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<memory> src_mem;
+ std::shared_ptr<memory> diff_dst_mem;
+ std::shared_ptr<memory> diff_src_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
+
+ // memory desc
+ std::shared_ptr<memory::desc> src_md;
+ std::shared_ptr<memory::desc> diff_dst_md;
+ std::shared_ptr<memory::desc> common_md;
+
+ // memory primitive desc
+ std::shared_ptr<memory::primitive_desc> src_mpd;
+ std::shared_ptr<memory::primitive_desc> diff_dst_mpd;
+
+ // fwd primitive desc
+ std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
+ std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd;
+
+ // Eltwise primitive
+ std::shared_ptr<mkldnn::primitive> eltwise_bwd;
+
+ std::shared_ptr<stream> bwd_stream;
+ std::vector<mkldnn::primitive> bwd_primitives;
+
+ EltwiseBwdContext()
+ : src_fmt(memory::format::any),
+ diff_dst_fmt(memory::format::any),
+ src_mem(nullptr),
+ diff_dst_mem(nullptr),
+ diff_src_mem(nullptr),
+ src_md(nullptr),
+ diff_dst_md(nullptr),
+ common_md(nullptr),
+ src_mpd(nullptr),
+ diff_dst_mpd(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ bwd_pd(nullptr),
+ eltwise_bwd(nullptr),
+ bwd_stream(nullptr) {}
+ };
+
+ // Eltwise backward primitive setup
+ void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
+ // create memory descriptors for eltwise data w/ no specified format
+ context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
+ context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
+
+ context_.src_mpd.reset(
+ new memory::primitive_desc(*context_.src_md, cpu_engine_));
+ context_.diff_dst_mpd.reset(
+ new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_));
+
+ // create forward eltwise primitive
+ context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
+ prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
+ bwdParams.alpha, bwdParams.beta));
+ context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+ context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
+ bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
+ bwdParams.alpha, bwdParams.beta));
+ context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc(
+ *context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
+ context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
+ context_.diff_src_mem.reset(new memory(
+ context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
+
+ // create eltwise primitive and add it to net
+ context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
+ *context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
+ *context_.diff_src_mem));
+
+ context_.bwd_primitives.push_back(*context_.eltwise_bwd);
+ }
+
+ struct EltwiseBwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ private:
+ MklEltwiseBwdPrimitiveFactory() {}
+ ~MklEltwiseBwdPrimitiveFactory() {}
+
+ public:
+ static MklEltwiseBwdPrimitive<T>* Get(
+ const MklEltwiseBwdParams<T>& bwdParams) {
+ MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
+
+ auto src_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ auto diff_dst_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+
+ // try to find a suitable one in pool
+ eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
+ MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
+ bwdParams, src_fmt, diff_dst_fmt));
+
+ if (eltwise_backward == nullptr) {
+ eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
+ MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
+ bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
+ }
+ return eltwise_backward;
+ }
+
+ static MklEltwiseBwdPrimitiveFactory& GetInstance() {
+ static MklEltwiseBwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt) {
+ string prefix = "eltwise_bwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(bwdParams.src_dims);
+ key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
+ key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
+ key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
+ key_creator.AddAsKey(static_cast<int>(src_fmt));
+ key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt) {
+ string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
+ return this->GetOp(key);
+ }
+
+ void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt, MklPrimitive* op) {
+ string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklReluHelpers {
@@ -375,55 +786,63 @@ class MklReluOpBase : public OpKernel {
~MklReluOpBase() {}
explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {}
-
virtual void Compute_Scalar(OpKernelContext* context) = 0;
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const size_t src_index = 0; // index of src input tensor
const size_t dst_index = 0; // index of dst output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);
- Tensor* dst_tensor = nullptr;
if (src_tensor.dims() == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
- // Create relu primitive.
- MklDnnData<T> src(&cpu_engine);
- MklDnnData<T> dst(&cpu_engine);
-
// Set DNN primitive - src
+ MklDnnData<T> src(&cpu_engine);
+ memory::dims src_dims;
memory::desc src_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
} else {
- auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+ src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
// Create blocked memory descriptor
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
}
- src.SetUsrMem(src_md, &src_tensor);
T alpha = 0, beta = 0;
- std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
- auto relu_fwd_desc = relu_forward::desc(
- prop_kind::forward_training,
- // Operator memory descriptor is same as user memory descriptor.
- alg_kind, src.GetUsrMemDesc(), alpha, beta);
- relu_fwd_pd.reset(
- new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
-
- // allocate dst tensor
+
+ // get a eltwise fwd from primitive pool
+ MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha, beta);
+ MklEltwiseFwdPrimitive<T>* eltwise_fwd =
+ MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ // prepare for execuation
+ const T* src_data = src_tensor.flat<T>().data();
+ // check wehther src need to reorder
+ if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ auto src_target_pd = memory::primitive_desc(
+ {{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()},
+ cpu_engine);
+ src.CheckReorderToOpMem(src_target_pd);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ // allocate dst tensor, always set it as MKL-DNN layout
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
+ eltwise_fwd->GetEltwiseFwdPd();
MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
- auto dst_pd = relu_fwd_pd->dst_primitive_desc();
+ auto dst_pd = eltwise_fwd_pd->dst_primitive_desc();
dnn_shape_dst.SetMklLayout(&dst_pd);
dnn_shape_dst.SetElemType(MklDnnType<T>());
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
@@ -434,34 +853,32 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
-
- // Allocate output and MklDnnShape tensors separately for possible
- // in-place operation
+
+ Tensor* dst_tensor = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
tf_shape_dst, &dst_tensor));
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
- // Destination memory descriptor is same as source memory descriptor.
- auto &dst_md = src_md;
- dst.SetUsrMem(dst_md, dst_tensor);
+ T* dst_data = dst_tensor->flat<T>().data();
- // execute net
- std::vector<primitive> net;
- auto relu_fwd =
- relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem());
- net.push_back(relu_fwd);
- stream(stream::kind::eager).submit(net).wait();
+ // execute eltwise
+ eltwise_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
+
+ private:
+ engine cpu_engine = engine(engine::cpu, 0);
+ std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
};
template <typename Device, typename T, algorithm alg_kind>
@@ -470,16 +887,15 @@ class MklReluGradOpBase : public OpKernel {
~MklReluGradOpBase() {}
explicit MklReluGradOpBase(OpKernelConstruction* context)
- : OpKernel(context) {}
+ : OpKernel(context) {
+ }
virtual void Compute_Scalar(OpKernelContext* context) = 0;
void Compute(OpKernelContext* context) {
try {
- auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
- MklDnnData<T> diff_src(&cpu_engine);
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
@@ -495,37 +911,23 @@ class MklReluGradOpBase : public OpKernel {
int src_dims_size = src_tensor.dims();
if (src_dims_size == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
- // Set DNN primitives for src & diff_dst
+ // get a eltwise bwd from primitive pool
+ memory::dims src_dims = {};
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
-
- // For creating Sum primitive, we need to ensure that all inputs are in
- // same format. What that means is if we have a mixed input case - where
- // one input is in Tensorflow format and one input is in MKL format -,
- // then we need to ensure that all inputs are in same format for
- // primitive construction. For performance reason, we say that all inputs
- // are in MKL format in such case, and insert reorder for input that is
- // in Tensorflow format into MKL format. On the other hand, if both the
- // inputs are in MKL format or both are in Tensorflow format, then we
- // dont need reorder.
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
- // If both the inputs are in Tensorflow format, we create blocked memory
- // descriptor.
- auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+ src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
diff_dst_md = src_md;
} else if (dnn_shape_src.IsMklTensor() &&
!dnn_shape_diff_dst.IsMklTensor()) {
- // If one input is in MKL format and other is in Tensorflow, then
- // create respective descriptors describing the actual case. For input
- // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
- // Tensorflow format, we create memory descriptor using data format.
src_md = dnn_shape_src.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
auto src_tf_data_format =
@@ -536,26 +938,27 @@ class MklReluGradOpBase : public OpKernel {
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
} else if (!dnn_shape_src.IsMklTensor() &&
dnn_shape_diff_dst.IsMklTensor()) {
- // Same comment as above.
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
memory::format diff_dst_mkl_data_format =
dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
- auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
- diff_dst_tf_data_format);
+
+ src_dims = (src_tensor.dims() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
+ diff_dst_tf_data_format)
+ : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
+ diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
} else {
- // If both the inputs are in MKL format, we use Mkl layout of the input
- // tensors.
src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
}
- src.SetUsrMem(src_md, &src_tensor);
- diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ T alpha = 0, beta = 0;
// As per comment above, we tell MKLDNN that both the inputs are in same
// format. So we set common memory descriptor in MKL format, if any of the
@@ -570,24 +973,38 @@ class MklReluGradOpBase : public OpKernel {
common_md = src_md;
}
- T alpha = 0, beta = 0;
- std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
- auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training,
- alg_kind, src_md, alpha, beta);
- relu_fwd_pd.reset(
- new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
- auto relu_bwd_desc =
- relu_backward::desc(alg_kind, common_md, common_md, alpha, beta);
- auto relu_bwd_pd = relu_backward::primitive_desc(
- relu_bwd_desc, cpu_engine, *relu_fwd_pd);
+ MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha,
+ beta);
+ MklEltwiseBwdPrimitive<T>* eltwise_bwd =
+ MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
+ auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
+
+ // check whether need reorder for src / diff_dst
+ const T* src_data = src_tensor.flat<T>().data();
+ if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ src.CheckReorderToOpMem(
+ eltwise_bwd_pd.get()->diff_src_primitive_desc());
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
+ if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ eltwise_bwd_pd.get()->diff_src_primitive_desc());
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
+ }
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
+ auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
dnn_shape_diff_src.SetMklTensor(true);
- auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc();
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
if (dnn_shape_src.IsMklTensor()) {
@@ -602,25 +1019,18 @@ class MklReluGradOpBase : public OpKernel {
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
- // both src and diff_dst are TensorFlow layout,
- // so it is ok to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
- // Allocate diff_src and MklDnnShape tensors separately for possible
- // in-place operation
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {static_cast<const int>(diff_dst_index)},
- static_cast<const int>(diff_src_index),
- tf_shape_diff_src,
- &diff_src_tensor));
+ {diff_dst_index}, diff_src_index,
+ tf_shape_diff_src, &diff_src_tensor));
AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
- // diff_src memory descriptor is same as memory descriptor for both
- // inputs.
- diff_src.SetUsrMem(common_md, diff_src_tensor);
+ T* diff_src_data = diff_src_tensor->flat<T>().data();
- PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst);
+ // execute eltwise bwd
+ eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -631,22 +1041,9 @@ class MklReluGradOpBase : public OpKernel {
}
}
- void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc,
- MklDnnData<T>* src, MklDnnData<T>* diff_src,
- MklDnnData<T>* diff_dst) {
- std::vector<primitive> net;
-
- // Check if we need to reorder original input tensors into common_md layout
- // that we set for primitive creation. diff_src_primitive_desc is same as
- // common_md.
- src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net);
- diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(),
- &net);
-
- net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(),
- diff_dst->GetOpMem(), diff_src->GetOpMem()));
- stream(stream::kind::eager).submit(net).wait();
- }
+ private:
+ engine cpu_engine = engine(engine::cpu, 0);
+ std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 8bde966be9..cfab529662 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -50,6 +50,7 @@ class MklSoftmaxOp : public OpKernel {
// src_tensor now points to the 0-th input of global data struct "context"
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const int input_dims = src_tensor.dims();
// Add: get MklShape
MklDnnShape src_mkl_shape;
@@ -62,7 +63,33 @@ class MklSoftmaxOp : public OpKernel {
: src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
-
+ memory::format layout_type;
+ // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor.
+ // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor,
+ // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor.
+ // Each of the simbols has the following meaning:
+ // n = batch, c = channels, t = sequence lenght, h = height,
+ // w = width, d = depth
+ switch (input_dims) {
+ case 1:
+ layout_type = memory::format::x;
+ break;
+ case 2:
+ layout_type = memory::format::nc;
+ break;
+ case 3:
+ layout_type = memory::format::tnc;
+ break;
+ case 4:
+ layout_type = memory::format::nchw;
+ break;
+ case 5:
+ layout_type = memory::format::ncdhw;
+ break;
+ default:
+ OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+ return;
+ }
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<T> src(&cpu_engine);
@@ -75,7 +102,7 @@ class MklSoftmaxOp : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
+ : memory::desc(src_dims, MklDnnType<T>(), layout_type);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -84,10 +111,11 @@ class MklSoftmaxOp : public OpKernel {
// data format is "nc" for src and dst; since the src and dst buffer is
// always in 2D shape
src.SetUsrMem(src_md, &src_tensor);
- src.SetOpMemDesc(src_dims, memory::format::nc);
+ src.SetOpMemDesc(src_dims, layout_type);
// creating a memory descriptor
- int axis = 1; // axis to which softmax will be applied
+ // passing outermost dim as default axis, where the softmax is applied
+ int axis = input_dims - 1;
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
auto softmax_fwd_pd =
@@ -107,7 +135,7 @@ class MklSoftmaxOp : public OpKernel {
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims.size(), output_dims,
- memory::format::nc);
+ layout_type);
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
} else { // then output is also TF shape
output_mkl_shape.SetMklTensor(false);
diff --git a/tensorflow/core/kernels/multinomial_op.h b/tensorflow/core/kernels/multinomial_op.h
index 6e41060aa4..34e2123613 100644
--- a/tensorflow/core/kernels/multinomial_op.h
+++ b/tensorflow/core/kernels/multinomial_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
-#define TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
namespace tensorflow {
@@ -27,4 +27,4 @@ struct MultinomialFunctor;
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h
index 11f5be7c03..0d5a42bf10 100644
--- a/tensorflow/core/kernels/neon/depthwiseconv_float.h
+++ b/tensorflow/core/kernels/neon/depthwiseconv_float.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
-#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
+#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
#include "public/gemmlowp.h"
#include "tensorflow/core/kernels/neon/types.h"
@@ -722,4 +722,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
} // end namespace neon
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/core/kernels/no_op.h b/tensorflow/core/kernels/no_op.h
index 29ea46aed6..9e16d06978 100644
--- a/tensorflow/core/kernels/no_op.h
+++ b/tensorflow/core/kernels/no_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_NO_OP_H_
-#define TENSORFLOW_KERNELS_NO_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NO_OP_H_
+#define TENSORFLOW_CORE_KERNELS_NO_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -29,4 +29,4 @@ class NoOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_NO_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_NO_OP_H_
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 5d9257e20b..81ce6d6e95 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOUGreaterThanThreshold(
- typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
- float iou_threshold) {
- const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
- const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
- const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
- const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
- const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
- const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
- const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
- const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
- const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
- const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
- if (area_i <= 0 || area_j <= 0) return 0.0;
- const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
- const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
- const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
- const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
- const float intersection_area =
- std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
- std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- const float iou = intersection_area / (area_i + area_j - intersection_area);
+template <typename T>
+static inline bool IOUGreaterThanThreshold(
+ typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
+ const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
+ const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
+ const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
+ const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
+ const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
+ const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
+ const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
+ const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
+ const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+ if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
+ const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
+ const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
+ const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
+ const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
+ const T intersection_area =
+ std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
+ std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
+ const T iou = intersection_area / (area_i + area_j - intersection_area);
return iou > iou_threshold;
}
@@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold(
return overlaps(i, j) > overlap_threshold;
}
+template <typename T>
static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
const Tensor& boxes, float threshold) {
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
- return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
- std::placeholders::_2, threshold);
+ typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
+ return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
+ std::placeholders::_1, std::placeholders::_2,
+ static_cast<T>(threshold));
}
static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
@@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
+template <typename T>
void DoNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& scores, int num_boxes,
const Tensor& max_output_size, const float score_threshold,
@@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp(
bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = max_output_size.scalar<int>()();
- std::vector<float> scores_data(num_boxes);
- std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+ std::vector<T> scores_data(num_boxes);
+ std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
- float score;
+ T score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
@@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp(
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
candidate_priority_queue(cmp);
for (int i = 0; i < scores_data.size(); ++i) {
- if (scores_data[i] > score_threshold) {
+ if (static_cast<float>(scores_data[i]) > score_threshold) {
candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
}
}
std::vector<int> selected;
- std::vector<float> selected_scores;
+ std::vector<T> selected_scores;
Candidate next_candidate;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
@@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp(
int num_valid_outputs = selected.size();
if (pad_to_max_output_size) {
selected.resize(output_size, 0);
- selected_scores.resize(output_size, 0);
+ selected_scores.resize(output_size, static_cast<T>(0));
}
if (ptr_num_valid_outputs) {
*ptr_num_valid_outputs = num_valid_outputs;
@@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
float iou_threshold_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
@@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel {
float score_threshold_val_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
@@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
@@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
int num_valid_outputs;
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn,
- pad_to_max_output_size_, &num_valid_outputs);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
// Allocate scalar output tensor for number of indices computed.
Tensor* num_outputs_t = nullptr;
@@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel {
auto suppress_check_fn =
CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
NonMaxSuppressionOp<CPUDevice>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
- NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
- NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
- NonMaxSuppressionV4Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/nth_element_op.h b/tensorflow/core/kernels/nth_element_op.h
index e7d25daecc..7a5ec3d0b5 100644
--- a/tensorflow/core/kernels/nth_element_op.h
+++ b/tensorflow/core/kernels/nth_element_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_NTH_ELEMENT_OP_H_
-#define TENSORFLOW_NTH_ELEMENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct NthElementFunctor {
} // namespace tensorflow
-#endif // TENSORFLOW_NTH_ELEMENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
diff --git a/tensorflow/core/kernels/one_hot_op.h b/tensorflow/core/kernels/one_hot_op.h
index db59f0f0d4..879df2b59b 100644
--- a/tensorflow/core/kernels/one_hot_op.h
+++ b/tensorflow/core/kernels/one_hot_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/array_ops.cc
-#ifndef TENSORFLOW_KERNELS_ONE_HOT_OP_H_
-#define TENSORFLOW_KERNELS_ONE_HOT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
// Generator definition for OneHotOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -69,4 +69,4 @@ struct OneHot {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ONE_HOT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index 2c195beb7f..5d607b9044 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
-#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
+#define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
#include <memory>
#include <vector>
@@ -252,4 +252,4 @@ class OpsTestBase : public ::testing::Test {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
index 93ef512778..a496487d1b 100644
--- a/tensorflow/core/kernels/ops_util.h
+++ b/tensorflow/core/kernels/ops_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_OPS_UTIL_H_
-#define TENSORFLOW_KERNELS_OPS_UTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
// This file contains utilities for various operations.
@@ -113,4 +113,4 @@ gtl::InlinedVector<T, 8> ComputeEigenStrides(const EigenDimensions& shape) {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h
index ee9e0f0330..ae79f515d9 100644
--- a/tensorflow/core/kernels/pad_op.h
+++ b/tensorflow/core/kernels/pad_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PAD_OP_H_
-#define TENSORFLOW_KERNELS_PAD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PAD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_PAD_OP_H_
// Functor definition for PadOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -54,4 +54,4 @@ struct Pad<Device, T, Tpadding, 0> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PAD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_PAD_OP_H_
diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h
index 9d7c935068..b86b03c8f0 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.h
+++ b/tensorflow/core/kernels/padding_fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@@ -86,4 +86,4 @@ class PaddingFIFOQueue : public FIFOQueue {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 0ab9ff9f65..aa70ee06f5 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -47,7 +47,7 @@ using random::PhiloxRandom;
template <typename T>
struct TruncatedNormalFunctor<CPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
@@ -124,6 +124,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
(normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
+
if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
@@ -143,15 +144,20 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const auto u = dist(&gen_copy);
for (int i = 0; i < size; i++) {
- if (u[i] <= Eigen::numext::exp(g[i]) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = u[i] <= Eigen::numext::exp(g[i]);
+ if (accept || numIterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
- // sample. Emperically, the probability of accepting each sample
- // is at least 50% for typical inputs, so we will always accept
- // by 100 iterations.
- // This introduces a slight inaccuracy when at least one bound
- // is large, minval is negative and maxval is positive.
+ // sample, but emit a warning.
+ // TODO(jjhunt) For small entropies (relative to the bounds),
+ // this sampler is poor and may take many iterations since
+ // the proposal distribution is the uniform distribution
+ // U(lower_bound, upper_bound).
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal uniform rejection sampler "
+ << "exceeded max iterations. Sample may contain "
+ << "outliers.";
+ }
output(sample) = z[i] * stddev + mean;
sample++;
if (sample >= limit_sample) {
@@ -181,8 +187,13 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const T g = Eigen::numext::exp(-x * x / T(2.0));
const T u = rand[i];
i++;
- if ((u <= g && z < normMax) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = (u <= g && z < normMax);
+ if (accept || numIterations + 1 >= kMaxIterations) {
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal exponential distribution "
+ << "rejection sampler exceeds max iterations. "
+ << "Sample may contain outliers.";
+ }
output(sample) = z * stddev + mean;
sample++;
if (sample >= limit_sample) {
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.h b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
index cc801eb810..2e54db31fe 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.h
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
-#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/random_distributions.h"
@@ -49,4 +49,4 @@ struct TruncatedNormalFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
index 661d47d925..5b80a962bc 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -190,7 +190,7 @@ __global__ void __launch_bounds__(1024)
// Partial specialization for GPU
template <typename T>
struct TruncatedNormalFunctor<GPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 8db78f9784..7bb403290d 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
@@ -98,20 +99,12 @@ class PartitionedCallOp : public AsyncOpKernel {
done);
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
CopyGraph(*fbody->graph, graph.get());
- OP_REQUIRES_OK_ASYNC(ctx, PropagateInheritedDevices(graph.get(), args),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
DeviceSet device_set;
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph.get(), &device_set);
- OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
- std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
- OP_REQUIRES_OK_ASYNC(
- ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
- done);
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
@@ -125,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel {
new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
overlay_libs_.emplace(lib, overlay_lib);
+ GraphOptimizationPassOptions optimization_options;
+ // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+ // make it possible to specify the relevant options via attributes.
+ SessionOptions session_options;
+ session_options.env = ctx->env();
+ optimization_options.session_options = &session_options;
+ optimization_options.graph = &graph;
+ optimization_options.flib_def = overlay_lib;
+ optimization_options.device_set = &device_set;
+ Placer placer(graph.get(), &device_set);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+ optimization_options),
+ done);
+
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
+ optimization_options.graph = nullptr;
+ optimization_options.device_set = nullptr;
+ optimization_options.partition_graphs = &subgraphs;
+ OP_REQUIRES_OK_ASYNC(ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING,
+ optimization_options),
+ done);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds
@@ -163,15 +197,10 @@ class PartitionedCallOp : public AsyncOpKernel {
std::vector<AllocatorAttributes>>
ArgAndRetAllocAttrs;
- // Propagates device annotations from the outer graph to the function body.
- //
// Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
// corresponding resource lives. This ensures that the Placer assigns ops that
- // access these resources to the appropriate devices. Additionally, places
- // nodes that are unadorned with device annotations onto PartitiondCallOp's
- // device. This lets call-site device annotations influence the execution
- // of the function.
- Status PropagateInheritedDevices(Graph* graph, const OpInputList& args) {
+ // access these resources to the appropriate devices.
+ Status PinResourceArgs(Graph* graph, const OpInputList& args) {
for (Node* node : graph->op_nodes()) {
string node_type = node->type_string();
if (node_type == FunctionLibraryDefinition::kArgOp) {
@@ -184,18 +213,6 @@ class PartitionedCallOp : public AsyncOpKernel {
ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
- } else if (node_type != FunctionLibraryDefinition::kRetOp) {
- // All non-RetVal nodes that weren't explicitly placed by the user
- // inherit PartitionedCallOp's device. RetVal placement is inferred by
- // the placer, to avoid forcing the function's outputs through a single
- // device.
- //
- // TODO(b/112166045): Plumb the original requested device into this
- // OpKernel (this->requested_device() isn't reliable), and merge it
- // with node->requested_device() if possible.
- if (node->requested_device().empty()) {
- node->set_requested_device(local_device_name_);
- }
}
}
return Status::OK();
diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h
new file mode 100644
index 0000000000..f91244454e
--- /dev/null
+++ b/tensorflow/core/kernels/poisson-loss.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+
+#include <cmath>
+
+#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+class PoissonLossUpdater : public DualLossUpdater {
+ public:
+ // Update is found by a Newton algorithm (see readme.md).
+ double ComputeUpdatedDual(const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) const final {
+ // Newton algorithm converges quadratically so 10 steps will be largely
+ // enough to achieve a very good precision
+ static const int newton_total_steps = 10;
+ // Initialize the Newton optimization at x such that
+ // exp(x) = label - current_dual
+ const double y_minus_a = label - current_dual;
+ double x = (y_minus_a > 0) ? log(y_minus_a) : 0;
+ for (int i = 0; i < newton_total_steps; ++i) {
+ x = NewtonStep(x, num_loss_partitions, label, wx, example_weight,
+ weighted_example_norm, current_dual);
+ }
+ return label - exp(x);
+ }
+
+ // Dual of poisson loss function.
+ // https://en.wikipedia.org/wiki/Convex_conjugate
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
+ // Dual of the poisson loss function is
+ // (y-a)*(log(y-a)-1), where a is the dual variable.
+ // It is defined only for a<y.
+ const double y_minus_a = example_label - current_dual;
+ if (y_minus_a == 0.0) {
+ // (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0.
+ return 0.0;
+ }
+ if (y_minus_a < 0.0) {
+ return std::numeric_limits<double>::max();
+ }
+ return y_minus_a * (log(y_minus_a) - 1) * example_weight;
+ }
+
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
+ return (exp(wx) - wx * example_label) * example_weight;
+ }
+
+ double PrimalLossDerivative(const double wx, const double label,
+ const double example_weight) const final {
+ return (exp(wx) - label) * example_weight;
+ }
+
+ // TODO(chapelle): We need to introduce a maximum_prediction parameter,
+ // expose that parameter to the user and have this method return
+ // 1.0/maximum_prediction.
+ // Setting this at 1 for now, it only impacts the adaptive sampling.
+ double SmoothnessConstant() const final { return 1; }
+
+ Status ConvertLabel(float* const example_label) const final {
+ if (*example_label < 0.0) {
+ return errors::InvalidArgument(
+ "Only non-negative labels can be used with the Poisson log loss. "
+ "Found example with label: ", *example_label);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // One Newton step (see readme.md).
+ double NewtonStep(const double x, const int num_loss_partitions,
+ const double label, const double wx,
+ const double example_weight,
+ const double weighted_example_norm,
+ const double current_dual) const {
+ const double expx = exp(x);
+ const double numerator =
+ x - wx - num_loss_partitions * weighted_example_norm *
+ example_weight * (label - current_dual - expx);
+ const double denominator =
+ 1 + num_loss_partitions * weighted_example_norm * example_weight * expx;
+ return x - numerator / denominator;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/pooling_ops_3d.h b/tensorflow/core/kernels/pooling_ops_3d.h
index d1be3ba407..319b17397e 100644
--- a/tensorflow/core/kernels/pooling_ops_3d.h
+++ b/tensorflow/core/kernels/pooling_ops_3d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
-#define TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/padding.h"
@@ -77,4 +77,4 @@ struct Pool3dParameters {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
diff --git a/tensorflow/core/kernels/pooling_ops_3d_gpu.h b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
index 350b1b6732..2c3681455e 100644
--- a/tensorflow/core/kernels/pooling_ops_3d_gpu.h
+++ b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
-#define TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
#define EIGEN_USE_GPU
@@ -45,4 +45,4 @@ struct MaxPool3dGradBackward {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index e9265551e3..dda2c80c49 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
#include <vector>
@@ -605,4 +605,4 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/priority_queue.h b/tensorflow/core/kernels/priority_queue.h
index ff168df449..8e69b5b699 100644
--- a/tensorflow/core/kernels/priority_queue.h
+++ b/tensorflow/core/kernels/priority_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
-#define TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
#include <deque>
#include <queue>
@@ -90,4 +90,4 @@ class PriorityQueue
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc
index c5b73139bb..8a3e3dc0a9 100644
--- a/tensorflow/core/kernels/qr_op_complex128.cc
+++ b/tensorflow/core/kernels/qr_op_complex128.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<complex128>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<complex128>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_complex64.cc b/tensorflow/core/kernels/qr_op_complex64.cc
index 4e14f2639c..467fa6c2d6 100644
--- a/tensorflow/core/kernels/qr_op_complex64.cc
+++ b/tensorflow/core/kernels/qr_op_complex64.cc
@@ -20,7 +20,11 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc
index 51885eb355..05537a0eaa 100644
--- a/tensorflow/core/kernels/qr_op_double.cc
+++ b/tensorflow/core/kernels/qr_op_double.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<double>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc
index d0a1dd4204..6aebd98186 100644
--- a/tensorflow/core/kernels/qr_op_float.cc
+++ b/tensorflow/core/kernels/qr_op_float.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<float>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h
index 0552c034d2..535df9d160 100644
--- a/tensorflow/core/kernels/qr_op_impl.h
+++ b/tensorflow/core/kernels/qr_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
@@ -292,6 +295,8 @@ class QrOpGpu : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
};
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/random_op.h b/tensorflow/core/kernels/random_op.h
index 97bcaf1a49..d313a021dd 100644
--- a/tensorflow/core/kernels/random_op.h
+++ b/tensorflow/core/kernels/random_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANDOM_OP_H_
-#define TENSORFLOW_KERNELS_RANDOM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/random/random_distributions.h"
@@ -69,4 +69,4 @@ struct FillPhiloxRandom<SYCLDevice, Distribution> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANDOM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h
index 4e9fd62520..62ae01c16c 100644
--- a/tensorflow/core/kernels/random_poisson_op.h
+++ b/tensorflow/core/kernels/random_poisson_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
-#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
namespace tensorflow {
@@ -28,4 +28,4 @@ struct PoissonFunctor;
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
index 3010666598..ed160adfb4 100644
--- a/tensorflow/core/kernels/range_sampler.h
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
-#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
+#define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
#include <vector>
@@ -249,4 +249,4 @@ class FixedUnigramSampler : public RangeSampler {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
index 9020121169..3d49af7cb1 100644
--- a/tensorflow/core/kernels/range_sampler_test.cc
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -45,7 +45,7 @@ class RangeSamplerTest : public ::testing::Test {
// Using a fixed random seed to make the test deterministic.
random::PhiloxRandom philox(123, 17);
random::SimplePhilox rnd(&philox);
- sampler_->SampleBatch(&rnd, false, &a);
+ sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
for (int i = 0; i < num_samples; i++) {
int64 val = a[i];
ASSERT_GE(val, 0);
@@ -251,8 +251,9 @@ TEST_F(RangeSamplerTest, All) {
extras[0] = 0;
extras[1] = batch_size - 1;
sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
- false, &batch, &batch_expected, extras,
- &extras_expected);
+ false, absl::MakeSpan(batch),
+ absl::MakeSpan(batch_expected), extras,
+ absl::MakeSpan(extras_expected));
for (int i = 0; i < batch_size; i++) {
EXPECT_EQ(i, batch[i]);
EXPECT_EQ(1, batch_expected[i]);
@@ -281,17 +282,18 @@ TEST_F(RangeSamplerTest, Unique) {
std::vector<float> expected(range);
// Sample one batch and get the expected counts of all values
- sampler_->SampleBatchGetExpectedCount(
- &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
+ MutableArraySlice<float>(), all_values,
+ absl::MakeSpan(expected));
// Check that all elements are unique
std::set<int64> s(batch.begin(), batch.end());
CHECK_EQ(batch_size, s.size());
for (int trial = 0; trial < num_batches; trial++) {
std::vector<float> trial_expected(range);
- sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
- MutableArraySlice<float>(),
- all_values, &trial_expected);
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ all_values, absl::MakeSpan(trial_expected));
for (int i = 0; i < range; i++) {
EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
}
@@ -318,8 +320,8 @@ TEST_F(RangeSamplerTest, Avoid) {
// We expect to pick all elements of [0, 100) except the avoided two.
sampler_->SampleBatchGetExpectedCountAvoid(
- &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
- MutableArraySlice<float>(), avoided);
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
int sum = 0;
for (auto val : batch) {
diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h
index 34817ad51b..159b43b4cd 100644
--- a/tensorflow/core/kernels/record_yielder.h
+++ b/tensorflow/core/kernels/record_yielder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
-#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
+#define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
#include <atomic>
#include <random>
@@ -157,4 +157,4 @@ class RecordYielder {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_
+#endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 9af4cc23b6..88b3c2ac76 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -1058,4 +1061,6 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
} // namespace functor
} // namespace tensorflow
-#endif
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h
index e43d2828f3..eb264e0e5a 100644
--- a/tensorflow/core/kernels/reduction_ops.h
+++ b/tensorflow/core/kernels/reduction_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_H_
-#define TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
// Functor definitions for Reduction ops, must be compilable by nvcc.
@@ -79,4 +79,4 @@ struct ReduceFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 03d6e82e01..d83e1c7d15 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -18,8 +18,8 @@ limitations under the License.
// is a header file because we split the various reduction ops into their
// own compilation units to get more parallelism in compilation.
-#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
#define EIGEN_USE_THREADS
@@ -277,4 +277,4 @@ struct ReduceFunctor<SYCLDevice, Reducer>
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c8e4..7edaaad8f7 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+ explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+ }
+ }
+
+ private:
+ std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+ StaticRegexFullMatchOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
index 59ec854a79..a1b948891d 100644
--- a/tensorflow/core/kernels/regex_replace_op.cc
+++ b/tensorflow/core/kernels/regex_replace_op.cc
@@ -20,8 +20,43 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
+namespace {
+
+// Execute the specified regex using the given context.
+// Context requirements:
+// - "input" string Tensor at input_index=0
+// - "output" string Tensor at output_index=0
+Status InternalCompute(const RE2& match, const string& rewrite,
+ const bool replace_global, OpKernelContext* ctx) {
+ const Tensor* input_tensor;
+ TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
+ Tensor* output_tensor;
+ std::unique_ptr<Tensor> maybe_forwarded =
+ ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
+ tensorflow::DT_STRING, input_tensor->shape(),
+ ctx->input_memory_type(0), ctx->input_alloc_attr(0));
+ if (maybe_forwarded) {
+ output_tensor = maybe_forwarded.get();
+ TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
+ } else {
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
+ output_tensor->flat<string>() = input_tensor->flat<string>();
+ }
+ auto output_flat = output_tensor->flat<string>();
+ for (size_t i = 0; i < output_flat.size(); ++i) {
+ if (replace_global) {
+ RE2::GlobalReplace(&output_flat(i), match, rewrite);
+ } else {
+ RE2::Replace(&output_flat(i), match, rewrite);
+ }
+ }
+ return Status::OK();
+}
+} // namespace
class RegexReplaceOp : public OpKernel {
public:
@@ -30,10 +65,6 @@ class RegexReplaceOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- const Tensor* input_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
- const auto& input_flat = input_tensor->flat<string>();
-
const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
@@ -51,19 +82,7 @@ class RegexReplaceOp : public OpKernel {
errors::InvalidArgument("Rewrite must be scalar, but received ",
rewrite_tensor->shape().DebugString()));
const string rewrite = rewrite_tensor->flat<string>()(0);
-
- Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
- &output_tensor));
- auto output_flat = output_tensor->flat<string>();
- for (size_t i = 0; i < input_flat.size(); ++i) {
- output_flat(i) = input_flat(i);
- if (replace_global_) {
- RE2::GlobalReplace(&output_flat(i), match, rewrite);
- } else {
- RE2::Replace(&output_flat(i), match, rewrite);
- }
- }
+ OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
}
private:
@@ -73,4 +92,31 @@ class RegexReplaceOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
RegexReplaceOp);
+class StaticRegexReplaceOp : public OpKernel {
+ public:
+ explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx,
+ InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
+ }
+
+ private:
+ string rewrite_str_;
+ std::unique_ptr<RE2> re_;
+ bool replace_global_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
+ StaticRegexReplaceOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc
new file mode 100644
index 0000000000..9691d4a89f
--- /dev/null
+++ b/tensorflow/core/kernels/regex_replace_op_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+const char kRegExPattern[] = "\\p{P}";
+const char kRewrite[] = " ";
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
+ const string& input_rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor pattern(DT_STRING, TensorShape({}));
+ pattern.flat<string>().setConstant(input_pattern);
+ Tensor rewrite(DT_STRING, TensorShape({}));
+ rewrite.flat<string>().setConstant(input_rewrite);
+
+ TF_CHECK_OK(NodeBuilder("regex_replace_op", "RegexReplace")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, pattern))
+ .Input(test::graph::Constant(g, rewrite))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_RegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_RegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
+ const string& rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+
+ TF_CHECK_OK(NodeBuilder("static_regex_replace_op", "StaticRegexReplace")
+ .Attr("pattern", input_pattern)
+ .Attr("rewrite", rewrite)
+ .Input(test::graph::Constant(g, input))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+void BM_StaticRegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StaticRegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index cafa49cbb6..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -143,6 +143,12 @@ namespace functor {
typename TTypes<T>::Tensor backprops); \
extern template struct SeluGrad<GPUDevice, T>;
+template <>
+void Relu<GPUDevice, qint8>::operator()(
+ const GPUDevice& d, typename TTypes<qint8>::ConstTensor features,
+ typename TTypes<qint8>::Tensor activations);
+extern template struct Relu<GPUDevice, qint8>;
+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
@@ -182,6 +188,27 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+template <typename Device>
+class ReluOp<Device, qint8>
+ : public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
+ public:
+ using UnaryElementWiseOp<qint8, ReluOp<Device, qint8>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ auto flat_input = input.flat<qint8>();
+ OP_REQUIRES(context, (flat_input.size() % 4) == 0,
+ errors::InvalidArgument(
+ "Tensor size must be a multiple of 4 for Relu<qint8>. Got ",
+ flat_input.size()));
+ functor::Relu<Device, qint8> func;
+ func(context->eigen_device<Device>(), flat_input, output->flat<qint8>());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
+ ReluOp<GPUDevice, qint8>);
+
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index fa79ab03ae..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
-#ifndef TENSORFLOW_KERNELS_RELU_OP_H_
-#define TENSORFLOW_KERNELS_RELU_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RELU_OP_H_
#define EIGEN_USE_THREADS
@@ -280,4 +280,4 @@ void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
#undef EIGEN_USE_THREADS
-#endif // TENSORFLOW_KERNELS_RELU_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index 548d5a277d..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -198,4 +198,4 @@ struct SeluGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 4452f4dcc9..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -103,7 +103,7 @@ struct ReluGrad<Device, Eigen::half> {
int32 count = gradient.size();
if (count == 0) return;
int32 half2_count = Eigen::divup(count, 2);
- const int32 kThreadInBlock = 512;
+ constexpr int32 kThreadInBlock = 512;
CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
ReluGradHalfKernel<<<config.block_count, config.thread_per_block, 0,
@@ -111,6 +111,37 @@ struct ReluGrad<Device, Eigen::half> {
backprop.data(), count);
}
};
+
+__global__ void Relu_int8x4_kernel(int vect_count, const int32* input,
+ int32* output) {
+ CUDA_1D_KERNEL_LOOP(index, vect_count) {
+ output[index] = __vmaxs4(input[index], 0);
+ }
+}
+
+// Functor used by ReluOp to do the computations.
+template <typename Device>
+struct Relu<Device, qint8> {
+ // Computes Relu activation of 'input' containing int8 elements, whose buffer
+ // size should be a multiple of 4, and aligned to an int32* boundary.
+ // (Alignment should be guaranteed by the GPU tensor allocator).
+ // 'output' should have the same size as 'input'.
+ void operator()(const Device& d, typename TTypes<qint8>::ConstTensor input,
+ typename TTypes<qint8>::Tensor output) {
+ int32 count = input.size();
+ if (count == 0) return;
+
+ int32 vect_count = Eigen::divup(count, 4);
+ constexpr int32 kThreadInBlock = 512;
+ CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
+ vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock);
+ Relu_int8x4_kernel<<<config.block_count, config.thread_per_block, 0,
+ d.stream()>>>(
+ vect_count, reinterpret_cast<const int32*>(input.data()),
+ reinterpret_cast<int32*>(output.data()));
+ }
+};
+
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
@@ -128,6 +159,8 @@ struct ReluGrad<Device, Eigen::half> {
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+template struct functor::Relu<GPUDevice, qint8>;
+
} // end namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 194a711d98..26f107f940 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
std::unordered_set<string> retval;
for (const string& node_name_and_port : node_names_and_ports) {
const TensorId tid = ParseTensorName(node_name_and_port);
- retval.emplace(std::string(tid.first));
+ retval.emplace(tid.first);
}
return retval;
}
@@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) {
const NodeDef* FindNodeDefByName(const string& input,
const GraphDef& graph_def) {
const TensorId tid = ParseTensorName(input);
- const string name = std::string(tid.first);
+ const string name = string(tid.first);
for (const NodeDef& node_def : graph_def.node()) {
if (node_def.name() == name) {
return &node_def;
@@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
const TensorId tid = ParseTensorName(name_and_port);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name) {
if (node_name.find(':') != string::npos) {
const TensorId tid = ParseTensorName(node_name);
- return GetTensorShapeType(tensor_shape_map, std::string(tid.first),
- tid.second);
+ return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
} else {
return GetTensorShapeType(tensor_shape_map, node_name, 0);
}
@@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
const TensorId tid = ParseTensorName(name);
CHECK_EQ(tensor_shape_map->count(name), 0);
tensor_shape_map->emplace(
- std::string(tid.first),
+ string(tid.first),
std::make_pair(tid.second,
std::make_pair(tensor.dtype(), tensor.shape())));
}
@@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::vector<NodeBuilder::NodeOut> node_out_list;
for (const string& input : inputs) {
const TensorId tid = ParseTensorName(input);
- Node* node = FindMutableNodeByName(std::string(tid.first), graph);
+ Node* node = FindMutableNodeByName(string(tid.first), graph);
CHECK_NOTNULL(node);
node_out_list.emplace_back(node, tid.second);
}
@@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (const string& subgraph_input : std::get<1>(cluster)) {
const TensorId tid = ParseTensorName(subgraph_input);
- const string subgraph_input_name = std::string(tid.first);
+ const string subgraph_input_name(tid.first);
const int subgraph_input_port = tid.second;
const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::deque<const Node*> queue;
for (const string& output : border_outputs) {
const TensorId tid = ParseTensorName(output);
- const string& output_node_name = std::string(tid.first);
+ const string output_node_name(tid.first);
for (const Node* node : graph.nodes()) {
if (output_node_name == node->name()) {
queue.push_back(node);
@@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (int j = 0; j < border_outputs.size(); ++j) {
const string& output = border_outputs.at(j);
const TensorId tid = ParseTensorName(output);
- const string output_name = std::string(tid.first);
+ const string output_name(tid.first);
Node* src_node = edge->src();
if (src_node != nullptr && src_node->name() == output_name &&
edge->src_output() == tid.second) {
@@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
// RemoteFusedGraphExecuteOpNode
for (const string& output : outputs) {
const TensorId output_tid = ParseTensorName(output);
- const string output_name = std::string(output_tid.first);
+ const string output_name(output_tid.first);
for (size_t i = 0; i < border_outputs.size(); ++i) {
const TensorId subgraph_output_tid =
ParseTensorName(border_outputs.at(i));
- const string& subgraph_output_name =
- std::string(subgraph_output_tid.first);
+ const string subgraph_output_name(subgraph_output_tid.first);
if (output_name == subgraph_output_name) {
LOG(INFO) << "As graph output and subgraph output are same, "
<< "the graph output node is replaced by identity node";
@@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
GraphDef* graph_def) {
const TensorId tid = ParseTensorName(input);
CHECK_EQ(0, tid.second);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
for (NodeDef& node : *graph_def->mutable_node()) {
if (node.name() != node_name) {
continue;
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
index 5db2d148b9..7458ac75ca 100644
--- a/tensorflow/core/kernels/reshape_op.h
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RESHAPE_OP_H_
-#define TENSORFLOW_KERNELS_RESHAPE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
@@ -121,4 +121,4 @@ class ReshapeOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RESHAPE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
index dde59e8e74..f10c9a19a7 100644
--- a/tensorflow/core/kernels/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -277,13 +277,13 @@ struct ResizeBilinearGrad<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor input_grad,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output_grad) {
- const int batch = output_grad.dimension(0);
- const int64 original_height = output_grad.dimension(1);
- const int64 original_width = output_grad.dimension(2);
- const int channels = output_grad.dimension(3);
+ const Eigen::Index batch = output_grad.dimension(0);
+ const Eigen::Index original_height = output_grad.dimension(1);
+ const Eigen::Index original_width = output_grad.dimension(2);
+ const Eigen::Index channels = output_grad.dimension(3);
- const int64 resized_height = input_grad.dimension(1);
- const int64 resized_width = input_grad.dimension(2);
+ const Eigen::Index resized_height = input_grad.dimension(1);
+ const Eigen::Index resized_width = input_grad.dimension(2);
output_grad.setZero();
@@ -294,22 +294,24 @@ struct ResizeBilinearGrad<CPUDevice, T> {
// + top_right * (1 - y) * x
// + bottom_left * y * (1 - x)
// + bottom_right * y * x
- for (int64 b = 0; b < batch; ++b) {
- for (int64 y = 0; y < resized_height; ++y) {
+ for (Eigen::Index b = 0; b < batch; ++b) {
+ for (Eigen::Index y = 0; y < resized_height; ++y) {
const float in_y = y * height_scale;
- const int64 top_y_index = static_cast<int64>(floorf(in_y));
- const int64 bottom_y_index =
- std::min(static_cast<int64>(ceilf(in_y)), original_height - 1);
+ const Eigen::Index top_y_index =
+ static_cast<Eigen::Index>(floorf(in_y));
+ const Eigen::Index bottom_y_index = std::min(
+ static_cast<Eigen::Index>(ceilf(in_y)), original_height - 1);
const float y_lerp = in_y - top_y_index;
const float inverse_y_lerp = (1.0f - y_lerp);
- for (int64 x = 0; x < resized_width; ++x) {
+ for (Eigen::Index x = 0; x < resized_width; ++x) {
const float in_x = x * width_scale;
- const int64 left_x_index = static_cast<int64>(floorf(in_x));
- const int64 right_x_index =
- std::min(static_cast<int64>(ceilf(in_x)), original_width - 1);
+ const Eigen::Index left_x_index =
+ static_cast<Eigen::Index>(floorf(in_x));
+ const Eigen::Index right_x_index = std::min(
+ static_cast<Eigen::Index>(ceilf(in_x)), original_width - 1);
const float x_lerp = in_x - left_x_index;
const float inverse_x_lerp = (1.0f - x_lerp);
- for (int64 c = 0; c < channels; ++c) {
+ for (Eigen::Index c = 0; c < channels; ++c) {
output_grad(b, top_y_index, left_x_index, c) +=
T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
output_grad(b, top_y_index, right_x_index, c) +=
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index 8ec526c2b2..e985d3e5a5 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -88,25 +88,27 @@ struct ResizeNearestNeighbor<CPUDevice, T, align_corners> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output) {
- const int batch_size = input.dimension(0);
- const int64 in_height = input.dimension(1);
- const int64 in_width = input.dimension(2);
- const int channels = input.dimension(3);
-
- const int64 out_height = output.dimension(1);
- const int64 out_width = output.dimension(2);
-
- for (int b = 0; b < batch_size; ++b) {
- for (int y = 0; y < out_height; ++y) {
- const int64 in_y = std::min(
- (align_corners) ? static_cast<int64>(roundf(y * height_scale))
- : static_cast<int64>(floorf(y * height_scale)),
- in_height - 1);
- for (int x = 0; x < out_width; ++x) {
- const int64 in_x = std::min(
- (align_corners) ? static_cast<int64>(roundf(x * width_scale))
- : static_cast<int64>(floorf(x * width_scale)),
- in_width - 1);
+ const Eigen::Index batch_size = input.dimension(0);
+ const Eigen::Index in_height = input.dimension(1);
+ const Eigen::Index in_width = input.dimension(2);
+ const Eigen::Index channels = input.dimension(3);
+
+ const Eigen::Index out_height = output.dimension(1);
+ const Eigen::Index out_width = output.dimension(2);
+
+ for (Eigen::Index b = 0; b < batch_size; ++b) {
+ for (Eigen::Index y = 0; y < out_height; ++y) {
+ const Eigen::Index in_y =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(y * height_scale))
+ : static_cast<Eigen::Index>(floorf(y * height_scale)),
+ in_height - 1);
+ for (Eigen::Index x = 0; x < out_width; ++x) {
+ const Eigen::Index in_x =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(x * width_scale))
+ : static_cast<Eigen::Index>(floorf(x * width_scale)),
+ in_width - 1);
std::copy_n(&input(b, in_y, in_x, 0), channels, &output(b, y, x, 0));
}
}
@@ -199,28 +201,29 @@ struct ResizeNearestNeighborGrad<CPUDevice, T, align_corners> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output) {
- const int batch_size = input.dimension(0);
- const int64 in_height = input.dimension(1);
- const int64 in_width = input.dimension(2);
- const int channels = input.dimension(3);
+ const Eigen::Index batch_size = input.dimension(0);
+ const Eigen::Index in_height = input.dimension(1);
+ const Eigen::Index in_width = input.dimension(2);
+ const Eigen::Index channels = input.dimension(3);
- const int64 out_height = output.dimension(1);
- const int64 out_width = output.dimension(2);
+ const Eigen::Index out_height = output.dimension(1);
+ const Eigen::Index out_width = output.dimension(2);
output.setZero();
- for (int y = 0; y < in_height; ++y) {
- const int64 out_y = std::min(
- (align_corners) ? static_cast<int64>(roundf(y * height_scale))
- : static_cast<int64>(floorf(y * height_scale)),
+ for (Eigen::Index y = 0; y < in_height; ++y) {
+ const Eigen::Index out_y = std::min(
+ (align_corners) ? static_cast<Eigen::Index>(roundf(y * height_scale))
+ : static_cast<Eigen::Index>(floorf(y * height_scale)),
out_height - 1);
- for (int x = 0; x < in_width; ++x) {
- const int64 out_x = std::min(
- (align_corners) ? static_cast<int64>(roundf(x * width_scale))
- : static_cast<int64>(floorf(x * width_scale)),
- out_width - 1);
- for (int b = 0; b < batch_size; ++b) {
- for (int c = 0; c < channels; ++c) {
+ for (Eigen::Index x = 0; x < in_width; ++x) {
+ const Eigen::Index out_x =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(x * width_scale))
+ : static_cast<Eigen::Index>(floorf(x * width_scale)),
+ out_width - 1);
+ for (Eigen::Index b = 0; b < batch_size; ++b) {
+ for (Eigen::Index c = 0; c < channels; ++c) {
output(b, out_y, out_x, c) += input(b, y, x, c);
}
}
diff --git a/tensorflow/core/kernels/reverse_op.h b/tensorflow/core/kernels/reverse_op.h
index 934f0277a9..44e7967c5d 100644
--- a/tensorflow/core/kernels/reverse_op.h
+++ b/tensorflow/core/kernels/reverse_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REVERSE_OP_H_
-#define TENSORFLOW_KERNELS_REVERSE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -45,4 +45,4 @@ struct Reverse<Device, T, 0> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MIRROR_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h
index 8ccd32ea16..d6ba2781a9 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.h
+++ b/tensorflow/core/kernels/reverse_sequence_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
-#define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
// Generator definition for ReverseSequenceOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -75,4 +75,4 @@ struct ReverseSequence {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index e335e38bdc..82546d581a 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context,
// If we cannot find a cached reader we will allocate our own.
std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
- const checkpoint::TensorSliceReader* reader =
- context->slice_reader_cache()->GetReader(file_pattern, open_func,
- preferred_shard);
+ const checkpoint::TensorSliceReader* reader = nullptr;
+
+ if (context->slice_reader_cache()) {
+ reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ }
if (!reader) {
allocated_reader.reset(new checkpoint::TensorSliceReader(
file_pattern, open_func, preferred_shard));
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h
index 5b74b586e8..be7f4b889e 100644
--- a/tensorflow/core/kernels/save_restore_tensor.h
+++ b/tensorflow/core/kernels/save_restore_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
-#define TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
+#define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
#include "tensorflow/core/util/tensor_slice_reader.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
@@ -70,4 +70,4 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ab4de6c815..180eb3ca34 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel {
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
if (delete_old_dirs_) {
- const string& merged_dir = std::string(io::Dirname(merged_prefix));
+ const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
- const string& dirname = std::string(io::Dirname(input_prefix));
+ const string dirname(io::Dirname(input_prefix));
if (dirname == merged_dir) continue;
Status status = env->DeleteDir(dirname);
// For sharded save, only the first delete will go through and all
diff --git a/tensorflow/core/kernels/scan_ops.h b/tensorflow/core/kernels/scan_ops.h
index 1a1f71d722..13831bb377 100644
--- a/tensorflow/core/kernels/scan_ops.h
+++ b/tensorflow/core/kernels/scan_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCAN_OPS_H_
-#define TENSORFLOW_KERNELS_SCAN_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -43,4 +43,4 @@ struct Scan {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SCAN_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index ebaa2bd9c6..2d43bde23f 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
#include <type_traits>
@@ -488,4 +488,4 @@ struct ScatterScalarFunctorSYCL {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index 70809e4dcf..057755a05c 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
-#define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
#if GOOGLE_CUDA
@@ -161,4 +161,4 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
#endif // GOOGLE_CUDA
-#endif // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 1c071d3d41..a8e9b3261c 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -251,7 +251,7 @@ Status Examples::SampleAdaptiveProbabilities(
num_weight_vectors);
const double kappa = example_state_data(example_id, 0) +
loss_updater->PrimalLossDerivative(
- example_statistics.wx[0], label, example_weight);
+ example_statistics.wx[0], label, 1.0);
probabilities_[example_id] = example_weight *
sqrt(examples_[example_id].squared_norm_ +
regularization.symmetric_l2() *
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index 05c835ebc4..3bd4168dc7 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/sdca_internal.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
@@ -75,6 +76,8 @@ struct ComputeOptions {
loss_updater.reset(new HingeLossUpdater);
} else if (loss_type == "smooth_hinge_loss") {
loss_updater.reset(new SmoothHingeLossUpdater);
+ } else if (loss_type == "poisson_loss") {
+ loss_updater.reset(new PoissonLossUpdater);
} else {
OP_REQUIRES(
context, false,
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
index 271dd2c485..b5274f8788 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
@@ -85,3 +88,5 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/sendrecv_ops.h b/tensorflow/core/kernels/sendrecv_ops.h
index 1ff8eff13f..223854de13 100644
--- a/tensorflow/core/kernels/sendrecv_ops.h
+++ b/tensorflow/core/kernels/sendrecv_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SENDRECV_OPS_H_
-#define TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -49,4 +49,4 @@ class RecvOp : public AsyncOpKernel {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index f893d4e945..0428909145 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -269,7 +269,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
// Group by all but last dimension, create a set of group values, and add set
// size to output.
- VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1);
+ VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
std::set<T> group_set;
for (const auto& group : set_st.group(group_ix)) {
PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
@@ -500,8 +500,8 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
std::vector<int64> group_indices;
int64 num_elements;
@@ -621,11 +621,11 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set1_grouper = set1_st.group(
- VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1));
+ auto set1_grouper =
+ set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
auto set1_group_it = set1_grouper.begin();
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
// Group by rows, and iterate over rows of both sets in parallel, creating a
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 28a39bae3f..ab1ce0f9c8 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/kernels/shape_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
@@ -460,4 +461,96 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
SqueezeOp);
#endif // TENSORFLOW_USE_SYCL
+class EnsureShapeOp : public OpKernel {
+ public:
+ explicit EnsureShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx,
+ shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
+
+ if (!expected_shape_.IsCompatibleWith(shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Shape of tensor ", this->def().input(0), " ", shape.DebugString(),
+ " is not compatible with expected shape ",
+ expected_shape_.DebugString(), "."));
+ }
+
+ // If shape matches, outputs the tensor.
+ if (IsRefType(ctx->input_dtype(0))) {
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ ctx->set_output(0, ctx->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ PartialTensorShape expected_shape_;
+};
+
+// NOTE(rachelim): The kernel registrations for EnsureShapeOp are identical to
+// those of the identity op, since the ops have the same device type
+// constraints.
+REGISTER_KERNEL_BUILDER(Name("EnsureShape").Device(DEVICE_CPU), EnsureShapeOp);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#undef REGISTER_SYCL_KERNEL
+
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(bool);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(Variant);
+
+#undef REGISTER_GPU_KERNEL
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32 and bool.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(bool);
+REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_GPU_HOST_KERNEL
+
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index f75723af7d..7a50f158af 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
-#define TENSORFLOW_KERNELS_SHAPE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
#include <limits>
#include <unordered_set>
@@ -274,4 +274,4 @@ class SqueezeOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SHAPE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index db7eded745..1d662f6362 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SLICE_OP_H_
-#define TENSORFLOW_KERNELS_SLICE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
// Functor definition for SliceOp, must be compilable by nvcc.
@@ -51,4 +51,4 @@ struct Slice {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
diff --git a/tensorflow/core/kernels/smooth-hinge-loss.h b/tensorflow/core/kernels/smooth-hinge-loss.h
index 5074ad0795..d51f5c130e 100644
--- a/tensorflow/core/kernels/smooth-hinge-loss.h
+++ b/tensorflow/core/kernels/smooth-hinge-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
-#define TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
#include <limits>
@@ -110,5 +110,5 @@ class SmoothHingeLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
// TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h
index a18065d42b..02d492988e 100644
--- a/tensorflow/core/kernels/snapshot_op.h
+++ b/tensorflow/core/kernels/snapshot_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
-#define TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -41,4 +41,4 @@ struct Snapshot {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index d3a267ed87..c8bc1ad3bb 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
// Functor definition for SoftmaxOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -98,4 +98,4 @@ struct SoftmaxEigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
index 494a83ed14..d3fc0e1461 100644
--- a/tensorflow/core/kernels/softplus_op.cc
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftplusOp : public UnaryElementWiseOp<T, SoftplusOp<Device, T>> {
public:
explicit SoftplusOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softplus<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftplusGradOp
: public BinaryElementWiseOp<T, SoftplusGradOp<Device, T>> {
public:
explicit SoftplusGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -89,7 +84,7 @@ void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftplusGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftplusGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softplus_op.h b/tensorflow/core/kernels/softplus_op.h
index e17e175d41..8c083ba158 100644
--- a/tensorflow/core/kernels/softplus_op.h
+++ b/tensorflow/core/kernels/softplus_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
-#define TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
// Functor definition for SoftplusOp and SoftplusGradOp, must be compilable by
// nvcc.
@@ -73,4 +73,4 @@ struct SoftplusGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
index 00ee649b17..d691f15651 100644
--- a/tensorflow/core/kernels/softsign_op.cc
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
public:
explicit SoftsignOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softsign<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftsignGradOp
: public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
public:
explicit SoftsignGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -90,7 +85,7 @@ void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftsignGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h
index c2ababf697..61ff6eeede 100644
--- a/tensorflow/core/kernels/softsign_op.h
+++ b/tensorflow/core/kernels/softsign_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
-#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by
// nvcc.
@@ -57,4 +57,4 @@ struct SoftsignGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 2c1bffbee4..a4453bd7ab 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
-#define TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
@@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public:
SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
+ const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
- dtype, shape, name) {
+ dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr;
count_element_ = nullptr;
accum_val_ = nullptr;
@@ -459,4 +459,4 @@ class SparseConditionalAccumulator
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1934..1e542a26a7 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator =
- new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
- cinfo_.name());
+ new SparseConditionalAccumulator<Device, T>(
+ dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index e89280724e..6b9db8f471 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/byte_order.h"
@@ -465,4 +465,4 @@ EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
#endif
} // namespace internal
} // namespace Eigen
-#endif
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc
index dc3119bba4..37664fe8df 100644
--- a/tensorflow/core/kernels/sparse_softmax_op.cc
+++ b/tensorflow/core/kernels/sparse_softmax_op.cc
@@ -90,7 +90,7 @@ class SparseSoftmaxOp : public OpKernel {
// { 0, ..., rank-1 }.
const ArraySlice<int64> kReorderDims(dims);
// All but the last dim -- the class dimension to be max-reduced along.
- const ArraySlice<int64> kGroupByDims(kReorderDims, 0, rank - 1);
+ const ArraySlice<int64> kGroupByDims = kReorderDims.subspan(0, rank - 1);
st.Reorder<T>(kReorderDims);
int count = 0;
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.h b/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
index 353cf0e519..c26ed5e874 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -39,4 +39,4 @@ struct ScatterNdFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index da13190494..d6dd2deca5 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -71,4 +71,4 @@ class MaybeAdjoint<MATRIX, true> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h
index b5587aa9d7..6ba7931ab5 100644
--- a/tensorflow/core/kernels/sparse_xent_op.h
+++ b/tensorflow/core/kernels/sparse_xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
// Functor definition for SparseXentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -224,4 +224,4 @@ struct SparseXentEigenImpl {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
diff --git a/tensorflow/core/kernels/split_lib.h b/tensorflow/core/kernels/split_lib.h
index bc1fa28f8f..9d43a00822 100644
--- a/tensorflow/core/kernels/split_lib.h
+++ b/tensorflow/core/kernels/split_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPLIT_LIB_H_
-#define TENSORFLOW_KERNELS_SPLIT_LIB_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
+#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
// Functor definition for SplitOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -62,4 +62,4 @@ struct Split<Eigen::SyclDevice, T> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPLIT_LIB_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
diff --git a/tensorflow/core/kernels/squared-loss.h b/tensorflow/core/kernels/squared-loss.h
index 49e6db406e..d256a69350 100644
--- a/tensorflow/core/kernels/squared-loss.h
+++ b/tensorflow/core/kernels/squared-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SQUARED_LOSS_H_
-#define TENSORFLOW_KERNELS_SQUARED_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
#include "tensorflow/core/kernels/loss.h"
@@ -70,4 +70,4 @@ class SquaredLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SQUARED_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 59fdc2262a..7b537fef5b 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -300,7 +300,8 @@ class StridedSliceAssignOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
- Tensor old_lhs;
+ Tensor* old_lhs = nullptr;
+ Tensor tmp;
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(context,
@@ -308,29 +309,30 @@ class StridedSliceAssignOp : public OpKernel {
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
- old_lhs = *v->tensor();
- OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
+ old_lhs = v->tensor();
+ OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
- "l-value dtype ", DataTypeString(old_lhs.dtype()),
+ "l-value dtype ", DataTypeString(old_lhs->dtype()),
" does not match r-value dtype ",
DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
- old_lhs = context->mutable_input(0, true);
+ tmp = context->mutable_input(0, true);
+ old_lhs = &tmp;
}
OP_REQUIRES_OK(
- context,
- ValidateStridedSliceOp(
- &context->input(1), &context->input(2), context->input(3),
- old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
- shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
- &is_simple_slice, &slice_dim0, &begin, &end, &strides));
+ context, ValidateStridedSliceOp(
+ &context->input(1), &context->input(2), context->input(3),
+ old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask, &processing_shape,
+ &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
+ &begin, &end, &strides));
if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
TensorShape input_shape = input.shape();
- TensorShape original_shape = old_lhs.shape();
+ TensorShape original_shape = old_lhs->shape();
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(
@@ -345,12 +347,12 @@ class StridedSliceAssignOp : public OpKernel {
// scalar shape
// Handle general dimensions
-#define HANDLE_DIM(NDIM) \
- if (processing_dims == NDIM) { \
- HandleStridedSliceAssignCase<Device, T, NDIM>()( \
- context, begin, end, strides, processing_shape, is_simple_slice, \
- &old_lhs); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (processing_dims == NDIM) { \
+ HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
+ strides, processing_shape, \
+ is_simple_slice, old_lhs); \
+ return; \
}
HANDLE_DIM(0);
HANDLE_DIM(1);
diff --git a/tensorflow/core/kernels/strided_slice_op.h b/tensorflow/core/kernels/strided_slice_op.h
index 2b58632298..86d105391d 100644
--- a/tensorflow/core/kernels/strided_slice_op.h
+++ b/tensorflow/core/kernels/strided_slice_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_H_
-#define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
// Functor definition for StridedSliceOp, must be compilable by nvcc.
@@ -137,4 +137,4 @@ struct StridedSliceAssignScalar {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index 1c4472bb1a..099083b2ff 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
-#define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
// Functor definition for StridedSliceOp, must be compilable by nvcc.
@@ -313,4 +313,4 @@ DECLARE_FOR_N_SYCL(int64);
} // end namespace tensorflow
#endif // END STRIDED_SLICE_INSTANTIATE_DIM
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 26ab72f12e..3884370a6c 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -26,25 +26,81 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
-
namespace {
+// Split input string `str` based on a character delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Note: The single character delimiter is a common case and is implemented as
+// a series of finds in the input string, making it much more effcient than
+// SplitOnCharSet.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnChar(const string& str, const char delim,
+ Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ auto f = text.find(delim);
+ while (f != StringPiece::npos) {
+ StringPiece token = text.substr(0, f);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ text.remove_prefix(f + 1);
+ f = text.find(delim);
+ }
+ if (p(text)) {
+ result.push_back(text);
+ }
+ return result;
+}
-std::vector<string> Split(const string& str, const string& delimiter,
- const bool skipEmpty) {
- if (!delimiter.empty()) {
- if (skipEmpty) {
- return str_util::Split(str, delimiter, str_util::SkipEmpty());
+// Split input string `str` based on a set of character delimiters.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Based on str_util::Split.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnCharSet(const string& str,
+ const string& delim_set, Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ StringPiece delims(delim_set);
+ size_t token_start = 0;
+ for (size_t i = 0; i < text.size() + 1; i++) {
+ if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
+ StringPiece token(text.data() + token_start, i - token_start);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ token_start = i + 1;
}
- return str_util::Split(str, delimiter);
}
- std::vector<string> char_vector(str.size());
- for (size_t i = 0; i < str.size(); ++i) {
- char_vector[i] = str[i];
+ return result;
+}
+
+// Split input string `str` based on given delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+template <typename Predicate>
+std::vector<StringPiece> Split(const string& str, const string& delimiter,
+ Predicate predicate) {
+ if (str.empty()) {
+ return std::vector<StringPiece>();
+ }
+ if (delimiter.empty()) {
+ std::vector<StringPiece> result;
+ result.resize(str.size());
+ for (size_t i = 0; i < str.size(); ++i) {
+ result[i] = StringPiece(str.data() + i, 1);
+ }
+ return result;
}
- return char_vector;
+ if (delimiter.size() == 1) {
+ return SplitOnChar(str, delimiter[0], predicate);
+ }
+ return SplitOnCharSet(str, delimiter, predicate);
}
-std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
+std::vector<StringPiece> SplitV2(const string& str, StringPiece sep,
+ int maxsplit) {
// This SplitV2 method matches the behavior of python's str.split:
// If sep is given, consecutive delimiters are not grouped together
// and are deemed to delimit empty strings (for example, '1,,2'.split(',')
@@ -59,11 +115,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
// splitting an empty string or a string consisting of just whitespace
// with a None separator returns [].
- std::vector<string> result;
+ std::vector<StringPiece> result;
StringPiece text(str);
if (maxsplit == 0) {
- result.emplace_back(std::string(text));
+ result.emplace_back(text);
return result;
}
@@ -73,11 +129,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
str_util::RemoveLeadingWhitespace(&text);
int split = 0;
while (str_util::ConsumeNonWhitespace(&text, &token)) {
- result.emplace_back(std::string(token));
+ result.push_back(token);
str_util::RemoveLeadingWhitespace(&text);
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
}
@@ -87,17 +143,17 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
int split = 0;
while (p != text.end()) {
StringPiece token = text.substr(0, p - text.begin());
- result.emplace_back(std::string(token));
+ result.push_back(token);
text.remove_prefix(token.size());
text.remove_prefix(sep.size());
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(StringPiece(text));
return result;
}
p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
}
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
@@ -134,7 +190,7 @@ class StringSplitOp : public OpKernel {
const auto delimiter_vec = delimiter_tensor->flat<string>();
const string& delimiter = delimiter_vec(0);
// Empty delimiter means split the input character by character.
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -143,12 +199,15 @@ class StringSplitOp : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = Split(input_vec(i), delimiter, skip_empty_);
+ std::vector<StringPiece> parts =
+ skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty())
+ : Split(input_vec(i), delimiter, str_util::AllowEmpty());
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
max_num_entries = std::max(max_num_entries, n_entries);
- tokens.insert(tokens.end(), parts.begin(), parts.end());
+ tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()),
+ std::make_move_iterator(parts.end()));
}
Tensor* sp_indices_t;
@@ -170,7 +229,7 @@ class StringSplitOp : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
@@ -204,7 +263,7 @@ class StringSplitV2Op : public OpKernel {
sep_tensor->shape().DebugString()));
const auto sep_vec = sep_tensor->flat<string>();
StringPiece sep(sep_vec(0));
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -213,7 +272,7 @@ class StringSplitV2Op : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = SplitV2(input_vec(i), sep, maxsplit_);
+ std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_);
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
@@ -240,7 +299,7 @@ class StringSplitV2Op : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc
new file mode 100644
index 0000000000..58ad61adc8
--- /dev/null
+++ b/tensorflow/core/kernels/string_split_op_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupStringSplitGraph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor delim(DT_STRING, TensorShape({}));
+ delim.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, delim))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplit(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitGraph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplit)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStringSplitV2Graph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor sep(DT_STRING, TensorShape({}));
+ sep.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, sep))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplitV2(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitV2Graph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplitV2)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 2aeafa28c4..544dca96ba 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -43,7 +43,7 @@ class StringStripOp : public OpKernel {
for (int64 i = 0; i < input.size(); ++i) {
StringPiece entry(input(i));
str_util::RemoveWhitespaceContext(&entry);
- output(i) = std::string(entry);
+ output(i) = string(entry);
}
}
};
diff --git a/tensorflow/core/kernels/svd_op_impl.h b/tensorflow/core/kernels/svd_op_impl.h
index a996b67c62..2a67700c12 100644
--- a/tensorflow/core/kernels/svd_op_impl.h
+++ b/tensorflow/core/kernels/svd_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual svd_*op*.cc files for registering
@@ -101,3 +104,5 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 68fab85770..e8dc4fad21 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
-#define TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
+#define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
#include <limits.h>
#include <vector>
@@ -629,4 +629,4 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
+#endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index b368ffc875..2ec2651c04 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -297,7 +297,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
resource.name());
}
tensor_array_name =
- std::string(StringPiece(resource.name()).substr(container.size()));
+ string(StringPiece(resource.name()).substr(container.size()));
}
auto output_handle = tensor_array_output_handle->flat<string>();
@@ -1119,8 +1119,8 @@ class TensorArrayUnpackOrScatterOp : public OpKernel {
{1, num_values, element_shape.num_elements()});
Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1,
- element_shape.num_elements()};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, 1, static_cast<Eigen::DenseIndex>(element_shape.num_elements())};
std::vector<PersistentTensor> write_values;
write_values.reserve(num_values);
@@ -1315,9 +1315,11 @@ class TensorArraySplitOp : public OpKernel {
PersistentTensor persistent_tensor;
int64 previous_length = (i == 0) ? 0 : cumulative_lengths[i - 1];
- Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, previous_length, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, tensor_lengths_t(i),
- elements_per_row};
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{
+ 0, static_cast<Eigen::DenseIndex>(previous_length), 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, static_cast<Eigen::DenseIndex>(tensor_lengths_t(i)),
+ static_cast<Eigen::DenseIndex>(elements_per_row)};
OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
tensor_array->ElemType(), element_shapes[i],
diff --git a/tensorflow/core/kernels/tile_functor.h b/tensorflow/core/kernels/tile_functor.h
index 189be9239b..95986af8b7 100644
--- a/tensorflow/core/kernels/tile_functor.h
+++ b/tensorflow/core/kernels/tile_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -106,4 +106,4 @@ struct Tile {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/tile_ops_impl.h b/tensorflow/core/kernels/tile_ops_impl.h
index 9861717a0b..6a9de388c6 100644
--- a/tensorflow/core/kernels/tile_ops_impl.h
+++ b/tensorflow/core/kernels/tile_ops_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
-#define TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -68,4 +68,4 @@ struct ReduceAndReshape {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TILE_OPS_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
diff --git a/tensorflow/core/kernels/topk_op.h b/tensorflow/core/kernels/topk_op.h
index a53e3ec8d4..1fdbc5b15f 100644
--- a/tensorflow/core/kernels/topk_op.h
+++ b/tensorflow/core/kernels/topk_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_TOPK_OP_H_
-#define TENSORFLOW_TOPK_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
+#define TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -39,4 +39,4 @@ struct TopKFunctor {
} // end namespace tensorflow
-#endif // TENSORFLOW_TOPK_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 765335d3a0..071cb371a7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
-#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
+#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/variant_op_registry.h"
@@ -90,4 +90,4 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
+#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 271329599f..9a07ded17d 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
-
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include <algorithm>
@@ -201,7 +200,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
typename TTypes<T>::ConstScalar l2_shrinkage,
typename TTypes<T>::ConstScalar lr_power) {
auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
- auto new_accum = accum + grad_with_shrinkage.square();
+ auto new_accum = accum + grad * grad;
// special case for which lr_power=-0.5.
if (lr_power() == static_cast<T>(-0.5)) {
linear.device(d) +=
@@ -226,7 +225,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
var.device(d) = (linear.abs() > linear.constant(l1()))
.select(pre_shrink, var.constant(static_cast<T>(0)));
}
- accum.device(d) += grad_with_shrinkage.square();
+ accum.device(d) += grad * grad;
}
};
@@ -2167,15 +2166,15 @@ class SparseApplyFtrlOp : public OpKernel {
// Use a macro to implement the computation here due to the templating of the
// eigen tensor library.
-#define COMPUTE_FTRL(grad_to_use) \
- auto new_accum = accum + grad_to_use.square(); \
+#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \
+ auto new_accum = accum + grad.square(); \
if (lr_power_scalar == static_cast<T>(-0.5)) { \
- linear += \
- grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - \
+ (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
} else { \
- linear += grad_to_use - (new_accum.pow(-lr_power_scalar) - \
- accum.pow(-lr_power_scalar)) / \
- lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
+ accum.pow(-lr_power_scalar)) / \
+ lr_scalar * var; \
} \
auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \
auto x = l1_reg_adjust - linear; \
@@ -2188,14 +2187,14 @@ class SparseApplyFtrlOp : public OpKernel {
linear.constant(static_cast<T>(2) * l2_scalar); \
var = x / y; \
} \
- accum += grad_to_use.square();
+ accum += grad.square();
if (has_l2_shrinkage) {
auto grad_with_shrinkage =
grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
- COMPUTE_FTRL(grad_with_shrinkage);
+ COMPUTE_FTRL(grad, grad_with_shrinkage);
} else {
- COMPUTE_FTRL(grad);
+ COMPUTE_FTRL(grad, grad);
}
}
#undef COMPUTE_FTRL
@@ -2228,12 +2227,12 @@ class SparseApplyFtrlOp : public OpKernel {
T g;
if (has_l2_shrinkage) {
g = grad_flat(i) +
- (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i));
+ (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index));
} else {
g = grad_flat(i);
}
- T updated_a = a + g * g;
+ T updated_a = a + grad_flat(i) * grad_flat(i);
using Eigen::numext::pow;
T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
sigma /= lr_scalar;
@@ -2856,9 +2855,8 @@ class ApplyAdaMaxOp : public OpKernel {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyAdaMax<Device, T>()(
device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
- beta1_power.scalar<T>(), lr.scalar<T>(),
- beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
- grad.flat<T>());
+ beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(),
+ beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
@@ -2867,16 +2865,16 @@ class ApplyAdaMaxOp : public OpKernel {
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>); \
REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
- .HostMemory("var") \
- .HostMemory("m") \
- .HostMemory("v") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T"), \
+ .HostMemory("var") \
+ .HostMemory("m") \
+ .HostMemory("v") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -2889,7 +2887,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
- void ApplyAdaMax<GPUDevice, T>::operator()( \
+ void ApplyAdaMax<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
typename TTypes<T>::ConstScalar beta1_power, \
@@ -2897,7 +2895,7 @@ namespace functor {
typename TTypes<T>::ConstScalar beta1, \
typename TTypes<T>::ConstScalar beta2, \
typename TTypes<T>::ConstScalar epsilon, \
- typename TTypes<T>::ConstFlat grad); \
+ typename TTypes<T>::ConstFlat grad); \
extern template struct ApplyAdaMax<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index 495a94f1a1..e10a4cb125 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
-#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -199,4 +199,4 @@ struct ApplyPowerSign {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TRAINING_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 1980f758fc..ca341e511e 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
-#define TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
@@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public:
TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
- : ConditionalAccumulatorBase(dtype, shape, name) {}
+ const string& name,
+ const string& reduction_type)
+ : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
@@ -91,4 +92,4 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index f27dab4ddd..4742e429ed 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_
-#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -46,4 +46,4 @@ class VariableOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
diff --git a/tensorflow/core/kernels/warn_about_ints.cc b/tensorflow/core/kernels/warn_about_ints.cc
deleted file mode 100644
index 75ecdf2ae4..0000000000
--- a/tensorflow/core/kernels/warn_about_ints.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/kernels/warn_about_ints.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-
-namespace tensorflow {
-
-void WarnAboutInts(OpKernelConstruction* context) {
- DataType dtype;
- OP_REQUIRES_OK(context, context->GetAttr("T", &dtype));
- if (DataTypeIsInteger(dtype)) {
- LOG(WARNING) << "Op " << context->def().name() << " of type "
- << context->def().op() << " used with integer dtype "
- << DataTypeString(dtype)
- << ". This op was registered with integer support "
- << "accidentally, and you won't like the result.";
- }
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h
index d26849c8bd..e63b3ba8cd 100644
--- a/tensorflow/core/kernels/where_op.h
+++ b/tensorflow/core/kernels/where_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_
-#define TENSORFLOW_KERNELS_WHERE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -63,4 +63,4 @@ struct Where {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_WHERE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 57f51889de..8879d9dd4c 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -346,3 +349,5 @@ TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index ed2bf3e8e2..1bf46b5e46 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel {
"Contents tensor must be scalar, but had shape: ",
contents_input->shape().DebugString()));
const string& filename = filename_input->scalar<string>()();
- const string dir = std::string(io::Dirname(filename));
+ const string dir(io::Dirname(filename));
if (!context->env()->FileExists(dir).ok()) {
OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
}
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
index 87be17fca9..23d3ad39a8 100644
--- a/tensorflow/core/kernels/xent_op.h
+++ b/tensorflow/core/kernels/xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_XENT_OP_H_
// Functor definition for XentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -125,4 +125,4 @@ struct XentEigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_XENT_OP_H_
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index d6f3f26cd5..5c917e80c1 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -61,9 +61,7 @@ struct bfloat16 {
}
B16_DEVICE_FUNC explicit bfloat16(const float v) {
- // TODO(asabne) : change the below line to
- // value = round_to_bfloat16(v).value;
- value = truncate_to_bfloat16(v).value;
+ value = round_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h
index 5698303247..624ee77027 100644
--- a/tensorflow/core/lib/core/arena.h
+++ b/tensorflow/core/lib/core/arena.h
@@ -15,8 +15,8 @@ limitations under the License.
// TODO(vrv): Switch this to an open-sourced version of Arena.
-#ifndef TENSORFLOW_LIB_CORE_ARENA_H_
-#define TENSORFLOW_LIB_CORE_ARENA_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_ARENA_H_
+#define TENSORFLOW_CORE_LIB_CORE_ARENA_H_
#include <assert.h>
@@ -107,4 +107,4 @@ class Arena {
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_ARENA_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_ARENA_H_
diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h
index 1110ef5c2a..86e539a266 100644
--- a/tensorflow/core/lib/core/bits.h
+++ b/tensorflow/core/lib/core/bits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_BITS_H_
-#define TENSORFLOW_LIB_CORE_BITS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_BITS_H_
+#define TENSORFLOW_CORE_LIB_CORE_BITS_H_
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -106,4 +106,4 @@ inline uint64 NextPowerOfTwo64(uint64 value) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_BITS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_BITS_H_
diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h
index 0f925c6051..7546d4edc5 100644
--- a/tensorflow/core/lib/core/casts.h
+++ b/tensorflow/core/lib/core/casts.h
@@ -20,8 +20,8 @@ limitations under the License.
// any changes here, make sure that you're not breaking any platforms.
//
-#ifndef TENSORFLOW_LIB_CORE_CASTS_H_
-#define TENSORFLOW_LIB_CORE_CASTS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_CASTS_H_
+#define TENSORFLOW_CORE_LIB_CORE_CASTS_H_
#include <string.h> // for memcpy
@@ -97,4 +97,4 @@ inline Dest bit_cast(const Source& source) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_CASTS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_CASTS_H_
diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h
index 8265aec870..4a70ffa619 100644
--- a/tensorflow/core/lib/core/coding.h
+++ b/tensorflow/core/lib/core/coding.h
@@ -18,8 +18,8 @@ limitations under the License.
// * In addition we support variable length "varint" encoding
// * Strings are encoded prefixed by their length in varint format
-#ifndef TENSORFLOW_LIB_CORE_CODING_H_
-#define TENSORFLOW_LIB_CORE_CODING_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_CODING_H_
+#define TENSORFLOW_CORE_LIB_CORE_CODING_H_
#include "tensorflow/core/lib/core/raw_coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -76,4 +76,4 @@ extern int VarintLength(uint64_t v);
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_CODING_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_CODING_H_
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index a631d9815a..d5cbe6c616 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_
-#define TENSORFLOW_LIB_CORE_ERRORS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+#define TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
#include <sstream>
@@ -131,11 +131,23 @@ inline string FormatNodeNameForError(const string& name) {
// LINT.ThenChange(//tensorflow/python/client/session.py)
template <typename T>
string FormatNodeNamesForError(const T& names) {
- ::tensorflow::str_util::Formatter<string> f(
- [](string* output, const string& s) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
});
- return ::tensorflow::str_util::Join(names, ", ", f);
+}
+// LINT.IfChange
+inline string FormatColocationNodeForError(const string& name) {
+ return strings::StrCat("{{colocation_node ", name, "}}");
+}
+// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
+template <typename T>
+string FormatColocationNodeForError(const T& names) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
+ ::tensorflow::strings::StrAppend(output,
+ FormatColocationNodeForError(s));
+ });
}
// The CanonicalCode() for non-errors.
@@ -144,4 +156,4 @@ using ::tensorflow::error::OK;
} // namespace errors
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_ERRORS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h
index b3e515e28f..5def958e6b 100644
--- a/tensorflow/core/lib/core/notification.h
+++ b/tensorflow/core/lib/core/notification.h
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_
-#define TENSORFLOW_UTIL_NOTIFICATION_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
+#define TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
// Notification implementation is platform-dependent, to support
// alternative synchronization primitives.
#include "tensorflow/core/platform/notification.h"
-#endif // TENSORFLOW_UTIL_NOTIFICATION_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h
index 37201b755d..f49214939b 100644
--- a/tensorflow/core/lib/core/raw_coding.h
+++ b/tensorflow/core/lib/core/raw_coding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_
-#define TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
+#define TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
#include <string.h>
#include "tensorflow/core/platform/byte_order.h"
@@ -68,4 +68,4 @@ inline uint64 DecodeFixed64(const char* ptr) {
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc
index 12dfcd284f..cb2a06e620 100644
--- a/tensorflow/core/lib/core/status.cc
+++ b/tensorflow/core/lib/core/status.cc
@@ -22,7 +22,7 @@ Status::Status(tensorflow::error::Code code, StringPiece msg) {
assert(code != tensorflow::error::OK);
state_ = std::unique_ptr<State>(new State);
state_->code = code;
- state_->msg = msg.ToString();
+ state_->msg = string(msg);
}
void Status::Update(const Status& new_status) {
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
index b35633c9da..c695caa8d1 100644
--- a/tensorflow/core/lib/core/status_test_util.h
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
-#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
+#define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
@@ -31,4 +31,4 @@ limitations under the License.
// If you want to check for particular errors, a better alternative is:
// EXPECT_EQ(..expected tensorflow::error::Code..., status.code());
-#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
deleted file mode 100644
index 4c488066e4..0000000000
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-#include <algorithm>
-#include <iostream>
-
-namespace tensorflow {
-
-std::ostream& operator<<(std::ostream& o, StringPiece piece) {
- o.write(piece.data(), piece.size());
- return o;
-}
-
-size_t StringPiece::find(char c, size_t pos) const {
- if (pos >= size_) {
- return npos;
- }
- const char* result =
- reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
- return result != nullptr ? result - data_ : npos;
-}
-
-// Search range is [0..pos] inclusive. If pos == npos, search everything.
-size_t StringPiece::rfind(char c, size_t pos) const {
- if (size_ == 0) return npos;
- for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
- if (*p == c) {
- return p - data_;
- }
- }
- return npos;
-}
-
-StringPiece StringPiece::substr(size_t pos, size_t n) const {
- if (pos > size_) pos = size_;
- if (n > size_ - pos) n = size_ - pos;
- return StringPiece(data_ + pos, n);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index d7ecc44e50..e7b17c9b36 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -23,129 +23,22 @@ limitations under the License.
// non-const method, all threads accessing the same StringPiece must use
// external synchronization.
-#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_
-#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
+#define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
#include <assert.h>
#include <stddef.h>
#include <string.h>
#include <iosfwd>
#include <string>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class StringPiece {
- public:
- typedef size_t size_type;
-
- // Create an empty slice.
- StringPiece() : data_(nullptr), size_(0) {}
-
- // Create a slice that refers to d[0,n-1].
- StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
-
- // Create a slice that refers to the contents of "s"
- StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
-
- // Create a slice that refers to s[0,strlen(s)-1]
- StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
-
- // Return a pointer to the beginning of the referenced data
- const char* data() const { return data_; }
-
- // Return the length (in bytes) of the referenced data
- size_t size() const { return size_; }
-
- // Return true iff the length of the referenced data is zero
- bool empty() const { return size_ == 0; }
-
- typedef const char* const_iterator;
- typedef const char* iterator;
- iterator begin() const { return data_; }
- iterator end() const { return data_ + size_; }
-
- static const size_t npos = size_type(-1);
-
- // Return the ith byte in the referenced data.
- // REQUIRES: n < size()
- char operator[](size_t n) const {
- assert(n < size());
- return data_[n];
- }
-
- // Drop the first "n" bytes from this slice.
- void remove_prefix(size_t n) {
- assert(n <= size());
- data_ += n;
- size_ -= n;
- }
-
- void remove_suffix(size_t n) {
- assert(size_ >= n);
- size_ -= n;
- }
-
- size_t find(char c, size_t pos = 0) const;
- size_t rfind(char c, size_t pos = npos) const;
-
- StringPiece substr(size_t pos, size_t n = npos) const;
-
- // Return a string that contains the copy of the referenced data.
- // DEPRECATED: use std::string(sv) instead.
- std::string ToString() const { return std::string(data_, size_); }
-
- // Three-way comparison. Returns value:
- // < 0 iff "*this" < "b",
- // == 0 iff "*this" == "b",
- // > 0 iff "*this" > "b"
- int compare(StringPiece b) const;
-
- // Converts to `std::basic_string`.
- template <typename A>
- explicit operator std::basic_string<char, std::char_traits<char>, A>() const {
- if (!data()) return {};
- return std::basic_string<char, std::char_traits<char>, A>(data(), size());
- }
-
- private:
- const char* data_;
- size_t size_;
-
- // Intentionally copyable
-};
-
-inline bool operator==(StringPiece x, StringPiece y) {
- return ((x.size() == y.size()) &&
- (memcmp(x.data(), y.data(), x.size()) == 0));
-}
-
-inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
-
-inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
-inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
-inline bool operator<=(StringPiece x, StringPiece y) {
- return x.compare(y) <= 0;
-}
-inline bool operator>=(StringPiece x, StringPiece y) {
- return x.compare(y) >= 0;
-}
-
-inline int StringPiece::compare(StringPiece b) const {
- const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
- int r = memcmp(data_, b.data_, min_len);
- if (r == 0) {
- if (size_ < b.size_)
- r = -1;
- else if (size_ > b.size_)
- r = +1;
- }
- return r;
-}
-
-// allow StringPiece to be logged
-extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+// Deprecated: please use absl::string_view directly.
+using StringPiece = absl::string_view;
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
diff --git a/tensorflow/core/lib/core/stringpiece_test.cc b/tensorflow/core/lib/core/stringpiece_test.cc
index 952b9eaaaa..e4b489fe17 100644
--- a/tensorflow/core/lib/core/stringpiece_test.cc
+++ b/tensorflow/core/lib/core/stringpiece_test.cc
@@ -56,8 +56,8 @@ TEST(StringPiece, Ctor) {
}
TEST(StringPiece, ConversionToString) {
- EXPECT_EQ("", std::string(StringPiece("")));
- EXPECT_EQ("foo", std::string(StringPiece("foo")));
+ EXPECT_EQ("", string(StringPiece("")));
+ EXPECT_EQ("foo", string(StringPiece("foo")));
}
} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index b89b74b8de..74df7c84a4 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
-#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
+#define TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
#include <functional>
#include <memory>
@@ -108,4 +108,4 @@ class ThreadPool {
} // namespace thread
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
index 002d166c72..8f47faf89e 100644
--- a/tensorflow/core/lib/gtl/array_slice.h
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -13,302 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An ArraySlice<T> represents an immutable array of elements of type
-// T. It has a length "length", and a base pointer "ptr", and the
-// array it represents contains the elements "ptr[0] .. ptr[len-1]".
-// The backing store for the array is *not* owned by the ArraySlice
-// object, and clients must arrange for the backing store to remain
-// live while the ArraySlice object is in use.
-//
-// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
-// array elements of type T.
-//
-// Implicit conversion operations are provided from types such as
-// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
-// objects constructed from types in this way may be invalidated by
-// any operations that mutate the underlying vector.
-//
-// One common use for ArraySlice is when passing arguments to a
-// routine where you want to be able to accept a variety of array
-// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
-// etc.). The usual approach here is to have the client explicitly
-// pass in a pointer and a length, as in:
-//
-// void MyRoutine(const int* elems, int N) {
-// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
-// }
-//
-// Unfortunately, this leads to ugly and error-prone code at the call site:
-//
-// std::vector<int> my_vector;
-// MyRoutine(vector_as_array(&my_vector), my_vector.size());
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
-//
-// int my_array[10];
-// MyRoutine(my_array, 10);
-//
-// Instead, you can use an ArraySlice as the argument to the routine:
-//
-// void MyRoutine(ArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
-// }
-//
-// This makes the call sites cleaner, for the most part:
-//
-// std::vector<int> my_vector;
-// MyRoutine(my_vector);
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector);
-//
-// int my_array[10];
-// MyRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
-//
-// MutableArraySlice<T> represents a mutable array of elements, and, like
-// ArraySlice, does not own the backing store. The implicit constructors it
-// provides allow functions not to worry about whether their mutable arguments
-// refer to vectors, arrays, proto2::RepeatedFields, etc.:
-//
-// void MyMutatingRoutine(MutableArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
-// }
-//
-// std::vector<int> my_vector;
-// MyMutatingRoutine(&my_vector);
-//
-// int my_array[10];
-// MyMutatingRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
-//
-// MyProto my_proto;
-// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
-// MyMutatingRoutine(my_proto.mutable_value());
+#ifndef TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
+#define TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
-
-#include <initializer_list>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "absl/types/span.h"
+// TODO(timshen): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
namespace gtl {
template <typename T>
-class ArraySlice {
- private:
- typedef array_slice_internal::ArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- ArraySlice() : impl_(nullptr, 0) {}
- ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- template <size_t N>
- ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- // The constructor for any class supplying 'data() const' that returns either
- // const T* or a less const-qualified version of it, and 'some_integral_type
- // size() const'. proto2::RepeatedField<T>, string and (since C++11)
- // std::vector<T,A> and std::array<T, N> are examples of this. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- ArraySlice(const V& v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Implicitly constructs an ArraySlice from an initializer list. This makes it
- // possible to pass a brace-enclosed initializer list to a function expecting
- // an ArraySlice:
- // void Process(ArraySlice<int> x);
- // Process({1, 2, 3});
- // The data referenced by the initializer_list must outlive this
- // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
- // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
- // reference data that is no longer valid.
- ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
- : impl_(v.begin(), v.size()) {}
+using ArraySlice = absl::Span<const T>;
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- ArraySlice(const ArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- const_pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- const_reference operator[](size_type i) const { return impl_[i]; }
- const_reference at(size_type i) const { return impl_.at(i); }
- const_reference front() const { return impl_.front(); }
- const_reference back() const { return impl_.back(); }
-
- const_iterator begin() const { return impl_.begin(); }
- const_iterator end() const { return impl_.end(); }
- const_reverse_iterator rbegin() const { return impl_.rbegin(); }
- const_reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
- void pop_back() { remove_suffix(1); }
- void pop_front() { remove_prefix(1); }
-
- // These relational operators have the same semantics as the
- // std::vector<T> relational operators: they do deep (element-wise)
- // comparisons. Array slices are equal iff their size is the same
- // and all their elements are equal.
- bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
- bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
-
- private:
- Impl impl_;
-};
-
-// Mutable version of ArraySlice, which allows the clients to mutate the
-// underlying data. It is implicitly convertible to ArraySlice since it provides
-// the data() and size() methods with correct signatures. When a
-// MutableArraySlice is created from a pointer to a container (as opposed to raw
-// memory pointer), the pointer must not be null.
-//
-// A note on const-ness: "mutable" here refers to the mutability of the
-// underlying data, not of the slice itself. It is perfectly reasonable to have
-// a variable of type "const MutableArraySlice<T>"; this means that the bounds
-// of the view on the array cannot be changed, but the underlying data in the
-// array still may be modified. This is akin to a "T* const" pointer, as opposed
-// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
-template <typename T>
-class MutableArraySlice {
- private:
- typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- MutableArraySlice() : impl_(nullptr, 0) {}
- MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- template <size_t N>
- MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- MutableArraySlice(
- InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
- // (the former is called if both exist), and 'some_integral_type size()
- // const'. proto2::RepeatedField is an example of this. Also supports string
- // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- MutableArraySlice(V* v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Substring of another MutableArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- // Accessors.
- pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- reference operator[](size_type i) const { return impl_[i]; }
- reference at(size_type i) const { return impl_.at(i); }
- reference front() const { return impl_.front(); }
- reference back() const { return impl_.back(); }
-
- iterator begin() const { return impl_.begin(); }
- iterator end() const { return impl_.end(); }
- reverse_iterator rbegin() const { return impl_.rbegin(); }
- reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
- void pop_back() { remove_suffix(1); }
- void pop_front() { remove_prefix(1); }
-
- bool operator==(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) == other;
- }
- bool operator!=(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) != other;
- }
-
- // DEPRECATED(jacobsa): Please use data() instead.
- pointer mutable_data() const { return impl_.data(); }
-
- private:
- Impl impl_;
-};
-
-template <typename T>
-const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
template <typename T>
-const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+using MutableArraySlice = absl::Span<T>;
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
deleted file mode 100644
index 689dd8a646..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_internal.h
+++ /dev/null
@@ -1,269 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
-// array_slice.h.
-
-// Helper functions and templates for ArraySlice.
-
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-
-#include <stddef.h>
-#include <algorithm>
-#include <iterator>
-#include <memory>
-#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace array_slice_internal {
-
-// Template logic for generic constructors.
-
-// Wrappers whose Get() delegates to the appropriate method of a container, and
-// is defined when this method exists. Delegates to the const method if C is a
-// const type.
-struct Data {
- template <typename C>
- static decltype(std::declval<C>().data()) Get(C* v) {
- return v->data();
- }
-};
-
-struct MutableData {
- template <typename C>
- static decltype(std::declval<C>().mutable_data()) Get(C* v) {
- return v->mutable_data();
- }
-};
-
-struct Size {
- template <typename C>
- static decltype(std::declval<C>().size()) Get(C* v) {
- return v->size();
- }
-};
-
-struct MutableStringData {
- // Defined only for string.
- static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
-};
-
-// Checks whether M::Get(C*) is defined and has a return type R such that
-// Checker::valid<R>()==true.
-template <typename M, typename Checker, typename C>
-struct HasGetHelper : public M {
- private:
- struct None {};
- // M::Get is selected when it is viable. Get(...) is selected otherwise.
- using M::Get;
- static None Get(...);
-
- public:
- static constexpr bool HasGet() {
- using Result = decltype(Get(std::declval<C*>()));
- return !std::is_same<Result, None>() && Checker::template valid<Result>();
- }
-};
-
-// Defines HasGet() for a particular method, container, and checker. If
-// HasGet()==true, provides Get() that delegates to the method.
-template <typename M, typename Checker, typename C,
- bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
-struct Wrapper {
- static constexpr bool HasGet() { return false; }
-};
-
-template <typename M, typename Checker, typename C>
-struct Wrapper<M, Checker, C, true> {
- static constexpr bool HasGet() { return true; }
- static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
-};
-
-// Type checker for a method returning an integral value.
-struct SizeChecker {
- template <typename R>
- static constexpr bool valid() {
- return std::is_integral<R>::value;
- }
-};
-
-// Type checker for a method returning either a pointer to T or a less const
-// version of that.
-template <typename T>
-struct DataChecker {
- // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
- // but
- // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
- // use
- // the fact that U** is convertible to Q* const* if and only if Q is the same
- // type or a more cv-qualified version of U.
- template <typename R>
- static constexpr bool valid() {
- return std::is_convertible<R*, T* const*>::value;
- }
-};
-
-// Aliases to A if A::HasGet()==true, or to B otherwise.
-template <typename A, typename B>
-using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
-
-// Wraps C::data() const, returning a pointer to const data.
-template <typename T, typename C>
-using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
-
-// Wraps a method returning a pointer to mutable data. Prefers data() over
-// mutable_data(), and handles strings when T==char. If data() returns a pointer
-// to mutable data, it is most likely overloaded, but may also be a single
-// method 'T* C::data() const' in a non-STL-compliant container.
-template <typename T, typename C>
-using ContainerMutableData =
- FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
- FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
- Wrapper<MutableStringData, DataChecker<T>, C>>>;
-
-// Wraps C::size() const.
-template <typename C>
-using ContainerSize = Wrapper<Size, SizeChecker, const C>;
-
-// Implementation class for ArraySlice and MutableArraySlice. In the case of
-// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
-// mutable type.
-template <typename T>
-class ArraySliceImplBase {
- public:
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
- typedef std::reverse_iterator<iterator> reverse_iterator;
- typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
- typedef size_t size_type;
- typedef ptrdiff_t difference_type;
-
- static const size_type npos = static_cast<size_type>(-1);
-
- ArraySliceImplBase(pointer array, size_type length)
- : ptr_(array), length_(length) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
- : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
-
- // Some of the const methods below return pointers and references to mutable
- // data. This is only the case in this internal class; ArraySlice and
- // MutableArraySlice provide deep-constness.
-
- pointer data() const { return ptr_; }
- size_type size() const { return length_; }
-
- void clear() {
- ptr_ = nullptr;
- length_ = 0;
- }
-
- reference operator[](size_type i) const { return ptr_[i]; }
- reference at(size_type i) const {
- DCHECK_LT(i, length_);
- return ptr_[i];
- }
- reference front() const {
- DCHECK_GT(length_, 0);
- return ptr_[0];
- }
- reference back() const {
- DCHECK_GT(length_, 0);
- return ptr_[length_ - 1];
- }
-
- void remove_prefix(size_type n) {
- DCHECK_GE(length_, n);
- ptr_ += n;
- length_ -= n;
- }
- void remove_suffix(size_type n) {
- DCHECK_GE(length_, n);
- length_ -= n;
- }
-
- iterator begin() const { return ptr_; }
- iterator end() const { return ptr_ + length_; }
- reverse_iterator rbegin() const { return reverse_iterator(end()); }
- reverse_iterator rend() const { return reverse_iterator(begin()); }
-
- bool operator==(const ArraySliceImplBase& other) const {
- if (size() != other.size()) return false;
- if (data() == other.data()) return true;
- return std::equal(data(), data() + size(), other.data());
- }
- bool operator!=(const ArraySliceImplBase& other) const {
- return !(*this == other);
- }
-
- private:
- pointer ptr_;
- size_type length_;
-};
-
-template <typename T>
-class ArraySliceImpl : public ArraySliceImplBase<const T> {
- public:
- using ArraySliceImplBase<const T>::ArraySliceImplBase;
-
- // Defined iff the data and size accessors for the container C have been
- // defined.
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- // Constructs from a container when EnableIfConvertibleFrom is
- // defined. std::addressof handles types with overloaded operator&.
- template <typename C>
- explicit ArraySliceImpl(const C& v)
- : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
- ContainerSize<C>::Get(std::addressof(v))) {}
-};
-
-template <typename T>
-class MutableArraySliceImpl : public ArraySliceImplBase<T> {
- public:
- using ArraySliceImplBase<T>::ArraySliceImplBase;
-
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- template <typename C>
- explicit MutableArraySliceImpl(C* v)
- : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
- ContainerSize<C>::Get(v)) {}
-};
-
-} // namespace array_slice_internal
-} // namespace gtl
-} // namespace tensorflow
-
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
deleted file mode 100644
index 4d3da85b88..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_test.cc
+++ /dev/null
@@ -1,666 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/array_slice.h"
-
-#include <algorithm>
-#include <array>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace {
-
-typedef ArraySlice<int> IntSlice;
-typedef ArraySlice<char> CharSlice;
-typedef MutableArraySlice<int> MutableIntSlice;
-typedef MutableArraySlice<char> MutableCharSlice;
-typedef std::vector<int> IntVec;
-
-// Append 0..len-1 to *v
-template <typename Vector>
-static void Fill(Vector* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
- IntSlice other; // To test the assignment return value.
- IntSlice v = other = vorig;
- const int len = vec.size();
- EXPECT_EQ(v.size(), vec.size());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(v[i], vec[i]);
- EXPECT_EQ(v.at(i), vec[i]);
- }
- EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
-
- int counter = 0;
- for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- if (len > 1) {
- v.pop_front();
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i + 1, v[i]);
- }
- }
- }
-}
-
-// The element access test that is applicable both when MutableArraySlice is
-// const and when it's not.
-template <class V>
-void MutableTestHelperTemplated(V v, int* ptr, const int len) {
- CHECK_EQ(v.size(), len);
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(ptr + i, &v[i]);
- EXPECT_EQ(ptr + i, &v.at(i));
- }
- EXPECT_EQ(ptr, v.begin());
- EXPECT_EQ(ptr + len, v.end());
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
-
- if (len > 0) {
- EXPECT_EQ(ptr, &v.front());
- EXPECT_EQ(ptr + len - 1, &v.back());
- EXPECT_EQ(ptr + len - 1, &*v.rbegin());
- EXPECT_EQ(ptr, &*(v.rend() - 1));
- }
-}
-
-static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
- const int len) {
- // Test the data accessors both when the MutableArraySlice is declared const,
- // and when it is not.
- MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
- MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
-
- MutableIntSlice other; // To test the assignment return value.
- MutableIntSlice v = other = vorig;
- EXPECT_EQ(ptr, v.mutable_data());
-
- int counter = 0;
- for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- // Test that elements are assignable.
- v[0] = 1;
- v.front() = 2;
- v.back() = 5;
- *v.mutable_data() = 4;
- std::fill(v.begin(), v.end(), 5);
- std::fill(v.rbegin(), v.rend(), 6);
- // Test size-changing methods.
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i, &v[i]);
- }
- if (len > 1) {
- v.pop_front();
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i + 1, &v[i]);
- }
- }
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
- EXPECT_EQ(v.size(), vec.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(v[i], vec[i]);
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
- TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
-}
-
-static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-
-static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-// A struct supplying the data(), mutable_data() and size() methods, just like
-// e.g. proto2::RepeatedField.
-struct RepeatedField {
- std::vector<int> storage;
- const int* data() const { return storage.data(); }
- int* mutable_data() { return storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying the data() (both mutable and const versions) and
-// size(). It also supplies mutable_data() but we test that data() is selected
-// instead.
-struct ContainerWithOverloads {
- std::vector<int> storage;
- std::vector<int> wrong_storage;
- const int* data() const { return storage.data(); }
- int* data() { return storage.data(); }
- // MutableArraySlice should not call mutable_data(), preferring data()
- // instead.
- int* mutable_data() { return wrong_storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying data() and size() methods.
-struct ContainerWithShallowConstData {
- std::vector<int> storage;
- int* data() const { return const_cast<int*>(storage.data()); }
- int size() const { return storage.size(); }
-};
-
-TEST(IntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- TestHelper(IntSlice(vec), vec);
- TestHelper(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, WithPosAndLen) {
- IntVec vec;
- Fill(&vec, 20);
- for (size_t len = 0; len < vec.size(); len++) {
- IntVec subvec(vec.begin(), vec.begin() + len);
- TestImplicitConversion(IntSlice(vec, 0, len), subvec);
- TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
- }
- EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
- EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
- TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
-}
-
-TEST(IntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice v(vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec, bvec;
- Fill(&avec, l1);
- Fill(&bvec, l2, 100);
- IntSlice a(avec), b(bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(IntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice slice;
- slice = vec;
- TestImplicitConversion(vec, vec);
- TestImplicitConversion(slice, vec);
- TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- IntVec vec;
- Fill(&vec, len);
- IntSlice v = inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(inline_vec, vec);
- }
-}
-
-TEST(IntSlice, StaticArrayConversion) {
- int array[20];
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(array));
- std::copy(vec.begin(), vec.end(), array);
- IntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, vec);
-}
-
-TEST(IntSlice, StdArrayConversion) {
- std::array<int, 20> array;
- IntVec vec;
- Fill(&vec, array.size());
- std::copy(vec.begin(), vec.end(), array.begin());
-
- // Check assignment.
- {
- IntSlice v = array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- IntSlice v = {array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(array, vec);
-}
-
-// Values according to the Fill function.
-static const int test_const_array[] = {0, 1, 2};
-
-TEST(IntSlice, ConstStaticArrayConversion) {
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(test_const_array));
- IntSlice v = test_const_array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(test_const_array, vec);
-}
-
-TEST(IntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- IntVec vec;
- Fill(&vec, 20);
- repeated_field.storage = vec;
- IntSlice v = repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(repeated_field, vec);
-}
-
-TEST(IntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, MutableIntSliceConversion) {
- IntVec vec(20);
- IntSlice slice = MutableIntSlice(&vec);
- EXPECT_EQ(vec.size(), slice.size());
- EXPECT_EQ(vec.data(), slice.data());
-}
-
-TEST(IntSlice, Equality) {
- IntVec vec1(20);
- IntVec vec2(20);
- // These two slices are from different vectors, but have the same
- // size and have the same elements (right now). They should
- // compare equal.
- const IntSlice from1(vec1);
- const IntSlice from2(vec2);
- EXPECT_EQ(from1, from1);
- EXPECT_EQ(from1, from2);
-
- // This verifies that MutableArraySlices can be compared freely with
- // ArraySlices.
- const MutableIntSlice mutable_from1(&vec1);
- const MutableIntSlice mutable_from2(&vec2);
- EXPECT_EQ(from1, mutable_from1);
- EXPECT_EQ(mutable_from1, from1);
- EXPECT_EQ(mutable_from1, mutable_from2);
- EXPECT_EQ(mutable_from2, mutable_from1);
-
- // With a different size, the array slices should not be equal.
- EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
-
- // With different contents, the array slices should not be equal.
- ++vec2.back();
- EXPECT_NE(from1, from2);
-}
-
-// Compile-asserts that the argument has the expected type.
-template <typename Expected, typename T>
-void CheckType(const T& value) {
- ::testing::StaticAssertTypeEq<Expected, T>();
-}
-
-TEST(IntSlice, ExposesContainerTypesAndConsts) {
- IntSlice slice;
- const IntSlice const_slice;
- CheckType<IntSlice::iterator>(slice.begin());
- CheckType<IntSlice::const_iterator>(const_slice.end());
- CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
- CheckType<IntSlice::reverse_iterator>(slice.rend());
- ::testing::StaticAssertTypeEq<int, IntSlice::value_type>();
- ::testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
- EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
-}
-
-void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
-
-void TestRange(IntSlice slice, int from, int to) {
- ASSERT_EQ(to - from + 1, slice.size());
- for (size_t i = 0; i < slice.size(); ++i) {
- EXPECT_EQ(from + i, slice[i]);
- }
-}
-
-TEST(IntSlice, InitializerListConversion) {
- TestEmpty({});
- TestRange({1}, 1, 1);
- TestRange({10, 11, 12, 13}, 10, 13);
-}
-
-TEST(CharSlice, StringConversion) {
- IntVec vec;
- Fill(&vec, 20);
- string str(vec.begin(), vec.end());
- CharSlice v = str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(str, vec);
-}
-
-TEST(IntPtrSlice, ConstConversion) {
- int one = 1;
- int two = 2;
- std::vector<int*> vec;
- vec.push_back(&one);
- vec.push_back(&two);
- ArraySlice<const int*> v = vec;
- ASSERT_EQ(2, v.size());
- EXPECT_EQ(&one, v[0]);
- EXPECT_EQ(&two, v[1]);
-}
-
-TEST(MutableIntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
- MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
- }
-}
-
-TEST(MutableIntSlice, WithPosAndLen) {
- IntVec vec(20);
- for (size_t len = 0; len < vec.size(); len++) {
- TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
- TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
- vec.data(), len);
- }
- EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
- EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
- TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
- vec.data(), vec.size());
-}
-
-TEST(MutableIntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice v(&vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(MutableIntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec(l1), bvec(l2);
- MutableIntSlice a(&avec), b(&bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(&avec[i], &b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(&bvec[i], &a[i]);
- }
- }
- }
-}
-
-TEST(MutableIntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice slice;
- slice = &vec;
- TestImplicitConversion(&vec, vec.data(), len);
- TestImplicitConversion(slice, vec.data(), len);
- TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
- len);
- }
-}
-
-TEST(MutableIntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- MutableIntSlice v = &inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&inline_vec, inline_vec.data(), inline_vec.size());
- }
-}
-
-TEST(MutableIntSlice, StaticArrayConversion) {
- int array[20];
- MutableIntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
-}
-
-TEST(MutableIntSlice, StdArrayConversion) {
- std::array<int, 20> array;
-
- // Check assignment.
- {
- MutableIntSlice v = &array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- MutableIntSlice v = {&array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(&array, &array[0], array.size());
-}
-
-TEST(MutableIntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- Fill(&repeated_field.storage, 20);
- MutableIntSlice v = &repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
- repeated_field.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, TypedefsAndConstants) {
- ::testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
- ::testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
- ::testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
- ::testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
-
- EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_pointer(s.mutable_data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences_Const) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- const MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_pointer(s.mutable_data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-bool TestMutableOverload(MutableIntSlice slice) { return false; }
-
-bool TestMutableOverload(MutableCharSlice slice) { return true; }
-
-TEST(MutableCharSlice, StringConversion) {
- for (int len = 0; len < 20; len++) {
- string str(len, '\0');
- MutableCharSlice v = &str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(v, str.data(), str.size());
- }
- // Verify that only the correct overload is feasible. Note that this would
- // fail if the string ctor was declared simply as MutableArraySlice(string*),
- // since in that case both overloads would be feasible.
- string str;
- EXPECT_TRUE(TestMutableOverload(&str));
-
- // Avoid warning "unused function 'TestMutableOverload'"
- int a[1];
- EXPECT_FALSE(TestMutableOverload(a));
-}
-
-} // namespace
-} // namespace gtl
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/cleanup.h b/tensorflow/core/lib/gtl/cleanup.h
index 6bd60ca482..8c73dc6aa9 100644
--- a/tensorflow/core/lib/gtl/cleanup.h
+++ b/tensorflow/core/lib/gtl/cleanup.h
@@ -39,8 +39,8 @@ limitations under the License.
//
// You can call 'release()' on a Cleanup object to cancel the cleanup.
-#ifndef TENSORFLOW_LIB_GTL_CLEANUP_H_
-#define TENSORFLOW_LIB_GTL_CLEANUP_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
+#define TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
#include <type_traits>
#include <utility>
@@ -110,4 +110,4 @@ TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_CLEANUP_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index 2011f7d4a1..2d622dc229 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -13,676 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
-// for sequences of length <= N are provided inline without requiring
-// any heap allocation. Typically N is very small (e.g., 4) so that
-// sequences that are expected to be short do not require allocations.
-//
-// Only some of the std::vector<> operations are currently implemented.
-// Other operations may be added as needed to facilitate migrating
-// code that uses std::vector<> to InlinedVector<>.
-//
-// NOTE: If you want an inlined version to replace use of a
-// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
-// in util/bitmap/inlined_bitvector.h
-//
-// TODO(billydonahue): change size_t to size_type where appropriate.
+#ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
+#define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
-#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
-#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
-
-#include <stddef.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <algorithm>
-#include <cstddef>
-#include <iterator>
-#include <memory>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/byte_order.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
+#include "absl/container/inlined_vector.h"
+// TODO(kramerb): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include <initializer_list> // NOLINT(build/include_order)
-
namespace tensorflow {
namespace gtl {
-template <typename T, int N>
-class InlinedVector {
- public:
- typedef T value_type;
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef size_t size_type;
- typedef std::ptrdiff_t difference_type;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
-
- // Create an empty vector
- InlinedVector();
-
- // Create a vector with n copies of value_type().
- explicit InlinedVector(size_t n);
-
- // Create a vector with n copies of elem
- InlinedVector(size_t n, const value_type& elem);
-
- // Create and initialize with the elements [range_start .. range_end).
- // The unused enable_if argument restricts this constructor so that it is
- // elided when value_type is an integral type. This prevents ambiguous
- // interpretation between a call to this constructor with two integral
- // arguments and a call to the preceding (n, elem) constructor.
- template <typename InputIterator>
- InlinedVector(
- InputIterator range_start, InputIterator range_end,
- typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
- NULL) {
- InitRep();
- AppendRange(range_start, range_end);
- }
-
- InlinedVector(std::initializer_list<value_type> init) {
- InitRep();
- AppendRange(init.begin(), init.end());
- }
-
- InlinedVector(const InlinedVector& v);
-
- ~InlinedVector() { clear(); }
-
- InlinedVector& operator=(const InlinedVector& v) {
- // Optimized to avoid reallocation.
- // Prefer reassignment to copy construction for elements.
- const size_t s = size();
- const size_t vs = v.size();
- if (s < vs) { // grow
- reserve(vs);
- if (s) std::copy(v.begin(), v.begin() + s, begin());
- std::copy(v.begin() + s, v.end(), std::back_inserter(*this));
- } else { // maybe shrink
- erase(begin() + vs, end());
- std::copy(v.begin(), v.end(), begin());
- }
- return *this;
- }
-
- size_t size() const { return size_internal(); }
-
- bool empty() const { return (size() == 0); }
-
- // Return number of elements that can be stored in vector
- // without requiring a reallocation of underlying memory
- size_t capacity() const {
- if (is_inline()) {
- return kFit;
- } else {
- return static_cast<size_t>(1) << u_.data[kSize - 2];
- }
- }
-
- // Return a pointer to the underlying array.
- // Only result[0,size()-1] are defined.
- pointer data() {
- if (is_inline()) {
- return reinterpret_cast<T*>(u_.data);
- } else {
- return outofline_pointer();
- }
- }
- const_pointer data() const {
- return const_cast<InlinedVector<T, N>*>(this)->data();
- }
-
- // Remove all elements
- void clear() {
- DiscardStorage();
- u_.data[kSize - 1] = 0;
- }
-
- // Return the ith element
- // REQUIRES: 0 <= i < size()
- const value_type& at(size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
- const value_type& operator[](size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- // Return a non-const reference to the ith element
- // REQUIRES: 0 <= i < size()
- value_type& at(size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
- value_type& operator[](size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- value_type& back() {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- const value_type& back() const {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- value_type& front() {
- DCHECK(!empty());
- return at(0);
- }
-
- const value_type& front() const {
- DCHECK(!empty());
- return at(0);
- }
-
- // Append a T constructed with args to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- template <typename... Args>
- void emplace_back(Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
- if (s < capacity()) {
- new (data() + s) T(std::forward<Args>(args)...);
- set_size_internal(s + 1);
- } else {
- EmplaceBackSlow(std::forward<Args>(args)...);
- }
- }
-
- // Append t to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- void push_back(const value_type& t) { emplace_back(t); }
- void push_back(value_type&& t) { emplace_back(std::move(t)); }
-
- inline void pop_back() {
- DCHECK(!empty());
- const size_t s = size();
- Destroy(data() + s - 1, 1);
- set_size_internal(s - 1);
- }
-
- // Resizes the vector to contain "n" elements.
- // If "n" is smaller than the initial size, extra elements are destroyed.
- // If "n" is larger than the initial size, enough copies of "elem"
- // are appended to increase the size to "n". If "elem" is omitted,
- // new elements are value-initialized.
- void resize(size_t n) { Resize<ValueInit>(n, nullptr); }
- void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); }
-
- iterator begin() { return data(); }
- const_iterator begin() const { return data(); }
-
- iterator end() { return data() + size(); }
- const_iterator end() const { return data() + size(); }
-
- iterator insert(iterator pos, const value_type& v);
-
- iterator erase(iterator pos) {
- DCHECK_LT(pos, end());
- DCHECK_GE(pos, begin());
- std::copy(pos + 1, end(), pos);
- pop_back();
- return pos;
- }
-
- iterator erase(iterator first, iterator last);
-
- // Enlarges the underlying representation so it can hold at least
- // "n" elements without reallocation.
- // Does not change size() or the actual contents of the vector.
- void reserve(size_t n) {
- if (n > capacity()) {
- // Make room for new elements
- Grow<Move>(n);
- }
- }
-
- // Swap the contents of *this with other.
- // REQUIRES: value_type is swappable and copyable.
- void swap(InlinedVector& other);
-
- private:
- // Representation can either be inlined or out-of-line.
- // In either case, at least sizeof(void*) + 8 bytes are available.
- //
- // Inlined:
- // Last byte holds the length.
- // First (length*sizeof(T)) bytes stores the elements.
- // Outlined:
- // Last byte holds kSentinel.
- // Second-last byte holds lg(capacity)
- // Preceding 6 bytes hold size.
- // First sizeof(T*) bytes hold pointer.
-
- // Compute rep size.
- static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag
- static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align
-
- // See how many fit T we can fit inside kSize, but no more than 254
- // since 255 is used as sentinel tag for out-of-line allocation.
- static const unsigned int kSentinel = 255;
- static const size_t kFit1 = (kSize - 1) / sizeof(T);
- static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1;
-
- union {
- unsigned char data[kSize];
- // Force data to be aligned enough for a pointer.
- T* unused_aligner;
- } u_;
-
- inline void InitRep() { u_.data[kSize - 1] = 0; }
- inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; }
-
- inline T* outofline_pointer() const {
- T* ptr;
- memcpy(&ptr, &u_.data[0], sizeof(ptr));
- return ptr;
- }
-
- inline void set_outofline_pointer(T* p) {
- memcpy(&u_.data[0], &p, sizeof(p));
- }
-
- inline uint64_t outofline_word() const {
- uint64_t word;
- memcpy(&word, &u_.data[kSize - 8], sizeof(word));
- return word;
- }
-
- inline void set_outofline_word(uint64_t w) {
- memcpy(&u_.data[kSize - 8], &w, sizeof(w));
- }
-
- inline size_t size_internal() const {
- uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]);
- if (s != kSentinel) {
- return static_cast<size_t>(s);
- } else {
- const uint64_t word = outofline_word();
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- return static_cast<size_t>(word & 0xffffffffffffull);
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- return static_cast<size_t>(word >> 16);
- }
- }
- }
-
- void set_size_internal(size_t n) {
- if (is_inline()) {
- DCHECK_LT(n, kSentinel);
- u_.data[kSize - 1] = static_cast<unsigned char>(n);
- } else {
- uint64_t word;
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- word = (static_cast<uint64_t>(n) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) |
- (static_cast<uint64_t>(kSentinel) << 56));
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- word = ((static_cast<uint64_t>(n) << 16) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) |
- (static_cast<uint64_t>(kSentinel)));
- }
- set_outofline_word(word);
- DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n;
- }
- }
-
- void DiscardStorage() {
- T* base = data();
- size_t n = size();
- Destroy(base, n);
- if (!is_inline()) {
- port::Free(base);
- }
- }
-
- template <typename... Args>
- void EmplaceBackSlow(Args&&... args) {
- const size_t s = size();
- DCHECK_EQ(s, capacity());
- Grow<Move, Construct>(s + 1, std::forward<Args>(args)...);
- set_size_internal(s + 1);
- }
-
- // Movers for Grow
- // Does nothing.
- static void Nop(T* src, size_t n, T* dst) {}
-
- // Moves srcs[0,n-1] contents to dst[0,n-1].
- static void Move(T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(std::move(*(src + i)));
- }
- }
-
- // Initializers for Resize.
- // Initializes dst[0,n-1] with empty constructor.
- static void ValueInit(const T*, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T();
- }
- }
-
- // Initializes dst[0,n-1] with copies of *src.
- static void Fill(const T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(*src);
- }
- }
-
- void Destroy(T* src, int n) {
- if (!std::is_trivially_destructible<T>::value) {
- for (int i = 0; i < n; i++) {
- (src + i)->~T();
- }
- }
- }
-
- // Initialization methods for Grow.
- // 1) Leave uninitialized memory.
- struct Uninitialized {
- void operator()(T*) const {}
- };
- // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
- struct Construct {
- template <class... Args>
- void operator()(T* dst, Args&&... args) const {
- new (dst) T(std::forward<Args>(args)...);
- }
- };
-
- // Grow so that capacity >= n. Uses Mover to move existing elements
- // to new buffer, and possibly initialize the new element according
- // to InitType.
- // We pass the InitType and Mover as template arguments so that
- // this code compiles even if T does not support copying or default
- // construction.
- template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized,
- class... Args>
- void Grow(size_t n, Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
-
- // Compute new capacity by repeatedly doubling current capacity
- size_t target = 1;
- size_t target_lg = 0;
- while (target < kFit || target < n) {
- // TODO(psrc): Check and avoid overflow?
- target_lg++;
- target <<= 1;
- }
-
- T* src = data();
- T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
-
- // Need to copy elem before discarding src since it might alias src.
- InitType{}(dst + s, std::forward<Args>(args)...);
- Mover(src, s, dst);
- DiscardStorage();
-
- u_.data[kSize - 1] = kSentinel;
- u_.data[kSize - 2] = static_cast<unsigned char>(target_lg);
- set_size_internal(s);
- DCHECK_EQ(capacity(), target);
- set_outofline_pointer(dst);
- }
-
- // Resize to size n. Any new elements are initialized by passing
- // elem and the destination to Initializer. We pass the Initializer
- // as a template argument so that this code compiles even if T does
- // not support copying.
- template <void(Initializer)(const T*, size_t, T*)>
- void Resize(size_t n, const T* elem) {
- size_t s = size();
- if (n <= s) {
- Destroy(data() + n, s - n);
- set_size_internal(n);
- return;
- }
- reserve(n);
- DCHECK_GE(capacity(), n);
- set_size_internal(n);
- Initializer(elem, n - s, data() + s);
- }
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::input_iterator_tag);
-
- // Faster path for forward iterators.
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last);
-};
-
-// Provide linkage for constants.
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSizeUnaligned;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSize;
-template <typename T, int N>
-const unsigned int InlinedVector<T, N>::kSentinel;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit1;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit;
-
-template <typename T, int N>
-inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) {
- a.swap(b);
-}
-
-template <typename T, int N>
-inline bool operator==(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
-}
-
-template <typename T, int N>
-inline bool operator!=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a == b);
-}
-
-template <typename T, int N>
-inline bool operator<(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
-}
-
-template <typename T, int N>
-inline bool operator>(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return b < a;
-}
-
-template <typename T, int N>
-inline bool operator<=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(b < a);
-}
-
-template <typename T, int N>
-inline bool operator>=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a < b);
-}
-
-// ========================================
-// Implementation
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector() {
- InitRep();
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Must use Nop in case T is not copyable
- }
- set_size_internal(n);
- ValueInit(nullptr, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Can use Nop since we know we have nothing to copy
- }
- set_size_internal(n);
- Fill(&elem, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) {
- InitRep();
- *this = v;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert(
- iterator pos, const value_type& v) {
- DCHECK_GE(pos, begin());
- DCHECK_LE(pos, end());
- if (pos == end()) {
- push_back(v);
- return end() - 1;
- }
- size_t s = size();
- size_t idx = std::distance(begin(), pos);
- if (s == capacity()) {
- Grow<Move>(s + 1);
- }
- CHECK_LT(s, capacity());
- pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
- Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1]
- std::copy_backward(pos, data() + s - 1, data() + s);
- *pos = v;
-
- set_size_internal(s + 1);
- return pos;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase(
- iterator first, iterator last) {
- DCHECK_LE(begin(), first);
- DCHECK_LE(first, last);
- DCHECK_LE(last, end());
-
- size_t s = size();
- ptrdiff_t erase_gap = std::distance(first, last);
- std::copy(last, data() + s, first);
- Destroy(data() + s - erase_gap, erase_gap);
- set_size_internal(s - erase_gap);
- return first;
-}
-
-template <typename T, int N>
-void InlinedVector<T, N>::swap(InlinedVector& other) {
- using std::swap; // Augment ADL with std::swap.
- if (&other == this) {
- return;
- }
-
- InlinedVector* a = this;
- InlinedVector* b = &other;
-
- const bool a_inline = a->is_inline();
- const bool b_inline = b->is_inline();
-
- if (!a_inline && !b_inline) {
- // Just swap the top-level representations.
- T* aptr = a->outofline_pointer();
- T* bptr = b->outofline_pointer();
- a->set_outofline_pointer(bptr);
- b->set_outofline_pointer(aptr);
-
- uint64_t aword = a->outofline_word();
- uint64_t bword = b->outofline_word();
- a->set_outofline_word(bword);
- b->set_outofline_word(aword);
- return;
- }
-
- // Make a the larger of the two to reduce number of cases.
- size_t a_size = a->size();
- size_t b_size = b->size();
- if (a->size() < b->size()) {
- swap(a, b);
- swap(a_size, b_size);
- }
- DCHECK_GE(a_size, b_size);
-
- if (b->capacity() < a_size) {
- b->Grow<Move>(a_size);
- }
-
- // One is inline and one is not.
- // 'a' is larger. Swap the elements up to the smaller array size.
- std::swap_ranges(a->data(), a->data() + b_size, b->data());
- std::uninitialized_copy(a->data() + b_size, a->data() + a_size,
- b->data() + b_size);
- Destroy(a->data() + b_size, a_size - b_size);
- a->set_size_internal(b_size);
- b->set_size_internal(a_size);
- DCHECK_EQ(b->size(), a_size);
- DCHECK_EQ(a->size(), b_size);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::input_iterator_tag) {
- std::copy(first, last, std::back_inserter(*this));
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::forward_iterator_tag) {
- typedef typename std::iterator_traits<Iter>::difference_type Length;
- Length length = std::distance(first, last);
- size_t s = size();
- reserve(s + length);
- std::uninitialized_copy_n(first, length, data() + s);
- set_size_internal(s + length);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) {
- typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
- AppendRange(first, last, IterTag());
-}
+using absl::InlinedVector;
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
deleted file mode 100644
index 2721885c4a..0000000000
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ /dev/null
@@ -1,898 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-
-#include <list>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
-
-// A type that counts number of live occurrences of the type
-static int64 instances = 0;
-class Instance {
- public:
- int value_;
- explicit Instance(int x) : value_(x) { instances++; }
- Instance(const Instance& x) : value_(x.value_) { instances++; }
- ~Instance() { instances--; }
-
- friend inline void swap(Instance& a, Instance& b) {
- using std::swap;
- swap(a.value_, b.value_);
- }
-
- friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
- return o << "[value:" << v.value_ << "]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
-
-// A simple reference counted class to make sure that the proper elements are
-// destroyed in the erase(begin, end) test.
-class RefCounted {
- public:
- RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
-
- RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
- VLOG(5) << "[RefCounted: copy"
- << " from count @" << v.count_ << "]";
- Ref();
- }
-
- ~RefCounted() {
- Unref();
- count_ = nullptr;
- }
-
- friend void swap(RefCounted& a, RefCounted& b) {
- using std::swap;
- swap(a.value_, b.value_);
- swap(a.count_, b.count_);
- }
-
- RefCounted& operator=(RefCounted v) {
- using std::swap;
- swap(*this, v);
- return *this;
- }
-
- void Ref() const {
- CHECK(count_ != nullptr);
- ++(*count_);
- VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- void Unref() const {
- --(*count_);
- CHECK_GE(*count_, 0);
- VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- int count() const { return *count_; }
-
- friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
- return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
- }
-
- int value_;
- int* count_;
-};
-
-typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
-
-// A class with a vtable pointer
-class Dynamic {
- public:
- virtual ~Dynamic() {}
-
- friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
- return o << "[Dynamic]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
-
-// Append 0..len-1 to *v
-static void Fill(IntVec* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static IntVec Fill(int len, int offset = 0) {
- IntVec v;
- Fill(&v, len, offset);
- return v;
-}
-
-TEST(IntVec, SimpleOps) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- const IntVec& cv = v; // const alias
-
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- EXPECT_EQ(v.begin(), v.data());
- EXPECT_EQ(cv.begin(), cv.data());
-
- int counter = 0;
- for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, Erase) {
- for (int len = 1; len < 20; len++) {
- for (int i = 0; i < len; ++i) {
- IntVec v;
- Fill(&v, len);
- v.erase(v.begin() + i);
- EXPECT_EQ(len - 1, v.size());
- for (int j = 0; j < i; ++j) {
- EXPECT_EQ(j, v[j]);
- }
- for (int j = i; j < len - 1; ++j) {
- EXPECT_EQ(j + 1, v[j]);
- }
- }
- }
-}
-
-// At the end of this test loop, the elements between [erase_begin, erase_end)
-// should have reference counts == 0, and all others elements should have
-// reference counts == 1.
-TEST(RefCountedVec, EraseBeginEnd) {
- for (int len = 1; len < 20; ++len) {
- for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
- for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
- std::vector<int> counts(len, 0);
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- int erase_len = erase_end - erase_begin;
-
- v.erase(v.begin() + erase_begin, v.begin() + erase_end);
-
- EXPECT_EQ(len - erase_len, v.size());
-
- // Check the elements before the first element erased.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(i, v[i].value_);
- }
-
- // Check the elements after the first element erased.
- for (size_t i = erase_begin; i < v.size(); ++i) {
- EXPECT_EQ(i + erase_len, v[i].value_);
- }
-
- // Check that the elements at the beginning are preserved.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
-
- // Check that the erased elements are destroyed
- for (int i = erase_begin; i < erase_end; ++i) {
- EXPECT_EQ(0, counts[i]);
- }
-
- // Check that the elements at the end are preserved.
- for (int i = erase_end; i < len; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
- }
- }
- }
-}
-
-struct NoDefaultCtor {
- explicit NoDefaultCtor(int) {}
-};
-struct NoCopy {
- NoCopy() {}
- NoCopy(const NoCopy&) = delete;
-};
-struct NoAssign {
- NoAssign() {}
- NoAssign& operator=(const NoAssign&) = delete;
-};
-struct MoveOnly {
- MoveOnly() {}
- MoveOnly(MoveOnly&&) = default;
- MoveOnly& operator=(MoveOnly&&) = default;
-};
-TEST(InlinedVectorTest, NoDefaultCtor) {
- tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
- (void)v;
-}
-TEST(InlinedVectorTest, NoCopy) {
- tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, NoAssign) {
- tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, MoveOnly) {
- gtl::InlinedVector<MoveOnly, 2> v;
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
-}
-
-TEST(IntVec, Insert) {
- for (int len = 0; len < 20; len++) {
- for (int pos = 0; pos <= len; pos++) {
- IntVec v;
- Fill(&v, len);
- v.insert(v.begin() + pos, 9999);
- EXPECT_EQ(v.size(), len + 1);
- for (int i = 0; i < pos; i++) {
- EXPECT_EQ(v[i], i);
- }
- EXPECT_EQ(v[pos], 9999);
- for (size_t i = pos + 1; i < v.size(); i++) {
- EXPECT_EQ(v[i], i - 1);
- }
- }
- }
-}
-
-TEST(RefCountedVec, InsertConstructorDestructor) {
- // Make sure the proper construction/destruction happen during insert
- // operations.
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- for (int pos = 0; pos <= len; pos++) {
- SCOPED_TRACE(pos);
- std::vector<int> counts(len, 0);
- int inserted_count = 0;
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- SCOPED_TRACE(i);
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
-
- RefCounted insert_element(9999, &inserted_count);
- EXPECT_EQ(1, inserted_count);
- v.insert(v.begin() + pos, insert_element);
- EXPECT_EQ(2, inserted_count);
- // Check that the elements at the end are preserved.
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
- EXPECT_EQ(2, inserted_count);
- }
- }
-}
-
-TEST(IntVec, Resize) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- // Try resizing up and down by k elements
- static const int kResizeElem = 1000000;
- for (int k = 0; k < 10; k++) {
- // Enlarging resize
- v.resize(len + k, kResizeElem);
- EXPECT_EQ(len + k, v.size());
- EXPECT_LE(len + k, v.capacity());
- for (int i = 0; i < len + k; i++) {
- if (i < len) {
- EXPECT_EQ(i, v[i]);
- } else {
- EXPECT_EQ(kResizeElem, v[i]);
- }
- }
-
- // Shrinking resize
- v.resize(len, kResizeElem);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, InitWithLength) {
- for (int len = 0; len < 20; len++) {
- IntVec v(len, 7);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(7, v[i]);
- }
- }
-}
-
-TEST(IntVec, CopyConstructorAndAssignment) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- IntVec v2(v);
- EXPECT_EQ(v, v2);
-
- for (int start_len = 0; start_len < 20; start_len++) {
- IntVec v3;
- Fill(&v3, start_len, 99); // Add dummy elements that should go away
- v3 = v;
- EXPECT_EQ(v, v3);
- }
- }
-}
-
-TEST(OverheadTest, Storage) {
- // Check for size overhead.
- using tensorflow::gtl::InlinedVector;
- EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>));
- EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>));
-
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>));
- EXPECT_EQ(2 * sizeof(char*),
- sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>));
- EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>));
-}
-
-TEST(IntVec, Clear) {
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- IntVec v;
- Fill(&v, len);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntVec, Reserve) {
- for (size_t len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- for (size_t newlen = 0; newlen < 100; newlen++) {
- const int* start_rep = v.data();
- v.reserve(newlen);
- const int* final_rep = v.data();
- if (newlen <= len) {
- EXPECT_EQ(start_rep, final_rep);
- }
- EXPECT_LE(newlen, v.capacity());
-
- // Filling up to newlen should not change rep
- while (v.size() < newlen) {
- v.push_back(0);
- }
- EXPECT_EQ(final_rep, v.data());
- }
- }
-}
-
-template <typename T>
-static std::vector<typename T::value_type> Vec(const T& src) {
- std::vector<typename T::value_type> result;
- for (const auto& elem : src) {
- result.push_back(elem);
- }
- return result;
-}
-
-TEST(IntVec, SelfRefPushBack) {
- std::vector<string> std_v;
- tensorflow::gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(std_v, Vec(v));
-
- v.push_back(v.back());
- std_v.push_back(std_v.back());
- }
- EXPECT_EQ(std_v, Vec(v));
-}
-
-TEST(IntVec, SelfRefPushBackWithMove) {
- std::vector<string> std_v;
- gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(v.back(), std_v.back());
-
- v.push_back(std::move(v.back()));
- std_v.push_back(std::move(std_v.back()));
- }
- EXPECT_EQ(v.back(), std_v.back());
-}
-
-TEST(IntVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- SCOPED_TRACE(l1);
- for (int l2 = 0; l2 < 20; l2++) {
- SCOPED_TRACE(l2);
- IntVec a = Fill(l1, 0);
- IntVec b = Fill(l2, 100);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(InstanceVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- InstanceVec a, b;
- for (int i = 0; i < l1; i++) a.push_back(Instance(i));
- for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
- EXPECT_EQ(l1 + l2, instances);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1 + l2, instances);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i].value_);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i].value_);
- }
- }
- }
-}
-
-TEST(IntVec, EqualAndNotEqual) {
- IntVec a, b;
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- a.push_back(3);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b.push_back(3);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b.push_back(7);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.push_back(6);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.clear();
- b.clear();
- for (int i = 0; i < 100; i++) {
- a.push_back(i);
- b.push_back(i);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b[i] = b[i] + 1;
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b[i] = b[i] - 1; // Back to before
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
- }
-}
-
-TEST(IntVec, RelationalOps) {
- IntVec a, b;
- EXPECT_FALSE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_FALSE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_TRUE(b <= a);
- EXPECT_TRUE(a >= b);
- EXPECT_TRUE(b >= a);
- b.push_back(3);
- EXPECT_TRUE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_TRUE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_FALSE(b <= a);
- EXPECT_FALSE(a >= b);
- EXPECT_TRUE(b >= a);
-}
-
-TEST(InstanceVec, CountConstructorsDestructors) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- InstanceVec v;
- for (int i = 0; i < len; i++) {
- v.push_back(Instance(i));
- }
- EXPECT_EQ(start + len, instances);
-
- { // Copy constructor should create 'len' more instances.
- InstanceVec v_copy(v);
- EXPECT_EQ(start + len + len, instances);
- }
- EXPECT_EQ(start + len, instances);
-
- // Enlarging resize() must construct some objects
- v.resize(len + 10, Instance(100));
- EXPECT_EQ(start + len + 10, instances);
-
- // Shrinking resize() must destroy some objects
- v.resize(len, Instance(100));
- EXPECT_EQ(start + len, instances);
-
- // reserve() must not increase the number of initialized objects
- v.reserve(len + 1000);
- EXPECT_EQ(start + len, instances);
-
- // pop_back() and erase() must destroy one object
- if (len > 0) {
- v.pop_back();
- EXPECT_EQ(start + len - 1, instances);
- if (!v.empty()) {
- v.erase(v.begin());
- EXPECT_EQ(start + len - 2, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- for (int longorshort = 0; longorshort <= 1; ++longorshort) {
- InstanceVec longer, shorter;
- for (int i = 0; i < len; i++) {
- longer.push_back(Instance(i));
- shorter.push_back(Instance(i));
- }
- longer.push_back(Instance(len));
- EXPECT_EQ(start + len + len + 1, instances);
-
- if (longorshort) {
- shorter = longer;
- EXPECT_EQ(start + (len + 1) + (len + 1), instances);
- } else {
- longer = shorter;
- EXPECT_EQ(start + len + len, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(RangedConstructor, SimpleType) {
- std::vector<int> source_v = {4, 5, 6, 7};
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
- tensorflow::gtl::InlinedVector<int, 4> empty4;
- EXPECT_EQ(4, v.size());
- EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(4, v[0]);
- EXPECT_EQ(5, v[1]);
- EXPECT_EQ(6, v[2]);
- EXPECT_EQ(7, v[3]);
-
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<int, 2> empty2;
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty2.capacity(), realloc_v.capacity());
- EXPECT_EQ(4, realloc_v[0]);
- EXPECT_EQ(5, realloc_v[1]);
- EXPECT_EQ(6, realloc_v[2]);
- EXPECT_EQ(7, realloc_v[3]);
-}
-
-TEST(RangedConstructor, ComplexType) {
- // We also use a list here to pass a different flavor of iterator (e.g. not
- // random-access).
- std::list<Instance> source_v = {Instance(0)};
-
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<Instance, 1> empty1;
- EXPECT_EQ(1, v.size());
- EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(0, v[0].value_);
-
- std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2),
- Instance(3)};
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
- source_v2.end());
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty1.capacity(), realloc_v.capacity());
- EXPECT_EQ(0, realloc_v[0].value_);
- EXPECT_EQ(1, realloc_v[1].value_);
- EXPECT_EQ(2, realloc_v[2].value_);
- EXPECT_EQ(3, realloc_v[3].value_);
-}
-
-TEST(RangedConstructor, ElementsAreConstructed) {
- std::vector<string> source_v = {"cat", "dog"};
-
- // Force expansion and re-allocation of v. Ensures that when the vector is
- // expanded that new elements are constructed.
- tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
- EXPECT_EQ("cat", v[0]);
- EXPECT_EQ("dog", v[1]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
- auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_EQ(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
- auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_LE(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, DisparateTypesInList) {
- EXPECT_EQ((std::vector<int>{-7, 8}),
- Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
-
- EXPECT_EQ(
- (std::vector<string>{"foo", "bar"}),
- Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
- tensorflow::gtl::InlinedVector<Instance, 1> empty;
- auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
- EXPECT_EQ(1, vec.size());
- EXPECT_EQ(empty.capacity(), vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
- auto vec =
- tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
- EXPECT_EQ(2, vec.size());
- EXPECT_LE(2, vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
- EXPECT_EQ(1, vec[1].value_);
-}
-
-TEST(DynamicVec, DynamicVecCompiles) {
- DynamicVec v;
- (void)v;
-}
-
-static void BM_InlinedVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- IntVec v;
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
-
-static void BM_InlinedVectorFillRange(int iters, int len) {
- std::unique_ptr<int[]> ia(new int[len]);
- for (int j = 0; j < len; j++) {
- ia[j] = j;
- }
- for (int i = 0; i < iters; i++) {
- IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
-
-static void BM_StdVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- std::vector<int> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
-
-bool StringRepresentedInline(string s) {
- const char* chars = s.data();
- string s1 = std::move(s);
- return s1.data() != chars;
-}
-
-static void BM_InlinedVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- gtl::InlinedVector<string, 8> v;
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
-}
-BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024);
-
-static void BM_StdVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- std::vector<string> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
- // The purpose of the benchmark is to verify that inlined vector is
- // efficient when moving is more efficient than copying. To do so, we
- // use strings that are larger than the small string optimization.
- CHECK(!StringRepresentedInline(strings[0]));
-}
-BENCHMARK(BM_StdVectorFillString)->Range(0, 1024);
-
-namespace {
-struct Buffer { // some arbitrary structure for benchmarking.
- char* base;
- int length;
- int capacity;
- void* user_data;
-};
-} // anonymous namespace
-
-static void BM_InlinedVectorTenAssignments(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
-
- BufferVec src;
- src.resize(len);
-
- iters *= 10;
- BufferVec dst;
- for (int i = 0; i < iters; i++) {
- dst = src;
- }
-}
-BENCHMARK(BM_InlinedVectorTenAssignments)
- ->Arg(0)
- ->Arg(1)
- ->Arg(2)
- ->Arg(3)
- ->Arg(4)
- ->Arg(20);
-
-static void BM_CreateFromInitializerList(int iters) {
- for (; iters > 0; iters--) {
- tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
- (void)x[0];
- }
-}
-BENCHMARK(BM_CreateFromInitializerList);
-
-namespace {
-
-struct LargeSwappable {
- LargeSwappable() : d_(1024, 17) {}
- ~LargeSwappable() {}
- LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
-
- friend void swap(LargeSwappable& a, LargeSwappable& b) {
- using std::swap;
- swap(a.d_, b.d_);
- }
-
- LargeSwappable& operator=(LargeSwappable o) {
- using std::swap;
- swap(*this, o);
- return *this;
- }
-
- std::vector<int> d_;
-};
-
-} // namespace
-
-static void BM_LargeSwappableElements(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
- Vec a(len);
- Vec b;
- while (--iters >= 0) {
- using std::swap;
- swap(a, b);
- }
-}
-BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h
index 4ee3f88d18..238aa18e1e 100644
--- a/tensorflow/core/lib/gtl/optional.h
+++ b/tensorflow/core/lib/gtl/optional.h
@@ -13,864 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_GTL_OPTIONAL_H_
-#define TENSORFLOW_LIB_GTL_OPTIONAL_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
+#define TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
-#include <assert.h>
-#include <functional>
-#include <initializer_list>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/platform/logging.h"
+#include "absl/types/optional.h"
namespace tensorflow {
namespace gtl {
-// A value of type gtl::optional<T> holds either a value of T or an
-// "empty" value. When it holds a value of T, it stores it as a direct
-// subobject, so sizeof(optional<T>) is approximately sizeof(T)+1. The interface
-// is based on the upcoming std::optional<T>, and gtl::optional<T> is
-// designed to be cheaply drop-in replaceable by std::optional<T>, once it is
-// rolled out.
-//
-// This implementation is based on the specification in the latest draft as of
-// 2017-01-05, section 20.6.
-//
-// Differences between gtl::optional<T> and std::optional<T> include:
-// - constexpr not used for nonconst member functions.
-// (dependency on some differences between C++11 and C++14.)
-// - nullopt and in_place are not constexpr. We need the inline variable
-// support in C++17 for external linkage.
-// - CHECK instead of throwing std::bad_optional_access.
-// - optional::swap() and swap() relies on std::is_(nothrow_)swappable
-// which is introduced in C++17. So we assume is_swappable is always true
-// and is_nothrow_swappable is same as std::is_trivial.
-// - make_optional cannot be constexpr due to absence of guaranteed copy
-// elision.
-//
-// Synopsis:
-//
-// #include "tensorflow/core/lib/gtl/optional.h"
-//
-// tensorflow::gtl::optional<string> f() {
-// string result;
-// if (...) {
-// ...
-// result = ...;
-// return result;
-// } else {
-// ...
-// return tensorflow::gtl::nullopt;
-// }
-// }
-//
-// int main() {
-// tensorflow::gtl::optional<string> optstr = f();
-// if (optstr) {
-// // non-empty
-// print(optstr.value());
-// } else {
-// // empty
-// error();
-// }
-// }
-template <typename T>
-class optional;
-
-// The tag constant `in_place` is used as the first parameter of an optional<T>
-// constructor to indicate that the remaining arguments should be forwarded
-// to the underlying T constructor.
-struct in_place_t {};
-extern const in_place_t in_place;
-
-// The tag constant `nullopt` is used to indicate an empty optional<T> in
-// certain functions, such as construction or assignment.
-struct nullopt_t {
- struct init_t {};
- static init_t init;
- // It must not be default-constructible to avoid ambiguity for opt = {}.
- // Note the non-const reference, it is to eliminate ambiguity for code like:
- // struct S { int value; };
- //
- // void Test() {
- // optional<S> opt;
- // opt = {{}};
- // }
- explicit constexpr nullopt_t(init_t& /*unused*/) {} // NOLINT
-};
-extern const nullopt_t nullopt;
-
-namespace internal_optional {
-
-// define forward locally because std::forward is not constexpr until C++14
-template <typename T>
-constexpr T&& forward(typename std::remove_reference<T>::type&
- t) noexcept { // NOLINT(runtime/references)
- return static_cast<T&&>(t);
-}
-
-struct empty_struct {};
-// This class stores the data in optional<T>.
-// It is specialized based on whether T is trivially destructible.
-// This is the specialization for non trivially destructible type.
-template <typename T, bool = std::is_trivially_destructible<T>::value>
-class optional_data_dtor_base {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
-
- void destruct() noexcept {
- if (engaged_) {
- data_.~T();
- engaged_ = false;
- }
- }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() { destruct(); }
-};
-
-// Specialization for trivially destructible type.
-template <typename T>
-class optional_data_dtor_base<T, true> {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
- void destruct() noexcept { engaged_ = false; }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() = default;
-};
-
-template <typename T>
-class optional_data : public optional_data_dtor_base<T> {
- protected:
- using base = optional_data_dtor_base<T>;
- using base::base;
-
- T* pointer() { return &this->data_; }
-
- constexpr const T* pointer() const { return &this->data_; }
-
- template <typename... Args>
- void construct(Args&&... args) {
- new (pointer()) T(std::forward<Args>(args)...);
- this->engaged_ = true;
- }
-
- template <typename U>
- void assign(U&& u) {
- if (this->engaged_) {
- this->data_ = std::forward<U>(u);
- } else {
- construct(std::forward<U>(u));
- }
- }
-
- optional_data() = default;
-
- optional_data(const optional_data& rhs) {
- if (rhs.engaged_) {
- construct(rhs.data_);
- }
- }
-
- optional_data(optional_data&& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- construct(std::move(rhs.data_));
- }
- }
-
- optional_data& operator=(const optional_data& rhs) {
- if (rhs.engaged_) {
- assign(rhs.data_);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- optional_data& operator=(optional_data&& rhs) noexcept(
- std::is_nothrow_move_assignable<T>::value&&
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- assign(std::move(rhs.data_));
- } else {
- this->destruct();
- }
- return *this;
- }
-};
-
-// ordered by level of restriction, from low to high.
-// copyable implies movable.
-enum class copy_traits { copyable = 0, movable = 1, non_movable = 2 };
-
-// base class for enabling/disabling copy/move constructor.
-template <copy_traits>
-class optional_ctor_base;
-
-template <>
-class optional_ctor_base<copy_traits::copyable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = default;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::non_movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = delete;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-// base class for enabling/disabling copy/move assignment.
-template <copy_traits>
-class optional_assign_base;
-
-template <>
-class optional_assign_base<copy_traits::copyable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = default;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::non_movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = delete;
-};
-
+// Deprecated: please use absl::optional directly.
+using absl::make_optional;
+using absl::nullopt;
template <typename T>
-constexpr copy_traits get_ctor_copy_traits() {
- return std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_constructible<T>::value ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-template <typename T>
-constexpr copy_traits get_assign_copy_traits() {
- return std::is_copy_assignable<T>::value &&
- std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_assignable<T>::value &&
- std::is_move_constructible<T>::value
- ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-// Whether T is constructible or convertible from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_from_optional
- : std::integral_constant<
- bool, std::is_constructible<T, optional<U>&>::value ||
- std::is_constructible<T, optional<U>&&>::value ||
- std::is_constructible<T, const optional<U>&>::value ||
- std::is_constructible<T, const optional<U>&&>::value ||
- std::is_convertible<optional<U>&, T>::value ||
- std::is_convertible<optional<U>&&, T>::value ||
- std::is_convertible<const optional<U>&, T>::value ||
- std::is_convertible<const optional<U>&&, T>::value> {};
-
-// Whether T is constructible or convertible or assignable from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_assignable_from_optional
- : std::integral_constant<
- bool, is_constructible_convertible_from_optional<T, U>::value ||
- std::is_assignable<T&, optional<U>&>::value ||
- std::is_assignable<T&, optional<U>&&>::value ||
- std::is_assignable<T&, const optional<U>&>::value ||
- std::is_assignable<T&, const optional<U>&&>::value> {};
-
-} // namespace internal_optional
-
-template <typename T>
-class optional : private internal_optional::optional_data<T>,
- private internal_optional::optional_ctor_base<
- internal_optional::get_ctor_copy_traits<T>()>,
- private internal_optional::optional_assign_base<
- internal_optional::get_assign_copy_traits<T>()> {
- using data_base = internal_optional::optional_data<T>;
-
- public:
- typedef T value_type;
-
- // [optional.ctor], constructors
-
- // A default constructed optional holds the empty value, NOT a default
- // constructed T.
- constexpr optional() noexcept {}
-
- // An optional initialized with `nullopt` holds the empty value.
- constexpr optional(nullopt_t) noexcept {} // NOLINT(runtime/explicit)
-
- // Copy constructor, standard semantics.
- optional(const optional& src) = default;
-
- // Move constructor, standard semantics.
- optional(optional&& src) = default;
-
- // optional<T>(in_place, arg1, arg2, arg3) constructs a non-empty optional
- // with an in-place constructed value of T(arg1,arg2,arg3).
- // TODO(b/34201852): Add std::is_constructible<T, Args&&...> SFINAE.
- template <typename... Args>
- constexpr explicit optional(in_place_t, Args&&... args)
- : data_base(in_place_t(), internal_optional::forward<Args>(args)...) {}
-
- // optional<T>(in_place, {arg1, arg2, arg3}) constructs a non-empty optional
- // with an in-place list-initialized value of T({arg1, arg2, arg3}).
- template <typename U, typename... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- constexpr explicit optional(in_place_t, std::initializer_list<U> il,
- Args&&... args)
- : data_base(in_place_t(), il, internal_optional::forward<Args>(args)...) {
- }
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- constexpr optional(U&& v) // NOLINT
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit constexpr optional(U&& v)
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- // Converting copy constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<const U&, T>::value,
- bool>::type = false>
- optional(const optional<U>& rhs) { // NOLINT
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting copy constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<const U&, T>::value,
- bool>::type = false>
- explicit optional(const optional<U>& rhs) {
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting move constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- optional(optional<U>&& rhs) { // NOLINT
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // Converting move constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit optional(optional<U>&& rhs) {
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // [optional.dtor], destructor, trivial if T is trivially destructible.
- ~optional() = default;
-
- // [optional.assign], assignment
-
- // Assignment from nullopt: opt = nullopt
- optional& operator=(nullopt_t) noexcept {
- this->destruct();
- return *this;
- }
-
- // Copy assignment, standard semantics.
- optional& operator=(const optional& src) = default;
-
- // Move assignment, standard semantics.
- optional& operator=(optional&& src) = default;
-
- // Value assignment
- template <
- typename U = T,
- typename = typename std::enable_if<
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- (!std::is_scalar<T>::value ||
- !std::is_same<T, typename std::decay<U>::type>::value) &&
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value>::type>
- optional& operator=(U&& v) {
- this->assign(std::forward<U>(v));
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- std::is_assignable<T&, const U&>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(const optional<U>& rhs) {
- if (rhs) {
- this->assign(*rhs);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(optional<U>&& rhs) {
- if (rhs) {
- this->assign(std::move(*rhs));
- } else {
- this->destruct();
- }
- return *this;
- }
-
- // [optional.mod], modifiers
- // Destroys the inner T value if one is present.
- void reset() noexcept { this->destruct(); }
-
- // Emplace reconstruction. (Re)constructs the underlying T in-place with the
- // given arguments forwarded:
- //
- // optional<Foo> opt;
- // opt.emplace(arg1,arg2,arg3); (Constructs Foo(arg1,arg2,arg3))
- //
- // If the optional is non-empty, and the `args` refer to subobjects of the
- // current object, then behavior is undefined. This is because the current
- // object will be destructed before the new object is constructed with `args`.
- //
- template <typename... Args,
- typename = typename std::enable_if<
- std::is_constructible<T, Args&&...>::value>::type>
- void emplace(Args&&... args) {
- this->destruct();
- this->construct(std::forward<Args>(args)...);
- }
-
- // Emplace reconstruction with initializer-list. See immediately above.
- template <class U, class... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- void emplace(std::initializer_list<U> il, Args&&... args) {
- this->destruct();
- this->construct(il, std::forward<Args>(args)...);
- }
-
- // [optional.swap], swap
- // Swap, standard semantics.
- void swap(optional& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value&&
- std::is_trivial<T>::value) {
- if (*this) {
- if (rhs) {
- using std::swap;
- swap(**this, *rhs);
- } else {
- rhs.construct(std::move(**this));
- this->destruct();
- }
- } else {
- if (rhs) {
- this->construct(std::move(*rhs));
- rhs.destruct();
- } else {
- // no effect (swap(disengaged, disengaged))
- }
- }
- }
-
- // [optional.observe], observers
- // You may use `*opt`, and `opt->m`, to access the underlying T value and T's
- // member `m`, respectively. If the optional is empty, behavior is
- // undefined.
- constexpr const T* operator->() const { return this->pointer(); }
- T* operator->() {
- assert(this->engaged_);
- return this->pointer();
- }
- constexpr const T& operator*() const& { return reference(); }
- T& operator*() & {
- assert(this->engaged_);
- return reference();
- }
- constexpr const T&& operator*() const&& { return std::move(reference()); }
- T&& operator*() && {
- assert(this->engaged_);
- return std::move(reference());
- }
-
- // In a bool context an optional<T> will return false if and only if it is
- // empty.
- //
- // if (opt) {
- // // do something with opt.value();
- // } else {
- // // opt is empty
- // }
- //
- constexpr explicit operator bool() const noexcept { return this->engaged_; }
-
- // Returns false if and only if *this is empty.
- constexpr bool has_value() const noexcept { return this->engaged_; }
-
- // Use `opt.value()` to get a reference to underlying value. The constness
- // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
- // subobject.
- const T& value() const& {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T& value() & {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T&& value() && { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
- const T&& value() const&& { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
-
- // Use `opt.value_or(val)` to get either the value of T or the given default
- // `val` in the empty case.
- template <class U>
- constexpr T value_or(U&& v) const& {
- return static_cast<bool>(*this) ? **this
- : static_cast<T>(std::forward<U>(v));
- }
- template <class U>
- T value_or(U&& v) && { // NOLINT(build/c++11)
- return static_cast<bool>(*this) ? std::move(**this)
- : static_cast<T>(std::forward<U>(v));
- }
-
- private:
- // Private accessors for internal storage viewed as reference to T.
- constexpr const T& reference() const { return *this->pointer(); }
- T& reference() { return *(this->pointer()); }
-
- // T constraint checks. You can't have an optional of nullopt_t, in_place_t
- // or a reference.
- static_assert(
- !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
- "optional<nullopt_t> is not allowed.");
- static_assert(
- !std::is_same<in_place_t, typename std::remove_cv<T>::type>::value,
- "optional<in_place_t> is not allowed.");
- static_assert(!std::is_reference<T>::value,
- "optional<reference> is not allowed.");
-};
-
-// [optional.specalg]
-// Swap, standard semantics.
-// This function shall not participate in overload resolution unless
-// is_move_constructible_v<T> is true and is_swappable_v<T> is true.
-// NOTE: we assume is_swappable is always true. There will be a compiling error
-// if T is actually not Swappable.
-template <typename T,
- typename std::enable_if<std::is_move_constructible<T>::value,
- bool>::type = false>
-void swap(optional<T>& a, optional<T>& b) noexcept(noexcept(a.swap(b))) {
- a.swap(b);
-}
-
-// NOTE: make_optional cannot be constexpr in C++11 because the copy/move
-// constructor is not constexpr and we don't have guaranteed copy elision
-// util C++17. But they are still declared constexpr for consistency with
-// the standard.
-
-// make_optional(v) creates a non-empty optional<T> where the type T is deduced
-// from v. Can also be explicitly instantiated as make_optional<T>(v).
-template <typename T>
-constexpr optional<typename std::decay<T>::type> make_optional(T&& v) {
- return optional<typename std::decay<T>::type>(std::forward<T>(v));
-}
-
-template <typename T, typename... Args>
-constexpr optional<T> make_optional(Args&&... args) {
- return optional<T>(in_place_t(), internal_optional::forward<Args>(args)...);
-}
-
-template <typename T, typename U, typename... Args>
-constexpr optional<T> make_optional(std::initializer_list<U> il,
- Args&&... args) {
- return optional<T>(in_place_t(), il,
- internal_optional::forward<Args>(args)...);
-}
-
-// Relational operators. Empty optionals are considered equal to each
-// other and less than non-empty optionals. Supports relations between
-// optional<T> and optional<T>, between optional<T> and T, and between
-// optional<T> and nullopt.
-// Note: We're careful to support T having non-bool relationals.
-
-// Relational operators [optional.relops]
-// The C++17 (N4606) "Returns:" statements are translated into code
-// in an obvious way here, and the original text retained as function docs.
-// Returns: If bool(x) != bool(y), false; otherwise if bool(x) == false, true;
-// otherwise *x == *y.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? false
- : static_cast<bool>(x) == false ? true : *x == *y;
-}
-// Returns: If bool(x) != bool(y), true; otherwise, if bool(x) == false, false;
-// otherwise *x != *y.
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? true
- : static_cast<bool>(x) == false ? false : *x != *y;
-}
-// Returns: If !y, false; otherwise, if !x, true; otherwise *x < *y.
-template <class T>
-constexpr bool operator<(const optional<T>& x, const optional<T>& y) {
- return !y ? false : !x ? true : *x < *y;
-}
-// Returns: If !x, false; otherwise, if !y, true; otherwise *x > *y.
-template <class T>
-constexpr bool operator>(const optional<T>& x, const optional<T>& y) {
- return !x ? false : !y ? true : *x > *y;
-}
-// Returns: If !x, true; otherwise, if !y, false; otherwise *x <= *y.
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const optional<T>& y) {
- return !x ? true : !y ? false : *x <= *y;
-}
-// Returns: If !y, true; otherwise, if !x, false; otherwise *x >= *y.
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const optional<T>& y) {
- return !y ? true : !x ? false : *x >= *y;
-}
-
-// Comparison with nullopt [optional.nullops]
-// The C++17 (N4606) "Returns:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, nullopt_t) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator<=(nullopt_t, const optional<T>& x) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator>(nullopt_t, const optional<T>& x) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, nullopt_t) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-
-// Comparison with T [optional.comp_with_t]
-// The C++17 (N4606) "Equivalent to:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x == v : false;
-}
-template <class T>
-constexpr bool operator==(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v == *x : false;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x != v : true;
-}
-template <class T>
-constexpr bool operator!=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v != *x : true;
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x < v : true;
-}
-template <class T>
-constexpr bool operator<(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v < *x : false;
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x <= v : true;
-}
-template <class T>
-constexpr bool operator<=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v <= *x : false;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x > v : false;
-}
-template <class T>
-constexpr bool operator>(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v > *x : true;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x >= v : false;
-}
-template <class T>
-constexpr bool operator>=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v >= *x : true;
-}
+using optional = absl::optional<T>;
} // namespace gtl
} // namespace tensorflow
-namespace std {
-
-// Normally std::hash specializations are not recommended in tensorflow code,
-// but we allow this as it is following a standard library component.
-template <class T>
-struct hash<::tensorflow::gtl::optional<T>> {
- size_t operator()(const ::tensorflow::gtl::optional<T>& opt) const {
- if (opt) {
- return hash<T>()(*opt);
- } else {
- return static_cast<size_t>(0x297814aaad196e6dULL);
- }
- }
-};
-
-} // namespace std
-
-#endif // TENSORFLOW_LIB_GTL_OPTIONAL_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc
deleted file mode 100644
index 12b5bbc60b..0000000000
--- a/tensorflow/core/lib/gtl/optional_test.cc
+++ /dev/null
@@ -1,1098 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/optional.h"
-
-#include <string>
-#include <utility>
-
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-using tensorflow::gtl::in_place;
-using tensorflow::gtl::in_place_t;
-using tensorflow::gtl::make_optional;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::nullopt_t;
-using tensorflow::gtl::optional;
-
-template <typename T>
-string TypeQuals(T&) {
- return "&";
-}
-template <typename T>
-string TypeQuals(T&&) {
- return "&&";
-}
-template <typename T>
-string TypeQuals(const T&) {
- return "c&";
-}
-template <typename T>
-string TypeQuals(const T&&) {
- return "c&&";
-}
-
-struct StructorListener {
- int construct0 = 0;
- int construct1 = 0;
- int construct2 = 0;
- int listinit = 0;
- int copy = 0;
- int move = 0;
- int copy_assign = 0;
- int move_assign = 0;
- int destruct = 0;
-};
-
-struct Listenable {
- static StructorListener* listener;
-
- Listenable() { ++listener->construct0; }
- Listenable(int /*unused*/) { ++listener->construct1; } // NOLINT
- Listenable(int /*unused*/, int /*unused*/) { ++listener->construct2; }
- Listenable(std::initializer_list<int> /*unused*/) { ++listener->listinit; }
- Listenable(const Listenable& /*unused*/) { ++listener->copy; }
- Listenable(Listenable&& /*unused*/) { ++listener->move; } // NOLINT
- Listenable& operator=(const Listenable& /*unused*/) {
- ++listener->copy_assign;
- return *this;
- }
- Listenable& operator=(Listenable&& /*unused*/) { // NOLINT
- ++listener->move_assign;
- return *this;
- }
- ~Listenable() { ++listener->destruct; }
-};
-
-StructorListener* Listenable::listener = nullptr;
-
-// clang on macos -- even the latest major version at time of writing (8.x) --
-// does not like much of our constexpr business. clang < 3.0 also has trouble.
-#if defined(__clang__) && defined(__APPLE__)
-#define SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
-#endif
-
-struct ConstexprType {
- constexpr ConstexprType() : x(0) {}
- constexpr explicit ConstexprType(int i) : x(i) {}
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr ConstexprType(std::initializer_list<int> il) : x(il.size()) {}
-#endif
- constexpr ConstexprType(const char* s) : x(-1) {} // NOLINT
- int x;
-};
-
-struct Copyable {
- Copyable() {}
- Copyable(const Copyable&) {}
- Copyable& operator=(const Copyable&) { return *this; }
-};
-
-struct MoveableThrow {
- MoveableThrow() {}
- MoveableThrow(MoveableThrow&&) {}
- MoveableThrow& operator=(MoveableThrow&&) { return *this; }
-};
-
-struct MoveableNoThrow {
- MoveableNoThrow() {}
- MoveableNoThrow(MoveableNoThrow&&) noexcept {}
- MoveableNoThrow& operator=(MoveableNoThrow&&) noexcept { return *this; }
-};
-
-struct NonMovable {
- NonMovable() {}
- NonMovable(const NonMovable&) = delete;
- NonMovable& operator=(const NonMovable&) = delete;
- NonMovable(NonMovable&&) = delete;
- NonMovable& operator=(NonMovable&&) = delete;
-};
-
-TEST(optionalTest, DefaultConstructor) {
- optional<int> empty;
- EXPECT_FALSE(!!empty);
- constexpr optional<int> cempty;
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE(std::is_nothrow_default_constructible<optional<int>>::value);
-}
-
-TEST(optionalTest, NullOptConstructor) {
- optional<int> empty(nullopt);
- EXPECT_FALSE(!!empty);
- // Creating a temporary nullopt_t object instead of using nullopt because
- // nullopt cannot be constexpr and have external linkage at the same time.
- constexpr optional<int> cempty{nullopt_t(nullopt_t::init)};
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE((std::is_nothrow_constructible<optional<int>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_copy(empty);
- EXPECT_FALSE(!!empty_copy);
- optional<int> opt42_copy(opt42);
- EXPECT_TRUE(!!opt42_copy);
- EXPECT_EQ(42, opt42_copy);
- // test copyablility
- EXPECT_TRUE(std::is_copy_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_copy_constructible<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_move(std::move(empty));
- EXPECT_FALSE(!!empty_move);
- optional<int> opt42_move(std::move(opt42));
- EXPECT_TRUE(!!opt42_move);
- EXPECT_EQ(42, opt42_move);
- // test movability
- EXPECT_TRUE(std::is_move_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_constructible<optional<NonMovable>>::value);
- // test noexcept
- EXPECT_TRUE(std::is_nothrow_move_constructible<optional<int>>::value);
- EXPECT_FALSE(
- std::is_nothrow_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_constructible<optional<MoveableNoThrow>>::value);
-}
-
-TEST(optionalTest, Destructor) {
- struct Trivial {};
-
- struct NonTrivial {
- ~NonTrivial() {}
- };
-
- EXPECT_TRUE(std::is_trivially_destructible<optional<int>>::value);
- EXPECT_TRUE(std::is_trivially_destructible<optional<Trivial>>::value);
- EXPECT_FALSE(std::is_trivially_destructible<optional<NonTrivial>>::value);
-}
-
-TEST(optionalTest, InPlaceConstructor) {
- constexpr optional<ConstexprType> opt0{in_place_t()};
- static_assert(opt0, "");
- static_assert(opt0->x == 0, "");
- constexpr optional<ConstexprType> opt1{in_place_t(), 1};
- static_assert(opt1, "");
- static_assert(opt1->x == 1, "");
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<ConstexprType> opt2{in_place_t(), {1, 2}};
- static_assert(opt2, "");
- static_assert(opt2->x == 2, "");
-#endif
-
- // TODO(b/34201852): uncomment these when std::is_constructible<T, Args&&...>
- // SFINAE is added to optional::optional(in_place_t, Args&&...).
- // struct I {
- // I(in_place_t);
- // };
-
- // EXPECT_FALSE((std::is_constructible<optional<I>, in_place_t>::value));
- // EXPECT_FALSE((std::is_constructible<optional<I>, const
- // in_place_t&>::value));
-}
-
-// template<U=T> optional(U&&);
-TEST(optionalTest, ValueConstructor) {
- constexpr optional<int> opt0(0);
- static_assert(opt0, "");
- static_assert(*opt0 == 0, "");
- EXPECT_TRUE((std::is_convertible<int, optional<int>>::value));
- // Copy initialization ( = "abc") won't work due to optional(optional&&)
- // is not constexpr. Use list initialization instead. This invokes
- // optional<ConstexprType>::optional<U>(U&&), with U = const char (&) [4],
- // which direct-initializes the ConstexprType value held by the optional
- // via ConstexprType::ConstexprType(const char*).
- constexpr optional<ConstexprType> opt1 = {"abc"};
- static_assert(opt1, "");
- static_assert(-1 == opt1->x, "");
- EXPECT_TRUE(
- (std::is_convertible<const char*, optional<ConstexprType>>::value));
- // direct initialization
- constexpr optional<ConstexprType> opt2{2};
- static_assert(opt2, "");
- static_assert(2 == opt2->x, "");
- EXPECT_FALSE((std::is_convertible<int, optional<ConstexprType>>::value));
-
- // this invokes optional<int>::optional(int&&)
- // NOTE: this has different behavior than assignment, e.g.
- // "opt3 = {};" clears the optional rather than setting the value to 0
- constexpr optional<int> opt3({});
- static_assert(opt3, "");
- static_assert(*opt3 == 0, "");
-
- // this invokes the move constructor with a default constructed optional
- // because non-template function is a better match than template function.
- optional<ConstexprType> opt4({});
- EXPECT_FALSE(!!opt4);
-}
-
-struct Implicit {};
-
-struct Explicit {};
-
-struct Convert {
- Convert(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false) {}
- Convert(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true) {}
- explicit Convert(const Explicit&) : implicit(false), move(false) {}
- explicit Convert(Explicit&&) : implicit(false), move(true) {}
-
- bool implicit;
- bool move;
-};
-
-struct ConvertFromOptional {
- ConvertFromOptional(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(false) {}
- ConvertFromOptional(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(false) {}
- ConvertFromOptional(const optional<Implicit>&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(true) {}
- ConvertFromOptional(optional<Implicit>&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(true) {}
- explicit ConvertFromOptional(const Explicit&)
- : implicit(false), move(false), from_optional(false) {}
- explicit ConvertFromOptional(Explicit&&)
- : implicit(false), move(true), from_optional(false) {}
- explicit ConvertFromOptional(const optional<Explicit>&)
- : implicit(false), move(false), from_optional(true) {}
- explicit ConvertFromOptional(optional<Explicit>&&)
- : implicit(false), move(true), from_optional(true) {}
-
- bool implicit;
- bool move;
- bool from_optional;
-};
-
-TEST(optionalTest, ConvertingConstructor) {
- optional<Implicit> i_empty;
- optional<Implicit> i(in_place);
- optional<Explicit> e_empty;
- optional<Explicit> e(in_place);
- {
- // implicitly constructing optional<Convert> from optional<Implicit>
- optional<Convert> empty = i_empty;
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy = i;
- EXPECT_TRUE(!!opt_copy);
- EXPECT_TRUE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- optional<Convert> opt_move = optional<Implicit>(in_place);
- EXPECT_TRUE(!!opt_move);
- EXPECT_TRUE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- }
- {
- // explicitly constructing optional<Convert> from optional<Explicit>
- optional<Convert> empty(e_empty);
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy(e);
- EXPECT_TRUE(!!opt_copy);
- EXPECT_FALSE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<Convert>>::value));
- optional<Convert> opt_move{optional<Explicit>(in_place)};
- EXPECT_TRUE(!!opt_move);
- EXPECT_FALSE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- EXPECT_FALSE(
- (std::is_convertible<optional<Explicit>&&, optional<Convert>>::value));
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Implicit> via ConvertFromOptional(optional<Implicit>&&)
- // check that ConvertFromOptional(Implicit&&) is NOT called
- static_assert(
- gtl::internal_optional::is_constructible_convertible_from_optional<
- ConvertFromOptional, Implicit>::value,
- "");
- optional<ConvertFromOptional> opt0 = i_empty;
- EXPECT_TRUE(!!opt0);
- EXPECT_TRUE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- optional<ConvertFromOptional> opt1 = optional<Implicit>();
- EXPECT_TRUE(!!opt1);
- EXPECT_TRUE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Explicit> via ConvertFromOptional(optional<Explicit>&&)
- // check that ConvertFromOptional(Explicit&&) is NOT called
- optional<ConvertFromOptional> opt0(e_empty);
- EXPECT_TRUE(!!opt0);
- EXPECT_FALSE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<ConvertFromOptional>>::value));
- optional<ConvertFromOptional> opt1{optional<Explicit>()};
- EXPECT_TRUE(!!opt1);
- EXPECT_FALSE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- EXPECT_FALSE((std::is_convertible<optional<Explicit>&&,
- optional<ConvertFromOptional>>::value));
- }
-}
-
-TEST(optionalTest, StructorBasic) {
- StructorListener listener;
- Listenable::listener = &listener;
- {
- optional<Listenable> empty;
- EXPECT_FALSE(!!empty);
- optional<Listenable> opt0(in_place);
- EXPECT_TRUE(!!opt0);
- optional<Listenable> opt1(in_place, 1);
- EXPECT_TRUE(!!opt1);
- optional<Listenable> opt2(in_place, 1, 2);
- EXPECT_TRUE(!!opt2);
- }
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(3, listener.destruct);
-}
-
-TEST(optionalTest, CopyMoveStructor) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> original(in_place);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> copy(original);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> move(std::move(original));
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(1, listener.move);
-}
-
-TEST(optionalTest, ListInit) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> listinit1(in_place, {1});
- optional<Listenable> listinit2(in_place, {1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, AssignFromNullopt) {
- optional<int> opt(1);
- opt = nullopt;
- EXPECT_FALSE(!!opt);
-
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt1(in_place);
- opt1 = nullopt;
- EXPECT_FALSE(opt1);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.destruct);
-
- EXPECT_TRUE((std::is_nothrow_assignable<optional<int>, nullopt_t>::value));
- EXPECT_TRUE(
- (std::is_nothrow_assignable<optional<Listenable>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyAssignment) {
- const optional<int> empty, opt1 = 1, opt2 = 2;
- optional<int> empty_to_opt1, opt1_to_opt2, opt2_to_empty;
-
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = empty;
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = opt1;
- EXPECT_TRUE(!!empty_to_opt1);
- EXPECT_EQ(1, empty_to_opt1.value());
-
- EXPECT_FALSE(!!opt1_to_opt2);
- opt1_to_opt2 = opt1;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(1, opt1_to_opt2.value());
- opt1_to_opt2 = opt2;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(2, opt1_to_opt2.value());
-
- EXPECT_FALSE(!!opt2_to_empty);
- opt2_to_empty = opt2;
- EXPECT_TRUE(!!opt2_to_empty);
- EXPECT_EQ(2, opt2_to_empty.value());
- opt2_to_empty = empty;
- EXPECT_FALSE(!!opt2_to_empty);
-
- EXPECT_TRUE(std::is_copy_assignable<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveAssignment) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> empty1, empty2, set1(in_place), set2(in_place);
- EXPECT_EQ(2, listener.construct0);
- optional<Listenable> empty_to_empty, empty_to_set, set_to_empty(in_place),
- set_to_set(in_place);
- EXPECT_EQ(4, listener.construct0);
- empty_to_empty = std::move(empty1);
- empty_to_set = std::move(set1);
- set_to_empty = std::move(empty2);
- set_to_set = std::move(set2);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(1, listener.destruct);
- EXPECT_EQ(1, listener.move_assign);
-
- EXPECT_TRUE(std::is_move_assignable<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_assignable<optional<NonMovable>>::value);
-
- EXPECT_FALSE(std::is_nothrow_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_assignable<optional<MoveableNoThrow>>::value);
-}
-
-struct NoConvertToOptional {
- // disable implicit conversion from const NoConvertToOptional&
- // to optional<NoConvertToOptional>.
- NoConvertToOptional(const NoConvertToOptional&) = delete;
-};
-
-struct CopyConvert {
- CopyConvert(const NoConvertToOptional&);
- CopyConvert& operator=(const CopyConvert&) = delete;
- CopyConvert& operator=(const NoConvertToOptional&);
-};
-
-struct CopyConvertFromOptional {
- CopyConvertFromOptional(const NoConvertToOptional&);
- CopyConvertFromOptional(const optional<NoConvertToOptional>&);
- CopyConvertFromOptional& operator=(const CopyConvertFromOptional&) = delete;
- CopyConvertFromOptional& operator=(const NoConvertToOptional&);
- CopyConvertFromOptional& operator=(const optional<NoConvertToOptional>&);
-};
-
-struct MoveConvert {
- MoveConvert(NoConvertToOptional&&);
- MoveConvert& operator=(const MoveConvert&) = delete;
- MoveConvert& operator=(NoConvertToOptional&&);
-};
-
-struct MoveConvertFromOptional {
- MoveConvertFromOptional(NoConvertToOptional&&);
- MoveConvertFromOptional(optional<NoConvertToOptional>&&);
- MoveConvertFromOptional& operator=(const MoveConvertFromOptional&) = delete;
- MoveConvertFromOptional& operator=(NoConvertToOptional&&);
- MoveConvertFromOptional& operator=(optional<NoConvertToOptional>&&);
-};
-
-// template <class U = T> optional<T>& operator=(U&& v);
-TEST(optionalTest, ValueAssignment) {
- optional<int> opt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = nullopt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = 43;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(43, opt.value());
- opt = {}; // this should clear optional
- EXPECT_FALSE(!!opt);
-
- opt = {44};
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(44, opt.value());
-
- // U = const NoConvertToOptional&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvert>&,
- const NoConvertToOptional&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvert>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- NoConvertToOptional&&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvertFromOptional, const
- // NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- NoConvertToOptional&&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
- // U = optional<NoConvertToOptional>
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- optional<NoConvertToOptional>&&>::value));
-}
-
-// template <class U> optional<T>& operator=(const optional<U>& rhs);
-// template <class U> optional<T>& operator=(optional<U>&& rhs);
-TEST(optionalTest, ConvertingAssignment) {
- optional<int> opt_i;
- optional<char> opt_c('c');
- opt_i = opt_c;
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ(*opt_c, *opt_i);
- opt_i = optional<char>();
- EXPECT_FALSE(!!opt_i);
- opt_i = optional<char>('d');
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ('d', *opt_i);
-
- optional<string> opt_str;
- optional<const char*> opt_cstr("abc");
- opt_str = opt_cstr;
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("abc"), *opt_str);
- opt_str = optional<const char*>();
- EXPECT_FALSE(!!opt_str);
- opt_str = optional<const char*>("def");
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("def"), *opt_str);
-
- // operator=(const optional<U>&) with U = NoConvertToOptional
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvert>,
- const optional<NoConvertToOptional>&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional
- // triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvert>&,
- const optional<NoConvertToOptional>&>::value));
- // operator=(optional<U>&&) with U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- optional<NoConvertToOptional>&&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional triggers SFINAE
- // because std::is_constructible_v<MoveConvertFromOptional,
- // const NoConvertToOptional&> is false.
- // operator=(U&&) with U = const optional<NoConverToOptional>& triggers SFINAE
- // because std::is_constructible<MoveConvertFromOptional,
- // optional<NoConvertToOptional>&&> is true.
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
-}
-
-TEST(optionalTest, ResetAndHasValue) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- opt.emplace();
- EXPECT_TRUE(!!opt);
- EXPECT_TRUE(opt.has_value());
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- EXPECT_EQ(1, listener.destruct);
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
-
- constexpr optional<int> empty;
- static_assert(!empty.has_value(), "");
- constexpr optional<int> nonempty(1);
- static_assert(nonempty.has_value(), "");
-}
-
-TEST(optionalTest, Emplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace(1);
- EXPECT_TRUE(!!opt);
- opt.emplace(1, 2);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, ListEmplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace({1});
- EXPECT_TRUE(!!opt);
- opt.emplace({1, 2});
- EXPECT_EQ(2, listener.listinit);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, Swap) {
- optional<int> opt_empty, opt1 = 1, opt2 = 2;
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt1);
- EXPECT_TRUE(!!opt_empty);
- EXPECT_EQ(1, opt_empty.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt1, opt2);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(2, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(1, opt2.value());
-
- EXPECT_TRUE(noexcept(opt1.swap(opt2)));
- EXPECT_TRUE(noexcept(swap(opt1, opt2)));
-}
-
-TEST(optionalTest, PointerStuff) {
- optional<string> opt(in_place, "foo");
- EXPECT_EQ("foo", *opt);
- const auto& opt_const = opt;
- EXPECT_EQ("foo", *opt_const);
- EXPECT_EQ(opt->size(), 3);
- EXPECT_EQ(opt_const->size(), 3);
-
- constexpr optional<ConstexprType> opt1(1);
- static_assert(opt1->x == 1, "");
-}
-
-// gcc has a bug pre 4.9 where it doesn't do correct overload resolution
-// between rvalue reference qualified member methods. Skip that test to make
-// the build green again when using the old compiler.
-#if defined(__GNUC__) && !defined(__clang__)
-#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 9)
-#define SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
-#endif
-#endif
-
-TEST(optionalTest, Value) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", lvalue.value());
- EXPECT_EQ("clvalue", clvalue.value());
- EXPECT_EQ("xvalue", O(in_place, "xvalue").value());
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", CO(in_place, "cxvalue").value());
- EXPECT_EQ("&", TypeQuals(lvalue.value()));
- EXPECT_EQ("c&", TypeQuals(clvalue.value()));
- EXPECT_EQ("&&", TypeQuals(O(in_place, "xvalue").value()));
- EXPECT_EQ("c&&", TypeQuals(CO(in_place, "cxvalue").value()));
-#endif
-}
-
-TEST(optionalTest, DerefOperator) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", *lvalue);
- EXPECT_EQ("clvalue", *clvalue);
- EXPECT_EQ("xvalue", *O(in_place, "xvalue"));
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", *CO(in_place, "cxvalue"));
- EXPECT_EQ("&", TypeQuals(*lvalue));
- EXPECT_EQ("c&", TypeQuals(*clvalue));
- EXPECT_EQ("&&", TypeQuals(*O(in_place, "xvalue")));
- EXPECT_EQ("c&&", TypeQuals(*CO(in_place, "cxvalue")));
-#endif
-
- constexpr optional<int> opt1(1);
- static_assert(*opt1 == 1, "");
-
-#if !defined(SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG) && \
- !defined(SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG)
- using COI = const optional<int>;
- static_assert(*COI(2) == 2, "");
-#endif
-}
-
-TEST(optionalTest, ValueOr) {
- optional<double> opt_empty, opt_set = 1.2;
- EXPECT_EQ(42.0, opt_empty.value_or(42));
- EXPECT_EQ(1.2, opt_set.value_or(42));
- EXPECT_EQ(42.0, optional<double>().value_or(42));
- EXPECT_EQ(1.2, optional<double>(1.2).value_or(42));
-
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<double> copt_empty;
- static_assert(42.0 == copt_empty.value_or(42), "");
-
- constexpr optional<double> copt_set = {1.2};
- static_assert(1.2 == copt_set.value_or(42), "");
-
- using COD = const optional<double>;
- static_assert(42.0 == COD().value_or(42), "");
- static_assert(1.2 == COD(1.2).value_or(42), "");
-#endif
-}
-
-// make_optional cannot be constexpr until C++17
-TEST(optionalTest, make_optional) {
- auto opt_int = make_optional(42);
- EXPECT_TRUE((std::is_same<decltype(opt_int), optional<int>>::value));
- EXPECT_EQ(42, opt_int);
-
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> opt0 = make_optional<Listenable>();
- EXPECT_EQ(1, listener.construct0);
- optional<Listenable> opt1 = make_optional<Listenable>(1);
- EXPECT_EQ(1, listener.construct1);
- optional<Listenable> opt2 = make_optional<Listenable>(1, 2);
- EXPECT_EQ(1, listener.construct2);
- optional<Listenable> opt3 = make_optional<Listenable>({1});
- optional<Listenable> opt4 = make_optional<Listenable>({1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, Comparisons) {
- optional<int> ae, be, a2 = 2, b2 = 2, a4 = 4, b4 = 4;
-
-#define optionalTest_Comparisons_EXPECT_LESS(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_TRUE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_FALSE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_SAME(x, y) \
- EXPECT_TRUE((x) == (y)); \
- EXPECT_FALSE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_GREATER(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_TRUE((x) > (y)); \
- EXPECT_FALSE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
- // LHS: nullopt, ae, a2, 3, a4
- // RHS: nullopt, be, b2, 3, b4
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,nullopt);
- optionalTest_Comparisons_EXPECT_SAME(nullopt, be);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b2);
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,3);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b4);
-
- optionalTest_Comparisons_EXPECT_SAME(ae, nullopt);
- optionalTest_Comparisons_EXPECT_SAME(ae, be);
- optionalTest_Comparisons_EXPECT_LESS(ae, b2);
- optionalTest_Comparisons_EXPECT_LESS(ae, 3);
- optionalTest_Comparisons_EXPECT_LESS(ae, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a2, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a2, be);
- optionalTest_Comparisons_EXPECT_SAME(a2, b2);
- optionalTest_Comparisons_EXPECT_LESS(a2, 3);
- optionalTest_Comparisons_EXPECT_LESS(a2, b4);
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(3,nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(3, be);
- optionalTest_Comparisons_EXPECT_GREATER(3, b2);
- optionalTest_Comparisons_EXPECT_SAME(3, 3);
- optionalTest_Comparisons_EXPECT_LESS(3, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a4, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a4, be);
- optionalTest_Comparisons_EXPECT_GREATER(a4, b2);
- optionalTest_Comparisons_EXPECT_GREATER(a4, 3);
- optionalTest_Comparisons_EXPECT_SAME(a4, b4);
-}
-
-TEST(optionalTest, SwapRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- optional<Listenable> b(in_place);
- a.swap(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-
- {
- optional<Listenable> a(in_place);
- optional<Listenable> b;
- a.swap(b);
- }
-
- EXPECT_EQ(2, listener.construct0);
- EXPECT_EQ(2, listener.move);
- EXPECT_EQ(4, listener.destruct);
-}
-
-TEST(optionalTest, BigStringLeakCheck) {
- constexpr size_t n = 1 << 16;
-
- using OS = optional<string>;
-
- OS a;
- OS b = nullopt;
- OS c = string(n, 'c');
- string sd(n, 'd');
- OS d = sd;
- OS e(in_place, n, 'e');
- OS f;
- f.emplace(n, 'f');
-
- OS ca(a);
- OS cb(b);
- OS cc(c);
- OS cd(d);
- OS ce(e);
-
- OS oa;
- OS ob = nullopt;
- OS oc = string(n, 'c');
- string sod(n, 'd');
- OS od = sod;
- OS oe(in_place, n, 'e');
- OS of;
- of.emplace(n, 'f');
-
- OS ma(std::move(oa));
- OS mb(std::move(ob));
- OS mc(std::move(oc));
- OS md(std::move(od));
- OS me(std::move(oe));
- OS mf(std::move(of));
-
- OS aa1;
- OS ab1 = nullopt;
- OS ac1 = string(n, 'c');
- string sad1(n, 'd');
- OS ad1 = sad1;
- OS ae1(in_place, n, 'e');
- OS af1;
- af1.emplace(n, 'f');
-
- OS aa2;
- OS ab2 = nullopt;
- OS ac2 = string(n, 'c');
- string sad2(n, 'd');
- OS ad2 = sad2;
- OS ae2(in_place, n, 'e');
- OS af2;
- af2.emplace(n, 'f');
-
- aa1 = af2;
- ab1 = ae2;
- ac1 = ad2;
- ad1 = ac2;
- ae1 = ab2;
- af1 = aa2;
-
- OS aa3;
- OS ab3 = nullopt;
- OS ac3 = string(n, 'c');
- string sad3(n, 'd');
- OS ad3 = sad3;
- OS ae3(in_place, n, 'e');
- OS af3;
- af3.emplace(n, 'f');
-
- aa3 = nullopt;
- ab3 = nullopt;
- ac3 = nullopt;
- ad3 = nullopt;
- ae3 = nullopt;
- af3 = nullopt;
-
- OS aa4;
- OS ab4 = nullopt;
- OS ac4 = string(n, 'c');
- string sad4(n, 'd');
- OS ad4 = sad4;
- OS ae4(in_place, n, 'e');
- OS af4;
- af4.emplace(n, 'f');
-
- aa4 = OS(in_place, n, 'a');
- ab4 = OS(in_place, n, 'b');
- ac4 = OS(in_place, n, 'c');
- ad4 = OS(in_place, n, 'd');
- ae4 = OS(in_place, n, 'e');
- af4 = OS(in_place, n, 'f');
-
- OS aa5;
- OS ab5 = nullopt;
- OS ac5 = string(n, 'c');
- string sad5(n, 'd');
- OS ad5 = sad5;
- OS ae5(in_place, n, 'e');
- OS af5;
- af5.emplace(n, 'f');
-
- string saa5(n, 'a');
- string sab5(n, 'a');
- string sac5(n, 'a');
- string sad52(n, 'a');
- string sae5(n, 'a');
- string saf5(n, 'a');
-
- aa5 = saa5;
- ab5 = sab5;
- ac5 = sac5;
- ad5 = sad52;
- ae5 = sae5;
- af5 = saf5;
-
- OS aa6;
- OS ab6 = nullopt;
- OS ac6 = string(n, 'c');
- string sad6(n, 'd');
- OS ad6 = sad6;
- OS ae6(in_place, n, 'e');
- OS af6;
- af6.emplace(n, 'f');
-
- aa6 = string(n, 'a');
- ab6 = string(n, 'b');
- ac6 = string(n, 'c');
- ad6 = string(n, 'd');
- ae6 = string(n, 'e');
- af6 = string(n, 'f');
-
- OS aa7;
- OS ab7 = nullopt;
- OS ac7 = string(n, 'c');
- string sad7(n, 'd');
- OS ad7 = sad7;
- OS ae7(in_place, n, 'e');
- OS af7;
- af7.emplace(n, 'f');
-
- aa7.emplace(n, 'A');
- ab7.emplace(n, 'B');
- ac7.emplace(n, 'C');
- ad7.emplace(n, 'D');
- ae7.emplace(n, 'E');
- af7.emplace(n, 'F');
-}
-
-TEST(optionalTest, MoveAssignRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- Listenable b;
- a = std::move(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-}
-
-TEST(optionalTest, ValueType) {
- EXPECT_TRUE((std::is_same<optional<int>::value_type, int>::value));
- EXPECT_TRUE((std::is_same<optional<string>::value_type, string>::value));
- EXPECT_FALSE((std::is_same<optional<int>::value_type, nullopt_t>::value));
-}
-
-TEST(optionalTest, Hash) {
- std::hash<optional<int>> hash;
- std::set<size_t> hashcodes;
- hashcodes.insert(hash(nullopt));
- for (int i = 0; i < 100; ++i) {
- hashcodes.insert(hash(i));
- }
- EXPECT_GT(hashcodes.size(), 90);
-}
-
-struct MoveMeNoThrow {
- MoveMeNoThrow() : x(0) {}
- MoveMeNoThrow(const MoveMeNoThrow& other) : x(other.x) {
- LOG(FATAL) << "Should not be called.";
- }
- MoveMeNoThrow(MoveMeNoThrow&& other) noexcept : x(other.x) {}
- int x;
-};
-
-struct MoveMeThrow {
- MoveMeThrow() : x(0) {}
- MoveMeThrow(const MoveMeThrow& other) : x(other.x) {}
- MoveMeThrow(MoveMeThrow&& other) : x(other.x) {}
- int x;
-};
-
-TEST(optionalTest, NoExcept) {
- static_assert(
- std::is_nothrow_move_constructible<optional<MoveMeNoThrow>>::value, "");
- static_assert(
- !std::is_nothrow_move_constructible<optional<MoveMeThrow>>::value, "");
- std::vector<optional<MoveMeNoThrow>> v;
- v.reserve(10);
- for (int i = 0; i < 10; ++i) v.emplace_back();
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/priority_queue_util.h b/tensorflow/core/lib/gtl/priority_queue_util.h
index 07311e3725..93bf3d3037 100644
--- a/tensorflow/core/lib/gtl/priority_queue_util.h
+++ b/tensorflow/core/lib/gtl/priority_queue_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
-#define TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#define TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
#include <algorithm>
#include <queue>
@@ -52,4 +52,4 @@ T ConsumeTop(std::priority_queue<T, Container, Comparator>* q) {
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h
index ee0bda93b1..2718cd31b3 100644
--- a/tensorflow/core/lib/hash/crc32c.h
+++ b/tensorflow/core/lib/hash/crc32c.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_
-#define TENSORFLOW_LIB_HASH_CRC32C_H_
+#ifndef TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
+#define TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
#include <stddef.h>
#include "tensorflow/core/platform/types.h"
@@ -51,4 +51,4 @@ inline uint32 Unmask(uint32 masked_crc) {
} // namespace crc32c
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HASH_CRC32C_H_
+#endif // TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index 737d23f699..675bab7191 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -15,8 +15,8 @@ limitations under the License.
// Simple hash functions used for internal data structures
-#ifndef TENSORFLOW_LIB_HASH_HASH_H_
-#define TENSORFLOW_LIB_HASH_HASH_H_
+#ifndef TENSORFLOW_CORE_LIB_HASH_HASH_H_
+#define TENSORFLOW_CORE_LIB_HASH_HASH_H_
#include <stddef.h>
#include <stdint.h>
@@ -110,4 +110,4 @@ struct hash<std::pair<T, U>> {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HASH_HASH_H_
+#endif // TENSORFLOW_CORE_LIB_HASH_HASH_H_
diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h
index 65ce10786d..f882ee9abe 100644
--- a/tensorflow/core/lib/histogram/histogram.h
+++ b/tensorflow/core/lib/histogram/histogram.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
-#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#ifndef TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
+#define TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
#include <string>
#include <vector>
@@ -136,4 +136,4 @@ class ThreadSafeHistogram {
} // namespace histogram
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#endif // TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h
index 924619f40f..96a95b7ed9 100644
--- a/tensorflow/core/lib/io/buffered_inputstream.h
+++ b/tensorflow/core/lib/io/buffered_inputstream.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
-#define TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#define TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/file_system.h"
@@ -104,4 +104,4 @@ class BufferedInputStream : public InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#endif // TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h
index 3083d20776..cbfc509d93 100644
--- a/tensorflow/core/lib/io/inputstream_interface.h
+++ b/tensorflow/core/lib/io/inputstream_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_
-#define TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
+#define TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc
index b62206012c..b75dcecadf 100644
--- a/tensorflow/core/lib/io/path.cc
+++ b/tensorflow/core/lib/io/path.cc
@@ -42,7 +42,7 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
if (path.empty()) continue;
if (result.empty()) {
- result = std::string(path);
+ result = string(path);
continue;
}
@@ -124,7 +124,7 @@ StringPiece Extension(StringPiece path) {
}
string CleanPath(StringPiece unclean_path) {
- string path = std::string(unclean_path);
+ string path(unclean_path);
const char* src = path.c_str();
string::iterator dst = path.begin();
@@ -237,7 +237,7 @@ void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
if (scheme.empty()) {
- return std::string(path);
+ return string(path);
}
return strings::StrCat(scheme, "://", host, path);
}
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index 818ba99888..e3649fd0c9 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_PATH_H_
-#define TENSORFLOW_LIB_IO_PATH_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_PATH_H_
+#define TENSORFLOW_CORE_LIB_IO_PATH_H_
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -94,4 +94,4 @@ string GetTempFilename(const string& extension);
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_PATH_H_
+#endif // TENSORFLOW_CORE_LIB_IO_PATH_H_
diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc
index e3275b93b6..0090b9100c 100644
--- a/tensorflow/core/lib/io/path_test.cc
+++ b/tensorflow/core/lib/io/path_test.cc
@@ -104,9 +104,9 @@ TEST(PathTest, CleanPath) {
StringPiece u(uri); \
StringPiece s, h, p; \
ParseURI(u, &s, &h, &p); \
- EXPECT_EQ(scheme, s.ToString()); \
- EXPECT_EQ(host, h.ToString()); \
- EXPECT_EQ(path, p.ToString()); \
+ EXPECT_EQ(scheme, s); \
+ EXPECT_EQ(host, h); \
+ EXPECT_EQ(path, p); \
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
EXPECT_LE(u.begin(), s.begin()); \
EXPECT_GE(u.end(), s.begin()); \
diff --git a/tensorflow/core/lib/io/proto_encode_helper.h b/tensorflow/core/lib/io/proto_encode_helper.h
index f70e1cbaab..34905520f1 100644
--- a/tensorflow/core/lib/io/proto_encode_helper.h
+++ b/tensorflow/core/lib/io/proto_encode_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
-#define TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
+#define TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -95,4 +95,4 @@ class ProtoEncodeHelper {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h
index bdbdbd71ff..c822fe50e9 100644
--- a/tensorflow/core/lib/io/random_inputstream.h
+++ b/tensorflow/core/lib/io/random_inputstream.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
-#define TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
+#define TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/file_system.h"
@@ -54,4 +54,4 @@ class RandomAccessInputStream : public InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index c24628be57..f93ebea771 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -109,9 +109,6 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
}
Status RecordReader::ReadRecord(uint64* offset, string* record) {
- static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
- static const size_t kFooterSize = sizeof(uint32);
-
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
int64 desired_pos = static_cast<int64>(*offset);
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index f6d587dfa0..11af1366b0 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_
-#define TENSORFLOW_LIB_IO_RECORD_READER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
+#define TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -58,6 +58,14 @@ class RecordReaderOptions {
// Note: this class is not thread safe; external synchronization required.
class RecordReader {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -122,4 +130,4 @@ class SequentialRecordReader {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 6e71d23e71..2c6db2487e 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() {
}
}
-static uint32 MaskedCrc(const char* data, size_t n) {
- return crc32c::Mask(crc32c::Value(data, n));
-}
-
Status RecordWriter::WriteRecord(StringPiece data) {
if (dest_ == nullptr) {
return Status(::tensorflow::error::FAILED_PRECONDITION,
@@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) {
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
- char header[sizeof(uint64) + sizeof(uint32)];
- core::EncodeFixed64(header + 0, data.size());
- core::EncodeFixed32(header + sizeof(uint64),
- MaskedCrc(header, sizeof(uint64)));
- char footer[sizeof(uint32)];
- core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
-
+ char header[kHeaderSize];
+ char footer[kFooterSize];
+ PopulateHeader(header, data.data(), data.size());
+ PopulateFooter(footer, data.data(), data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(dest_->Append(data));
return dest_->Append(StringPiece(footer, sizeof(footer)));
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index daed809af3..1212e1fafb 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_
-#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@@ -41,12 +43,20 @@ class RecordWriterOptions {
// Options specific to zlib compression.
#if !defined(IS_SLIM_BUILD)
- ZlibCompressionOptions zlib_options;
+ tensorflow::io::ZlibCompressionOptions zlib_options;
#endif // IS_SLIM_BUILD
};
class RecordWriter {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a writer that will append data to "*dest".
// "*dest" must be initially empty.
// "*dest" must remain live while this Writer is in use.
@@ -72,14 +82,36 @@ class RecordWriter {
// are invalid.
Status Close();
+ // Utility method to populate TFRecord headers. Populates record-header in
+ // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1].
+ inline static void PopulateHeader(char* header, const char* data, size_t n);
+
+ // Utility method to populate TFRecord footers. Populates record-footer in
+ // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1].
+ inline static void PopulateFooter(char* footer, const char* data, size_t n);
+
private:
WritableFile* dest_;
RecordWriterOptions options_;
+ inline static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
+void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) {
+ core::EncodeFixed64(header + 0, n);
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+}
+
+void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) {
+ core::EncodeFixed32(footer, MaskedCrc(data, n));
+}
+
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h
index a1b78eae5b..b9c6b8d9d2 100644
--- a/tensorflow/core/lib/io/table.h
+++ b/tensorflow/core/lib/io/table.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_TABLE_H_
-#define TENSORFLOW_LIB_IO_TABLE_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_H_
#include <stdint.h>
#include "tensorflow/core/lib/io/iterator.h"
@@ -84,4 +84,4 @@ class Table {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_H_
diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h
index 0202f90446..0e37e0a77f 100644
--- a/tensorflow/core/lib/io/table_builder.h
+++ b/tensorflow/core/lib/io/table_builder.h
@@ -21,8 +21,8 @@ limitations under the License.
// non-const method, all threads accessing the same TableBuilder must use
// external synchronization.
-#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
-#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
#include <stdint.h>
#include "tensorflow/core/lib/core/status.h"
@@ -96,4 +96,4 @@ class TableBuilder {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h
index fd8a9d4a78..9a36bf1631 100644
--- a/tensorflow/core/lib/io/table_options.h
+++ b/tensorflow/core/lib/io/table_options.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
-#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
#include <stddef.h>
@@ -65,4 +65,4 @@ struct Options {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 9e3309f0a7..877ac40f1c 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -147,7 +147,7 @@ class Constructor {
virtual ~Constructor() {}
void Add(const string& key, const StringPiece& value) {
- data_[key] = std::string(value);
+ data_[key] = string(value);
}
// Finish constructing the data structure with all the keys that have
@@ -188,7 +188,7 @@ class BlockConstructor : public Constructor {
builder.Add(it->first, it->second);
}
// Open the block
- data_ = std::string(builder.Finish());
+ data_ = string(builder.Finish());
BlockContents contents;
contents.data = data_;
contents.cachable = false;
@@ -515,7 +515,7 @@ TEST_F(Harness, Randomized) {
for (int e = 0; e < num_entries; e++) {
string v;
Add(test::RandomKey(&rnd, rnd.Skewed(4)),
- std::string(test::RandomString(&rnd, rnd.Skewed(5), &v)));
+ string(test::RandomString(&rnd, rnd.Skewed(5), &v)));
}
Test(&rnd);
}
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h
index 7d86be51da..86fa3ac5c2 100644
--- a/tensorflow/core/lib/jpeg/jpeg_handle.h
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.h
@@ -16,8 +16,8 @@ limitations under the License.
// This file declares the functions and structures for memory I/O with libjpeg
// These functions are not meant to be used directly, see jpeg_mem.h instead.
-#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
-#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
+#define TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
#include "tensorflow/core/platform/jpeg.h"
#include "tensorflow/core/platform/types.h"
@@ -57,4 +57,4 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
} // namespace jpeg
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h
index 59342d28c0..03437a4e78 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.h
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.h
@@ -18,8 +18,8 @@ limitations under the License.
// (data array and size fields).
// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop..
-#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
-#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
+#define TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
#include <functional>
#include <string>
@@ -159,4 +159,4 @@ bool Compress(const void* srcdata, int width, int height,
} // namespace jpeg
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h
index 41d486f2bd..502d741512 100644
--- a/tensorflow/core/lib/math/math_util.h
+++ b/tensorflow/core/lib/math/math_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_MATH_MATH_UTIL_H_
-#define TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
+#define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
#include <type_traits>
@@ -160,4 +160,4 @@ T MathUtil::IPow(T base, int exp) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc
index 8c28620ff9..fface033cb 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.cc
+++ b/tensorflow/core/lib/monitoring/collection_registry.cc
@@ -38,15 +38,15 @@ void Collector::CollectMetricDescriptor(
mutex_lock l(mu_);
return collected_metrics_->metric_descriptor_map
.insert(std::make_pair(
- std::string(metric_def->name()),
+ string(metric_def->name()),
std::unique_ptr<MetricDescriptor>(new MetricDescriptor())))
.first->second.get();
}();
- metric_descriptor->name = std::string(metric_def->name());
- metric_descriptor->description = std::string(metric_def->description());
+ metric_descriptor->name = string(metric_def->name());
+ metric_descriptor->description = string(metric_def->description());
for (const StringPiece label_name : metric_def->label_descriptions()) {
- metric_descriptor->label_names.push_back(std::string(label_name));
+ metric_descriptor->label_names.emplace_back(label_name);
}
metric_descriptor->metric_kind = metric_def->kind();
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 20f0444f8b..c204d52cfe 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -72,7 +72,7 @@ class MetricCollector {
registration_time_millis_(registration_time_millis),
collector_(collector),
point_set_(point_set) {
- point_set_->metric_name = std::string(metric_def->name());
+ point_set_->metric_name = string(metric_def->name());
}
const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
@@ -261,7 +261,7 @@ class Collector {
auto* const point_set = [&]() {
mutex_lock l(mu_);
return collected_metrics_->point_set_map
- .insert(std::make_pair(std::string(metric_def->name()),
+ .insert(std::make_pair(string(metric_def->name()),
std::unique_ptr<PointSet>(new PointSet())))
.first->second.get();
}();
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 6f94685665..756e5c2af8 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -98,8 +98,8 @@ class AbstractMetricDef {
const std::vector<string>& label_descriptions)
: kind_(kind),
value_type_(value_type),
- name_(std::string(name)),
- description_(std::string(description)),
+ name_(name),
+ description_(description),
label_descriptions_(std::vector<string>(label_descriptions.begin(),
label_descriptions.end())) {}
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h
index 25605d8ed4..7aa50ece03 100644
--- a/tensorflow/core/lib/random/distribution_sampler.h
+++ b/tensorflow/core/lib/random/distribution_sampler.h
@@ -28,8 +28,8 @@ limitations under the License.
//
// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
-#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
-#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
#include <memory>
#include <utility>
@@ -91,4 +91,4 @@ class DistributionSampler {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
index b2adb4462b..058ed95ffb 100644
--- a/tensorflow/core/lib/random/philox_random.h
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -17,8 +17,8 @@ limitations under the License.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
-#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
-#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
#include <stdlib.h>
@@ -248,4 +248,4 @@ class PhiloxRandom {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
index e963511f5c..c3801a0412 100644
--- a/tensorflow/core/lib/random/random_distributions.h
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
-#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
#define _USE_MATH_DEFINES
#include <math.h>
@@ -744,4 +744,4 @@ PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h
index d529e08913..6464036856 100644
--- a/tensorflow/core/lib/random/simple_philox.h
+++ b/tensorflow/core/lib/random/simple_philox.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
-#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
#include <math.h>
#include <string.h>
@@ -73,4 +73,4 @@ class SimplePhilox {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 1d5bacac93..959290ba8c 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_
-#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
#include <string>
@@ -140,11 +140,11 @@ inline bool ProtoParseNumeric(StringPiece s, uint64* value) {
}
inline bool ProtoParseNumeric(StringPiece s, float* value) {
- return safe_strtof(std::string(s).c_str(), value);
+ return safe_strtof(s, value);
}
inline bool ProtoParseNumeric(StringPiece s, double* value) {
- return safe_strtod(std::string(s).c_str(), value);
+ return safe_strtod(s, value);
}
// Convert strings to number of type T.
@@ -176,4 +176,4 @@ string HumanReadableElapsedTime(double seconds);
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
index cab8f81585..3aba5ec80e 100644
--- a/tensorflow/core/lib/strings/str_util.cc
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -332,7 +332,7 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all) {
// TODO(jlebar): We could avoid having to shift data around in the string if
// we had a StringPiece::find() overload that searched for a StringPiece.
- string res = std::string(s);
+ string res(s);
size_t pos = 0;
while ((pos = res.find(oldsub.data(), pos, oldsub.size())) != string::npos) {
res.replace(pos, oldsub.size(), newsub.data(), newsub.size());
@@ -448,8 +448,7 @@ bool SplitAndParseAsFloats(StringPiece text, char delim,
std::vector<float>* result) {
return SplitAndParseAsInts<float>(text, delim,
[](StringPiece str, float* value) {
- return strings::safe_strtof(
- std::string(str).c_str(), value);
+ return strings::safe_strtof(str, value);
},
result);
}
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index c887db7eff..9f52cf29fc 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
-#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
#include <functional>
#include <string>
@@ -205,7 +205,7 @@ std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) {
if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
StringPiece token(text.data() + token_start, i - token_start);
if (p(token)) {
- result.push_back(std::string(token));
+ result.emplace_back(token);
}
token_start = i + 1;
}
@@ -231,4 +231,4 @@ size_t Strnlen(const char* str, const size_t string_max_len);
} // namespace str_util
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index fb2cd5bc7e..a620f59447 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -17,8 +17,8 @@ limitations under the License.
// #category: operations on strings
// #summary: Merges strings or numbers with no delimiter.
//
-#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_
-#define TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
#include <string>
@@ -59,29 +59,29 @@ namespace tensorflow {
namespace strings {
enum PadSpec {
- NO_PAD = 1,
- ZERO_PAD_2,
- ZERO_PAD_3,
- ZERO_PAD_4,
- ZERO_PAD_5,
- ZERO_PAD_6,
- ZERO_PAD_7,
- ZERO_PAD_8,
- ZERO_PAD_9,
- ZERO_PAD_10,
- ZERO_PAD_11,
- ZERO_PAD_12,
- ZERO_PAD_13,
- ZERO_PAD_14,
- ZERO_PAD_15,
- ZERO_PAD_16,
+ kNoPad = 1,
+ kZeroPad2,
+ kZeroPad3,
+ kZeroPad4,
+ kZeroPad5,
+ kZeroPad6,
+ kZeroPad7,
+ kZeroPad8,
+ kZeroPad9,
+ kZeroPad10,
+ kZeroPad11,
+ kZeroPad12,
+ kZeroPad13,
+ kZeroPad14,
+ kZeroPad15,
+ kZeroPad16
};
struct Hex {
uint64 value;
enum PadSpec spec;
template <class Int>
- explicit Hex(Int v, PadSpec s = NO_PAD) : spec(s) {
+ explicit Hex(Int v, PadSpec s = kNoPad) : spec(s) {
// Prevent sign-extension by casting integers to
// their unsigned counterparts.
static_assert(
@@ -124,6 +124,9 @@ class AlphaNum {
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
: piece_(str) {}
+ template <typename A>
+ AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str)
+ : piece_(str) {} // NOLINT(runtime/explicit)
StringPiece::size_type size() const { return piece_.size(); }
const char *data() const { return piece_.data(); }
@@ -233,4 +236,4 @@ inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc
index 8cc64a6f0a..6c4e5526b1 100644
--- a/tensorflow/core/lib/strings/strcat_test.cc
+++ b/tensorflow/core/lib/strings/strcat_test.cc
@@ -308,11 +308,11 @@ TEST(StrAppend, Death) {
static void CheckHex64(uint64 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad16));
string expected = Printf("%016llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
- actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
expected = Printf("%08llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -323,7 +323,7 @@ static void CheckHex64(uint64 v) {
static void CheckHex32(uint32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -334,7 +334,7 @@ static void CheckHex32(uint32 v) {
static void CheckHexSigned32(int32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h
index f7957252ea..52af410d42 100644
--- a/tensorflow/core/lib/strings/stringprintf.h
+++ b/tensorflow/core/lib/strings/stringprintf.h
@@ -20,8 +20,8 @@ limitations under the License.
// strings::SPrintf(&result, "%d %s\n", 10, "hello");
// strings::Appendf(&result, "%d %s\n", 20, "there");
-#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
-#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
#include <stdarg.h>
#include <string>
@@ -49,4 +49,4 @@ extern void Appendv(string* dst, const char* format, va_list ap);
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 1f2e57e9a9..3d03bc1d5f 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -354,6 +354,27 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
+Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
+ // Ret val defs
+ {"dparams: Tparams", "dindices: Tindices"},
+ // Attr defs
+ {"Tparams: type", "Tindices: type"},
+ // Nodes
+ {
+ {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
+ {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
+ {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
+ {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
+
Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
*g = FDH::Define(
// Arg defs
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 1d11ec00ce..7dbb18aa5d 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1446,6 +1446,30 @@ REGISTER_OP("ShapeN")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn);
+REGISTER_OP("EnsureShape")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("shape: shape")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ // Merges desired shape and statically known shape of input
+ PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+
+ int rank = desired_shape.dims();
+ ShapeHandle input_shape_handle;
+ ShapeHandle desired_shape_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &desired_shape_handle));
+
+ ShapeHandle merged_shape;
+ TF_RETURN_IF_ERROR(
+ c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
+ c->set_output(0, merged_shape);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
REGISTER_OP("ReverseSequence")
.Input("input: T")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index c15409a246..03dab390a7 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -1620,6 +1620,24 @@ TEST(ArrayOpsTest, Slice_ShapeFn) {
INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
}
+TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
+ ShapeInferenceTestOp op("StridedSlice");
+ TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
+ .Input("input", 0, DT_FLOAT)
+ .Input("begin", 1, DT_INT32)
+ .Input("end", 2, DT_INT32)
+ .Input("strides", 3, DT_INT32)
+ .Attr("shrink_axis_mask", 1)
+ .Finalize(&op.node_def));
+ op.input_tensors.resize(4);
+ Tensor strides = test::AsTensor<int32>({1});
+ op.input_tensors[3] = &strides;
+ // Slicing on the 0-th dimension.
+ INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
+ // Slicing on the 0-th dimension. This time some of the result dimension is 0.
+ INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
+}
+
TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
ShapeInferenceTestOp op("StridedSliceGrad");
op.input_tensors.resize(5);
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 72b9477f28..34e6b5560b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -13070,6 +13070,71 @@ op {
is_stateful: true
}
op {
+ name: "ConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Conj"
input_arg {
name: "input"
@@ -20317,6 +20382,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -20865,6 +20955,25 @@ op {
is_stateful: true
}
op {
+ name: "EnsureShape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -22461,33 +22570,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -29364,6 +29446,49 @@ op {
}
}
op {
+ name: "MapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "MapDefun"
input_arg {
name: "arguments"
@@ -29991,6 +30116,32 @@ op {
}
}
op {
+ name: "MatrixExponential"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_DOUBLE
+ type: DT_FLOAT
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ deprecation {
+ version: 27
+ }
+}
+op {
name: "MatrixInverse"
input_arg {
name: "input"
@@ -35639,6 +35790,42 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV2"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
@@ -35666,6 +35853,46 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
@@ -35704,6 +35931,57 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
@@ -36994,6 +37272,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -37075,6 +37401,53 @@ op {
}
}
op {
+ name: "ParallelMapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ParameterizedTruncatedNormal"
input_arg {
name: "shape"
@@ -37284,6 +37657,271 @@ op {
}
}
op {
+ name: "ParseExampleDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "dense_defaults"
+ type_list_attr: "Tdense"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "sparse_types"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tdense"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "dense_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -43819,6 +44457,38 @@ op {
}
}
op {
+ name: "Relu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ type: DT_QINT8
+ }
+ }
+ }
+}
+op {
name: "Relu6"
input_arg {
name: "features"
@@ -56325,6 +56995,125 @@ op {
}
}
op {
+ name: "SdcaOptimizer"
+ input_arg {
+ name: "sparse_example_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_values"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features_with_values"
+ }
+ input_arg {
+ name: "dense_features"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "example_labels"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "sparse_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_delta_sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ output_arg {
+ name: "out_delta_dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ attr {
+ name: "loss_type"
+ type: "string"
+ allowed_values {
+ list {
+ s: "logistic_loss"
+ s: "squared_loss"
+ s: "hinge_loss"
+ s: "smooth_hinge_loss"
+ s: "poisson_loss"
+ }
+ }
+ }
+ attr {
+ name: "adaptative"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "num_sparse_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_sparse_features_with_values"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_dense_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "l1"
+ type: "float"
+ }
+ attr {
+ name: "l2"
+ type: "float"
+ }
+ attr {
+ name: "num_loss_partitions"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "num_inner_iterations"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "SdcaShrinkL1"
input_arg {
name: "weights"
@@ -64041,6 +64830,71 @@ op {
is_stateful: true
}
op {
+ name: "SparseConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "SparseCross"
input_arg {
name: "indices"
@@ -68834,6 +69688,47 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -71669,6 +72564,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -71785,6 +72699,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
@@ -73417,41 +74364,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce174..ffab8ad661 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 13733d48f0..9d2b3af51d 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -166,17 +166,21 @@ REGISTER_OP("LatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("FeatureStatsDataset")
+REGISTER_OP("ParseExampleDataset")
.Input("input_dataset: variant")
- .Input("tag: string")
+ .Input("num_parallel_calls: int64")
+ .Input("dense_defaults: Tdense")
.Output("handle: variant")
+ .Attr("sparse_keys: list(string) >= 0")
+ .Attr("dense_keys: list(string) >= 0")
+ .Attr("sparse_types: list({float,int64,string}) >= 0")
+ .Attr("Tdense: list({float,int64,string}) >= 0")
+ .Attr("dense_shapes: list(shape) >= 0")
.Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle tag_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
- return shape_inference::ScalarShape(c);
- });
+ .Attr("output_shapes: list(shape) >= 1") // Output components will be
+ // sorted by key (dense_keys and
+ // sparse_keys combined) here.
+ .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
@@ -194,6 +198,7 @@ REGISTER_OP("MapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParallelMapDataset")
@@ -205,6 +210,7 @@ REGISTER_OP("ParallelMapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@@ -321,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -862,7 +881,7 @@ REGISTER_OP("MapDefun")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<TensorShape> output_shapes;
+ std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
@@ -872,6 +891,10 @@ REGISTER_OP("MapDefun")
int64 dim_zero = -1;
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ if (c->Rank(c->input(i)) == 0) {
+ return errors::InvalidArgument(
+ "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ }
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
if (dim_zero == -1) {
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 11ca0bd259..5427275284 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression")
});
REGISTER_OP("NonMaxSuppressionV2")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
ShapeHandle boxes;
@@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2")
});
REGISTER_OP("NonMaxSuppressionV3")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
.Output("valid_outputs: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(NMSShapeFn(c));
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index f37f79ddbf..1d4d51a25d 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -235,6 +235,8 @@ REGISTER_OP("MatrixInverse")
.SetShapeFn(BatchUnchangedSquareShapeFn);
REGISTER_OP("MatrixExponential")
+ .Deprecated(
+ 27, "Use Python implementation tf.linalg.matrix_exponential instead.")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float, complex64, complex128}")
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index b9f94ba1c5..7d79df9c1c 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor")
shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o));
shape_inference::ShapeHandle element_shape;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}});
@@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
c->set_output_handle_shapes_and_types(
@@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListGather")
+ .Input("input_handle: variant")
+ .Input("indices: int32")
+ .Output("values: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListScatter")
+ .Input("tensor: element_dtype")
+ .Input("indices: int32")
+ .Input("element_shape: shape_type")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s));
+ c->set_output_handle_shapes_and_types(0, {{s, t}});
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index fbde692e95..639d211767 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -27,6 +28,8 @@ REGISTER_OP("Assert")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::NoOutputs);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Assert");
+
REGISTER_OP("Print")
.Input("input: T")
.Input("data: U")
@@ -39,6 +42,8 @@ REGISTER_OP("Print")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::UnchangedShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
@@ -116,4 +121,6 @@ REGISTER_OP("Timestamp")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Timestamp");
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 7c71406c6b..72a77be70d 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -294,7 +294,9 @@ REGISTER_OP("LookupTableImportV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
- // TODO: Validate keys and values shape.
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 57499a6f1d..07f876cb90 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,18 +495,18 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
+Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForBinaryCwise(g, {
- {{"gx"}, "UnsafeDiv", {"dz", "y"}},
+ {{"gx"}, "DivNoNan", {"dz", "y"}},
{{"nx"}, "Neg", {"x"}, {}, {"dz"}},
{{"y2"}, "Square", {"y"}, {}, {"dz"}},
- {{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
+ {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
{{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
});
// clang-format on
}
-REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
+REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index b0d1595c31..5ee79809ac 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,14 +753,14 @@ TEST_F(MathGradTest, Div) {
}
}
-TEST_F(MathGradTest, UnsafeDiv) {
+TEST_F(MathGradTest, DivNoNan) {
auto x = test::AsTensor<float>(
{0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
Tensor dx;
Tensor dy;
{
- SymGrad("UnsafeDiv", x, y, &dx, &dy);
+ SymGrad("DivNoNan", x, y, &dx, &dy);
{
auto g = [](float x, float y) {
if (y == 0.f) {
@@ -792,7 +792,7 @@ TEST_F(MathGradTest, UnsafeDiv) {
}
}
{ // Swap x and y.
- SymGrad("UnsafeDiv", y, x, &dy, &dx);
+ SymGrad("DivNoNan", y, x, &dy, &dx);
{
auto g = [](float x, float y) {
if (y == 0.f) {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 49646f1f3a..717263a9b0 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,8 +392,11 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
-REGISTER_OP("UnsafeDiv")
- .BINARY_MORE()
+REGISTER_OP("DivNoNan")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("FloorDiv")
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index ebeb048157..be4c3ed2b6 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -121,7 +121,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
- "UnsafeDiv"}) {
+ "DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 023f988f80..6c318e358a 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -960,7 +960,7 @@ REGISTER_OP("Dilation2DBackpropFilter")
REGISTER_OP("Relu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {realnumbertype, qint8}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("ReluGrad")
@@ -1024,6 +1024,7 @@ REGISTER_OP("SeluGrad")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
@@ -1037,6 +1038,7 @@ REGISTER_OP("SoftplusGrad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
@@ -1751,6 +1753,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklConv3D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter: uint8")
+ .Output("output: T")
+ .Output("filter_output: T")
+ .Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv3DShape)
+ .Doc(R"doc(
+MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropInputV2")
+ .Input("input_sizes: Tshape")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Input("mkl_input_sizes: uint8")
+ .Input("mkl_filter: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropFilterV2")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Input("out_backprop: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter_size: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklRelu")
.Input("features: T")
.Input("mkl_features: uint8")
@@ -1958,6 +2041,104 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklAvgPool3D")
+ .Input("value: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+
+REGISTER_OP("_MklAvgPool3DGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_grad: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of AvgPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3D")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("workspace: uint8")
+ .Output("mkl_output: uint8")
+ .Output("mkl_workspace: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float}")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3DGrad")
+ .Input("orig_input: TInput")
+ .Input("orig_output: TInput")
+ .Input("grad: T")
+ .Input("workspace: uint8")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_orig_output: uint8")
+ .Input("mkl_grad: uint8")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 5);
+ })
+ .Doc(R"doc(
+MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of MklPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklLRN")
.Input("input: T")
.Input("mkl_input: uint8")
@@ -2176,7 +2357,7 @@ REGISTER_OP("_MklToTf")
.Input("mkl_input: uint8")
.Output("output: T")
.Attr("T: {half, float, double}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to convert a tensor from MKL layout to TensorFlow layout.
@@ -2198,7 +2379,7 @@ REGISTER_OP("_MklInputConversion")
.Attr(
"T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
"complex64, complex128}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to process the inputs to an elementwise MKL op. Both inputs
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7693c2d485..a2fc76c8b6 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -5592,6 +5592,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -9190,6 +9203,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -9642,6 +9680,25 @@ op {
is_stateful: true
}
op {
+ name: "EnsureShape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -10415,33 +10472,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -14593,6 +14623,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "MapDefun"
@@ -15089,6 +15126,10 @@ op {
}
}
}
+ deprecation {
+ version: 27
+ explanation: "Use Python implementation tf.linalg.matrix_exponential instead."
+ }
}
op {
name: "MatrixInverse"
@@ -17125,11 +17166,11 @@ op {
name: "NonMaxSuppressionV2"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17143,16 +17184,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17170,16 +17224,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17202,6 +17269,19 @@ op {
type: DT_INT32
}
attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
name: "pad_to_max_output_size"
type: "bool"
default_value {
@@ -18239,6 +18319,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -18277,6 +18405,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "ParameterizedTruncatedNormal"
@@ -18425,6 +18560,271 @@ op {
}
}
op {
+ name: "ParseExampleDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "dense_defaults"
+ type_list_attr: "Tdense"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "sparse_types"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tdense"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "dense_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -22358,6 +22758,7 @@ op {
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
+ type: DT_QINT8
}
}
}
@@ -26758,6 +27159,7 @@ op {
s: "squared_loss"
s: "hinge_loss"
s: "smooth_hinge_loss"
+ s: "poisson_loss"
}
}
}
@@ -29390,6 +29792,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -31888,6 +32303,47 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -33904,6 +34360,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -34020,6 +34495,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
@@ -35028,41 +35536,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index ddb714b4e9..79ca96d249 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -132,6 +132,99 @@ REGISTER_OP("ParseSingleExample")
return Status::OK();
});
+REGISTER_OP("ParseSequenceExample")
+ .Input("serialized: string")
+ .Input("debug_name: string")
+ .Input("context_dense_defaults: Tcontext_dense")
+ .Output("context_sparse_indices: Ncontext_sparse * int64")
+ .Output("context_sparse_values: context_sparse_types")
+ .Output("context_sparse_shapes: Ncontext_sparse * int64")
+ .Output("context_dense_values: Tcontext_dense")
+ .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64")
+ .Output("feature_list_sparse_values: feature_list_sparse_types")
+ .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64")
+ .Output("feature_list_dense_values: feature_list_dense_types")
+ .Output("feature_list_dense_lengths: Nfeature_list_dense * int64")
+ .Attr("feature_list_dense_missing_assumed_empty: list(string) >= 0")
+ .Attr("context_sparse_keys: list(string) >= 0")
+ .Attr("context_dense_keys: list(string) >= 0")
+ .Attr("feature_list_sparse_keys: list(string) >= 0")
+ .Attr("feature_list_dense_keys: list(string) >= 0")
+ .Attr("Ncontext_sparse: int >= 0 = 0")
+ .Attr("Ncontext_dense: int >= 0 = 0")
+ .Attr("Nfeature_list_sparse: int >= 0 = 0")
+ .Attr("Nfeature_list_dense: int >= 0 = 0")
+ .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []")
+ .Attr("context_dense_shapes: list(shape) >= 0 = []")
+ .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_shapes: list(shape) >= 0 = []")
+ .SetShapeFn([](InferenceContext* c) {
+ ParseSequenceExampleAttrs attrs;
+ TF_RETURN_IF_ERROR(attrs.Init(c));
+
+ // Verify that the input is a vector, and carry the shape if known.
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
+ shape_inference::DimensionHandle num_examples = c->Dim(input, 0);
+
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // debug_name
+
+ int output_idx = 0;
+
+ // Output context_sparse_indices, context_sparse_values, and
+ // context_sparse_shapes.
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 2));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(2));
+ }
+
+ // Output context_dense_values.
+ for (int i = 0; i < attrs.num_context_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.context_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(num_examples), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_sparse_indices, feature_list_sparse_values,
+ // feature_list_sparse_shapes.
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 3));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(3));
+ }
+
+ // Output feature_list_dense_shapes.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.feature_list_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->Matrix(num_examples, c->UnknownDim()), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_dense_lengths.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ c->set_output(output_idx++, c->Vector(num_examples));
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ParseSingleSequenceExample")
.Input("serialized: string")
.Input("feature_list_dense_missing_assumed_empty: string")
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index 9121d7ae92..c65e66d1a8 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -143,6 +143,88 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) {
"?;?;?;?;?;?;?;?;?;?");
}
+TEST(ParsingOpsTest, ParseSequenceExample_ShapeFn) {
+ ShapeInferenceTestOp op("ParseSequenceExample");
+ auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
+ int num_feature_list_sparse,
+ int num_feature_list_dense,
+ bool add_extra_shape = false) {
+ using NodeOutList = std::vector<NodeDefBuilder::NodeOut>;
+ using DataTypeList = std::vector<DataType>;
+ string string_in("test");
+ NodeDefBuilder::NodeOut node_in{"a", 0, DT_STRING};
+ TF_ASSERT_OK(
+ NodeDefBuilder("test", "ParseSequenceExample")
+ .Input("serialized", 0, DT_STRING)
+ .Input("debug_name", 0, DT_STRING)
+ .Input(NodeOutList(num_context_dense, node_in))
+ .Attr("Ncontext_sparse", num_context_sparse)
+ .Attr("Ncontext_dense", num_context_dense)
+ .Attr("Nfeature_list_sparse", num_feature_list_sparse)
+ .Attr("Nfeature_list_dense", num_feature_list_dense)
+ .Attr("feature_list_dense_missing_assumed_empty",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_keys",
+ std::vector<string>(num_context_sparse, string_in))
+ .Attr("context_dense_keys",
+ std::vector<string>(num_context_dense, string_in))
+ .Attr("feature_list_sparse_keys",
+ std::vector<string>(num_feature_list_sparse, string_in))
+ .Attr("feature_list_dense_keys",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_types",
+ DataTypeList(num_context_sparse, DT_FLOAT))
+ .Attr("context_dense_types",
+ DataTypeList(num_context_dense, DT_FLOAT))
+ .Attr("context_dense_shapes",
+ MakeDenseShapes(num_context_dense, add_extra_shape, 0))
+ .Attr("feature_list_sparse_types",
+ DataTypeList(num_feature_list_sparse, DT_FLOAT))
+ .Attr("feature_list_dense_types",
+ DataTypeList(num_feature_list_dense, DT_FLOAT))
+ .Attr("feature_list_dense_shapes",
+ MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0))
+ .Finalize(&op.node_def));
+ };
+
+ // Verify inputs 'serialized' and 'debug_name'.
+ set_outputs(0, 0, 0, 0);
+ INFER_OK(op, "[?];[?]", "");
+ INFER_OK(op, "[8];[8]", "");
+ INFER_ERROR("must be rank 1", op, "[];[?]");
+ INFER_ERROR("must be rank 1", op, "[?];[]");
+
+ // context inputs with no feature_list inputs.
+ set_outputs(2 /* num_context_sparse */, 3 /* num_context_dense */, 0, 0);
+ INFER_OK(op, "[?];[?];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3]")); // context dense
+
+ // feature_list inputs with no context inputs.
+ set_outputs(0, 0, 2 /* num_feature_list_sparse */,
+ 3 /* num_feature_list_dense */);
+ INFER_OK(op, "[?];[?]",
+ ("[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Combine previous two test cases.
+ set_outputs(2, 3, 2, 3);
+ INFER_OK(op, "[7];[7];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3];" // context dense
+ "[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Confirm an error from ParseSequenceExampleAttrs.Init().
+ set_outputs(1, 1, 1, 1, true /* add_extra_shape */);
+ INFER_ERROR(
+ "num_context_dense (1) must match the size of context_dense_keys (1), "
+ "context_dense_types (1) and context_dense_shapes (2)",
+ op, "[?];[?];?");
+}
+
TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) {
ShapeInferenceTestOp op("ParseSingleSequenceExample");
auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc
index 4025070adb..fdf53a55dd 100644
--- a/tensorflow/core/ops/sdca_ops.cc
+++ b/tensorflow/core/ops/sdca_ops.cc
@@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) {
REGISTER_OP("SdcaOptimizer")
.Attr(
"loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss',"
- "'smooth_hinge_loss'}")
+ "'smooth_hinge_loss', 'poisson_loss'}")
.Attr("adaptative : bool=false")
.Attr("num_sparse_features: int >= 0")
.Attr("num_sparse_features_with_values: int >= 0")
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index d1e38e6d22..ef8b15dc8a 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -37,6 +37,14 @@ REGISTER_OP("RegexReplace")
return Status::OK();
});
+REGISTER_OP("StaticRegexReplace")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Attr("rewrite: string")
+ .Output("output: string")
+ .Attr("replace_global: bool = true")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("RegexFullMatch")
.Input("input: string")
.Input("pattern: string")
@@ -48,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
+REGISTER_OP("StaticRegexFullMatch")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 763d467457..591e83b0c4 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_ABI_H_
-#define TENSORFLOW_PLATFORM_ABI_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_ABI_H_
+#define TENSORFLOW_CORE_PLATFORM_ABI_H_
#include <string>
@@ -26,4 +26,4 @@ std::string MaybeAbiDemangle(const char* name);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_ABI_H_
+#endif // TENSORFLOW_CORE_PLATFORM_ABI_H_
diff --git a/tensorflow/core/platform/cloud/auth_provider.h b/tensorflow/core/platform/cloud/auth_provider.h
index 465ff248d9..7347bc626d 100644
--- a/tensorflow/core/platform/cloud/auth_provider.h
+++ b/tensorflow/core/platform/cloud/auth_provider.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
-#define TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
#include <string>
#include "tensorflow/core/lib/core/errors.h"
@@ -51,4 +51,4 @@ class EmptyAuthProvider : public AuthProvider {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
index dacf56187c..e147d88371 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
@@ -43,7 +43,7 @@ Status ComputeEngineZoneProvider::GetZone(string* zone) {
*zone = cached_zone;
} else {
LOG(ERROR) << "Failed to parse the zone name from location: "
- << location.ToString();
+ << string(location);
}
return Status::OK();
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index a1be4aacce..5e1eabee5b 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size,
.StopCapture()
.OneLiteral(": ")
.GetResult(&value, &name)) {
- string str_value = std::string(value);
+ string str_value(value);
str_util::StripTrailingWhitespace(&str_value);
- that->response_headers_[std::string(name)] = str_value;
+ that->response_headers_[string(name)] = str_value;
}
return size * nmemb;
}
diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h
index 40f16f1044..07d0e59fd5 100644
--- a/tensorflow/core/platform/cloud/gcs_dns_cache.h
+++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
-#define TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
#include <random>
@@ -74,4 +74,4 @@ class GcsDnsCache {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 9d33787bd5..8f959c018e 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
fname);
}
- *bucket = std::string(bucketp);
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = std::string(objectp);
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
fname);
@@ -224,7 +224,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
for (const string& path : paths) {
StringPiece subpath = io::Dirname(path);
while (!subpath.empty()) {
- result.emplace(std::string(subpath));
+ result.emplace(string(subpath));
subpath = io::Dirname(subpath);
}
}
@@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() {
if (!header_name.empty() && !header_value.empty()) {
additional_header_.reset(new std::pair<const string, const string>(
- std::string(header_name), std::string(header_value)));
+ string(header_name), string(header_value)));
VLOG(1) << "GCS additional header ENABLED. "
<< "Name: " << additional_header_->first << ", "
@@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern,
// Find the fixed prefix by looking for the first wildcard.
const string& fixed_prefix =
pattern.substr(0, pattern.find_first_of("*?[\\"));
- const string& dir = std::string(io::Dirname(fixed_prefix));
+ const string dir(io::Dirname(fixed_prefix));
if (dir.empty()) {
return errors::InvalidArgument(
"A GCS pattern doesn't have a bucket name: ", pattern);
@@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
" doesn't match the prefix ", object_prefix));
}
if (!relative_path.empty() || include_self_directory_marker) {
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
}
if (++retrieved_results >= max_results) {
return Status::OK();
@@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
"Unexpected response: the returned folder name ", prefix_str,
" doesn't match the prefix ", object_prefix);
}
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
if (++retrieved_results >= max_results) {
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.h b/tensorflow/core/platform/cloud/google_auth_provider.h
index 58a785fd60..3755b124a8 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.h
+++ b/tensorflow/core/platform/cloud/google_auth_provider.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
-#define TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
#include <memory>
#include "tensorflow/core/platform/cloud/auth_provider.h"
@@ -65,4 +65,4 @@ class GoogleAuthProvider : public AuthProvider {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 2343bca608..e925eefb1f 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
-#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
#include <string>
#include <unordered_map>
@@ -188,4 +188,4 @@ class HttpRequest {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h
index 7711eaceb2..0a1164b64a 100644
--- a/tensorflow/core/platform/cloud/http_request_fake.h
+++ b/tensorflow/core/platform/cloud/http_request_fake.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
-#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
#include <algorithm>
#include <fstream>
@@ -212,4 +212,4 @@ class FakeHttpRequestFactory : public HttpRequest::Factory {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index ee6ba7b041..9b85cae9b9 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson(
// Send the request to the Google OAuth 2.0 server to get the token.
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
@@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson(
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 4ffa72288b..1cd0641cd3 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
grant_type);
- int last_dot = std::string(assertion).find_last_of(".");
- string header_dot_claim = std::string(assertion.substr(0, last_dot));
- string signature_encoded = std::string(assertion.substr(last_dot + 1));
+ int last_dot = assertion.rfind('.');
+ string header_dot_claim(assertion.substr(0, last_dot));
+ string signature_encoded(assertion.substr(last_dot + 1));
// Check that 'signature' signs 'header_dot_claim'.
diff --git a/tensorflow/core/platform/context.h b/tensorflow/core/platform/context.h
index 728ef91631..9f7beb7a68 100644
--- a/tensorflow/core/platform/context.h
+++ b/tensorflow/core/platform/context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CONTEXT_H_
-#define TENSORFLOW_PLATFORM_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
+#define TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
namespace tensorflow {
@@ -42,4 +42,4 @@ class WithContext;
#include "tensorflow/core/platform/default/context.h"
#endif
-#endif // TENSORFLOW_PLATFORM_CONTEXT_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
diff --git a/tensorflow/core/platform/cpu_feature_guard.h b/tensorflow/core/platform/cpu_feature_guard.h
index 586a6be55e..3d7bfe95b1 100644
--- a/tensorflow/core/platform/cpu_feature_guard.h
+++ b/tensorflow/core/platform/cpu_feature_guard.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
-#define TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
+#define TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
namespace tensorflow {
namespace port {
@@ -29,4 +29,4 @@ void InfoAboutUnusedCPUFeatures();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h
index 175c9ae8b1..6eba83224a 100644
--- a/tensorflow/core/platform/cpu_info.h
+++ b/tensorflow/core/platform/cpu_info.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CPU_INFO_H_
-#define TENSORFLOW_PLATFORM_CPU_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
#include <string>
@@ -117,4 +117,4 @@ int CPUIDNumSMT();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_CPU_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 7251c6c725..bb841aeab7 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -13,219 +13,224 @@ load(
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
- tf_deps = []
+ tf_deps = []
- # If the package name is in shorthand form (ie: does not contain a ':'),
- # expand it to the full name.
- for dep in deps:
- tf_dep = dep
+ # If the package name is in shorthand form (ie: does not contain a ':'),
+ # expand it to the full name.
+ for dep in deps:
+ tf_dep = dep
- if not ":" in dep:
- dep_pieces = dep.split("/")
- tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
+ if not ":" in dep:
+ dep_pieces = dep.split("/")
+ tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
- tf_deps += [tf_dep + suffix]
+ tf_deps += [tf_dep + suffix]
- return tf_deps
+ return tf_deps
# Modified from @cython//:Tools/rules.bzl
def pyx_library(
- name,
- deps=[],
- py_deps=[],
- srcs=[],
- **kwargs):
- """Compiles a group of .pyx / .pxd / .py files.
-
- First runs Cython to create .cpp files for each input .pyx or .py + .pxd
- pair. Then builds a shared object for each, passing "deps" to each cc_binary
- rule (includes Python headers by default). Finally, creates a py_library rule
- with the shared objects and any pure Python "srcs", with py_deps as its
- dependencies; the shared objects can be imported like normal Python files.
-
- Args:
- name: Name for the rule.
- deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
- py_deps: Pure Python dependencies of the final library.
- srcs: .py, .pyx, or .pxd files to either compile or pass through.
- **kwargs: Extra keyword arguments passed to the py_library.
- """
- # First filter out files that should be run compiled vs. passed through.
- py_srcs = []
- pyx_srcs = []
- pxd_srcs = []
- for src in srcs:
- if src.endswith(".pyx") or (src.endswith(".py")
- and src[:-3] + ".pxd" in srcs):
- pyx_srcs.append(src)
- elif src.endswith(".py"):
- py_srcs.append(src)
- else:
- pxd_srcs.append(src)
- if src.endswith("__init__.py"):
- pxd_srcs.append(src)
-
- # Invoke cython to produce the shared object libraries.
- for filename in pyx_srcs:
- native.genrule(
- name = filename + "_cython_translation",
- srcs = [filename],
- outs = [filename.split(".")[0] + ".cpp"],
- # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
- # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
- cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
- tools = ["@cython//:cython_binary"] + pxd_srcs,
+ name,
+ deps = [],
+ py_deps = [],
+ srcs = [],
+ **kwargs):
+ """Compiles a group of .pyx / .pxd / .py files.
+
+ First runs Cython to create .cpp files for each input .pyx or .py + .pxd
+ pair. Then builds a shared object for each, passing "deps" to each cc_binary
+ rule (includes Python headers by default). Finally, creates a py_library rule
+ with the shared objects and any pure Python "srcs", with py_deps as its
+ dependencies; the shared objects can be imported like normal Python files.
+
+ Args:
+ name: Name for the rule.
+ deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
+ py_deps: Pure Python dependencies of the final library.
+ srcs: .py, .pyx, or .pxd files to either compile or pass through.
+ **kwargs: Extra keyword arguments passed to the py_library.
+ """
+
+ # First filter out files that should be run compiled vs. passed through.
+ py_srcs = []
+ pyx_srcs = []
+ pxd_srcs = []
+ for src in srcs:
+ if src.endswith(".pyx") or (src.endswith(".py") and
+ src[:-3] + ".pxd" in srcs):
+ pyx_srcs.append(src)
+ elif src.endswith(".py"):
+ py_srcs.append(src)
+ else:
+ pxd_srcs.append(src)
+ if src.endswith("__init__.py"):
+ pxd_srcs.append(src)
+
+ # Invoke cython to produce the shared object libraries.
+ for filename in pyx_srcs:
+ native.genrule(
+ name = filename + "_cython_translation",
+ srcs = [filename],
+ outs = [filename.split(".")[0] + ".cpp"],
+ # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
+ # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
+ cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
+ tools = ["@cython//:cython_binary"] + pxd_srcs,
+ )
+
+ shared_objects = []
+ for src in pyx_srcs:
+ stem = src.split(".")[0]
+ shared_object_name = stem + ".so"
+ native.cc_binary(
+ name = shared_object_name,
+ srcs = [stem + ".cpp"],
+ deps = deps + ["//third_party/python_runtime:headers"],
+ linkshared = 1,
+ )
+ shared_objects.append(shared_object_name)
+
+ # Now create a py_library with these shared objects as data.
+ native.py_library(
+ name = name,
+ srcs = py_srcs,
+ deps = py_deps,
+ srcs_version = "PY2AND3",
+ data = shared_objects,
+ **kwargs
)
- shared_objects = []
- for src in pyx_srcs:
- stem = src.split(".")[0]
- shared_object_name = stem + ".so"
- native.cc_binary(
- name=shared_object_name,
- srcs=[stem + ".cpp"],
- deps=deps + ["//third_party/python_runtime:headers"],
- linkshared = 1,
- )
- shared_objects.append(shared_object_name)
-
- # Now create a py_library with these shared objects as data.
- native.py_library(
- name=name,
- srcs=py_srcs,
- deps=py_deps,
- srcs_version = "PY2AND3",
- data=shared_objects,
- **kwargs
- )
-
-def _proto_cc_hdrs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
- return ret
-
-def _proto_cc_srcs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
- return ret
-
-def _proto_py_outs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
- return ret
+def _proto_cc_hdrs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
+ return ret
+
+def _proto_cc_srcs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
+ return ret
+
+def _proto_py_outs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
+ return ret
# Re-defined protocol buffer rule to allow building "header only" protocol
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
# containing select() statements.
def cc_proto_library(
- name,
- srcs=[],
- deps=[],
- cc_libs=[],
- include=None,
- protoc="@protobuf_archive//:protoc",
- internal_bootstrap_hack=False,
- use_grpc_plugin=False,
- use_grpc_namespace=False,
- default_header=False,
- **kargs):
- """Bazel rule to create a C++ protobuf library from proto source files.
-
- Args:
- name: the name of the cc_proto_library.
- srcs: the .proto files of the cc_proto_library.
- deps: a list of dependency labels; must be cc_proto_library.
- cc_libs: a list of other cc_library targets depended by the generated
- cc_library.
- include: a string indicating the include path of the .proto files.
- protoc: the label of the protocol compiler to generate the sources.
- internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
- for bootstraping. When it is set to True, no files will be generated.
- The rule will simply be a provider for .proto files, so that other
- cc_proto_library can depend on it.
- use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
- when processing the proto files.
- default_header: Controls the naming of generated rules. If True, the `name`
- rule will be header-only, and an _impl rule will contain the
- implementation. Otherwise the header-only rule (name + "_headers_only")
- must be referred to explicitly.
- **kargs: other keyword arguments that are passed to cc_library.
- """
-
- includes = []
- if include != None:
- includes = [include]
-
- if internal_bootstrap_hack:
- # For pre-checked-in generated files, we add the internal_bootstrap_hack
- # which will skip the codegen action.
+ name,
+ srcs = [],
+ deps = [],
+ cc_libs = [],
+ include = None,
+ protoc = "@protobuf_archive//:protoc",
+ internal_bootstrap_hack = False,
+ use_grpc_plugin = False,
+ use_grpc_namespace = False,
+ default_header = False,
+ **kargs):
+ """Bazel rule to create a C++ protobuf library from proto source files.
+
+ Args:
+ name: the name of the cc_proto_library.
+ srcs: the .proto files of the cc_proto_library.
+ deps: a list of dependency labels; must be cc_proto_library.
+ cc_libs: a list of other cc_library targets depended by the generated
+ cc_library.
+ include: a string indicating the include path of the .proto files.
+ protoc: the label of the protocol compiler to generate the sources.
+ internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
+ for bootstraping. When it is set to True, no files will be generated.
+ The rule will simply be a provider for .proto files, so that other
+ cc_proto_library can depend on it.
+ use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
+ when processing the proto files.
+ default_header: Controls the naming of generated rules. If True, the `name`
+ rule will be header-only, and an _impl rule will contain the
+ implementation. Otherwise the header-only rule (name + "_headers_only")
+ must be referred to explicitly.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ if internal_bootstrap_hack:
+ # For pre-checked-in generated files, we add the internal_bootstrap_hack
+ # which will skip the codegen action.
+ proto_gen(
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ visibility = ["//visibility:public"],
+ )
+
+ # An empty cc_library to make rule dependency consistent.
+ native.cc_library(
+ name = name,
+ **kargs
+ )
+ return
+
+ grpc_cpp_plugin = None
+ plugin_options = []
+ if use_grpc_plugin:
+ grpc_cpp_plugin = "//external:grpc_cpp_plugin"
+ if use_grpc_namespace:
+ plugin_options = ["services_namespace=grpc"]
+
+ gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
+ gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
+ outs = gen_srcs + gen_hdrs
+
proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ plugin = grpc_cpp_plugin,
+ plugin_language = "grpc",
+ plugin_options = plugin_options,
+ gen_cc = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
)
- # An empty cc_library to make rule dependency consistent.
- native.cc_library(
- name=name,
- **kargs)
- return
-
- grpc_cpp_plugin = None
- plugin_options = []
- if use_grpc_plugin:
- grpc_cpp_plugin = "//external:grpc_cpp_plugin"
- if use_grpc_namespace:
- plugin_options = ["services_namespace=grpc"]
-
- gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
- gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
- outs = gen_srcs + gen_hdrs
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- plugin=grpc_cpp_plugin,
- plugin_language="grpc",
- plugin_options=plugin_options,
- gen_cc=1,
- outs=outs,
- visibility=["//visibility:public"],
- )
-
- if use_grpc_plugin:
- cc_libs += select({
- "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
- "//conditions:default": ["//external:grpc_lib"],
- })
- if default_header:
- header_only_name = name
- impl_name = name + "_impl"
- else:
- header_only_name = name + "_headers_only"
- impl_name = name
-
- native.cc_library(
- name=impl_name,
- srcs=gen_srcs,
- hdrs=gen_hdrs,
- deps=cc_libs + deps,
- includes=includes,
- **kargs)
- native.cc_library(
- name=header_only_name,
- deps=["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
- hdrs=gen_hdrs,
- **kargs)
+ if use_grpc_plugin:
+ cc_libs += select({
+ "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
+ "//conditions:default": ["//external:grpc_lib"],
+ })
+
+ if default_header:
+ header_only_name = name
+ impl_name = name + "_impl"
+ else:
+ header_only_name = name + "_headers_only"
+ impl_name = name
+
+ native.cc_library(
+ name = impl_name,
+ srcs = gen_srcs,
+ hdrs = gen_hdrs,
+ deps = cc_libs + deps,
+ includes = includes,
+ **kargs
+ )
+ native.cc_library(
+ name = header_only_name,
+ deps = ["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
+ hdrs = gen_hdrs,
+ **kargs
+ )
# Re-defined protocol buffer rule to bring in the change introduced in commit
# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
@@ -234,484 +239,517 @@ def cc_proto_library(
# to include the above commit.
def py_proto_library(
name,
- srcs=[],
- deps=[],
- py_libs=[],
- py_extra_srcs=[],
- include=None,
- default_runtime="@protobuf_archive//:protobuf_python",
- protoc="@protobuf_archive//:protoc",
- use_grpc_plugin=False,
+ srcs = [],
+ deps = [],
+ py_libs = [],
+ py_extra_srcs = [],
+ include = None,
+ default_runtime = "@protobuf_archive//:protobuf_python",
+ protoc = "@protobuf_archive//:protoc",
+ use_grpc_plugin = False,
**kargs):
- """Bazel rule to create a Python protobuf library from proto source files
-
- NOTE: the rule is only an internal workaround to generate protos. The
- interface may change and the rule may be removed when bazel has introduced
- the native rule.
-
- Args:
- name: the name of the py_proto_library.
- srcs: the .proto files of the py_proto_library.
- deps: a list of dependency labels; must be py_proto_library.
- py_libs: a list of other py_library targets depended by the generated
- py_library.
- py_extra_srcs: extra source files that will be added to the output
- py_library. This attribute is used for internal bootstrapping.
- include: a string indicating the include path of the .proto files.
- default_runtime: the implicitly default runtime which will be depended on by
- the generated py_library target.
- protoc: the label of the protocol compiler to generate the sources.
- use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
- when processing the proto files.
- **kargs: other keyword arguments that are passed to cc_library.
- """
- outs = _proto_py_outs(srcs, use_grpc_plugin)
-
- includes = []
- if include != None:
- includes = [include]
-
- grpc_python_plugin = None
- if use_grpc_plugin:
- grpc_python_plugin = "//external:grpc_python_plugin"
- # Note: Generated grpc code depends on Python grpc module. This dependency
- # is not explicitly listed in py_libs. Instead, host system is assumed to
- # have grpc installed.
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- gen_py=1,
- outs=outs,
- visibility=["//visibility:public"],
- plugin=grpc_python_plugin,
- plugin_language="grpc"
- )
-
- if default_runtime and not default_runtime in py_libs + deps:
- py_libs = py_libs + [default_runtime]
-
- native.py_library(
- name=name,
- srcs=outs+py_extra_srcs,
- deps=py_libs+deps,
- imports=includes,
- **kargs)
-
-def tf_proto_library_cc(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_stubby_versions = None,
- cc_grpc_version = None,
- j2objc_api_version = 1,
- cc_api_version = 2,
- dart_api_version = 2,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- default_header = False):
- js_codegen = js_codegen # unused argument
- js_api_version = js_api_version # unused argument
- native.filegroup(
- name = name + "_proto_srcs",
- srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
- testonly = testonly,
- visibility = visibility,
- )
-
- use_grpc_plugin = None
- if cc_grpc_version:
- use_grpc_plugin = True
-
- cc_deps = tf_deps(protodeps, "_cc")
- cc_name = name + "_cc"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
+ """Bazel rule to create a Python protobuf library from proto source files
+
+ NOTE: the rule is only an internal workaround to generate protos. The
+ interface may change and the rule may be removed when bazel has introduced
+ the native rule.
+
+ Args:
+ name: the name of the py_proto_library.
+ srcs: the .proto files of the py_proto_library.
+ deps: a list of dependency labels; must be py_proto_library.
+ py_libs: a list of other py_library targets depended by the generated
+ py_library.
+ py_extra_srcs: extra source files that will be added to the output
+ py_library. This attribute is used for internal bootstrapping.
+ include: a string indicating the include path of the .proto files.
+ default_runtime: the implicitly default runtime which will be depended on by
+ the generated py_library target.
+ protoc: the label of the protocol compiler to generate the sources.
+ use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
+ when processing the proto files.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+ outs = _proto_py_outs(srcs, use_grpc_plugin)
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ grpc_python_plugin = None
+ if use_grpc_plugin:
+ grpc_python_plugin = "//external:grpc_python_plugin"
+ # Note: Generated grpc code depends on Python grpc module. This dependency
+ # is not explicitly listed in py_libs. Instead, host system is assumed to
+ # have grpc installed.
+
proto_gen(
- name = cc_name + "_genproto",
- deps = [s + "_genproto" for s in cc_deps],
- protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ gen_py = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
+ plugin = grpc_python_plugin,
+ plugin_language = "grpc",
)
- native.cc_library(
- name = cc_name,
- deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
- if_static([name + "_cc_impl"]),
+
+ if default_runtime and not default_runtime in py_libs + deps:
+ py_libs = py_libs + [default_runtime]
+
+ native.py_library(
+ name = name,
+ srcs = outs + py_extra_srcs,
+ deps = py_libs + deps,
+ imports = includes,
+ **kargs
+ )
+
+def tf_proto_library_cc(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_stubby_versions = None,
+ cc_grpc_version = None,
+ j2objc_api_version = 1,
+ cc_api_version = 2,
+ dart_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ default_header = False):
+ js_codegen = js_codegen # unused argument
+ js_api_version = js_api_version # unused argument
+ native.filegroup(
+ name = name + "_proto_srcs",
+ srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
testonly = testonly,
visibility = visibility,
)
- native.cc_library(
- name = cc_name + "_impl",
- deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
- )
- return
-
- cc_proto_library(
- name = cc_name,
- srcs = srcs,
- deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
- cc_libs = cc_libs + if_static(
- ["@protobuf_archive//:protobuf"],
- ["@protobuf_archive//:protobuf_headers"]
- ),
- copts = if_not_windows([
- "-Wno-unknown-warning-option",
- "-Wno-unused-but-set-variable",
- "-Wno-sign-compare",
- ]),
- protoc = "@protobuf_archive//:protoc",
- use_grpc_plugin = use_grpc_plugin,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
-def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
- testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
- py_deps = tf_deps(protodeps, "_py")
- py_name = name + "_py"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
- proto_gen(
- name = py_name + "_genproto",
- deps = [s + "_genproto" for s in py_deps],
+ use_grpc_plugin = None
+ if cc_grpc_version:
+ use_grpc_plugin = True
+
+ cc_deps = tf_deps(protodeps, "_cc")
+ cc_name = name + "_cc"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = cc_name + "_genproto",
+ deps = [s + "_genproto" for s in cc_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.cc_library(
+ name = cc_name,
+ deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
+ if_static([name + "_cc_impl"]),
+ testonly = testonly,
+ visibility = visibility,
+ )
+ native.cc_library(
+ name = cc_name + "_impl",
+ deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
+ )
+
+ return
+
+ cc_proto_library(
+ name = cc_name,
+ srcs = srcs,
+ deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
+ cc_libs = cc_libs + if_static(
+ ["@protobuf_archive//:protobuf"],
+ ["@protobuf_archive//:protobuf_headers"],
+ ),
+ copts = if_not_windows([
+ "-Wno-unknown-warning-option",
+ "-Wno-unused-but-set-variable",
+ "-Wno-sign-compare",
+ ]),
protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ use_grpc_plugin = use_grpc_plugin,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
)
- native.py_library(
+
+def tf_proto_library_py(
+ name,
+ srcs = [],
+ protodeps = [],
+ deps = [],
+ visibility = [],
+ testonly = 0,
+ srcs_version = "PY2AND3",
+ use_grpc_plugin = False):
+ py_deps = tf_deps(protodeps, "_py")
+ py_name = name + "_py"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = py_name + "_genproto",
+ deps = [s + "_genproto" for s in py_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.py_library(
+ name = py_name,
+ deps = py_deps + ["@protobuf_archive//:protobuf_python"],
+ testonly = testonly,
+ visibility = visibility,
+ )
+ return
+
+ py_proto_library(
name = py_name,
- deps = py_deps + ["@protobuf_archive//:protobuf_python"],
- testonly = testonly,
+ srcs = srcs,
+ srcs_version = srcs_version,
+ deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
+ protoc = "@protobuf_archive//:protoc",
+ default_runtime = "@protobuf_archive//:protobuf_python",
visibility = visibility,
+ testonly = testonly,
+ use_grpc_plugin = use_grpc_plugin,
)
- return
-
- py_proto_library(
- name = py_name,
- srcs = srcs,
- srcs_version = srcs_version,
- deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
- protoc = "@protobuf_archive//:protoc",
- default_runtime = "@protobuf_archive//:protobuf_python",
- visibility = visibility,
- testonly = testonly,
- use_grpc_plugin = use_grpc_plugin,
- )
def tf_jspb_proto_library(**kwargs):
- pass
+ pass
def tf_nano_proto_library(**kwargs):
- pass
-
-def tf_proto_library(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_api_version = 2, cc_grpc_version = None,
- dart_api_version = 2, j2objc_api_version = 1,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- provide_cc_alias = False,
- default_header = False):
- """Make a proto library, possibly depending on other proto libraries."""
- _ignore = (js_api_version, js_codegen, provide_cc_alias)
-
- tf_proto_library_cc(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- cc_grpc_version = cc_grpc_version,
- cc_libs = cc_libs,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
- tf_proto_library_py(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- srcs_version = "PY2AND3",
- testonly = testonly,
- visibility = visibility,
- use_grpc_plugin = has_services,
- )
+ pass
+
+def tf_proto_library(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_api_version = 2,
+ cc_grpc_version = None,
+ dart_api_version = 2,
+ j2objc_api_version = 1,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ provide_cc_alias = False,
+ default_header = False):
+ """Make a proto library, possibly depending on other proto libraries."""
+ _ignore = (js_api_version, js_codegen, provide_cc_alias)
+
+ tf_proto_library_cc(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ cc_grpc_version = cc_grpc_version,
+ cc_libs = cc_libs,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
+ )
+
+ tf_proto_library_py(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ srcs_version = "PY2AND3",
+ testonly = testonly,
+ visibility = visibility,
+ use_grpc_plugin = has_services,
+ )
# A list of all files under platform matching the pattern in 'files'. In
# contrast with 'tf_platform_srcs' below, which seletive collects files that
# must be compiled in the 'default' platform, this is a list of all headers
# mentioned in the platform/* files.
def tf_platform_hdrs(files):
- return native.glob(["platform/*/" + f for f in files])
+ return native.glob(["platform/*/" + f for f in files])
def tf_platform_srcs(files):
- base_set = ["platform/default/" + f for f in files]
- windows_set = base_set + ["platform/windows/" + f for f in files]
- posix_set = base_set + ["platform/posix/" + f for f in files]
-
- # Handle cases where we must also bring the posix file in. Usually, the list
- # of files to build on windows builds is just all the stuff in the
- # windows_set. However, in some cases the implementations in 'posix/' are
- # just what is necessary and historically we choose to simply use the posix
- # file instead of making a copy in 'windows'.
- for f in files:
- if f == "error.cc":
- windows_set.append("platform/posix/" + f)
-
- return select({
- "//tensorflow:windows" : native.glob(windows_set),
- "//conditions:default" : native.glob(posix_set),
- })
+ base_set = ["platform/default/" + f for f in files]
+ windows_set = base_set + ["platform/windows/" + f for f in files]
+ posix_set = base_set + ["platform/posix/" + f for f in files]
+
+ # Handle cases where we must also bring the posix file in. Usually, the list
+ # of files to build on windows builds is just all the stuff in the
+ # windows_set. However, in some cases the implementations in 'posix/' are
+ # just what is necessary and historically we choose to simply use the posix
+ # file instead of making a copy in 'windows'.
+ for f in files:
+ if f == "error.cc":
+ windows_set.append("platform/posix/" + f)
+
+ return select({
+ "//tensorflow:windows": native.glob(windows_set),
+ "//conditions:default": native.glob(posix_set),
+ })
def tf_additional_lib_hdrs(exclude = []):
- windows_hdrs = native.glob([
- "platform/default/*.h",
- "platform/windows/*.h",
- "platform/posix/error.h",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_hdrs,
- "//conditions:default" : native.glob([
+ windows_hdrs = native.glob([
"platform/default/*.h",
- "platform/posix/*.h",
- ], exclude = exclude),
- })
+ "platform/windows/*.h",
+ "platform/posix/error.h",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_hdrs,
+ "//conditions:default": native.glob([
+ "platform/default/*.h",
+ "platform/posix/*.h",
+ ], exclude = exclude),
+ })
def tf_additional_lib_srcs(exclude = []):
- windows_srcs = native.glob([
- "platform/default/*.cc",
- "platform/windows/*.cc",
- "platform/posix/error.cc",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_srcs,
- "//conditions:default" : native.glob([
+ windows_srcs = native.glob([
"platform/default/*.cc",
- "platform/posix/*.cc",
- ], exclude = exclude),
- })
+ "platform/windows/*.cc",
+ "platform/posix/error.cc",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_srcs,
+ "//conditions:default": native.glob([
+ "platform/default/*.cc",
+ "platform/posix/*.cc",
+ ], exclude = exclude),
+ })
def tf_additional_minimal_lib_srcs():
- return [
- "platform/default/integral_types.h",
- "platform/default/mutex.h",
- ]
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/mutex.h",
+ ]
def tf_additional_proto_hdrs():
- return [
- "platform/default/integral_types.h",
- "platform/default/logging.h",
- "platform/default/protobuf.h"
- ] + if_windows([
- "platform/windows/integral_types.h",
- ])
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/logging.h",
+ "platform/default/protobuf.h",
+ ] + if_windows([
+ "platform/windows/integral_types.h",
+ ])
def tf_additional_proto_compiler_hdrs():
- return [
- "platform/default/protobuf_compiler.h"
- ]
+ return [
+ "platform/default/protobuf_compiler.h",
+ ]
def tf_additional_proto_srcs():
- return [
- "platform/default/protobuf.cc",
- ]
+ return [
+ "platform/default/protobuf.cc",
+ ]
def tf_additional_human_readable_json_deps():
- return []
+ return []
def tf_additional_all_protos():
- return ["//tensorflow/core:protos_all"]
+ return ["//tensorflow/core:protos_all"]
def tf_protos_all_impl():
- return ["//tensorflow/core:protos_all_cc_impl"]
+ return ["//tensorflow/core:protos_all_cc_impl"]
def tf_protos_all():
- return if_static(
- extra_deps=tf_protos_all_impl(),
- otherwise=["//tensorflow/core:protos_all_cc"])
+ return if_static(
+ extra_deps = tf_protos_all_impl(),
+ otherwise = ["//tensorflow/core:protos_all_cc"],
+ )
def tf_protos_grappler_impl():
- return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
+ return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
def tf_protos_grappler():
- return if_static(
- extra_deps=tf_protos_grappler_impl(),
- otherwise=["//tensorflow/core/grappler/costs:op_performance_data_cc"])
+ return if_static(
+ extra_deps = tf_protos_grappler_impl(),
+ otherwise = ["//tensorflow/core/grappler/costs:op_performance_data_cc"],
+ )
def tf_additional_cupti_wrapper_deps():
- return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
+ return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
def tf_additional_device_tracer_srcs():
- return ["platform/default/device_tracer.cc"]
+ return ["platform/default/device_tracer.cc"]
def tf_additional_device_tracer_cuda_deps():
- return []
+ return []
def tf_additional_device_tracer_deps():
- return []
+ return []
def tf_additional_libdevice_data():
- return []
+ return []
def tf_additional_libdevice_deps():
- return ["@local_config_cuda//cuda:cuda_headers"]
+ return ["@local_config_cuda//cuda:cuda_headers"]
def tf_additional_libdevice_srcs():
- return ["platform/default/cuda_libdevice_path.cc"]
+ return ["platform/default/cuda_libdevice_path.cc"]
def tf_additional_test_deps():
- return []
+ return []
def tf_additional_test_srcs():
- return [
- "platform/default/test_benchmark.cc",
- ] + select({
- "//tensorflow:windows" : [
- "platform/windows/test.cc"
+ return [
+ "platform/default/test_benchmark.cc",
+ ] + select({
+ "//tensorflow:windows": [
+ "platform/windows/test.cc",
],
- "//conditions:default" : [
- "platform/posix/test.cc",
+ "//conditions:default": [
+ "platform/posix/test.cc",
],
})
def tf_kernel_tests_linkstatic():
- return 0
+ return 0
def tf_additional_lib_defines():
- """Additional defines needed to build TF libraries."""
- return select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
- "//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
- "//conditions:default": [],
- }) + if_not_mobile(["TENSORFLOW_USE_ABSL"])
+ """Additional defines needed to build TF libraries."""
+ return select({
+ "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["TENSORFLOW_USE_JEMALLOC"],
+ "//conditions:default": [],
+ })
def tf_additional_lib_deps():
- """Additional dependencies needed to build TF libraries."""
- return if_not_mobile(["@com_google_absl//absl/base:base"]) + if_static(
- ["@nsync//:nsync_cpp"],
- ["@nsync//:nsync_headers"]
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- })
+ """Additional dependencies needed to build TF libraries."""
+ return [
+ "@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:optional",
+ ] + if_static(
+ ["@nsync//:nsync_cpp"],
+ ["@nsync//:nsync_headers"],
+ ) + select({
+ "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
+ "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
+ "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
+ "//conditions:default": [],
+ })
def tf_additional_core_deps():
- return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/core/platform/s3:s3_file_system",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/core/platform/cloud:gcs_file_system",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_hdfs_support_windows_override": [],
+ "//tensorflow:with_hdfs_support_android_override": [],
+ "//tensorflow:with_hdfs_support_ios_override": [],
+ "//tensorflow:with_hdfs_support": [
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support_android_override": [],
+ "//tensorflow:with_aws_support_ios_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/core/platform/s3:s3_file_system",
+ ],
+ "//conditions:default": [],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
- "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
+ "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
+ ],
+ "//conditions:default": [],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
- "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gcp_support_windows_override": [],
+ "//tensorflow:with_gcp_support_android_override": [],
+ "//tensorflow:with_gcp_support_ios_override": [],
+ "//tensorflow:with_gcp_support": [
+ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
+ "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
+ ],
+ "//conditions:default": [],
+ })
def tf_lib_proto_parsing_deps():
- return [
- ":protos_all_cc",
- "//third_party/eigen3",
- "//tensorflow/core/platform/default/build_config:proto_parsing",
- ]
+ return [
+ ":protos_all_cc",
+ "//third_party/eigen3",
+ "//tensorflow/core/platform/default/build_config:proto_parsing",
+ ]
def tf_lib_proto_compiler_deps():
- return [
- "@protobuf_archive//:protoc_lib",
- ]
+ return [
+ "@protobuf_archive//:protoc_lib",
+ ]
def tf_additional_verbs_lib_defines():
- return select({
- "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_lib_defines():
- return select({
- "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_lib_defines():
- return select({
- "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
+ "//conditions:default": [],
+ })
-def tf_py_clif_cc(name, visibility=None, **kwargs):
- pass
+def tf_py_clif_cc(name, visibility = None, **kwargs):
+ pass
-def tf_pyclif_proto_library(name, proto_lib, proto_srcfile="", visibility=None,
- **kwargs):
- pass
+def tf_pyclif_proto_library(
+ name,
+ proto_lib,
+ proto_srcfile = "",
+ visibility = None,
+ **kwargs):
+ pass
def tf_additional_binary_deps():
- return ["@nsync//:nsync_cpp"] + if_cuda(
- [
- "//tensorflow/stream_executor:cuda_platform",
- "//tensorflow/core/platform/default/build_config:cuda",
- ],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- }) + [
- # TODO(allenl): Split these out into their own shared objects (they are
- # here because they are shared between contrib/ op shared objects and
- # core).
- "//tensorflow/core/kernels:lookup_util",
- "//tensorflow/core/util/tensor_bundle",
- ] + if_mkl_ml(
- [
- "//third_party/intel_mkl_ml",
- ],
- )
+ return ["@nsync//:nsync_cpp"] + if_cuda(
+ [
+ "//tensorflow/stream_executor:cuda_platform",
+ "//tensorflow/core/platform/default/build_config:cuda",
+ ],
+ ) + select({
+ "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
+ "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
+ "//conditions:default": [],
+ }) + [
+ # TODO(allenl): Split these out into their own shared objects (they are
+ # here because they are shared between contrib/ op shared objects and
+ # core).
+ "//tensorflow/core/kernels:lookup_util",
+ "//tensorflow/core/util/tensor_bundle",
+ ] + if_mkl_ml(
+ [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ )
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index ccddf1eafc..0389149469 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
+ bool IsEnabled(bool is_expensive) const override {
+ // We don't do anything with 'Activities' so we are never 'enabled'.
+ return false;
+ }
+
protected:
// This callback is used exclusively by CUPTIManager.
friend class CUPTIManager;
diff --git a/tensorflow/core/platform/default/integral_types.h b/tensorflow/core/platform/default/integral_types.h
index 7cbe7d62f7..92186bc912 100644
--- a/tensorflow/core/platform/default/integral_types.h
+++ b/tensorflow/core/platform/default/integral_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/types.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/types.h
@@ -33,4 +33,4 @@ typedef unsigned long long uint64;
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index 2c134f1be9..08a692fff7 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/logging.h
@@ -314,4 +314,4 @@ int64 MinVLogLevelFromEnv();
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index 48d90779e1..bef7801037 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/mutex.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/mutex.h
@@ -173,4 +173,4 @@ inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h
index a6aa5b1b5e..d21d60ab0b 100644
--- a/tensorflow/core/platform/default/thread_annotations.h
+++ b/tensorflow/core/platform/default/thread_annotations.h
@@ -32,8 +32,8 @@ limitations under the License.
// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object.
//
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/thread_annotations.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/thread_annotations.h
@@ -174,4 +174,4 @@ inline T& ts_unchecked_read(T& v) NO_THREAD_SAFETY_ANALYSIS {
} // namespace thread_safety_analysis
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h
index b161378405..b7a5f1386c 100644
--- a/tensorflow/core/platform/default/tracing_impl.h
+++ b/tensorflow/core/platform/default/tracing_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
// Stub implementations of tracing functionality.
@@ -43,4 +43,4 @@ inline bool EventCollector::IsEnabled() { return false; }
} // namespace tracing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
diff --git a/tensorflow/core/platform/denormal.h b/tensorflow/core/platform/denormal.h
index 09bb0352a2..555ac023db 100644
--- a/tensorflow/core/platform/denormal.h
+++ b/tensorflow/core/platform/denormal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DENORMAL_H_
-#define TENSORFLOW_PLATFORM_DENORMAL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
+#define TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
#include "tensorflow/core/platform/macros.h"
@@ -59,4 +59,4 @@ class ScopedDontFlushDenormal {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DENORMAL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
diff --git a/tensorflow/core/platform/dynamic_annotations.h b/tensorflow/core/platform/dynamic_annotations.h
index f51f3f33a3..dad0d0f4e4 100644
--- a/tensorflow/core/platform/dynamic_annotations.h
+++ b/tensorflow/core/platform/dynamic_annotations.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
#include "tensorflow/core/platform/platform.h"
@@ -28,4 +28,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 47c59d435b..afc4201e53 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -92,7 +92,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
StringPiece scheme, host, path;
io::ParseURI(fname, &scheme, &host, &path);
- FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme));
+ FileSystem* file_system = file_system_registry_->Lookup(string(scheme));
if (!file_system) {
if (scheme.empty()) {
scheme = "[local]";
@@ -166,7 +166,7 @@ bool Env::FilesExist(const std::vector<string>& files,
for (const auto& file : files) {
StringPiece scheme, host, path;
io::ParseURI(file, &scheme, &host, &path);
- files_per_fs[std::string(scheme)].push_back(file);
+ files_per_fs[string(scheme)].push_back(file);
}
std::unordered_map<string, Status> per_file_status;
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 922773684b..3ab542a5d8 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -158,7 +158,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
std::reverse(sub_dirs.begin(), sub_dirs.end());
// Now create the directories.
- string built_path = std::string(remaining_dir);
+ string built_path(remaining_dir);
for (const StringPiece sub_dir : sub_dirs) {
built_path = io::JoinPath(built_path, sub_dir);
Status status = CreateDir(io::CreateURI(scheme, host, built_path));
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
index 0ba0e6304f..342cf28e38 100644
--- a/tensorflow/core/platform/file_system_helper.cc
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -59,7 +59,7 @@ Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
string eval_pattern = pattern;
std::vector<string> all_files;
- string dir = std::string(io::Dirname(fixed_prefix));
+ string dir(io::Dirname(fixed_prefix));
// If dir is empty then we need to fix up fixed_prefix and eval_pattern to
// include . as the top level directory.
if (dir.empty()) {
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index c0a16c95f9..a637d42a92 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -125,7 +125,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
ASSERT_EQ(scheme, "ipfs");
ASSERT_EQ(host, "solarsystem");
str_util::ConsumePrefix(&path, "/");
- *parsed_path = std::string(path);
+ *parsed_path = string(path);
}
std::map<string, std::set<string>> celestial_bodies_ = {
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index ff4b4436bb..8cdb08f51b 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -144,7 +144,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
StringPiece scheme, namenode, path;
io::ParseURI(fname, &scheme, &namenode, &path);
- const string nn = namenode.ToString();
+ const string nn(namenode);
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
if (scheme == "file") {
@@ -183,7 +183,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
string HadoopFileSystem::TranslateName(const string& name) const {
StringPiece scheme, namenode, path;
io::ParseURI(name, &scheme, &namenode, &path);
- return path.ToString();
+ return string(path);
}
class HDFSRandomAccessFile : public RandomAccessFile {
@@ -392,7 +392,7 @@ Status HadoopFileSystem::GetChildren(const string& dir,
return IOError(dir, errno);
}
for (int i = 0; i < entries; i++) {
- result->push_back(io::Basename(info[i].mName).ToString());
+ result->push_back(string(io::Basename(info[i].mName)));
}
hdfs_->hdfsFreeFileInfo(info, entries);
return Status::OK();
diff --git a/tensorflow/core/platform/host_info.h b/tensorflow/core/platform/host_info.h
index 6124c95923..e76b83adf3 100644
--- a/tensorflow/core/platform/host_info.h
+++ b/tensorflow/core/platform/host_info.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HOST_INFO_H_
-#define TENSORFLOW_PLATFORM_HOST_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
#include "tensorflow/core/platform/types.h"
@@ -27,4 +27,4 @@ string Hostname();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_HOST_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
diff --git a/tensorflow/core/platform/init_main.h b/tensorflow/core/platform/init_main.h
index 20cbc615b1..834c529816 100644
--- a/tensorflow/core/platform/init_main.h
+++ b/tensorflow/core/platform/init_main.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_INIT_MAIN_H_
-#define TENSORFLOW_PLATFORM_INIT_MAIN_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
+#define TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
namespace tensorflow {
namespace port {
@@ -28,4 +28,4 @@ void InitMain(const char* usage, int* argc, char*** argv);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_INIT_MAIN_H_
+#endif // TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h
index 9038de25f3..c7eeb2918c 100644
--- a/tensorflow/core/platform/load_library.h
+++ b/tensorflow/core/platform/load_library.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
-#define TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
+#define TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
#include "tensorflow/core/lib/core/status.h"
@@ -31,4 +31,4 @@ string FormatLibraryFileName(const string& name, const string& version);
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+#endif // TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h
index 985c061676..17a5d5fb5b 100644
--- a/tensorflow/core/platform/logging.h
+++ b/tensorflow/core/platform/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_LOGGING_H_
-#define TENSORFLOW_PLATFORM_LOGGING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+#define TENSORFLOW_CORE_PLATFORM_LOGGING_H_
#include "tensorflow/core/platform/platform.h" // To pick up PLATFORM_define
@@ -36,4 +36,4 @@ void LogString(const char* fname, int line, int severity,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_LOGGING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_LOGGING_H_
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index b65eb43146..e1d83e18ac 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MACROS_H_
-#define TENSORFLOW_PLATFORM_MACROS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MACROS_H_
+#define TENSORFLOW_CORE_PLATFORM_MACROS_H_
// Compiler attributes
#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)
@@ -125,4 +125,4 @@ limitations under the License.
} while (0)
#endif
-#endif // TENSORFLOW_PLATFORM_MACROS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MACROS_H_
diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h
index fca3a2332d..e8150f7322 100644
--- a/tensorflow/core/platform/mem.h
+++ b/tensorflow/core/platform/mem.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MEM_H_
-#define TENSORFLOW_PLATFORM_MEM_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MEM_H_
+#define TENSORFLOW_CORE_PLATFORM_MEM_H_
// TODO(cwhipkey): remove this when callers use annotations directly.
#include "tensorflow/core/platform/dynamic_annotations.h"
@@ -65,4 +65,4 @@ int64 AvailableRam();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_MEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MEM_H_
diff --git a/tensorflow/core/platform/mutex.h b/tensorflow/core/platform/mutex.h
index 42d46ceb5b..66b20da95a 100644
--- a/tensorflow/core/platform/mutex.h
+++ b/tensorflow/core/platform/mutex.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MUTEX_H_
-#define TENSORFLOW_PLATFORM_MUTEX_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MUTEX_H_
+#define TENSORFLOW_CORE_PLATFORM_MUTEX_H_
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ ConditionResult WaitForMilliseconds(mutex_lock* mu, condition_variable* cv,
int64 ms);
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_MUTEX_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MUTEX_H_
diff --git a/tensorflow/core/platform/net.h b/tensorflow/core/platform/net.h
index 9e7851728d..7dbc92f058 100644
--- a/tensorflow/core/platform/net.h
+++ b/tensorflow/core/platform/net.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_NET_H_
-#define TENSORFLOW_PLATFORM_NET_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_NET_H_
+#define TENSORFLOW_CORE_PLATFORM_NET_H_
namespace tensorflow {
namespace internal {
@@ -24,4 +24,4 @@ int PickUnusedPortOrDie();
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_NET_H_
+#endif // TENSORFLOW_CORE_PLATFORM_NET_H_
diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h
index b110d63aba..93b1425f7a 100644
--- a/tensorflow/core/platform/png.h
+++ b/tensorflow/core/platform/png.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PNG_H_
-#define TENSORFLOW_PLATFORM_PNG_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PNG_H_
+#define TENSORFLOW_CORE_PLATFORM_PNG_H_
#include "tensorflow/core/platform/platform.h"
@@ -27,4 +27,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_PNG_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PNG_H_
diff --git a/tensorflow/core/platform/posix/error.h b/tensorflow/core/platform/posix/error.h
index 9b614d0f70..9df5f2daa1 100644
--- a/tensorflow/core/platform/posix/error.h
+++ b/tensorflow/core/platform/posix/error.h
@@ -24,4 +24,4 @@ Status IOError(const string& context, int err_number);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_POSIX_ERROR_H_
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 1939cf72fb..b46b9927cd 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -17,9 +17,7 @@ limitations under the License.
#include "jemalloc/jemalloc.h"
#endif
-#ifdef TENSORFLOW_USE_ABSL
#include "absl/base/internal/sysinfo.h"
-#endif
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
@@ -194,11 +192,7 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
string Demangle(const char* mangled) { return mangled; }
double NominalCPUFrequency() {
-#ifdef TENSORFLOW_USE_ABSL
return absl::base_internal::NominalCPUFrequency();
-#else
- return 1.0;
-#endif
}
int64 AvailableRam() {
diff --git a/tensorflow/core/platform/posix/posix_file_system.h b/tensorflow/core/platform/posix/posix_file_system.h
index e8898d0a97..752eccea66 100644
--- a/tensorflow/core/platform/posix/posix_file_system.h
+++ b/tensorflow/core/platform/posix/posix_file_system.h
@@ -70,7 +70,7 @@ class LocalPosixFileSystem : public PosixFileSystem {
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
io::ParseURI(name, &scheme, &host, &path);
- return path.ToString();
+ return string(path);
}
};
diff --git a/tensorflow/core/platform/posix/subprocess.h b/tensorflow/core/platform/posix/subprocess.h
index 53f95f3c14..9740d75595 100644
--- a/tensorflow/core/platform/posix/subprocess.h
+++ b/tensorflow/core/platform/posix/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
#include <errno.h>
#include <unistd.h>
@@ -128,4 +128,4 @@ class SubProcess {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/prefetch.h b/tensorflow/core/platform/prefetch.h
index 81e1a5210a..9cefab3c1b 100644
--- a/tensorflow/core/platform/prefetch.h
+++ b/tensorflow/core/platform/prefetch.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PREFETCH_H_
-#define TENSORFLOW_PLATFORM_PREFETCH_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
+#define TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
#include "tensorflow/core/platform/platform.h"
@@ -56,4 +56,4 @@ inline void prefetch(const void* x) {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PREFETCH_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
index ce2069b004..2d94736c97 100644
--- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
+++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
#include <sys/types.h>
@@ -64,4 +64,4 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper {
#endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) &&
// (defined(__ARM_ARCH_7A__) || defined(__aarch64__))
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
diff --git a/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h b/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
index de4eec28e3..e25456374c 100644
--- a/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
+++ b/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
-#define TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
#include <algorithm>
@@ -103,4 +103,4 @@ class ClockCycleProfiler {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.h b/tensorflow/core/platform/profile_utils/cpu_utils.h
index 8f06290303..b0b1ef0363 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.h
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// This class is designed to get accurate profile for programs.
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
#include <chrono>
#include <memory>
@@ -164,4 +164,4 @@ class CpuUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
diff --git a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
index 11b739c009..cab7618a70 100644
--- a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
+++ b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ class ICpuUtilsHelper {
} // namespace profile_utils
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h
index 288d091624..fcbf1fc8c5 100644
--- a/tensorflow/core/platform/protobuf.h
+++ b/tensorflow/core/platform/protobuf.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROTOBUF_H_
-#define TENSORFLOW_PLATFORM_PROTOBUF_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
+#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
@@ -52,4 +52,4 @@ inline void SetProtobufStringSwapAllowed(string* src, string* dest) {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROTOBUF_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h
index 2f151a5aee..d0cfde09bc 100644
--- a/tensorflow/core/platform/protobuf_internal.h
+++ b/tensorflow/core/platform/protobuf_internal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
-#define TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
+#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
#include "google/protobuf/any.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -69,4 +69,4 @@ Status ParseAny(const google::protobuf::Any& any, T* message,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 462113f9bb..ce0f6cd741 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -150,13 +150,13 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("S3 path doesn't start with 's3://': ",
fname);
}
- *bucket = bucketp.ToString();
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = objectp.ToString();
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("S3 path doesn't contain an object name: ",
fname);
diff --git a/tensorflow/core/platform/setround.h b/tensorflow/core/platform/setround.h
index d076e7acc6..ded00b23b1 100644
--- a/tensorflow/core/platform/setround.h
+++ b/tensorflow/core/platform/setround.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SETROUND_H_
-#define TENSORFLOW_PLATFORM_SETROUND_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SETROUND_H_
+#define TENSORFLOW_CORE_PLATFORM_SETROUND_H_
#include <cfenv>
@@ -42,4 +42,4 @@ class ScopedSetRound {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_SETROUND_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SETROUND_H_
diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h
index 62c208ffb4..5477b097ef 100644
--- a/tensorflow/core/platform/snappy.h
+++ b/tensorflow/core/platform/snappy.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SNAPPY_H_
-#define TENSORFLOW_PLATFORM_SNAPPY_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
+#define TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
#include "tensorflow/core/platform/types.h"
@@ -31,4 +31,4 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_SNAPPY_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h
index a52970fdaa..9f118b91b8 100644
--- a/tensorflow/core/platform/stacktrace_handler.h
+++ b/tensorflow/core/platform/stacktrace_handler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
-#define TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
+#define TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
namespace tensorflow {
namespace testing {
@@ -25,4 +25,4 @@ void InstallStacktraceHandler();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
diff --git a/tensorflow/core/platform/subprocess.h b/tensorflow/core/platform/subprocess.h
index dcc0c1a4ee..7c11e6232f 100644
--- a/tensorflow/core/platform/subprocess.h
+++ b/tensorflow/core/platform/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
#include <memory>
#include <vector>
@@ -67,4 +67,4 @@ std::unique_ptr<SubProcess> CreateSubProcess(const std::vector<string>& argv);
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h
index 99bae63edf..f5d3282f57 100644
--- a/tensorflow/core/platform/test.h
+++ b/tensorflow/core/platform/test.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TEST_H_
-#define TENSORFLOW_PLATFORM_TEST_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TEST_H_
+#define TENSORFLOW_CORE_PLATFORM_TEST_H_
#include <memory>
#include <vector>
@@ -55,4 +55,4 @@ int PickUnusedPortOrDie();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TEST_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TEST_H_
diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h
index 9b8726d98f..61fcd0d372 100644
--- a/tensorflow/core/platform/test_benchmark.h
+++ b/tensorflow/core/platform/test_benchmark.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Simple benchmarking facility.
-#ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
-#define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
+#define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
#include <utility>
#include <vector>
@@ -115,4 +115,4 @@ void UseRealTime();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
diff --git a/tensorflow/core/platform/thread_annotations.h b/tensorflow/core/platform/thread_annotations.h
index 50195cbbc7..aec34df8a1 100644
--- a/tensorflow/core/platform/thread_annotations.h
+++ b/tensorflow/core/platform/thread_annotations.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
#include "tensorflow/core/platform/types.h"
@@ -27,4 +27,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index c322777705..9974bbbb4e 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TRACING_H_
-#define TENSORFLOW_PLATFORM_TRACING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TRACING_H_
+#define TENSORFLOW_CORE_PLATFORM_TRACING_H_
// Tracing interface
@@ -155,6 +155,10 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this activity handle tracking is enabled for an op of the
+ // given expensiveness.
+ virtual bool IsEnabled(bool is_expensive) const = 0;
+
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
@@ -238,4 +242,4 @@ const char* GetLogDir();
#include "tensorflow/core/platform/default/tracing_impl.h"
#endif
-#endif // TENSORFLOW_PLATFORM_TRACING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TRACING_H_
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index 68897ac423..a4fa790317 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TYPES_H_
-#define TENSORFLOW_PLATFORM_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_TYPES_H_
#include <string>
#include "tensorflow/core/platform/platform.h"
@@ -66,4 +66,4 @@ namespace tensorflow {
namespace se = ::stream_executor;
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TYPES_H_
diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h
index ba2126abcf..8b42cbec7a 100644
--- a/tensorflow/core/platform/windows/cpu_info.h
+++ b/tensorflow/core/platform/windows/cpu_info.h
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
// included so __cpuidex function is available for GETCPUID on Windows
#include <intrin.h>
-#endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
diff --git a/tensorflow/core/platform/windows/integral_types.h b/tensorflow/core/platform/windows/integral_types.h
index 46338a536d..283af49f20 100644
--- a/tensorflow/core/platform/windows/integral_types.h
+++ b/tensorflow/core/platform/windows/integral_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
#include "tensorflow/core/platform/default/integral_types.h"
@@ -22,4 +22,4 @@ limitations under the License.
typedef std::ptrdiff_t ssize_t;
-#endif // TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/platform/windows/subprocess.h b/tensorflow/core/platform/windows/subprocess.h
index f00471d484..9084ff5a92 100644
--- a/tensorflow/core/platform/windows/subprocess.h
+++ b/tensorflow/core/platform/windows/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
#include <memory>
#include <vector>
@@ -33,4 +33,4 @@ std::unique_ptr<SubProcess> CreateSubProcess(const std::vector<string>& argv) {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h
index 6b04720c68..1f4c535f24 100644
--- a/tensorflow/core/platform/windows/windows_file_system.h
+++ b/tensorflow/core/platform/windows/windows_file_system.h
@@ -71,7 +71,7 @@ class LocalWinFileSystem : public WindowsFileSystem {
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
io::ParseURI(name, &scheme, &host, &path);
- return path.ToString();
+ return string(path);
}
};
diff --git a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
index f5ac5c9c5a..0d1c92eb08 100644
--- a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
+++ b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
@@ -137,4 +137,4 @@ class ExpensiveOperationChecker : public Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
index 270662bd4a..e1533f882f 100644
--- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
+++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
-#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
#include "tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h"
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -78,4 +78,4 @@ class Advisor {
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc
index 2c4f52e3ad..744e1e95de 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.cc
+++ b/tensorflow/core/profiler/internal/tfprof_code.cc
@@ -37,7 +37,7 @@ const char* const kGradientSuffix = " (gradient)";
// Convert to Trace proto into a short readable string.
string GetTraceString(const CallStack::Trace& trace) {
- string ntrace = io::Basename(trace.file()).ToString();
+ string ntrace(io::Basename(trace.file()));
ntrace += strings::StrCat(":", trace.lineno());
if (trace.function().length() < 20) {
ntrace += ":" + trace.function();
@@ -113,7 +113,7 @@ class FunctionTable {
// function index should start from 1.
func_pb->set_id(function_table_.size());
- string file_base = io::Basename(file_path).ToString();
+ string file_base(io::Basename(file_path));
file_base = file_base.substr(0, file_base.find_last_of("."));
func_pb->set_name(
string_table_->GetIndex(strings::StrCat(file_base, ":", func_name)));
diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h
index d61deb72ac..57c7e11fa2 100644
--- a/tensorflow/core/profiler/tfprof_options.h
+++ b/tensorflow/core/profiler/tfprof_options.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
-#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
+#define TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
#include <set>
#include <string>
@@ -183,4 +183,4 @@ tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#endif // TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index da3a99565e..625d5649e6 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -390,9 +390,12 @@ message ConfigProto {
message Experimental {
// Task name for group resolution.
string collective_group_leader = 1;
- // Whether the client will format templated errors. For example, the string:
- // "The node was defined on ^^node:Foo:${file}:${line}^^".
- bool client_handles_error_formatting = 2;
+
+ // We removed the flag client_handles_error_formatting. Marking the tag
+ // number as reserved.
+ // TODO(shikharagarwal): Should we just remove this tag so that it can be
+ // used in future for other purpose?
+ reserved 2;
// Which executor to use, the default executor will be used
// if it is an empty string or "DEFAULT"
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
index 811cf406b9..8ca76c44c0 100644
--- a/tensorflow/core/protobuf/debug.proto
+++ b/tensorflow/core/protobuf/debug.proto
@@ -60,6 +60,12 @@ message DebugOptions {
// Note that this is distinct from the session run count and the executor
// step count.
int64 global_step = 10;
+
+ // Whether the total disk usage of tfdbg is to be reset to zero
+ // in this Session.run call. This is used by wrappers and hooks
+ // such as the local CLI ones to indicate that the dumped tensors
+ // are cleaned up from the disk after each Session.run.
+ bool reset_disk_byte_usage = 11;
}
message DebuggedSourceFile {
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index cc8596ef3d..536a07c413 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PUBLIC_SESSION_H_
-#define TENSORFLOW_PUBLIC_SESSION_H_
+#ifndef TENSORFLOW_CORE_PUBLIC_SESSION_H_
+#define TENSORFLOW_CORE_PUBLIC_SESSION_H_
#include <string>
#include <vector>
@@ -279,4 +279,4 @@ Session* NewSession(const SessionOptions& options);
} // end namespace tensorflow
-#endif // TENSORFLOW_PUBLIC_SESSION_H_
+#endif // TENSORFLOW_CORE_PUBLIC_SESSION_H_
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 563564119f..4129c93af5 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -96,10 +96,12 @@ limitations under the License.
// GraphDef. (7dec2017)
// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
// deprecated in favor of V2 ops. (2018/01/23)
+// 28. Deprecate MatrixExponential op in favor of Python implementation.
+// (2018/08/21).
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 26
+#define TF_GRAPH_DEF_VERSION 27
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/util/activation_mode.h b/tensorflow/core/util/activation_mode.h
index 2e03ccd5c8..2f7820fb47 100644
--- a/tensorflow/core/util/activation_mode.h
+++ b/tensorflow/core/util/activation_mode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_ACTIVATION_MODE_H_
-#define TENSORFLOW_UTIL_ACTIVATION_MODE_H_
+#ifndef TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
+#define TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
// This file contains helper routines to deal with activation mode in various
// ops and kernels.
@@ -43,4 +43,4 @@ Status GetActivationModeFromString(const string& str_value,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_ACTIVATION_MODE_H_
+#endif // TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
index 81d64e5676..6d73c38e3c 100644
--- a/tensorflow/core/util/bcast.h
+++ b/tensorflow/core/util/bcast.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_BCAST_H_
-#define TENSORFLOW_UTIL_BCAST_H_
+#ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
+#define TENSORFLOW_CORE_UTIL_BCAST_H_
#include <algorithm>
@@ -132,4 +132,4 @@ class BCast {
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_BCAST_H_
+#endif // TENSORFLOW_CORE_UTIL_BCAST_H_
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index b281acb2b0..55f1e30880 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -32,7 +32,7 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
if (str_util::ConsumePrefix(&arg, "--") &&
str_util::ConsumePrefix(&arg, flag) &&
str_util::ConsumePrefix(&arg, "=")) {
- *value_parsing_ok = hook(std::string(arg));
+ *value_parsing_ok = hook(string(arg));
return true;
}
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index 973e315f09..24002e72a0 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h
index 1a622babe1..1e45a8abd3 100644
--- a/tensorflow/core/util/ctc/ctc_beam_scorer.h
+++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
// Collection of scoring classes that can be extended and provided to the
// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index aee647a1b3..6fbb1ed0da 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
@@ -259,6 +260,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -290,10 +301,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -328,6 +339,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -341,7 +354,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index 3be36822e5..b55d7d77ac 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h
index 36be9e92ef..054412d388 100644
--- a/tensorflow/core/util/ctc/ctc_loss_util.h
+++ b/tensorflow/core/util/ctc/ctc_loss_util.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index 4071a70836..3f0bc60562 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
-#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#ifndef TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
+#define TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
#include <string>
@@ -173,4 +173,4 @@ class DeviceNameUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#endif // TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
diff --git a/tensorflow/core/util/env_var.cc b/tensorflow/core/util/env_var.cc
index 8d43bcc927..2604a5d66a 100644
--- a/tensorflow/core/util/env_var.cc
+++ b/tensorflow/core/util/env_var.cc
@@ -28,7 +28,7 @@ namespace tensorflow {
Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
bool* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -48,7 +48,7 @@ Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
int64* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -62,11 +62,11 @@ Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
string* value) {
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val != nullptr) {
*value = tf_env_var_val;
} else {
- *value = std::string(default_val);
+ *value = string(default_val);
}
return Status::OK();
}
diff --git a/tensorflow/core/util/env_var.h b/tensorflow/core/util/env_var.h
index 47f9ff3a3b..724ca35729 100644
--- a/tensorflow/core/util/env_var.h
+++ b/tensorflow/core/util/env_var.h
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_ENV_VAR_H_
+#ifndef TENSORFLOW_CORE_UTIL_ENV_VAR_H_
+#define TENSORFLOW_CORE_UTIL_ENV_VAR_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -42,4 +43,4 @@ Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_ENV_VAR_H_
+#endif // TENSORFLOW_CORE_UTIL_ENV_VAR_H_
diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h
index 5dbaf97af4..d5952c3cbd 100644
--- a/tensorflow/core/util/events_writer.h
+++ b/tensorflow/core/util/events_writer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_
-#define TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#ifndef TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
+#define TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
#include <memory>
#include <string>
@@ -95,4 +95,4 @@ class EventsWriter {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#endif // TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 1fec0010a1..e52d55e2ff 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -353,7 +353,7 @@ bool TestFastParse(const string& serialized, Example* example) {
// I.e. last entry in the map overwrites all the previous ones.
parsed::FeatureMapEntry& name_and_feature =
parsed_example[parsed_example_size - i - 1];
- string name = std::string(name_and_feature.first);
+ string name(name_and_feature.first);
if ((*features.mutable_feature()).count(name) > 0) continue;
auto& value = (*features.mutable_feature())[name];
@@ -1722,10 +1722,11 @@ Status FastParseSequenceExample(
const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* context_result,
- Result* feature_list_result) {
+ Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size();
DCHECK(context_result != nullptr);
DCHECK(feature_list_result != nullptr);
+ DCHECK(dense_feature_lengths != nullptr);
std::map<StringPiece, bool> context_is_sparse;
std::map<StringPiece, std::pair<DataType, size_t>>
context_feature_type_and_lengths;
@@ -1740,9 +1741,22 @@ Status FastParseSequenceExample(
context_is_sparse[c.feature_name] = true;
}
for (auto& c : context_config.dense) {
+ if (context_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Context feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
context_feature_type_and_lengths[c.feature_name] =
- std::make_pair(c.dtype, 0);
+ std::make_pair(c.dtype, c.default_value.NumElements());
+ if (c.default_value.NumElements() > 0) {
+ if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
+ return errors::InvalidArgument("Default value for context feature ",
+ c.feature_name,
+ " has an incorrect shape: saw ",
+ c.default_value.shape().DebugString(),
+ " but expected ", c.shape.DebugString());
+ }
+ }
context_is_sparse[c.feature_name] = false;
}
std::map<StringPiece, bool> sequence_is_sparse;
@@ -1755,6 +1769,10 @@ Status FastParseSequenceExample(
sequence_is_sparse[c.feature_name] = true;
}
for (auto& c : feature_list_config.dense) {
+ if (sequence_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Sequence feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
sequence_feature_type_and_lengths[c.feature_name] =
std::make_pair(c.dtype, 0);
@@ -1792,14 +1810,14 @@ Status FastParseSequenceExample(
features = sequence_features;
config = &sequence_feature_type_and_lengths;
} else if (!SkipExtraneousTag(&stream)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
if (features != nullptr) {
uint32 length;
if (!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
while (!stream.ExpectAtEnd()) {
@@ -1807,16 +1825,16 @@ Status FastParseSequenceExample(
uint32 length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!ParseString(&stream, &key) ||
!stream.ExpectTag(kDelimitedTag(2)) ||
!ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
stream.PopLimit(limit);
// Only save if this feature was requested.
@@ -1851,9 +1869,8 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in context feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in context feature ", c.first,
+ " in example ", example_name);
}
num_elements += num;
}
@@ -1876,9 +1893,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -1898,22 +1915,22 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
num_elements += num;
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
}
}
@@ -1936,15 +1953,19 @@ Status FastParseSequenceExample(
feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ dense_feature_lengths->resize(feature_list_config.dense.size());
+
int t = 0;
for (const auto& c : context_config.dense) {
- TensorShape dense_shape;
+ TensorShape dense_shape, example_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
context_feature_type_and_lengths[c.feature_name].second;
- if (expected_max_elements != dense_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Inconsistent number of elements for feature ", c.feature_name));
+ if (!c.shape.AsTensorShape(&example_shape) ||
+ expected_max_elements != example_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Inconsistent number of elements for feature ", c.feature_name, ": ",
+ expected_max_elements, " vs ", dense_shape.num_elements());
}
dense_shape.AddDim(num_examples);
for (const int dim : c.shape.dim_sizes()) {
@@ -1968,18 +1989,58 @@ Status FastParseSequenceExample(
out_int64 = context_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0;
- const auto& feature = all_context_features[e][c.feature_name];
+ const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_context_features[e].end()) {
+ // Copy the default value, if present. If not, return an error.
+ if (c.default_value.NumElements() == 0) {
+ return errors::InvalidArgument(
+ "Feature: ", c.feature_name,
+ " (data type: ", DataTypeString(c.dtype), ")",
+ " is required but could not be found.");
+ }
+ const string* in_bytes = nullptr;
+ const float* in_float = nullptr;
+ const int64* in_int64 = nullptr;
+ size_t num = 0;
+ switch (dtype) {
+ case DT_STRING:
+ in_bytes = c.default_value.flat<string>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_bytes++ = *in_bytes++;
+ }
+ break;
+ case DT_FLOAT:
+ in_float = c.default_value.flat<float>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_float++ = *in_float++;
+ }
+ break;
+ case DT_INT64:
+ in_int64 = c.default_value.flat<int64>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_int64++ = *in_int64++;
+ }
+ break;
+ default:
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
+ }
+ num_elements += num;
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -1998,14 +2059,14 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
}
if (num_elements != expected_max_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in example ", example_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in example ", example_name);
}
}
}
@@ -2037,8 +2098,8 @@ Status FastParseSequenceExample(
out_int64 = context_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
auto out_shape = context_result->sparse_shapes[t].vec<int64>();
@@ -2070,8 +2131,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2082,30 +2143,35 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected total number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected total number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_cols;
}
t = 0;
+ TensorShape dense_length_shape({num_examples});
for (const auto& c : feature_list_config.dense) {
TensorShape dense_shape, row_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
sequence_feature_type_and_lengths[c.feature_name].second;
- int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
if (!c.shape.AsTensorShape(&row_shape) ||
- expected_max_elements != expected_max_rows * row_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected shape error in feature ", c.feature_name));
+ expected_max_elements !=
+ (expected_max_elements / row_shape.num_elements()) *
+ row_shape.num_elements()) {
+ return errors::InvalidArgument("Unexpected shape error in feature ",
+ c.feature_name);
}
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
dense_shape.AddDim(num_examples);
dense_shape.AddDim(expected_max_rows);
for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
dense_shape.AddDim(dim);
}
feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+ (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
+ int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr;
float* out_float = nullptr;
@@ -2121,18 +2187,26 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
- size_t num_elements = 0;
- const auto& feature = all_sequence_features[e][c.feature_name];
+ size_t num_elements = 0, num_rows = 0;
+ const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_sequence_features[e].end()) {
+ // Return an error if this feature was not allowed to be missing.
+ // Otherwise, we'll pad as needed below.
+ if (!c.variable_length) {
+ return errors::InvalidArgument("Missing feature ", c.feature_name,
+ " in example ", example_name);
+ }
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -2140,9 +2214,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
auto limit = stream.PushLimit(feature_length);
size_t num_added;
@@ -2160,10 +2234,11 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
+ num_rows++;
if (num_added != row_shape.num_elements()) {
return errors::InvalidArgument(
"Unexpected number of elements in feature ", c.feature_name,
@@ -2172,6 +2247,7 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
}
}
+ *out_lengths++ = num_rows;
// Pad as necessary.
int num_to_pad = expected_max_elements - num_elements;
switch (dtype) {
@@ -2187,8 +2263,8 @@ Status FastParseSequenceExample(
out_int64 += num_to_pad;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
}
}
@@ -2219,8 +2295,8 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices =
feature_list_result->sparse_indices[t].flat<int64>().data();
@@ -2244,9 +2320,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -2265,8 +2341,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2278,14 +2354,14 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
num_rows++;
}
@@ -2293,8 +2369,8 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_rows;
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index db5b5ff929..055d9c2c30 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -118,7 +118,8 @@ Status FastParseSequenceExample(
const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result,
- example::Result* feature_list_result);
+ example::Result* feature_list_result,
+ std::vector<Tensor>* dense_feature_lengths);
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
index 37faa927bf..6c5f80a535 100644
--- a/tensorflow/core/util/example_proto_fast_parsing_test.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -42,7 +42,7 @@ string SerializedToReadable(string serialized) {
string result;
result += '"';
for (char c : serialized)
- result += strings::StrCat("\\x", strings::Hex(c, strings::ZERO_PAD_2));
+ result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
result += '"';
return result;
}
diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc
index e156a3bc8f..41fb20c00a 100644
--- a/tensorflow/core/util/example_proto_helper.cc
+++ b/tensorflow/core/util/example_proto_helper.cc
@@ -443,6 +443,59 @@ Status ParseSingleExampleAttrs::FinishInit() {
return Status::OK();
}
+Status ParseSequenceExampleAttrs::FinishInit() {
+ if (num_context_sparse != context_sparse_keys.size() ||
+ num_context_sparse != context_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_context_sparse (", num_context_sparse,
+ ") must match the size of context_sparse_keys (",
+ context_sparse_keys.size(), ") and context_sparse_types (",
+ context_sparse_types.size(), ")");
+ }
+ if (num_context_dense != context_dense_keys.size() ||
+ num_context_dense != context_dense_types.size() ||
+ num_context_dense != context_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_context_dense (", num_context_dense,
+ ") must match the size of context_dense_keys (",
+ context_dense_keys.size(), "), context_dense_types (",
+ context_dense_types.size(), ") and context_dense_shapes (",
+ context_dense_shapes.size(), ")");
+ }
+ if (num_feature_list_sparse != feature_list_sparse_keys.size() ||
+ num_feature_list_sparse != feature_list_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_sparse (", num_feature_list_sparse,
+ ") must match the size of feature_list_sparse_keys (",
+ feature_list_sparse_keys.size(), ") and feature_list_sparse_types (",
+ feature_list_sparse_types.size(), ")");
+ }
+ if (num_feature_list_dense != feature_list_dense_keys.size() ||
+ num_feature_list_dense != feature_list_dense_types.size() ||
+ num_feature_list_dense != feature_list_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_dense (", num_feature_list_dense,
+ ") must match the size of feature_list_dense_keys (",
+ feature_list_dense_keys.size(), "), feature_list_dense_types (",
+ feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
+ feature_list_dense_shapes.size(), ")");
+ }
+ for (const DataType& type : context_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : context_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+
+ return Status::OK();
+}
+
Status ParseSingleSequenceExampleAttrs::FinishInit() {
if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h
index e511704962..c183ee4d96 100644
--- a/tensorflow/core/util/example_proto_helper.h
+++ b/tensorflow/core/util/example_proto_helper.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
@@ -271,6 +272,66 @@ class ParseSingleExampleAttrs {
Status FinishInit(); // for context-independent parts of Init.
};
+// Parses the attributes passed to ParseSequenceExample.
+// REQUIRES: Init must be called after construction.
+class ParseSequenceExampleAttrs {
+ public:
+ template <typename ContextType>
+ Status Init(ContextType* ctx) {
+ std::vector<string> feature_list_dense_missing_assumed_empty_tmp;
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_missing_assumed_empty",
+ &feature_list_dense_missing_assumed_empty_tmp));
+ for (const string& feature : feature_list_dense_missing_assumed_empty_tmp) {
+ feature_list_dense_missing_assumed_empty.insert(feature);
+ }
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_keys", &context_sparse_keys));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("context_dense_keys", &context_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_keys", &feature_list_sparse_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_types", &context_sparse_types));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
+ return FinishInit();
+ }
+
+ std::unordered_set<string> feature_list_dense_missing_assumed_empty;
+ int64 num_context_sparse;
+ int64 num_context_dense;
+ int64 num_feature_list_sparse;
+ int64 num_feature_list_dense;
+ std::vector<string> context_sparse_keys;
+ std::vector<string> context_dense_keys;
+ std::vector<string> feature_list_sparse_keys;
+ std::vector<string> feature_list_dense_keys;
+ std::vector<DataType> context_sparse_types;
+ std::vector<DataType> context_dense_types;
+ std::vector<TensorShape> context_dense_shapes;
+ std::vector<DataType> feature_list_sparse_types;
+ std::vector<DataType> feature_list_dense_types;
+ std::vector<TensorShape> feature_list_dense_shapes;
+
+ private:
+ Status FinishInit(); // for context-independent parts of Init.
+};
+
// Parses the attributes passed to ParseSingleSequenceExample.
// REQUIRES: Init must be called after construction.
class ParseSingleSequenceExampleAttrs {
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h
index 44970eb949..8be7a374f0 100644
--- a/tensorflow/core/util/guarded_philox_random.h
+++ b/tensorflow/core/util/guarded_philox_random.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
-#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#ifndef TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
+#define TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/random/philox_random.h"
@@ -79,4 +79,4 @@ class GuardedPhiloxRandom {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#endif // TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h
index f703d47ab1..ceee9b06b0 100644
--- a/tensorflow/core/util/mirror_pad_mode.h
+++ b/tensorflow/core/util/mirror_pad_mode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
-#define TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
+#ifndef TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
+#define TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
// This file contains helper routines to deal with padding in various ops and
// kernels.
@@ -49,4 +49,4 @@ Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
+#endif // TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 159a787d05..680211edff 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -33,6 +34,12 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
+// Using pragma message since #warning doesn't work with all compilers
+#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
+#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#endif
+
+#ifdef INTEL_MKL_ML_ONLY
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
@@ -50,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -66,7 +74,6 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-
namespace tensorflow {
// The file contains a number of utility classes and functions used by MKL
@@ -87,6 +94,18 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+typedef enum {
+ Dim3d_N = 0,
+ Dim3d_C = 1,
+ Dim3d_D = 2,
+ Dim3d_H = 3,
+ Dim3d_W = 4,
+ Dim3d_O = 0,
+ Dim3d_I = 1
+} MklDnnDims3D;
+
+static const int kSmallBatchSize = 32;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -351,6 +370,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -453,6 +473,13 @@ class MklDnnShape {
return this->DimSize(index);
}
+ inline size_t GetDimension3D(char dimension) const {
+ int index = GetMklDnnTensor3DDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -469,6 +496,24 @@ class MklDnnShape {
}
}
+ inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims3D::Dim3d_N;
+ case 'C':
+ return MklDnnDims3D::Dim3d_C;
+ case 'D':
+ return MklDnnDims3D::Dim3d_D;
+ case 'H':
+ return MklDnnDims3D::Dim3d_H;
+ case 'W':
+ return MklDnnDims3D::Dim3d_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
inline size_t GetDimension() const { return data_.dimension_; }
inline const int* GetSizes() const {
return reinterpret_cast<const int*>(&data_.sizes_[0]);
@@ -587,15 +632,29 @@ class MklDnnShape {
}
inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
- // TODO(nhasabni): Why do we restrict this to 4D?
- CHECK_EQ(dimension, 4);
- CHECK(dimension == data_.dimension_);
- data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
- data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
- data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
- data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ if (dimension == 5) {
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
+ MklDnnDims3D::Dim3d_D;
+ data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
+ MklDnnDims3D::Dim3d_H;
+ data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
+ MklDnnDims3D::Dim3d_W;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
+ MklDnnDims3D::Dim3d_C;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
+ MklDnnDims3D::Dim3d_N;
+ } else {
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
}
+
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
SetTfDimOrder(dimension, data_format);
@@ -1329,6 +1388,19 @@ memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
}
+/// Map TensorFlow's data format into MKL-DNN 3D data format
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::ndhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::ncdhw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ return memory::format::format_undef;
+}
+
/// Map TensorFlow's data format into MKL-DNN data format
///
/// @input: TensorFlow data format
@@ -1340,7 +1412,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1350,9 +1421,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
- if (format == memory::format::nhwc)
+ if (format == memory::format::nhwc || format == memory::format::ndhwc)
return FORMAT_NHWC;
- else if (format == memory::format::nchw)
+ else if (format == memory::format::nchw || format == memory::format::ncdhw)
return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
@@ -1402,6 +1473,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
+ int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
+ int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
+ int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
+
+ // MKL-DNN requires dimensions in NCDHW format.
+ return memory::dims({n, c, d, h, w});
+}
+
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1514,6 +1601,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1540,6 +1629,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
@@ -1911,7 +2004,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -1934,6 +2029,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -1971,21 +2082,25 @@ class FactoryKeyCreator {
const char delimiter = 'x';
const int kMaxKeyLength = 256;
void Append(StringPiece s) {
- key_.append(s.ToString());
+ key_.append(string(s));
key_.append(1, delimiter);
}
};
-static inline memory::format get_desired_format(int channel) {
+
+static inline memory::format get_desired_format(int channel,
+ bool is_2d = true) {
memory::format fmt_desired = memory::format::any;
- if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) {
- fmt_desired = memory::format::nChw16c;
+ if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
+ fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = memory::format::nChw8c;
+ fmt_desired = is_2d
+ ? memory::format::nChw8c
+ : memory::format::ncdhw; //not support avx2 for 3d yet.
} else {
- fmt_desired = memory::format::nchw;
+ fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
return fmt_desired;
}
@@ -2006,7 +2121,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2048,7 +2163,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2093,6 +2208,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
index a4278ff2b4..76f9b4dd9a 100644
--- a/tensorflow/core/util/padding.h
+++ b/tensorflow/core/util/padding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_PADDING_H_
-#define TENSORFLOW_UTIL_PADDING_H_
+#ifndef TENSORFLOW_CORE_UTIL_PADDING_H_
+#define TENSORFLOW_CORE_UTIL_PADDING_H_
// This file contains helper routines to deal with padding in various ops and
// kernels.
@@ -50,4 +50,4 @@ Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_PADDING_H_
+#endif // TENSORFLOW_CORE_UTIL_PADDING_H_
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
index 981def9d22..e9b9cb1cd2 100644
--- a/tensorflow/core/util/port.h
+++ b/tensorflow/core/util/port.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_PORT_H_
-#define TENSORFLOW_UTIL_PORT_H_
+#ifndef TENSORFLOW_CORE_UTIL_PORT_H_
+#define TENSORFLOW_CORE_UTIL_PORT_H_
namespace tensorflow {
@@ -30,4 +30,4 @@ bool IsMklEnabled();
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_PORT_H_
+#endif // TENSORFLOW_CORE_UTIL_PORT_H_
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
index 90672a10a8..7c9cfa35f7 100644
--- a/tensorflow/core/util/saved_tensor_slice_util.h
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -15,8 +15,8 @@ limitations under the License.
// Utilities for saving/restoring tensor slice checkpoints.
-#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
-#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
#include <string> // for string
#include "tensorflow/core/framework/tensor.pb.h"
@@ -210,4 +210,4 @@ inline void Fill(const string* data, size_t n, TensorProto* t) {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index aca60b942d..ad8a44a518 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -326,7 +326,7 @@ Status ValidateStridedSliceOp(
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
- if (dim_i > 0) {
+ if (dim_i >= 0) {
if (stride_i < 0) {
interval_length = -dim_i;
} else {
diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h
index 3d21570c74..6539d565e2 100644
--- a/tensorflow/core/util/tensor_bundle/naming.h
+++ b/tensorflow/core/util/tensor_bundle/naming.h
@@ -31,8 +31,8 @@ limitations under the License.
//
// Regexp can also be used: e.g. R"<prefix>.data-\d{5}-of-\d{5}" for data files.
-#ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
-#define TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -43,4 +43,4 @@ string DataFilename(StringPiece prefix, int32 shard_id, int32 num_shards);
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index 7190614706..ea8a259d1a 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -370,14 +370,14 @@ Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
: env_(env),
options_(options),
- prefix_(std::string(prefix)),
+ prefix_(prefix),
tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
random::New64())),
tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
random::New64())),
out_(nullptr),
size_(0) {
- status_ = env_->CreateDir(std::string(io::Dirname(prefix_)));
+ status_ = env_->CreateDir(string(io::Dirname(prefix_)));
if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
return;
}
@@ -394,7 +394,7 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
Status BundleWriter::Add(StringPiece key, const Tensor& val) {
if (!status_.ok()) return status_;
CHECK_NE(key, kHeaderEntryKey);
- const string key_string = std::string(key);
+ const string key_string(key);
if (entries_.find(key_string) != entries_.end()) {
status_ = errors::InvalidArgument("Adding duplicate key: ", key);
return status_;
@@ -445,7 +445,7 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key,
// In the case of a sharded save, MergeBundles() is responsible for merging
// the "slices" field of multiple metadata entries corresponding to the same
// full tensor.
- const string full_tensor_key_string = std::string(full_tensor_key);
+ const string full_tensor_key_string(full_tensor_key);
BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
if (full_entry->dtype() != DT_INVALID) {
CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
@@ -600,7 +600,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
// Loops through the non-header to-merge entries.
BundleEntryProto to_merge_entry;
for (; iter->Valid(); iter->Next()) {
- const string key = std::string(iter->key());
+ const string key(iter->key());
const auto entry_iter = merge_state->entries.find(key);
// Illegal: the duplicated entry is a non-slice tensor.
@@ -649,7 +649,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
// Merges all metadata tables.
// TODO(zhifengc): KeyValue sorter if it becomes too big.
MergeState merge;
- Status status = env->CreateDir(std::string(io::Dirname(merged_prefix)));
+ Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
for (int i = 0; i < prefixes.size(); ++i) {
TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
@@ -697,7 +697,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
BundleReader::BundleReader(Env* env, StringPiece prefix)
: env_(env),
- prefix_(std::string(prefix)),
+ prefix_(prefix),
metadata_(nullptr),
table_(nullptr),
iter_(nullptr) {
@@ -919,7 +919,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
std::vector<std::pair<TensorSlice, string>> details;
- const string full_tensor_key_string = std::string(full_tensor_key);
+ const string full_tensor_key_string(full_tensor_key);
const TensorSliceSet* tss =
gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index d30ce3f0cf..3a2ffbb495 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -58,8 +58,8 @@ limitations under the License.
// "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
//
-#ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
-#define TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
@@ -346,4 +346,4 @@ class FileOutputBuffer {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 92ce8ae00e..59c42baa06 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -107,7 +107,7 @@ std::vector<string> AllTensorKeys(BundleReader* reader) {
reader->Seek(kHeaderEntryKey);
reader->Next();
for (; reader->Valid(); reader->Next()) {
- ret.push_back(std::string(reader->key()));
+ ret.emplace_back(reader->key());
}
return ret;
}
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows:
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
index 263f56c7fc..4aa9a4708e 100644
--- a/tensorflow/core/util/tensor_slice_reader.h
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to read checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
#include <unordered_map>
@@ -192,4 +192,4 @@ bool TensorSliceReader::CopySliceData(const string& name,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h
index 63a8d0b068..9f1919df4e 100644
--- a/tensorflow/core/util/tensor_slice_reader_cache.h
+++ b/tensorflow/core/util/tensor_slice_reader_cache.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to read checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
#include <unordered_map>
@@ -85,4 +85,4 @@ class TensorSliceReaderCache {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
index 2888c66d10..0db2fb4804 100644
--- a/tensorflow/core/util/tensor_slice_writer.h
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to write checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
#include <unordered_map>
@@ -192,4 +192,4 @@ Status CreateTableTensorSliceBuilder(const string& filename,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h
index 4adf2f14dc..93dfd51ab5 100644
--- a/tensorflow/core/util/util.h
+++ b/tensorflow/core/util/util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_UTIL_H_
-#define TENSORFLOW_UTIL_UTIL_H_
+#ifndef TENSORFLOW_CORE_UTIL_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_UTIL_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -58,4 +58,4 @@ string SliceDebugString(const TensorShape& shape, const int64 flat);
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_UTIL_H_
+#endif // TENSORFLOW_CORE_UTIL_UTIL_H_
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index 72ce493c1b..b12c31c1ae 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
-#define TENSORFLOW_UTIL_WORK_SHARDER_H_
+#ifndef TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
+#define TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
#include <functional>
@@ -95,4 +95,4 @@ class Sharder {
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
+#endif // TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
diff --git a/tensorflow/docs_src/README.md b/tensorflow/docs_src/README.md
new file mode 100644
index 0000000000..5b824f1150
--- /dev/null
+++ b/tensorflow/docs_src/README.md
@@ -0,0 +1,3 @@
+# This directory has moved
+
+The new location is: https://github.com/tensorflow/docs/
diff --git a/tensorflow/docs_src/about/attribution.md b/tensorflow/docs_src/about/attribution.md
deleted file mode 100644
index a4858b400a..0000000000
--- a/tensorflow/docs_src/about/attribution.md
+++ /dev/null
@@ -1,9 +0,0 @@
-# Attribution
-
-Please only use the TensorFlow name and marks when accurately referencing this
-software distribution, and do not use our marks in a way that suggests you are
-endorsed by or otherwise affiliated with Google. When referring to our marks,
-please include the following attribution statement: "TensorFlow, the TensorFlow
-logo and any related marks are trademarks of Google Inc."
-
-
diff --git a/tensorflow/docs_src/about/bib.md b/tensorflow/docs_src/about/bib.md
deleted file mode 100644
index 5593a3d95c..0000000000
--- a/tensorflow/docs_src/about/bib.md
+++ /dev/null
@@ -1,131 +0,0 @@
-# TensorFlow White Papers
-
-This document identifies white papers about TensorFlow.
-
-## Large-Scale Machine Learning on Heterogeneous Distributed Systems
-
-[Access this white paper.](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
-
-**Abstract:** TensorFlow is an interface for expressing machine learning
-algorithms, and an implementation for executing such algorithms.
-A computation expressed using TensorFlow can be
-executed with little or no change on a wide variety of heterogeneous
-systems, ranging from mobile devices such as phones
-and tablets up to large-scale distributed systems of hundreds
-of machines and thousands of computational devices such as
-GPU cards. The system is flexible and can be used to express
-a wide variety of algorithms, including training and inference
-algorithms for deep neural network models, and it has been
-used for conducting research and for deploying machine learning
-systems into production across more than a dozen areas of
-computer science and other fields, including speech recognition,
-computer vision, robotics, information retrieval, natural
-language processing, geographic information extraction, and
-computational drug discovery. This paper describes the TensorFlow
-interface and an implementation of that interface that
-we have built at Google. The TensorFlow API and a reference
-implementation were released as an open-source package under
-the Apache 2.0 license in November, 2015 and are available at
-www.tensorflow.org.
-
-
-### In BibTeX format
-
-If you use TensorFlow in your research and would like to cite the TensorFlow
-system, we suggest you cite this whitepaper.
-
-<pre>
-@misc{tensorflow2015-whitepaper,
-title={ {TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems},
-url={https://www.tensorflow.org/},
-note={Software available from tensorflow.org},
-author={
- Mart\'{\i}n~Abadi and
- Ashish~Agarwal and
- Paul~Barham and
- Eugene~Brevdo and
- Zhifeng~Chen and
- Craig~Citro and
- Greg~S.~Corrado and
- Andy~Davis and
- Jeffrey~Dean and
- Matthieu~Devin and
- Sanjay~Ghemawat and
- Ian~Goodfellow and
- Andrew~Harp and
- Geoffrey~Irving and
- Michael~Isard and
- Yangqing Jia and
- Rafal~Jozefowicz and
- Lukasz~Kaiser and
- Manjunath~Kudlur and
- Josh~Levenberg and
- Dandelion~Man\'{e} and
- Rajat~Monga and
- Sherry~Moore and
- Derek~Murray and
- Chris~Olah and
- Mike~Schuster and
- Jonathon~Shlens and
- Benoit~Steiner and
- Ilya~Sutskever and
- Kunal~Talwar and
- Paul~Tucker and
- Vincent~Vanhoucke and
- Vijay~Vasudevan and
- Fernanda~Vi\'{e}gas and
- Oriol~Vinyals and
- Pete~Warden and
- Martin~Wattenberg and
- Martin~Wicke and
- Yuan~Yu and
- Xiaoqiang~Zheng},
- year={2015},
-}
-</pre>
-
-Or in textual form:
-
-<pre>
-Martín Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo,
-Zhifeng Chen, Craig Citro, Greg S. Corrado, Andy Davis,
-Jeffrey Dean, Matthieu Devin, Sanjay Ghemawat, Ian Goodfellow,
-Andrew Harp, Geoffrey Irving, Michael Isard, Rafal Jozefowicz, Yangqing Jia,
-Lukasz Kaiser, Manjunath Kudlur, Josh Levenberg, Dan Mané, Mike Schuster,
-Rajat Monga, Sherry Moore, Derek Murray, Chris Olah, Jonathon Shlens,
-Benoit Steiner, Ilya Sutskever, Kunal Talwar, Paul Tucker,
-Vincent Vanhoucke, Vijay Vasudevan, Fernanda Viégas,
-Oriol Vinyals, Pete Warden, Martin Wattenberg, Martin Wicke,
-Yuan Yu, and Xiaoqiang Zheng.
-TensorFlow: Large-scale machine learning on heterogeneous systems,
-2015. Software available from tensorflow.org.
-</pre>
-
-
-
-## TensorFlow: A System for Large-Scale Machine Learning
-
-[Access this white paper.](https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf)
-
-**Abstract:** TensorFlow is a machine learning system that operates at
-large scale and in heterogeneous environments. TensorFlow
-uses dataflow graphs to represent computation,
-shared state, and the operations that mutate that state. It
-maps the nodes of a dataflow graph across many machines
-in a cluster, and within a machine across multiple computational
-devices, including multicore CPUs, generalpurpose
-GPUs, and custom-designed ASICs known as
-Tensor Processing Units (TPUs). This architecture gives
-flexibility to the application developer: whereas in previous
-“parameter server” designs the management of shared
-state is built into the system, TensorFlow enables developers
-to experiment with novel optimizations and training algorithms.
-TensorFlow supports a variety of applications,
-with a focus on training and inference on deep neural networks.
-Several Google services use TensorFlow in production,
-we have released it as an open-source project, and
-it has become widely used for machine learning research.
-In this paper, we describe the TensorFlow dataflow model
-and demonstrate the compelling performance that TensorFlow
-achieves for several real-world applications.
-
diff --git a/tensorflow/docs_src/about/index.md b/tensorflow/docs_src/about/index.md
deleted file mode 100644
index dc1e9af876..0000000000
--- a/tensorflow/docs_src/about/index.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# About TensorFlow
-
-This section provides a few documents about TensorFlow itself,
-including the following:
-
- * @{$uses$TensorFlow in Use}, which provides a link to our model zoo and
- lists some popular ways that TensorFlow is being used.
- * @{$bib$TensorFlow White Papers}, which provides abstracts of white papers
- about TensorFlow.
- * @{$attribution$Attribution}, which specifies how to attribute and refer
- to TensorFlow.
diff --git a/tensorflow/docs_src/about/leftnav_files b/tensorflow/docs_src/about/leftnav_files
deleted file mode 100644
index 63763b9d9c..0000000000
--- a/tensorflow/docs_src/about/leftnav_files
+++ /dev/null
@@ -1,4 +0,0 @@
-index.md
-uses.md
-bib.md
-attribution.md
diff --git a/tensorflow/docs_src/about/uses.md b/tensorflow/docs_src/about/uses.md
deleted file mode 100644
index d3db98203e..0000000000
--- a/tensorflow/docs_src/about/uses.md
+++ /dev/null
@@ -1,68 +0,0 @@
-# TensorFlow In Use
-
-This page highlights TensorFlow models in real world use.
-
-
-## Model zoo
-
-Please visit our collection of TensorFlow models in the
-[TensorFlow Zoo](https://github.com/tensorflow/models).
-
-If you have built a model with TensorFlow, please consider publishing it in
-the Zoo.
-
-
-## Current uses
-
-This section describes some of the current uses of the TensorFlow system.
-
-> If you are using TensorFlow for research, for education, or for production
-> usage in some product, we would love to add something about your usage here.
-> Please feel free to [email us](mailto:usecases@tensorflow.org) a brief
-> description of how you're using TensorFlow, or even better, send us a
-> pull request to add an entry to this file.
-
-* **Deep Speech**
-<ul>
- <li>**Organization**: Mozilla</li>
- <li> **Domain**: Speech Recognition</li>
- <li> **Description**: A TensorFlow implementation motivated by Baidu's Deep Speech architecture.</li>
- <li> **More info**: [GitHub Repo](https://github.com/mozilla/deepspeech)</li>
-</ul>
-
-* **RankBrain**
-<ul>
- <li>**Organization**: Google</li>
- <li> **Domain**: Information Retrieval</li>
- <li> **Description**: A large-scale deployment of deep neural nets for search ranking on www.google.com.</li>
- <li> **More info**: ["Google Turning Over Its Lucrative Search to AI Machines"](http://www.bloomberg.com/news/articles/2015-10-26/google-turning-its-lucrative-web-search-over-to-ai-machines)</li>
-</ul>
-
-* **Inception Image Classification Model**
-<ul>
- <li> **Organization**: Google</li>
- <li> **Description**: Baseline model and follow on research into highly accurate computer vision models, starting with the model that won the 2014 Imagenet image classification challenge</li>
- <li> **More Info**: Baseline model described in [Arxiv paper](http://arxiv.org/abs/1409.4842)</li>
-</ul>
-
-* **SmartReply**
-<ul>
- <li> **Organization**: Google</li>
- <li> **Description**: Deep LSTM model to automatically generate email responses</li>
- <li> **More Info**: [Google research blog post](http://googleresearch.blogspot.com/2015/11/computer-respond-to-this-email.html)</li>
-</ul>
-
-* **Massively Multitask Networks for Drug Discovery**
-<ul>
- <li> **Organization**: Google and Stanford University</li>
- <li> **Domain**: Drug discovery</li>
- <li> **Description**: A deep neural network model for identifying promising drug candidates.</li>
- <li> **More info**: [Arxiv paper](http://arxiv.org/abs/1502.02072)</li>
-</ul>
-
-* **On-Device Computer Vision for OCR**
-<ul>
- <li> **Organization**: Google</li>
- <li> **Description**: On-device computer vision model to do optical character recognition to enable real-time translation.</li>
- <li> **More info**: [Google Research blog post](http://googleresearch.blogspot.com/2015/07/how-google-translate-squeezes-deep.html)</li>
-</ul>
diff --git a/tensorflow/docs_src/api_guides/cc/guide.md b/tensorflow/docs_src/api_guides/cc/guide.md
deleted file mode 100644
index 2cd645afa7..0000000000
--- a/tensorflow/docs_src/api_guides/cc/guide.md
+++ /dev/null
@@ -1,301 +0,0 @@
-# C++ API
-
-Note: By default [tensorflow.org](https://www.tensorflow.org) shows docs for the
-most recent stable version. The instructions in this doc require building from
-source. You will probably want to build from the `master` version of tensorflow.
-You should, as a result, be sure you are following the
-[`master` version of this doc](https://www.tensorflow.org/versions/master/api_guides/cc/guide),
-in case there have been any changes.
-
-Note: The C++ API is only designed to work with TensorFlow `bazel build`.
-If you need a stand-alone option use the [C-api](../../install/install_c.md).
-See [these instructions](https://docs.bazel.build/versions/master/external.html)
-for details on how to include TensorFlow as a subproject (instead of building
-your project from inside TensorFlow, as in this example).
-
-[TOC]
-
-TensorFlow's C++ API provides mechanisms for constructing and executing a data
-flow graph. The API is designed to be simple and concise: graph operations are
-clearly expressed using a "functional" construction style, including easy
-specification of names, device placement, etc., and the resulting graph can be
-efficiently run and the desired outputs fetched in a few lines of code. This
-guide explains the basic concepts and data structures needed to get started with
-TensorFlow graph construction and execution in C++.
-
-## The Basics
-
-Let's start with a simple example that illustrates graph construction and
-execution using the C++ API.
-
-```c++
-// tensorflow/cc/example/example.cc
-
-#include "tensorflow/cc/client/client_session.h"
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/tensor.h"
-
-int main() {
- using namespace tensorflow;
- using namespace tensorflow::ops;
- Scope root = Scope::NewRootScope();
- // Matrix A = [3 2; -1 0]
- auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
- // Vector b = [3 5]
- auto b = Const(root, { {3.f, 5.f} });
- // v = Ab^T
- auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
- std::vector<Tensor> outputs;
- ClientSession session(root);
- // Run and fetch v
- TF_CHECK_OK(session.Run({v}, &outputs));
- // Expect outputs[0] == [19; -3]
- LOG(INFO) << outputs[0].matrix<float>();
- return 0;
-}
-```
-
-Place this example code in the file `tensorflow/cc/example/example.cc` inside a
-clone of the
-TensorFlow
-[github repository](http://www.github.com/tensorflow/tensorflow). Also place a
-`BUILD` file in the same directory with the following contents:
-
-```python
-load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
-
-tf_cc_binary(
- name = "example",
- srcs = ["example.cc"],
- deps = [
- "//tensorflow/cc:cc_ops",
- "//tensorflow/cc:client_session",
- "//tensorflow/core:tensorflow",
- ],
-)
-```
-
-Use `tf_cc_binary` rather than Bazel's native `cc_binary` to link in necessary
-symbols from `libtensorflow_framework.so`. You should be able to build and run
-the example using the following command (be sure to run `./configure` in your
-build sandbox first):
-
-```shell
-bazel run -c opt //tensorflow/cc/example:example
-```
-
-This example shows some of the important features of the C++ API such as the
-following:
-
-* Constructing tensor constants from C++ nested initializer lists
-* Constructing and naming of TensorFlow operations
-* Specifying optional attributes to operation constructors
-* Executing and fetching the tensor values from the TensorFlow session.
-
-We will delve into the details of each below.
-
-## Graph Construction
-
-### Scope
-
-`tensorflow::Scope` is the main data structure that holds the current state
-of graph construction. A `Scope` acts as a handle to the graph being
-constructed, as well as storing TensorFlow operation properties. The `Scope`
-object is the first argument to operation constructors, and operations that use
-a given `Scope` as their first argument inherit that `Scope`'s properties, such
-as a common name prefix. Multiple `Scope`s can refer to the same graph, as
-explained further below.
-
-Create a new `Scope` object by calling `Scope::NewRootScope`. This creates
-some resources such as a graph to which operations are added. It also creates a
-`tensorflow::Status` object which will be used to indicate errors encountered
-when constructing operations. The `Scope` class has value semantics, thus, a
-`Scope` object can be freely copied and passed around.
-
-The `Scope` object returned by `Scope::NewRootScope` is referred
-to as the root scope. "Child" scopes can be constructed from the root scope by
-calling various member functions of the `Scope` class, thus forming a hierarchy
-of scopes. A child scope inherits all of the properties of the parent scope and
-typically has one property added or changed. For instance, `NewSubScope(name)`
-appends `name` to the prefix of names for operations created using the returned
-`Scope` object.
-
-Here are some of the properties controlled by a `Scope` object:
-
-* Operation names
-* Set of control dependencies for an operation
-* Device placement for an operation
-* Kernel attribute for an operation
-
-Please refer to `tensorflow::Scope` for the complete list of member functions
-that let you create child scopes with new properties.
-
-### Operation Constructors
-
-You can create graph operations with operation constructors, one C++ class per
-TensorFlow operation. Unlike the Python API which uses snake-case to name the
-operation constructors, the C++ API uses camel-case to conform to C++ coding
-style. For instance, the `MatMul` operation has a C++ class with the same name.
-
-Using this class-per-operation method, it is possible, though not recommended,
-to construct an operation as follows:
-
-```c++
-// Not recommended
-MatMul m(scope, a, b);
-```
-
-Instead, we recommend the following "functional" style for constructing
-operations:
-
-```c++
-// Recommended
-auto m = MatMul(scope, a, b);
-```
-
-The first parameter for all operation constructors is always a `Scope` object.
-Tensor inputs and mandatory attributes form the rest of the arguments.
-
-For optional arguments, constructors have an optional parameter that allows
-optional attributes. For operations with optional arguments, the constructor's
-last optional parameter is a `struct` type called `[operation]:Attrs` that
-contains data members for each optional attribute. You can construct such
-`Attrs` in multiple ways:
-
-* You can specify a single optional attribute by constructing an `Attrs` object
-using the `static` functions provided in the C++ class for the operation. For
-example:
-
-```c++
-auto m = MatMul(scope, a, b, MatMul::TransposeA(true));
-```
-
-* You can specify multiple optional attributes by chaining together functions
- available in the `Attrs` struct. For example:
-
-```c++
-auto m = MatMul(scope, a, b, MatMul::TransposeA(true).TransposeB(true));
-
-// Or, alternatively
-auto m = MatMul(scope, a, b, MatMul::Attrs().TransposeA(true).TransposeB(true));
-```
-
-The arguments and return values of operations are handled in different ways
-depending on their type:
-
-* For operations that return single tensors, the object returned by
- the operation object can be passed directly to other operation
- constructors. For example:
-
-```c++
-auto m = MatMul(scope, x, W);
-auto sum = Add(scope, m, bias);
-```
-
-* For operations producing multiple outputs, the object returned by the
- operation constructor has a member for each of the outputs. The names of those
- members are identical to the names present in the `OpDef` for the
- operation. For example:
-
-```c++
-auto u = Unique(scope, a);
-// u.y has the unique values and u.idx has the unique indices
-auto m = Add(scope, u.y, b);
-```
-
-* Operations producing a list-typed output return an object that can
- be indexed using the `[]` operator. That object can also be directly passed to
- other constructors that expect list-typed inputs. For example:
-
-```c++
-auto s = Split(scope, 0, a, 2);
-// Access elements of the returned list.
-auto b = Add(scope, s[0], s[1]);
-// Pass the list as a whole to other constructors.
-auto c = Concat(scope, s, 0);
-```
-
-### Constants
-
-You may pass many different types of C++ values directly to tensor
-constants. You may explicitly create a tensor constant by calling the
-`tensorflow::ops::Const` function from various kinds of C++ values. For
-example:
-
-* Scalars
-
-```c++
-auto f = Const(scope, 42.0f);
-auto s = Const(scope, "hello world!");
-```
-
-* Nested initializer lists
-
-```c++
-// 2x2 matrix
-auto c1 = Const(scope, { {1, 2}, {2, 4} });
-// 1x3x1 tensor
-auto c2 = Const(scope, { { {1}, {2}, {3} } });
-// 1x2x0 tensor
-auto c3 = ops::Const(scope, { { {}, {} } });
-```
-
-* Shapes explicitly specified
-
-```c++
-// 2x2 matrix with all elements = 10
-auto c1 = Const(scope, 10, /* shape */ {2, 2});
-// 1x3x2x1 tensor
-auto c2 = Const(scope, {1, 2, 3, 4, 5, 6}, /* shape */ {1, 3, 2, 1});
-```
-
-You may directly pass constants to other operation constructors, either by
-explicitly constructing one using the `Const` function, or implicitly as any of
-the above types of C++ values. For example:
-
-```c++
-// [1 1] * [41; 1]
-auto x = MatMul(scope, { {1, 1} }, { {41}, {1} });
-// [1 2 3 4] + 10
-auto y = Add(scope, {1, 2, 3, 4}, 10);
-```
-
-## Graph Execution
-
-When executing a graph, you will need a session. The C++ API provides a
-`tensorflow::ClientSession` class that will execute ops created by the
-operation constructors. TensorFlow will automatically determine which parts of
-the graph need to be executed, and what values need feeding. For example:
-
-```c++
-Scope root = Scope::NewRootScope();
-auto c = Const(root, { {1, 1} });
-auto m = MatMul(root, c, { {42}, {1} });
-
-ClientSession session(root);
-std::vector<Tensor> outputs;
-session.Run({m}, &outputs);
-// outputs[0] == {42}
-```
-
-Similarly, the object returned by the operation constructor can be used as the
-argument to specify a value being fed when executing the graph. Furthermore, the
-value to feed can be specified with the different kinds of C++ values used to
-specify tensor constants. For example:
-
-```c++
-Scope root = Scope::NewRootScope();
-auto a = Placeholder(root, DT_INT32);
-// [3 3; 3 3]
-auto b = Const(root, 3, {2, 2});
-auto c = Add(root, a, b);
-ClientSession session(root);
-std::vector<Tensor> outputs;
-
-// Feed a <- [1 2; 3 4]
-session.Run({ {a, { {1, 2}, {3, 4} } } }, {c}, &outputs);
-// outputs[0] == [4 5; 6 7]
-```
-
-Please see the `tensorflow::Tensor` documentation for more information on how
-to use the execution output.
diff --git a/tensorflow/docs_src/api_guides/python/array_ops.md b/tensorflow/docs_src/api_guides/python/array_ops.md
deleted file mode 100644
index ddeea80c56..0000000000
--- a/tensorflow/docs_src/api_guides/python/array_ops.md
+++ /dev/null
@@ -1,87 +0,0 @@
-# Tensor Transformations
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Casting
-
-TensorFlow provides several operations that you can use to cast tensor data
-types in your graph.
-
-* `tf.string_to_number`
-* `tf.to_double`
-* `tf.to_float`
-* `tf.to_bfloat16`
-* `tf.to_int32`
-* `tf.to_int64`
-* `tf.cast`
-* `tf.bitcast`
-* `tf.saturate_cast`
-
-## Shapes and Shaping
-
-TensorFlow provides several operations that you can use to determine the shape
-of a tensor and change the shape of a tensor.
-
-* `tf.broadcast_dynamic_shape`
-* `tf.broadcast_static_shape`
-* `tf.shape`
-* `tf.shape_n`
-* `tf.size`
-* `tf.rank`
-* `tf.reshape`
-* `tf.squeeze`
-* `tf.expand_dims`
-* `tf.meshgrid`
-
-## Slicing and Joining
-
-TensorFlow provides several operations to slice or extract parts of a tensor,
-or join multiple tensors together.
-
-* `tf.slice`
-* `tf.strided_slice`
-* `tf.split`
-* `tf.tile`
-* `tf.pad`
-* `tf.concat`
-* `tf.stack`
-* `tf.parallel_stack`
-* `tf.unstack`
-* `tf.reverse_sequence`
-* `tf.reverse`
-* `tf.reverse_v2`
-* `tf.transpose`
-* `tf.extract_image_patches`
-* `tf.space_to_batch_nd`
-* `tf.space_to_batch`
-* `tf.required_space_to_batch_paddings`
-* `tf.batch_to_space_nd`
-* `tf.batch_to_space`
-* `tf.space_to_depth`
-* `tf.depth_to_space`
-* `tf.gather`
-* `tf.gather_nd`
-* `tf.unique_with_counts`
-* `tf.scatter_nd`
-* `tf.dynamic_partition`
-* `tf.dynamic_stitch`
-* `tf.boolean_mask`
-* `tf.one_hot`
-* `tf.sequence_mask`
-* `tf.dequantize`
-* `tf.quantize_v2`
-* `tf.quantized_concat`
-* `tf.setdiff1d`
-
-## Fake quantization
-Operations used to help train for better quantization accuracy.
-
-* `tf.fake_quant_with_min_max_args`
-* `tf.fake_quant_with_min_max_args_gradient`
-* `tf.fake_quant_with_min_max_vars`
-* `tf.fake_quant_with_min_max_vars_gradient`
-* `tf.fake_quant_with_min_max_vars_per_channel`
-* `tf.fake_quant_with_min_max_vars_per_channel_gradient`
diff --git a/tensorflow/docs_src/api_guides/python/check_ops.md b/tensorflow/docs_src/api_guides/python/check_ops.md
deleted file mode 100644
index b52fdaa3ab..0000000000
--- a/tensorflow/docs_src/api_guides/python/check_ops.md
+++ /dev/null
@@ -1,19 +0,0 @@
-# Asserts and boolean checks
-
-* `tf.assert_negative`
-* `tf.assert_positive`
-* `tf.assert_proper_iterable`
-* `tf.assert_non_negative`
-* `tf.assert_non_positive`
-* `tf.assert_equal`
-* `tf.assert_integer`
-* `tf.assert_less`
-* `tf.assert_less_equal`
-* `tf.assert_greater`
-* `tf.assert_greater_equal`
-* `tf.assert_rank`
-* `tf.assert_rank_at_least`
-* `tf.assert_type`
-* `tf.is_non_decreasing`
-* `tf.is_numeric_tensor`
-* `tf.is_strictly_increasing`
diff --git a/tensorflow/docs_src/api_guides/python/client.md b/tensorflow/docs_src/api_guides/python/client.md
deleted file mode 100644
index 56367e6671..0000000000
--- a/tensorflow/docs_src/api_guides/python/client.md
+++ /dev/null
@@ -1,36 +0,0 @@
-# Running Graphs
-[TOC]
-
-This library contains classes for launching graphs and executing operations.
-
-@{$guide/low_level_intro$This guide} has examples of how a graph
-is launched in a `tf.Session`.
-
-## Session management
-
-* `tf.Session`
-* `tf.InteractiveSession`
-* `tf.get_default_session`
-
-## Error classes and convenience functions
-
-* `tf.OpError`
-* `tf.errors.CancelledError`
-* `tf.errors.UnknownError`
-* `tf.errors.InvalidArgumentError`
-* `tf.errors.DeadlineExceededError`
-* `tf.errors.NotFoundError`
-* `tf.errors.AlreadyExistsError`
-* `tf.errors.PermissionDeniedError`
-* `tf.errors.UnauthenticatedError`
-* `tf.errors.ResourceExhaustedError`
-* `tf.errors.FailedPreconditionError`
-* `tf.errors.AbortedError`
-* `tf.errors.OutOfRangeError`
-* `tf.errors.UnimplementedError`
-* `tf.errors.InternalError`
-* `tf.errors.UnavailableError`
-* `tf.errors.DataLossError`
-* `tf.errors.exception_type_from_error_code`
-* `tf.errors.error_code_from_exception_type`
-* `tf.errors.raise_exception_on_not_ok_status`
diff --git a/tensorflow/docs_src/api_guides/python/constant_op.md b/tensorflow/docs_src/api_guides/python/constant_op.md
deleted file mode 100644
index 498ec3db5d..0000000000
--- a/tensorflow/docs_src/api_guides/python/constant_op.md
+++ /dev/null
@@ -1,87 +0,0 @@
-# Constants, Sequences, and Random Values
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Constant Value Tensors
-
-TensorFlow provides several operations that you can use to generate constants.
-
-* `tf.zeros`
-* `tf.zeros_like`
-* `tf.ones`
-* `tf.ones_like`
-* `tf.fill`
-* `tf.constant`
-
-## Sequences
-
-* `tf.linspace`
-* `tf.range`
-
-## Random Tensors
-
-TensorFlow has several ops that create random tensors with different
-distributions. The random ops are stateful, and create new random values each
-time they are evaluated.
-
-The `seed` keyword argument in these functions acts in conjunction with
-the graph-level random seed. Changing either the graph-level seed using
-`tf.set_random_seed` or the
-op-level seed will change the underlying seed of these operations. Setting
-neither graph-level nor op-level seed, results in a random seed for all
-operations.
-See `tf.set_random_seed`
-for details on the interaction between operation-level and graph-level random
-seeds.
-
-### Examples:
-
-```python
-# Create a tensor of shape [2, 3] consisting of random normal values, with mean
-# -1 and standard deviation 4.
-norm = tf.random_normal([2, 3], mean=-1, stddev=4)
-
-# Shuffle the first dimension of a tensor
-c = tf.constant([[1, 2], [3, 4], [5, 6]])
-shuff = tf.random_shuffle(c)
-
-# Each time we run these ops, different results are generated
-sess = tf.Session()
-print(sess.run(norm))
-print(sess.run(norm))
-
-# Set an op-level seed to generate repeatable sequences across sessions.
-norm = tf.random_normal([2, 3], seed=1234)
-sess = tf.Session()
-print(sess.run(norm))
-print(sess.run(norm))
-sess = tf.Session()
-print(sess.run(norm))
-print(sess.run(norm))
-```
-
-Another common use of random values is the initialization of variables. Also see
-the @{$variables$Variables How To}.
-
-```python
-# Use random uniform values in [0, 1) as the initializer for a variable of shape
-# [2, 3]. The default type is float32.
-var = tf.Variable(tf.random_uniform([2, 3]), name="var")
-init = tf.global_variables_initializer()
-
-sess = tf.Session()
-sess.run(init)
-print(sess.run(var))
-```
-
-* `tf.random_normal`
-* `tf.truncated_normal`
-* `tf.random_uniform`
-* `tf.random_shuffle`
-* `tf.random_crop`
-* `tf.multinomial`
-* `tf.random_gamma`
-* `tf.set_random_seed`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.crf.md b/tensorflow/docs_src/api_guides/python/contrib.crf.md
deleted file mode 100644
index a544f136b3..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.crf.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# CRF (contrib)
-
-Linear-chain CRF layer.
-
-* `tf.contrib.crf.crf_sequence_score`
-* `tf.contrib.crf.crf_log_norm`
-* `tf.contrib.crf.crf_log_likelihood`
-* `tf.contrib.crf.crf_unary_score`
-* `tf.contrib.crf.crf_binary_score`
-* `tf.contrib.crf.CrfForwardRnnCell`
-* `tf.contrib.crf.viterbi_decode`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.ffmpeg.md b/tensorflow/docs_src/api_guides/python/contrib.ffmpeg.md
deleted file mode 100644
index 7df7547131..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.ffmpeg.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# FFmpeg (contrib)
-[TOC]
-
-## Encoding and decoding audio using FFmpeg
-
-TensorFlow provides Ops to decode and encode audio files using the
-[FFmpeg](https://www.ffmpeg.org/) library. FFmpeg must be
-locally [installed](https://ffmpeg.org/download.html) for these Ops to succeed.
-
-Example:
-
-```python
-from tensorflow.contrib import ffmpeg
-
-audio_binary = tf.read_file('song.mp3')
-waveform = ffmpeg.decode_audio(
- audio_binary, file_format='mp3', samples_per_second=44100, channel_count=2)
-uncompressed_binary = ffmpeg.encode_audio(
- waveform, file_format='wav', samples_per_second=44100)
-```
-
-* `tf.contrib.ffmpeg.decode_audio`
-* `tf.contrib.ffmpeg.encode_audio`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.framework.md b/tensorflow/docs_src/api_guides/python/contrib.framework.md
deleted file mode 100644
index 00fb8b0ac3..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.framework.md
+++ /dev/null
@@ -1,64 +0,0 @@
-# Framework (contrib)
-[TOC]
-
-Framework utilities.
-
-* `tf.contrib.framework.assert_same_float_dtype`
-* `tf.contrib.framework.assert_scalar`
-* `tf.contrib.framework.assert_scalar_int`
-* `tf.convert_to_tensor_or_sparse_tensor`
-* `tf.contrib.framework.get_graph_from_inputs`
-* `tf.is_numeric_tensor`
-* `tf.is_non_decreasing`
-* `tf.is_strictly_increasing`
-* `tf.contrib.framework.is_tensor`
-* `tf.contrib.framework.reduce_sum_n`
-* `tf.contrib.framework.remove_squeezable_dimensions`
-* `tf.contrib.framework.with_shape`
-* `tf.contrib.framework.with_same_shape`
-
-## Deprecation
-
-* `tf.contrib.framework.deprecated`
-* `tf.contrib.framework.deprecated_args`
-* `tf.contrib.framework.deprecated_arg_values`
-
-## Arg_Scope
-
-* `tf.contrib.framework.arg_scope`
-* `tf.contrib.framework.add_arg_scope`
-* `tf.contrib.framework.has_arg_scope`
-* `tf.contrib.framework.arg_scoped_arguments`
-
-## Variables
-
-* `tf.contrib.framework.add_model_variable`
-* `tf.train.assert_global_step`
-* `tf.contrib.framework.assert_or_get_global_step`
-* `tf.contrib.framework.assign_from_checkpoint`
-* `tf.contrib.framework.assign_from_checkpoint_fn`
-* `tf.contrib.framework.assign_from_values`
-* `tf.contrib.framework.assign_from_values_fn`
-* `tf.contrib.framework.create_global_step`
-* `tf.contrib.framework.filter_variables`
-* `tf.train.get_global_step`
-* `tf.contrib.framework.get_or_create_global_step`
-* `tf.contrib.framework.get_local_variables`
-* `tf.contrib.framework.get_model_variables`
-* `tf.contrib.framework.get_unique_variable`
-* `tf.contrib.framework.get_variables_by_name`
-* `tf.contrib.framework.get_variables_by_suffix`
-* `tf.contrib.framework.get_variables_to_restore`
-* `tf.contrib.framework.get_variables`
-* `tf.contrib.framework.local_variable`
-* `tf.contrib.framework.model_variable`
-* `tf.contrib.framework.variable`
-* `tf.contrib.framework.VariableDeviceChooser`
-* `tf.contrib.framework.zero_initializer`
-
-## Checkpoint utilities
-
-* `tf.contrib.framework.load_checkpoint`
-* `tf.contrib.framework.list_variables`
-* `tf.contrib.framework.load_variable`
-* `tf.contrib.framework.init_from_checkpoint`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.graph_editor.md b/tensorflow/docs_src/api_guides/python/contrib.graph_editor.md
deleted file mode 100644
index 8ce49b952b..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.graph_editor.md
+++ /dev/null
@@ -1,177 +0,0 @@
-# Graph Editor (contrib)
-[TOC]
-
-TensorFlow Graph Editor.
-
-The TensorFlow Graph Editor library allows for modification of an existing
-`tf.Graph` instance in-place.
-
-The author's github username is [purpledog](https://github.com/purpledog).
-
-## Library overview
-
-Appending new nodes is the only graph editing operation allowed by the
-TensorFlow core library. The Graph Editor library is an attempt to allow for
-other kinds of editing operations, namely, *rerouting* and *transforming*.
-
-* *rerouting* is a local operation consisting in re-plugging existing tensors
- (the edges of the graph). Operations (the nodes) are not modified by this
- operation. For example, rerouting can be used to insert an operation adding
- noise in place of an existing tensor.
-* *transforming* is a global operation consisting in transforming a graph into
- another. By default, a transformation is a simple copy but it can be
- customized to achieved other goals. For instance, a graph can be transformed
- into another one in which noise is added after all the operations of a
- specific type.
-
-**Important: modifying a graph in-place with the Graph Editor must be done
-`offline`, that is, without any active sessions.**
-
-Of course new operations can be appended online but Graph Editor specific
-operations like rerouting and transforming can currently only be done offline.
-
-Here is an example of what you **cannot** do:
-
-* Build a graph.
-* Create a session and run the graph.
-* Modify the graph with the Graph Editor.
-* Re-run the graph with the `same` previously created session.
-
-To edit an already running graph, follow these steps:
-
-* Build a graph.
-* Create a session and run the graph.
-* Save the graph state and terminate the session
-* Modify the graph with the Graph Editor.
-* create a new session and restore the graph state
-* Re-run the graph with the newly created session.
-
-Note that this procedure is very costly because a new session must be created
-after any modifications. Among other things, it takes time because the entire
-graph state must be saved and restored again.
-
-## Sub-graph
-
-Most of the functions in the Graph Editor library operate on *sub-graph*.
-More precisely, they take as input arguments instances of the SubGraphView class
-(or anything which can be converted to it). Doing so allows the same function
-to transparently operate on single operations as well as sub-graph of any size.
-
-A subgraph can be created in several ways:
-
-* using a list of ops:
-
- ```python
- my_sgv = ge.sgv(ops)
- ```
-
-* from a name scope:
-
- ```python
- my_sgv = ge.sgv_scope("foo/bar", graph=tf.get_default_graph())
- ```
-
-* using regular expression:
-
- ```python
- my_sgv = ge.sgv("foo/.*/.*read$", graph=tf.get_default_graph())
- ```
-
-Note that the Graph Editor is meant to manipulate several graphs at the same
-time, typically during transform or copy operation. For that reason,
-to avoid any confusion, the default graph is never used and the graph on
-which to operate must always be given explicitly. This is the reason why
-*`graph=tf.get_default_graph()`* is used in the code snippets above.
-
-## Modules overview
-
-* util: utility functions.
-* select: various selection methods of TensorFlow tensors and operations.
-* match: TensorFlow graph matching. Think of this as regular expressions for
- graphs (but not quite yet).
-* reroute: various ways of rerouting tensors to different consuming ops like
- *swap* or *reroute_a2b*.
-* subgraph: the SubGraphView class, which enables subgraph manipulations in a
- TensorFlow `tf.Graph`.
-* edit: various editing functions operating on subgraphs like *detach*,
- *connect* or *bypass*.
-* transform: the Transformer class, which enables transforming
- (or simply copying) a subgraph into another one.
-
-## Module: util
-
-* `tf.contrib.graph_editor.make_list_of_op`
-* `tf.contrib.graph_editor.get_tensors`
-* `tf.contrib.graph_editor.make_list_of_t`
-* `tf.contrib.graph_editor.get_generating_ops`
-* `tf.contrib.graph_editor.get_consuming_ops`
-* `tf.contrib.graph_editor.ControlOutputs`
-* `tf.contrib.graph_editor.placeholder_name`
-* `tf.contrib.graph_editor.make_placeholder_from_tensor`
-* `tf.contrib.graph_editor.make_placeholder_from_dtype_and_shape`
-
-## Module: select
-
-* `tf.contrib.graph_editor.filter_ts`
-* `tf.contrib.graph_editor.filter_ts_from_regex`
-* `tf.contrib.graph_editor.filter_ops`
-* `tf.contrib.graph_editor.filter_ops_from_regex`
-* `tf.contrib.graph_editor.get_name_scope_ops`
-* `tf.contrib.graph_editor.check_cios`
-* `tf.contrib.graph_editor.get_ops_ios`
-* `tf.contrib.graph_editor.compute_boundary_ts`
-* `tf.contrib.graph_editor.get_within_boundary_ops`
-* `tf.contrib.graph_editor.get_forward_walk_ops`
-* `tf.contrib.graph_editor.get_backward_walk_ops`
-* `tf.contrib.graph_editor.get_walks_intersection_ops`
-* `tf.contrib.graph_editor.get_walks_union_ops`
-* `tf.contrib.graph_editor.select_ops`
-* `tf.contrib.graph_editor.select_ts`
-* `tf.contrib.graph_editor.select_ops_and_ts`
-
-## Module: subgraph
-
-* `tf.contrib.graph_editor.SubGraphView`
-* `tf.contrib.graph_editor.make_view`
-* `tf.contrib.graph_editor.make_view_from_scope`
-
-## Module: reroute
-
-* `tf.contrib.graph_editor.swap_ts`
-* `tf.contrib.graph_editor.reroute_ts`
-* `tf.contrib.graph_editor.swap_inputs`
-* `tf.contrib.graph_editor.reroute_inputs`
-* `tf.contrib.graph_editor.swap_outputs`
-* `tf.contrib.graph_editor.reroute_outputs`
-* `tf.contrib.graph_editor.swap_ios`
-* `tf.contrib.graph_editor.reroute_ios`
-* `tf.contrib.graph_editor.remove_control_inputs`
-* `tf.contrib.graph_editor.add_control_inputs`
-
-## Module: edit
-
-* `tf.contrib.graph_editor.detach_control_inputs`
-* `tf.contrib.graph_editor.detach_control_outputs`
-* `tf.contrib.graph_editor.detach_inputs`
-* `tf.contrib.graph_editor.detach_outputs`
-* `tf.contrib.graph_editor.detach`
-* `tf.contrib.graph_editor.connect`
-* `tf.contrib.graph_editor.bypass`
-
-## Module: transform
-
-* `tf.contrib.graph_editor.replace_t_with_placeholder_handler`
-* `tf.contrib.graph_editor.keep_t_if_possible_handler`
-* `tf.contrib.graph_editor.assign_renamed_collections_handler`
-* `tf.contrib.graph_editor.transform_op_if_inside_handler`
-* `tf.contrib.graph_editor.copy_op_handler`
-* `tf.contrib.graph_editor.Transformer`
-* `tf.contrib.graph_editor.copy`
-* `tf.contrib.graph_editor.copy_with_input_replacements`
-* `tf.contrib.graph_editor.graph_replace`
-
-## Useful aliases
-
-* `tf.contrib.graph_editor.ph`
-* `tf.contrib.graph_editor.sgv`
-* `tf.contrib.graph_editor.sgv_scope`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.integrate.md b/tensorflow/docs_src/api_guides/python/contrib.integrate.md
deleted file mode 100644
index a70d202ab5..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.integrate.md
+++ /dev/null
@@ -1,41 +0,0 @@
-# Integrate (contrib)
-[TOC]
-
-Integration and ODE solvers for TensorFlow.
-
-## Example: Lorenz attractor
-
-We can use `odeint` to solve the
-[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
-differential equations, a prototypical example of chaotic dynamics:
-
-```python
-rho = 28.0
-sigma = 10.0
-beta = 8.0/3.0
-
-def lorenz_equation(state, t):
- x, y, z = tf.unstack(state)
- dx = sigma * (y - x)
- dy = x * (rho - z) - y
- dz = x * y - beta * z
- return tf.stack([dx, dy, dz])
-
-init_state = tf.constant([0, 2, 20], dtype=tf.float64)
-t = np.linspace(0, 50, num=5000)
-tensor_state, tensor_info = tf.contrib.integrate.odeint(
- lorenz_equation, init_state, t, full_output=True)
-
-sess = tf.Session()
-state, info = sess.run([tensor_state, tensor_info])
-x, y, z = state.T
-plt.plot(x, z)
-```
-
-<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/lorenz_attractor.png" alt>
-</div>
-
-## Ops
-
-* `tf.contrib.integrate.odeint`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.layers.md b/tensorflow/docs_src/api_guides/python/contrib.layers.md
deleted file mode 100644
index 4c176a129c..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.layers.md
+++ /dev/null
@@ -1,109 +0,0 @@
-# Layers (contrib)
-[TOC]
-
-Ops for building neural network layers, regularizers, summaries, etc.
-
-## Higher level ops for building neural network layers
-
-This package provides several ops that take care of creating variables that are
-used internally in a consistent way and provide the building blocks for many
-common machine learning algorithms.
-
-* `tf.contrib.layers.avg_pool2d`
-* `tf.contrib.layers.batch_norm`
-* `tf.contrib.layers.convolution2d`
-* `tf.contrib.layers.conv2d_in_plane`
-* `tf.contrib.layers.convolution2d_in_plane`
-* `tf.nn.conv2d_transpose`
-* `tf.contrib.layers.convolution2d_transpose`
-* `tf.nn.dropout`
-* `tf.contrib.layers.flatten`
-* `tf.contrib.layers.fully_connected`
-* `tf.contrib.layers.layer_norm`
-* `tf.contrib.layers.max_pool2d`
-* `tf.contrib.layers.one_hot_encoding`
-* `tf.nn.relu`
-* `tf.nn.relu6`
-* `tf.contrib.layers.repeat`
-* `tf.contrib.layers.safe_embedding_lookup_sparse`
-* `tf.nn.separable_conv2d`
-* `tf.contrib.layers.separable_convolution2d`
-* `tf.nn.softmax`
-* `tf.stack`
-* `tf.contrib.layers.unit_norm`
-* `tf.contrib.layers.embed_sequence`
-
-Aliases for fully_connected which set a default activation function are
-available: `relu`, `relu6` and `linear`.
-
-`stack` operation is also available. It builds a stack of layers by applying
-a layer repeatedly.
-
-## Regularizers
-
-Regularization can help prevent overfitting. These have the signature
-`fn(weights)`. The loss is typically added to
-`tf.GraphKeys.REGULARIZATION_LOSSES`.
-
-* `tf.contrib.layers.apply_regularization`
-* `tf.contrib.layers.l1_regularizer`
-* `tf.contrib.layers.l2_regularizer`
-* `tf.contrib.layers.sum_regularizer`
-
-## Initializers
-
-Initializers are used to initialize variables with sensible values given their
-size, data type, and purpose.
-
-* `tf.contrib.layers.xavier_initializer`
-* `tf.contrib.layers.xavier_initializer_conv2d`
-* `tf.contrib.layers.variance_scaling_initializer`
-
-## Optimization
-
-Optimize weights given a loss.
-
-* `tf.contrib.layers.optimize_loss`
-
-## Summaries
-
-Helper functions to summarize specific variables or ops.
-
-* `tf.contrib.layers.summarize_activation`
-* `tf.contrib.layers.summarize_tensor`
-* `tf.contrib.layers.summarize_tensors`
-* `tf.contrib.layers.summarize_collection`
-
-The layers module defines convenience functions `summarize_variables`,
-`summarize_weights` and `summarize_biases`, which set the `collection` argument
-of `summarize_collection` to `VARIABLES`, `WEIGHTS` and `BIASES`, respectively.
-
-* `tf.contrib.layers.summarize_activations`
-
-## Feature columns
-
-Feature columns provide a mechanism to map data to a model.
-
-* `tf.contrib.layers.bucketized_column`
-* `tf.contrib.layers.check_feature_columns`
-* `tf.contrib.layers.create_feature_spec_for_parsing`
-* `tf.contrib.layers.crossed_column`
-* `tf.contrib.layers.embedding_column`
-* `tf.contrib.layers.scattered_embedding_column`
-* `tf.contrib.layers.input_from_feature_columns`
-* `tf.contrib.layers.joint_weighted_sum_from_feature_columns`
-* `tf.contrib.layers.make_place_holder_tensors_for_base_features`
-* `tf.contrib.layers.multi_class_target`
-* `tf.contrib.layers.one_hot_column`
-* `tf.contrib.layers.parse_feature_columns_from_examples`
-* `tf.contrib.layers.parse_feature_columns_from_sequence_examples`
-* `tf.contrib.layers.real_valued_column`
-* `tf.contrib.layers.shared_embedding_columns`
-* `tf.contrib.layers.sparse_column_with_hash_bucket`
-* `tf.contrib.layers.sparse_column_with_integerized_feature`
-* `tf.contrib.layers.sparse_column_with_keys`
-* `tf.contrib.layers.sparse_column_with_vocabulary_file`
-* `tf.contrib.layers.weighted_sparse_column`
-* `tf.contrib.layers.weighted_sum_from_feature_columns`
-* `tf.contrib.layers.infer_real_valued_columns`
-* `tf.contrib.layers.sequence_input_from_feature_columns`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.learn.md b/tensorflow/docs_src/api_guides/python/contrib.learn.md
deleted file mode 100644
index 635849ead5..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.learn.md
+++ /dev/null
@@ -1,63 +0,0 @@
-# Learn (contrib)
-[TOC]
-
-High level API for learning with TensorFlow.
-
-## Estimators
-
-Train and evaluate TensorFlow models.
-
-* `tf.contrib.learn.BaseEstimator`
-* `tf.contrib.learn.Estimator`
-* `tf.contrib.learn.Trainable`
-* `tf.contrib.learn.Evaluable`
-* `tf.contrib.learn.KMeansClustering`
-* `tf.contrib.learn.ModeKeys`
-* `tf.contrib.learn.ModelFnOps`
-* `tf.contrib.learn.MetricSpec`
-* `tf.contrib.learn.PredictionKey`
-* `tf.contrib.learn.DNNClassifier`
-* `tf.contrib.learn.DNNRegressor`
-* `tf.contrib.learn.DNNLinearCombinedRegressor`
-* `tf.contrib.learn.DNNLinearCombinedClassifier`
-* `tf.contrib.learn.LinearClassifier`
-* `tf.contrib.learn.LinearRegressor`
-* `tf.contrib.learn.LogisticRegressor`
-
-## Distributed training utilities
-
-* `tf.contrib.learn.Experiment`
-* `tf.contrib.learn.ExportStrategy`
-* `tf.contrib.learn.TaskType`
-
-## Graph actions
-
-Perform various training, evaluation, and inference actions on a graph.
-
-* `tf.train.NanLossDuringTrainingError`
-* `tf.contrib.learn.RunConfig`
-* `tf.contrib.learn.evaluate`
-* `tf.contrib.learn.infer`
-* `tf.contrib.learn.run_feeds`
-* `tf.contrib.learn.run_n`
-* `tf.contrib.learn.train`
-
-## Input processing
-
-Queue and read batched input data.
-
-* `tf.contrib.learn.extract_dask_data`
-* `tf.contrib.learn.extract_dask_labels`
-* `tf.contrib.learn.extract_pandas_data`
-* `tf.contrib.learn.extract_pandas_labels`
-* `tf.contrib.learn.extract_pandas_matrix`
-* `tf.contrib.learn.infer_real_valued_columns_from_input`
-* `tf.contrib.learn.infer_real_valued_columns_from_input_fn`
-* `tf.contrib.learn.read_batch_examples`
-* `tf.contrib.learn.read_batch_features`
-* `tf.contrib.learn.read_batch_record_features`
-
-Export utilities
-
-* `tf.contrib.learn.build_parsing_serving_input_fn`
-* `tf.contrib.learn.ProblemType`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.linalg.md b/tensorflow/docs_src/api_guides/python/contrib.linalg.md
deleted file mode 100644
index 3055449dc2..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.linalg.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# Linear Algebra (contrib)
-[TOC]
-
-Linear algebra libraries for TensorFlow.
-
-## `LinearOperator`
-
-Subclasses of `LinearOperator` provide a access to common methods on a
-(batch) matrix, without the need to materialize the matrix. This allows:
-
-* Matrix free computations
-* Different operators to take advantage of special structure, while providing a
- consistent API to users.
-
-### Base class
-
-* `tf.contrib.linalg.LinearOperator`
-
-### Individual operators
-
-* `tf.contrib.linalg.LinearOperatorDiag`
-* `tf.contrib.linalg.LinearOperatorIdentity`
-* `tf.contrib.linalg.LinearOperatorScaledIdentity`
-* `tf.contrib.linalg.LinearOperatorFullMatrix`
-* `tf.contrib.linalg.LinearOperatorLowerTriangular`
-* `tf.contrib.linalg.LinearOperatorLowRankUpdate`
-
-### Transformations and Combinations of operators
-
-* `tf.contrib.linalg.LinearOperatorComposition`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.losses.md b/tensorflow/docs_src/api_guides/python/contrib.losses.md
deleted file mode 100644
index 8787454af6..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.losses.md
+++ /dev/null
@@ -1,125 +0,0 @@
-# Losses (contrib)
-
-## Deprecated
-
-This module is deprecated. Instructions for updating: Use `tf.losses` instead.
-
-## Loss operations for use in neural networks.
-
-Note: By default, all the losses are collected into the `GraphKeys.LOSSES`
-collection.
-
-All of the loss functions take a pair of predictions and ground truth labels,
-from which the loss is computed. It is assumed that the shape of both these
-tensors is of the form [batch_size, d1, ... dN] where `batch_size` is the number
-of samples in the batch and `d1` ... `dN` are the remaining dimensions.
-
-It is common, when training with multiple loss functions, to adjust the relative
-strengths of individual losses. This is performed by rescaling the losses via
-a `weight` parameter passed to the loss functions. For example, if we were
-training with both log_loss and mean_squared_error, and we wished that the
-log_loss penalty be twice as severe as the mean_squared_error, we would
-implement this as:
-
-```python
- # Explicitly set the weight.
- tf.contrib.losses.log(predictions, labels, weight=2.0)
-
- # Uses default weight of 1.0
- tf.contrib.losses.mean_squared_error(predictions, labels)
-
- # All the losses are collected into the `GraphKeys.LOSSES` collection.
- losses = tf.get_collection(tf.GraphKeys.LOSSES)
-```
-
-While specifying a scalar loss rescales the loss over the entire batch,
-we sometimes want to rescale the loss per batch sample. For example, if we have
-certain examples that matter more to us to get correctly, we might want to have
-a higher loss that other samples whose mistakes matter less. In this case, we
-can provide a weight vector of length `batch_size` which results in the loss
-for each sample in the batch being scaled by the corresponding weight element.
-For example, consider the case of a classification problem where we want to
-maximize our accuracy but we especially interested in obtaining high accuracy
-for a specific class:
-
-```python
- inputs, labels = LoadData(batch_size=3)
- logits = MyModelPredictions(inputs)
-
- # Ensures that the loss for examples whose ground truth class is `3` is 5x
- # higher than the loss for all other examples.
- weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
-
- onehot_labels = tf.one_hot(labels, num_classes=5)
- tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
-```
-
-Finally, in certain cases, we may want to specify a different loss for every
-single measurable value. For example, if we are performing per-pixel depth
-prediction, or per-pixel denoising, a single batch sample has P values where P
-is the number of pixels in the image. For many losses, the number of measurable
-values matches the number of elements in the predictions and labels tensors.
-For others, such as softmax_cross_entropy and cosine_distance, the
-loss functions reduces the dimensions of the inputs to produces a tensor of
-losses for each measurable value. For example, softmax_cross_entropy takes as
-input predictions and labels of dimension [batch_size, num_classes] but the
-number of measurable values is [batch_size]. Consequently, when passing a weight
-tensor to specify a different loss for every measurable value, the dimension of
-the tensor will depend on the loss being used.
-
-For a concrete example, consider the case of per-pixel depth prediction where
-certain ground truth depth values are missing (due to sensor noise in the
-capture process). In this case, we want to assign zero weight to losses for
-these predictions.
-
-```python
- # 'depths' that are missing have a value of 0:
- images, depths = LoadData(...)
- predictions = MyModelPredictions(images)
-
- weight = tf.cast(tf.greater(depths, 0), tf.float32)
- loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
-```
-
-Note that when using weights for the losses, the final average is computed
-by rescaling the losses by the weights and then dividing by the total number of
-non-zero samples. For an arbitrary set of weights, this may not necessarily
-produce a weighted average. Instead, it simply and transparently rescales the
-per-element losses before averaging over the number of observations. For example
-if the losses computed by the loss function is an array [4, 1, 2, 3] and the
-weights are an array [1, 0.5, 3, 9], then the average loss is:
-
-```python
- (4*1 + 1*0.5 + 2*3 + 3*9) / 4
-```
-
-However, with a single loss function and an arbitrary set of weights, one can
-still easily create a loss function such that the resulting loss is a
-weighted average over the individual prediction errors:
-
-
-```python
- images, labels = LoadData(...)
- predictions = MyModelPredictions(images)
-
- weight = MyComplicatedWeightingFunction(labels)
- weight = tf.div(weight, tf.size(weight))
- loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
-```
-
-* `tf.contrib.losses.absolute_difference`
-* `tf.contrib.losses.add_loss`
-* `tf.contrib.losses.hinge_loss`
-* `tf.contrib.losses.compute_weighted_loss`
-* `tf.contrib.losses.cosine_distance`
-* `tf.contrib.losses.get_losses`
-* `tf.contrib.losses.get_regularization_losses`
-* `tf.contrib.losses.get_total_loss`
-* `tf.contrib.losses.log_loss`
-* `tf.contrib.losses.mean_pairwise_squared_error`
-* `tf.contrib.losses.mean_squared_error`
-* `tf.contrib.losses.sigmoid_cross_entropy`
-* `tf.contrib.losses.softmax_cross_entropy`
-* `tf.contrib.losses.sparse_softmax_cross_entropy`
-
-
diff --git a/tensorflow/docs_src/api_guides/python/contrib.metrics.md b/tensorflow/docs_src/api_guides/python/contrib.metrics.md
deleted file mode 100644
index de6346ca80..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.metrics.md
+++ /dev/null
@@ -1,133 +0,0 @@
-# Metrics (contrib)
-[TOC]
-
-##Ops for evaluation metrics and summary statistics.
-
-### API
-
-This module provides functions for computing streaming metrics: metrics computed
-on dynamically valued `Tensors`. Each metric declaration returns a
-"value_tensor", an idempotent operation that returns the current value of the
-metric, and an "update_op", an operation that accumulates the information
-from the current value of the `Tensors` being measured as well as returns the
-value of the "value_tensor".
-
-To use any of these metrics, one need only declare the metric, call `update_op`
-repeatedly to accumulate data over the desired number of `Tensor` values (often
-each one is a single batch) and finally evaluate the value_tensor. For example,
-to use the `streaming_mean`:
-
-```python
-value = ...
-mean_value, update_op = tf.contrib.metrics.streaming_mean(values)
-sess.run(tf.local_variables_initializer())
-
-for i in range(number_of_batches):
- print('Mean after batch %d: %f' % (i, update_op.eval())
-print('Final Mean: %f' % mean_value.eval())
-```
-
-Each metric function adds nodes to the graph that hold the state necessary to
-compute the value of the metric as well as a set of operations that actually
-perform the computation. Every metric evaluation is composed of three steps
-
-* Initialization: initializing the metric state.
-* Aggregation: updating the values of the metric state.
-* Finalization: computing the final metric value.
-
-In the above example, calling streaming_mean creates a pair of state variables
-that will contain (1) the running sum and (2) the count of the number of samples
-in the sum. Because the streaming metrics use local variables,
-the Initialization stage is performed by running the op returned
-by `tf.local_variables_initializer()`. It sets the sum and count variables to
-zero.
-
-Next, Aggregation is performed by examining the current state of `values`
-and incrementing the state variables appropriately. This step is executed by
-running the `update_op` returned by the metric.
-
-Finally, finalization is performed by evaluating the "value_tensor"
-
-In practice, we commonly want to evaluate across many batches and multiple
-metrics. To do so, we need only run the metric computation operations multiple
-times:
-
-```python
-labels = ...
-predictions = ...
-accuracy, update_op_acc = tf.contrib.metrics.streaming_accuracy(
- labels, predictions)
-error, update_op_error = tf.contrib.metrics.streaming_mean_absolute_error(
- labels, predictions)
-
-sess.run(tf.local_variables_initializer())
-for batch in range(num_batches):
- sess.run([update_op_acc, update_op_error])
-
-accuracy, error = sess.run([accuracy, error])
-```
-
-Note that when evaluating the same metric multiple times on different inputs,
-one must specify the scope of each metric to avoid accumulating the results
-together:
-
-```python
-labels = ...
-predictions0 = ...
-predictions1 = ...
-
-accuracy0 = tf.contrib.metrics.accuracy(labels, predictions0, name='preds0')
-accuracy1 = tf.contrib.metrics.accuracy(labels, predictions1, name='preds1')
-```
-
-Certain metrics, such as streaming_mean or streaming_accuracy, can be weighted
-via a `weights` argument. The `weights` tensor must be the same size as the
-labels and predictions tensors and results in a weighted average of the metric.
-
-## Metric `Ops`
-
-* `tf.contrib.metrics.streaming_accuracy`
-* `tf.contrib.metrics.streaming_mean`
-* `tf.contrib.metrics.streaming_recall`
-* `tf.contrib.metrics.streaming_recall_at_thresholds`
-* `tf.contrib.metrics.streaming_precision`
-* `tf.contrib.metrics.streaming_precision_at_thresholds`
-* `tf.contrib.metrics.streaming_auc`
-* `tf.contrib.metrics.streaming_recall_at_k`
-* `tf.contrib.metrics.streaming_mean_absolute_error`
-* `tf.contrib.metrics.streaming_mean_iou`
-* `tf.contrib.metrics.streaming_mean_relative_error`
-* `tf.contrib.metrics.streaming_mean_squared_error`
-* `tf.contrib.metrics.streaming_mean_tensor`
-* `tf.contrib.metrics.streaming_root_mean_squared_error`
-* `tf.contrib.metrics.streaming_covariance`
-* `tf.contrib.metrics.streaming_pearson_correlation`
-* `tf.contrib.metrics.streaming_mean_cosine_distance`
-* `tf.contrib.metrics.streaming_percentage_less`
-* `tf.contrib.metrics.streaming_sensitivity_at_specificity`
-* `tf.contrib.metrics.streaming_sparse_average_precision_at_k`
-* `tf.contrib.metrics.streaming_sparse_precision_at_k`
-* `tf.contrib.metrics.streaming_sparse_precision_at_top_k`
-* `tf.contrib.metrics.streaming_sparse_recall_at_k`
-* `tf.contrib.metrics.streaming_specificity_at_sensitivity`
-* `tf.contrib.metrics.streaming_concat`
-* `tf.contrib.metrics.streaming_false_negatives`
-* `tf.contrib.metrics.streaming_false_negatives_at_thresholds`
-* `tf.contrib.metrics.streaming_false_positives`
-* `tf.contrib.metrics.streaming_false_positives_at_thresholds`
-* `tf.contrib.metrics.streaming_true_negatives`
-* `tf.contrib.metrics.streaming_true_negatives_at_thresholds`
-* `tf.contrib.metrics.streaming_true_positives`
-* `tf.contrib.metrics.streaming_true_positives_at_thresholds`
-* `tf.contrib.metrics.auc_using_histogram`
-* `tf.contrib.metrics.accuracy`
-* `tf.contrib.metrics.aggregate_metrics`
-* `tf.contrib.metrics.aggregate_metric_map`
-* `tf.contrib.metrics.confusion_matrix`
-
-## Set `Ops`
-
-* `tf.contrib.metrics.set_difference`
-* `tf.contrib.metrics.set_intersection`
-* `tf.contrib.metrics.set_size`
-* `tf.contrib.metrics.set_union`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.rnn.md b/tensorflow/docs_src/api_guides/python/contrib.rnn.md
deleted file mode 100644
index d265ab6925..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.rnn.md
+++ /dev/null
@@ -1,61 +0,0 @@
-# RNN and Cells (contrib)
-[TOC]
-
-Module for constructing RNN Cells and additional RNN operations.
-
-## Base interface for all RNN Cells
-
-* `tf.contrib.rnn.RNNCell`
-
-## Core RNN Cells for use with TensorFlow's core RNN methods
-
-* `tf.contrib.rnn.BasicRNNCell`
-* `tf.contrib.rnn.BasicLSTMCell`
-* `tf.contrib.rnn.GRUCell`
-* `tf.contrib.rnn.LSTMCell`
-* `tf.contrib.rnn.LayerNormBasicLSTMCell`
-
-## Classes storing split `RNNCell` state
-
-* `tf.contrib.rnn.LSTMStateTuple`
-
-## Core RNN Cell wrappers (RNNCells that wrap other RNNCells)
-
-* `tf.contrib.rnn.MultiRNNCell`
-* `tf.contrib.rnn.LSTMBlockWrapper`
-* `tf.contrib.rnn.DropoutWrapper`
-* `tf.contrib.rnn.EmbeddingWrapper`
-* `tf.contrib.rnn.InputProjectionWrapper`
-* `tf.contrib.rnn.OutputProjectionWrapper`
-* `tf.contrib.rnn.DeviceWrapper`
-* `tf.contrib.rnn.ResidualWrapper`
-
-### Block RNNCells
-* `tf.contrib.rnn.LSTMBlockCell`
-* `tf.contrib.rnn.GRUBlockCell`
-
-### Fused RNNCells
-* `tf.contrib.rnn.FusedRNNCell`
-* `tf.contrib.rnn.FusedRNNCellAdaptor`
-* `tf.contrib.rnn.TimeReversedFusedRNN`
-* `tf.contrib.rnn.LSTMBlockFusedCell`
-
-### LSTM-like cells
-* `tf.contrib.rnn.CoupledInputForgetGateLSTMCell`
-* `tf.contrib.rnn.TimeFreqLSTMCell`
-* `tf.contrib.rnn.GridLSTMCell`
-
-### RNNCell wrappers
-* `tf.contrib.rnn.AttentionCellWrapper`
-* `tf.contrib.rnn.CompiledWrapper`
-
-
-## Recurrent Neural Networks
-
-TensorFlow provides a number of methods for constructing Recurrent Neural
-Networks.
-
-* `tf.contrib.rnn.static_rnn`
-* `tf.contrib.rnn.static_state_saving_rnn`
-* `tf.contrib.rnn.static_bidirectional_rnn`
-* `tf.contrib.rnn.stack_bidirectional_dynamic_rnn`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md b/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md
deleted file mode 100644
index 54f2fafc71..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.seq2seq.md
+++ /dev/null
@@ -1,138 +0,0 @@
-# Seq2seq Library (contrib)
-[TOC]
-
-Module for constructing seq2seq models and dynamic decoding. Builds on top of
-libraries in `tf.contrib.rnn`.
-
-This library is composed of two primary components:
-
-* New attention wrappers for `tf.contrib.rnn.RNNCell` objects.
-* A new object-oriented dynamic decoding framework.
-
-## Attention
-
-Attention wrappers are `RNNCell` objects that wrap other `RNNCell` objects and
-implement attention. The form of attention is determined by a subclass of
-`tf.contrib.seq2seq.AttentionMechanism`. These subclasses describe the form
-of attention (e.g. additive vs. multiplicative) to use when creating the
-wrapper. An instance of an `AttentionMechanism` is constructed with a
-`memory` tensor, from which lookup keys and values tensors are created.
-
-### Attention Mechanisms
-
-The two basic attention mechanisms are:
-
-* `tf.contrib.seq2seq.BahdanauAttention` (additive attention,
- [ref.](https://arxiv.org/abs/1409.0473))
-* `tf.contrib.seq2seq.LuongAttention` (multiplicative attention,
- [ref.](https://arxiv.org/abs/1508.04025))
-
-The `memory` tensor passed the attention mechanism's constructor is expected to
-be shaped `[batch_size, memory_max_time, memory_depth]`; and often an additional
-`memory_sequence_length` vector is accepted. If provided, the `memory`
-tensors' rows are masked with zeros past their true sequence lengths.
-
-Attention mechanisms also have a concept of depth, usually determined as a
-construction parameter `num_units`. For some kinds of attention (like
-`BahdanauAttention`), both queries and memory are projected to tensors of depth
-`num_units`. For other kinds (like `LuongAttention`), `num_units` should match
-the depth of the queries; and the `memory` tensor will be projected to this
-depth.
-
-### Attention Wrappers
-
-The basic attention wrapper is `tf.contrib.seq2seq.AttentionWrapper`.
-This wrapper accepts an `RNNCell` instance, an instance of `AttentionMechanism`,
-and an attention depth parameter (`attention_size`); as well as several
-optional arguments that allow one to customize intermediate calculations.
-
-At each time step, the basic calculation performed by this wrapper is:
-
-```python
-cell_inputs = concat([inputs, prev_state.attention], -1)
-cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state)
-score = attention_mechanism(cell_output)
-alignments = softmax(score)
-context = matmul(alignments, attention_mechanism.values)
-attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1))
-next_state = AttentionWrapperState(
- cell_state=next_cell_state,
- attention=attention)
-output = attention
-return output, next_state
-```
-
-In practice, a number of the intermediate calculations are configurable.
-For example, the initial concatenation of `inputs` and `prev_state.attention`
-can be replaced with another mixing function. The function `softmax` can
-be replaced with alternative options when calculating `alignments` from the
-`score`. Finally, the outputs returned by the wrapper can be configured to
-be the value `cell_output` instead of `attention`.
-
-The benefit of using a `AttentionWrapper` is that it plays nicely with
-other wrappers and the dynamic decoder described below. For example, one can
-write:
-
-```python
-cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0")
-attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs)
-attn_cell = tf.contrib.seq2seq.AttentionWrapper(
- cell, attention_mechanism, attention_size=256)
-attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1")
-top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1")
-multi_cell = MultiRNNCell([attn_cell, top_cell])
-```
-
-The `multi_rnn` cell will perform the bottom layer calculations on GPU 0;
-attention calculations will be performed on GPU 1 and immediately passed
-up to the top layer which is also calculated on GPU 1. The attention is
-also passed forward in time to the next time step and copied to GPU 0 for the
-next time step of `cell`. (*Note*: This is just an example of use,
-not a suggested device partitioning strategy.)
-
-## Dynamic Decoding
-
-Example usage:
-
-``` python
-cell = # instance of RNNCell
-
-if mode == "train":
- helper = tf.contrib.seq2seq.TrainingHelper(
- input=input_vectors,
- sequence_length=input_lengths)
-elif mode == "infer":
- helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
- embedding=embedding,
- start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
- end_token=END_SYMBOL)
-
-decoder = tf.contrib.seq2seq.BasicDecoder(
- cell=cell,
- helper=helper,
- initial_state=cell.zero_state(batch_size, tf.float32))
-outputs, _ = tf.contrib.seq2seq.dynamic_decode(
- decoder=decoder,
- output_time_major=False,
- impute_finished=True,
- maximum_iterations=20)
-```
-
-### Decoder base class and functions
-
-* `tf.contrib.seq2seq.Decoder`
-* `tf.contrib.seq2seq.dynamic_decode`
-
-### Basic Decoder
-
-* `tf.contrib.seq2seq.BasicDecoderOutput`
-* `tf.contrib.seq2seq.BasicDecoder`
-
-### Decoder Helpers
-
-* `tf.contrib.seq2seq.Helper`
-* `tf.contrib.seq2seq.CustomHelper`
-* `tf.contrib.seq2seq.GreedyEmbeddingHelper`
-* `tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper`
-* `tf.contrib.seq2seq.ScheduledOutputTrainingHelper`
-* `tf.contrib.seq2seq.TrainingHelper`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.signal.md b/tensorflow/docs_src/api_guides/python/contrib.signal.md
deleted file mode 100644
index 66df561084..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.signal.md
+++ /dev/null
@@ -1,172 +0,0 @@
-# Signal Processing (contrib)
-[TOC]
-
-`tf.contrib.signal` is a module for signal processing primitives. All
-operations have GPU support and are differentiable. This module is especially
-helpful for building TensorFlow models that process or generate audio, though
-the techniques are useful in many domains.
-
-## Framing variable length sequences
-
-When dealing with variable length signals (e.g. audio) it is common to "frame"
-them into multiple fixed length windows. These windows can overlap if the 'step'
-of the frame is less than the frame length. `tf.contrib.signal.frame` does
-exactly this. For example:
-
-```python
-# A batch of float32 time-domain signals in the range [-1, 1] with shape
-# [batch_size, signal_length]. Both batch_size and signal_length may be unknown.
-signals = tf.placeholder(tf.float32, [None, None])
-
-# Compute a [batch_size, ?, 128] tensor of fixed length, overlapping windows
-# where each window overlaps the previous by 75% (frame_length - frame_step
-# samples of overlap).
-frames = tf.contrib.signal.frame(signals, frame_length=128, frame_step=32)
-```
-
-The `axis` parameter to `tf.contrib.signal.frame` allows you to frame tensors
-with inner structure (e.g. a spectrogram):
-
-```python
-# `magnitude_spectrograms` is a [batch_size, ?, 129] tensor of spectrograms. We
-# would like to produce overlapping fixed-size spectrogram patches; for example,
-# for use in a situation where a fixed size input is needed.
-magnitude_spectrograms = tf.abs(tf.contrib.signal.stft(
- signals, frame_length=256, frame_step=64, fft_length=256))
-
-# `spectrogram_patches` is a [batch_size, ?, 64, 129] tensor containing a
-# variable number of [64, 129] spectrogram patches per batch item.
-spectrogram_patches = tf.contrib.signal.frame(
- magnitude_spectrograms, frame_length=64, frame_step=16, axis=1)
-```
-
-## Reconstructing framed sequences and applying a tapering window
-
-`tf.contrib.signal.overlap_and_add` can be used to reconstruct a signal from a
-framed representation. For example, the following code reconstructs the signal
-produced in the preceding example:
-
-```python
-# Reconstructs `signals` from `frames` produced in the above example. However,
-# the magnitude of `reconstructed_signals` will be greater than `signals`.
-reconstructed_signals = tf.contrib.signal.overlap_and_add(frames, frame_step=32)
-```
-
-Note that because `frame_step` is 25% of `frame_length` in the above example,
-the resulting reconstruction will have a greater magnitude than the original
-`signals`. To compensate for this, we can use a tapering window function. If the
-window function satisfies the Constant Overlap-Add (COLA) property for the given
-frame step, then it will recover the original `signals`.
-
-`tf.contrib.signal.hamming_window` and `tf.contrib.signal.hann_window` both
-satisfy the COLA property for a 75% overlap.
-
-```python
-frame_length = 128
-frame_step = 32
-windowed_frames = frames * tf.contrib.signal.hann_window(frame_length)
-reconstructed_signals = tf.contrib.signal.overlap_and_add(
- windowed_frames, frame_step)
-```
-
-## Computing spectrograms
-
-A spectrogram is a time-frequency decomposition of a signal that indicates its
-frequency content over time. The most common approach to computing spectrograms
-is to take the magnitude of the [Short-time Fourier Transform][stft] (STFT),
-which `tf.contrib.signal.stft` can compute as follows:
-
-```python
-# A batch of float32 time-domain signals in the range [-1, 1] with shape
-# [batch_size, signal_length]. Both batch_size and signal_length may be unknown.
-signals = tf.placeholder(tf.float32, [None, None])
-
-# `stfts` is a complex64 Tensor representing the Short-time Fourier Transform of
-# each signal in `signals`. Its shape is [batch_size, ?, fft_unique_bins]
-# where fft_unique_bins = fft_length // 2 + 1 = 513.
-stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
- fft_length=1024)
-
-# A power spectrogram is the squared magnitude of the complex-valued STFT.
-# A float32 Tensor of shape [batch_size, ?, 513].
-power_spectrograms = tf.real(stfts * tf.conj(stfts))
-
-# An energy spectrogram is the magnitude of the complex-valued STFT.
-# A float32 Tensor of shape [batch_size, ?, 513].
-magnitude_spectrograms = tf.abs(stfts)
-```
-
-You may use a power spectrogram or a magnitude spectrogram; each has its
-advantages. Note that if you apply logarithmic compression, the power
-spectrogram and magnitude spectrogram will differ by a factor of 2.
-
-## Logarithmic compression
-
-It is common practice to apply a compressive nonlinearity such as a logarithm or
-power-law compression to spectrograms. This helps to balance the importance of
-detail in low and high energy regions of the spectrum, which more closely
-matches human auditory sensitivity.
-
-When compressing with a logarithm, it's a good idea to use a stabilizing offset
-to avoid high dynamic ranges caused by the singularity at zero.
-
-```python
-log_offset = 1e-6
-log_magnitude_spectrograms = tf.log(magnitude_spectrograms + log_offset)
-```
-
-## Computing log-mel spectrograms
-
-When working with spectral representations of audio, the [mel scale][mel] is a
-common reweighting of the frequency dimension, which results in a
-lower-dimensional and more perceptually-relevant representation of the audio.
-
-`tf.contrib.signal.linear_to_mel_weight_matrix` produces a matrix you can use
-to convert a spectrogram to the mel scale.
-
-```python
-# Warp the linear-scale, magnitude spectrograms into the mel-scale.
-num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
-lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
-linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
- num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
- upper_edge_hertz)
-mel_spectrograms = tf.tensordot(
- magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
-# Note: Shape inference for `tf.tensordot` does not currently handle this case.
-mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
- linear_to_mel_weight_matrix.shape[-1:]))
-```
-
-If desired, compress the mel spectrogram magnitudes. For example, you may use
-logarithmic compression (as discussed in the previous section).
-
-Order matters! Compressing the spectrogram magnitudes after
-reweighting the frequencies is different from reweighting the compressed
-spectrogram magnitudes. According to the perceptual justification of the mel
-scale, conversion from linear scale entails summing intensity or energy among
-adjacent bands, i.e. it should be applied before logarithmic compression. Taking
-the weighted sum of log-compressed values amounts to multiplying the
-pre-logarithm values, which rarely, if ever, makes sense.
-
-```python
-log_offset = 1e-6
-log_mel_spectrograms = tf.log(mel_spectrograms + log_offset)
-```
-
-## Computing Mel-Frequency Cepstral Coefficients (MFCCs)
-
-Call `tf.contrib.signal.mfccs_from_log_mel_spectrograms` to compute
-[MFCCs][mfcc] from log-magnitude, mel-scale spectrograms (as computed in the
-preceding example):
-
-```python
-num_mfccs = 13
-# Keep the first `num_mfccs` MFCCs.
-mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms(
- log_mel_spectrograms)[..., :num_mfccs]
-```
-
-[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
-[mel]: https://en.wikipedia.org/wiki/Mel_scale
-[mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
diff --git a/tensorflow/docs_src/api_guides/python/contrib.staging.md b/tensorflow/docs_src/api_guides/python/contrib.staging.md
deleted file mode 100644
index de143a7bd3..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.staging.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# Staging (contrib)
-[TOC]
-
-This library contains utilities for adding pipelining to a model.
-
-* `tf.contrib.staging.StagingArea`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.training.md b/tensorflow/docs_src/api_guides/python/contrib.training.md
deleted file mode 100644
index 068efdc829..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.training.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# Training (contrib)
-[TOC]
-
-Training and input utilities.
-
-## Splitting sequence inputs into minibatches with state saving
-
-Use `tf.contrib.training.SequenceQueueingStateSaver` or
-its wrapper `tf.contrib.training.batch_sequences_with_states` if
-you have input data with a dynamic primary time / frame count axis which
-you'd like to convert into fixed size segments during minibatching, and would
-like to store state in the forward direction across segments of an example.
-
-* `tf.contrib.training.batch_sequences_with_states`
-* `tf.contrib.training.NextQueuedSequenceBatch`
-* `tf.contrib.training.SequenceQueueingStateSaver`
-
-
-## Online data resampling
-
-To resample data with replacement on a per-example basis, use
-`tf.contrib.training.rejection_sample` or
-`tf.contrib.training.resample_at_rate`. For `rejection_sample`, provide
-a boolean Tensor describing whether to accept or reject. Resulting batch sizes
-are always the same. For `resample_at_rate`, provide the desired rate for each
-example. Resulting batch sizes may vary. If you wish to specify relative
-rates, rather than absolute ones, use `tf.contrib.training.weighted_resample`
-(which also returns the actual resampling rate used for each output example).
-
-Use `tf.contrib.training.stratified_sample` to resample without replacement
-from the data to achieve a desired mix of class proportions that the Tensorflow
-graph sees. For instance, if you have a binary classification dataset that is
-99.9% class 1, a common approach is to resample from the data so that the data
-is more balanced.
-
-* `tf.contrib.training.rejection_sample`
-* `tf.contrib.training.resample_at_rate`
-* `tf.contrib.training.stratified_sample`
-* `tf.contrib.training.weighted_resample`
-
-## Bucketing
-
-Use `tf.contrib.training.bucket` or
-`tf.contrib.training.bucket_by_sequence_length` to stratify
-minibatches into groups ("buckets"). Use `bucket_by_sequence_length`
-with the argument `dynamic_pad=True` to receive minibatches of similarly
-sized sequences for efficient training via `dynamic_rnn`.
-
-* `tf.contrib.training.bucket`
-* `tf.contrib.training.bucket_by_sequence_length`
diff --git a/tensorflow/docs_src/api_guides/python/contrib.util.md b/tensorflow/docs_src/api_guides/python/contrib.util.md
deleted file mode 100644
index e5fd97e9f2..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.util.md
+++ /dev/null
@@ -1,12 +0,0 @@
-# Utilities (contrib)
-[TOC]
-
-Utilities for dealing with Tensors.
-
-## Miscellaneous Utility Functions
-
-* `tf.contrib.util.constant_value`
-* `tf.contrib.util.make_tensor_proto`
-* `tf.contrib.util.make_ndarray`
-* `tf.contrib.util.ops_used_by_graph_def`
-* `tf.contrib.util.stripped_op_list_for_graph`
diff --git a/tensorflow/docs_src/api_guides/python/control_flow_ops.md b/tensorflow/docs_src/api_guides/python/control_flow_ops.md
deleted file mode 100644
index 42c86d9978..0000000000
--- a/tensorflow/docs_src/api_guides/python/control_flow_ops.md
+++ /dev/null
@@ -1,57 +0,0 @@
-# Control Flow
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Control Flow Operations
-
-TensorFlow provides several operations and classes that you can use to control
-the execution of operations and add conditional dependencies to your graph.
-
-* `tf.identity`
-* `tf.tuple`
-* `tf.group`
-* `tf.no_op`
-* `tf.count_up_to`
-* `tf.cond`
-* `tf.case`
-* `tf.while_loop`
-
-## Logical Operators
-
-TensorFlow provides several operations that you can use to add logical operators
-to your graph.
-
-* `tf.logical_and`
-* `tf.logical_not`
-* `tf.logical_or`
-* `tf.logical_xor`
-
-## Comparison Operators
-
-TensorFlow provides several operations that you can use to add comparison
-operators to your graph.
-
-* `tf.equal`
-* `tf.not_equal`
-* `tf.less`
-* `tf.less_equal`
-* `tf.greater`
-* `tf.greater_equal`
-* `tf.where`
-
-## Debugging Operations
-
-TensorFlow provides several operations that you can use to validate values and
-debug your graph.
-
-* `tf.is_finite`
-* `tf.is_inf`
-* `tf.is_nan`
-* `tf.verify_tensor_all_finite`
-* `tf.check_numerics`
-* `tf.add_check_numerics_ops`
-* `tf.Assert`
-* `tf.Print`
diff --git a/tensorflow/docs_src/api_guides/python/framework.md b/tensorflow/docs_src/api_guides/python/framework.md
deleted file mode 100644
index 40a6c0783a..0000000000
--- a/tensorflow/docs_src/api_guides/python/framework.md
+++ /dev/null
@@ -1,51 +0,0 @@
-# Building Graphs
-[TOC]
-
-Classes and functions for building TensorFlow graphs.
-
-## Core graph data structures
-
-* `tf.Graph`
-* `tf.Operation`
-* `tf.Tensor`
-
-## Tensor types
-
-* `tf.DType`
-* `tf.as_dtype`
-
-## Utility functions
-
-* `tf.device`
-* `tf.container`
-* `tf.name_scope`
-* `tf.control_dependencies`
-* `tf.convert_to_tensor`
-* `tf.convert_to_tensor_or_indexed_slices`
-* `tf.convert_to_tensor_or_sparse_tensor`
-* `tf.get_default_graph`
-* `tf.reset_default_graph`
-* `tf.import_graph_def`
-* `tf.load_file_system_library`
-* `tf.load_op_library`
-
-## Graph collections
-
-* `tf.add_to_collection`
-* `tf.get_collection`
-* `tf.get_collection_ref`
-* `tf.GraphKeys`
-
-## Defining new operations
-
-* `tf.RegisterGradient`
-* `tf.NotDifferentiable`
-* `tf.NoGradient`
-* `tf.TensorShape`
-* `tf.Dimension`
-* `tf.op_scope`
-* `tf.get_seed`
-
-## For libraries building on TensorFlow
-
-* `tf.register_tensor_conversion_function`
diff --git a/tensorflow/docs_src/api_guides/python/functional_ops.md b/tensorflow/docs_src/api_guides/python/functional_ops.md
deleted file mode 100644
index 0a9fe02ad5..0000000000
--- a/tensorflow/docs_src/api_guides/python/functional_ops.md
+++ /dev/null
@@ -1,18 +0,0 @@
-# Higher Order Functions
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-Functional operations.
-
-## Higher Order Operators
-
-TensorFlow provides several higher order operators to simplify the common
-map-reduce programming patterns.
-
-* `tf.map_fn`
-* `tf.foldl`
-* `tf.foldr`
-* `tf.scan`
diff --git a/tensorflow/docs_src/api_guides/python/image.md b/tensorflow/docs_src/api_guides/python/image.md
deleted file mode 100644
index c51b92db05..0000000000
--- a/tensorflow/docs_src/api_guides/python/image.md
+++ /dev/null
@@ -1,144 +0,0 @@
-# Images
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Encoding and Decoding
-
-TensorFlow provides Ops to decode and encode JPEG and PNG formats. Encoded
-images are represented by scalar string Tensors, decoded images by 3-D uint8
-tensors of shape `[height, width, channels]`. (PNG also supports uint16.)
-
-The encode and decode Ops apply to one image at a time. Their input and output
-are all of variable size. If you need fixed size images, pass the output of
-the decode Ops to one of the cropping and resizing Ops.
-
-Note: The PNG encode and decode Ops support RGBA, but the conversions Ops
-presently only support RGB, HSV, and GrayScale. Presently, the alpha channel has
-to be stripped from the image and re-attached using slicing ops.
-
-* `tf.image.decode_bmp`
-* `tf.image.decode_gif`
-* `tf.image.decode_jpeg`
-* `tf.image.encode_jpeg`
-* `tf.image.decode_png`
-* `tf.image.encode_png`
-* `tf.image.decode_image`
-
-## Resizing
-
-The resizing Ops accept input images as tensors of several types. They always
-output resized images as float32 tensors.
-
-The convenience function `tf.image.resize_images` supports both 4-D
-and 3-D tensors as input and output. 4-D tensors are for batches of images,
-3-D tensors for individual images.
-
-Other resizing Ops only support 4-D batches of images as input:
-`tf.image.resize_area`, `tf.image.resize_bicubic`,
-`tf.image.resize_bilinear`,
-`tf.image.resize_nearest_neighbor`.
-
-Example:
-
-```python
-# Decode a JPG image and resize it to 299 by 299 using default method.
-image = tf.image.decode_jpeg(...)
-resized_image = tf.image.resize_images(image, [299, 299])
-```
-
-* `tf.image.resize_images`
-* `tf.image.resize_area`
-* `tf.image.resize_bicubic`
-* `tf.image.resize_bilinear`
-* `tf.image.resize_nearest_neighbor`
-
-## Cropping
-
-* `tf.image.resize_image_with_crop_or_pad`
-* `tf.image.central_crop`
-* `tf.image.pad_to_bounding_box`
-* `tf.image.crop_to_bounding_box`
-* `tf.image.extract_glimpse`
-* `tf.image.crop_and_resize`
-
-## Flipping, Rotating and Transposing
-
-* `tf.image.flip_up_down`
-* `tf.image.random_flip_up_down`
-* `tf.image.flip_left_right`
-* `tf.image.random_flip_left_right`
-* `tf.image.transpose_image`
-* `tf.image.rot90`
-
-## Converting Between Colorspaces
-
-Image ops work either on individual images or on batches of images, depending on
-the shape of their input Tensor.
-
-If 3-D, the shape is `[height, width, channels]`, and the Tensor represents one
-image. If 4-D, the shape is `[batch_size, height, width, channels]`, and the
-Tensor represents `batch_size` images.
-
-Currently, `channels` can usefully be 1, 2, 3, or 4. Single-channel images are
-grayscale, images with 3 channels are encoded as either RGB or HSV. Images
-with 2 or 4 channels include an alpha channel, which has to be stripped from the
-image before passing the image to most image processing functions (and can be
-re-attached later).
-
-Internally, images are either stored in as one `float32` per channel per pixel
-(implicitly, values are assumed to lie in `[0,1)`) or one `uint8` per channel
-per pixel (values are assumed to lie in `[0,255]`).
-
-TensorFlow can convert between images in RGB or HSV. The conversion functions
-work only on float images, so you need to convert images in other formats using
-`tf.image.convert_image_dtype`.
-
-Example:
-
-```python
-# Decode an image and convert it to HSV.
-rgb_image = tf.image.decode_png(..., channels=3)
-rgb_image_float = tf.image.convert_image_dtype(rgb_image, tf.float32)
-hsv_image = tf.image.rgb_to_hsv(rgb_image)
-```
-
-* `tf.image.rgb_to_grayscale`
-* `tf.image.grayscale_to_rgb`
-* `tf.image.hsv_to_rgb`
-* `tf.image.rgb_to_hsv`
-* `tf.image.convert_image_dtype`
-
-## Image Adjustments
-
-TensorFlow provides functions to adjust images in various ways: brightness,
-contrast, hue, and saturation. Each adjustment can be done with predefined
-parameters or with random parameters picked from predefined intervals. Random
-adjustments are often useful to expand a training set and reduce overfitting.
-
-If several adjustments are chained it is advisable to minimize the number of
-redundant conversions by first converting the images to the most natural data
-type and representation (RGB or HSV).
-
-* `tf.image.adjust_brightness`
-* `tf.image.random_brightness`
-* `tf.image.adjust_contrast`
-* `tf.image.random_contrast`
-* `tf.image.adjust_hue`
-* `tf.image.random_hue`
-* `tf.image.adjust_gamma`
-* `tf.image.adjust_saturation`
-* `tf.image.random_saturation`
-* `tf.image.per_image_standardization`
-
-## Working with Bounding Boxes
-
-* `tf.image.draw_bounding_boxes`
-* `tf.image.non_max_suppression`
-* `tf.image.sample_distorted_bounding_box`
-
-## Denoising
-
-* `tf.image.total_variation`
diff --git a/tensorflow/docs_src/api_guides/python/index.md b/tensorflow/docs_src/api_guides/python/index.md
deleted file mode 100644
index a791a1432a..0000000000
--- a/tensorflow/docs_src/api_guides/python/index.md
+++ /dev/null
@@ -1,52 +0,0 @@
-# Python API Guides
-
-* [Asserts and boolean checks](check_ops.md)
-* [Building Graphs](framework.md)
-* [Constants, Sequences, and Random Values](constant_op.md)
-* [Control Flow](control_flow_ops.md)
-* [Data IO (Python functions)](python_io.md)
-* [Exporting and Importing a MetaGraph](meta_graph.md)
-* [Higher Order Functions](functional_ops.md)
-* [Histograms](histogram_ops.md)
-* [Images](image.md)
-* [Inputs and Readers](io_ops.md)
-* [Math](math_ops.md)
-* [Neural Network](nn.md)
-* [Reading data](reading_data.md)
-* [Running Graphs](client.md)
-* [Sparse Tensors](sparse_ops.md)
-* [Spectral Functions](spectral_ops.md)
-* [Strings](string_ops.md)
-* [Summary Operations](summary.md)
-* [TensorFlow Debugger](tfdbg.md)
-* [Tensor Handle Operations](session_ops.md)
-* [Tensor Transformations](array_ops.md)
-* [Testing](test.md)
-* [Training](train.md)
-* [Variables](state_ops.md)
-* [Wraps python functions](script_ops.md)
-* [BayesFlow Entropy (contrib)](contrib.bayesflow.entropy.md)
-* [BayesFlow Monte Carlo (contrib)](contrib.bayesflow.monte_carlo.md)
-* [BayesFlow Stochastic Graph (contrib)](contrib.bayesflow.stochastic_graph.md)
-* [BayesFlow Stochastic Tensors (contrib)](contrib.bayesflow.stochastic_tensor.md)
-* [BayesFlow Variational Inference (contrib)](contrib.bayesflow.variational_inference.md)
-* [Copying Graph Elements (contrib)](contrib.copy_graph.md)
-* [CRF (contrib)](contrib.crf.md)
-* [FFmpeg (contrib)](contrib.ffmpeg.md)
-* [Framework (contrib)](contrib.framework.md)
-* [Graph Editor (contrib)](contrib.graph_editor.md)
-* [Integrate (contrib)](contrib.integrate.md)
-* [Layers (contrib)](contrib.layers.md)
-* [Learn (contrib)](contrib.learn.md)
-* [Linear Algebra (contrib)](contrib.linalg.md)
-* [Losses (contrib)](contrib.losses.md)
-* [Metrics (contrib)](contrib.metrics.md)
-* [Optimization (contrib)](contrib.opt.md)
-* [Random variable transformations (contrib)](contrib.distributions.bijectors.md)
-* [RNN and Cells (contrib)](contrib.rnn.md)
-* [Seq2seq Library (contrib)](contrib.seq2seq.md)
-* [Signal Processing (contrib)](contrib.signal.md)
-* [Staging (contrib)](contrib.staging.md)
-* [Statistical Distributions (contrib)](contrib.distributions.md)
-* [Training (contrib)](contrib.training.md)
-* [Utilities (contrib)](contrib.util.md)
diff --git a/tensorflow/docs_src/api_guides/python/input_dataset.md b/tensorflow/docs_src/api_guides/python/input_dataset.md
deleted file mode 100644
index ab572e53d4..0000000000
--- a/tensorflow/docs_src/api_guides/python/input_dataset.md
+++ /dev/null
@@ -1,85 +0,0 @@
-# Dataset Input Pipeline
-[TOC]
-
-`tf.data.Dataset` allows you to build complex input pipelines. See the
-@{$guide/datasets} for an in-depth explanation of how to use this API.
-
-## Reader classes
-
-Classes that create a dataset from input files.
-
-* `tf.data.FixedLengthRecordDataset`
-* `tf.data.TextLineDataset`
-* `tf.data.TFRecordDataset`
-
-## Creating new datasets
-
-Static methods in `Dataset` that create new datasets.
-
-* `tf.data.Dataset.from_generator`
-* `tf.data.Dataset.from_tensor_slices`
-* `tf.data.Dataset.from_tensors`
-* `tf.data.Dataset.list_files`
-* `tf.data.Dataset.range`
-* `tf.data.Dataset.zip`
-
-## Transformations on existing datasets
-
-These functions transform an existing dataset, and return a new dataset. Calls
-can be chained together, as shown in the example below:
-
-```
-train_data = train_data.batch(100).shuffle().repeat()
-```
-
-* `tf.data.Dataset.apply`
-* `tf.data.Dataset.batch`
-* `tf.data.Dataset.cache`
-* `tf.data.Dataset.concatenate`
-* `tf.data.Dataset.filter`
-* `tf.data.Dataset.flat_map`
-* `tf.data.Dataset.interleave`
-* `tf.data.Dataset.map`
-* `tf.data.Dataset.padded_batch`
-* `tf.data.Dataset.prefetch`
-* `tf.data.Dataset.repeat`
-* `tf.data.Dataset.shard`
-* `tf.data.Dataset.shuffle`
-* `tf.data.Dataset.skip`
-* `tf.data.Dataset.take`
-
-### Custom transformation functions
-
-Custom transformation functions can be applied to a `Dataset` using `tf.data.Dataset.apply`. Below are custom transformation functions from `tf.contrib.data`:
-
-* `tf.contrib.data.batch_and_drop_remainder`
-* `tf.contrib.data.dense_to_sparse_batch`
-* `tf.contrib.data.enumerate_dataset`
-* `tf.contrib.data.group_by_window`
-* `tf.contrib.data.ignore_errors`
-* `tf.contrib.data.map_and_batch`
-* `tf.contrib.data.padded_batch_and_drop_remainder`
-* `tf.contrib.data.parallel_interleave`
-* `tf.contrib.data.rejection_resample`
-* `tf.contrib.data.scan`
-* `tf.contrib.data.shuffle_and_repeat`
-* `tf.contrib.data.unbatch`
-
-## Iterating over datasets
-
-These functions make a `tf.data.Iterator` from a `Dataset`.
-
-* `tf.data.Dataset.make_initializable_iterator`
-* `tf.data.Dataset.make_one_shot_iterator`
-
-The `Iterator` class also contains static methods that create a `tf.data.Iterator` that can be used with multiple `Dataset` objects.
-
-* `tf.data.Iterator.from_structure`
-* `tf.data.Iterator.from_string_handle`
-
-## Extra functions from `tf.contrib.data`
-
-* `tf.contrib.data.get_single_element`
-* `tf.contrib.data.make_saveable_from_iterator`
-* `tf.contrib.data.read_batch_features`
-
diff --git a/tensorflow/docs_src/api_guides/python/io_ops.md b/tensorflow/docs_src/api_guides/python/io_ops.md
deleted file mode 100644
index ab3c70daa0..0000000000
--- a/tensorflow/docs_src/api_guides/python/io_ops.md
+++ /dev/null
@@ -1,130 +0,0 @@
-# Inputs and Readers
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Placeholders
-
-TensorFlow provides a placeholder operation that must be fed with data
-on execution. For more info, see the section on @{$reading_data#Feeding$Feeding data}.
-
-* `tf.placeholder`
-* `tf.placeholder_with_default`
-
-For feeding `SparseTensor`s which are composite type,
-there is a convenience function:
-
-* `tf.sparse_placeholder`
-
-## Readers
-
-TensorFlow provides a set of Reader classes for reading data formats.
-For more information on inputs and readers, see @{$reading_data$Reading data}.
-
-* `tf.ReaderBase`
-* `tf.TextLineReader`
-* `tf.WholeFileReader`
-* `tf.IdentityReader`
-* `tf.TFRecordReader`
-* `tf.FixedLengthRecordReader`
-
-## Converting
-
-TensorFlow provides several operations that you can use to convert various data
-formats into tensors.
-
-* `tf.decode_csv`
-* `tf.decode_raw`
-
-- - -
-
-### Example protocol buffer
-
-TensorFlow's @{$reading_data#standard_tensorflow_format$recommended format for training examples}
-is serialized `Example` protocol buffers, [described
-here](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
-They contain `Features`, [described
-here](https://www.tensorflow.org/code/tensorflow/core/example/feature.proto).
-
-* `tf.VarLenFeature`
-* `tf.FixedLenFeature`
-* `tf.FixedLenSequenceFeature`
-* `tf.SparseFeature`
-* `tf.parse_example`
-* `tf.parse_single_example`
-* `tf.parse_tensor`
-* `tf.decode_json_example`
-
-## Queues
-
-TensorFlow provides several implementations of 'Queues', which are
-structures within the TensorFlow computation graph to stage pipelines
-of tensors together. The following describe the basic Queue interface
-and some implementations. To see an example use, see @{$threading_and_queues$Threading and Queues}.
-
-* `tf.QueueBase`
-* `tf.FIFOQueue`
-* `tf.PaddingFIFOQueue`
-* `tf.RandomShuffleQueue`
-* `tf.PriorityQueue`
-
-## Conditional Accumulators
-
-* `tf.ConditionalAccumulatorBase`
-* `tf.ConditionalAccumulator`
-* `tf.SparseConditionalAccumulator`
-
-## Dealing with the filesystem
-
-* `tf.matching_files`
-* `tf.read_file`
-* `tf.write_file`
-
-## Input pipeline
-
-TensorFlow functions for setting up an input-prefetching pipeline.
-Please see the @{$reading_data$reading data how-to}
-for context.
-
-### Beginning of an input pipeline
-
-The "producer" functions add a queue to the graph and a corresponding
-`QueueRunner` for running the subgraph that fills that queue.
-
-* `tf.train.match_filenames_once`
-* `tf.train.limit_epochs`
-* `tf.train.input_producer`
-* `tf.train.range_input_producer`
-* `tf.train.slice_input_producer`
-* `tf.train.string_input_producer`
-
-### Batching at the end of an input pipeline
-
-These functions add a queue to the graph to assemble a batch of
-examples, with possible shuffling. They also add a `QueueRunner` for
-running the subgraph that fills that queue.
-
-Use `tf.train.batch` or `tf.train.batch_join` for batching
-examples that have already been well shuffled. Use
-`tf.train.shuffle_batch` or
-`tf.train.shuffle_batch_join` for examples that would
-benefit from additional shuffling.
-
-Use `tf.train.batch` or `tf.train.shuffle_batch` if you want a
-single thread producing examples to batch, or if you have a
-single subgraph producing examples but you want to run it in *N* threads
-(where you increase *N* until it can keep the queue full). Use
-`tf.train.batch_join` or `tf.train.shuffle_batch_join`
-if you have *N* different subgraphs producing examples to batch and you
-want them run by *N* threads. Use `maybe_*` to enqueue conditionally.
-
-* `tf.train.batch`
-* `tf.train.maybe_batch`
-* `tf.train.batch_join`
-* `tf.train.maybe_batch_join`
-* `tf.train.shuffle_batch`
-* `tf.train.maybe_shuffle_batch`
-* `tf.train.shuffle_batch_join`
-* `tf.train.maybe_shuffle_batch_join`
diff --git a/tensorflow/docs_src/api_guides/python/math_ops.md b/tensorflow/docs_src/api_guides/python/math_ops.md
deleted file mode 100644
index e738161e49..0000000000
--- a/tensorflow/docs_src/api_guides/python/math_ops.md
+++ /dev/null
@@ -1,199 +0,0 @@
-# Math
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-Note: Elementwise binary operations in TensorFlow follow [numpy-style
-broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
-
-## Arithmetic Operators
-
-TensorFlow provides several operations that you can use to add basic arithmetic
-operators to your graph.
-
-* `tf.add`
-* `tf.subtract`
-* `tf.multiply`
-* `tf.scalar_mul`
-* `tf.div`
-* `tf.divide`
-* `tf.truediv`
-* `tf.floordiv`
-* `tf.realdiv`
-* `tf.truncatediv`
-* `tf.floor_div`
-* `tf.truncatemod`
-* `tf.floormod`
-* `tf.mod`
-* `tf.cross`
-
-## Basic Math Functions
-
-TensorFlow provides several operations that you can use to add basic
-mathematical functions to your graph.
-
-* `tf.add_n`
-* `tf.abs`
-* `tf.negative`
-* `tf.sign`
-* `tf.reciprocal`
-* `tf.square`
-* `tf.round`
-* `tf.sqrt`
-* `tf.rsqrt`
-* `tf.pow`
-* `tf.exp`
-* `tf.expm1`
-* `tf.log`
-* `tf.log1p`
-* `tf.ceil`
-* `tf.floor`
-* `tf.maximum`
-* `tf.minimum`
-* `tf.cos`
-* `tf.sin`
-* `tf.lbeta`
-* `tf.tan`
-* `tf.acos`
-* `tf.asin`
-* `tf.atan`
-* `tf.cosh`
-* `tf.sinh`
-* `tf.asinh`
-* `tf.acosh`
-* `tf.atanh`
-* `tf.lgamma`
-* `tf.digamma`
-* `tf.erf`
-* `tf.erfc`
-* `tf.squared_difference`
-* `tf.igamma`
-* `tf.igammac`
-* `tf.zeta`
-* `tf.polygamma`
-* `tf.betainc`
-* `tf.rint`
-
-## Matrix Math Functions
-
-TensorFlow provides several operations that you can use to add linear algebra
-functions on matrices to your graph.
-
-* `tf.diag`
-* `tf.diag_part`
-* `tf.trace`
-* `tf.transpose`
-* `tf.eye`
-* `tf.matrix_diag`
-* `tf.matrix_diag_part`
-* `tf.matrix_band_part`
-* `tf.matrix_set_diag`
-* `tf.matrix_transpose`
-* `tf.matmul`
-* `tf.norm`
-* `tf.matrix_determinant`
-* `tf.matrix_inverse`
-* `tf.cholesky`
-* `tf.cholesky_solve`
-* `tf.matrix_solve`
-* `tf.matrix_triangular_solve`
-* `tf.matrix_solve_ls`
-* `tf.qr`
-* `tf.self_adjoint_eig`
-* `tf.self_adjoint_eigvals`
-* `tf.svd`
-
-
-## Tensor Math Function
-
-TensorFlow provides operations that you can use to add tensor functions to your
-graph.
-
-* `tf.tensordot`
-
-
-## Complex Number Functions
-
-TensorFlow provides several operations that you can use to add complex number
-functions to your graph.
-
-* `tf.complex`
-* `tf.conj`
-* `tf.imag`
-* `tf.angle`
-* `tf.real`
-
-
-## Reduction
-
-TensorFlow provides several operations that you can use to perform
-common math computations that reduce various dimensions of a tensor.
-
-* `tf.reduce_sum`
-* `tf.reduce_prod`
-* `tf.reduce_min`
-* `tf.reduce_max`
-* `tf.reduce_mean`
-* `tf.reduce_all`
-* `tf.reduce_any`
-* `tf.reduce_logsumexp`
-* `tf.count_nonzero`
-* `tf.accumulate_n`
-* `tf.einsum`
-
-## Scan
-
-TensorFlow provides several operations that you can use to perform scans
-(running totals) across one axis of a tensor.
-
-* `tf.cumsum`
-* `tf.cumprod`
-
-## Segmentation
-
-TensorFlow provides several operations that you can use to perform common
-math computations on tensor segments.
-Here a segmentation is a partitioning of a tensor along
-the first dimension, i.e. it defines a mapping from the first dimension onto
-`segment_ids`. The `segment_ids` tensor should be the size of
-the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
-where `k<d0`.
-In particular, a segmentation of a matrix tensor is a mapping of rows to
-segments.
-
-For example:
-
-```python
-c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
-tf.segment_sum(c, tf.constant([0, 0, 1]))
- ==> [[0 0 0 0]
- [5 6 7 8]]
-```
-
-* `tf.segment_sum`
-* `tf.segment_prod`
-* `tf.segment_min`
-* `tf.segment_max`
-* `tf.segment_mean`
-* `tf.unsorted_segment_sum`
-* `tf.sparse_segment_sum`
-* `tf.sparse_segment_mean`
-* `tf.sparse_segment_sqrt_n`
-
-
-## Sequence Comparison and Indexing
-
-TensorFlow provides several operations that you can use to add sequence
-comparison and index extraction to your graph. You can use these operations to
-determine sequence differences and determine the indexes of specific values in
-a tensor.
-
-* `tf.argmin`
-* `tf.argmax`
-* `tf.setdiff1d`
-* `tf.where`
-* `tf.unique`
-* `tf.edit_distance`
-* `tf.invert_permutation`
diff --git a/tensorflow/docs_src/api_guides/python/meta_graph.md b/tensorflow/docs_src/api_guides/python/meta_graph.md
deleted file mode 100644
index 7dbd9a56f4..0000000000
--- a/tensorflow/docs_src/api_guides/python/meta_graph.md
+++ /dev/null
@@ -1,277 +0,0 @@
-# Exporting and Importing a MetaGraph
-
-A [`MetaGraph`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto) contains both a TensorFlow GraphDef
-as well as associated metadata necessary for running computation in a
-graph when crossing a process boundary. It can also be used for long
-term storage of graphs. The MetaGraph contains the information required
-to continue training, perform evaluation, or run inference on a previously trained graph.
-
-The APIs for exporting and importing the complete model are in
-the `tf.train.Saver` class:
-`tf.train.export_meta_graph`
-and
-`tf.train.import_meta_graph`.
-
-## What's in a MetaGraph
-
-The information contained in a MetaGraph is expressed as a
-[`MetaGraphDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto)
-protocol buffer. It contains the following fields:
-
-* [`MetaInfoDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto) for meta information, such as version and other user information.
-* [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) for describing the graph.
-* [`SaverDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/saver.proto) for the saver.
-* [`CollectionDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto)
-map that further describes additional components of the model such as
-@{$python/state_ops$`Variables`},
-`tf.train.QueueRunner`, etc.
-
-In order for a Python object to be serialized
-to and from `MetaGraphDef`, the Python class must implement `to_proto()` and
-`from_proto()` methods, and register them with the system using
-`register_proto_function`. For example:
-
- ```Python
- def to_proto(self, export_scope=None):
-
- """Converts a `Variable` to a `VariableDef` protocol buffer.
-
- Args:
- export_scope: Optional `string`. Name scope to remove.
-
- Returns:
- A `VariableDef` protocol buffer, or `None` if the `Variable` is not
- in the specified name scope.
- """
- if (export_scope is None or
- self._variable.name.startswith(export_scope)):
- var_def = variable_pb2.VariableDef()
- var_def.variable_name = ops.strip_name_scope(
- self._variable.name, export_scope)
- var_def.initializer_name = ops.strip_name_scope(
- self.initializer.name, export_scope)
- var_def.snapshot_name = ops.strip_name_scope(
- self._snapshot.name, export_scope)
- if self._save_slice_info:
- var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
- export_scope=export_scope))
- return var_def
- else:
- return None
-
- @staticmethod
- def from_proto(variable_def, import_scope=None):
- """Returns a `Variable` object created from `variable_def`."""
- return Variable(variable_def=variable_def, import_scope=import_scope)
-
- ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES,
- proto_type=variable_pb2.VariableDef,
- to_proto=Variable.to_proto,
- from_proto=Variable.from_proto)
- ```
-
-## Exporting a Complete Model to MetaGraph
-
-The API for exporting a running model as a MetaGraph is `export_meta_graph()`.
-
- ```Python
- def export_meta_graph(filename=None, collection_list=None, as_text=False):
- """Writes `MetaGraphDef` to save_path/filename.
-
- Args:
- filename: Optional meta_graph filename including the path.
- collection_list: List of string keys to collect.
- as_text: If `True`, writes the meta_graph as an ASCII proto.
-
- Returns:
- A `MetaGraphDef` proto.
- """
- ```
-
- A `collection` can contain any Python objects that users would like to
- be able to uniquely identify and easily retrieve. These objects can be
- special operations in the graph, such as `train_op`, or hyper parameters,
- such as "learning rate". Users can specify the list of collections
- they would like to export. If no `collection_list` is specified,
- all collections in the model will be exported.
-
- The API returns a serialized protocol buffer. If `filename` is
- specified, the protocol buffer will also be written to a file.
-
- Here are some of the typical usage models:
-
- * Export the default running graph:
-
- ```Python
- # Build the model
- ...
- with tf.Session() as sess:
- # Use the model
- ...
- # Export the model to /tmp/my-model.meta.
- meta_graph_def = tf.train.export_meta_graph(filename='/tmp/my-model.meta')
- ```
-
- * Export the default running graph and only a subset of the collections.
-
- ```Python
- meta_graph_def = tf.train.export_meta_graph(
- filename='/tmp/my-model.meta',
- collection_list=["input_tensor", "output_tensor"])
- ```
-
-
-The MetaGraph is also automatically exported via the `save()` API in
-`tf.train.Saver`.
-
-
-## Import a MetaGraph
-
-The API for importing a MetaGraph file into a graph is `import_meta_graph()`.
-
-Here are some of the typical usage models:
-
-* Import and continue training without building the model from scratch.
-
- ```Python
- ...
- # Create a saver.
- saver = tf.train.Saver(...variables...)
- # Remember the training_op we want to run by adding it to a collection.
- tf.add_to_collection('train_op', train_op)
- sess = tf.Session()
- for step in xrange(1000000):
- sess.run(train_op)
- if step % 1000 == 0:
- # Saves checkpoint, which by default also exports a meta_graph
- # named 'my-model-global_step.meta'.
- saver.save(sess, 'my-model', global_step=step)
- ```
-
- Later we can continue training from this saved `meta_graph` without building
- the model from scratch.
-
- ```Python
- with tf.Session() as sess:
- new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
- new_saver.restore(sess, 'my-save-dir/my-model-10000')
- # tf.get_collection() returns a list. In this example we only want the
- # first one.
- train_op = tf.get_collection('train_op')[0]
- for step in xrange(1000000):
- sess.run(train_op)
- ```
-
-* Import and extend the graph.
-
- For example, we can first build an inference graph, export it as a meta graph:
-
- ```Python
- # Creates an inference graph.
- # Hidden 1
- images = tf.constant(1.2, tf.float32, shape=[100, 28])
- with tf.name_scope("hidden1"):
- weights = tf.Variable(
- tf.truncated_normal([28, 128],
- stddev=1.0 / math.sqrt(float(28))),
- name="weights")
- biases = tf.Variable(tf.zeros([128]),
- name="biases")
- hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
- # Hidden 2
- with tf.name_scope("hidden2"):
- weights = tf.Variable(
- tf.truncated_normal([128, 32],
- stddev=1.0 / math.sqrt(float(128))),
- name="weights")
- biases = tf.Variable(tf.zeros([32]),
- name="biases")
- hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
- # Linear
- with tf.name_scope("softmax_linear"):
- weights = tf.Variable(
- tf.truncated_normal([32, 10],
- stddev=1.0 / math.sqrt(float(32))),
- name="weights")
- biases = tf.Variable(tf.zeros([10]),
- name="biases")
- logits = tf.matmul(hidden2, weights) + biases
- tf.add_to_collection("logits", logits)
-
- init_all_op = tf.global_variables_initializer()
-
- with tf.Session() as sess:
- # Initializes all the variables.
- sess.run(init_all_op)
- # Runs to logit.
- sess.run(logits)
- # Creates a saver.
- saver0 = tf.train.Saver()
- saver0.save(sess, 'my-save-dir/my-model-10000')
- # Generates MetaGraphDef.
- saver0.export_meta_graph('my-save-dir/my-model-10000.meta')
- ```
-
- Then later import it and extend it to a training graph.
-
- ```Python
- with tf.Session() as sess:
- new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
- new_saver.restore(sess, 'my-save-dir/my-model-10000')
- # Addes loss and train.
- labels = tf.constant(0, tf.int32, shape=[100], name="labels")
- batch_size = tf.size(labels)
- logits = tf.get_collection("logits")[0]
- loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
- logits=logits)
-
- tf.summary.scalar('loss', loss)
- # Creates the gradient descent optimizer with the given learning rate.
- optimizer = tf.train.GradientDescentOptimizer(0.01)
-
- # Runs train_op.
- train_op = optimizer.minimize(loss)
- sess.run(train_op)
- ```
-
-* Import a graph with preset devices.
-
- Sometimes an exported meta graph is from a training environment that the
- importer doesn't have. For example, the model might have been trained
- on GPUs, or in a distributed environment with replicas. When importing
- such models, it's useful to be able to clear the device settings in
- the graph so that we can run it on locally available devices. This can
- be achieved by calling `import_meta_graph` with the `clear_devices`
- option set to `True`.
-
- ```Python
- with tf.Session() as sess:
- new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta',
- clear_devices=True)
- new_saver.restore(sess, 'my-save-dir/my-model-10000')
- ...
- ```
-
-* Import within the default graph.
-
- Sometimes you might want to run `export_meta_graph` and `import_meta_graph`
- in codelab using the default graph. In that case, you need to reset
- the default graph by calling `tf.reset_default_graph()` first before
- running import.
-
- ```Python
- meta_graph_def = tf.train.export_meta_graph()
- ...
- tf.reset_default_graph()
- ...
- tf.train.import_meta_graph(meta_graph_def)
- ...
- ```
-
-* Retrieve Hyper Parameters
-
- ```Python
- filename = ".".join([tf.train.latest_checkpoint(train_dir), "meta"])
- tf.train.import_meta_graph(filename)
- hparams = tf.get_collection("hparams")
- ```
diff --git a/tensorflow/docs_src/api_guides/python/nn.md b/tensorflow/docs_src/api_guides/python/nn.md
deleted file mode 100644
index 40dda3941d..0000000000
--- a/tensorflow/docs_src/api_guides/python/nn.md
+++ /dev/null
@@ -1,418 +0,0 @@
-# Neural Network
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Activation Functions
-
-The activation ops provide different types of nonlinearities for use in neural
-networks. These include smooth nonlinearities (`sigmoid`, `tanh`, `elu`, `selu`,
-`softplus`, and `softsign`), continuous but not everywhere differentiable
-functions (`relu`, `relu6`, `crelu` and `relu_x`), and random regularization
-(`dropout`).
-
-All activation ops apply componentwise, and produce a tensor of the same
-shape as the input tensor.
-
-* `tf.nn.relu`
-* `tf.nn.relu6`
-* `tf.nn.crelu`
-* `tf.nn.elu`
-* `tf.nn.selu`
-* `tf.nn.softplus`
-* `tf.nn.softsign`
-* `tf.nn.dropout`
-* `tf.nn.bias_add`
-* `tf.sigmoid`
-* `tf.tanh`
-
-## Convolution
-
-The convolution ops sweep a 2-D filter over a batch of images, applying the
-filter to each window of each image of the appropriate size. The different
-ops trade off between generic vs. specific filters:
-
-* `conv2d`: Arbitrary filters that can mix channels together.
-* `depthwise_conv2d`: Filters that operate on each channel independently.
-* `separable_conv2d`: A depthwise spatial filter followed by a pointwise filter.
-
-Note that although these ops are called "convolution", they are strictly
-speaking "cross-correlation" since the filter is combined with an input window
-without reversing the filter. For details, see [the properties of
-cross-correlation](https://en.wikipedia.org/wiki/Cross-correlation#Properties).
-
-The filter is applied to image patches of the same size as the filter and
-strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies
-the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
-filter to every other image patch in each dimension, etc.
-
-Ignoring channels for the moment, assume that the 4-D `input` has shape
-`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
-`[filter_height, filter_width, ...]`. The spatial semantics of the
-convolution ops depend on the padding scheme chosen: `'SAME'` or `'VALID'`.
-Note that the padding values are always zero.
-
-First, consider the `'SAME'` padding scheme. A detailed explanation of the
-reasoning behind it is given in
-[these notes](#Notes_on_SAME_Convolution_Padding). Here, we summarize the
-mechanics of this padding scheme. When using `'SAME'`, the output height and
-width are computed as:
-
- out_height = ceil(float(in_height) / float(strides[1]))
- out_width = ceil(float(in_width) / float(strides[2]))
-
-The total padding applied along the height and width is computed as:
-
- if (in_height % strides[1] == 0):
- pad_along_height = max(filter_height - strides[1], 0)
- else:
- pad_along_height = max(filter_height - (in_height % strides[1]), 0)
- if (in_width % strides[2] == 0):
- pad_along_width = max(filter_width - strides[2], 0)
- else:
- pad_along_width = max(filter_width - (in_width % strides[2]), 0)
-
-Finally, the padding on the top, bottom, left and right are:
-
- pad_top = pad_along_height // 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width // 2
- pad_right = pad_along_width - pad_left
-
-Note that the division by 2 means that there might be cases when the padding on
-both sides (top vs bottom, right vs left) are off by one. In this case, the
-bottom and right sides always get the one additional padded pixel. For example,
-when `pad_along_height` is 5, we pad 2 pixels at the top and 3 pixels at the
-bottom. Note that this is different from existing libraries such as cuDNN and
-Caffe, which explicitly specify the number of padded pixels and always pad the
-same number of pixels on both sides.
-
-For the `'VALID'` scheme, the output height and width are computed as:
-
- out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))
- out_width = ceil(float(in_width - filter_width + 1) / float(strides[2]))
-
-and no padding is used.
-
-Given the output size and the padding, the output can be computed as
-
-$$ output[b, i, j, :] =
- sum_{d_i, d_j} input[b, strides[1] * i + d_i - pad_{top},\
- strides[2] * j + d_j - pad_{left}, ...] *
- filter[d_i, d_j,\ ...]$$
-
-where any value outside the original input image region are considered zero (
-i.e. we pad zero values around the border of the image).
-
-Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these
-vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
-vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
-is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
-concatenated.
-
-* `tf.nn.convolution`
-* `tf.nn.conv2d`
-* `tf.nn.depthwise_conv2d`
-* `tf.nn.depthwise_conv2d_native`
-* `tf.nn.separable_conv2d`
-* `tf.nn.atrous_conv2d`
-* `tf.nn.atrous_conv2d_transpose`
-* `tf.nn.conv2d_transpose`
-* `tf.nn.conv1d`
-* `tf.nn.conv3d`
-* `tf.nn.conv3d_transpose`
-* `tf.nn.conv2d_backprop_filter`
-* `tf.nn.conv2d_backprop_input`
-* `tf.nn.conv3d_backprop_filter_v2`
-* `tf.nn.depthwise_conv2d_native_backprop_filter`
-* `tf.nn.depthwise_conv2d_native_backprop_input`
-
-## Pooling
-
-The pooling ops sweep a rectangular window over the input tensor, computing a
-reduction operation for each window (average, max, or max with argmax). Each
-pooling op uses rectangular windows of size `ksize` separated by offset
-`strides`. For example, if `strides` is all ones every window is used, if
-`strides` is all twos every other window is used in each dimension, etc.
-
-In detail, the output is
-
- output[i] = reduce(value[strides * i:strides * i + ksize])
-
-where the indices also take into consideration the padding values. Please refer
-to the `Convolution` section for details about the padding calculation.
-
-* `tf.nn.avg_pool`
-* `tf.nn.max_pool`
-* `tf.nn.max_pool_with_argmax`
-* `tf.nn.avg_pool3d`
-* `tf.nn.max_pool3d`
-* `tf.nn.fractional_avg_pool`
-* `tf.nn.fractional_max_pool`
-* `tf.nn.pool`
-
-## Morphological filtering
-
-Morphological operators are non-linear filters used in image processing.
-
-[Greyscale morphological dilation
-](https://en.wikipedia.org/wiki/Dilation_(morphology))
-is the max-sum counterpart of standard sum-product convolution:
-
-$$ output[b, y, x, c] =
- max_{dy, dx} input[b,
- strides[1] * y + rates[1] * dy,
- strides[2] * x + rates[2] * dx,
- c] +
- filter[dy, dx, c]$$
-
-The `filter` is usually called structuring function. Max-pooling is a special
-case of greyscale morphological dilation when the filter assumes all-zero
-values (a.k.a. flat structuring function).
-
-[Greyscale morphological erosion
-](https://en.wikipedia.org/wiki/Erosion_(morphology))
-is the min-sum counterpart of standard sum-product convolution:
-
-$$ output[b, y, x, c] =
- min_{dy, dx} input[b,
- strides[1] * y - rates[1] * dy,
- strides[2] * x - rates[2] * dx,
- c] -
- filter[dy, dx, c]$$
-
-Dilation and erosion are dual to each other. The dilation of the input signal
-`f` by the structuring signal `g` is equal to the negation of the erosion of
-`-f` by the reflected `g`, and vice versa.
-
-Striding and padding is carried out in exactly the same way as in standard
-convolution. Please refer to the `Convolution` section for details.
-
-* `tf.nn.dilation2d`
-* `tf.nn.erosion2d`
-* `tf.nn.with_space_to_batch`
-
-## Normalization
-
-Normalization is useful to prevent neurons from saturating when inputs may
-have varying scale, and to aid generalization.
-
-* `tf.nn.l2_normalize`
-* `tf.nn.local_response_normalization`
-* `tf.nn.sufficient_statistics`
-* `tf.nn.normalize_moments`
-* `tf.nn.moments`
-* `tf.nn.weighted_moments`
-* `tf.nn.fused_batch_norm`
-* `tf.nn.batch_normalization`
-* `tf.nn.batch_norm_with_global_normalization`
-
-## Losses
-
-The loss ops measure error between two tensors, or between a tensor and zero.
-These can be used for measuring accuracy of a network in a regression task
-or for regularization purposes (weight decay).
-
-* `tf.nn.l2_loss`
-* `tf.nn.log_poisson_loss`
-
-## Classification
-
-TensorFlow provides several operations that help you perform classification.
-
-* `tf.nn.sigmoid_cross_entropy_with_logits`
-* `tf.nn.softmax`
-* `tf.nn.log_softmax`
-* `tf.nn.softmax_cross_entropy_with_logits`
-* `tf.nn.softmax_cross_entropy_with_logits_v2` - identical to the base
- version, except it allows gradient propagation into the labels.
-* `tf.nn.sparse_softmax_cross_entropy_with_logits`
-* `tf.nn.weighted_cross_entropy_with_logits`
-
-## Embeddings
-
-TensorFlow provides library support for looking up values in embedding
-tensors.
-
-* `tf.nn.embedding_lookup`
-* `tf.nn.embedding_lookup_sparse`
-
-## Recurrent Neural Networks
-
-TensorFlow provides a number of methods for constructing Recurrent
-Neural Networks. Most accept an `RNNCell`-subclassed object
-(see the documentation for `tf.contrib.rnn`).
-
-* `tf.nn.dynamic_rnn`
-* `tf.nn.bidirectional_dynamic_rnn`
-* `tf.nn.raw_rnn`
-
-## Connectionist Temporal Classification (CTC)
-
-* `tf.nn.ctc_loss`
-* `tf.nn.ctc_greedy_decoder`
-* `tf.nn.ctc_beam_search_decoder`
-
-## Evaluation
-
-The evaluation ops are useful for measuring the performance of a network.
-They are typically used at evaluation time.
-
-* `tf.nn.top_k`
-* `tf.nn.in_top_k`
-
-## Candidate Sampling
-
-Do you want to train a multiclass or multilabel model with thousands
-or millions of output classes (for example, a language model with a
-large vocabulary)? Training with a full Softmax is slow in this case,
-since all of the classes are evaluated for every training example.
-Candidate Sampling training algorithms can speed up your step times by
-only considering a small randomly-chosen subset of contrastive classes
-(called candidates) for each batch of training examples.
-
-See our
-[Candidate Sampling Algorithms
-Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
-
-### Sampled Loss Functions
-
-TensorFlow provides the following sampled loss functions for faster training.
-
-* `tf.nn.nce_loss`
-* `tf.nn.sampled_softmax_loss`
-
-### Candidate Samplers
-
-TensorFlow provides the following samplers for randomly sampling candidate
-classes when using one of the sampled loss functions above.
-
-* `tf.nn.uniform_candidate_sampler`
-* `tf.nn.log_uniform_candidate_sampler`
-* `tf.nn.learned_unigram_candidate_sampler`
-* `tf.nn.fixed_unigram_candidate_sampler`
-
-### Miscellaneous candidate sampling utilities
-
-* `tf.nn.compute_accidental_hits`
-
-### Quantization ops
-
-* `tf.nn.quantized_conv2d`
-* `tf.nn.quantized_relu_x`
-* `tf.nn.quantized_max_pool`
-* `tf.nn.quantized_avg_pool`
-
-## Notes on SAME Convolution Padding
-
-In these notes, we provide more background on the use of the `'SAME'` padding
-scheme for convolution operations.
-
-Tensorflow uses the smallest possible padding to achieve the desired output
-size. To understand what is done, consider the \\(1\\)-dimensional case. Denote
-\\(n_i\\) and \\(n_o\\) the input and output sizes, respectively, and denote the
-kernel size \\(k\\) and stride \\(s\\). As discussed in the
-[Convolution section](#Convolution), for `'SAME'`,
-\\(n_o = \left \lceil{\frac{n_i}{s}}\right \rceil\\).
-
-To achieve a desired output size \\(n_o\\), we need to pad the input such that the
-output size after a `'VALID'` convolution is \\(n_o\\). In other words, we need to
-have padding \\(p_i\\) such that:
-
-\begin{equation}
-\left \lceil{\frac{n_i + p_i - k + 1}{s}}\right \rceil = n_o
-\label{eq:tf_pad_1}
-\end{equation}
-
-What is the smallest \\(p_i\\) that we could possibly use? In general, \\(\left
-\lceil{\frac{x}{a}}\right \rceil = b\\) (with \\(a > 0\\)) means that \\(b-1 <
-\frac{x}{a} \leq b\\), and the smallest integer \\(x\\) we can choose to satisfy
-this is \\(x = a\cdot (b-1) + 1\\). The same applies to our problem; we need
-\\(p_i\\) such that:
-
-\begin{equation}
-n_i + p_i - k + 1 = s\cdot (n_o - 1) + 1
-\label{eq:tf_pad_2}
-\end{equation}
-
-which leads to:
-
-\begin{equation}
-p_i = s\cdot (n_o - 1) + k - n_i
-\label{eq:tf_pad_3}
-\end{equation}
-
-Note that this might lead to negative \\(p_i\\), since in some cases we might
-already have more input samples than we actually need. Thus,
-
-\begin{equation}
-p_i = max(s\cdot (n_o - 1) + k - n_i, 0)
-\label{eq:tf_pad_4}
-\end{equation}
-
-Remember that, for `'SAME'` padding,
-\\(n_o = \left \lceil{\frac{n_i}{s}}\right \rceil\\), as mentioned above.
-We need to analyze in detail two cases:
-
-- \\(n_i \text{ mod } s = 0\\)
-
-In this simple case, \\(n_o = \frac{n_i}{s}\\), and the expression for \\(p_i\\)
-becomes:
-
-\begin{equation}
-p_i = max(k - s, 0)
-\label{eq:tf_pad_5}
-\end{equation}
-
-- \\(n_i \text{ mod } s \neq 0\\)
-
-This case is more involved to parse. First, we write:
-
-\begin{equation}
-n_i = s\cdot\left \lceil{\frac{n_i}{s}}\right \rceil
-- s \left(\left \lceil{\frac{n_i}{s}}\right \rceil -
- \left \lfloor{\frac{n_i}{s}}\right \rfloor\right)
-+ (n_i \text{ mod } s)
-\label{eq:tf_pad_6}
-\end{equation}
-
-For the case where \\((n_i \text{ mod } s) \neq 0\\), we have \\(\left
-\lceil{\frac{n_i}{s}}\right \rceil -\left \lfloor{\frac{n_i}{s}}\right \rfloor =
-1\\), leading to:
-
-\begin{equation}
-n_i = s\cdot\left \lceil{\frac{n_i}{s}}\right \rceil
-- s
-+ (n_i \text{ mod } s)
-\label{eq:tf_pad_7}
-\end{equation}
-
-We can use this expression to substitute \\(n_o = \left
-\lceil{\frac{n_i}{s}}\right \rceil\\) and get:
-
-$$\begin{align}
-p_i &= max\left(s\cdot \left(\frac{n_i + s - (n_i \text{ mod } s)}{s}
- - 1\right) + k - n_i, 0\right) \nonumber\\
-&= max(n_i + s - (n_i \text{ mod } s) - s + k - n_i,0) \nonumber \\
-&= max(k - (n_i \text{ mod } s),0)
-\label{eq:tf_pad_8}
-\end{align}$$
-
-### Final expression
-
-Putting all together, the total padding used by tensorflow's convolution with
-`'SAME'` mode is:
-
-$$\begin{align}
-p_i =
- \begin{cases}
- max(k - s, 0), & \text{if $(n_i \text{ mod } s) = 0$} \\
- max(k - (n_i \text{ mod } s),0), & \text{if $(n_i \text{ mod } s) \neq 0$}
- \end{cases}
- \label{eq:tf_pad_9}
-\end{align}$$
-
-This expression is exactly equal to the ones presented for `pad_along_height`
-and `pad_along_width` in the [Convolution section](#Convolution).
diff --git a/tensorflow/docs_src/api_guides/python/python_io.md b/tensorflow/docs_src/api_guides/python/python_io.md
deleted file mode 100644
index e7e82a8701..0000000000
--- a/tensorflow/docs_src/api_guides/python/python_io.md
+++ /dev/null
@@ -1,29 +0,0 @@
-# Data IO (Python functions)
-[TOC]
-
-A TFRecords file represents a sequence of (binary) strings. The format is not
-random access, so it is suitable for streaming large amounts of data but not
-suitable if fast sharding or other non-sequential access is desired.
-
-* `tf.python_io.TFRecordWriter`
-* `tf.python_io.tf_record_iterator`
-* `tf.python_io.TFRecordCompressionType`
-* `tf.python_io.TFRecordOptions`
-
-- - -
-
-## TFRecords Format Details
-
-A TFRecords file contains a sequence of strings with CRC32C (32-bit CRC using
-the Castagnoli polynomial) hashes. Each record has the format
-
- uint64 length
- uint32 masked_crc32_of_length
- byte data[length]
- uint32 masked_crc32_of_data
-
-and the records are concatenated together to produce the file. CRCs are
-[described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check), and
-the mask of a CRC is
-
- masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
diff --git a/tensorflow/docs_src/api_guides/python/reading_data.md b/tensorflow/docs_src/api_guides/python/reading_data.md
deleted file mode 100644
index 78c36d965c..0000000000
--- a/tensorflow/docs_src/api_guides/python/reading_data.md
+++ /dev/null
@@ -1,522 +0,0 @@
-# Reading data
-
-Note: The preferred way to feed data into a tensorflow program is using the
-@{$datasets$`tf.data` API}.
-
-There are four methods of getting data into a TensorFlow program:
-
-* `tf.data` API: Easily construct a complex input pipeline. (preferred method)
-* Feeding: Python code provides the data when running each step.
-* `QueueRunner`: a queue-based input pipeline reads the data from files
- at the beginning of a TensorFlow graph.
-* Preloaded data: a constant or variable in the TensorFlow graph holds
- all the data (for small data sets).
-
-[TOC]
-
-## `tf.data` API
-
-See the @{$guide/datasets} for an in-depth explanation of `tf.data.Dataset`.
-The `tf.data` API enables you to extract and preprocess data
-from different input/file formats, and apply transformations such as batching,
-shuffling, and mapping functions over the dataset. This is an improved version
-of the old input methods---feeding and `QueueRunner`---which are described
-below for historical purposes.
-
-## Feeding
-
-Warning: "Feeding" is the least efficient way to feed data into a TensorFlow
-program and should only be used for small experiments and debugging.
-
-TensorFlow's feed mechanism lets you inject data into any Tensor in a
-computation graph. A Python computation can thus feed data directly into the
-graph.
-
-Supply feed data through the `feed_dict` argument to a run() or eval() call
-that initiates computation.
-
-```python
-with tf.Session():
- input = tf.placeholder(tf.float32)
- classifier = ...
- print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))
-```
-
-While you can replace any Tensor with feed data, including variables and
-constants, the best practice is to use a
-`tf.placeholder` node. A
-`placeholder` exists solely to serve as the target of feeds. It is not
-initialized and contains no data. A placeholder generates an error if
-it is executed without a feed, so you won't forget to feed it.
-
-An example using `placeholder` and feeding to train on MNIST data can be found
-in
-[`tensorflow/examples/tutorials/mnist/fully_connected_feed.py`](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/fully_connected_feed.py).
-
-## `QueueRunner`
-
-Warning: This section discusses implementing input pipelines using the
-queue-based APIs which can be cleanly replaced by the @{$datasets$`tf.data`
-API}.
-
-A typical queue-based pipeline for reading records from files has the following stages:
-
-1. The list of filenames
-2. *Optional* filename shuffling
-3. *Optional* epoch limit
-4. Filename queue
-5. A Reader for the file format
-6. A decoder for a record read by the reader
-7. *Optional* preprocessing
-8. Example queue
-
-### Filenames, shuffling, and epoch limits
-
-For the list of filenames, use either a constant string Tensor (like
-`["file0", "file1"]` or `[("file%d" % i) for i in range(2)]`) or the
-`tf.train.match_filenames_once` function.
-
-Pass the list of filenames to the `tf.train.string_input_producer` function.
-`string_input_producer` creates a FIFO queue for holding the filenames until
-the reader needs them.
-
-`string_input_producer` has options for shuffling and setting a maximum number
-of epochs. A queue runner adds the whole list of filenames to the queue once
-for each epoch, shuffling the filenames within an epoch if `shuffle=True`.
-This procedure provides a uniform sampling of files, so that examples are not
-under- or over- sampled relative to each other.
-
-The queue runner works in a thread separate from the reader that pulls
-filenames from the queue, so the shuffling and enqueuing process does not
-block the reader.
-
-### File formats
-
-Select the reader that matches your input file format and pass the filename
-queue to the reader's read method. The read method outputs a key identifying
-the file and record (useful for debugging if you have some weird records), and
-a scalar string value. Use one (or more) of the decoder and conversion ops to
-decode this string into the tensors that make up an example.
-
-#### CSV files
-
-To read text files in [comma-separated value (CSV)
-format](https://tools.ietf.org/html/rfc4180), use a
-`tf.TextLineReader` with the
-`tf.decode_csv` operation. For example:
-
-```python
-filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
-
-reader = tf.TextLineReader()
-key, value = reader.read(filename_queue)
-
-# Default values, in case of empty columns. Also specifies the type of the
-# decoded result.
-record_defaults = [[1], [1], [1], [1], [1]]
-col1, col2, col3, col4, col5 = tf.decode_csv(
- value, record_defaults=record_defaults)
-features = tf.stack([col1, col2, col3, col4])
-
-with tf.Session() as sess:
- # Start populating the filename queue.
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
-
- for i in range(1200):
- # Retrieve a single instance:
- example, label = sess.run([features, col5])
-
- coord.request_stop()
- coord.join(threads)
-```
-
-Each execution of `read` reads a single line from the file. The
-`decode_csv` op then parses the result into a list of tensors. The
-`record_defaults` argument determines the type of the resulting tensors and
-sets the default value to use if a value is missing in the input string.
-
-You must call `tf.train.start_queue_runners` to populate the queue before
-you call `run` or `eval` to execute the `read`. Otherwise `read` will
-block while it waits for filenames from the queue.
-
-#### Fixed length records
-
-To read binary files in which each record is a fixed number of bytes, use
-`tf.FixedLengthRecordReader`
-with the `tf.decode_raw` operation.
-The `decode_raw` op converts from a string to a uint8 tensor.
-
-For example, [the CIFAR-10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html)
-uses a file format where each record is represented using a fixed number of
-bytes: 1 byte for the label followed by 3072 bytes of image data. Once you have
-a uint8 tensor, standard operations can slice out each piece and reformat as
-needed. For CIFAR-10, you can see how to do the reading and decoding in
-[`tensorflow_models/tutorials/image/cifar10/cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py)
-and described in
-@{$deep_cnn#prepare-the-data$this tutorial}.
-
-#### Standard TensorFlow format
-
-Another approach is to convert whatever data you have into a supported format.
-This approach makes it easier to mix and match data sets and network
-architectures. The recommended format for TensorFlow is a
-@{$python/python_io#tfrecords_format_details$TFRecords file}
-containing
-[`tf.train.Example` protocol buffers](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
-(which contain
-[`Features`](https://www.tensorflow.org/code/tensorflow/core/example/feature.proto)
-as a field). You write a little program that gets your data, stuffs it in an
-`Example` protocol buffer, serializes the protocol buffer to a string, and then
-writes the string to a TFRecords file using the
-`tf.python_io.TFRecordWriter`.
-For example,
-[`tensorflow/examples/how_tos/reading_data/convert_to_records.py`](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/convert_to_records.py)
-converts MNIST data to this format.
-
-The recommended way to read a TFRecord file is with a `tf.data.TFRecordDataset`, [as in this example](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py):
-
-``` python
- dataset = tf.data.TFRecordDataset(filename)
- dataset = dataset.repeat(num_epochs)
-
- # map takes a python function and applies it to every sample
- dataset = dataset.map(decode)
-```
-
-To accomplish the same task with a queue based input pipeline requires the following code
-(using the same `decode` function from the above example):
-
-``` python
- filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- image,label = decode(serialized_example)
-```
-
-### Preprocessing
-
-You can then do any preprocessing of these examples you want. This would be any
-processing that doesn't depend on trainable parameters. Examples include
-normalization of your data, picking a random slice, adding noise or distortions,
-etc. See
-[`tensorflow_models/tutorials/image/cifar10/cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py)
-for an example.
-
-### Batching
-
-At the end of the pipeline we use another queue to batch together examples for
-training, evaluation, or inference. For this we use a queue that randomizes the
-order of examples, using the
-`tf.train.shuffle_batch`.
-
-Example:
-
-```
-def read_my_file_format(filename_queue):
- reader = tf.SomeReader()
- key, record_string = reader.read(filename_queue)
- example, label = tf.some_decoder(record_string)
- processed_example = some_processing(example)
- return processed_example, label
-
-def input_pipeline(filenames, batch_size, num_epochs=None):
- filename_queue = tf.train.string_input_producer(
- filenames, num_epochs=num_epochs, shuffle=True)
- example, label = read_my_file_format(filename_queue)
- # min_after_dequeue defines how big a buffer we will randomly sample
- # from -- bigger means better shuffling but slower start up and more
- # memory used.
- # capacity must be larger than min_after_dequeue and the amount larger
- # determines the maximum we will prefetch. Recommendation:
- # min_after_dequeue + (num_threads + a small safety margin) * batch_size
- min_after_dequeue = 10000
- capacity = min_after_dequeue + 3 * batch_size
- example_batch, label_batch = tf.train.shuffle_batch(
- [example, label], batch_size=batch_size, capacity=capacity,
- min_after_dequeue=min_after_dequeue)
- return example_batch, label_batch
-```
-
-If you need more parallelism or shuffling of examples between files, use
-multiple reader instances using the
-`tf.train.shuffle_batch_join`.
-For example:
-
-```
-def read_my_file_format(filename_queue):
- # Same as above
-
-def input_pipeline(filenames, batch_size, read_threads, num_epochs=None):
- filename_queue = tf.train.string_input_producer(
- filenames, num_epochs=num_epochs, shuffle=True)
- example_list = [read_my_file_format(filename_queue)
- for _ in range(read_threads)]
- min_after_dequeue = 10000
- capacity = min_after_dequeue + 3 * batch_size
- example_batch, label_batch = tf.train.shuffle_batch_join(
- example_list, batch_size=batch_size, capacity=capacity,
- min_after_dequeue=min_after_dequeue)
- return example_batch, label_batch
-```
-
-You still only use a single filename queue that is shared by all the readers.
-That way we ensure that the different readers use different files from the same
-epoch until all the files from the epoch have been started. (It is also usually
-sufficient to have a single thread filling the filename queue.)
-
-An alternative is to use a single reader via the
-`tf.train.shuffle_batch`
-with `num_threads` bigger than 1. This will make it read from a single file at
-the same time (but faster than with 1 thread), instead of N files at once.
-This can be important:
-
-* If you have more reading threads than input files, to avoid the risk that
- you will have two threads reading the same example from the same file near
- each other.
-* Or if reading N files in parallel causes too many disk seeks.
-
-How many threads do you need? the `tf.train.shuffle_batch*` functions add a
-summary to the graph that indicates how full the example queue is. If you have
-enough reading threads, that summary will stay above zero. You can
-@{$summaries_and_tensorboard$view your summaries as training progresses using TensorBoard}.
-
-### Creating threads to prefetch using `QueueRunner` objects
-
-The short version: many of the `tf.train` functions listed above add
-`tf.train.QueueRunner` objects to your
-graph. These require that you call
-`tf.train.start_queue_runners`
-before running any training or inference steps, or it will hang forever. This
-will start threads that run the input pipeline, filling the example queue so
-that the dequeue to get the examples will succeed. This is best combined with a
-`tf.train.Coordinator` to cleanly
-shut down these threads when there are errors. If you set a limit on the number
-of epochs, that will use an epoch counter that will need to be initialized. The
-recommended code pattern combining these is:
-
-```python
-# Create the graph, etc.
-init_op = tf.global_variables_initializer()
-
-# Create a session for running operations in the Graph.
-sess = tf.Session()
-
-# Initialize the variables (like the epoch counter).
-sess.run(init_op)
-
-# Start input enqueue threads.
-coord = tf.train.Coordinator()
-threads = tf.train.start_queue_runners(sess=sess, coord=coord)
-
-try:
- while not coord.should_stop():
- # Run training steps or whatever
- sess.run(train_op)
-
-except tf.errors.OutOfRangeError:
- print('Done training -- epoch limit reached')
-finally:
- # When done, ask the threads to stop.
- coord.request_stop()
-
-# Wait for threads to finish.
-coord.join(threads)
-sess.close()
-```
-
-#### Aside: What is happening here?
-
-First we create the graph. It will have a few pipeline stages that are
-connected by queues. The first stage will generate filenames to read and enqueue
-them in the filename queue. The second stage consumes filenames (using a
-`Reader`), produces examples, and enqueues them in an example queue. Depending
-on how you have set things up, you may actually have a few independent copies of
-the second stage, so that you can read from multiple files in parallel. At the
-end of these stages is an enqueue operation, which enqueues into a queue that
-the next stage dequeues from. We want to start threads running these enqueuing
-operations, so that our training loop can dequeue examples from the example
-queue.
-
-<div style="width:70%; margin-left:12%; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/AnimatedFileQueues.gif">
-</div>
-
-The helpers in `tf.train` that create these queues and enqueuing operations add
-a `tf.train.QueueRunner` to the
-graph using the
-`tf.train.add_queue_runner`
-function. Each `QueueRunner` is responsible for one stage, and holds the list of
-enqueue operations that need to be run in threads. Once the graph is
-constructed, the
-`tf.train.start_queue_runners`
-function asks each QueueRunner in the graph to start its threads running the
-enqueuing operations.
-
-If all goes well, you can now run your training steps and the queues will be
-filled by the background threads. If you have set an epoch limit, at some point
-an attempt to dequeue examples will get an
-`tf.errors.OutOfRangeError`. This
-is the TensorFlow equivalent of "end of file" (EOF) -- this means the epoch
-limit has been reached and no more examples are available.
-
-The last ingredient is the
-`tf.train.Coordinator`. This is responsible
-for letting all the threads know if anything has signaled a shut down. Most
-commonly this would be because an exception was raised, for example one of the
-threads got an error when running some operation (or an ordinary Python
-exception).
-
-For more about threading, queues, QueueRunners, and Coordinators
-@{$threading_and_queues$see here}.
-
-#### Aside: How clean shut-down when limiting epochs works
-
-Imagine you have a model that has set a limit on the number of epochs to train
-on. That means that the thread generating filenames will only run that many
-times before generating an `OutOfRange` error. The QueueRunner will catch that
-error, close the filename queue, and exit the thread. Closing the queue does two
-things:
-
-* Any future attempt to enqueue in the filename queue will generate an error.
- At this point there shouldn't be any threads trying to do that, but this
- is helpful when queues are closed due to other errors.
-* Any current or future dequeue will either succeed (if there are enough
- elements left) or fail (with an `OutOfRange` error) immediately. They won't
- block waiting for more elements to be enqueued, since by the previous point
- that can't happen.
-
-The point is that when the filename queue is closed, there will likely still be
-many filenames in that queue, so the next stage of the pipeline (with the reader
-and other preprocessing) may continue running for some time. Once the filename
-queue is exhausted, though, the next attempt to dequeue a filename (e.g. from a
-reader that has finished with the file it was working on) will trigger an
-`OutOfRange` error. In this case, though, you might have multiple threads
-associated with a single QueueRunner. If this isn't the last thread in the
-QueueRunner, the `OutOfRange` error just causes the one thread to exit. This
-allows the other threads, which are still finishing up their last file, to
-proceed until they finish as well. (Assuming you are using a
-`tf.train.Coordinator`,
-other types of errors will cause all the threads to stop.) Once all the reader
-threads hit the `OutOfRange` error, only then does the next queue, the example
-queue, gets closed.
-
-Again, the example queue will have some elements queued, so training will
-continue until those are exhausted. If the example queue is a
-`tf.RandomShuffleQueue`, say
-because you are using `shuffle_batch` or `shuffle_batch_join`, it normally will
-avoid ever having fewer than its `min_after_dequeue` attr elements buffered.
-However, once the queue is closed that restriction will be lifted and the queue
-will eventually empty. At that point the actual training threads, when they
-try and dequeue from example queue, will start getting `OutOfRange` errors and
-exiting. Once all the training threads are done,
-`tf.train.Coordinator.join`
-will return and you can exit cleanly.
-
-### Filtering records or producing multiple examples per record
-
-Instead of examples with shapes `[x, y, z]`, you will produce a batch of
-examples with shape `[batch, x, y, z]`. The batch size can be 0 if you want to
-filter this record out (maybe it is in a hold-out set?), or bigger than 1 if you
-are producing multiple examples per record. Then simply set `enqueue_many=True`
-when calling one of the batching functions (such as `shuffle_batch` or
-`shuffle_batch_join`).
-
-### Sparse input data
-
-SparseTensors don't play well with queues. If you use SparseTensors you have
-to decode the string records using
-`tf.parse_example` **after**
-batching (instead of using `tf.parse_single_example` before batching).
-
-## Preloaded data
-
-This is only used for small data sets that can be loaded entirely in memory.
-There are two approaches:
-
-* Store the data in a constant.
-* Store the data in a variable, that you initialize (or assign to) and then
- never change.
-
-Using a constant is a bit simpler, but uses more memory (since the constant is
-stored inline in the graph data structure, which may be duplicated a few times).
-
-```python
-training_data = ...
-training_labels = ...
-with tf.Session():
- input_data = tf.constant(training_data)
- input_labels = tf.constant(training_labels)
- ...
-```
-
-To instead use a variable, you need to also initialize it after the graph has been built.
-
-```python
-training_data = ...
-training_labels = ...
-with tf.Session() as sess:
- data_initializer = tf.placeholder(dtype=training_data.dtype,
- shape=training_data.shape)
- label_initializer = tf.placeholder(dtype=training_labels.dtype,
- shape=training_labels.shape)
- input_data = tf.Variable(data_initializer, trainable=False, collections=[])
- input_labels = tf.Variable(label_initializer, trainable=False, collections=[])
- ...
- sess.run(input_data.initializer,
- feed_dict={data_initializer: training_data})
- sess.run(input_labels.initializer,
- feed_dict={label_initializer: training_labels})
-```
-
-Setting `trainable=False` keeps the variable out of the
-`GraphKeys.TRAINABLE_VARIABLES` collection in the graph, so we won't try and
-update it when training. Setting `collections=[]` keeps the variable out of the
-`GraphKeys.GLOBAL_VARIABLES` collection used for saving and restoring checkpoints.
-
-Either way,
-`tf.train.slice_input_producer`
-can be used to produce a slice at a time. This shuffles the examples across an
-entire epoch, so further shuffling when batching is undesirable. So instead of
-using the `shuffle_batch` functions, we use the plain
-`tf.train.batch` function. To use
-multiple preprocessing threads, set the `num_threads` parameter to a number
-bigger than 1.
-
-An MNIST example that preloads the data using constants can be found in
-[`tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py`](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py), and one that preloads the data using variables can be found in
-[`tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py`](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py),
-You can compare these with the `fully_connected_feed` and
-`fully_connected_reader` versions above.
-
-## Multiple input pipelines
-
-Commonly you will want to train on one dataset and evaluate (or "eval") on
-another. One way to do this is to actually have two separate graphs and
-sessions, maybe in separate processes:
-
-* The training process reads training input data and periodically writes
- checkpoint files with all the trained variables.
-* The evaluation process restores the checkpoint files into an inference
- model that reads validation input data.
-
-This is what is done `tf.estimator` and manually in
-@{$deep_cnn#save-and-restore-checkpoints$the example CIFAR-10 model}.
-This has a couple of benefits:
-
-* The eval is performed on a single snapshot of the trained variables.
-* You can perform the eval even after training has completed and exited.
-
-You can have the train and eval in the same graph in the same process, and share
-their trained variables or layers. See @{$variables$the shared variables tutorial}.
-
-To support the single-graph approach
-@{$guide/datasets$`tf.data`} also supplies
-@{$guide/datasets#creating_an_iterator$advanced iterator types} that
-that allow the user to change the input pipeline without rebuilding the graph or
-session.
-
-Note: Regardless of the implementation, many
-operations (like `tf.layers.batch_normalization`, and `tf.layers.dropout`)
-need to know if they are in training or evaluation mode, and you must be
-careful to set this appropriately if you change the data source.
diff --git a/tensorflow/docs_src/api_guides/python/regression_examples.md b/tensorflow/docs_src/api_guides/python/regression_examples.md
deleted file mode 100644
index f8abbf0f97..0000000000
--- a/tensorflow/docs_src/api_guides/python/regression_examples.md
+++ /dev/null
@@ -1,232 +0,0 @@
-# Regression Examples
-
-This unit provides the following short examples demonstrating how
-to implement regression in Estimators:
-
-<table>
- <tr> <th>Example</th> <th>Demonstrates How To...</th></tr>
-
- <tr>
- <td><a href="https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/linear_regression.py">linear_regression.py</a></td>
- <td>Use the `tf.estimator.LinearRegressor` Estimator to train a
- regression model on numeric data.</td>
- </tr>
-
- <tr>
- <td><a href="https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/linear_regression_categorical.py">linear_regression_categorical.py</a></td>
- <td>Use the `tf.estimator.LinearRegressor` Estimator to train a
- regression model on categorical data.</td>
- </tr>
-
- <tr>
- <td><a href="https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/dnn_regression.py">dnn_regression.py</a></td>
- <td>Use the `tf.estimator.DNNRegressor` Estimator to train a
- regression model on discrete data with a deep neural network.</td>
- </tr>
-
- <tr>
- <td><a href="https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/custom_regression.py">custom_regression.py</a></td>
- <td>Use `tf.estimator.Estimator` to train a customized dnn
- regression model.</td>
- </tr>
-
-</table>
-
-The preceding examples rely on the following data set utility:
-
-<table>
- <tr> <th>Utility</th> <th>Description</th></tr>
-
- <tr>
- <td><a href="https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/imports85.py">imports85.py</a></td>
- <td>This program provides utility functions that load the
- <tt>imports85</tt> data set into formats that other TensorFlow
- programs (for example, <tt>linear_regression.py</tt> and
- <tt>dnn_regression.py</tt>) can use.</td>
- </tr>
-
-
-</table>
-
-
-<!--
-## Linear regression concepts
-
-If you are new to machine learning and want to learn about regression,
-watch the following video:
-
-(todo:jbgordon) Video introduction goes here.
--->
-
-<!--
-[When MLCC becomes available externally, add links to the relevant MLCC units.]
--->
-
-
-<a name="running"></a>
-## Running the examples
-
-You must @{$install$install TensorFlow} prior to running these examples.
-Depending on the way you've installed TensorFlow, you might also
-need to activate your TensorFlow environment. Then, do the following:
-
-1. Clone the TensorFlow repository from github.
-2. `cd` to the top of the downloaded tree.
-3. Check out the branch for you current tensorflow version: `git checkout rX.X`
-4. `cd tensorflow/examples/get_started/regression`.
-
-You can now run any of the example TensorFlow programs in the
-`tensorflow/examples/get_started/regression` directory as you
-would run any Python program:
-
-```bsh
-python linear_regressor.py
-```
-
-During training, all three programs output the following information:
-
-* The name of the checkpoint directory, which is important for TensorBoard.
-* The training loss after every 100 iterations, which helps you
- determine whether the model is converging.
-
-For example, here's some possible output for the `linear_regressor.py`
-program:
-
-``` None
-INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpAObiz9/model.ckpt.
-INFO:tensorflow:loss = 161.308, step = 1
-INFO:tensorflow:global_step/sec: 1557.24
-INFO:tensorflow:loss = 15.7937, step = 101 (0.065 sec)
-INFO:tensorflow:global_step/sec: 1529.17
-INFO:tensorflow:loss = 12.1988, step = 201 (0.065 sec)
-INFO:tensorflow:global_step/sec: 1663.86
-...
-INFO:tensorflow:loss = 6.99378, step = 901 (0.058 sec)
-INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmpAObiz9/model.ckpt.
-INFO:tensorflow:Loss for final step: 5.12413.
-```
-
-
-<a name="basic"></a>
-## linear_regressor.py
-
-`linear_regressor.py` trains a model that predicts the price of a car from
-two numerical features.
-
-<table>
- <tr>
- <td>Estimator</td>
- <td><tt>LinearRegressor</tt>, which is a pre-made Estimator for linear
- regression.</td>
- </tr>
-
- <tr>
- <td>Features</td>
- <td>Numerical: <tt>body-style</tt> and <tt>make</tt>.</td>
- </tr>
-
- <tr>
- <td>Label</td>
- <td>Numerical: <tt>price</tt>
- </tr>
-
- <tr>
- <td>Algorithm</td>
- <td>Linear regression.</td>
- </tr>
-</table>
-
-After training the model, the program concludes by outputting predicted
-car prices for two car models.
-
-
-
-<a name="categorical"></a>
-## linear_regression_categorical.py
-
-This program illustrates ways to represent categorical features. It
-also demonstrates how to train a linear model based on a mix of
-categorical and numerical features.
-
-<table>
- <tr>
- <td>Estimator</td>
- <td><tt>LinearRegressor</tt>, which is a pre-made Estimator for linear
- regression. </td>
- </tr>
-
- <tr>
- <td>Features</td>
- <td>Categorical: <tt>curb-weight</tt> and <tt>highway-mpg</tt>.<br/>
- Numerical: <tt>body-style</tt> and <tt>make</tt>.</td>
- </tr>
-
- <tr>
- <td>Label</td>
- <td>Numerical: <tt>price</tt>.</td>
- </tr>
-
- <tr>
- <td>Algorithm</td>
- <td>Linear regression.</td>
- </tr>
-</table>
-
-
-<a name="dnn"></a>
-## dnn_regression.py
-
-Like `linear_regression_categorical.py`, the `dnn_regression.py` example
-trains a model that predicts the price of a car from two features.
-Unlike `linear_regression_categorical.py`, the `dnn_regression.py` example uses
-a deep neural network to train the model. Both examples rely on the same
-features; `dnn_regression.py` demonstrates how to treat categorical features
-in a deep neural network.
-
-<table>
- <tr>
- <td>Estimator</td>
- <td><tt>DNNRegressor</tt>, which is a pre-made Estimator for
- regression that relies on a deep neural network. The
- `hidden_units` parameter defines the topography of the network.</td>
- </tr>
-
- <tr>
- <td>Features</td>
- <td>Categorical: <tt>curb-weight</tt> and <tt>highway-mpg</tt>.<br/>
- Numerical: <tt>body-style</tt> and <tt>make</tt>.</td>
- </tr>
-
- <tr>
- <td>Label</td>
- <td>Numerical: <tt>price</tt>.</td>
- </tr>
-
- <tr>
- <td>Algorithm</td>
- <td>Regression through a deep neural network.</td>
- </tr>
-</table>
-
-After printing loss values, the program outputs the Mean Square Error
-on a test set.
-
-
-<a name="dnn"></a>
-## custom_regression.py
-
-The `custom_regression.py` example also trains a model that predicts the price
-of a car based on mixed real-valued and categorical input features, described by
-feature_columns. Unlike `linear_regression_categorical.py`, and
-`dnn_regression.py` this example does not use a pre-made estimator, but defines
-a custom model using the base `tf.estimator.Estimator` class. The
-custom model is quite similar to the model defined by `dnn_regression.py`.
-
-The custom model is defined by the `model_fn` argument to the constructor. The
-customization is made more reusable through `params` dictionary, which is later
-passed through to the `model_fn` when the `model_fn` is called.
-
-The `model_fn` returns an
-`tf.estimator.EstimatorSpec` which is a simple structure
-indicating to the `Estimator` which operations should be run to accomplish
-various tasks.
diff --git a/tensorflow/docs_src/api_guides/python/session_ops.md b/tensorflow/docs_src/api_guides/python/session_ops.md
deleted file mode 100644
index 5f41bcf209..0000000000
--- a/tensorflow/docs_src/api_guides/python/session_ops.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# Tensor Handle Operations
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Tensor Handle Operations
-
-TensorFlow provides several operators that allows the user to keep tensors
-"in-place" across run calls.
-
-* `tf.get_session_handle`
-* `tf.get_session_tensor`
-* `tf.delete_session_tensor`
diff --git a/tensorflow/docs_src/api_guides/python/sparse_ops.md b/tensorflow/docs_src/api_guides/python/sparse_ops.md
deleted file mode 100644
index b360055ed0..0000000000
--- a/tensorflow/docs_src/api_guides/python/sparse_ops.md
+++ /dev/null
@@ -1,45 +0,0 @@
-# Sparse Tensors
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Sparse Tensor Representation
-
-TensorFlow supports a `SparseTensor` representation for data that is sparse
-in multiple dimensions. Contrast this representation with `IndexedSlices`,
-which is efficient for representing tensors that are sparse in their first
-dimension, and dense along all other dimensions.
-
-* `tf.SparseTensor`
-* `tf.SparseTensorValue`
-
-## Conversion
-
-* `tf.sparse_to_dense`
-* `tf.sparse_tensor_to_dense`
-* `tf.sparse_to_indicator`
-* `tf.sparse_merge`
-
-## Manipulation
-
-* `tf.sparse_concat`
-* `tf.sparse_reorder`
-* `tf.sparse_reshape`
-* `tf.sparse_split`
-* `tf.sparse_retain`
-* `tf.sparse_reset_shape`
-* `tf.sparse_fill_empty_rows`
-* `tf.sparse_transpose`
-
-## Reduction
-* `tf.sparse_reduce_sum`
-* `tf.sparse_reduce_sum_sparse`
-
-## Math Operations
-* `tf.sparse_add`
-* `tf.sparse_softmax`
-* `tf.sparse_tensor_dense_matmul`
-* `tf.sparse_maximum`
-* `tf.sparse_minimum`
diff --git a/tensorflow/docs_src/api_guides/python/spectral_ops.md b/tensorflow/docs_src/api_guides/python/spectral_ops.md
deleted file mode 100644
index f6d109a3a0..0000000000
--- a/tensorflow/docs_src/api_guides/python/spectral_ops.md
+++ /dev/null
@@ -1,26 +0,0 @@
-# Spectral Functions
-
-[TOC]
-
-The `tf.spectral` module supports several spectral decomposition operations
-that you can use to transform Tensors of real and complex signals.
-
-## Discrete Fourier Transforms
-
-* `tf.spectral.fft`
-* `tf.spectral.ifft`
-* `tf.spectral.fft2d`
-* `tf.spectral.ifft2d`
-* `tf.spectral.fft3d`
-* `tf.spectral.ifft3d`
-* `tf.spectral.rfft`
-* `tf.spectral.irfft`
-* `tf.spectral.rfft2d`
-* `tf.spectral.irfft2d`
-* `tf.spectral.rfft3d`
-* `tf.spectral.irfft3d`
-
-## Discrete Cosine Transforms
-
-* `tf.spectral.dct`
-* `tf.spectral.idct`
diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md
deleted file mode 100644
index fc55ea1481..0000000000
--- a/tensorflow/docs_src/api_guides/python/state_ops.md
+++ /dev/null
@@ -1,110 +0,0 @@
-# Variables
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Variables
-
-* `tf.Variable`
-
-## Variable helper functions
-
-TensorFlow provides a set of functions to help manage the set of variables
-collected in the graph.
-
-* `tf.global_variables`
-* `tf.local_variables`
-* `tf.model_variables`
-* `tf.trainable_variables`
-* `tf.moving_average_variables`
-* `tf.global_variables_initializer`
-* `tf.local_variables_initializer`
-* `tf.variables_initializer`
-* `tf.is_variable_initialized`
-* `tf.report_uninitialized_variables`
-* `tf.assert_variables_initialized`
-* `tf.assign`
-* `tf.assign_add`
-* `tf.assign_sub`
-
-## Saving and Restoring Variables
-
-* `tf.train.Saver`
-* `tf.train.latest_checkpoint`
-* `tf.train.get_checkpoint_state`
-* `tf.train.update_checkpoint_state`
-
-## Sharing Variables
-
-TensorFlow provides several classes and operations that you can use to
-create variables contingent on certain conditions.
-
-* `tf.get_variable`
-* `tf.get_local_variable`
-* `tf.VariableScope`
-* `tf.variable_scope`
-* `tf.variable_op_scope`
-* `tf.get_variable_scope`
-* `tf.make_template`
-* `tf.no_regularizer`
-* `tf.constant_initializer`
-* `tf.random_normal_initializer`
-* `tf.truncated_normal_initializer`
-* `tf.random_uniform_initializer`
-* `tf.uniform_unit_scaling_initializer`
-* `tf.zeros_initializer`
-* `tf.ones_initializer`
-* `tf.orthogonal_initializer`
-
-## Variable Partitioners for Sharding
-
-* `tf.fixed_size_partitioner`
-* `tf.variable_axis_size_partitioner`
-* `tf.min_max_variable_partitioner`
-
-## Sparse Variable Updates
-
-The sparse update ops modify a subset of the entries in a dense `Variable`,
-either overwriting the entries or adding / subtracting a delta. These are
-useful for training embedding models and similar lookup-based networks, since
-only a small subset of embedding vectors change in any given step.
-
-Since a sparse update of a large tensor may be generated automatically during
-gradient computation (as in the gradient of
-`tf.gather`),
-an `tf.IndexedSlices` class is provided that encapsulates a set
-of sparse indices and values. `IndexedSlices` objects are detected and handled
-automatically by the optimizers in most cases.
-
-* `tf.scatter_update`
-* `tf.scatter_add`
-* `tf.scatter_sub`
-* `tf.scatter_mul`
-* `tf.scatter_div`
-* `tf.scatter_min`
-* `tf.scatter_max`
-* `tf.scatter_nd_update`
-* `tf.scatter_nd_add`
-* `tf.scatter_nd_sub`
-* `tf.sparse_mask`
-* `tf.IndexedSlices`
-
-### Read-only Lookup Tables
-
-* `tf.initialize_all_tables`
-* `tf.tables_initializer`
-
-
-## Exporting and Importing Meta Graphs
-
-* `tf.train.export_meta_graph`
-* `tf.train.import_meta_graph`
-
-# Deprecated functions (removed after 2017-03-02). Please don't use them.
-
-* `tf.all_variables`
-* `tf.initialize_all_variables`
-* `tf.initialize_local_variables`
-* `tf.initialize_variables`
diff --git a/tensorflow/docs_src/api_guides/python/string_ops.md b/tensorflow/docs_src/api_guides/python/string_ops.md
deleted file mode 100644
index 24a3aad642..0000000000
--- a/tensorflow/docs_src/api_guides/python/string_ops.md
+++ /dev/null
@@ -1,39 +0,0 @@
-# Strings
-
-Note: Functions taking `Tensor` arguments can also take anything accepted by
-`tf.convert_to_tensor`.
-
-[TOC]
-
-## Hashing
-
-String hashing ops take a string input tensor and map each element to an
-integer.
-
-* `tf.string_to_hash_bucket_fast`
-* `tf.string_to_hash_bucket_strong`
-* `tf.string_to_hash_bucket`
-
-## Joining
-
-String joining ops concatenate elements of input string tensors to produce a new
-string tensor.
-
-* `tf.reduce_join`
-* `tf.string_join`
-
-## Splitting
-
-* `tf.string_split`
-* `tf.substr`
-
-## Conversion
-
-* `tf.as_string`
-* `tf.string_to_number`
-
-* `tf.decode_raw`
-* `tf.decode_csv`
-
-* `tf.encode_base64`
-* `tf.decode_base64`
diff --git a/tensorflow/docs_src/api_guides/python/summary.md b/tensorflow/docs_src/api_guides/python/summary.md
deleted file mode 100644
index e290703b7d..0000000000
--- a/tensorflow/docs_src/api_guides/python/summary.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# Summary Operations
-[TOC]
-
-Summaries provide a way to export condensed information about a model, which is
-then accessible in tools such as @{$summaries_and_tensorboard$TensorBoard}.
-
-## Generation of Summaries
-
-### Class for writing Summaries
-* `tf.summary.FileWriter`
-* `tf.summary.FileWriterCache`
-
-### Summary Ops
-* `tf.summary.tensor_summary`
-* `tf.summary.scalar`
-* `tf.summary.histogram`
-* `tf.summary.audio`
-* `tf.summary.image`
-* `tf.summary.merge`
-* `tf.summary.merge_all`
-
-## Utilities
-* `tf.summary.get_summary_description`
diff --git a/tensorflow/docs_src/api_guides/python/test.md b/tensorflow/docs_src/api_guides/python/test.md
deleted file mode 100644
index b6e0a332b9..0000000000
--- a/tensorflow/docs_src/api_guides/python/test.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# Testing
-[TOC]
-
-## Unit tests
-
-TensorFlow provides a convenience class inheriting from `unittest.TestCase`
-which adds methods relevant to TensorFlow tests. Here is an example:
-
-```python
- import tensorflow as tf
-
-
- class SquareTest(tf.test.TestCase):
-
- def testSquare(self):
- with self.test_session():
- x = tf.square([2, 3])
- self.assertAllEqual(x.eval(), [4, 9])
-
-
- if __name__ == '__main__':
- tf.test.main()
-```
-
-`tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional
-methods. See `tf.test.TestCase` for details.
-
-* `tf.test.main`
-* `tf.test.TestCase`
-* `tf.test.test_src_dir_path`
-
-## Utilities
-
-Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
-depending on the python version.
-
-* `tf.test.assert_equal_graph_def`
-* `tf.test.get_temp_dir`
-* `tf.test.is_built_with_cuda`
-* `tf.test.is_gpu_available`
-* `tf.test.gpu_device_name`
-
-## Gradient checking
-
-`tf.test.compute_gradient` and `tf.test.compute_gradient_error` perform
-numerical differentiation of graphs for comparison against registered analytic
-gradients.
diff --git a/tensorflow/docs_src/api_guides/python/tfdbg.md b/tensorflow/docs_src/api_guides/python/tfdbg.md
deleted file mode 100644
index 9778cdc0b0..0000000000
--- a/tensorflow/docs_src/api_guides/python/tfdbg.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# TensorFlow Debugger
-[TOC]
-
-Public Python API of TensorFlow Debugger (tfdbg).
-
-## Functions for adding debug watches
-
-These functions help you modify `RunOptions` to specify which `Tensor`s are to
-be watched when the TensorFlow graph is executed at runtime.
-
-* `tfdbg.add_debug_tensor_watch`
-* `tfdbg.watch_graph`
-* `tfdbg.watch_graph_with_blacklists`
-
-
-## Classes for debug-dump data and directories
-
-These classes allow you to load and inspect tensor values dumped from
-TensorFlow graphs during runtime.
-
-* `tfdbg.DebugTensorDatum`
-* `tfdbg.DebugDumpDir`
-
-
-## Functions for loading debug-dump data
-
-* `tfdbg.load_tensor_from_event_file`
-
-
-## Tensor-value predicates
-
-Built-in tensor-filter predicates to support conditional breakpoint between
-runs. See `DebugDumpDir.find()` for more details.
-
-* `tfdbg.has_inf_or_nan`
-
-
-## Session wrapper class and `SessionRunHook` implementations
-
-These classes allow you to
-
-* wrap aroundTensorFlow `Session` objects to debug plain TensorFlow models
- (see `DumpingDebugWrapperSession` and `LocalCLIDebugWrapperSession`), or
-* generate `SessionRunHook` objects to debug `tf.contrib.learn` models (see
- `DumpingDebugHook` and `LocalCLIDebugHook`).
-
-* `tfdbg.DumpingDebugHook`
-* `tfdbg.DumpingDebugWrapperSession`
-* `tfdbg.LocalCLIDebugHook`
-* `tfdbg.LocalCLIDebugWrapperSession`
diff --git a/tensorflow/docs_src/api_guides/python/threading_and_queues.md b/tensorflow/docs_src/api_guides/python/threading_and_queues.md
deleted file mode 100644
index 48f0778b73..0000000000
--- a/tensorflow/docs_src/api_guides/python/threading_and_queues.md
+++ /dev/null
@@ -1,270 +0,0 @@
-# Threading and Queues
-
-Note: In versions of TensorFlow before 1.2, we recommended using multi-threaded,
-queue-based input pipelines for performance. Beginning with TensorFlow 1.4,
-however, we recommend using the `tf.data` module instead. (See
-@{$datasets$Datasets} for details. In TensorFlow 1.2 and 1.3, the module was
-called `tf.contrib.data`.) The `tf.data` module offers an easier-to-use
-interface for constructing efficient input pipelines. Furthermore, we've stopped
-developing the old multi-threaded, queue-based input pipelines. We've retained
-the documentation in this file to help developers who are still maintaining
-older code.
-
-Multithreaded queues are a powerful and widely used mechanism supporting
-asynchronous computation.
-
-Following the [dataflow programming model](graphs.md), TensorFlow's queues are
-implemented using nodes in the computation graph. A queue is a stateful node,
-like a variable: other nodes can modify its content. In particular, nodes can
-enqueue new items in to the queue, or dequeue existing items from the
-queue. TensorFlow's queues provide a way to coordinate multiple steps of a
-computation: a queue will **block** any step that attempts to dequeue from it
-when it is empty, or enqueue to it when it is full. When that condition no
-longer holds, the queue will unblock the step and allow execution to proceed.
-
-TensorFlow implements several classes of queue. The principal difference between
-these classes is the order that items are removed from the queue. To get a feel
-for queues, let's consider a simple example. We will create a "first in, first
-out" queue (`tf.FIFOQueue`) and fill it with zeros. Then we'll construct a
-graph that takes an item off the queue, adds one to that item, and puts it back
-on the end of the queue. Slowly, the numbers on the queue increase.
-
-<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/IncremeterFifoQueue.gif">
-</div>
-
-`Enqueue`, `EnqueueMany`, and `Dequeue` are special nodes. They take a pointer
-to the queue instead of a normal value, allowing them to mutate its state. We
-recommend that you think of these operations as being like methods of the queue
-in an object-oriented sense. In fact, in the Python API, these operations are
-created by calling methods on a queue object (e.g. `q.enqueue(...)`).
-
-Note: Queue methods (such as `q.enqueue(...)`) *must* run on the same device
-as the queue. Incompatible device placement directives will be ignored when
-creating these operations.
-
-Now that you have a bit of a feel for queues, let's dive into the details...
-
-## Queue usage overview
-
-Queues, such as `tf.FIFOQueue`
-and `tf.RandomShuffleQueue`,
-are important TensorFlow objects that aid in computing tensors asynchronously
-in a graph.
-
-For example, a typical queue-based input pipeline uses a `RandomShuffleQueue` to
-prepare inputs for training a model as follows:
-
-* Multiple threads prepare training examples and enqueue them.
-* A training thread executes a training op that dequeues mini-batches from the
- queue
-
-We recommend using the `tf.data.Dataset.shuffle`
-and `tf.data.Dataset.batch` methods of a
-`tf.data.Dataset` to accomplish this. However, if you'd prefer
-to use a queue-based version instead, you can find a full implementation in the
-`tf.train.shuffle_batch` function.
-
-For demonstration purposes a simplified implementation is given below.
-
-This function takes a source tensor, a capacity, and a batch size as arguments
-and returns a tensor that dequeues a shuffled batch when executed.
-
-``` python
-def simple_shuffle_batch(source, capacity, batch_size=10):
- # Create a random shuffle queue.
- queue = tf.RandomShuffleQueue(capacity=capacity,
- min_after_dequeue=int(0.9*capacity),
- shapes=source.shape, dtypes=source.dtype)
-
- # Create an op to enqueue one item.
- enqueue = queue.enqueue(source)
-
- # Create a queue runner that, when started, will launch 4 threads applying
- # that enqueue op.
- num_threads = 4
- qr = tf.train.QueueRunner(queue, [enqueue] * num_threads)
-
- # Register the queue runner so it can be found and started by
- # `tf.train.start_queue_runners` later (the threads are not launched yet).
- tf.train.add_queue_runner(qr)
-
- # Create an op to dequeue a batch
- return queue.dequeue_many(batch_size)
-```
-
-Once started by `tf.train.start_queue_runners`, or indirectly through
-`tf.train.MonitoredSession`, the `QueueRunner` will launch the
-threads in the background to fill the queue. Meanwhile the main thread will
-execute the `dequeue_many` op to pull data from it. Note how these ops do not
-depend on each other, except indirectly through the internal state of the queue.
-
-The simplest possible use of this function might be something like this:
-
-``` python
-# create a dataset that counts from 0 to 99
-input = tf.constant(list(range(100)))
-input = tf.data.Dataset.from_tensor_slices(input)
-input = input.make_one_shot_iterator().get_next()
-
-# Create a slightly shuffled batch from the sorted elements
-get_batch = simple_shuffle_batch(input, capacity=20)
-
-# `MonitoredSession` will start and manage the `QueueRunner` threads.
-with tf.train.MonitoredSession() as sess:
- # Since the `QueueRunners` have been started, data is available in the
- # queue, so the `sess.run(get_batch)` call will not hang.
- while not sess.should_stop():
- print(sess.run(get_batch))
-```
-
-```
-[ 8 10 7 5 4 13 15 14 25 0]
-[23 29 28 31 33 18 19 11 34 27]
-[12 21 37 39 35 22 44 36 20 46]
-...
-```
-
-For most use cases, the automatic thread startup and management provided
-by `tf.train.MonitoredSession` is sufficient. In the rare case that it is not,
-TensorFlow provides tools for manually managing your threads and queues.
-
-## Manual Thread Management
-
-As we have seen, the TensorFlow `Session` object is multithreaded and
-thread-safe, so multiple threads can
-easily use the same session and run ops in parallel. However, it is not always
-easy to implement a Python program that drives threads as required. All
-threads must be able to stop together, exceptions must be caught and
-reported, and queues must be properly closed when stopping.
-
-TensorFlow provides two classes to help:
-`tf.train.Coordinator` and
-`tf.train.QueueRunner`. These two classes
-are designed to be used together. The `Coordinator` class helps multiple threads
-stop together and report exceptions to a program that waits for them to stop.
-The `QueueRunner` class is used to create a number of threads cooperating to
-enqueue tensors in the same queue.
-
-### Coordinator
-
-The `tf.train.Coordinator` class manages background threads in a TensorFlow
-program and helps multiple threads stop together.
-
-Its key methods are:
-
-* `tf.train.Coordinator.should_stop`: returns `True` if the threads should stop.
-* `tf.train.Coordinator.request_stop`: requests that threads should stop.
-* `tf.train.Coordinator.join`: waits until the specified threads have stopped.
-
-You first create a `Coordinator` object, and then create a number of threads
-that use the coordinator. The threads typically run loops that stop when
-`should_stop()` returns `True`.
-
-Any thread can decide that the computation should stop. It only has to call
-`request_stop()` and the other threads will stop as `should_stop()` will then
-return `True`.
-
-```python
-# Using Python's threading library.
-import threading
-
-# Thread body: loop until the coordinator indicates a stop was requested.
-# If some condition becomes true, ask the coordinator to stop.
-def MyLoop(coord):
- while not coord.should_stop():
- ...do something...
- if ...some condition...:
- coord.request_stop()
-
-# Main thread: create a coordinator.
-coord = tf.train.Coordinator()
-
-# Create 10 threads that run 'MyLoop()'
-threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]
-
-# Start the threads and wait for all of them to stop.
-for t in threads:
- t.start()
-coord.join(threads)
-```
-
-Obviously, the coordinator can manage threads doing very different things.
-They don't have to be all the same as in the example above. The coordinator
-also has support to capture and report exceptions. See the `tf.train.Coordinator` documentation for more details.
-
-### QueueRunner
-
-The `tf.train.QueueRunner` class creates a number of threads that repeatedly
-run an enqueue op. These threads can use a coordinator to stop together. In
-addition, a queue runner will run a *closer operation* that closes the queue if
-an exception is reported to the coordinator.
-
-You can use a queue runner to implement the architecture described above.
-
-First build a graph that uses a TensorFlow queue (e.g. a `tf.RandomShuffleQueue`) for input examples. Add ops that
-process examples and enqueue them in the queue. Add training ops that start by
-dequeueing from the queue.
-
-```python
-example = ...ops to create one example...
-# Create a queue, and an op that enqueues examples one at a time in the queue.
-queue = tf.RandomShuffleQueue(...)
-enqueue_op = queue.enqueue(example)
-# Create a training graph that starts by dequeueing a batch of examples.
-inputs = queue.dequeue_many(batch_size)
-train_op = ...use 'inputs' to build the training part of the graph...
-```
-
-In the Python training program, create a `QueueRunner` that will run a few
-threads to process and enqueue examples. Create a `Coordinator` and ask the
-queue runner to start its threads with the coordinator. Write a training loop
-that also uses the coordinator.
-
-```python
-# Create a queue runner that will run 4 threads in parallel to enqueue
-# examples.
-qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
-
-# Launch the graph.
-sess = tf.Session()
-# Create a coordinator, launch the queue runner threads.
-coord = tf.train.Coordinator()
-enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
-# Run the training loop, controlling termination with the coordinator.
-for step in xrange(1000000):
- if coord.should_stop():
- break
- sess.run(train_op)
-# When done, ask the threads to stop.
-coord.request_stop()
-# And wait for them to actually do it.
-coord.join(enqueue_threads)
-```
-
-### Handling exceptions
-
-Threads started by queue runners do more than just run the enqueue ops. They
-also catch and handle exceptions generated by queues, including the
-`tf.errors.OutOfRangeError` exception, which is used to report that a queue was
-closed.
-
-A training program that uses a coordinator must similarly catch and report
-exceptions in its main loop.
-
-Here is an improved version of the training loop above.
-
-```python
-try:
- for step in xrange(1000000):
- if coord.should_stop():
- break
- sess.run(train_op)
-except Exception, e:
- # Report exceptions to the coordinator.
- coord.request_stop(e)
-finally:
- # Terminate as usual. It is safe to call `coord.request_stop()` twice.
- coord.request_stop()
- coord.join(threads)
-```
diff --git a/tensorflow/docs_src/api_guides/python/train.md b/tensorflow/docs_src/api_guides/python/train.md
deleted file mode 100644
index a118123665..0000000000
--- a/tensorflow/docs_src/api_guides/python/train.md
+++ /dev/null
@@ -1,139 +0,0 @@
-# Training
-[TOC]
-
-`tf.train` provides a set of classes and functions that help train models.
-
-## Optimizers
-
-The Optimizer base class provides methods to compute gradients for a loss and
-apply gradients to variables. A collection of subclasses implement classic
-optimization algorithms such as GradientDescent and Adagrad.
-
-You never instantiate the Optimizer class itself, but instead instantiate one
-of the subclasses.
-
-* `tf.train.Optimizer`
-* `tf.train.GradientDescentOptimizer`
-* `tf.train.AdadeltaOptimizer`
-* `tf.train.AdagradOptimizer`
-* `tf.train.AdagradDAOptimizer`
-* `tf.train.MomentumOptimizer`
-* `tf.train.AdamOptimizer`
-* `tf.train.FtrlOptimizer`
-* `tf.train.ProximalGradientDescentOptimizer`
-* `tf.train.ProximalAdagradOptimizer`
-* `tf.train.RMSPropOptimizer`
-
-See `tf.contrib.opt` for more optimizers.
-
-## Gradient Computation
-
-TensorFlow provides functions to compute the derivatives for a given
-TensorFlow computation graph, adding operations to the graph. The
-optimizer classes automatically compute derivatives on your graph, but
-creators of new Optimizers or expert users can call the lower-level
-functions below.
-
-* `tf.gradients`
-* `tf.AggregationMethod`
-* `tf.stop_gradient`
-* `tf.hessians`
-
-
-## Gradient Clipping
-
-TensorFlow provides several operations that you can use to add clipping
-functions to your graph. You can use these functions to perform general data
-clipping, but they're particularly useful for handling exploding or vanishing
-gradients.
-
-* `tf.clip_by_value`
-* `tf.clip_by_norm`
-* `tf.clip_by_average_norm`
-* `tf.clip_by_global_norm`
-* `tf.global_norm`
-
-## Decaying the learning rate
-
-* `tf.train.exponential_decay`
-* `tf.train.inverse_time_decay`
-* `tf.train.natural_exp_decay`
-* `tf.train.piecewise_constant`
-* `tf.train.polynomial_decay`
-* `tf.train.cosine_decay`
-* `tf.train.linear_cosine_decay`
-* `tf.train.noisy_linear_cosine_decay`
-
-## Moving Averages
-
-Some training algorithms, such as GradientDescent and Momentum often benefit
-from maintaining a moving average of variables during optimization. Using the
-moving averages for evaluations often improve results significantly.
-
-* `tf.train.ExponentialMovingAverage`
-
-## Coordinator and QueueRunner
-
-See @{$threading_and_queues$Threading and Queues}
-for how to use threads and queues. For documentation on the Queue API,
-see @{$python/io_ops#queues$Queues}.
-
-
-* `tf.train.Coordinator`
-* `tf.train.QueueRunner`
-* `tf.train.LooperThread`
-* `tf.train.add_queue_runner`
-* `tf.train.start_queue_runners`
-
-## Distributed execution
-
-See @{$distributed$Distributed TensorFlow} for
-more information about how to configure a distributed TensorFlow program.
-
-* `tf.train.Server`
-* `tf.train.Supervisor`
-* `tf.train.SessionManager`
-* `tf.train.ClusterSpec`
-* `tf.train.replica_device_setter`
-* `tf.train.MonitoredTrainingSession`
-* `tf.train.MonitoredSession`
-* `tf.train.SingularMonitoredSession`
-* `tf.train.Scaffold`
-* `tf.train.SessionCreator`
-* `tf.train.ChiefSessionCreator`
-* `tf.train.WorkerSessionCreator`
-
-## Reading Summaries from Event Files
-
-See @{$summaries_and_tensorboard$Summaries and TensorBoard} for an
-overview of summaries, event files, and visualization in TensorBoard.
-
-* `tf.train.summary_iterator`
-
-## Training Hooks
-
-Hooks are tools that run in the process of training/evaluation of the model.
-
-* `tf.train.SessionRunHook`
-* `tf.train.SessionRunArgs`
-* `tf.train.SessionRunContext`
-* `tf.train.SessionRunValues`
-* `tf.train.LoggingTensorHook`
-* `tf.train.StopAtStepHook`
-* `tf.train.CheckpointSaverHook`
-* `tf.train.NewCheckpointReader`
-* `tf.train.StepCounterHook`
-* `tf.train.NanLossDuringTrainingError`
-* `tf.train.NanTensorHook`
-* `tf.train.SummarySaverHook`
-* `tf.train.GlobalStepWaiterHook`
-* `tf.train.FinalOpsHook`
-* `tf.train.FeedFnHook`
-
-## Training Utilities
-
-* `tf.train.global_step`
-* `tf.train.basic_train_loop`
-* `tf.train.get_global_step`
-* `tf.train.assert_global_step`
-* `tf.train.write_graph`
diff --git a/tensorflow/docs_src/community/benchmarks.md b/tensorflow/docs_src/community/benchmarks.md
deleted file mode 100644
index 153ef4a015..0000000000
--- a/tensorflow/docs_src/community/benchmarks.md
+++ /dev/null
@@ -1,108 +0,0 @@
-# Defining and Running Benchmarks
-
-This guide contains instructions for defining and running a TensorFlow benchmark. These benchmarks store output in [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) format. If these benchmarks are added to the TensorFlow github repo, we will run them daily with our continuous build and display a graph on our dashboard: https://benchmarks-dot-tensorflow-testing.appspot.com/.
-
-[TOC]
-
-
-## Defining a Benchmark
-
-Defining a TensorFlow benchmark requires extending the `tf.test.Benchmark`
-class and calling the `self.report_benchmark` method. Below, you'll find an example of benchmark code:
-
-```python
-import time
-
-import tensorflow as tf
-
-
-# Define a class that extends from tf.test.Benchmark.
-class SampleBenchmark(tf.test.Benchmark):
-
- # Note: benchmark method name must start with `benchmark`.
- def benchmarkSum(self):
- with tf.Session() as sess:
- x = tf.constant(10)
- y = tf.constant(5)
- result = tf.add(x, y)
-
- iters = 100
- start_time = time.time()
- for _ in range(iters):
- sess.run(result)
- total_wall_time = time.time() - start_time
-
- # Call report_benchmark to report a metric value.
- self.report_benchmark(
- name="sum_wall_time",
- # This value should always be per iteration.
- wall_time=total_wall_time/iters,
- iters=iters)
-
-if __name__ == "__main__":
- tf.test.main()
-```
-See the full example for [SampleBenchmark](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/benchmark/).
-
-
-Key points to note in the example above:
-
-* Benchmark class extends from `tf.test.Benchmark`.
-* Each benchmark method should start with `benchmark` prefix.
-* Benchmark method calls `report_benchmark` to report the metric value.
-
-
-## Running with Python
-
-Use the `--benchmarks` flag to run the benchmark with Python. A [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto will be printed.
-
-```
-python sample_benchmark.py --benchmarks=SampleBenchmark
-```
-
-Setting the flag as `--benchmarks=.` or `--benchmarks=all` works as well.
-
-(Please ensure that Tensorflow is installed to successfully import the package in the line `import tensorflow as tf`. For installation instructions, see [Installing TensorFlow](https://www.tensorflow.org/install/). This step is not necessary when running with Bazel.)
-
-
-## Adding a `bazel` Target
-
-We have a special target called `tf_py_logged_benchmark` for benchmarks defined under the TensorFlow github repo. `tf_py_logged_benchmark` should wrap around a regular `py_test` target. Running a `tf_py_logged_benchmark` would print a [TestResults](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/test_log.proto) proto. Defining a `tf_py_logged_benchmark` also lets us run it with TensorFlow continuous build.
-
-First, define a regular `py_test` target. See example below:
-
-```build
-py_test(
- name = "sample_benchmark",
- srcs = ["sample_benchmark.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
-)
-```
-
-You can run benchmarks in a `py_test` target by passing the `--benchmarks` flag. The benchmark should just print out a [BenchmarkEntries](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/util/test_log.proto) proto.
-
-```shell
-bazel test :sample_benchmark --test_arg=--benchmarks=all
-```
-
-
-Now, add the `tf_py_logged_benchmark` target (if available). This target would
-pass in `--benchmarks=all` to the wrapped `py_test` target and provide a way to store output for our TensorFlow continuous build. The target `tf_py_logged_benchmark` should be available in TensorFlow repository.
-
-```build
-load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
-
-tf_py_logged_benchmark(
- name = "sample_logged_benchmark",
- target = "//tensorflow/examples/benchmark:sample_benchmark",
-)
-```
-
-Use the following command to run the benchmark target:
-
-```shell
-bazel test :sample_logged_benchmark
-```
diff --git a/tensorflow/docs_src/community/contributing.md b/tensorflow/docs_src/community/contributing.md
deleted file mode 100644
index afbb8bbdd0..0000000000
--- a/tensorflow/docs_src/community/contributing.md
+++ /dev/null
@@ -1,49 +0,0 @@
-# Contributing to TensorFlow
-
-TensorFlow is an open-source project, and we welcome your participation
-and contribution. This page describes how to get involved.
-
-## Repositories
-
-The code for TensorFlow is hosted in the [TensorFlow GitHub
-organization](https://github.com/tensorflow). Multiple projects are located
-inside the organization, including:
-
-* [TensorFlow](https://github.com/tensorflow/tensorflow)
-* [Models](https://github.com/tensorflow/models)
-* [TensorBoard](https://github.com/tensorflow/tensorboard)
-* [TensorFlow.js](https://github.com/tensorflow/tfjs)
-* [TensorFlow Serving](https://github.com/tensorflow/serving)
-* [TensorFlow Documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/docs_src)
-
-## Contributor checklist
-
-* Before contributing to TensorFlow source code, please review the [contribution
-guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md).
-
-* Join the
-[developers@tensorflow.org](https://groups.google.com/a/tensorflow.org/d/forum/developers)
-mailing list, to coordinate and discuss with others contributing to TensorFlow.
-
-* For coding style conventions, read the @{$style_guide$TensorFlow Style Guide}.
-
-* Finally, review @{$documentation$Writing TensorFlow Documentation}, which
- explains documentation conventions.
-
-You may also wish to review our guide to @{$benchmarks$defining and running benchmarks}.
-
-## Special Interest Groups
-
-To enable focused collaboration on particular areas of TensorFlow, we host
-Special Interest Groups (SIGs). SIGs do their work in public: if you want to
-join and contribute, review the work of the group, and get in touch with the
-relevant SIG leader. Membership policies vary on a per-SIG basis.
-
-* **SIG Build** focuses on issues surrounding building, packaging, and
- distribution of TensorFlow. [Mailing list](https://groups.google.com/a/tensorflow.org/d/forum/build).
-
-* **SIG TensorBoard** furthers the development and direction of TensorBoard and its plugins.
- [Mailing list](https://groups.google.com/a/tensorflow.org/d/forum/sig-tensorboard).
-
-* **SIG Rust** collaborates on the development of TensorFlow's Rust bindings.
- [Mailing list](https://groups.google.com/a/tensorflow.org/d/forum/rust).
diff --git a/tensorflow/docs_src/community/documentation.md b/tensorflow/docs_src/community/documentation.md
deleted file mode 100644
index 8639656d07..0000000000
--- a/tensorflow/docs_src/community/documentation.md
+++ /dev/null
@@ -1,673 +0,0 @@
-# Writing TensorFlow Documentation
-
-We welcome contributions to the TensorFlow documentation from the community.
-This document explains how you can contribute to that documentation. In
-particular, this document explains the following:
-
-* Where the documentation is located.
-* How to make conformant edits.
-* How to build and test your documentation changes before you submit them.
-
-You can view TensorFlow documentation on https://www.tensorflow.org, and you
-can view and edit the raw files on
-[GitHub](https://www.tensorflow.org/code/tensorflow/docs_src/).
-We're publishing our docs on GitHub so everybody can contribute. Whatever gets
-checked in to `tensorflow/docs_src` will be published soon after on
-https://www.tensorflow.org.
-
-Republishing TensorFlow documentation in different forms is absolutely allowed,
-but we are unlikely to accept other documentation formats (or the tooling to
-generate them) into our repository. If you do choose to republish our
-documentation in another form, please be sure to include:
-
-* The version of the API this represents (for example, r1.0, master, etc.)
-* The commit or version from which the documentation was generated
-* Where to get the latest documentation (that is, https://www.tensorflow.org)
-* The Apache 2.0 license.
-
-## A note on versions
-
-tensorflow.org, at root, shows documentation for the latest stable binary. This
-is the documentation you should be reading if you are using `pip` to install
-TensorFlow.
-
-However, most developers will contribute documentation into the master GitHub
-branch, which is published, occasionally,
-at [tensorflow.org/versions/master](https://www.tensorflow.org/versions/master).
-
-If you want documentation changes to appear at root, you will need to also
-contribute that change to the current stable binary branch (and/or
-[cherrypick](https://stackoverflow.com/questions/9339429/what-does-cherry-picking-a-commit-with-git-mean)).
-
-## Reference vs. non-reference documentation
-
-The following reference documentation is automatically generated from comments
-in the code:
-
-- C++ API reference docs
-- Java API reference docs
-- Python API reference docs
-
-To modify the reference documentation, you edit the appropriate code comments.
-
-Non-reference documentation (for example, the TensorFlow installation guides) is
-authored by humans. This documentation is located in the
-[`tensorflow/docs_src`](https://www.tensorflow.org/code/tensorflow/docs_src/)
-directory. Each subdirectory of `docs_src` contains a set of related TensorFlow
-documentation. For example, the TensorFlow installation guides are all in the
-`docs_src/install` directory.
-
-The C++ documentation is generated from XML files generated via doxygen;
-however, those tools are not available in open source at this time.
-
-## Markdown
-
-Editable TensorFlow documentation is written in Markdown. With a few exceptions,
-TensorFlow uses
-the [standard Markdown rules](https://daringfireball.net/projects/markdown/).
-
-This section explains the primary differences between standard Markdown rules
-and the Markdown rules that editable TensorFlow documentation uses.
-
-### Math in Markdown
-
-You may use MathJax within TensorFlow when editing Markdown files, but note the
-following:
-
-- MathJax renders properly on [tensorflow.org](https://www.tensorflow.org)
-- MathJax does not render properly on [github](https://github.com/tensorflow/tensorflow).
-
-When writing MathJax, you can use <code>&#36;&#36;</code> and `\\(` and `\\)` to
-surround your math. <code>&#36;&#36;</code> guards will cause line breaks, so
-within text, use `\\(` `\\)` instead.
-
-### Links in Markdown
-
-Links fall into a few categories:
-
-- Links to a different part of the same file
-- Links to a URL outside of tensorflow.org
-- Links from a Markdown file (or code comments) to another file within tensorflow.org
-
-For the first two link categories, you may use standard Markdown links, but put
-the link entirely on one line, rather than splitting it across lines. For
-example:
-
-- `[text](link) # Good link`
-- `[text]\n(link) # Bad link`
-- `[text](\nlink) # Bad link`
-
-For the final link category (links to another file within tensorflow.org),
-please use a special link parameterization mechanism. This mechanism enables
-authors to move and reorganize files without breaking links.
-
-The parameterization scheme is as follows. Use:
-
-<!-- Note: the use of &#64; is a hack so we don't translate these as symbols -->
-- <code>&#64;{tf.symbol}</code> to make a link to the reference page for a
- Python symbol. Note that class members don't get their own page, but the
- syntax still works, since <code>&#64;{tf.MyClass.method}</code> links to the
- proper part of the tf.MyClass page.
-
-- <code>&#64;{tensorflow::symbol}</code> to make a link to the reference page
- for a C++ symbol.
-
-- <code>&#64;{$doc_page}</code> to make a link to another (not an API reference)
- doc page. To link to
-
- - `red/green/blue/index.md` use <code>&#64;{$blue}</code> or
- <code>&#64;{$green/blue}</code>,
-
- - `foo/bar/baz.md` use <code>&#64;{$baz}</code> or
- <code>&#64;{$bar/baz}</code>.
-
- The shorter one is preferred, so we can move pages around without breaking
- these references. The main exception is that the Python API guides should
- probably be referred to using <code>&#64;{$python/<guide-name>}</code> to
- avoid ambiguity.
-
-- <code>&#64;{$doc_page#anchor-tag$link-text}</code> to link to an anchor in
- that doc and use different link text (by default, the link text is the title
- of the target page).
-
- To override the link text only, omit the `#anchor-tag`.
-
-To link to source code, use a link starting with:
-`https://www.tensorflow.org/code/`, followed by
-the file name starting at the github root. For instance, a link to the file you
-are currently reading should be written as
-`https://www.tensorflow.org/code/tensorflow/docs_src/community/documentation.md`.
-
-This URL naming scheme ensures
-that [tensorflow.org](https://www.tensorflow.org/) can forward the link to the
-branch of the code corresponding to the version of the documentation you're
-viewing. Do not include url parameters in the source code URL.
-
-## Generating docs and previewing links
-
-Before building the documentation, you must first set up your environment by
-doing the following:
-
-1. If bazel is not installed on your machine, install it now. If you are on
- Linux, install bazel by issuing the following command:
-
- $ sudo apt-get install bazel # Linux
-
- If you are on Mac OS, find bazel installation instructions on
- [this page](https://bazel.build/versions/master/docs/install.html#mac-os-x).
-
-2. Change directory to the top-level `tensorflow` directory of the TensorFlow
- source code.
-
-3. Run the `configure` script and answer its prompts appropriately for your
- system.
-
- $ ./configure
-
-Then, change to the `tensorflow` directory which contains `docs_src` (`cd
-tensorflow`). Run the following command to compile TensorFlow and generate the
-documentation in the `/tmp/tfdocs` dir:
-
- bazel run tools/docs:generate -- \
- --src_dir="$(pwd)/docs_src/" \
- --output_dir=/tmp/tfdocs/
-
-Note: You must set `src_dir` and `output_dir` to absolute file paths.
-
-## Generating Python API documentation
-
-Ops, classes, and utility functions are defined in Python modules, such as
-`image_ops.py`. Python modules contain a module docstring. For example:
-
-```python
-"""Image processing and decoding ops."""
-```
-
-The documentation generator places this module docstring at the beginning of the
-Markdown file generated for the module, in this
-case, [tf.image](https://www.tensorflow.org/api_docs/python/tf/image).
-
-It used to be a requirement to list every member of a module inside the module
-file at the beginning, putting a `@@` before each member. The `@@member_name`
-syntax is deprecated and no longer generates any docs. But depending on how a
-module is [sealed](#sealing_modules) it may still be necessary to mark the
-elements of the module’s contents as public. The called-out op, function, or
-class does not have to be defined in the same file. The next few sections of
-this document discuss sealing and how to add elements to the public
-documentation.
-
-The new documentation system automatically documents public symbols, except for
-the following:
-
-- Private symbols whose names start with an underscore.
-- Symbols originally defined in `object` or protobuf’s `Message`.
-- Some class members, such as `__base__`, `__class__`, which are dynamically
- created but generally have no useful documentation.
-
-Only top level modules (currently just `tf` and `tfdbg`) need to be manually
-added to the generate script.
-
-### Sealing modules
-
-Because the doc generator walks all visible symbols, and descends into anything
-it finds, it will document any accidentally exposed symbols. If a module only
-exposes symbols that are meant to be part of the public API, we call it
-**sealed**. Because of Python’s loose import and visibility conventions, naively
-written Python code will inadvertently expose a lot of modules which are
-implementation details. Improperly sealed modules may expose other unsealed
-modules, which will typically lead the doc generator to fail. **This failure is
-the intended behavior.** It ensures that our API is well defined, and allows us
-to change implementation details (including which modules are imported where)
-without fear of accidentally breaking users.
-
-If a module is accidentally imported, it typically breaks the doc generator
-(`generate_test`). This is a clear sign you need to seal your modules. However,
-even if the doc generator succeeds, unwanted symbols may show up in the
-docs. Check the generated docs to make sure that all symbols that are documented
-are expected. If there are symbols that shouldn’t be there, you have the
-following options for dealing with them:
-
-- Private symbols and imports
-- The `remove_undocumented` filter
-- A traversal blacklist.
-
-We'll discuss these options in detail below.
-
-#### Private symbols and imports
-
-The easiest way to conform to the API sealing expectations is to make non-public
-symbols private (by prepending an underscore _). The doc generator respects
-private symbols. This also applies to modules. If the only problem is that there
-is a small number of imported modules that show up in the docs (or break the
-generator), you can simply rename them on import, e.g.: `import sys as _sys`.
-
-Because Python considers all files to be modules, this applies to files as
-well. If you have a directory containing the following two files/modules:
-
- module/__init__.py
- module/private_impl.py
-
-Then, after `module` is imported, it will be possible to access
-`module.private_impl`. Renaming `private_impl.py` to `_private_impl.py` solves
-the problem. If renaming modules is awkward, read on.
-
-#### Use the `remove_undocumented` filter
-
-Another way to seal a module is to split your implementation from the API. To do
-so, consider using `remove_undocumented`, which takes a list of allowed symbols,
-and deletes everything else from the module. For example, the following snippet
-demonstrates how to put `remove_undocumented` in the `__init__.py` file for a
-module:
-
-__init__.py:
-
- # Use * imports only if __all__ defined in some_file
- from tensorflow.some_module.some_file import *
-
- # Otherwise import symbols directly
- from tensorflow.some_module.some_other_file import some_symbol
-
- from tensorflow.python.util.all_util import remove_undocumented
-
- _allowed_symbols = [‘some_symbol’, ‘some_other_symbol’]
-
- remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
-
-The `@@member_name` syntax is deprecated, but it still exists in some places in
-the documentation as an indicator to `remove_undocumented` that those symbols
-are public. All `@@`s will eventually be removed. If you see them, however,
-please do not randomly delete them as they are still in use by some of our
-systems.
-
-#### Traversal blacklist
-
-If all else fails, you may add entries to the traversal blacklist in
-`generate_lib.py.` **Almost all entries in this list are an abuse of its
-purpose; avoid adding to it if you can!**
-
-The traversal blacklist maps qualified module names (without the leading `tf.`)
-to local names that are not to be descended into. For instance, the following
-entry will exclude `some_module` from traversal.
-
- { ...
- ‘contrib.my_module’: [‘some_module’]
- ...
- }
-
-That means that the doc generator will show that `some_module` exists, but it
-will not enumerate its content.
-
-This blacklist was originally intended to make sure that system modules (mock,
-flags, ...) included for platform abstraction can be documented without
-documenting their interior. Its use beyond this purpose is a shortcut that may
-be acceptable for contrib, but not for core tensorflow.
-
-## Op documentation style guide
-
-Long, descriptive module-level documentation for modules should go in the API
-Guides in `docs_src/api_guides/python`.
-
-For classes and ops, ideally, you should provide the following information, in
-order of presentation:
-
-* A short sentence that describes what the op does.
-* A short description of what happens when you pass arguments to the op.
-* An example showing how the op works (pseudocode is best).
-* Requirements, caveats, important notes (if there are any).
-* Descriptions of inputs, outputs, and Attrs or other parameters of the op
- constructor.
-
-Each of these is described in more
-detail [below](#description-of-the-docstring-sections).
-
-Write your text in Markdown format. A basic syntax reference
-is [here](https://daringfireball.net/projects/markdown/). You are allowed to
-use [MathJax](https://www.mathjax.org) notation for equations (see above for
-restrictions).
-
-### Writing about code
-
-Put backticks around these things when they're used in text:
-
-* Argument names (for example, `input`, `x`, `tensor`)
-* Returned tensor names (for example, `output`, `idx`, `out`)
-* Data types (for example, `int32`, `float`, `uint8`)
-* Other op names referenced in text (for example, `list_diff()`, `shuffle()`)
-* Class names (for example, `Tensor` when you actually mean a `Tensor` object;
- don't capitalize or use backticks if you're just explaining what an op does to
- a tensor, or a graph, or an operation in general)
-* File names (for example, `image_ops.py`, or
- `/path-to-your-data/xml/example-name`)
-* Math expressions or conditions (for example, `-1-input.dims() <= dim <=
- input.dims()`)
-
-Put three backticks around sample code and pseudocode examples. And use `==>`
-instead of a single equal sign when you want to show what an op returns. For
-example:
-
- ```
- # 'input' is a tensor of shape [2, 3, 5]
- (tf.expand_dims(input, 0)) ==> [1, 2, 3, 5]
- ```
-
-If you're providing a Python code sample, add the python style label to ensure
-proper syntax highlighting:
-
- ```python
- # some Python code
- ```
-
-Two notes about backticks for code samples in Markdown:
-
-1. You can use backticks for pretty printing languages other than Python, if
- necessary. A full list of languages is available
- [here](https://github.com/google/code-prettify#how-do-i-specify-the-language-of-my-code).
-2. Markdown also allows you to indent four spaces to specify a code sample.
- However, do NOT indent four spaces and use backticks simultaneously. Use one
- or the other.
-
-### Tensor dimensions
-
-When you're talking about a tensor in general, don't capitalize the word tensor.
-When you're talking about the specific object that's provided to an op as an
-argument or returned by an op, then you should capitalize the word Tensor and
-add backticks around it because you're talking about a `Tensor` object.
-
-Don't use the word `Tensors` to describe multiple Tensor objects unless you
-really are talking about a `Tensors` object. Better to say "a list of `Tensor`
-objects."
-
-Use the term "dimension" to refer to the size of a tensor. If you need to be
-specific about the size, use these conventions:
-
-- Refer to a scalar as a "0-D tensor"
-- Refer to a vector as a "1-D tensor"
-- Refer to a matrix as a "2-D tensor"
-- Refer to tensors with 3 or more dimensions as 3-D tensors or n-D tensors. Use
- the word "rank" only if it makes sense, but try to use "dimension" instead.
- Never use the word "order" to describe the size of a tensor.
-
-Use the word "shape" to detail the dimensions of a tensor, and show the shape in
-square brackets with backticks. For example:
-
- If `input` is a 3-D tensor with shape `[3, 4, 3]`, this operation
- returns a 3-D tensor with shape `[6, 8, 6]`.
-
-### Ops defined in C++
-
-All Ops defined in C++ (and accessible from other languages) must be documented
-with a `REGISTER_OP` declaration. The docstring in the C++ file is processed to
-automatically add some information for the input types, output types, and Attr
-types and default values.
-
-For example:
-
-```c++
-REGISTER_OP("PngDecode")
- .Input("contents: string")
- .Attr("channels: int = 0")
- .Output("image: uint8")
- .Doc(R"doc(
-Decodes the contents of a PNG file into a uint8 tensor.
-
-contents: PNG file contents.
-channels: Number of color channels, or 0 to autodetect based on the input.
- Must be 0 for autodetect, 1 for grayscale, 3 for RGB, or 4 for RGBA.
- If the input has a different number of channels, it will be transformed
- accordingly.
-image:= A 3-D uint8 tensor of shape `[height, width, channels]`.
- If `channels` is 0, the last dimension is determined
- from the png contents.
-)doc");
-```
-
-Results in this piece of Markdown:
-
- ### tf.image.png_decode(contents, channels=None, name=None) {#png_decode}
-
- Decodes the contents of a PNG file into a uint8 tensor.
-
- #### Args:
-
- * **contents**: A string Tensor. PNG file contents.
- * **channels**: An optional int. Defaults to 0.
- Number of color channels, or 0 to autodetect based on the input.
- Must be 0 for autodetect, 1 for grayscale, 3 for RGB, or 4 for RGBA. If the
- input has a different number of channels, it will be transformed accordingly.
- * **name**: A name for the operation (optional).
-
- #### Returns:
- A 3-D uint8 tensor of shape `[height, width, channels]`. If `channels` is
- 0, the last dimension is determined from the png contents.
-
-Much of the argument description is added automatically. In particular, the doc
-generator automatically adds the name and type of all inputs, attrs, and
-outputs. In the above example, `contents: A string Tensor.` was added
-automatically. You should write your additional text to flow naturally after
-that description.
-
-For inputs and output, you can prefix your additional text with an equal sign to
-prevent the automatically added name and type. In the above example, the
-description for the output named `image` starts with `=` to prevent the addition
-of `A uint8 Tensor.` before our text `A 3-D uint8 Tensor...`. You cannot prevent
-the addition of the name, type, and default value of attrs this way, so write
-your text carefully.
-
-### Ops defined in Python
-
-If your op is defined in a `python/ops/*.py` file, then you need to provide text
-for all of the arguments and output (returned) tensors. The doc generator does
-not auto-generate any text for ops that are defined in Python, so what you write
-is what you get.
-
-You should conform to the usual Python docstring conventions, except that you
-should use Markdown in the docstring.
-
-Here's a simple example:
-
- def foo(x, y, name="bar"):
- """Computes foo.
-
- Given two 1-D tensors `x` and `y`, this operation computes the foo.
-
- Example:
-
- ```
- # x is [1, 1]
- # y is [2, 2]
- tf.foo(x, y) ==> [3, 3]
- ```
- Args:
- x: A `Tensor` of type `int32`.
- y: A `Tensor` of type `int32`.
- name: A name for the operation (optional).
-
- Returns:
- A `Tensor` of type `int32` that is the foo of `x` and `y`.
-
- Raises:
- ValueError: If `x` or `y` are not of type `int32`.
- """
-
-## Description of the docstring sections
-
-This section details each of the elements in docstrings.
-
-### Short sentence describing what the op does
-
-Examples:
-
-```
-Concatenates tensors.
-```
-
-```
-Flips an image horizontally from left to right.
-```
-
-```
-Computes the Levenshtein distance between two sequences.
-```
-
-```
-Saves a list of tensors to a file.
-```
-
-```
-Extracts a slice from a tensor.
-```
-
-### Short description of what happens when you pass arguments to the op
-
-Examples:
-
- Given a tensor input of numerical type, this operation returns a tensor of
- the same type and size with values reversed along dimension `seq_dim`. A
- vector `seq_lengths` determines which elements are reversed for each index
- within dimension 0 (usually the batch dimension).
-
-
- This operation returns a tensor of type `dtype` and dimensions `shape`, with
- all elements set to zero.
-
-### Example demonstrating the op
-
-Good code samples are short and easy to understand, typically containing a brief
-snippet of code to clarify what the example is demonstrating. When an op
-manipulates the shape of a Tensor it is often useful to include an example of
-the before and after, as well.
-
-The `squeeze()` op has a nice pseudocode example:
-
- # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
- shape(squeeze(t)) ==> [2, 3]
-
-The `tile()` op provides a good example in descriptive text:
-
- For example, tiling `[a, b, c, d]` by `[2]` produces `[a b c d a b c d]`.
-
-It is often helpful to show code samples in Python. Never put them in the C++
-Ops file, and avoid putting them in the Python Ops doc. We recommend, if
-possible, putting code samples in the
-[API guides](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/docs_src/api_guides).
-Otherwise, add them to the module or class docstring where the Ops constructors
-are called out.
-
-Here's an example from the module docstring in `api_guides/python/math_ops.md`:
-
- ## Segmentation
-
- TensorFlow provides several operations that you can use to perform common
- math computations on tensor segments.
- ...
- In particular, a segmentation of a matrix tensor is a mapping of rows to
- segments.
-
- For example:
-
- ```python
- c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
- tf.segment_sum(c, tf.constant([0, 0, 1]))
- ==> [[0 0 0 0]
- [5 6 7 8]]
- ```
-
-### Requirements, caveats, important notes
-
-Examples:
-
-```
-This operation requires that: `-1-input.dims() <= dim <= input.dims()`
-```
-
-```
-Note: This tensor will produce an error if evaluated. Its value must
-be fed using the `feed_dict` optional argument to `Session.run()`,
-`Tensor.eval()`, or `Operation.run()`.
-```
-
-### Descriptions of arguments and output (returned) tensors.
-
-Keep the descriptions brief and to the point. You should not have to explain how
-the operation works in the argument sections.
-
-Mention if the Op has strong constraints on the dimensions of the input or
-output tensors. Remember that for C++ Ops, the type of the tensor is
-automatically added as either as "A ..type.. Tensor" or "A Tensor with type in
-{...list of types...}". In such cases, if the Op has a constraint on the
-dimensions either add text such as "Must be 4-D" or start the description with
-`=` (to prevent the tensor type to be added) and write something like "A 4-D
-float tensor".
-
-For example, here are two ways to document an image argument of a C++ op (note
-the "=" sign):
-
-```
-image: Must be 4-D. The image to resize.
-```
-
-```
-image:= A 4-D `float` tensor. The image to resize.
-```
-
-In the documentation, these will be rendered to markdown as
-
-```
-image: A `float` Tensor. Must be 4-D. The image to resize.
-```
-
-```
-image: A 4-D `float` Tensor. The image to resize.
-```
-
-### Optional arguments descriptions ("attrs")
-
-The doc generator always describes the type for each attr and their default
-value, if any. You cannot override that with an equal sign because the
-description is very different in the C++ and Python generated docs.
-
-Phrase any additional attr description so that it flows well after the type
-and default value. The type and defaults are displayed first, and additional
-descriptions follow afterwards. Therefore, complete sentences are best.
-
-Here's an example from `image_ops.cc`:
-
- REGISTER_OP("DecodePng")
- .Input("contents: string")
- .Attr("channels: int = 0")
- .Attr("dtype: {uint8, uint16} = DT_UINT8")
- .Output("image: dtype")
- .SetShapeFn(DecodeImageShapeFn)
- .Doc(R"doc(
- Decode a PNG-encoded image to a uint8 or uint16 tensor.
-
- The attr `channels` indicates the desired number of color channels for the
- decoded image.
-
- Accepted values are:
-
- * 0: Use the number of channels in the PNG-encoded image.
- * 1: output a grayscale image.
- * 3: output an RGB image.
- * 4: output an RGBA image.
-
- If needed, the PNG-encoded image is transformed to match the requested
- number of color channels.
-
- contents: 0-D. The PNG-encoded image.
- channels: Number of color channels for the decoded image.
- image: 3-D with shape `[height, width, channels]`.
- )doc");
-
-This generates the following Args section in
-`api_docs/python/tf/image/decode_png.md`:
-
- #### Args:
-
- * **`contents`**: A `Tensor` of type `string`. 0-D. The PNG-encoded
- image.
- * **`channels`**: An optional `int`. Defaults to `0`. Number of color
- channels for the decoded image.
- * **`dtype`**: An optional `tf.DType` from: `tf.uint8,
- tf.uint16`. Defaults to `tf.uint 8`.
- * **`name`**: A name for the operation (optional).
diff --git a/tensorflow/docs_src/community/groups.md b/tensorflow/docs_src/community/groups.md
deleted file mode 100644
index 0b07d413da..0000000000
--- a/tensorflow/docs_src/community/groups.md
+++ /dev/null
@@ -1,38 +0,0 @@
-# User Groups
-
-TensorFlow has communities around the world. [Submit your community!](https://docs.google.com/forms/d/e/1FAIpQLSc_RQIUYtVgLLihzATaO_WUXkEyBDE_OoRoOXYDPmBEvHuEBA/viewform)
-
-## Asia
-
-* [TensorFlow China community](https://www.tensorflowers.cn)
-* [TensorFlow Korea (TF-KR) User Group](https://www.facebook.com/groups/TensorFlowKR/)
-* [TensorFlow User Group Tokyo](https://tfug-tokyo.connpass.com/)
-* [Soleil Data Dojo](https://soleildatadojo.connpass.com/)
-* [TensorFlow User Group Utsunomiya](https://tfug-utsunomiya.connpass.com/)
-* [TensorFlow Philippines Community](https://www.facebook.com/groups/TensorFlowPH/)
-* [TensorFlow and Deep Learning Singapore](https://www.meetup.com/TensorFlow-and-Deep-Learning-Singapore/)
-* [TensorFlow India](https://www.facebook.com/tensorflowindia)
-
-
-## Europe
-
-* [TensorFlow Barcelona](https://www.meetup.com/Barcelona-Machine-Learning-Meetup/)
-* [TensorFlow Madrid](https://www.meetup.com/TensorFlow-Madrid/)
-* [Tensorflow Belgium](https://www.meetup.com/TensorFlow-Belgium)
-* [TensorFlow x Rome Meetup](https://www.meetup.com/it-IT/TensorFlow-x-Rome-Meetup)
-* [TensorFlow London](https://www.meetup.com/TensorFlow-London/)
-* [TensorFlow Edinburgh](https://www.meetup.com/tensorflow-edinburgh/)
-
-
-## America
-
-* [TensorFlow Buenos Aires](https://www.meetup.com/TensorFlow-Buenos-Aires/)
-
-
-## Oceania
-* [Melbourne TensorFlow Meetup](https://www.meetup.com/Melbourne-TensorFlow-Meetup)
-
-
-## Africa
-
-* [TensorFlow Tunis Meetup](https://www.meetup.com/fr-FR/TensorFlow-Tunis-Meetup/)
diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md
deleted file mode 100644
index 865a203bf8..0000000000
--- a/tensorflow/docs_src/community/index.md
+++ /dev/null
@@ -1,85 +0,0 @@
-# Community
-
-Welcome to the TensorFlow community! This page explains where to get help, and
-different ways to be part of the community. We are committed to fostering an
-open and welcoming environment, and request that you review our [code of
-conduct](https://github.com/tensorflow/tensorflow/blob/master/CODE_OF_CONDUCT.md).
-
-## Get Help
-
-### Technical Questions
-
-To ask or answer technical questions about TensorFlow, use [Stack
-Overflow](https://stackoverflow.com/questions/tagged/tensorflow). For example,
-ask or search about a particular error message you encountered during
-installation.
-
-### Bugs and Feature Requests
-
-To report bugs or make feature requests, file an issue on GitHub. Please choose
-the appropriate repository for the project. Major repositories include:
-
- * [TensorFlow](https://github.com/tensorflow/tensorflow/issues)
- * [TensorBoard](https://github.com/tensorflow/tensorboard/issues)
- * [TensorFlow models](https://github.com/tensorflow/models/issues)
-
-### Security
-
-Before using TensorFlow, please take a look at our [security model](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md#tensorflow-models-are-programs),
-[list of recent security advisories and announcements](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md),
-and [ways you can report security issues](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md#reporting-vulnerabilities)
-to the TensorFlow team at the [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) page on GitHub.
-
-## Stay Informed
-
-### Announcements Mailing List
-
-All major releases and important announcements are sent to
-[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
-We recommend that you join this list if you depend on TensorFlow in any way.
-
-### Development Roadmap
-
-The @{$roadmap$Roadmap} summarizes plans for upcoming additions to TensorFlow.
-
-### Social Media
-
-For news and updates from around the universe of TensorFlow projects, follow
-[@tensorflow](https://twitter.com/tensorflow) on Twitter.
-
-### Blog
-
-We post regularly to the [TensorFlow Blog](http://blog.tensorflow.org/),
-with content from the TensorFlow team and the best articles from the community.
-
-### YouTube
-
-Our [YouTube Channel](http://youtube.com/tensorflow/) focuses on machine learning
-and AI with TensorFlow. On it we have a number of new shows, including:
-
-- TensorFlow Meets: meet with community contributors to learn and share what they're doing
-- Ask TensorFlow: the team answers the best questions tagged #AskTensorFlow from social media
-- Coding TensorFlow: short bites with tips for success with TensorFlow
-
-## Community Support
-
-### Mailing Lists
-
-For general discussion about TensorFlow development and direction, please join
-the [TensorFlow discuss mailing
-list](https://groups.google.com/a/tensorflow.org/d/forum/discuss).
-
-A number of other mailing lists exist, focused on different project areas, which
-can be found at @{$lists$TensorFlow Mailing Lists}.
-
-### User Groups
-
-To meet with like-minded people local to you, check out the many
-@{$groups$TensorFlow user groups} around the world.
-
-
-## Contributing To TensorFlow
-
-We welcome contributions and collaboration on TensorFlow. For more information,
-please read [Contributing to TensorFlow](contributing.md).
-
diff --git a/tensorflow/docs_src/community/leftnav_files b/tensorflow/docs_src/community/leftnav_files
deleted file mode 100644
index 0bd1f14de9..0000000000
--- a/tensorflow/docs_src/community/leftnav_files
+++ /dev/null
@@ -1,8 +0,0 @@
-index.md
-roadmap.md
-contributing.md
-lists.md
-groups.md
-documentation.md
-style_guide.md
-benchmarks.md
diff --git a/tensorflow/docs_src/community/lists.md b/tensorflow/docs_src/community/lists.md
deleted file mode 100644
index bc2f573c29..0000000000
--- a/tensorflow/docs_src/community/lists.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# Mailing Lists
-
-As a community, we do much of our collaboration on public mailing lists.
-Please note that if you're looking for help using TensorFlow, [Stack
-Overflow](https://stackoverflow.com/questions/tagged/tensorflow) and
-[GitHub issues](https://github.com/tensorflow/tensorflow/issues)
-are the best initial places to look. For more information,
-see [how to get help](/community/#get_help).
-
-## General TensorFlow lists
-
-* [announce](https://groups.google.com/a/tensorflow.org/d/forum/announce) - Low-volume announcements of new releases.
-* [discuss](https://groups.google.com/a/tensorflow.org/d/forum/discuss) - General community discussion around TensorFlow.
-* [developers](https://groups.google.com/a/tensorflow.org/d/forum/developers) - Discussion for developers contributing to TensorFlow.
-
-## Project-specific lists
-
-These projects inside the TensorFlow GitHub organization have lists dedicated to their communities:
-
-* [hub](https://groups.google.com/a/tensorflow.org/d/forum/hub) -
- Discussion and collaboration around [TensorFlow Hub](https://github.com/tensorflow/hub).
-* [magenta-discuss](https://groups.google.com/a/tensorflow.org/d/forum/magenta-discuss) -
- General discussion about [Magenta](https://magenta.tensorflow.org/)
- development and directions.
-* [swift](https://groups.google.com/a/tensorflow.org/d/forum/swift) -
- Community and collaboration around Swift for TensorFlow.
-* [tensor2tensor](https://groups.google.com/d/forum/tensor2tensor) - Discussion
- and peer support for Tensor2Tensor.
-* [tfjs-announce](https://groups.google.com/a/tensorflow.org/d/forum/tfjs-announce) -
- Announcements of new TensorFlow.js releases.
-* [tfjs](https://groups.google.com/a/tensorflow.org/d/forum/tfjs) - Discussion
- and peer support for TensorFlow.js.
-* [tflite](https://groups.google.com/a/tensorflow.org/d/forum/tflite) - Discussion and
- peer support for TensorFlow Lite.
-* [tfprobability](https://groups.google.com/a/tensorflow.org/d/forum/tfprobability) - Discussion and
- peer support for TensorFlow Probability.
-* [tpu-users](https://groups.google.com/a/tensorflow.org/d/forum/tpu-users) - Community discussion
- and support for TPU users.
-
-## Special Interest Groups
-
-TensorFlow's [Special Interest
-Groups](/community/contributing#special_interest_groups) (SIGs) support
-community collaboration on particular project focuses. Members of these groups
-work together to build and support TensorFlow related projects. While their
-archives are public, different SIGs have their own membership policies.
-
-* [build](https://groups.google.com/a/tensorflow.org/d/forum/build) -
- Supporting SIG Build, for build, distribution and packaging of TensorFlow.
-* [sig-tensorboard](https://groups.google.com/a/tensorflow.org/d/forum/sig-tensorboard) -
- Supporting SIG TensorBoard, for plugin development and other contribution.
-* [rust](https://groups.google.com/a/tensorflow.org/d/forum/rust) -
- Supporting SIG Rust, for the Rust language bindings.
diff --git a/tensorflow/docs_src/community/roadmap.md b/tensorflow/docs_src/community/roadmap.md
deleted file mode 100644
index 0463ca05fe..0000000000
--- a/tensorflow/docs_src/community/roadmap.md
+++ /dev/null
@@ -1,121 +0,0 @@
-# Roadmap
-**Last updated: Apr 27, 2018**
-
-TensorFlow is a rapidly moving, community supported project. This document is intended
-to provide guidance about priorities and focus areas of the core set of TensorFlow
-developers and about functionality that can be expected in the upcoming releases of
-TensorFlow. Many of these areas are driven by community use cases, and we welcome
-further
-[contributions](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md)
-to TensorFlow.
-
-The features below do not have concrete release dates. However, the majority can be
-expected in the next one to two releases.
-
-### APIs
-#### High Level APIs:
-* Easy multi-GPU and TPU utilization with Estimators
-* Easy-to-use high-level pre-made estimators for Gradient Boosted Trees, Time Series, and other models
-
-#### Eager Execution:
-* Efficient utilization of multiple GPUs
-* Distributed training support (multi-machine)
-* Performance improvements
-* Simpler export to a GraphDef/SavedModel
-
-#### Keras API:
-* Better integration with tf.data (ability to call `model.fit` with data tensors)
-* Full support for Eager Execution (both Eager support for the regular Keras API, and ability
-to create Keras models Eager- style via Model subclassing)
-* Better distribution/multi-GPU support and TPU support (including a smoother model-to-estimator workflow)
-
-#### Official Models:
-* A set of
-[models](https://github.com/tensorflow/models/tree/master/official)
-across image recognition, speech, object detection, and
- translation that demonstrate best practices and serve as a starting point for
- high-performance model development.
-
-#### Contrib:
-* Deprecate parts of tf.contrib where preferred implementations exist outside of tf.contrib.
-* As much as possible, move large projects inside tf.contrib to separate repositories.
-* The tf.contrib module will eventually be discontinued in its current form, experimental development will in future happen in other repositories.
-
-
-#### Probabilistic Reasoning and Statistical Analysis:
-* Rich set of tools for probabilistic and statistical analysis in tf.distributions
- and tf.probability. These include new samplers, layers, optimizers, losses, and structured models
-* Statistical tools for hypothesis testing, convergence diagnostics, and sample statistics
-* Edward 2.0: High-level API for probabilistic programming
-
-### Platforms
-#### TensorFlow Lite:
-* Increase coverage of supported ops in TensorFlow Lite
-* Easier conversion of a trained TensorFlow graph for use on TensorFlow Lite
-* Support for GPU acceleration in TensorFlow Lite (iOS and Android)
-* Support for hardware accelerators via Android NeuralNets API
-* Improve CPU performance by quantization and other network optimizations (eg. pruning, distillation)
-* Increase support for devices beyond Android and iOS (eg. RPi, Cortex-M)
-
-#### TensorFlow.js:
-* Release package for Node.js bindings to the TensorFlow C API through the TensorFlow.js backend interface
-* Expand support for importing TensorFlow SavedModels and Keras models into browser with unified APIs supporting retraining in browser
-* Improve Layers API and allow model exporting/saving
-* Release tfjs-data API for efficient data input pipelines
-
-#### TensorFlow with Swift:
-* Establish open source project including documentation, open design, and code availability.
-* Continue implementing and refining implementation and design through 2018.
-* Aim for implementation to be solid enough for general use later in 2018.
-
-### Performance
-#### Distributed TensorFlow:
-* Optimize Multi-GPU support for a variety of GPU topologies
-* Improve mechanisms for distributing computations on several machines
-
-#### GPU Optimizations:
-* Simplify mixed precision API with initial example model and guide.
-* Finalize TensorRT API and move to core.
-* CUDA 9.2 and NCCL 2.x default in TensorFlow builds.
-* Optimizations for DGX-2.
-* Remove support for CUDA less than 8.x and cuDNN less than 6.x.
-
-
-#### CPU Optimizations
-* Int8 support for SkyLake via MKL
-* Dynamic loading of SIMD-optimized kernels
-* MKL for Linux and Windows
-
-### End-to-end ML systems:
-#### TensorFlow Hub:
-* Expand support for module-types in TF Hub with TF Eager integration, Keras layers integration, and TensorFlow.js integration
-* Accept variable-sized image input
-* Improve multi-GPU estimator support
-* Document and improve TPU integration
-
-#### TensorFlow Extended:
-* Open source more of the TensorFlow Extended platform to facilitate adoption of TensorFlow in production settings.
-* Release TFX libraries for Data Validation
-
-### Documentation and Resources:
-* Update documentation, tutorials and Getting Started guides on all features and APIs
-* Update [Youtube Tensorflow channel](https://youtube.com/tensorflow) weekly with new content:
-Coding TensorFlow - where we teach folks coding with tensorflow
-TensorFlow Meets - where we highlight community contributions
-Ask TensorFlow - where we answer community questions
-Guest and Showcase videos
-* Update [Official TensorFlow blog](https://blog.tensorflow.org) with regular articles from Google team and the Community
-
-
-### Community and Partner Engagement
-#### Special Interest Groups:
-* Mobilize the community to work together in focused domains
-* [tf-distribute](https://groups.google.com/a/tensorflow.org/forum/#!forum/tf-distribute): build and packaging of TensorFlow
-* SIG TensorBoard, SIG Rust, and more to be identified and launched
-
-#### Community:
-* Incorporate public feedback on significant design decisions via a Request-for-Comment (RFC) process
-* Formalize process for external contributions to land in TensorFlow and associated projects
-* Grow global TensorFlow communities and user groups
-* Collaborate with partners to co-develop and publish research papers
-* Process to enable external contributions to tutorials, documentation, and blogs showcasing best practice use-cases of TensorFlow and high-impact applications
diff --git a/tensorflow/docs_src/community/style_guide.md b/tensorflow/docs_src/community/style_guide.md
deleted file mode 100644
index daf0d2fdc0..0000000000
--- a/tensorflow/docs_src/community/style_guide.md
+++ /dev/null
@@ -1,136 +0,0 @@
-# TensorFlow Style Guide
-
-This page contains style decisions that both developers and users of TensorFlow
-should follow to increase the readability of their code, reduce the number of
-errors, and promote consistency.
-
-[TOC]
-
-## Python style
-
-Generally follow
-[PEP8 Python style guide](https://www.python.org/dev/peps/pep-0008/),
-except for using 2 spaces.
-
-
-## Python 2 and 3 compatible
-
-* All code needs to be compatible with Python 2 and 3.
-
-* Next lines should be present in all Python files:
-
-```
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-```
-
-* Use `six` to write compatible code (for example `six.moves.range`).
-
-
-## Bazel BUILD rules
-
-TensorFlow uses Bazel build system and enforces next requirements:
-
-* Every BUILD file should contain next header:
-
-```
-# Description:
-# <...>
-
-package(
- default_visibility = ["//visibility:private"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-```
-
-
-
-* For all Python BUILD targets (libraries and tests) add next line:
-
-```
-srcs_version = "PY2AND3",
-```
-
-
-## Tensor
-
-* Operations that deal with batches may assume that the first dimension of a Tensor is the batch dimension.
-
-* In most models the *last dimension* is the number of channels.
-
-* Dimensions excluding the first and last usually make up the "space" dimensions: Sequence-length or Image-size.
-
-## Python operations
-
-A *Python operation* is a function that, given input tensors and parameters,
-creates a part of the graph and returns output tensors.
-
-* The first arguments should be tensors, followed by basic python parameters.
- The last argument is `name` with a default value of `None`.
- If operation needs to save some `Tensor`s to Graph collections,
- put the arguments with names of the collections right before `name` argument.
-
-* Tensor arguments should be either a single tensor or an iterable of tensors.
- E.g. a "Tensor or list of Tensors" is too broad. See `assert_proper_iterable`.
-
-* Operations that take tensors as arguments should call `convert_to_tensor`
- to convert non-tensor inputs into tensors if they are using C++ operations.
- Note that the arguments are still described as a `Tensor` object
- of a specific dtype in the documentation.
-
-* Each Python operation should have a `name_scope` like below. Pass as
- arguments `name`, a default name of the op, and a list of the input tensors.
-
-* Operations should contain an extensive Python comment with Args and Returns
- declarations that explain both the type and meaning of each value. Possible
- shapes, dtypes, or ranks should be specified in the description.
- @{$documentation$See documentation details}
-
-* For increased usability include an example of usage with inputs / outputs
- of the op in Example section.
-
-Example:
-
- def my_op(tensor_in, other_tensor_in, my_param, other_param=0.5,
- output_collections=(), name=None):
- """My operation that adds two tensors with given coefficients.
-
- Args:
- tensor_in: `Tensor`, input tensor.
- other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
- my_param: `float`, coefficient for `tensor_in`.
- other_param: `float`, coefficient for `other_tensor_in`.
- output_collections: `tuple` of `string`s, name of the collection to
- collect result of this op.
- name: `string`, name of the operation.
-
- Returns:
- `Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
-
- Example:
- >>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
- output_collections=['MY_OPS'], name='add_t1t2')
- [2.3, 3.4]
- """
- with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
- tensor_in = tf.convert_to_tensor(tensor_in)
- other_tensor_in = tf.convert_to_tensor(other_tensor_in)
- result = my_param * tensor_in + other_param * other_tensor_in
- tf.add_to_collection(output_collections, result)
- return result
-
-Usage:
-
- output = my_op(t1, t2, my_param=0.5, other_param=0.6,
- output_collections=['MY_OPS'], name='add_t1t2')
-
-
-## Layers
-
-Use `tf.keras.layers`, not `tf.layers`.
-
-See `tf.keras.layers` and [the Keras guide](../guide/keras.md#custom_layers) for details on how to sub-class layers.
diff --git a/tensorflow/docs_src/deploy/deploy_to_js.md b/tensorflow/docs_src/deploy/deploy_to_js.md
deleted file mode 100644
index d7ce3ea90b..0000000000
--- a/tensorflow/docs_src/deploy/deploy_to_js.md
+++ /dev/null
@@ -1,4 +0,0 @@
-# Deploy to JavaScript
-
-You can find details about deploying JavaScript TensorFlow programs
-in the separate [js.tensorflow.org site](https://js.tensorflow.org).
diff --git a/tensorflow/docs_src/deploy/distributed.md b/tensorflow/docs_src/deploy/distributed.md
deleted file mode 100644
index 6a760f53c8..0000000000
--- a/tensorflow/docs_src/deploy/distributed.md
+++ /dev/null
@@ -1,354 +0,0 @@
-# Distributed TensorFlow
-
-This document shows how to create a cluster of TensorFlow servers, and how to
-distribute a computation graph across that cluster. We assume that you are
-familiar with the @{$guide/low_level_intro$basic concepts} of
-writing low level TensorFlow programs.
-
-## Hello distributed TensorFlow!
-
-To see a simple TensorFlow cluster in action, execute the following:
-
-```shell
-# Start a TensorFlow server as a single-process "cluster".
-$ python
->>> import tensorflow as tf
->>> c = tf.constant("Hello, distributed TensorFlow!")
->>> server = tf.train.Server.create_local_server()
->>> sess = tf.Session(server.target) # Create a session on the server.
->>> sess.run(c)
-'Hello, distributed TensorFlow!'
-```
-
-The
-`tf.train.Server.create_local_server`
-method creates a single-process cluster, with an in-process server.
-
-## Create a cluster
-
-<div class="video-wrapper">
- <iframe class="devsite-embedded-youtube-video" data-video-id="la_M6bCV91M"
- data-autohide="1" data-showinfo="0" frameborder="0" allowfullscreen>
- </iframe>
-</div>
-
-A TensorFlow "cluster" is a set of "tasks" that participate in the distributed
-execution of a TensorFlow graph. Each task is associated with a TensorFlow
-"server", which contains a "master" that can be used to create sessions, and a
-"worker" that executes operations in the graph. A cluster can also be divided
-into one or more "jobs", where each job contains one or more tasks.
-
-To create a cluster, you start one TensorFlow server per task in the cluster.
-Each task typically runs on a different machine, but you can run multiple tasks
-on the same machine (e.g. to control different GPU devices). In each task, do
-the following:
-
-1. **Create a `tf.train.ClusterSpec`** that describes all of the tasks
- in the cluster. This should be the same for each task.
-
-2. **Create a `tf.train.Server`**, passing the `tf.train.ClusterSpec` to
- the constructor, and identifying the local task with a job name
- and task index.
-
-
-### Create a `tf.train.ClusterSpec` to describe the cluster
-
-The cluster specification dictionary maps job names to lists of network
-addresses. Pass this dictionary to
-the `tf.train.ClusterSpec`
-constructor. For example:
-
-<table>
- <tr><th><code>tf.train.ClusterSpec</code> construction</th><th>Available tasks</th>
- <tr>
- <td><pre>
-tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
-</pre></td>
-<td><code>/job:local/task:0<br/>/job:local/task:1</code></td>
- </tr>
- <tr>
- <td><pre>
-tf.train.ClusterSpec({
- "worker": [
- "worker0.example.com:2222",
- "worker1.example.com:2222",
- "worker2.example.com:2222"
- ],
- "ps": [
- "ps0.example.com:2222",
- "ps1.example.com:2222"
- ]})
-</pre></td><td><code>/job:worker/task:0</code><br/><code>/job:worker/task:1</code><br/><code>/job:worker/task:2</code><br/><code>/job:ps/task:0</code><br/><code>/job:ps/task:1</code></td>
- </tr>
-</table>
-
-### Create a `tf.train.Server` instance in each task
-
-A `tf.train.Server` object contains a
-set of local devices, a set of connections to other tasks in its
-`tf.train.ClusterSpec`, and a
-`tf.Session` that can use these
-to perform a distributed computation. Each server is a member of a specific
-named job and has a task index within that job. A server can communicate with
-any other server in the cluster.
-
-For example, to launch a cluster with two servers running on `localhost:2222`
-and `localhost:2223`, run the following snippets in two different processes on
-the local machine:
-
-```python
-# In task 0:
-cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
-server = tf.train.Server(cluster, job_name="local", task_index=0)
-```
-```python
-# In task 1:
-cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]})
-server = tf.train.Server(cluster, job_name="local", task_index=1)
-```
-
-**Note:** Manually specifying these cluster specifications can be tedious,
-especially for large clusters. We are working on tools for launching tasks
-programmatically, e.g. using a cluster manager like
-[Kubernetes](http://kubernetes.io). If there are particular cluster managers for
-which you'd like to see support, please raise a
-[GitHub issue](https://github.com/tensorflow/tensorflow/issues).
-
-## Specifying distributed devices in your model
-
-To place operations on a particular process, you can use the same
-`tf.device`
-function that is used to specify whether ops run on the CPU or GPU. For example:
-
-```python
-with tf.device("/job:ps/task:0"):
- weights_1 = tf.Variable(...)
- biases_1 = tf.Variable(...)
-
-with tf.device("/job:ps/task:1"):
- weights_2 = tf.Variable(...)
- biases_2 = tf.Variable(...)
-
-with tf.device("/job:worker/task:7"):
- input, labels = ...
- layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
- logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
- # ...
- train_op = ...
-
-with tf.Session("grpc://worker7.example.com:2222") as sess:
- for _ in range(10000):
- sess.run(train_op)
-```
-
-In the above example, the variables are created on two tasks in the `ps` job,
-and the compute-intensive part of the model is created in the `worker`
-job. TensorFlow will insert the appropriate data transfers between the jobs
-(from `ps` to `worker` for the forward pass, and from `worker` to `ps` for
-applying gradients).
-
-## Replicated training
-
-A common training configuration, called "data parallelism," involves multiple
-tasks in a `worker` job training the same model on different mini-batches of
-data, updating shared parameters hosted in one or more tasks in a `ps`
-job. All tasks typically run on different machines. There are many ways to
-specify this structure in TensorFlow, and we are building libraries that will
-simplify the work of specifying a replicated model. Possible approaches include:
-
-* **In-graph replication.** In this approach, the client builds a single
- `tf.Graph` that contains one set of parameters (in `tf.Variable` nodes pinned
- to `/job:ps`); and multiple copies of the compute-intensive part of the model,
- each pinned to a different task in `/job:worker`.
-
-* **Between-graph replication.** In this approach, there is a separate client
- for each `/job:worker` task, typically in the same process as the worker
- task. Each client builds a similar graph containing the parameters (pinned to
- `/job:ps` as before using
- `tf.train.replica_device_setter`
- to map them deterministically to the same tasks); and a single copy of the
- compute-intensive part of the model, pinned to the local task in
- `/job:worker`.
-
-* **Asynchronous training.** In this approach, each replica of the graph has an
- independent training loop that executes without coordination. It is compatible
- with both forms of replication above.
-
-* **Synchronous training.** In this approach, all of the replicas read the same
- values for the current parameters, compute gradients in parallel, and then
- apply them together. It is compatible with in-graph replication (e.g. using
- gradient averaging as in the
- [CIFAR-10 multi-GPU trainer](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py)),
- and between-graph replication (e.g. using the
- `tf.train.SyncReplicasOptimizer`).
-
-### Putting it all together: example trainer program
-
-The following code shows the skeleton of a distributed trainer program,
-implementing **between-graph replication** and **asynchronous training**. It
-includes the code for the parameter server and worker tasks.
-
-```python
-import argparse
-import sys
-
-import tensorflow as tf
-
-FLAGS = None
-
-
-def main(_):
- ps_hosts = FLAGS.ps_hosts.split(",")
- worker_hosts = FLAGS.worker_hosts.split(",")
-
- # Create a cluster from the parameter server and worker hosts.
- cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
-
- # Create and start a server for the local task.
- server = tf.train.Server(cluster,
- job_name=FLAGS.job_name,
- task_index=FLAGS.task_index)
-
- if FLAGS.job_name == "ps":
- server.join()
- elif FLAGS.job_name == "worker":
-
- # Assigns ops to the local worker by default.
- with tf.device(tf.train.replica_device_setter(
- worker_device="/job:worker/task:%d" % FLAGS.task_index,
- cluster=cluster)):
-
- # Build model...
- loss = ...
- global_step = tf.contrib.framework.get_or_create_global_step()
-
- train_op = tf.train.AdagradOptimizer(0.01).minimize(
- loss, global_step=global_step)
-
- # The StopAtStepHook handles stopping after running given steps.
- hooks=[tf.train.StopAtStepHook(last_step=1000000)]
-
- # The MonitoredTrainingSession takes care of session initialization,
- # restoring from a checkpoint, saving to a checkpoint, and closing when done
- # or an error occurs.
- with tf.train.MonitoredTrainingSession(master=server.target,
- is_chief=(FLAGS.task_index == 0),
- checkpoint_dir="/tmp/train_logs",
- hooks=hooks) as mon_sess:
- while not mon_sess.should_stop():
- # Run a training step asynchronously.
- # See `tf.train.SyncReplicasOptimizer` for additional details on how to
- # perform *synchronous* training.
- # mon_sess.run handles AbortedError in case of preempted PS.
- mon_sess.run(train_op)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.register("type", "bool", lambda v: v.lower() == "true")
- # Flags for defining the tf.train.ClusterSpec
- parser.add_argument(
- "--ps_hosts",
- type=str,
- default="",
- help="Comma-separated list of hostname:port pairs"
- )
- parser.add_argument(
- "--worker_hosts",
- type=str,
- default="",
- help="Comma-separated list of hostname:port pairs"
- )
- parser.add_argument(
- "--job_name",
- type=str,
- default="",
- help="One of 'ps', 'worker'"
- )
- # Flags for defining the tf.train.Server
- parser.add_argument(
- "--task_index",
- type=int,
- default=0,
- help="Index of task within the job"
- )
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
-```
-
-To start the trainer with two parameter servers and two workers, use the
-following command line (assuming the script is called `trainer.py`):
-
-```shell
-# On ps0.example.com:
-$ python trainer.py \
- --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
- --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
- --job_name=ps --task_index=0
-# On ps1.example.com:
-$ python trainer.py \
- --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
- --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
- --job_name=ps --task_index=1
-# On worker0.example.com:
-$ python trainer.py \
- --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
- --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
- --job_name=worker --task_index=0
-# On worker1.example.com:
-$ python trainer.py \
- --ps_hosts=ps0.example.com:2222,ps1.example.com:2222 \
- --worker_hosts=worker0.example.com:2222,worker1.example.com:2222 \
- --job_name=worker --task_index=1
-```
-
-## Glossary
-
-**Client**
-
-A client is typically a program that builds a TensorFlow graph and constructs a
-`tensorflow::Session` to interact with a cluster. Clients are typically written
-in Python or C++. A single client process can directly interact with multiple
-TensorFlow servers (see "Replicated training" above), and a single server can
-serve multiple clients.
-
-**Cluster**
-
-A TensorFlow cluster comprises one or more "jobs", each divided into lists of
-one or more "tasks". A cluster is typically dedicated to a particular high-level
-objective, such as training a neural network, using many machines in parallel. A
-cluster is defined by
-a `tf.train.ClusterSpec` object.
-
-**Job**
-
-A job comprises a list of "tasks", which typically serve a common purpose.
-For example, a job named `ps` (for "parameter server") typically hosts nodes
-that store and update variables; while a job named `worker` typically hosts
-stateless nodes that perform compute-intensive tasks. The tasks in a job
-typically run on different machines. The set of job roles is flexible:
-for example, a `worker` may maintain some state.
-
-**Master service**
-
-An RPC service that provides remote access to a set of distributed devices,
-and acts as a session target. The master service implements the
-`tensorflow::Session` interface, and is responsible for coordinating work across
-one or more "worker services". All TensorFlow servers implement the master
-service.
-
-**Task**
-
-A task corresponds to a specific TensorFlow server, and typically corresponds
-to a single process. A task belongs to a particular "job" and is identified by
-its index within that job's list of tasks.
-
-**TensorFlow server** A process running
-a `tf.train.Server` instance, which is
-a member of a cluster, and exports a "master service" and "worker service".
-
-**Worker service**
-
-An RPC service that executes parts of a TensorFlow graph using its local devices.
-A worker service implements [worker_service.proto](https://www.tensorflow.org/code/tensorflow/core/protobuf/worker_service.proto).
-All TensorFlow servers implement the worker service.
diff --git a/tensorflow/docs_src/deploy/hadoop.md b/tensorflow/docs_src/deploy/hadoop.md
deleted file mode 100644
index c4471562b9..0000000000
--- a/tensorflow/docs_src/deploy/hadoop.md
+++ /dev/null
@@ -1,65 +0,0 @@
-# How to run TensorFlow on Hadoop
-
-This document describes how to run TensorFlow on Hadoop. It will be expanded to
-describe running on various cluster managers, but only describes running on HDFS
-at the moment.
-
-## HDFS
-
-We assume that you are familiar with @{$reading_data$reading data}.
-
-To use HDFS with TensorFlow, change the file paths you use to read and write
-data to an HDFS path. For example:
-
-```python
-filename_queue = tf.train.string_input_producer([
- "hdfs://namenode:8020/path/to/file1.csv",
- "hdfs://namenode:8020/path/to/file2.csv",
-])
-```
-
-If you want to use the namenode specified in your HDFS configuration files, then
-change the file prefix to `hdfs://default/`.
-
-When launching your TensorFlow program, the following environment variables must
-be set:
-
-* **JAVA_HOME**: The location of your Java installation.
-* **HADOOP_HDFS_HOME**: The location of your HDFS installation. You can also
- set this environment variable by running:
-
- ```shell
- source ${HADOOP_HOME}/libexec/hadoop-config.sh
- ```
-
-* **LD_LIBRARY_PATH**: To include the path to libjvm.so, and optionally the path
- to libhdfs.so if your Hadoop distribution does not install libhdfs.so in
- `$HADOOP_HDFS_HOME/lib/native`. On Linux:
-
- ```shell
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${JAVA_HOME}/jre/lib/amd64/server
- ```
-
-* **CLASSPATH**: The Hadoop jars must be added prior to running your
- TensorFlow program. The CLASSPATH set by
- `${HADOOP_HOME}/libexec/hadoop-config.sh` is insufficient. Globs must be
- expanded as described in the libhdfs documentation:
-
- ```shell
- CLASSPATH=$(${HADOOP_HDFS_HOME}/bin/hadoop classpath --glob) python your_script.py
- ```
- For older version of Hadoop/libhdfs (older than 2.6.0), you have to expand the
- classpath wildcard manually. For more details, see
- [HADOOP-10903](https://issues.apache.org/jira/browse/HADOOP-10903).
-
-If the Hadoop cluster is in secure mode, the following environment variable must
-be set:
-
-* **KRB5CCNAME**: The path of Kerberos ticket cache file. For example:
-
- ```shell
- export KRB5CCNAME=/tmp/krb5cc_10002
- ```
-
-If you are running @{$distributed$Distributed TensorFlow}, then all
-workers must have the environment variables set and Hadoop installed.
diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md
deleted file mode 100644
index 3322004189..0000000000
--- a/tensorflow/docs_src/deploy/index.md
+++ /dev/null
@@ -1,21 +0,0 @@
-# Deploy
-
-This section focuses on deploying real-world models. It contains
-the following documents:
-
- * @{$distributed$Distributed TensorFlow}, which explains how to create
- a cluster of TensorFlow servers.
- * @{$hadoop$How to run TensorFlow on Hadoop}, which has a highly
- self-explanatory title.
- * @{$s3$How to run TensorFlow with the S3 filesystem}, which explains how
- to run TensorFlow with the S3 file system.
- * The entire document set for [TensorFlow serving](/serving), an open-source,
- flexible, high-performance serving system for machine-learned models
- designed for production environments. TensorFlow Serving provides
- out-of-the-box integration with TensorFlow models.
- [Source code for TensorFlow Serving](https://github.com/tensorflow/serving)
- is available on GitHub.
-
-[TensorFlow Extended (TFX)](/tfx) is an end-to-end machine learning platform for
-TensorFlow. Implemented at Google, we've open sourced some TFX libraries with the
-rest of the system to come.
diff --git a/tensorflow/docs_src/deploy/leftnav_files b/tensorflow/docs_src/deploy/leftnav_files
deleted file mode 100644
index 93f5bd1ed2..0000000000
--- a/tensorflow/docs_src/deploy/leftnav_files
+++ /dev/null
@@ -1,5 +0,0 @@
-index.md
-distributed.md
-hadoop.md
-s3.md
-deploy_to_js.md
diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md
deleted file mode 100644
index 079c796aa7..0000000000
--- a/tensorflow/docs_src/deploy/s3.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# How to run TensorFlow on S3
-
-Tensorflow supports reading and writing data to S3. S3 is an object storage API which is nearly ubiquitous, and can help in situations where data must accessed by multiple actors, such as in distributed training.
-
-This document guides you through the required setup, and provides examples on usage.
-
-## Configuration
-
-When reading or writing data on S3 with your TensorFlow program, the behavior
-can be controlled by various environmental variables:
-
-* **AWS_REGION**: By default, regional endpoint is used for S3, with region
- controlled by `AWS_REGION`. If `AWS_REGION` is not specified, then
- `us-east-1` is used.
-* **S3_ENDPOINT**: The endpoint could be overridden explicitly with
- `S3_ENDPOINT` specified.
-* **S3_USE_HTTPS**: HTTPS is used to access S3 by default, unless
- `S3_USE_HTTPS=0`.
-* **S3_VERIFY_SSL**: If HTTPS is used, SSL verification could be disabled
- with `S3_VERIFY_SSL=0`.
-
-To read or write objects in a bucket that is not publicly accessible,
-AWS credentials must be provided through one of the following methods:
-
-* Set credentials in the AWS credentials profile file on the local system,
- located at: `~/.aws/credentials` on Linux, macOS, or Unix, or
- `C:\Users\USERNAME\.aws\credentials` on Windows.
-* Set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment
- variables.
-* If TensorFlow is deployed on an EC2 instance, specify an IAM role and then
- give the EC2 instance access to that role.
-
-## Example Setup
-
-Using the above information, we can configure Tensorflow to communicate to an S3 endpoint by setting the following environment variables:
-
-```bash
-AWS_ACCESS_KEY_ID=XXXXX # Credentials only needed if connecting to a private endpoint
-AWS_SECRET_ACCESS_KEY=XXXXX
-AWS_REGION=us-east-1 # Region for the S3 bucket, this is not always needed. Default is us-east-1.
-S3_ENDPOINT=s3.us-east-1.amazonaws.com # The S3 API Endpoint to connect to. This is specified in a HOST:PORT format.
-S3_USE_HTTPS=1 # Whether or not to use HTTPS. Disable with 0.
-S3_VERIFY_SSL=1 # If HTTPS is used, controls if SSL should be enabled. Disable with 0.
-```
-
-## Usage
-
-Once setup is completed, Tensorflow can interact with S3 in a variety of ways. Anywhere there is a Tensorflow IO function, an S3 URL can be used.
-
-### Smoke Test
-
-To test your setup, stat a file:
-
-```python
-from tensorflow.python.lib.io import file_io
-print file_io.stat('s3://bucketname/path/')
-```
-
-You should see output similar to this:
-
-```console
-<tensorflow.python.pywrap_tensorflow_internal.FileStatistics; proxy of <Swig Object of type 'tensorflow::FileStatistics *' at 0x10c2171b0> >
-```
-
-### Reading Data
-
-When @{$reading_data$reading data}, change the file paths you use to read and write
-data to an S3 path. For example:
-
-```python
-filenames = ["s3://bucketname/path/to/file1.tfrecord",
- "s3://bucketname/path/to/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-```
-
-### Tensorflow Tools
-
-Many Tensorflow tools, such as Tensorboard or model serving, can also take S3 URLS as arguments:
-
-```bash
-tensorboard --logdir s3://bucketname/path/to/model/
-tensorflow_model_server --port=9000 --model_name=model --model_base_path=s3://bucketname/path/to/model/export/
-```
-
-This enables an end to end workflow using S3 for all data needs.
-
-## S3 Endpoint Implementations
-
-S3 was invented by Amazon, but the S3 API has spread in popularity and has several implementations. The following implementations have passed basic compatibility tests:
-
-* [Amazon S3](https://aws.amazon.com/s3/)
-* [Google Storage](https://cloud.google.com/storage/docs/interoperability)
-* [Minio](https://www.minio.io/kubernetes.html)
diff --git a/tensorflow/docs_src/extend/add_filesys.md b/tensorflow/docs_src/extend/add_filesys.md
deleted file mode 100644
index bc0f662f0c..0000000000
--- a/tensorflow/docs_src/extend/add_filesys.md
+++ /dev/null
@@ -1,260 +0,0 @@
-# Adding a Custom Filesystem Plugin
-
-## Background
-
-The TensorFlow framework is often used in multi-process and
-multi-machine environments, such as Google data centers, Google Cloud
-Machine Learning, Amazon Web Services (AWS), and on-site distributed clusters.
-In order to both share and save certain types of state produced by TensorFlow,
-the framework assumes the existence of a reliable, shared filesystem. This
-shared filesystem has numerous uses, for example:
-
-* Checkpoints of state are often saved to a distributed filesystem for
- reliability and fault-tolerance.
-* Training processes communicate with TensorBoard by writing event files
- to a directory, which TensorBoard watches. A shared filesystem allows this
- communication to work even when TensorBoard runs in a different process or
- machine.
-
-There are many different implementations of shared or distributed filesystems in
-the real world, so TensorFlow provides an ability for users to implement a
-custom FileSystem plugin that can be registered with the TensorFlow runtime.
-When the TensorFlow runtime attempts to write to a file through the `FileSystem`
-interface, it uses a portion of the pathname to dynamically select the
-implementation that should be used for filesystem operations. Thus, adding
-support for your custom filesystem requires implementing a `FileSystem`
-interface, building a shared object containing that implementation, and loading
-that object at runtime in whichever process needs to write to that filesystem.
-
-Note that TensorFlow already includes many filesystem implementations, such as:
-
-* A standard POSIX filesystem
-
- Note: NFS filesystems often mount as a POSIX interface, and so standard
- TensorFlow can work on top of NFS-mounted remote filesystems.
-
-* HDFS - the Hadoop File System
-* GCS - Google Cloud Storage filesystem
-* S3 - Amazon Simple Storage Service filesystem
-* A "memory-mapped-file" filesystem
-
-The rest of this guide describes how to implement a custom filesystem.
-
-## Implementing a custom filesystem plugin
-
-To implement a custom filesystem plugin, you must do the following:
-
-* Implement subclasses of `RandomAccessFile`, `WriteableFile`,
- `AppendableFile`, and `ReadOnlyMemoryRegion`.
-* Implement the `FileSystem` interface as a subclass.
-* Register the `FileSystem` implementation with an appropriate prefix pattern.
-* Load the filesystem plugin in a process that wants to write to that
- filesystem.
-
-### The FileSystem interface
-
-The `FileSystem` interface is an abstract C++ interface defined in
-[file_system.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/file_system.h).
-An implementation of the `FileSystem` interface should implement all relevant
-the methods defined by the interface. Implementing the interface requires
-defining operations such as creating `RandomAccessFile`, `WritableFile`, and
-implementing standard filesystem operations such as `FileExists`, `IsDirectory`,
-`GetMatchingPaths`, `DeleteFile`, and so on. An implementation of these
-interfaces will often involve translating the function's input arguments to
-delegate to an already-existing library function implementing the equivalent
-functionality in your custom filesystem.
-
-For example, the `PosixFileSystem` implementation implements `DeleteFile` using
-the POSIX `unlink()` function; `CreateDir` simply calls `mkdir()`; `GetFileSize`
-involves calling `stat()` on the file and then returns the filesize as reported
-by the return of the stat object. Similarly, for the `HDFSFileSystem`
-implementation, these calls simply delegate to the `libHDFS` implementation of
-similar functionality, such as `hdfsDelete` for
-[DeleteFile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hadoop/hadoop_file_system.cc#L386).
-
-We suggest looking through these code examples to get an idea of how different
-filesystem implementations call their existing libraries. Examples include:
-
-* [POSIX
- plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/posix/posix_file_system.h)
-* [HDFS
- plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hadoop/hadoop_file_system.h)
-* [GCS
- plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/cloud/gcs_file_system.h)
-* [S3
- plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/s3/s3_file_system.h)
-
-#### The File interfaces
-
-Beyond operations that allow you to query and manipulate files and directories
-in a filesystem, the `FileSystem` interface requires you to implement factories
-that return implementations of abstract objects such as the
-[RandomAccessFile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/file_system.h#L223),
-the `WritableFile`, so that TensorFlow code and read and write to files in that
-`FileSystem` implementation.
-
-To implement a `RandomAccessFile`, you must implement a single interface called
-`Read()`, in which the implementation must provide a way to read from an offset
-within a named file.
-
-For example, below is the implementation of RandomAccessFile for the POSIX
-filesystem, which uses the `pread()` random-access POSIX function to implement
-read. Notice that the particular implementation must know how to retry or
-propagate errors from the underlying filesystem.
-
-```C++
- class PosixRandomAccessFile : public RandomAccessFile {
- public:
- PosixRandomAccessFile(const string& fname, int fd)
- : filename_(fname), fd_(fd) {}
- ~PosixRandomAccessFile() override { close(fd_); }
-
- Status Read(uint64 offset, size_t n, StringPiece* result,
- char* scratch) const override {
- Status s;
- char* dst = scratch;
- while (n > 0 && s.ok()) {
- ssize_t r = pread(fd_, dst, n, static_cast<off_t>(offset));
- if (r > 0) {
- dst += r;
- n -= r;
- offset += r;
- } else if (r == 0) {
- s = Status(error::OUT_OF_RANGE, "Read less bytes than requested");
- } else if (errno == EINTR || errno == EAGAIN) {
- // Retry
- } else {
- s = IOError(filename_, errno);
- }
- }
- *result = StringPiece(scratch, dst - scratch);
- return s;
- }
-
- private:
- string filename_;
- int fd_;
- };
-```
-
-To implement the WritableFile sequential-writing abstraction, one must implement
-a few interfaces, such as `Append()`, `Flush()`, `Sync()`, and `Close()`.
-
-For example, below is the implementation of WritableFile for the POSIX
-filesystem, which takes a `FILE` object in its constructor and uses standard
-posix functions on that object to implement the interface.
-
-```C++
- class PosixWritableFile : public WritableFile {
- public:
- PosixWritableFile(const string& fname, FILE* f)
- : filename_(fname), file_(f) {}
-
- ~PosixWritableFile() override {
- if (file_ != NULL) {
- fclose(file_);
- }
- }
-
- Status Append(const StringPiece& data) override {
- size_t r = fwrite(data.data(), 1, data.size(), file_);
- if (r != data.size()) {
- return IOError(filename_, errno);
- }
- return Status::OK();
- }
-
- Status Close() override {
- Status result;
- if (fclose(file_) != 0) {
- result = IOError(filename_, errno);
- }
- file_ = NULL;
- return result;
- }
-
- Status Flush() override {
- if (fflush(file_) != 0) {
- return IOError(filename_, errno);
- }
- return Status::OK();
- }
-
- Status Sync() override {
- Status s;
- if (fflush(file_) != 0) {
- s = IOError(filename_, errno);
- }
- return s;
- }
-
- private:
- string filename_;
- FILE* file_;
- };
-
-```
-
-For more details, please see the documentations of those interfaces, and look at
-example implementations for inspiration.
-
-### Registering and loading the filesystem
-
-Once you have implemented the `FileSystem` implementation for your custom
-filesystem, you need to register it under a "scheme" so that paths prefixed with
-that scheme are directed to your implementation. To do this, you call
-`REGISTER_FILE_SYSTEM`::
-
-```
- REGISTER_FILE_SYSTEM("foobar", FooBarFileSystem);
-```
-
-When TensorFlow tries to operate on a file whose path starts with `foobar://`,
-it will use the `FooBarFileSystem` implementation.
-
-```C++
- string filename = "foobar://path/to/file.txt";
- std::unique_ptr<WritableFile> file;
-
- // Calls FooBarFileSystem::NewWritableFile to return
- // a WritableFile class, which happens to be the FooBarFileSystem's
- // WritableFile implementation.
- TF_RETURN_IF_ERROR(env->NewWritableFile(filename, &file));
-```
-
-Next, you must build a shared object containing this implementation. An example
-of doing so using bazel's `cc_binary` rule can be found
-[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/BUILD#L244),
-but you may use any build system to do so. See the section on @{$adding_an_op#build_the_op_library$building the op library} for similar
-instructions.
-
-The result of building this target is a `.so` shared object file.
-
-Lastly, you must dynamically load this implementation in the process. In Python,
-you can call the `tf.load_file_system_library(file_system_library)` function,
-passing the path to the shared object. Calling this in your client program loads
-the shared object in the process, thus registering your implementation as
-available for any file operations going through the `FileSystem` interface. You
-can see
-[test_file_system.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/file_system_test.py)
-for an example.
-
-## What goes through this interface?
-
-Almost all core C++ file operations within TensorFlow use the `FileSystem`
-interface, such as the `CheckpointWriter`, the `EventsWriter`, and many other
-utilities. This means implementing a `FileSystem` implementation allows most of
-your TensorFlow programs to write to your shared filesystem.
-
-In Python, the `gfile` and `file_io` classes bind underneath to the `FileSystem
-implementation via SWIG, which means that once you have loaded this filesystem
-library, you can do:
-
-```
-with gfile.Open("foobar://path/to/file.txt") as w:
-
- w.write("hi")
-```
-
-When you do this, a file containing "hi" will appear in the "/path/to/file.txt"
-of your shared filesystem.
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
deleted file mode 100644
index fbf5c0b90d..0000000000
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ /dev/null
@@ -1,1460 +0,0 @@
-# Adding a New Op
-
-Note: By default [www.tensorflow.org](https://www.tensorflow.org) shows docs for the
-most recent stable version. The instructions in this doc require building from
-source. You will probably want to build from the `master` version of tensorflow.
-You should, as a result, be sure you are following the
-[`master` version of this doc](https://www.tensorflow.org/versions/master/extend/adding_an_op),
-in case there have been any changes.
-
-If you'd like to create an op that isn't covered by the existing TensorFlow
-library, we recommend that you first try writing the op in Python as
-a composition of existing Python ops or functions. If that isn't possible, you
-can create a custom C++ op. There are several reasons why you might want to
-create a custom C++ op:
-
-* It's not easy or possible to express your operation as a composition of
- existing ops.
-* It's not efficient to express your operation as a composition of existing
- primitives.
-* You want to hand-fuse a composition of primitives that a future compiler
- would find difficult fusing.
-
-For example, imagine you want to implement something like "median pooling",
-similar to the "MaxPool" operator, but computing medians over sliding windows
-instead of maximum values. Doing this using a composition of operations may be
-possible (e.g., using ExtractImagePatches and TopK), but may not be as
-performance- or memory-efficient as a native operation where you can do
-something more clever in a single, fused operation. As always, it is typically
-first worth trying to express what you want using operator composition, only
-choosing to add a new operation if that proves to be difficult or inefficient.
-
-To incorporate your custom op you'll need to:
-
-1. Register the new op in a C++ file. Op registration defines an interface
- (specification) for the op's functionality, which is independent of the
- op's implementation. For example, op registration defines the op's name and
- the op's inputs and outputs. It also defines the shape function
- that is used for tensor shape inference.
-2. Implement the op in C++. The implementation of an op is known
- as a kernel, and it is the concrete implementation of the specification you
- registered in Step 1. There can be multiple kernels for different input /
- output types or architectures (for example, CPUs, GPUs).
-3. Create a Python wrapper (optional). This wrapper is the public API that's
- used to create the op in Python. A default wrapper is generated from the
- op registration, which can be used directly or added to.
-4. Write a function to compute gradients for the op (optional).
-5. Test the op. We usually do this in Python for convenience, but you can also
- test the op in C++. If you define gradients, you can verify them with the
- Python `tf.test.compute_gradient_error`.
- See
- [`relu_op_test.py`](https://www.tensorflow.org/code/tensorflow/python/kernel_tests/relu_op_test.py) as
- an example that tests the forward functions of Relu-like operators and
- their gradients.
-
-PREREQUISITES:
-
-* Some familiarity with C++.
-* Must have installed the
- @{$install$TensorFlow binary}, or must have
- @{$install_sources$downloaded TensorFlow source},
- and be able to build it.
-
-[TOC]
-
-## Define the op's interface
-
-You define the interface of an op by registering it with the TensorFlow system.
-In the registration, you specify the name of your op, its inputs (types and
-names) and outputs (types and names), as well as docstrings and
-any [attrs](#attrs) the op might require.
-
-To see how this works, suppose you'd like to create an op that takes a tensor of
-`int32`s and outputs a copy of the tensor, with all but the first element set to
-zero. To do this, create a file named `zero_out.cc`. Then add a call to the
-`REGISTER_OP` macro that defines the interface for your op:
-
-```c++
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-using namespace tensorflow;
-
-REGISTER_OP("ZeroOut")
- .Input("to_zero: int32")
- .Output("zeroed: int32")
- .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- });
-```
-
-This `ZeroOut` op takes one tensor `to_zero` of 32-bit integers as input, and
-outputs a tensor `zeroed` of 32-bit integers. The op also uses a shape function
-to ensure that the output tensor is the same shape as the input tensor. For
-example, if the input is a tensor of shape [10, 20], then this shape function
-specifies that the output shape is also [10, 20].
-
-
-> A note on naming: The op name must be in CamelCase and it must be unique
-> among all other ops that are registered in the binary.
-
-## Implement the kernel for the op
-
-After you define the interface, provide one or more implementations of the op.
-To create one of these kernels, create a class that extends `OpKernel` and
-overrides the `Compute` method. The `Compute` method provides one `context`
-argument of type `OpKernelContext*`, from which you can access useful things
-like the input and output tensors.
-
-Add your kernel to the file you created above. The kernel might look something
-like this:
-
-```c++
-#include "tensorflow/core/framework/op_kernel.h"
-
-using namespace tensorflow;
-
-class ZeroOutOp : public OpKernel {
- public:
- explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // Grab the input tensor
- const Tensor& input_tensor = context->input(0);
- auto input = input_tensor.flat<int32>();
-
- // Create an output tensor
- Tensor* output_tensor = NULL;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
- &output_tensor));
- auto output_flat = output_tensor->flat<int32>();
-
- // Set all but the first element of the output tensor to 0.
- const int N = input.size();
- for (int i = 1; i < N; i++) {
- output_flat(i) = 0;
- }
-
- // Preserve the first input value if possible.
- if (N > 0) output_flat(0) = input(0);
- }
-};
-```
-
-After implementing your kernel, you register it with the TensorFlow system. In
-the registration, you specify different constraints under which this kernel
-will run. For example, you might have one kernel made for CPUs, and a separate
-one for GPUs.
-
-To do this for the `ZeroOut` op, add the following to `zero_out.cc`:
-
-```c++
-REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
-```
-
-> Important: Instances of your OpKernel may be accessed concurrently.
-> Your `Compute` method must be thread-safe. Guard any access to class
-> members with a mutex. Or better yet, don't share state via class members!
-> Consider using a [`ResourceMgr`](https://www.tensorflow.org/code/tensorflow/core/framework/resource_mgr.h)
-> to keep track of op state.
-
-### Multi-threaded CPU kernels
-
-To write a multi-threaded CPU kernel, the Shard function in
-[`work_sharder.h`](https://www.tensorflow.org/code/tensorflow/core/util/work_sharder.h)
-can be used. This function shards a computation function across the
-threads configured to be used for intra-op threading (see
-intra_op_parallelism_threads in
-[`config.proto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)).
-
-### GPU kernels
-
-A GPU kernel is implemented in two parts: the OpKernel and the CUDA kernel and
-its launch code.
-
-Sometimes the OpKernel implementation is common between a CPU and GPU kernel,
-such as around inspecting inputs and allocating outputs. In that case, a
-suggested implementation is to:
-
-1. Define the OpKernel templated on the Device and the primitive type of the
- tensor.
-2. To do the actual computation of the output, the Compute function calls a
- templated functor struct.
-3. The specialization of that functor for the CPUDevice is defined in the same
- file, but the specialization for the GPUDevice is defined in a .cu.cc file,
- since it will be compiled with the CUDA compiler.
-
-Here is an example implementation.
-
-```c++
-// kernel_example.h
-#ifndef KERNEL_EXAMPLE_H_
-#define KERNEL_EXAMPLE_H_
-
-template <typename Device, typename T>
-struct ExampleFunctor {
- void operator()(const Device& d, int size, const T* in, T* out);
-};
-
-#if GOOGLE_CUDA
-// Partially specialize functor for GpuDevice.
-template <typename Eigen::GpuDevice, typename T>
-struct ExampleFunctor {
- void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
-};
-#endif
-
-#endif KERNEL_EXAMPLE_H_
-```
-
-```c++
-// kernel_example.cc
-#include "example.h"
-#include "tensorflow/core/framework/op_kernel.h"
-
-using namespace tensorflow;
-
-using CPUDevice = Eigen::ThreadPoolDevice;
-using GPUDevice = Eigen::GpuDevice;
-
-// CPU specialization of actual computation.
-template <typename T>
-struct ExampleFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, int size, const T* in, T* out) {
- for (int i = 0; i < size; ++i) {
- out[i] = 2 * in[i];
- }
- }
-};
-
-// OpKernel definition.
-// template parameter <T> is the datatype of the tensors.
-template <typename Device, typename T>
-class ExampleOp : public OpKernel {
- public:
- explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // Grab the input tensor
- const Tensor& input_tensor = context->input(0);
-
- // Create an output tensor
- Tensor* output_tensor = NULL;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
- &output_tensor));
-
- // Do the computation.
- OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
- errors::InvalidArgument("Too many elements in tensor"));
- ExampleFunctor<Device, T>()(
- context->eigen_device<Device>(),
- static_cast<int>(input_tensor.NumElements()),
- input_tensor.flat<T>().data(),
- output_tensor->flat<T>().data());
- }
-};
-
-// Register the CPU kernels.
-#define REGISTER_CPU(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- ExampleOp<CPUDevice, T>);
-REGISTER_CPU(float);
-REGISTER_CPU(int32);
-
-// Register the GPU kernels.
-#ifdef GOOGLE_CUDA
-#define REGISTER_GPU(T) \
- /* Declare explicit instantiations in kernel_example.cu.cc. */ \
- extern template ExampleFunctor<GPUDevice, T>; \
- REGISTER_KERNEL_BUILDER( \
- Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- ExampleOp<GPUDevice, T>);
-REGISTER_GPU(float);
-REGISTER_GPU(int32);
-#endif // GOOGLE_CUDA
-```
-
-```c++
-// kernel_example.cu.cc
-#ifdef GOOGLE_CUDA
-#define EIGEN_USE_GPU
-#include "example.h"
-#include "tensorflow/core/util/cuda_kernel_helper.h"
-
-using namespace tensorflow;
-
-using GPUDevice = Eigen::GpuDevice;
-
-// Define the CUDA kernel.
-template <typename T>
-__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
- i += blockDim.x * gridDim.x) {
- out[i] = 2 * ldg(in + i);
- }
-}
-
-// Define the GPU implementation that launches the CUDA kernel.
-template <typename T>
-void ExampleFunctor<GPUDevice, T>::operator()(
- const GPUDevice& d, int size, const T* in, T* out) {
- // Launch the cuda kernel.
- //
- // See core/util/cuda_kernel_helper.h for example of computing
- // block count and thread_per_block count.
- int block_count = 1024;
- int thread_per_block = 20;
- ExampleCudaKernel<T>
- <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
-}
-
-// Explicitly instantiate functors for the types of OpKernels registered.
-template struct ExampleFunctor<GPUDevice, float>;
-template struct ExampleFunctor<GPUDevice, int32>;
-
-#endif // GOOGLE_CUDA
-```
-
-## Build the op library
-### Compile the op using your system compiler (TensorFlow binary installation)
-
-You should be able to compile `zero_out.cc` with a `C++` compiler such as `g++`
-or `clang` available on your system. The binary PIP package installs the header
-files and the library that you need to compile your op in locations that are
-system specific. However, the TensorFlow python library provides the
-`get_include` function to get the header directory, and the `get_lib` directory
-has a shared object to link against.
-Here are the outputs of these functions on an Ubuntu machine.
-
-```bash
-$ python
->>> import tensorflow as tf
->>> tf.sysconfig.get_include()
-'/usr/local/lib/python2.7/site-packages/tensorflow/include'
->>> tf.sysconfig.get_lib()
-'/usr/local/lib/python2.7/site-packages/tensorflow'
-```
-
-Assuming you have `g++` installed, here is the sequence of commands you can use
-to compile your op into a dynamic library.
-
-```bash
-TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
-TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
-g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
-```
-
-On Mac OS X, the additional flag "-undefined dynamic_lookup" is required when
-building the `.so` file.
-
-> Note on `gcc` version `>=5`: gcc uses the new C++
-> [ABI](https://gcc.gnu.org/gcc-5/changes.html#libstdcxx) since version `5`. The binary pip
-> packages available on the TensorFlow website are built with `gcc4` that uses
-> the older ABI. If you compile your op library with `gcc>=5`, add
-> `-D_GLIBCXX_USE_CXX11_ABI=0` to the command line to make the library
-> compatible with the older abi.
-> Furthermore if you are using TensorFlow package created from source remember to add `--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"`
-> as bazel command to compile the Python package.
-
-### Compile the op using bazel (TensorFlow source installation)
-
-If you have TensorFlow sources installed, you can make use of TensorFlow's build
-system to compile your op. Place a BUILD file with following Bazel build rule in
-the [`tensorflow/core/user_ops`][user_ops] directory.
-
-```python
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-
-tf_custom_op_library(
- name = "zero_out.so",
- srcs = ["zero_out.cc"],
-)
-```
-
-Run the following command to build `zero_out.so`.
-
-```bash
-$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so
-```
-
-> Note: Although you can create a shared library (a `.so` file) with the
-> standard `cc_library` rule, we strongly recommend that you use the
-> `tf_custom_op_library` macro. It adds some required dependencies, and
-> performs checks to ensure that the shared library is compatible with
-> TensorFlow's plugin loading mechanism.
-
-## Use the op in Python
-
-TensorFlow Python API provides the
-`tf.load_op_library` function to
-load the dynamic library and register the op with the TensorFlow
-framework. `load_op_library` returns a Python module that contains the Python
-wrappers for the op and the kernel. Thus, once you have built the op, you can
-do the following to run it from Python:
-
-```python
-import tensorflow as tf
-zero_out_module = tf.load_op_library('./zero_out.so')
-with tf.Session(''):
- zero_out_module.zero_out([[1, 2], [3, 4]]).eval()
-
-# Prints
-array([[1, 0], [0, 0]], dtype=int32)
-```
-
-Keep in mind, the generated function will be given a snake\_case name (to comply
-with [PEP8](https://www.python.org/dev/peps/pep-0008/)). So, if your op is
-named `ZeroOut` in the C++ files, the python function will be called `zero_out`.
-
-To make the op available as a regular function `import`-able from a Python
-module, it maybe useful to have the `load_op_library` call in a Python source
-file as follows:
-
-```python
-import tensorflow as tf
-
-zero_out_module = tf.load_op_library('./zero_out.so')
-zero_out = zero_out_module.zero_out
-```
-
-## Verify that the op works
-
-A good way to verify that you've successfully implemented your op is to write a
-test for it. Create the file
-`zero_out_op_test.py` with the contents:
-
-```python
-import tensorflow as tf
-
-class ZeroOutTest(tf.test.TestCase):
- def testZeroOut(self):
- zero_out_module = tf.load_op_library('./zero_out.so')
- with self.test_session():
- result = zero_out_module.zero_out([5, 4, 3, 2, 1])
- self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
-
-if __name__ == "__main__":
- tf.test.main()
-```
-
-Then run your test (assuming you have tensorflow installed):
-
-```sh
-$ python zero_out_op_test.py
-```
-
-## Building advanced features into your op
-
-Now that you know how to build a basic (and somewhat restricted) op and
-implementation, we'll look at some of the more complicated things you will
-typically need to build into your op. This includes:
-
-* [Conditional checks and validation](#conditional-checks-and-validation)
-* [Op registration](#op-registration)
- * [Attrs](#attrs)
- * [Attr types](#attr-types)
- * [Polymorphism](#polymorphism)
- * [Inputs and outputs](#inputs-and-outputs)
- * [Backwards compatibility](#backwards-compatibility)
-* [GPU support](#gpu-support)
- * [Compiling the kernel for the GPU device](#compiling-the-kernel-for-the-gpu-device)
-* [Implement the gradient in Python](#implement-the-gradient-in-python)
-* [Shape functions in C++](#shape-functions-in-c)
-
-### Conditional checks and validation
-
-The example above assumed that the op applied to a tensor of any shape. What
-if it only applied to vectors? That means adding a check to the above OpKernel
-implementation.
-
-```c++
- void Compute(OpKernelContext* context) override {
- // Grab the input tensor
- const Tensor& input_tensor = context->input(0);
-
- OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
- errors::InvalidArgument("ZeroOut expects a 1-D vector."));
- // ...
- }
-```
-
-This asserts that the input is a vector, and returns having set the
-`InvalidArgument` status if it isn't. The
-[`OP_REQUIRES` macro][validation-macros] takes three arguments:
-
-* The `context`, which can either be an `OpKernelContext` or
- `OpKernelConstruction` pointer (see
- [`tensorflow/core/framework/op_kernel.h`](https://www.tensorflow.org/code/tensorflow/core/framework/op_kernel.h)),
- for its `SetStatus()` method.
-* The condition. For example, there are functions for validating the shape
- of a tensor in
- [`tensorflow/core/framework/tensor_shape.h`](https://www.tensorflow.org/code/tensorflow/core/framework/tensor_shape.h)
-* The error itself, which is represented by a `Status` object, see
- [`tensorflow/core/lib/core/status.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/status.h). A
- `Status` has both a type (frequently `InvalidArgument`, but see the list of
- types) and a message. Functions for constructing an error may be found in
- [`tensorflow/core/lib/core/errors.h`][validation-macros].
-
-Alternatively, if you want to test whether a `Status` object returned from some
-function is an error, and if so return it, use
-[`OP_REQUIRES_OK`][validation-macros]. Both of these macros return from the
-function on error.
-
-### Op registration
-
-#### Attrs
-
-Ops can have attrs, whose values are set when the op is added to a graph. These
-are used to configure the op, and their values can be accessed both within the
-kernel implementation and in the types of inputs and outputs in the op
-registration. Prefer using an input instead of an attr when possible, since
-inputs are more flexible. This is because attrs are constants and must be
-defined at graph construction time. In contrast, inputs are Tensors whose
-values can be dynamic; that is, inputs can change every step, be set using a
-feed, etc. Attrs are used for things that can't be done with inputs: any
-configuration that affects the signature (number or type of inputs or outputs)
-or that can't change from step-to-step.
-
-You define an attr when you register the op, by specifying its name and type
-using the `Attr` method, which expects a spec of the form:
-
-```
-<name>: <attr-type-expr>
-```
-
-where `<name>` begins with a letter and can be composed of alphanumeric
-characters and underscores, and `<attr-type-expr>` is a type expression of the
-form [described below](#attr_types).
-
-For example, if you'd like the `ZeroOut` op to preserve a user-specified index,
-instead of only the 0th element, you can register the op like so:
-```c++
-REGISTER_OP("ZeroOut")
- .Attr("preserve_index: int")
- .Input("to_zero: int32")
- .Output("zeroed: int32");
-```
-
-(Note that the set of [attribute types](#attr_types) is different from the
-`tf.DType` used for inputs and outputs.)
-
-Your kernel can then access this attr in its constructor via the `context`
-parameter:
-```c++
-class ZeroOutOp : public OpKernel {
- public:
- explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
- // Get the index of the value to preserve
- OP_REQUIRES_OK(context,
- context->GetAttr("preserve_index", &preserve_index_));
- // Check that preserve_index is positive
- OP_REQUIRES(context, preserve_index_ >= 0,
- errors::InvalidArgument("Need preserve_index >= 0, got ",
- preserve_index_));
- }
- void Compute(OpKernelContext* context) override {
- // ...
- }
- private:
- int preserve_index_;
-};
-```
-
-which can then be used in the `Compute` method:
-```c++
- void Compute(OpKernelContext* context) override {
- // ...
-
- // We're using saved attr to validate potentially dynamic input
- // So we check that preserve_index is in range
- OP_REQUIRES(context, preserve_index_ < input.dimension(0),
- errors::InvalidArgument("preserve_index out of range"));
-
- // Set all the elements of the output tensor to 0
- const int N = input.size();
- for (int i = 0; i < N; i++) {
- output\_flat(i) = 0;
- }
-
- // Preserve the requested input value
- output_flat(preserve_index_) = input(preserve_index_);
- }
-```
-
-#### Attr types
-
-The following types are supported in an attr:
-
-* `string`: Any sequence of bytes (not required to be UTF8).
-* `int`: A signed integer.
-* `float`: A floating point number.
-* `bool`: True or false.
-* `type`: One of the (non-ref) values of [`DataType`][DataTypeString].
-* `shape`: A [`TensorShapeProto`][TensorShapeProto].
-* `tensor`: A [`TensorProto`][TensorProto].
-* `list(<type>)`: A list of `<type>`, where `<type>` is one of the above types.
- Note that `list(list(<type>))` is invalid.
-
-See also: [`op_def_builder.cc:FinalizeAttr`][FinalizeAttr] for a definitive list.
-
-##### Default values & constraints
-
-Attrs may have default values, and some types of attrs can have constraints. To
-define an attr with constraints, you can use the following `<attr-type-expr>`s:
-
-* `{'<string1>', '<string2>'}`: The value must be a string that has either the
- value `<string1>` or `<string2>`. The name of the type, `string`, is implied
- when you use this syntax. This emulates an enum:
-
- ```c++
- REGISTER_OP("EnumExample")
- .Attr("e: {'apple', 'orange'}");
- ```
-
-* `{<type1>, <type2>}`: The value is of type `type`, and must be one of
- `<type1>` or `<type2>`, where `<type1>` and `<type2>` are supported
- `tf.DType`. You don't specify
- that the type of the attr is `type`. This is implied when you have a list of
- types in `{...}`. For example, in this case the attr `t` is a type that must
- be an `int32`, a `float`, or a `bool`:
-
- ```c++
- REGISTER_OP("RestrictedTypeExample")
- .Attr("t: {int32, float, bool}");
- ```
-
-* There are shortcuts for common type constraints:
- * `numbertype`: Type `type` restricted to the numeric (non-string and
- non-bool) types.
- * `realnumbertype`: Like `numbertype` without complex types.
- * `quantizedtype`: Like `numbertype` but just the quantized number types.
-
- The specific lists of types allowed by these are defined by the functions
- (like `NumberTypes()`) in
- [`tensorflow/core/framework/types.h`](https://www.tensorflow.org/code/tensorflow/core/framework/types.h).
- In this example the attr `t` must be one of the numeric types:
-
- ```c++
- REGISTER_OP("NumberType")
- .Attr("t: numbertype");
- ```
-
- For this op:
-
- ```python
- tf.number_type(t=tf.int32) # Valid
- tf.number_type(t=tf.bool) # Invalid
- ```
-
- Lists can be combined with other lists and single types. The following
- op allows attr `t` to be any of the numeric types, or the bool type:
-
- ```c++
- REGISTER_OP("NumberOrBooleanType")
- .Attr("t: {numbertype, bool}");
- ```
-
- For this op:
-
- ```python
- tf.number_or_boolean_type(t=tf.int32) # Valid
- tf.number_or_boolean_type(t=tf.bool) # Valid
- tf.number_or_boolean_type(t=tf.string) # Invalid
- ```
-
-* `int >= <n>`: The value must be an int whose value is greater than or equal to
- `<n>`, where `<n>` is a natural number.
-
- For example, the following op registration specifies that the attr `a` must
- have a value that is at least `2`:
-
- ```c++
- REGISTER_OP("MinIntExample")
- .Attr("a: int >= 2");
- ```
-
-* `list(<type>) >= <n>`: A list of type `<type>` whose length is greater than
- or equal to `<n>`.
-
- For example, the following op registration specifies that the attr `a` is a
- list of types (either `int32` or `float`), and that there must be at least 3
- of them:
-
- ```c++
- REGISTER_OP("TypeListExample")
- .Attr("a: list({int32, float}) >= 3");
- ```
-
-To set a default value for an attr (making it optional in the generated code),
-add `= <default>` to the end, as in:
-
-```c++
-REGISTER_OP("AttrDefaultExample")
- .Attr("i: int = 0");
-```
-
-The supported syntax of the default value is what would be used in the proto
-representation of the resulting GraphDef definition.
-
-Here are examples for how to specify a default for all types:
-
-```c++
-REGISTER_OP("AttrDefaultExampleForAllTypes")
- .Attr("s: string = 'foo'")
- .Attr("i: int = 0")
- .Attr("f: float = 1.0")
- .Attr("b: bool = true")
- .Attr("ty: type = DT_INT32")
- .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
- .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
- .Attr("l_empty: list(int) = []")
- .Attr("l_int: list(int) = [2, 3, 5, 7]");
-```
-
-Note in particular that the values of type `type`
-use `tf.DType`.
-
-#### Polymorphism
-
-##### Type Polymorphism
-
-For ops that can take different types as input or produce different output
-types, you can specify [an attr](#attrs) in
-[an input or output type](#inputs-and-outputs) in the op registration. Typically
-you would then register an `OpKernel` for each supported type.
-
-For instance, if you'd like the `ZeroOut` op to work on `float`s
-in addition to `int32`s, your op registration might look like:
-```c++
-REGISTER_OP("ZeroOut")
- .Attr("T: {float, int32}")
- .Input("to_zero: T")
- .Output("zeroed: T");
-```
-
-Your op registration now specifies that the input's type must be `float`, or
-`int32`, and that its output will be the same type, since both have type `T`.
-
-> <a id="naming"></a>A note on naming: Inputs, outputs, and attrs generally should be
-> given snake\_case names. The one exception is attrs that are used as the type
-> of an input or in the type of an input. Those attrs can be inferred when the
-> op is added to the graph and so don't appear in the op's function. For
-> example, this last definition of ZeroOut will generate a Python function that
-> looks like:
->
-> ```python
-> def zero_out(to_zero, name=None):
-> """...
-> Args:
-> to_zero: A `Tensor`. Must be one of the following types:
-> `float32`, `int32`.
-> name: A name for the operation (optional).
->
-> Returns:
-> A `Tensor`. Has the same type as `to_zero`.
-> """
-> ```
->
-> If `to_zero` is passed an `int32` tensor, then `T` is automatically set to
-> `int32` (well, actually `DT_INT32`). Those inferred attrs are given
-> Capitalized or CamelCase names.
->
-> Compare this with an op that has a type attr that determines the output
-> type:
->
-> ```c++
-> REGISTER_OP("StringToNumber")
-> .Input("string_tensor: string")
-> .Output("output: out_type")
-> .Attr("out_type: {float, int32} = DT_FLOAT");
-> .Doc(R"doc(
-> Converts each string in the input Tensor to the specified numeric type.
-> )doc");
-> ```
->
-> In this case, the user has to specify the output type, as in the generated
-> Python:
->
-> ```python
-> def string_to_number(string_tensor, out_type=None, name=None):
-> """Converts each string in the input Tensor to the specified numeric type.
->
-> Args:
-> string_tensor: A `Tensor` of type `string`.
-> out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
-> Defaults to `tf.float32`.
-> name: A name for the operation (optional).
->
-> Returns:
-> A `Tensor` of type `out_type`.
-> """
-> ```
-
-```c++
-#include "tensorflow/core/framework/op_kernel.h"
-
-class ZeroOutInt32Op : public OpKernel {
- // as before
-};
-
-class ZeroOutFloatOp : public OpKernel {
- public:
- explicit ZeroOutFloatOp(OpKernelConstruction* context)
- : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // Grab the input tensor
- const Tensor& input_tensor = context->input(0);
- auto input = input_tensor.flat<float>();
-
- // Create an output tensor
- Tensor* output = NULL;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input_tensor.shape(), &output));
- auto output_flat = output->template flat<float>();
-
- // Set all the elements of the output tensor to 0
- const int N = input.size();
- for (int i = 0; i < N; i++) {
- output_flat(i) = 0;
- }
-
- // Preserve the first input value
- if (N > 0) output_flat(0) = input(0);
- }
-};
-
-// Note that TypeConstraint<int32>("T") means that attr "T" (defined
-// in the op registration above) must be "int32" to use this template
-// instantiation.
-REGISTER_KERNEL_BUILDER(
- Name("ZeroOut")
- .Device(DEVICE_CPU)
- .TypeConstraint<int32>("T"),
- ZeroOutOpInt32);
-REGISTER_KERNEL_BUILDER(
- Name("ZeroOut")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- ZeroOutFloatOp);
-```
-
-> To preserve [backwards compatibility](#backwards-compatibility), you should
-> specify a [default value](#default-values-constraints) when adding an attr to
-> an existing op:
->
-> ```c++
-> REGISTER_OP("ZeroOut")
-> .Attr("T: {float, int32} = DT_INT32")
-> .Input("to_zero: T")
-> .Output("zeroed: T")
-> ```
-
-Let's say you wanted to add more types, say `double`:
-```c++
-REGISTER_OP("ZeroOut")
- .Attr("T: {float, double, int32}")
- .Input("to_zero: T")
- .Output("zeroed: T");
-```
-
-Instead of writing another `OpKernel` with redundant code as above, often you
-will be able to use a C++ template instead. You will still have one kernel
-registration (`REGISTER_KERNEL_BUILDER` call) per overload.
-```c++
-template <typename T>
-class ZeroOutOp : public OpKernel {
- public:
- explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- // Grab the input tensor
- const Tensor& input_tensor = context->input(0);
- auto input = input_tensor.flat<T>();
-
- // Create an output tensor
- Tensor* output = NULL;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input_tensor.shape(), &output));
- auto output_flat = output->template flat<T>();
-
- // Set all the elements of the output tensor to 0
- const int N = input.size();
- for (int i = 0; i < N; i++) {
- output_flat(i) = 0;
- }
-
- // Preserve the first input value
- if (N > 0) output_flat(0) = input(0);
- }
-};
-
-// Note that TypeConstraint<int32>("T") means that attr "T" (defined
-// in the op registration above) must be "int32" to use this template
-// instantiation.
-REGISTER_KERNEL_BUILDER(
- Name("ZeroOut")
- .Device(DEVICE_CPU)
- .TypeConstraint<int32>("T"),
- ZeroOutOp<int32>);
-REGISTER_KERNEL_BUILDER(
- Name("ZeroOut")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- ZeroOutOp<float>);
-REGISTER_KERNEL_BUILDER(
- Name("ZeroOut")
- .Device(DEVICE_CPU)
- .TypeConstraint<double>("T"),
- ZeroOutOp<double>);
-```
-
-If you have more than a couple overloads, you can put the registration in a
-macro.
-
-```c++
-#include "tensorflow/core/framework/op_kernel.h"
-
-#define REGISTER_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ZeroOutOp<type>)
-
-REGISTER_KERNEL(int32);
-REGISTER_KERNEL(float);
-REGISTER_KERNEL(double);
-
-#undef REGISTER_KERNEL
-```
-
-Depending on the list of types you are registering the kernel for, you may be
-able to use a macro provided by
-[`tensorflow/core/framework/register_types.h`][register_types]:
-
-```c++
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-
-REGISTER_OP("ZeroOut")
- .Attr("T: realnumbertype")
- .Input("to_zero: T")
- .Output("zeroed: T");
-
-template <typename T>
-class ZeroOutOp : public OpKernel { ... };
-
-#define REGISTER_KERNEL(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ZeroOutOp<type>)
-
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
-
-#undef REGISTER_KERNEL
-```
-
-##### List Inputs and Outputs
-
-In addition to being able to accept or produce different types, ops can consume
-or produce a variable number of tensors.
-
-In the next example, the attr `T` holds a *list* of types, and is used as the
-type of both the input `in` and the output `out`. The input and output are
-lists of tensors of that type (and the number and types of tensors in the output
-are the same as the input, since both have type `T`).
-
-```c++
-REGISTER_OP("PolymorphicListExample")
- .Attr("T: list(type)")
- .Input("in: T")
- .Output("out: T");
-```
-
-You can also place restrictions on what types can be specified in the list. In
-this next case, the input is a list of `float` and `double` tensors. The op
-accepts, for example, input types `(float, double, float)` and in that case the
-output type would also be `(float, double, float)`.
-
-```c++
-REGISTER_OP("ListTypeRestrictionExample")
- .Attr("T: list({float, double})")
- .Input("in: T")
- .Output("out: T");
-```
-
-If you want all the tensors in a list to be of the same type, you might do
-something like:
-
-```c++
-REGISTER_OP("IntListInputExample")
- .Attr("N: int")
- .Input("in: N * int32")
- .Output("out: int32");
-```
-
-This accepts a list of `int32` tensors, and uses an `int` attr `N` to
-specify the length of the list.
-
-This can be made [type polymorphic](#type-polymorphism) as well. In the next
-example, the input is a list of tensors (with length `"N"`) of the same (but
-unspecified) type (`"T"`), and the output is a single tensor of matching type:
-
-```c++
-REGISTER_OP("SameListInputExample")
- .Attr("N: int")
- .Attr("T: type")
- .Input("in: N * T")
- .Output("out: T");
-```
-
-By default, tensor lists have a minimum length of 1. You can change that default
-using
-[a `">="` constraint on the corresponding attr](#default-values-constraints).
-In this next example, the input is a list of at least 2 `int32` tensors:
-
-```c++
-REGISTER_OP("MinLengthIntListExample")
- .Attr("N: int >= 2")
- .Input("in: N * int32")
- .Output("out: int32");
-```
-
-The same syntax works with `"list(type)"` attrs:
-
-```c++
-REGISTER_OP("MinimumLengthPolymorphicListExample")
- .Attr("T: list(type) >= 3")
- .Input("in: T")
- .Output("out: T");
-```
-
-#### Inputs and Outputs
-
-To summarize the above, an op registration can have multiple inputs and outputs:
-
-```c++
-REGISTER_OP("MultipleInsAndOuts")
- .Input("y: int32")
- .Input("z: float")
- .Output("a: string")
- .Output("b: int32");
-```
-
-Each input or output spec is of the form:
-
-```
-<name>: <io-type-expr>
-```
-
-where `<name>` begins with a letter and can be composed of alphanumeric
-characters and underscores. `<io-type-expr>` is one of the following type
-expressions:
-
-* `<type>`, where `<type>` is a supported input type (e.g. `float`, `int32`,
- `string`). This specifies a single tensor of the given type.
-
- See
- `tf.DType`.
-
- ```c++
- REGISTER_OP("BuiltInTypesExample")
- .Input("integers: int32")
- .Input("complex_numbers: complex64");
- ```
-
-* `<attr-type>`, where `<attr-type>` is the name of an [Attr](#attrs) with type
- `type` or `list(type)` (with a possible type restriction). This syntax allows
- for [polymorphic ops](#polymorphism).
-
- ```c++
- REGISTER_OP("PolymorphicSingleInput")
- .Attr("T: type")
- .Input("in: T");
-
- REGISTER_OP("RestrictedPolymorphicSingleInput")
- .Attr("T: {int32, int64}")
- .Input("in: T");
- ```
-
- Referencing an attr of type `list(type)` allows you to accept a sequence of
- tensors.
-
- ```c++
- REGISTER_OP("ArbitraryTensorSequenceExample")
- .Attr("T: list(type)")
- .Input("in: T")
- .Output("out: T");
-
- REGISTER_OP("RestrictedTensorSequenceExample")
- .Attr("T: list({int32, int64})")
- .Input("in: T")
- .Output("out: T");
- ```
-
- Note that the number and types of tensors in the output `out` is the same as
- in the input `in`, since both are of type `T`.
-
-* For a sequence of tensors with the same type: `<number> * <type>`, where
- `<number>` is the name of an [Attr](#attrs) with type `int`. The `<type>` can
- either be a `tf.DType`,
- or the name of an attr with type `type`. As an example of the first, this
- op accepts a list of `int32` tensors:
-
- ```c++
- REGISTER_OP("Int32SequenceExample")
- .Attr("NumTensors: int")
- .Input("in: NumTensors * int32")
- ```
-
- Whereas this op accepts a list of tensors of any type, as long as they are all
- the same:
-
- ```c++
- REGISTER_OP("SameTypeSequenceExample")
- .Attr("NumTensors: int")
- .Attr("T: type")
- .Input("in: NumTensors * T")
- ```
-
-* For a reference to a tensor: `Ref(<type>)`, where `<type>` is one of the
- previous types.
-
-> A note on naming: Any attr used in the type of an input will be inferred. By
-> convention those inferred attrs use capital names (like `T` or `N`).
-> Otherwise inputs, outputs, and attrs have names like function parameters
-> (e.g. `num_outputs`). For more details, see the
-> [earlier note on naming](#naming).
-
-For more details, see
-[`tensorflow/core/framework/op_def_builder.h`][op_def_builder].
-
-#### Backwards compatibility
-
-Let's assume you have written a nice, custom op and shared it with others, so
-you have happy customers using your operation. However, you'd like to make
-changes to the op in some way.
-
-In general, changes to existing, checked-in specifications must be
-backwards-compatible: changing the specification of an op must not break prior
-serialized `GraphDef` protocol buffers constructed from older specifications.
-The details of `GraphDef` compatibility are
-@{$version_compat#compatibility_of_graphs_and_checkpoints$described here}.
-
-There are several ways to preserve backwards-compatibility.
-
-1. Any new attrs added to an operation must have default values defined, and
- with that default value the op must have the original behavior. To change an
- operation from not polymorphic to polymorphic, you *must* give a default
- value to the new type attr to preserve the original signature by default. For
- example, if your operation was:
-
- REGISTER_OP("MyGeneralUnaryOp")
- .Input("in: float")
- .Output("out: float");
-
- you can make it polymorphic in a backwards-compatible way using:
-
- REGISTER_OP("MyGeneralUnaryOp")
- .Input("in: T")
- .Output("out: T")
- .Attr("T: numerictype = DT_FLOAT");
-
-2. You can safely make a constraint on an attr less restrictive. For example,
- you can change from `{int32, int64}` to `{int32, int64, float}` or `type`.
- Or you may change from `{"apple", "orange"}` to `{"apple", "banana",
- "orange"}` or `string`.
-
-3. You can change single inputs / outputs into list inputs / outputs, as long as
- the default for the list type matches the old signature.
-
-4. You can add a new list input / output, if it defaults to empty.
-
-5. Namespace any new ops you create, by prefixing the op names with something
- unique to your project. This avoids having your op colliding with any ops
- that might be included in future versions of TensorFlow.
-
-6. Plan ahead! Try to anticipate future uses for the op. Some signature changes
- can't be done in a compatible way (for example, making a list of the same
- type into a list of varying types).
-
-The full list of safe and unsafe changes can be found in
-[`tensorflow/core/framework/op_compatibility_test.cc`](https://www.tensorflow.org/code/tensorflow/core/framework/op_compatibility_test.cc).
-If you cannot make your change to an operation backwards compatible, then create
-a new operation with a new name with the new semantics.
-
-Also note that while these changes can maintain `GraphDef` compatibility, the
-generated Python code may change in a way that isn't compatible with old
-callers. The Python API may be kept compatible by careful changes in a
-hand-written Python wrapper, by keeping the old signature except possibly adding
-new optional arguments to the end. Generally incompatible changes may only be
-made when TensorFlow's changes major versions, and must conform to the
-@{$version_compat#compatibility_of_graphs_and_checkpoints$`GraphDef` version semantics}.
-
-### GPU Support
-
-You can implement different OpKernels and register one for CPU and another for
-GPU, just like you can [register kernels for different types](#polymorphism).
-There are several examples of kernels with GPU support in
-[`tensorflow/core/kernels/`](https://www.tensorflow.org/code/tensorflow/core/kernels/).
-Notice some kernels have a CPU version in a `.cc` file, a GPU version in a file
-ending in `_gpu.cu.cc`, and some code shared in common in a `.h` file.
-
-For example, the `tf.pad` has
-everything but the GPU kernel in [`tensorflow/core/kernels/pad_op.cc`][pad_op].
-The GPU kernel is in
-[`tensorflow/core/kernels/pad_op_gpu.cu.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/pad_op_gpu.cu.cc),
-and the shared code is a templated class defined in
-[`tensorflow/core/kernels/pad_op.h`](https://www.tensorflow.org/code/tensorflow/core/kernels/pad_op.h).
-We organize the code this way for two reasons: it allows you to share common
-code among the CPU and GPU implementations, and it puts the GPU implementation
-into a separate file so that it can be compiled only by the GPU compiler.
-
-One thing to note, even when the GPU kernel version of `pad` is used, it still
-needs its `"paddings"` input in CPU memory. To mark that inputs or outputs are
-kept on the CPU, add a `HostMemory()` call to the kernel registration, e.g.:
-
-```c++
-#define REGISTER_GPU_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("Pad") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("paddings"), \
- PadOp<GPUDevice, T>)
-```
-
-#### Compiling the kernel for the GPU device
-
-Look at
-[cuda_op_kernel.cu.cc](https://www.tensorflow.org/code/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc)
-for an example that uses a CUDA kernel to implement an op. The
-`tf_custom_op_library` accepts a `gpu_srcs` argument in which the list of source
-files containing the CUDA kernels (`*.cu.cc` files) can be specified. For use
-with a binary installation of TensorFlow, the CUDA kernels have to be compiled
-with NVIDIA's `nvcc` compiler. Here is the sequence of commands you can use to
-compile the
-[cuda_op_kernel.cu.cc](https://www.tensorflow.org/code/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc)
-and
-[cuda_op_kernel.cc](https://www.tensorflow.org/code/tensorflow/examples/adding_an_op/cuda_op_kernel.cc)
-into a single dynamically loadable library:
-
-```bash
-nvcc -std=c++11 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
- ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
-
-g++ -std=c++11 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
- cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}
-```
-
-`cuda_op_kernel.so` produced above can be loaded as usual in Python, using the
-`tf.load_op_library` function.
-
-Note that if your CUDA libraries are not installed in `/usr/local/lib64`,
-you'll need to specify the path explicitly in the second (g++) command above.
-For example, add `-L /usr/local/cuda-8.0/lib64/` if your CUDA is installed in
-`/usr/local/cuda-8.0`.
-
-> Note in some linux settings, additional options to `nvcc` compiling step are needed. Add `-D_MWAITXINTRIN_H_INCLUDED` to the `nvcc` command line to avoid errors from `mwaitxintrin.h`.
-
-### Implement the gradient in Python
-
-Given a graph of ops, TensorFlow uses automatic differentiation
-(backpropagation) to add new ops representing gradients with respect to the
-existing ops (see
-@{$python/train#gradient_computation$Gradient Computation}).
-To make automatic differentiation work for new ops, you must register a gradient
-function which computes gradients with respect to the ops' inputs given
-gradients with respect to the ops' outputs.
-
-Mathematically, if an op computes \\(y = f(x)\\) the registered gradient op
-converts gradients \\(\partial L/ \partial y\\) of loss \\(L\\) with respect to
-\\(y\\) into gradients \\(\partial L/ \partial x\\) with respect to \\(x\\) via
-the chain rule:
-
-$$\frac{\partial L}{\partial x}
- = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x}
- = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.$$
-
-In the case of `ZeroOut`, only one entry in the input affects the output, so the
-gradient with respect to the input is a sparse "one hot" tensor. This is
-expressed as follows:
-
-```python
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import sparse_ops
-
-@ops.RegisterGradient("ZeroOut")
-def _zero_out_grad(op, grad):
- """The gradients for `zero_out`.
-
- Args:
- op: The `zero_out` `Operation` that we are differentiating, which we can use
- to find the inputs and outputs of the original op.
- grad: Gradient with respect to the output of the `zero_out` op.
-
- Returns:
- Gradients with respect to the input of `zero_out`.
- """
- to_zero = op.inputs[0]
- shape = array_ops.shape(to_zero)
- index = array_ops.zeros_like(shape)
- first_grad = array_ops.reshape(grad, [-1])[0]
- to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
- return [to_zero_grad] # List of one Tensor, since we have one input
-```
-
-Details about registering gradient functions with
-`tf.RegisterGradient`:
-
-* For an op with one output, the gradient function will take an
- `tf.Operation` `op` and a
- `tf.Tensor` `grad` and build new ops
- out of the tensors
- [`op.inputs[i]`](../../api_docs/python/framework.md#Operation.inputs),
- [`op.outputs[i]`](../../api_docs/python/framework.md#Operation.outputs), and `grad`. Information
- about any attrs can be found via
- `tf.Operation.get_attr`.
-
-* If the op has multiple outputs, the gradient function will take `op` and
- `grads`, where `grads` is a list of gradients with respect to each output.
- The result of the gradient function must be a list of `Tensor` objects
- representing the gradients with respect to each input.
-
-* If there is no well-defined gradient for some input, such as for integer
- inputs used as indices, the corresponding returned gradient should be
- `None`. For example, for an op taking a floating point tensor `x` and an
- integer index `i`, the gradient function would `return [x_grad, None]`.
-
-* If there is no meaningful gradient for the op at all, you often will not have
- to register any gradient, and as long as the op's gradient is never needed,
- you will be fine. In some cases, an op has no well-defined gradient but can
- be involved in the computation of the gradient. Here you can use
- `ops.NotDifferentiable` to automatically propagate zeros backwards.
-
-Note that at the time the gradient function is called, only the data flow graph
-of ops is available, not the tensor data itself. Thus, all computation must be
-performed using other tensorflow ops, to be run at graph execution time.
-
-### Shape functions in C++
-
-The TensorFlow API has a feature called "shape inference" that provides
-information about the shapes of tensors without having to execute the
-graph. Shape inference is supported by "shape functions" that are registered for
-each op type in the C++ `REGISTER_OP` declaration, and perform two roles:
-asserting that the shapes of the inputs are compatible during graph
-construction, and specifying the shapes for the outputs.
-
-Shape functions are defined as operations on the
-`shape_inference::InferenceContext` class. For example, in the shape function
-for ZeroOut:
-
-```c++
- .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- });
-```
-
-`c->set_output(0, c->input(0));` declares that the first output's shape should
-be set to the first input's shape. If the output is selected by its index as in the above example, the second parameter of `set_output` should be a `ShapeHandle` object. You can create an empty `ShapeHandle` object by its default constructor. The `ShapeHandle` object for an input with index `idx` can be obtained by `c->input(idx)`.
-
-There are a number of common shape functions
-that apply to many ops, such as `shape_inference::UnchangedShape` which can be
-found in [common_shape_fns.h](https://www.tensorflow.org/code/tensorflow/core/framework/common_shape_fns.h) and used as follows:
-
-```c++
-REGISTER_OP("ZeroOut")
- .Input("to_zero: int32")
- .Output("zeroed: int32")
- .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);
-```
-
-A shape function can also constrain the shape of an input. For the version of
-[`ZeroOut` with a vector shape constraint](#validation), the shape function
-would be as follows:
-
-```c++
- .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
- ::tensorflow::shape_inference::ShapeHandle input;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
- c->set_output(0, input);
- return Status::OK();
- });
-```
-
-The `WithRank` call validates that the input shape `c->input(0)` has
-a shape with exactly one dimension (or if the input shape is unknown,
-the output shape will be a vector with one unknown dimension).
-
-If your op is [polymorphic with multiple inputs](#polymorphism), you can use
-members of `InferenceContext` to determine the number of shapes to check, and
-`Merge` to validate that the shapes are all compatible (alternatively, access
-attributes that indicate the lengths, with `InferenceContext::GetAttr`, which
-provides access to the attributes of the op).
-
-```c++
- .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
- ::tensorflow::shape_inference::ShapeHandle input;
- ::tensorflow::shape_inference::ShapeHandle output;
- for (size_t i = 0; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
- TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
- }
- c->set_output(0, output);
- return Status::OK();
- });
-```
-
-Since shape inference is an optional feature, and the shapes of tensors may vary
-dynamically, shape functions must be robust to incomplete shape information for
-any of the inputs. The `Merge` method in [`InferenceContext`](https://www.tensorflow.org/code/tensorflow/core/framework/shape_inference.h)
-allows the caller to assert that two shapes are the same, even if either
-or both of them do not have complete information. Shape functions are defined
-for all of the core TensorFlow ops and provide many different usage examples.
-
-The `InferenceContext` class has a number of functions that can be used to
-define shape function manipulations. For example, you can validate that a
-particular dimension has a very specific value using `InferenceContext::Dim` and
-`InferenceContext::WithValue`; you can specify that an output dimension is the
-sum / product of two input dimensions using `InferenceContext::Add` and
-`InferenceContext::Multiply`. See the `InferenceContext` class for
-all of the various shape manipulations you can specify. The following example sets
-shape of the first output to (n, 3), where first input has shape (n, ...)
-
-```c++
-.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
- c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
- return Status::OK();
-});
-```
-
-If you have a complicated shape function, you should consider adding a test for
-validating that various input shape combinations produce the expected output
-shape combinations. You can see examples of how to write these tests in some
-our
-[core ops tests](https://www.tensorflow.org/code/tensorflow/core/ops/array_ops_test.cc).
-(The syntax of `INFER_OK` and `INFER_ERROR` are a little cryptic, but try to be
-compact in representing input and output shape specifications in tests. For
-now, see the surrounding comments in those tests to get a sense of the shape
-string specification).
-
-
-[core-array_ops]:https://www.tensorflow.org/code/tensorflow/core/ops/array_ops.cc
-[python-user_ops]:https://www.tensorflow.org/code/tensorflow/python/user_ops/user_ops.py
-[tf-kernels]:https://www.tensorflow.org/code/tensorflow/core/kernels/
-[user_ops]:https://www.tensorflow.org/code/tensorflow/core/user_ops/
-[pad_op]:https://www.tensorflow.org/code/tensorflow/core/kernels/pad_op.cc
-[standard_ops-py]:https://www.tensorflow.org/code/tensorflow/python/ops/standard_ops.py
-[standard_ops-cc]:https://www.tensorflow.org/code/tensorflow/cc/ops/standard_ops.h
-[python-BUILD]:https://www.tensorflow.org/code/tensorflow/python/BUILD
-[validation-macros]:https://www.tensorflow.org/code/tensorflow/core/lib/core/errors.h
-[op_def_builder]:https://www.tensorflow.org/code/tensorflow/core/framework/op_def_builder.h
-[register_types]:https://www.tensorflow.org/code/tensorflow/core/framework/register_types.h
-[FinalizeAttr]:https://www.tensorflow.org/code/tensorflow/core/framework/op_def_builder.cc
-[DataTypeString]:https://www.tensorflow.org/code/tensorflow/core/framework/types.cc
-[python-BUILD]:https://www.tensorflow.org/code/tensorflow/python/BUILD
-[types-proto]:https://www.tensorflow.org/code/tensorflow/core/framework/types.proto
-[TensorShapeProto]:https://www.tensorflow.org/code/tensorflow/core/framework/tensor_shape.proto
-[TensorProto]:https://www.tensorflow.org/code/tensorflow/core/framework/tensor.proto
diff --git a/tensorflow/docs_src/extend/architecture.md b/tensorflow/docs_src/extend/architecture.md
deleted file mode 100644
index 83d70c9468..0000000000
--- a/tensorflow/docs_src/extend/architecture.md
+++ /dev/null
@@ -1,217 +0,0 @@
-# TensorFlow Architecture
-
-We designed TensorFlow for large-scale distributed training and inference, but
-it is also flexible enough to support experimentation with new machine
-learning models and system-level optimizations.
-
-This document describes the system architecture that makes this
-combination of scale and flexibility possible. It assumes that you have basic familiarity
-with TensorFlow programming concepts such as the computation graph, operations,
-and sessions. See @{$guide/low_level_intro$this document} for an introduction to
-these topics. Some familiarity with @{$distributed$distributed TensorFlow}
-will also be helpful.
-
-This document is for developers who want to extend TensorFlow in some way not
-supported by current APIs, hardware engineers who want to optimize for
-TensorFlow, implementers of machine learning systems working on scaling and
-distribution, or anyone who wants to look under Tensorflow's hood. By the end of this document
-you should understand the TensorFlow architecture well enough to read
-and modify the core TensorFlow code.
-
-## Overview
-
-The TensorFlow runtime is a cross-platform library. Figure 1 illustrates its
-general architecture. A C API separates user level code in different languages
-from the core runtime.
-
-![TensorFlow Layers](https://www.tensorflow.org/images/layers.png){: width="300"}
-
-**Figure 1**
-
-
-This document focuses on the following layers:
-
-* **Client**:
- * Defines the computation as a dataflow graph.
- * Initiates graph execution using a [**session**](
- https://www.tensorflow.org/code/tensorflow/python/client/session.py).
-* **Distributed Master**
- * Prunes a specific subgraph from the graph, as defined by the arguments
- to Session.run().
- * Partitions the subgraph into multiple pieces that run in different
- processes and devices.
- * Distributes the graph pieces to worker services.
- * Initiates graph piece execution by worker services.
-* **Worker Services** (one for each task)
- * Schedule the execution of graph operations using kernel implementations
- appropriate to the available hardware (CPUs, GPUs, etc).
- * Send and receive operation results to and from other worker services.
-* **Kernel Implementations**
- * Perform the computation for individual graph operations.
-
-Figure 2 illustrates the interaction of these components. "/job:worker/task:0" and
-"/job:ps/task:0" are both tasks with worker services. "PS" stands for "parameter
-server": a task responsible for storing and updating the model's parameters.
-Other tasks send updates to these parameters as they work on optimizing the
-parameters. This particular division of labor between tasks is not required, but
- is common for distributed training.
-
-![TensorFlow Architecture Diagram](https://www.tensorflow.org/images/diag1.svg){: width="500"}
-
-**Figure 2**
-
-Note that the Distributed Master and Worker Service only exist in
-distributed TensorFlow. The single-process version of TensorFlow includes a
-special Session implementation that does everything the distributed master does
-but only communicates with devices in the local process.
-
-The following sections describe the core TensorFlow layers in greater detail and
-step through the processing of an example graph.
-
-## Client
-
-Users write the client TensorFlow program that builds the computation graph.
-This program can either directly compose individual operations or use a
-convenience library like the Estimators API to compose neural network layers and
-other higher-level abstractions. TensorFlow supports multiple client
-languages, and we have prioritized Python and C++, because our internal users
-are most familiar with these languages. As features become more established,
-we typically port them to C++, so that users can access an optimized
-implementation from all client languages. Most of the training libraries are
-still Python-only, but C++ does have support for efficient inference.
-
-The client creates a session, which sends the graph definition to the
-distributed master as a `tf.GraphDef`
-protocol buffer. When the client evaluates a node or nodes in the
-graph, the evaluation triggers a call to the distributed master to initiate
-computation.
-
-In Figure 3, the client has built a graph that applies weights (w) to a
-feature vector (x), adds a bias term (b) and saves the result in a variable
-(s).
-
-![TensorFlow Architecture Diagram: Client](https://www.tensorflow.org/images/graph_client.svg){: width="700"}
-
-**Figure 3**
-
-### Code
-
-* `tf.Session`
-
-## Distributed master
-
-The distributed master:
-
-* prunes the graph to obtain the subgraph required to evaluate the nodes
- requested by the client,
-* partitions the graph to obtain graph pieces for
- each participating device, and
-* caches these pieces so that they may be re-used in subsequent steps.
-
-Since the master sees the overall computation for
-a step, it applies standard optimizations such as common subexpression
-elimination and constant folding. It then coordinates execution of the
-optimized subgraphs across a set of tasks.
-
-![TensorFlow Architecture Diagram: Master](https://www.tensorflow.org/images/graph_master_cln.svg){: width="700"}
-
-**Figure 4**
-
-
-Figure 5 shows a possible partition of our example graph. The distributed
-master has grouped the model parameters in order to place them together on the
-parameter server.
-
-![Partitioned Graph](https://www.tensorflow.org/images/graph_split1.svg){: width="700"}
-
-**Figure 5**
-
-
-Where graph edges are cut by the partition, the distributed master inserts
-send and receive nodes to pass information between the distributed tasks
-(Figure 6).
-
-![Partitioned Graph](https://www.tensorflow.org/images/graph_split2.svg){: width="700"}
-
-**Figure 6**
-
-
-The distributed master then ships the graph pieces to the distributed tasks.
-
-![Partitioned Graph](https://www.tensorflow.org/images/graph_workers_cln.svg){: width="700"}
-
-**Figure 7**
-
-### Code
-
-* [MasterService API definition](https://www.tensorflow.org/code/tensorflow/core/protobuf/master_service.proto)
-* [Master interface](https://www.tensorflow.org/code/tensorflow/core/distributed_runtime/master_interface.h)
-
-## Worker Service
-
-The worker service in each task:
-
-* handles requests from the master,
-* schedules the execution of the kernels for the operations that comprise a
- local subgraph, and
-* mediates direct communication between tasks.
-
-We optimize the worker service for running large graphs with low overhead. Our
-current implementation can execute tens of thousands of subgraphs per second,
-which enables a large number of replicas to make rapid, fine-grained training
-steps. The worker service dispatches kernels to local devices and runs kernels
-in parallel when possible, for example by using multiple CPU cores or GPU
-streams.
-
-We specialize Send and Recv operations for each pair of source and destination
-device types:
-
-* Transfers between local CPU and GPU devices use the
- `cudaMemcpyAsync()` API to overlap computation and data transfer.
-* Transfers between two local GPUs use peer-to-peer DMA, to avoid an expensive
- copy via the host CPU.
-
-For transfers between tasks, TensorFlow uses multiple protocols, including:
-
-* gRPC over TCP.
-* RDMA over Converged Ethernet.
-
-We also have preliminary support for NVIDIA's NCCL library for multi-GPU
-communication (see [`tf.contrib.nccl`](
-https://www.tensorflow.org/code/tensorflow/contrib/nccl/python/ops/nccl_ops.py)).
-
-![Partitioned Graph](https://www.tensorflow.org/images/graph_send_recv.svg){: width="700"}
-
-**Figure 8**
-
-### Code
-
-* [WorkerService API definition](https://www.tensorflow.org/code/tensorflow/core/protobuf/worker_service.proto)
-* [Worker interface](https://www.tensorflow.org/code/tensorflow/core/distributed_runtime/worker_interface.h)
-* [Remote rendezvous (for Send and Recv implementations)](https://www.tensorflow.org/code/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h)
-
-## Kernel Implementations
-
-The runtime contains over 200 standard operations including mathematical, array
-manipulation, control flow, and state management operations. Each of these
-operations can have kernel implementations optimized for a variety of devices.
-Many of the operation kernels are implemented using Eigen::Tensor, which uses
-C++ templates to generate efficient parallel code for multicore CPUs and GPUs;
-however, we liberally use libraries like cuDNN where a more efficient kernel
-implementation is possible. We have also implemented
-@{$quantization$quantization}, which enables
-faster inference in environments such as mobile devices and high-throughput
-datacenter applications, and use the
-[gemmlowp](https://github.com/google/gemmlowp) low-precision matrix library to
-accelerate quantized computation.
-
-If it is difficult or inefficient to represent a subcomputation as a composition
-of operations, users can register additional kernels that provide an efficient
-implementation written in C++. For example, we recommend registering your own
-fused kernels for some performance critical operations, such as the ReLU and
-Sigmoid activation functions and their corresponding gradients. The @{$xla$XLA Compiler} has an
-experimental implementation of automatic kernel fusion.
-
-### Code
-
-* [`OpKernel` interface](https://www.tensorflow.org/code/tensorflow/core/framework/op_kernel.h)
diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md
deleted file mode 100644
index 0e4bfd1dc4..0000000000
--- a/tensorflow/docs_src/extend/index.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# Extend
-
-This section explains how developers can add functionality to TensorFlow's
-capabilities. Begin by reading the following architectural overview:
-
- * @{$architecture$TensorFlow Architecture}
-
-The following guides explain how to extend particular aspects of
-TensorFlow:
-
- * @{$adding_an_op$Adding a New Op}, which explains how to create your own
- operations.
- * @{$add_filesys$Adding a Custom Filesystem Plugin}, which explains how to
- add support for your own shared or distributed filesystem.
- * @{$new_data_formats$Custom Data Readers}, which details how to add support
- for your own file and record formats.
-
-Python is currently the only language supported by TensorFlow's API stability
-promises. However, TensorFlow also provides functionality in C++, Go, Java and
-[JavaScript](https://js.tensorflow.org) (including
-[Node.js](https://github.com/tensorflow/tfjs-node)),
-plus community support for [Haskell](https://github.com/tensorflow/haskell) and
-[Rust](https://github.com/tensorflow/rust). If you'd like to create or
-develop TensorFlow features in a language other than these languages, read the
-following guide:
-
- * @{$language_bindings$TensorFlow in Other Languages}
-
-To create tools compatible with TensorFlow's model format, read the following
-guide:
-
- * @{$tool_developers$A Tool Developer's Guide to TensorFlow Model Files}
-
-
diff --git a/tensorflow/docs_src/extend/language_bindings.md b/tensorflow/docs_src/extend/language_bindings.md
deleted file mode 100644
index 9a968d365b..0000000000
--- a/tensorflow/docs_src/extend/language_bindings.md
+++ /dev/null
@@ -1,231 +0,0 @@
-# TensorFlow in other languages
-
-## Background
-
-This document is intended as a guide for those interested in the creation or
-development of TensorFlow functionality in other programming languages. It
-describes the features of TensorFlow and recommended steps for making the same
-available in other programming languages.
-
-Python was the first client language supported by TensorFlow and currently
-supports the most features. More and more of that functionality is being moved
-into the core of TensorFlow (implemented in C++) and exposed via a [C API].
-Client languages should use the language's [foreign function interface
-(FFI)](https://en.wikipedia.org/wiki/Foreign_function_interface) to call into
-this [C API] to provide TensorFlow functionality.
-
-## Overview
-
-Providing TensorFlow functionality in a programming language can be broken down
-into broad categories:
-
-- *Run a predefined graph*: Given a `GraphDef` (or
- `MetaGraphDef`) protocol message, be able to create a session, run queries,
- and get tensor results. This is sufficient for a mobile app or server that
- wants to run inference on a pre-trained model.
-- *Graph construction*: At least one function per defined
- TensorFlow op that adds an operation to the graph. Ideally these functions
- would be automatically generated so they stay in sync as the op definitions
- are modified.
-- *Gradients (AKA automatic differentiation)*: Given a graph and a list of
- input and output operations, add operations to the graph that compute the
- partial derivatives (gradients) of the inputs with respect to the outputs.
- Allows for customization of the gradient function for a particular operation
- in the graph.
-- *Functions*: Define a subgraph that may be called in multiple places in the
- main `GraphDef`. Defines a `FunctionDef` in the `FunctionDefLibrary`
- included in a `GraphDef`.
-- *Control Flow*: Construct "If" and "While" with user-specified subgraphs.
- Ideally these work with gradients (see above).
-- *Neural Network library*: A number of components that together support the
- creation of neural network models and training them (possibly in a
- distributed setting). While it would be convenient to have this available in
- other languages, there are currently no plans to support this in languages
- other than Python. These libraries are typically wrappers over the features
- described above.
-
-At a minimum, a language binding should support running a predefined graph, but
-most should also support graph construction. The TensorFlow Python API provides
-all these features.
-
-## Current Status
-
-New language support should be built on top of the [C API]. However, as you can
-see in the table below, not all functionality is available in C yet. Providing
-more functionality in the [C API] is an ongoing project.
-
-Feature | Python | C
-:--------------------------------------------- | :---------------------------------------------------------- | :--
-Run a predefined Graph | `tf.import_graph_def`, `tf.Session` | `TF_GraphImportGraphDef`, `TF_NewSession`
-Graph construction with generated op functions | Yes | Yes (The C API supports client languages that do this)
-Gradients | `tf.gradients` |
-Functions | `tf.python.framework.function.Defun` |
-Control Flow | `tf.cond`, `tf.while_loop` |
-Neural Network library | `tf.train`, `tf.nn`, `tf.contrib.layers`, `tf.contrib.slim` |
-
-## Recommended Approach
-
-### Run a predefined graph
-
-A language binding is expected to define the following classes:
-
-- `Graph`: A graph representing a TensorFlow computation. Consists of
- operations (represented in the client language by `Operation`s) and
- corresponds to a `TF_Graph` in the C API. Mainly used as an argument when
- creating new `Operation` objects and when starting a `Session`. Also
- supports iterating through the operations in the graph
- (`TF_GraphNextOperation`), looking up operations by name
- (`TF_GraphOperationByName`), and converting to and from a `GraphDef`
- protocol message (`TF_GraphToGraphDef` and `TF_GraphImportGraphDef` in the C
- API).
-- `Operation`: Represents a computation node in the graph. Corresponds to a
- `TF_Operation` in the C API.
-- `Output`: Represents one of the outputs of an operation in the graph. Has a
- `DataType` (and eventually a shape). May be passed as an input argument to a
- function for adding operations to a graph, or to a `Session`'s `Run()`
- method to fetch that output as a tensor. Corresponds to a `TF_Output` in the
- C API.
-- `Session`: Represents a client to a particular instance of the TensorFlow
- runtime. Its main job is to be constructed with a `Graph` and some options
- and then field calls to `Run()` the graph. Corresponds to a `TF_Session` in
- the C API.
-- `Tensor`: Represents an N-dimensional (rectangular) array with elements all
- the same `DataType`. Gets data into and out of a `Session`'s `Run()` call.
- Corresponds to a `TF_Tensor` in the C API.
-- `DataType`: An enumerant with all the possible tensor types supported by
- TensorFlow. Corresponds to `TF_DataType` in the C API and often referred to
- as `dtype` in the Python API.
-
-### Graph construction
-
-TensorFlow has many ops, and the list is not static, so we recommend generating
-the functions for adding ops to a graph instead of writing them by individually
-by hand (though writing a few by hand is a good way to figure out what the
-generator should generate). The information needed to generate a function is
-contained in an `OpDef` protocol message.
-
-There are a few ways to get a list of the `OpDef`s for the registered ops:
-
-- `TF_GetAllOpList` in the C API retrieves all registered `OpDef` protocol
- messages. This can be used to write the generator in the client language.
- This requires that the client language have protocol buffer support in order
- to interpret the `OpDef` messages.
-- The C++ function `OpRegistry::Global()->GetRegisteredOps()` returns the same
- list of all registered `OpDef`s (defined in
- [`tensorflow/core/framework/op.h`](https://www.tensorflow.org/code/tensorflow/core/framework/op.h)). This can be used to write the generator
- in C++ (particularly useful for languages that do not have protocol buffer
- support).
-- The ASCII-serialized version of that list is periodically checked in to
- [`tensorflow/core/ops/ops.pbtxt`](https://www.tensorflow.org/code/tensorflow/core/ops/ops.pbtxt) by an automated process.
-
-The `OpDef` specifies the following:
-
-- Name of the op in CamelCase. For generated functions follow the conventions
- of the language. For example, if the language uses snake_case, use that
- instead of CamelCase for the op's function name.
-- A list of inputs and outputs. The types for these may be polymorphic by
- referencing attributes, as described in the inputs and outputs section of
- @{$adding_an_op$Adding an op}.
-- A list of attributes, along with their default values (if any). Note that
- some of these will be inferred (if they are determined by an input), some
- will be optional (if they have a default), and some will be required (no
- default).
-- Documentation for the op in general and the inputs, outputs, and
- non-inferred attributes.
-- Some other fields that are used by the runtime and can be ignored by the
- code generators.
-
-An `OpDef` can be converted into the text of a function that adds that op to the
-graph using the `TF_OperationDescription` C API (wrapped in the language's FFI):
-
-- Start with `TF_NewOperation()` to create the `TF_OperationDescription*`.
-- Call `TF_AddInput()` or `TF_AddInputList()` once per input (depending on
- whether the input has a list type).
-- Call `TF_SetAttr*()` functions to set non-inferred attributes. May skip
- attributes with defaults if you don't want to override the default value.
-- Set optional fields if necessary:
- - `TF_SetDevice()`: force the operation onto a specific device.
- - `TF_AddControlInput()`: add requirements that another operation finish
- before this operation starts running
- - `TF_SetAttrString("_kernel")` to set the kernel label (rarely used)
- - `TF_ColocateWith()` to colocate one op with another
-- Call `TF_FinishOperation()` when done. This adds the operation to the graph,
- after which it can't be modified.
-
-The existing examples run the code generator as part of the build process (using
-a Bazel genrule). Alternatively, the code generator can be run by an automated
-cron process, possibly checking in the result. This creates a risk of divergence
-between the generated code and the `OpDef`s checked into the repository, but is
-useful for languages where code is expected to be generated ahead of time like
-`go get` for Go and `cargo ops` for Rust. At the other end of the spectrum, for
-some languages the code could be generated dynamically from
-[`tensorflow/core/ops/ops.pbtxt`](https://www.tensorflow.org/code/tensorflow/core/ops/ops.pbtxt).
-
-#### Handling Constants
-
-Calling code will be much more concise if users can provide constants to input
-arguments. The generated code should convert those constants to operations that
-are added to the graph and used as input to the op being instantiated.
-
-#### Optional parameters
-
-If the language allows for optional parameters to a function (like keyword
-arguments with defaults in Python), use them for optional attributes, operation
-names, devices, control inputs etc. In some languages, these optional parameters
-can be set using dynamic scopes (like "with" blocks in Python). Without these
-features, the library may resort to the "builder pattern", as is done in the C++
-version of the TensorFlow API.
-
-#### Name scopes
-
-It is a good idea to have support for naming graph operations using some sort of
-scoping hierarchy, especially considering the fact that TensorBoard relies on it
-to display large graphs in a reasonable way. The existing Python and C++ APIs
-take different approaches: In Python, the "directory" part of the name
-(everything up to the last "/") comes from `with` blocks. In effect, there is a
-thread-local stack with the scopes defining the name hierarchy. The last
-component of the name is either supplied explicitly by the user (using the
-optional `name` keyword argument) or defaults to the name of the type of the op
-being added. In C++ the "directory" part of the name is stored in an explicit
-`Scope` object. The `NewSubScope()` method appends to that part of the name and
-returns a new `Scope`. The last component of the name is set using the
-`WithOpName()` method, and like Python defaults to the name of the type of op
-being added. `Scope` objects are explicitly passed around to specify the name of
-the context.
-
-#### Wrappers
-
-It may make sense to keep the generated functions private for some ops so that
-wrapper functions that do a little bit of additional work can be used instead.
-This also gives an escape hatch for supporting features outside the scope of
-generated code.
-
-One use of a wrapper is for supporting `SparseTensor` input and output. A
-`SparseTensor` is a tuple of 3 dense tensors: indices, values, and shape. values
-is a vector size [n], shape is a vector size [rank], and indices is a matrix
-size [n, rank]. There are some sparse ops that use this triple to represent a
-single sparse tensor.
-
-Another reason to use wrappers is for ops that hold state. There are a few such
-ops (e.g. a variable) that have several companion ops for operating on that
-state. The Python API has classes for these ops where the constructor creates
-the op, and methods on that class add operations to the graph that operate on
-the state.
-
-#### Other Considerations
-
-- It is good to have a list of keywords used to rename op functions and
- arguments that collide with language keywords (or other symbols that will
- cause trouble, like the names of library functions or variables referenced
- in the generated code).
-- The function for adding a `Const` operation to a graph typically is a
- wrapper since the generated function will typically have redundant
- `DataType` inputs.
-
-### Gradients, functions and control flow
-
-At this time, support for gradients, functions and control flow operations ("if"
-and "while") is not available in languages other than Python. This will be
-updated when the [C API] provides necessary support.
-
-[C API]: https://www.tensorflow.org/code/tensorflow/c/c_api.h
diff --git a/tensorflow/docs_src/extend/leftnav_files b/tensorflow/docs_src/extend/leftnav_files
deleted file mode 100644
index 12315b711b..0000000000
--- a/tensorflow/docs_src/extend/leftnav_files
+++ /dev/null
@@ -1,7 +0,0 @@
-index.md
-architecture.md
-adding_an_op.md
-add_filesys.md
-new_data_formats.md
-language_bindings.md
-tool_developers/index.md
diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md
deleted file mode 100644
index 47a8344b70..0000000000
--- a/tensorflow/docs_src/extend/new_data_formats.md
+++ /dev/null
@@ -1,305 +0,0 @@
-# Reading custom file and record formats
-
-PREREQUISITES:
-
-* Some familiarity with C++.
-* Must have
- @{$install_sources$downloaded TensorFlow source}, and be
- able to build it.
-
-We divide the task of supporting a file format into two pieces:
-
-* File formats: We use a reader `tf.data.Dataset` to read raw *records* (which
- are typically represented by scalar string tensors, but can have more
- structure) from a file.
-* Record formats: We use decoder or parsing ops to turn a string record
- into tensors usable by TensorFlow.
-
-For example, to re-implement `tf.contrib.data.make_csv_dataset` function, we
-could use `tf.data.TextLineDataset` to extract the records, and then
-use `tf.data.Dataset.map` and `tf.decode_csv` to parses the CSV records from
-each line of text in the dataset.
-
-[TOC]
-
-## Writing a `Dataset` for a file format
-
-A `tf.data.Dataset` represents a sequence of *elements*, which can be the
-individual records in a file. There are several examples of "reader" datasets
-that are already built into TensorFlow:
-
-* `tf.data.TFRecordDataset`
- ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
-* `tf.data.FixedLengthRecordDataset`
- ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
-* `tf.data.TextLineDataset`
- ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
-
-Each of these implementations comprises three related classes:
-
-* A `tensorflow::DatasetOpKernel` subclass (e.g. `TextLineDatasetOp`), which
- tells TensorFlow how to construct a dataset object from the inputs to and
- attrs of an op, in its `MakeDataset()` method.
-
-* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
- which represents the *immutable* definition of the dataset itself, and tells
- TensorFlow how to construct an iterator object over that dataset, in its
- `MakeIteratorInternal()` method.
-
-* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
- `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
- of an iterator over a particular dataset, and tells TensorFlow how to get the
- next element from the iterator, in its `GetNextInternal()` method.
-
-The most important method is the `GetNextInternal()` method, since it defines
-how to actually read records from the file and represent them as one or more
-`Tensor` objects.
-
-To create a new reader dataset called (for example) `MyReaderDataset`, you will
-need to:
-
-1. In C++, define subclasses of `tensorflow::DatasetOpKernel`,
- `tensorflow::GraphDatasetBase`, and `tensorflow::DatasetIterator<Dataset>`
- that implement the reading logic.
-2. In C++, register a new reader op and kernel with the name
- `"MyReaderDataset"`.
-3. In Python, define a subclass of `tf.data.Dataset` called `MyReaderDataset`.
-
-You can put all the C++ code in a single file, such as
-`my_reader_dataset_op.cc`. It will help if you are
-familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
-can be used as a starting point for your implementation:
-
-```c++
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace myproject {
-namespace {
-
-using ::tensorflow::DT_STRING;
-using ::tensorflow::PartialTensorShape;
-using ::tensorflow::Status;
-
-class MyReaderDatasetOp : public tensorflow::DatasetOpKernel {
- public:
-
- MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx)
- : DatasetOpKernel(ctx) {
- // Parse and validate any attrs that define the dataset using
- // `ctx->GetAttr()`, and store them in member variables.
- }
-
- void MakeDataset(tensorflow::OpKernelContext* ctx,
- tensorflow::DatasetBase** output) override {
- // Parse and validate any input tensors 0that define the dataset using
- // `ctx->input()` or the utility function
- // `ParseScalarArgument<T>(ctx, &arg)`.
-
- // Create the dataset object, passing any (already-validated) arguments from
- // attrs or input tensors.
- *output = new Dataset(ctx);
- }
-
- private:
- class Dataset : public tensorflow::GraphDatasetBase {
- public:
- Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
-
- std::unique_ptr<tensorflow::IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<tensorflow::IteratorBase>(new Iterator(
- {this, tensorflow::strings::StrCat(prefix, "::MyReader")}));
- }
-
- // Record structure: Each record is represented by a scalar string tensor.
- //
- // Dataset elements can have a fixed number of components of different
- // types and shapes; replace the following two methods to customize this
- // aspect of the dataset.
- const tensorflow::DataTypeVector& output_dtypes() const override {
- static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING});
- return *dtypes;
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- static std::vector<PartialTensorShape>* shapes =
- new std::vector<PartialTensorShape>({{}});
- return *shapes;
- }
-
- string DebugString() const override { return "MyReaderDatasetOp::Dataset"; }
-
- protected:
- // Optional: Implementation of `GraphDef` serialization for this dataset.
- //
- // Implement this method if you want to be able to save and restore
- // instances of this dataset (and any iterators over it).
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- tensorflow::Node** output) const override {
- // Construct nodes to represent any of the input tensors from this
- // object's member variables using `b->AddScalar()` and `b->AddVector()`.
- std::vector<tensorflow::Node*> input_tensors;
- TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public tensorflow::DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params), i_(0) {}
-
- // Implementation of the reading logic.
- //
- // The example implementation in this file yields the string "MyReader!"
- // ten times. In general there are three cases:
- //
- // 1. If an element is successfully read, store it as one or more tensors
- // in `*out_tensors`, set `*end_of_sequence = false` and return
- // `Status::OK()`.
- // 2. If the end of input is reached, set `*end_of_sequence = true` and
- // return `Status::OK()`.
- // 3. If an error occurs, return an error status using one of the helper
- // functions from "tensorflow/core/lib/core/errors.h".
- Status GetNextInternal(tensorflow::IteratorContext* ctx,
- std::vector<tensorflow::Tensor>* out_tensors,
- bool* end_of_sequence) override {
- // NOTE: `GetNextInternal()` may be called concurrently, so it is
- // recommended that you protect the iterator state with a mutex.
- tensorflow::mutex_lock l(mu_);
- if (i_ < 10) {
- // Create a scalar string tensor and add it to the output.
- tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
- record_tensor.scalar<string>()() = "MyReader!";
- out_tensors->emplace_back(std::move(record_tensor));
- ++i_;
- *end_of_sequence = false;
- } else {
- *end_of_sequence = true;
- }
- return Status::OK();
- }
-
- protected:
- // Optional: Implementation of iterator state serialization for this
- // iterator.
- //
- // Implement these two methods if you want to be able to save and restore
- // instances of this iterator.
- Status SaveInternal(tensorflow::IteratorStateWriter* writer) override {
- tensorflow::mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
- return Status::OK();
- }
- Status RestoreInternal(tensorflow::IteratorContext* ctx,
- tensorflow::IteratorStateReader* reader) override {
- tensorflow::mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
- return Status::OK();
- }
-
- private:
- tensorflow::mutex mu_;
- int64 i_ GUARDED_BY(mu_);
- };
- };
-};
-
-// Register the op definition for MyReaderDataset.
-//
-// Dataset ops always have a single output, of type `variant`, which represents
-// the constructed `Dataset` object.
-//
-// Add any attrs and input tensors that define the dataset here.
-REGISTER_OP("MyReaderDataset")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(tensorflow::shape_inference::ScalarShape);
-
-// Register the kernel implementation for MyReaderDataset.
-REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU),
- MyReaderDatasetOp);
-
-} // namespace
-} // namespace myproject
-```
-
-The last step is to build the C++ code and add a Python wrapper. The easiest way
-to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
-library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
-that subclasses `tf.data.Dataset` to wrap it. An example Python program is
-given here:
-
-```python
-import tensorflow as tf
-
-# Assumes the file is in the current working directory.
-my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
-
-class MyReaderDataset(tf.data.Dataset):
-
- def __init__(self):
- super(MyReaderDataset, self).__init__()
- # Create any input attrs or tensors as members of this class.
-
- def _as_variant_tensor(self):
- # Actually construct the graph node for the dataset op.
- #
- # This method will be invoked when you create an iterator on this dataset
- # or a dataset derived from it.
- return my_reader_dataset_module.my_reader_dataset()
-
- # The following properties define the structure of each element: a scalar
- # `tf.string` tensor. Change these properties to match the `output_dtypes()`
- # and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
- # the structure of each element.
- @property
- def output_types(self):
- return tf.string
-
- @property
- def output_shapes(self):
- return tf.TensorShape([])
-
- @property
- def output_classes(self):
- return tf.Tensor
-
-if __name__ == "__main__":
- # Create a MyReaderDataset and print its elements.
- with tf.Session() as sess:
- iterator = MyReaderDataset().make_one_shot_iterator()
- next_element = iterator.get_next()
- try:
- while True:
- print(sess.run(next_element)) # Prints "MyReader!" ten times.
- except tf.errors.OutOfRangeError:
- pass
-```
-
-You can see some examples of `Dataset` wrapper classes in
-[`tensorflow/python/data/ops/dataset_ops.py`](https://www.tensorflow.org/code/tensorflow/python/data/ops/dataset_ops.py).
-
-## Writing an Op for a record format
-
-Generally this is an ordinary op that takes a scalar string record as input, and
-so follow @{$adding_an_op$the instructions to add an Op}.
-You may optionally take a scalar string key as input, and include that in error
-messages reporting improperly formatted data. That way users can more easily
-track down where the bad data came from.
-
-Examples of Ops useful for decoding records:
-
-* `tf.parse_single_example` (and `tf.parse_example`)
-* `tf.decode_csv`
-* `tf.decode_raw`
-
-Note that it can be useful to use multiple Ops to decode a particular record
-format. For example, you may have an image saved as a string in
-[a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
-Depending on the format of that image, you might take the corresponding output
-from a `tf.parse_single_example` op and call `tf.image.decode_jpeg`,
-`tf.image.decode_png`, or `tf.decode_raw`. It is common to take the output
-of `tf.decode_raw` and use `tf.slice` and `tf.reshape` to extract pieces.
diff --git a/tensorflow/docs_src/extend/tool_developers/index.md b/tensorflow/docs_src/extend/tool_developers/index.md
deleted file mode 100644
index f02cd23be8..0000000000
--- a/tensorflow/docs_src/extend/tool_developers/index.md
+++ /dev/null
@@ -1,186 +0,0 @@
-# A Tool Developer's Guide to TensorFlow Model Files
-
-Most users shouldn't need to care about the internal details of how TensorFlow
-stores data on disk, but you might if you're a tool developer. For example, you
-may want to analyze models, or convert back and forth between TensorFlow and
-other formats. This guide tries to explain some of the details of how you can
-work with the main files that hold model data, to make it easier to develop
-those kind of tools.
-
-[TOC]
-
-## Protocol Buffers
-
-All of TensorFlow's file formats are based on
-[Protocol Buffers](https://developers.google.com/protocol-buffers/?hl=en), so to
-start it's worth getting familiar with how they work. The summary is that you
-define data structures in text files, and the protobuf tools generate classes in
-C, Python, and other languages that can load, save, and access the data in a
-friendly way. We often refer to Protocol Buffers as protobufs, and I'll use
-that convention in this guide.
-
-## GraphDef
-
-The foundation of computation in TensorFlow is the `Graph` object. This holds a
-network of nodes, each representing one operation, connected to each other as
-inputs and outputs. After you've created a `Graph` object, you can save it out
-by calling `as_graph_def()`, which returns a `GraphDef` object.
-
-The GraphDef class is an object created by the ProtoBuf library from the
-definition in
-[tensorflow/core/framework/graph.proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto). The protobuf tools parse
-this text file, and generate the code to load, store, and manipulate graph
-definitions. If you see a standalone TensorFlow file representing a model, it's
-likely to contain a serialized version of one of these `GraphDef` objects
-saved out by the protobuf code.
-
-This generated code is used to save and load the GraphDef files from disk. The code that actually loads the model looks like this:
-
-```python
-graph_def = graph_pb2.GraphDef()
-```
-
-This line creates an empty `GraphDef` object, the class that's been created
-from the textual definition in graph.proto. This is the object we're going to
-populate with the data from our file.
-
-```python
-with open(FLAGS.graph, "rb") as f:
-```
-
-Here we get a file handle for the path we've passed in to the script
-
-```python
- if FLAGS.input_binary:
- graph_def.ParseFromString(f.read())
- else:
- text_format.Merge(f.read(), graph_def)
-```
-
-## Text or Binary?
-
-There are actually two different formats that a ProtoBuf can be saved in.
-TextFormat is a human-readable form, which makes it nice for debugging and
-editing, but can get large when there's numerical data like weights stored in
-it. You can see a small example of that in
-[graph_run_run2.pbtxt](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/demo/data/graph_run_run2.pbtxt).
-
-Binary format files are a lot smaller than their text equivalents, even though
-they're not as readable for us. In this script, we ask the user to supply a
-flag indicating whether the input file is binary or text, so we know the right
-function to call. You can find an example of a large binary file inside the
-[inception_v3 archive](https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz),
-as `inception_v3_2016_08_28_frozen.pb`.
-
-The API itself can be a bit confusing - the binary call is actually
-`ParseFromString()`, whereas you use a utility function from the `text_format`
-module to load textual files.
-
-## Nodes
-
-Once you've loaded a file into the `graph_def` variable, you can now access the
-data inside it. For most practical purposes, the important section is the list
-of nodes stored in the node member. Here's the code that loops through those:
-
-```python
-for node in graph_def.node
-```
-
-Each node is a `NodeDef` object, defined in
-[tensorflow/core/framework/node_def.proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto). These
-are the fundamental building blocks of TensorFlow graphs, with each one defining
-a single operation along with its input connections. Here are the members of a
-`NodeDef`, and what they mean.
-
-### `name`
-
-Every node should have a unique identifier that's not used by any other nodes
-in the graph. If you don't specify one as you're building a graph using the
-Python API, one reflecting the name of operation, such as "MatMul",
-concatenated with a monotonically increasing number, such as "5", will be
-picked for you. The name is used when defining the connections between nodes,
-and when setting inputs and outputs for the whole graph when it's run.
-
-### `op`
-
-This defines what operation to run, for example `"Add"`, `"MatMul"`, or
-`"Conv2D"`. When a graph is run, this op name is looked up in a registry to
-find an implementation. The registry is populated by calls to the
-`REGISTER_OP()` macro, like those in
-[tensorflow/core/ops/nn_ops.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops.cc).
-
-### `input`
-
-A list of strings, each one of which is the name of another node, optionally
-followed by a colon and an output port number. For example, a node with two
-inputs might have a list like `["some_node_name", "another_node_name"]`, which
-is equivalent to `["some_node_name:0", "another_node_name:0"]`, and defines the
-node's first input as the first output from the node with the name
-`"some_node_name"`, and a second input from the first output of
-`"another_node_name"`
-
-### `device`
-
-In most cases you can ignore this, since it defines where to run a node in a
-distributed environment, or when you want to force the operation onto CPU or
-GPU.
-
-### `attr`
-
-This is a key/value store holding all the attributes of a node. These are the
-permanent properties of nodes, things that don't change at runtime such as the
-size of filters for convolutions, or the values of constant ops. Because there
-can be so many different types of attribute values, from strings, to ints, to
-arrays of tensor values, there's a separate protobuf file defining the data
-structure that holds them, in
-[tensorflow/core/framework/attr_value.proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto).
-
-Each attribute has a unique name string, and the expected attributes are listed
-when the operation is defined. If an attribute isn't present in a node, but it
-has a default listed in the operation definition, that default is used when the
-graph is created.
-
-You can access all of these members by calling `node.name`, `node.op`, etc. in
-Python. The list of nodes stored in the `GraphDef` is a full definition of the
-model architecture.
-
-## Freezing
-
-One confusing part about this is that the weights usually aren't stored inside
-the file format during training. Instead, they're held in separate checkpoint
-files, and there are `Variable` ops in the graph that load the latest values
-when they're initialized. It's often not very convenient to have separate files
-when you're deploying to production, so there's the
-[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) script that takes a graph definition and a set
-of checkpoints and freezes them together into a single file.
-
-What this does is load the `GraphDef`, pull in the values for all the variables
-from the latest checkpoint file, and then replace each `Variable` op with a
-`Const` that has the numerical data for the weights stored in its attributes
-It then strips away all the extraneous nodes that aren't used for forward
-inference, and saves out the resulting `GraphDef` into an output file.
-
-## Weight Formats
-
-If you're dealing with TensorFlow models that represent neural networks, one of
-the most common problems is extracting and interpreting the weight values. A
-common way to store them, for example in graphs created by the freeze_graph
-script, is as `Const` ops containing the weights as `Tensors`. These are
-defined in
-[tensorflow/core/framework/tensor.proto](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto), and contain information
-about the size and type of the data, as well as the values themselves. In
-Python, you get a `TensorProto` object from a `NodeDef` representing a `Const`
-op by calling something like `some_node_def.attr['value'].tensor`.
-
-This will give you an object representing the weights data. The data itself
-will be stored in one of the lists with the suffix _val as indicated by the
-type of the object, for example `float_val` for 32-bit float data types.
-
-The ordering of convolution weight values is often tricky to deal with when
-converting between different frameworks. In TensorFlow, the filter weights for
-the `Conv2D` operation are stored on the second input, and are expected to be
-in the order `[filter_height, filter_width, input_depth, output_depth]`, where
-filter_count increasing by one means moving to an adjacent value in memory.
-
-Hopefully this rundown gives you a better idea of what's going on inside
-TensorFlow model files, and will help you if you ever need to manipulate them.
diff --git a/tensorflow/docs_src/extras/README.txt b/tensorflow/docs_src/extras/README.txt
deleted file mode 100644
index 765809a762..0000000000
--- a/tensorflow/docs_src/extras/README.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-This directory holds extra files we'd like to be able
-to link to and serve from within tensorflow.org.
-They are excluded from versioning. \ No newline at end of file
diff --git a/tensorflow/docs_src/guide/autograph.md b/tensorflow/docs_src/guide/autograph.md
deleted file mode 100644
index 823e1c6d6b..0000000000
--- a/tensorflow/docs_src/guide/autograph.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# AutoGraph: Easy control flow for graphs
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
diff --git a/tensorflow/docs_src/guide/checkpoints.md b/tensorflow/docs_src/guide/checkpoints.md
deleted file mode 100644
index e1add29852..0000000000
--- a/tensorflow/docs_src/guide/checkpoints.md
+++ /dev/null
@@ -1,238 +0,0 @@
-# Checkpoints
-
-This document examines how to save and restore TensorFlow models built with
-Estimators. TensorFlow provides two model formats:
-
-* checkpoints, which is a format dependent on the code that created
- the model.
-* SavedModel, which is a format independent of the code that created
- the model.
-
-This document focuses on checkpoints. For details on `SavedModel`, see the
-@{$saved_model$Saving and Restoring} guide.
-
-
-## Sample code
-
-This document relies on the same
-[Iris classification example](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py) detailed in @{$premade_estimators$Getting Started with TensorFlow}.
-To download and access the example, invoke the following two commands:
-
-```shell
-git clone https://github.com/tensorflow/models/
-cd models/samples/core/get_started
-```
-
-Most of the code snippets in this document are minor variations
-on `premade_estimator.py`.
-
-
-## Saving partially-trained models
-
-Estimators automatically write the following to disk:
-
-* **checkpoints**, which are versions of the model created during training.
-* **event files**, which contain information that
- [TensorBoard](https://developers.google.com/machine-learning/glossary/#TensorBoard)
- uses to create visualizations.
-
-To specify the top-level directory in which the Estimator stores its
-information, assign a value to the optional `model_dir` argument of *any*
-`Estimator`'s constructor.
-Taking `DNNClassifier` as an example,
-the following code sets the `model_dir`
-argument to the `models/iris` directory:
-
-```python
-classifier = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- hidden_units=[10, 10],
- n_classes=3,
- model_dir='models/iris')
-```
-
-Suppose you call the Estimator's `train` method. For example:
-
-
-```python
-classifier.train(
- input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
- steps=200)
-```
-
-As suggested by the following diagrams, the first call to `train`
-adds checkpoints and other files to the `model_dir` directory:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/first_train_calls.png">
-</div>
-<div style="text-align: center">
-The first call to train().
-</div>
-
-
-To see the objects in the created `model_dir` directory on a
-UNIX-based system, just call `ls` as follows:
-
-```none
-$ ls -1 models/iris
-checkpoint
-events.out.tfevents.timestamp.hostname
-graph.pbtxt
-model.ckpt-1.data-00000-of-00001
-model.ckpt-1.index
-model.ckpt-1.meta
-model.ckpt-200.data-00000-of-00001
-model.ckpt-200.index
-model.ckpt-200.meta
-```
-
-The preceding `ls` command shows that the Estimator created checkpoints
-at steps 1 (the start of training) and 200 (the end of training).
-
-
-### Default checkpoint directory
-
-If you don't specify `model_dir` in an Estimator's constructor, the Estimator
-writes checkpoint files to a temporary directory chosen by Python's
-[tempfile.mkdtemp](https://docs.python.org/3/library/tempfile.html#tempfile.mkdtemp)
-function. For example, the following Estimator constructor does *not* specify
-the `model_dir` argument:
-
-```python
-classifier = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- hidden_units=[10, 10],
- n_classes=3)
-
-print(classifier.model_dir)
-```
-
-The `tempfile.mkdtemp` function picks a secure, temporary directory
-appropriate for your operating system. For example, a typical temporary
-directory on macOS might be something like the following:
-
-```None
-/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa
-```
-
-### Checkpointing Frequency
-
-By default, the Estimator saves
-[checkpoints](https://developers.google.com/machine-learning/glossary/#checkpoint)
-in the `model_dir` according to the following schedule:
-
-* Writes a checkpoint every 10 minutes (600 seconds).
-* Writes a checkpoint when the `train` method starts (first iteration)
- and completes (final iteration).
-* Retains only the 5 most recent checkpoints in the directory.
-
-You may alter the default schedule by taking the following steps:
-
-1. Create a `tf.estimator.RunConfig` object that defines the
- desired schedule.
-2. When instantiating the Estimator, pass that `RunConfig` object to the
- Estimator's `config` argument.
-
-For example, the following code changes the checkpointing schedule to every
-20 minutes and retains the 10 most recent checkpoints:
-
-```python
-my_checkpointing_config = tf.estimator.RunConfig(
- save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
- keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
-)
-
-classifier = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- hidden_units=[10, 10],
- n_classes=3,
- model_dir='models/iris',
- config=my_checkpointing_config)
-```
-
-## Restoring your model
-
-The first time you call an Estimator's `train` method, TensorFlow saves a
-checkpoint to the `model_dir`. Each subsequent call to the Estimator's
-`train`, `evaluate`, or `predict` method causes the following:
-
-1. The Estimator builds the model's
- [graph](https://developers.google.com/machine-learning/glossary/#graph)
- by running the `model_fn()`. (For details on the `model_fn()`, see
- @{$custom_estimators$Creating Custom Estimators.})
-2. The Estimator initializes the weights of the new model from the data
- stored in the most recent checkpoint.
-
-In other words, as the following illustration suggests, once checkpoints
-exist, TensorFlow rebuilds the model each time you call `train()`,
-`evaluate()`, or `predict()`.
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/subsequent_calls.png">
-</div>
-<div style="text-align: center">
-Subsequent calls to train(), evaluate(), or predict()
-</div>
-
-
-### Avoiding a bad restoration
-
-Restoring a model's state from a checkpoint only works if the model
-and checkpoint are compatible. For example, suppose you trained a
-`DNNClassifier` Estimator containing two hidden layers,
-each having 10 nodes:
-
-```python
-classifier = tf.estimator.DNNClassifier(
- feature_columns=feature_columns,
- hidden_units=[10, 10],
- n_classes=3,
- model_dir='models/iris')
-
-classifier.train(
- input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
- steps=200)
-```
-
-After training (and, therefore, after creating checkpoints in `models/iris`),
-imagine that you changed the number of neurons in each hidden layer from 10 to
-20 and then attempted to retrain the model:
-
-``` python
-classifier2 = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- hidden_units=[20, 20], # Change the number of neurons in the model.
- n_classes=3,
- model_dir='models/iris')
-
-classifier.train(
- input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
- steps=200)
-```
-
-Since the state in the checkpoint is incompatible with the model described
-in `classifier2`, retraining fails with the following error:
-
-```None
-...
-InvalidArgumentError (see above for traceback): tensor_name =
-dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
-does not match the shape stored in checkpoint: [20]
-```
-
-To run experiments in which you train and compare slightly different
-versions of a model, save a copy of the code that created each
-`model_dir`, possibly by creating a separate git branch for each version.
-This separation will keep your checkpoints recoverable.
-
-## Summary
-
-Checkpoints provide an easy automatic mechanism for saving and restoring
-models created by Estimators.
-
-See the @{$saved_model$Saving and Restoring} guide for details about:
-
-* Saving and restoring models using low-level TensorFlow APIs.
-* Exporting and importing models in the SavedModel format, which is a
- language-neutral, recoverable, serialization format.
diff --git a/tensorflow/docs_src/guide/custom_estimators.md b/tensorflow/docs_src/guide/custom_estimators.md
deleted file mode 100644
index 199a0e93de..0000000000
--- a/tensorflow/docs_src/guide/custom_estimators.md
+++ /dev/null
@@ -1,602 +0,0 @@
-
-# Creating Custom Estimators
-
-This document introduces custom Estimators. In particular, this document
-demonstrates how to create a custom `tf.estimator.Estimator` that
-mimics the behavior of the pre-made Estimator
-`tf.estimator.DNNClassifier` in solving the Iris problem. See
-the @{$premade_estimators$Pre-Made Estimators chapter} for details
-on the Iris problem.
-
-To download and access the example code invoke the following two commands:
-
-```shell
-git clone https://github.com/tensorflow/models/
-cd models/samples/core/get_started
-```
-
-In this document we will be looking at
-[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py).
-You can run it with the following command:
-
-```bsh
-python custom_estimator.py
-```
-
-If you are feeling impatient, feel free to compare and contrast
-[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
-with
-[`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
-(which is in the same directory).
-
-
-
-## Pre-made vs. custom
-
-As the following figure shows, pre-made Estimators are subclasses of the
-`tf.estimator.Estimator` base class, while custom Estimators are an instance
-of tf.estimator.Estimator:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="Premade estimators are sub-classes of `Estimator`. Custom Estimators are usually (direct) instances of `Estimator`"
- src="../images/custom_estimators/estimator_types.png">
-</div>
-<div style="text-align: center">
-Pre-made and custom Estimators are all Estimators.
-</div>
-
-Pre-made Estimators are fully baked. Sometimes though, you need more control
-over an Estimator's behavior. That's where custom Estimators come in. You can
-create a custom Estimator to do just about anything. If you want hidden layers
-connected in some unusual fashion, write a custom Estimator. If you want to
-calculate a unique
-[metric](https://developers.google.com/machine-learning/glossary/#metric)
-for your model, write a custom Estimator. Basically, if you want an Estimator
-optimized for your specific problem, write a custom Estimator.
-
-A model function (or `model_fn`) implements the ML algorithm. The
-only difference between working with pre-made Estimators and custom Estimators
-is:
-
-* With pre-made Estimators, someone already wrote the model function for you.
-* With custom Estimators, you must write the model function.
-
-Your model function could implement a wide range of algorithms, defining all
-sorts of hidden layers and metrics. Like input functions, all model functions
-must accept a standard group of input parameters and return a standard group of
-output values. Just as input functions can leverage the Dataset API, model
-functions can leverage the Layers API and the Metrics API.
-
-Let's see how to solve the Iris problem with a custom Estimator. A quick
-reminder--here's the organization of the Iris model that we're trying to mimic:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs"
- src="../images/custom_estimators/full_network.png">
-</div>
-<div style="text-align: center">
-Our implementation of Iris contains four features, two hidden layers,
-and a logits output layer.
-</div>
-
-## Write an Input function
-
-Our custom Estimator implementation uses the same input function as our
-@{$premade_estimators$pre-made Estimator implementation}, from
-[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py).
-Namely:
-
-```python
-def train_input_fn(features, labels, batch_size):
- """An input function for training"""
- # Convert the inputs to a Dataset.
- dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
-
- # Shuffle, repeat, and batch the examples.
- dataset = dataset.shuffle(1000).repeat().batch(batch_size)
-
- # Return the read end of the pipeline.
- return dataset.make_one_shot_iterator().get_next()
-```
-
-This input function builds an input pipeline that yields batches of
-`(features, labels)` pairs, where `features` is a dictionary features.
-
-## Create feature columns
-
-As detailed in the @{$premade_estimators$Premade Estimators} and
-@{$feature_columns$Feature Columns} chapters, you must define
-your model's feature columns to specify how the model should use each feature.
-Whether working with pre-made Estimators or custom Estimators, you define
-feature columns in the same fashion.
-
-The following code creates a simple `numeric_column` for each input feature,
-indicating that the value of the input feature should be used directly as an
-input to the model:
-
-```python
-# Feature columns describe how to use the input.
-my_feature_columns = []
-for key in train_x.keys():
- my_feature_columns.append(tf.feature_column.numeric_column(key=key))
-```
-
-## Write a model function
-
-The model function we'll use has the following call signature:
-
-```python
-def my_model_fn(
- features, # This is batch_features from input_fn
- labels, # This is batch_labels from input_fn
- mode, # An instance of tf.estimator.ModeKeys
- params): # Additional configuration
-```
-
-The first two arguments are the batches of features and labels returned from
-the input function; that is, `features` and `labels` are the handles to the
-data your model will use. The `mode` argument indicates whether the caller is
-requesting training, predicting, or evaluation.
-
-The caller may pass `params` to an Estimator's constructor. Any `params` passed
-to the constructor are in turn passed on to the `model_fn`. In
-[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
-the following lines create the estimator and set the params to configure the
-model. This configuration step is similar to how we configured the `tf.estimator.DNNClassifier` in
-@{$premade_estimators}.
-
-```python
-classifier = tf.estimator.Estimator(
- model_fn=my_model_fn,
- params={
- 'feature_columns': my_feature_columns,
- # Two hidden layers of 10 nodes each.
- 'hidden_units': [10, 10],
- # The model must choose between 3 classes.
- 'n_classes': 3,
- })
-```
-
-To implement a typical model function, you must do the following:
-
-* [Define the model](#define_the_model).
-* Specify additional calculations for each of
- the [three different modes](#modes):
- * [Predict](#predict)
- * [Evaluate](#evaluate)
- * [Train](#train)
-
-## Define the model
-
-The basic deep neural network model must define the following three sections:
-
-* An [input layer](https://developers.google.com/machine-learning/glossary/#input_layer)
-* One or more [hidden layers](https://developers.google.com/machine-learning/glossary/#hidden_layer)
-* An [output layer](https://developers.google.com/machine-learning/glossary/#output_layer)
-
-### Define the input layer
-
-The first line of the `model_fn` calls `tf.feature_column.input_layer` to
-convert the feature dictionary and `feature_columns` into input for your model,
-as follows:
-
-```python
- # Use `input_layer` to apply the feature columns.
- net = tf.feature_column.input_layer(features, params['feature_columns'])
-```
-
-The preceding line applies the transformations defined by your feature columns,
-creating the model's input layer.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="A diagram of the input layer, in this case a 1:1 mapping from raw-inputs to features."
- src="../images/custom_estimators/input_layer.png">
-</div>
-
-
-### Hidden Layers
-
-If you are creating a deep neural network, you must define one or more hidden
-layers. The Layers API provides a rich set of functions to define all types of
-hidden layers, including convolutional, pooling, and dropout layers. For Iris,
-we're simply going to call `tf.layers.dense` to create hidden layers, with
-dimensions defined by `params['hidden_layers']`. In a `dense` layer each node
-is connected to every node in the preceding layer. Here's the relevant code:
-
-``` python
- # Build the hidden layers, sized according to the 'hidden_units' param.
- for units in params['hidden_units']:
- net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
-```
-
-* The `units` parameter defines the number of output neurons in a given layer.
-* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#activation_function) —
- [Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this
- case.
-
-The variable `net` here signifies the current top layer of the network. During
-the first iteration, `net` signifies the input layer. On each loop iteration
-`tf.layers.dense` creates a new layer, which takes the previous layer's output
-as its input, using the variable `net`.
-
-After creating two hidden layers, our network looks as follows. For
-simplicity, the figure does not show all the units in each layer.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="The input layer with two hidden layers added."
- src="../images/custom_estimators/add_hidden_layer.png">
-</div>
-
-Note that `tf.layers.dense` provides many additional capabilities, including
-the ability to set a multitude of regularization parameters. For the sake of
-simplicity, though, we're going to simply accept the default values of the
-other parameters.
-
-### Output Layer
-
-We'll define the output layer by calling `tf.layers.dense` yet again, this
-time without an activation function:
-
-```python
- # Compute logits (1 per class).
- logits = tf.layers.dense(net, params['n_classes'], activation=None)
-```
-
-Here, `net` signifies the final hidden layer. Therefore, the full set of layers
-is now connected as follows:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="A logit output layer connected to the top hidden layer"
- src="../images/custom_estimators/add_logits.png">
-</div>
-<div style="text-align: center">
-The final hidden layer feeds into the output layer.
-</div>
-
-When defining an output layer, the `units` parameter specifies the number of
-outputs. So, by setting `units` to `params['n_classes']`, the model produces
-one output value per class. Each element of the output vector will contain the
-score, or "logit", calculated for the associated class of Iris: Setosa,
-Versicolor, or Virginica, respectively.
-
-Later on, these logits will be transformed into probabilities by the
-`tf.nn.softmax` function.
-
-## Implement training, evaluation, and prediction {#modes}
-
-The final step in creating a model function is to write branching code that
-implements prediction, evaluation, and training.
-
-The model function gets invoked whenever someone calls the Estimator's `train`,
-`evaluate`, or `predict` methods. Recall that the signature for the model
-function looks like this:
-
-``` python
-def my_model_fn(
- features, # This is batch_features from input_fn
- labels, # This is batch_labels from input_fn
- mode, # An instance of tf.estimator.ModeKeys, see below
- params): # Additional configuration
-```
-
-Focus on that third argument, mode. As the following table shows, when someone
-calls `train`, `evaluate`, or `predict`, the Estimator framework invokes your model
-function with the mode parameter set as follows:
-
-| Estimator method | Estimator Mode |
-|:---------------------------------|:------------------|
-|`tf.estimator.Estimator.train` |`tf.estimator.ModeKeys.TRAIN` |
-|`tf.estimator.Estimator.evaluate` |`tf.estimator.ModeKeys.EVAL` |
-|`tf.estimator.Estimator.predict`|`tf.estimator.ModeKeys.PREDICT` |
-
-For example, suppose you instantiate a custom Estimator to generate an object
-named `classifier`. Then, you make the following call:
-
-``` python
-classifier = tf.estimator.Estimator(...)
-classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 500))
-```
-The Estimator framework then calls your model function with mode set to
-`ModeKeys.TRAIN`.
-
-Your model function must provide code to handle all three of the mode values.
-For each mode value, your code must return an instance of
-`tf.estimator.EstimatorSpec`, which contains the information the caller
-requires. Let's examine each mode.
-
-### Predict
-
-When the Estimator's `predict` method is called, the `model_fn` receives
-`mode = ModeKeys.PREDICT`. In this case, the model function must return a
-`tf.estimator.EstimatorSpec` containing the prediction.
-
-The model must have been trained prior to making a prediction. The trained model
-is stored on disk in the `model_dir` directory established when you
-instantiated the Estimator.
-
-The code to generate the prediction for this model looks as follows:
-
-```python
-# Compute predictions.
-predicted_classes = tf.argmax(logits, 1)
-if mode == tf.estimator.ModeKeys.PREDICT:
- predictions = {
- 'class_ids': predicted_classes[:, tf.newaxis],
- 'probabilities': tf.nn.softmax(logits),
- 'logits': logits,
- }
- return tf.estimator.EstimatorSpec(mode, predictions=predictions)
-```
-The prediction dictionary contains everything that your model returns when run
-in prediction mode.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="Additional outputs added to the output layer."
- src="../images/custom_estimators/add_predictions.png">
-</div>
-
-The `predictions` holds the following three key/value pairs:
-
-* `class_ids` holds the class id (0, 1, or 2) representing the model's
- prediction of the most likely species for this example.
-* `probabilities` holds the three probabilities (in this example, 0.02, 0.95,
- and 0.03)
-* `logit` holds the raw logit values (in this example, -1.3, 2.6, and -0.9)
-
-We return that dictionary to the caller via the `predictions` parameter of the
-`tf.estimator.EstimatorSpec`. The Estimator's
-`tf.estimator.Estimator.predict` method will yield these
-dictionaries.
-
-### Calculate the loss
-
-For both [training](#train) and [evaluation](#evaluate) we need to calculate the
-model's loss. This is the
-[objective](https://developers.google.com/machine-learning/glossary/#objective)
-that will be optimized.
-
-We can calculate the loss by calling `tf.losses.sparse_softmax_cross_entropy`.
-The value returned by this function will be approximately 0 at lowest,
-when the probability of the correct class (at index `label`) is near 1.0.
-The loss value returned is progressively larger as the probability of the
-correct class decreases.
-
-This function returns the average over the whole batch.
-
-```python
-# Compute loss.
-loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
-```
-
-### Evaluate
-
-When the Estimator's `evaluate` method is called, the `model_fn` receives
-`mode = ModeKeys.EVAL`. In this case, the model function must return a
-`tf.estimator.EstimatorSpec` containing the model's loss and optionally one
-or more metrics.
-
-Although returning metrics is optional, most custom Estimators do return at
-least one metric. TensorFlow provides a Metrics module `tf.metrics` to
-calculate common metrics. For brevity's sake, we'll only return accuracy. The
-`tf.metrics.accuracy` function compares our predictions against the
-true values, that is, against the labels provided by the input function. The
-`tf.metrics.accuracy` function requires the labels and predictions to have the
-same shape. Here's the call to `tf.metrics.accuracy`:
-
-``` python
-# Compute evaluation metrics.
-accuracy = tf.metrics.accuracy(labels=labels,
- predictions=predicted_classes,
- name='acc_op')
-```
-
-The `tf.estimator.EstimatorSpec` returned for evaluation
-typically contains the following information:
-
-* `loss`, which is the model's loss
-* `eval_metric_ops`, which is an optional dictionary of metrics.
-
-So, we'll create a dictionary containing our sole metric. If we had calculated
-other metrics, we would have added them as additional key/value pairs to that
-same dictionary. Then, we'll pass that dictionary in the `eval_metric_ops`
-argument of `tf.estimator.EstimatorSpec`. Here's the code:
-
-```python
-metrics = {'accuracy': accuracy}
-tf.summary.scalar('accuracy', accuracy[1])
-
-if mode == tf.estimator.ModeKeys.EVAL:
- return tf.estimator.EstimatorSpec(
- mode, loss=loss, eval_metric_ops=metrics)
-```
-
-The `tf.summary.scalar` will make accuracy available to TensorBoard
-in both `TRAIN` and `EVAL` modes. (More on this later).
-
-### Train
-
-When the Estimator's `train` method is called, the `model_fn` is called
-with `mode = ModeKeys.TRAIN`. In this case, the model function must return an
-`EstimatorSpec` that contains the loss and a training operation.
-
-Building the training operation will require an optimizer. We will use
-`tf.train.AdagradOptimizer` because we're mimicking the `DNNClassifier`, which
-also uses `Adagrad` by default. The `tf.train` package provides many other
-optimizers—feel free to experiment with them.
-
-Here is the code that builds the optimizer:
-
-``` python
-optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
-```
-
-Next, we build the training operation using the optimizer's
-`tf.train.Optimizer.minimize` method on the loss we calculated
-earlier.
-
-The `minimize` method also takes a `global_step` parameter. TensorFlow uses this
-parameter to count the number of training steps that have been processed
-(to know when to end a training run). Furthermore, the `global_step` is
-essential for TensorBoard graphs to work correctly. Simply call
-`tf.train.get_global_step` and pass the result to the `global_step`
-argument of `minimize`.
-
-Here's the code to train the model:
-
-``` python
-train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
-```
-
-The `tf.estimator.EstimatorSpec` returned for training
-must have the following fields set:
-
-* `loss`, which contains the value of the loss function.
-* `train_op`, which executes a training step.
-
-Here's our code to call `EstimatorSpec`:
-
-```python
-return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
-```
-
-The model function is now complete.
-
-## The custom Estimator
-
-Instantiate the custom Estimator through the Estimator base class as follows:
-
-```python
- # Build 2 hidden layer DNN with 10, 10 units respectively.
- classifier = tf.estimator.Estimator(
- model_fn=my_model_fn,
- params={
- 'feature_columns': my_feature_columns,
- # Two hidden layers of 10 nodes each.
- 'hidden_units': [10, 10],
- # The model must choose between 3 classes.
- 'n_classes': 3,
- })
-```
-Here the `params` dictionary serves the same purpose as the key-word
-arguments of `DNNClassifier`; that is, the `params` dictionary lets you
-configure your Estimator without modifying the code in the `model_fn`.
-
-The rest of the code to train, evaluate, and generate predictions using our
-Estimator is the same as in the
-@{$premade_estimators$Premade Estimators} chapter. For
-example, the following line will train the model:
-
-```python
-# Train the Model.
-classifier.train(
- input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
- steps=args.train_steps)
-```
-
-## TensorBoard
-
-You can view training results for your custom Estimator in TensorBoard. To see
-this reporting, start TensorBoard from your command line as follows:
-
-```bsh
-# Replace PATH with the actual path passed as model_dir
-tensorboard --logdir=PATH
-```
-
-Then, open TensorBoard by browsing to: [http://localhost:6006](http://localhost:6006)
-
-All the pre-made Estimators automatically log a lot of information to
-TensorBoard. With custom Estimators, however, TensorBoard only provides one
-default log (a graph of the loss) plus the information you explicitly tell
-TensorBoard to log. For the custom Estimator you just created, TensorBoard
-generates the following:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-
-<img style="display:block; margin: 0 auto"
- alt="Accuracy, 'scalar' graph from tensorboard"
- src="../images/custom_estimators/accuracy.png">
-
-<img style="display:block; margin: 0 auto"
- alt="loss 'scalar' graph from tensorboard"
- src="../images/custom_estimators/loss.png">
-
-<img style="display:block; margin: 0 auto"
- alt="steps/second 'scalar' graph from tensorboard"
- src="../images/custom_estimators/steps_per_second.png">
-</div>
-
-<div style="text-align: center">
-TensorBoard displays three graphs.
-</div>
-
-
-In brief, here's what the three graphs tell you:
-
-* global_step/sec: A performance indicator showing how many batches (gradient
- updates) we processed per second as the model trains.
-
-* loss: The loss reported.
-
-* accuracy: The accuracy is recorded by the following two lines:
-
- * `eval_metric_ops={'my_accuracy': accuracy}`, during evaluation.
- * `tf.summary.scalar('accuracy', accuracy[1])`, during training.
-
-These tensorboard graphs are one of the main reasons it's important to pass a
-`global_step` to your optimizer's `minimize` method. The model can't record
-the x-coordinate for these graphs without it.
-
-Note the following in the `my_accuracy` and `loss` graphs:
-
-* The orange line represents training.
-* The blue dot represents evaluation.
-
-During training, summaries (the orange line) are recorded periodically as
-batches are processed, which is why it becomes a graph spanning x-axis range.
-
-By contrast, evaluation produces only a single point on the graph for each call
-to `evaluate`. This point contains the average over the entire evaluation call.
-This has no width on the graph as it is evaluated entirely from the model state
-at a particular training step (from a single checkpoint).
-
-As suggested in the following figure, you may see and also selectively
-disable/enable the reporting using the controls on the left side.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="display:block; margin: 0 auto"
- alt="Check-boxes allowing the user to select which runs are shown."
- src="../images/custom_estimators/select_run.jpg">
-</div>
-<div style="text-align: center">
-Enable or disable reporting.
-</div>
-
-
-## Summary
-
-Although pre-made Estimators can be an effective way to quickly create new
-models, you will often need the additional flexibility that custom Estimators
-provide. Fortunately, pre-made and custom Estimators follow the same
-programming model. The only practical difference is that you must write a model
-function for custom Estimators; everything else is the same.
-
-For more details, be sure to check out:
-
-* The
- [official TensorFlow implementation of MNIST](https://github.com/tensorflow/models/tree/master/official/mnist),
- which uses a custom estimator.
-* The TensorFlow
- [official models repository](https://github.com/tensorflow/models/tree/master/official),
- which contains more curated examples using custom estimators.
-* This [TensorBoard video](https://youtu.be/eBbEDRsCmv4), which introduces
- TensorBoard.
-* The @{$low_level_intro$Low Level Introduction}, which demonstrates
- how to experiment directly with TensorFlow's low level APIs, making debugging
- easier.
diff --git a/tensorflow/docs_src/guide/datasets.md b/tensorflow/docs_src/guide/datasets.md
deleted file mode 100644
index bb18e8b79c..0000000000
--- a/tensorflow/docs_src/guide/datasets.md
+++ /dev/null
@@ -1,823 +0,0 @@
-# Importing Data
-
-The `tf.data` API enables you to build complex input pipelines from
-simple, reusable pieces. For example, the pipeline for an image model might
-aggregate data from files in a distributed file system, apply random
-perturbations to each image, and merge randomly selected images into a batch
-for training. The pipeline for a text model might involve extracting symbols
-from raw text data, converting them to embedding identifiers with a lookup
-table, and batching together sequences of different lengths. The `tf.data` API
-makes it easy to deal with large amounts of data, different data formats, and
-complicated transformations.
-
-The `tf.data` API introduces two new abstractions to TensorFlow:
-
-* A `tf.data.Dataset` represents a sequence of elements, in which
- each element contains one or more `Tensor` objects. For example, in an image
- pipeline, an element might be a single training example, with a pair of
- tensors representing the image data and a label. There are two distinct
- ways to create a dataset:
-
- * Creating a **source** (e.g. `Dataset.from_tensor_slices()`) constructs a
- dataset from
- one or more `tf.Tensor` objects.
-
- * Applying a **transformation** (e.g. `Dataset.batch()`) constructs a dataset
- from one or more `tf.data.Dataset` objects.
-
-* A `tf.data.Iterator` provides the main way to extract elements from a
- dataset. The operation returned by `Iterator.get_next()` yields the next
- element of a `Dataset` when executed, and typically acts as the interface
- between input pipeline code and your model. The simplest iterator is a
- "one-shot iterator", which is associated with a particular `Dataset` and
- iterates through it once. For more sophisticated uses, the
- `Iterator.initializer` operation enables you to reinitialize and parameterize
- an iterator with different datasets, so that you can, for example, iterate
- over training and validation data multiple times in the same program.
-
-## Basic mechanics
-
-This section of the guide describes the fundamentals of creating different kinds
-of `Dataset` and `Iterator` objects, and how to extract data from them.
-
-To start an input pipeline, you must define a *source*. For example, to
-construct a `Dataset` from some tensors in memory, you can use
-`tf.data.Dataset.from_tensors()` or
-`tf.data.Dataset.from_tensor_slices()`. Alternatively, if your input
-data are on disk in the recommended TFRecord format, you can construct a
-`tf.data.TFRecordDataset`.
-
-Once you have a `Dataset` object, you can *transform* it into a new `Dataset` by
-chaining method calls on the `tf.data.Dataset` object. For example, you
-can apply per-element transformations such as `Dataset.map()` (to apply a
-function to each element), and multi-element transformations such as
-`Dataset.batch()`. See the documentation for `tf.data.Dataset`
-for a complete list of transformations.
-
-The most common way to consume values from a `Dataset` is to make an
-**iterator** object that provides access to one element of the dataset at a time
-(for example, by calling `Dataset.make_one_shot_iterator()`). A
-`tf.data.Iterator` provides two operations: `Iterator.initializer`,
-which enables you to (re)initialize the iterator's state; and
-`Iterator.get_next()`, which returns `tf.Tensor` objects that correspond to the
-symbolic next element. Depending on your use case, you might choose a different
-type of iterator, and the options are outlined below.
-
-### Dataset structure
-
-A dataset comprises elements that each have the same structure. An element
-contains one or more `tf.Tensor` objects, called *components*. Each component
-has a `tf.DType` representing the type of elements in the tensor, and a
-`tf.TensorShape` representing the (possibly partially specified) static shape of
-each element. The `Dataset.output_types` and `Dataset.output_shapes` properties
-allow you to inspect the inferred types and shapes of each component of a
-dataset element. The *nested structure* of these properties map to the structure
-of an element, which may be a single tensor, a tuple of tensors, or a nested
-tuple of tensors. For example:
-
-```python
-dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
-print(dataset1.output_types) # ==> "tf.float32"
-print(dataset1.output_shapes) # ==> "(10,)"
-
-dataset2 = tf.data.Dataset.from_tensor_slices(
- (tf.random_uniform([4]),
- tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
-print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
-print(dataset2.output_shapes) # ==> "((), (100,))"
-
-dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
-print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
-print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
-```
-
-It is often convenient to give names to each component of an element, for
-example if they represent different features of a training example. In addition
-to tuples, you can use `collections.namedtuple` or a dictionary mapping strings
-to tensors to represent a single element of a `Dataset`.
-
-```python
-dataset = tf.data.Dataset.from_tensor_slices(
- {"a": tf.random_uniform([4]),
- "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
-print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
-print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
-```
-
-The `Dataset` transformations support datasets of any structure. When using the
-`Dataset.map()`, `Dataset.flat_map()`, and `Dataset.filter()` transformations,
-which apply a function to each element, the element structure determines the
-arguments of the function:
-
-```python
-dataset1 = dataset1.map(lambda x: ...)
-
-dataset2 = dataset2.flat_map(lambda x, y: ...)
-
-# Note: Argument destructuring is not available in Python 3.
-dataset3 = dataset3.filter(lambda x, (y, z): ...)
-```
-
-### Creating an iterator
-
-Once you have built a `Dataset` to represent your input data, the next step is to
-create an `Iterator` to access elements from that dataset. The `tf.data` API
-currently supports the following iterators, in increasing level of
-sophistication:
-
-* **one-shot**,
-* **initializable**,
-* **reinitializable**, and
-* **feedable**.
-
-A **one-shot** iterator is the simplest form of iterator, which only supports
-iterating once through a dataset, with no need for explicit initialization.
-One-shot iterators handle almost all of the cases that the existing queue-based
-input pipelines support, but they do not support parameterization. Using the
-example of `Dataset.range()`:
-
-```python
-dataset = tf.data.Dataset.range(100)
-iterator = dataset.make_one_shot_iterator()
-next_element = iterator.get_next()
-
-for i in range(100):
- value = sess.run(next_element)
- assert i == value
-```
-
-Note: Currently, one-shot iterators are the only type that is easily usable
-with an `Estimator`.
-
-An **initializable** iterator requires you to run an explicit
-`iterator.initializer` operation before using it. In exchange for this
-inconvenience, it enables you to *parameterize* the definition of the dataset,
-using one or more `tf.placeholder()` tensors that can be fed when you
-initialize the iterator. Continuing the `Dataset.range()` example:
-
-```python
-max_value = tf.placeholder(tf.int64, shape=[])
-dataset = tf.data.Dataset.range(max_value)
-iterator = dataset.make_initializable_iterator()
-next_element = iterator.get_next()
-
-# Initialize an iterator over a dataset with 10 elements.
-sess.run(iterator.initializer, feed_dict={max_value: 10})
-for i in range(10):
- value = sess.run(next_element)
- assert i == value
-
-# Initialize the same iterator over a dataset with 100 elements.
-sess.run(iterator.initializer, feed_dict={max_value: 100})
-for i in range(100):
- value = sess.run(next_element)
- assert i == value
-```
-
-A **reinitializable** iterator can be initialized from multiple different
-`Dataset` objects. For example, you might have a training input pipeline that
-uses random perturbations to the input images to improve generalization, and
-a validation input pipeline that evaluates predictions on unmodified data. These
-pipelines will typically use different `Dataset` objects that have the same
-structure (i.e. the same types and compatible shapes for each component).
-
-```python
-# Define training and validation datasets with the same structure.
-training_dataset = tf.data.Dataset.range(100).map(
- lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
-validation_dataset = tf.data.Dataset.range(50)
-
-# A reinitializable iterator is defined by its structure. We could use the
-# `output_types` and `output_shapes` properties of either `training_dataset`
-# or `validation_dataset` here, because they are compatible.
-iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
- training_dataset.output_shapes)
-next_element = iterator.get_next()
-
-training_init_op = iterator.make_initializer(training_dataset)
-validation_init_op = iterator.make_initializer(validation_dataset)
-
-# Run 20 epochs in which the training dataset is traversed, followed by the
-# validation dataset.
-for _ in range(20):
- # Initialize an iterator over the training dataset.
- sess.run(training_init_op)
- for _ in range(100):
- sess.run(next_element)
-
- # Initialize an iterator over the validation dataset.
- sess.run(validation_init_op)
- for _ in range(50):
- sess.run(next_element)
-```
-
-A **feedable** iterator can be used together with `tf.placeholder` to select
-what `Iterator` to use in each call to `tf.Session.run`, via the familiar
-`feed_dict` mechanism. It offers the same functionality as a reinitializable
-iterator, but it does not require you to initialize the iterator from the start
-of a dataset when you switch between iterators. For example, using the same
-training and validation example from above, you can use
-`tf.data.Iterator.from_string_handle` to define a feedable iterator
-that allows you to switch between the two datasets:
-
-```python
-# Define training and validation datasets with the same structure.
-training_dataset = tf.data.Dataset.range(100).map(
- lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
-validation_dataset = tf.data.Dataset.range(50)
-
-# A feedable iterator is defined by a handle placeholder and its structure. We
-# could use the `output_types` and `output_shapes` properties of either
-# `training_dataset` or `validation_dataset` here, because they have
-# identical structure.
-handle = tf.placeholder(tf.string, shape=[])
-iterator = tf.data.Iterator.from_string_handle(
- handle, training_dataset.output_types, training_dataset.output_shapes)
-next_element = iterator.get_next()
-
-# You can use feedable iterators with a variety of different kinds of iterator
-# (such as one-shot and initializable iterators).
-training_iterator = training_dataset.make_one_shot_iterator()
-validation_iterator = validation_dataset.make_initializable_iterator()
-
-# The `Iterator.string_handle()` method returns a tensor that can be evaluated
-# and used to feed the `handle` placeholder.
-training_handle = sess.run(training_iterator.string_handle())
-validation_handle = sess.run(validation_iterator.string_handle())
-
-# Loop forever, alternating between training and validation.
-while True:
- # Run 200 steps using the training dataset. Note that the training dataset is
- # infinite, and we resume from where we left off in the previous `while` loop
- # iteration.
- for _ in range(200):
- sess.run(next_element, feed_dict={handle: training_handle})
-
- # Run one pass over the validation dataset.
- sess.run(validation_iterator.initializer)
- for _ in range(50):
- sess.run(next_element, feed_dict={handle: validation_handle})
-```
-
-### Consuming values from an iterator
-
-The `Iterator.get_next()` method returns one or more `tf.Tensor` objects that
-correspond to the symbolic next element of an iterator. Each time these tensors
-are evaluated, they take the value of the next element in the underlying
-dataset. (Note that, like other stateful objects in TensorFlow, calling
-`Iterator.get_next()` does not immediately advance the iterator. Instead you
-must use the returned `tf.Tensor` objects in a TensorFlow expression, and pass
-the result of that expression to `tf.Session.run()` to get the next elements and
-advance the iterator.)
-
-If the iterator reaches the end of the dataset, executing
-the `Iterator.get_next()` operation will raise a `tf.errors.OutOfRangeError`.
-After this point the iterator will be in an unusable state, and you must
-initialize it again if you want to use it further.
-
-```python
-dataset = tf.data.Dataset.range(5)
-iterator = dataset.make_initializable_iterator()
-next_element = iterator.get_next()
-
-# Typically `result` will be the output of a model, or an optimizer's
-# training operation.
-result = tf.add(next_element, next_element)
-
-sess.run(iterator.initializer)
-print(sess.run(result)) # ==> "0"
-print(sess.run(result)) # ==> "2"
-print(sess.run(result)) # ==> "4"
-print(sess.run(result)) # ==> "6"
-print(sess.run(result)) # ==> "8"
-try:
- sess.run(result)
-except tf.errors.OutOfRangeError:
- print("End of dataset") # ==> "End of dataset"
-```
-
-A common pattern is to wrap the "training loop" in a `try`-`except` block:
-
-```python
-sess.run(iterator.initializer)
-while True:
- try:
- sess.run(result)
- except tf.errors.OutOfRangeError:
- break
-```
-
-If each element of the dataset has a nested structure, the return value of
-`Iterator.get_next()` will be one or more `tf.Tensor` objects in the same
-nested structure:
-
-```python
-dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
-dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
-dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
-
-iterator = dataset3.make_initializable_iterator()
-
-sess.run(iterator.initializer)
-next1, (next2, next3) = iterator.get_next()
-```
-
-Note that `next1`, `next2`, and `next3` are tensors produced by the
-same op/node (created by `Iterator.get_next()`). Therefore, evaluating *any* of
-these tensors will advance the iterator for all components. A typical consumer
-of an iterator will include all components in a single expression.
-
-### Saving iterator state
-
-The `tf.contrib.data.make_saveable_from_iterator` function creates a
-`SaveableObject` from an iterator, which can be used to save and
-restore the current state of the iterator (and, effectively, the whole input
-pipeline). A saveable object thus created can be added to `tf.train.Saver`
-variables list or the `tf.GraphKeys.SAVEABLE_OBJECTS` collection for saving and
-restoring in the same manner as a `tf.Variable`. Refer to
-@{$saved_model$Saving and Restoring} for details on how to save and restore
-variables.
-
-```python
-# Create saveable object from iterator.
-saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
-
-# Save the iterator state by adding it to the saveable objects collection.
-tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
-saver = tf.train.Saver()
-
-with tf.Session() as sess:
-
- if should_checkpoint:
- saver.save(path_to_checkpoint)
-
-# Restore the iterator state.
-with tf.Session() as sess:
- saver.restore(sess, path_to_checkpoint)
-```
-
-## Reading input data
-
-### Consuming NumPy arrays
-
-If all of your input data fit in memory, the simplest way to create a `Dataset`
-from them is to convert them to `tf.Tensor` objects and use
-`Dataset.from_tensor_slices()`.
-
-```python
-# Load the training data into two NumPy arrays, for example using `np.load()`.
-with np.load("/var/data/training_data.npy") as data:
- features = data["features"]
- labels = data["labels"]
-
-# Assume that each row of `features` corresponds to the same row as `labels`.
-assert features.shape[0] == labels.shape[0]
-
-dataset = tf.data.Dataset.from_tensor_slices((features, labels))
-```
-
-Note that the above code snippet will embed the `features` and `labels` arrays
-in your TensorFlow graph as `tf.constant()` operations. This works well for a
-small dataset, but wastes memory---because the contents of the array will be
-copied multiple times---and can run into the 2GB limit for the `tf.GraphDef`
-protocol buffer.
-
-As an alternative, you can define the `Dataset` in terms of `tf.placeholder()`
-tensors, and *feed* the NumPy arrays when you initialize an `Iterator` over the
-dataset.
-
-```python
-# Load the training data into two NumPy arrays, for example using `np.load()`.
-with np.load("/var/data/training_data.npy") as data:
- features = data["features"]
- labels = data["labels"]
-
-# Assume that each row of `features` corresponds to the same row as `labels`.
-assert features.shape[0] == labels.shape[0]
-
-features_placeholder = tf.placeholder(features.dtype, features.shape)
-labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
-
-dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
-# [Other transformations on `dataset`...]
-dataset = ...
-iterator = dataset.make_initializable_iterator()
-
-sess.run(iterator.initializer, feed_dict={features_placeholder: features,
- labels_placeholder: labels})
-```
-
-### Consuming TFRecord data
-
-The `tf.data` API supports a variety of file formats so that you can process
-large datasets that do not fit in memory. For example, the TFRecord file format
-is a simple record-oriented binary format that many TensorFlow applications use
-for training data. The `tf.data.TFRecordDataset` class enables you to
-stream over the contents of one or more TFRecord files as part of an input
-pipeline.
-
-```python
-# Creates a dataset that reads all of the examples from two files.
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-```
-
-The `filenames` argument to the `TFRecordDataset` initializer can either be a
-string, a list of strings, or a `tf.Tensor` of strings. Therefore if you have
-two sets of files for training and validation purposes, you can use a
-`tf.placeholder(tf.string)` to represent the filenames, and initialize an
-iterator from the appropriate filenames:
-
-```python
-filenames = tf.placeholder(tf.string, shape=[None])
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(...) # Parse the record into tensors.
-dataset = dataset.repeat() # Repeat the input indefinitely.
-dataset = dataset.batch(32)
-iterator = dataset.make_initializable_iterator()
-
-# You can feed the initializer with the appropriate filenames for the current
-# phase of execution, e.g. training vs. validation.
-
-# Initialize `iterator` with training data.
-training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
-
-# Initialize `iterator` with validation data.
-validation_filenames = ["/var/data/validation1.tfrecord", ...]
-sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
-```
-
-### Consuming text data
-
-Many datasets are distributed as one or more text files. The
-`tf.data.TextLineDataset` provides an easy way to extract lines from
-one or more text files. Given one or more filenames, a `TextLineDataset` will
-produce one string-valued element per line of those files. Like a
-`TFRecordDataset`, `TextLineDataset` accepts `filenames` as a `tf.Tensor`, so
-you can parameterize it by passing a `tf.placeholder(tf.string)`.
-
-```python
-filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
-dataset = tf.data.TextLineDataset(filenames)
-```
-
-By default, a `TextLineDataset` yields *every* line of each file, which may
-not be desirable, for example if the file starts with a header line, or contains
-comments. These lines can be removed using the `Dataset.skip()` and
-`Dataset.filter()` transformations. To apply these transformations to each
-file separately, we use `Dataset.flat_map()` to create a nested `Dataset` for
-each file.
-
-```python
-filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
-
-dataset = tf.data.Dataset.from_tensor_slices(filenames)
-
-# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
-# and then concatenate their contents sequentially into a single "flat" dataset.
-# * Skip the first line (header row).
-# * Filter out lines beginning with "#" (comments).
-dataset = dataset.flat_map(
- lambda filename: (
- tf.data.TextLineDataset(filename)
- .skip(1)
- .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
-```
-
-### Consuming CSV data
-
-The CSV file format is a popular format for storing tabular data in plain text.
-The `tf.contrib.data.CsvDataset` class provides a way to extract records from
-one or more CSV files that comply with [RFC 4180](https://tools.ietf.org/html/rfc4180).
-Given one or more filenames and a list of defaults, a `CsvDataset` will produce
-a tuple of elements whose types correspond to the types of the defaults
-provided, per CSV record. Like `TFRecordDataset` and `TextLineDataset`,
-`CsvDataset` accepts `filenames` as a `tf.Tensor`, so you can parameterize it
-by passing a `tf.placeholder(tf.string)`.
-
-```
-# Creates a dataset that reads all of the records from two CSV files, each with
-# eight float columns
-filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
-record_defaults = [tf.float32] * 8 # Eight required float columns
-dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)
-```
-
-If some columns are empty, you can provide defaults instead of types.
-
-```
-# Creates a dataset that reads all of the records from two CSV files, each with
-# four float columns which may have missing values
-record_defaults = [[0.0]] * 8
-dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)
-```
-
-By default, a `CsvDataset` yields *every* column of *every* line of the file,
-which may not be desirable, for example if the file starts with a header line
-that should be ignored, or if some columns are not required in the input.
-These lines and fields can be removed with the `header` and `select_cols`
-arguments respectively.
-
-```
-# Creates a dataset that reads all of the records from two CSV files with
-# headers, extracting float data from columns 2 and 4.
-record_defaults = [[0.0]] * 2 # Only provide defaults for the selected columns
-dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2,4])
-```
-<!--
-TODO(mrry): Add these sections.
-
-### Consuming from a Python generator
--->
-
-## Preprocessing data with `Dataset.map()`
-
-The `Dataset.map(f)` transformation produces a new dataset by applying a given
-function `f` to each element of the input dataset. It is based on
-the
-[`map()` function](https://en.wikipedia.org/wiki/Map_(higher-order_function))
-that is commonly applied to lists (and other structures) in functional
-programming languages. The function `f` takes the `tf.Tensor` objects that
-represent a single element in the input, and returns the `tf.Tensor` objects
-that will represent a single element in the new dataset. Its implementation uses
-standard TensorFlow operations to transform one element into another.
-
-This section covers common examples of how to use `Dataset.map()`.
-
-### Parsing `tf.Example` protocol buffer messages
-
-Many input pipelines extract `tf.train.Example` protocol buffer messages from a
-TFRecord-format file (written, for example, using
-`tf.python_io.TFRecordWriter`). Each `tf.train.Example` record contains one or
-more "features", and the input pipeline typically converts these features into
-tensors.
-
-```python
-# Transforms a scalar string `example_proto` into a pair of a scalar string and
-# a scalar integer, representing an image and its label, respectively.
-def _parse_function(example_proto):
- features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
- "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
- parsed_features = tf.parse_single_example(example_proto, features)
- return parsed_features["image"], parsed_features["label"]
-
-# Creates a dataset that reads all of the examples from two files, and extracts
-# the image and label features.
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(_parse_function)
-```
-
-### Decoding image data and resizing it
-
-When training a neural network on real-world image data, it is often necessary
-to convert images of different sizes to a common size, so that they may be
-batched into a fixed size.
-
-```python
-# Reads an image from a file, decodes it into a dense tensor, and resizes it
-# to a fixed shape.
-def _parse_function(filename, label):
- image_string = tf.read_file(filename)
- image_decoded = tf.image.decode_jpeg(image_string)
- image_resized = tf.image.resize_images(image_decoded, [28, 28])
- return image_resized, label
-
-# A vector of filenames.
-filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
-
-# `labels[i]` is the label for the image in `filenames[i].
-labels = tf.constant([0, 37, ...])
-
-dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
-dataset = dataset.map(_parse_function)
-```
-
-### Applying arbitrary Python logic with `tf.py_func()`
-
-For performance reasons, we encourage you to use TensorFlow operations for
-preprocessing your data whenever possible. However, it is sometimes useful to
-call upon external Python libraries when parsing your input data. To do so,
-invoke, the `tf.py_func()` operation in a `Dataset.map()` transformation.
-
-```python
-import cv2
-
-# Use a custom OpenCV function to read the image, instead of the standard
-# TensorFlow `tf.read_file()` operation.
-def _read_py_function(filename, label):
- image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
- return image_decoded, label
-
-# Use standard TensorFlow operations to resize the image to a fixed shape.
-def _resize_function(image_decoded, label):
- image_decoded.set_shape([None, None, None])
- image_resized = tf.image.resize_images(image_decoded, [28, 28])
- return image_resized, label
-
-filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
-labels = [0, 37, 29, 1, ...]
-
-dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
-dataset = dataset.map(
- lambda filename, label: tuple(tf.py_func(
- _read_py_function, [filename, label], [tf.uint8, label.dtype])))
-dataset = dataset.map(_resize_function)
-```
-
-<!--
-TODO(mrry): Add this section.
-
-### Handling text data with unusual sizes
--->
-
-## Batching dataset elements
-
-### Simple batching
-
-The simplest form of batching stacks `n` consecutive elements of a dataset into
-a single element. The `Dataset.batch()` transformation does exactly this, with
-the same constraints as the `tf.stack()` operator, applied to each component
-of the elements: i.e. for each component *i*, all elements must have a tensor
-of the exact same shape.
-
-```python
-inc_dataset = tf.data.Dataset.range(100)
-dec_dataset = tf.data.Dataset.range(0, -100, -1)
-dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
-batched_dataset = dataset.batch(4)
-
-iterator = batched_dataset.make_one_shot_iterator()
-next_element = iterator.get_next()
-
-print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
-print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
-print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
-```
-
-### Batching tensors with padding
-
-The above recipe works for tensors that all have the same size. However, many
-models (e.g. sequence models) work with input data that can have varying size
-(e.g. sequences of different lengths). To handle this case, the
-`Dataset.padded_batch()` transformation enables you to batch tensors of
-different shape by specifying one or more dimensions in which they may be
-padded.
-
-```python
-dataset = tf.data.Dataset.range(100)
-dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
-dataset = dataset.padded_batch(4, padded_shapes=[None])
-
-iterator = dataset.make_one_shot_iterator()
-next_element = iterator.get_next()
-
-print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
-print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
- # [5, 5, 5, 5, 5, 0, 0],
- # [6, 6, 6, 6, 6, 6, 0],
- # [7, 7, 7, 7, 7, 7, 7]]
-```
-
-The `Dataset.padded_batch()` transformation allows you to set different padding
-for each dimension of each component, and it may be variable-length (signified
-by `None` in the example above) or constant-length. It is also possible to
-override the padding value, which defaults to 0.
-
-<!--
-TODO(mrry): Add this section.
-
-### Dense ragged -> tf.SparseTensor
--->
-
-## Training workflows
-
-### Processing multiple epochs
-
-The `tf.data` API offers two main ways to process multiple epochs of the same
-data.
-
-The simplest way to iterate over a dataset in multiple epochs is to use the
-`Dataset.repeat()` transformation. For example, to create a dataset that repeats
-its input for 10 epochs:
-
-```python
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(...)
-dataset = dataset.repeat(10)
-dataset = dataset.batch(32)
-```
-
-Applying the `Dataset.repeat()` transformation with no arguments will repeat
-the input indefinitely. The `Dataset.repeat()` transformation concatenates its
-arguments without signaling the end of one epoch and the beginning of the next
-epoch.
-
-If you want to receive a signal at the end of each epoch, you can write a
-training loop that catches the `tf.errors.OutOfRangeError` at the end of a
-dataset. At that point you might collect some statistics (e.g. the validation
-error) for the epoch.
-
-```python
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(...)
-dataset = dataset.batch(32)
-iterator = dataset.make_initializable_iterator()
-next_element = iterator.get_next()
-
-# Compute for 100 epochs.
-for _ in range(100):
- sess.run(iterator.initializer)
- while True:
- try:
- sess.run(next_element)
- except tf.errors.OutOfRangeError:
- break
-
- # [Perform end-of-epoch calculations here.]
-```
-
-### Randomly shuffling input data
-
-The `Dataset.shuffle()` transformation randomly shuffles the input dataset
-using a similar algorithm to `tf.RandomShuffleQueue`: it maintains a fixed-size
-buffer and chooses the next element uniformly at random from that buffer.
-
-```python
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(...)
-dataset = dataset.shuffle(buffer_size=10000)
-dataset = dataset.batch(32)
-dataset = dataset.repeat()
-```
-
-### Using high-level APIs
-
-The `tf.train.MonitoredTrainingSession` API simplifies many aspects of running
-TensorFlow in a distributed setting. `MonitoredTrainingSession` uses the
-`tf.errors.OutOfRangeError` to signal that training has completed, so to use it
-with the `tf.data` API, we recommend using
-`Dataset.make_one_shot_iterator()`. For example:
-
-```python
-filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
-dataset = tf.data.TFRecordDataset(filenames)
-dataset = dataset.map(...)
-dataset = dataset.shuffle(buffer_size=10000)
-dataset = dataset.batch(32)
-dataset = dataset.repeat(num_epochs)
-iterator = dataset.make_one_shot_iterator()
-
-next_example, next_label = iterator.get_next()
-loss = model_function(next_example, next_label)
-
-training_op = tf.train.AdagradOptimizer(...).minimize(loss)
-
-with tf.train.MonitoredTrainingSession(...) as sess:
- while not sess.should_stop():
- sess.run(training_op)
-```
-
-To use a `Dataset` in the `input_fn` of a `tf.estimator.Estimator`, we also
-recommend using `Dataset.make_one_shot_iterator()`. For example:
-
-```python
-def dataset_input_fn():
- filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
- dataset = tf.data.TFRecordDataset(filenames)
-
- # Use `tf.parse_single_example()` to extract data from a `tf.Example`
- # protocol buffer, and perform any additional per-record preprocessing.
- def parser(record):
- keys_to_features = {
- "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
- "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
- "label": tf.FixedLenFeature((), tf.int64,
- default_value=tf.zeros([], dtype=tf.int64)),
- }
- parsed = tf.parse_single_example(record, keys_to_features)
-
- # Perform additional preprocessing on the parsed data.
- image = tf.image.decode_jpeg(parsed["image_data"])
- image = tf.reshape(image, [299, 299, 1])
- label = tf.cast(parsed["label"], tf.int32)
-
- return {"image_data": image, "date_time": parsed["date_time"]}, label
-
- # Use `Dataset.map()` to build a pair of a feature dictionary and a label
- # tensor for each example.
- dataset = dataset.map(parser)
- dataset = dataset.shuffle(buffer_size=10000)
- dataset = dataset.batch(32)
- dataset = dataset.repeat(num_epochs)
- iterator = dataset.make_one_shot_iterator()
-
- # `features` is a dictionary in which each value is a batch of values for
- # that feature; `labels` is a batch of labels.
- features, labels = iterator.get_next()
- return features, labels
-```
diff --git a/tensorflow/docs_src/guide/datasets_for_estimators.md b/tensorflow/docs_src/guide/datasets_for_estimators.md
deleted file mode 100644
index 969ea579f7..0000000000
--- a/tensorflow/docs_src/guide/datasets_for_estimators.md
+++ /dev/null
@@ -1,387 +0,0 @@
-# Datasets for Estimators
-
-The `tf.data` module contains a collection of classes that allows you to
-easily load data, manipulate it, and pipe it into your model. This document
-introduces the API by walking through two simple examples:
-
-* Reading in-memory data from numpy arrays.
-* Reading lines from a csv file.
-
-<!-- TODO(markdaoust): Add links to an example reading from multiple-files
-(image_retraining), and a from_generator example. -->
-
-## Basic input
-
-Taking slices from an array is the simplest way to get started with `tf.data`.
-
-The @{$premade_estimators$Premade Estimators} chapter describes
-the following `train_input_fn`, from
-[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py),
-to pipe the data into the Estimator:
-
-``` python
-def train_input_fn(features, labels, batch_size):
- """An input function for training"""
- # Convert the inputs to a Dataset.
- dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
-
- # Shuffle, repeat, and batch the examples.
- dataset = dataset.shuffle(1000).repeat().batch(batch_size)
-
- # Return the dataset.
- return dataset
-```
-
-Let's look at this more closely.
-
-### Arguments
-
-This function expects three arguments. Arguments expecting an "array" can
-accept nearly anything that can be converted to an array with `numpy.array`.
-One exception is
-[`tuple`](https://docs.python.org/3/tutorial/datastructures.html#tuples-and-sequences)
-which, as we will see, has special meaning for `Datasets`.
-
-* `features`: A `{'feature_name':array}` dictionary (or
- [`DataFrame`](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html))
- containing the raw input features.
-* `labels` : An array containing the
- [label](https://developers.google.com/machine-learning/glossary/#label)
- for each example.
-* `batch_size` : An integer indicating the desired batch size.
-
-In [`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py)
-we retrieved the Iris data using the `iris_data.load_data()` function.
-You can run it, and unpack the results as follows:
-
-``` python
-import iris_data
-
-# Fetch the data
-train, test = iris_data.load_data()
-features, labels = train
-```
-
-Then we passed this data to the input function, with a line similar to this:
-
-``` python
-batch_size=100
-iris_data.train_input_fn(features, labels, batch_size)
-```
-
-Let's walk through the `train_input_fn()`.
-
-### Slices
-
-The function starts by using the `tf.data.Dataset.from_tensor_slices` function
-to create a `tf.data.Dataset` representing slices of the array. The array is
-sliced across the first dimension. For example, an array containing the
-MNIST training data has a shape of `(60000, 28, 28)`. Passing this to
-`from_tensor_slices` returns a `Dataset` object containing 60000 slices, each one
-a 28x28 image.
-
-The code that returns this `Dataset` is as follows:
-
-``` python
-train, test = tf.keras.datasets.mnist.load_data()
-mnist_x, mnist_y = train
-
-mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
-print(mnist_ds)
-```
-
-This will print the following line, showing the
-@{$guide/tensors#shapes$shapes} and
-@{$guide/tensors#data_types$types} of the items in
-the dataset. Note that a `Dataset` does not know how many items it contains.
-
-``` None
-<TensorSliceDataset shapes: (28,28), types: tf.uint8>
-```
-
-The `Dataset` above represents a simple collection of arrays, but datasets are
-much more powerful than this. A `Dataset` can transparently handle any nested
-combination of dictionaries or tuples (or
-[`namedtuple`](https://docs.python.org/2/library/collections.html#collections.namedtuple)
-).
-
-For example after converting the iris `features`
-to a standard python dictionary, you can then convert the dictionary of arrays
-to a `Dataset` of dictionaries as follows:
-
-``` python
-dataset = tf.data.Dataset.from_tensor_slices(dict(features))
-print(dataset)
-```
-``` None
-<TensorSliceDataset
-
- shapes: {
- SepalLength: (), PetalWidth: (),
- PetalLength: (), SepalWidth: ()},
-
- types: {
- SepalLength: tf.float64, PetalWidth: tf.float64,
- PetalLength: tf.float64, SepalWidth: tf.float64}
->
-```
-
-Here we see that when a `Dataset` contains structured elements, the `shapes`
-and `types` of the `Dataset` take on the same structure. This dataset contains
-dictionaries of @{$guide/tensors#rank$scalars}, all of type
-`tf.float64`.
-
-The first line of the iris `train_input_fn` uses the same functionality, but
-adds another level of structure. It creates a dataset containing
-`(features_dict, label)` pairs.
-
-The following code shows that the label is a scalar with type `int64`:
-
-``` python
-# Convert the inputs to a Dataset.
-dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
-print(dataset)
-```
-```
-<TensorSliceDataset
- shapes: (
- {
- SepalLength: (), PetalWidth: (),
- PetalLength: (), SepalWidth: ()},
- ()),
-
- types: (
- {
- SepalLength: tf.float64, PetalWidth: tf.float64,
- PetalLength: tf.float64, SepalWidth: tf.float64},
- tf.int64)>
-```
-
-### Manipulation
-
-Currently the `Dataset` would iterate over the data once, in a fixed order, and
-only produce a single element at a time. It needs further processing before it
-can be used for training. Fortunately, the `tf.data.Dataset` class provides
-methods to better prepare the data for training. The next line of the input
-function takes advantage of several of these methods:
-
-``` python
-# Shuffle, repeat, and batch the examples.
-dataset = dataset.shuffle(1000).repeat().batch(batch_size)
-```
-
-The `tf.data.Dataset.shuffle` method uses a fixed-size buffer to
-shuffle the items as they pass through. In this case the `buffer_size` is
-greater than the number of examples in the `Dataset`, ensuring that the data is
-completely shuffled (The Iris data set only contains 150 examples).
-
-The `tf.data.Dataset.repeat` method restarts the `Dataset` when
-it reaches the end. To limit the number of epochs, set the `count` argument.
-
-The `tf.data.Dataset.batch` method collects a number of examples and
-stacks them, to create batches. This adds a dimension to their shape. The new
-dimension is added as the first dimension. The following code uses
-the `batch` method on the MNIST `Dataset`, from earlier. This results in a
-`Dataset` containing 3D arrays representing stacks of `(28,28)` images:
-
-``` python
-print(mnist_ds.batch(100))
-```
-
-``` none
-<BatchDataset
- shapes: (?, 28, 28),
- types: tf.uint8>
-```
-Note that the dataset has an unknown batch size because the last batch will
-have fewer elements.
-
-In `train_input_fn`, after batching the `Dataset` contains 1D vectors of
-elements where each scalar was previously:
-
-```python
-print(dataset)
-```
-```
-<TensorSliceDataset
- shapes: (
- {
- SepalLength: (?,), PetalWidth: (?,),
- PetalLength: (?,), SepalWidth: (?,)},
- (?,)),
-
- types: (
- {
- SepalLength: tf.float64, PetalWidth: tf.float64,
- PetalLength: tf.float64, SepalWidth: tf.float64},
- tf.int64)>
-```
-
-
-### Return
-
-At this point the `Dataset` contains `(features_dict, labels)` pairs.
-This is the format expected by the `train` and `evaluate` methods, so the
-`input_fn` returns the dataset.
-
-The `labels` can/should be omitted when using the `predict` method.
-
-<!--
- TODO(markdaoust): link to `input_fn` doc when it exists
--->
-
-
-## Reading a CSV File
-
-The most common real-world use case for the `Dataset` class is to stream data
-from files on disk. The `tf.data` module includes a variety of
-file readers. Let's see how parsing the Iris dataset from the csv file looks
-using a `Dataset`.
-
-The following call to the `iris_data.maybe_download` function downloads the
-data if necessary, and returns the pathnames of the resulting files:
-
-``` python
-import iris_data
-train_path, test_path = iris_data.maybe_download()
-```
-
-The [`iris_data.csv_input_fn`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py)
-function contains an alternative implementation that parses the csv files using
-a `Dataset`.
-
-Let's look at how to build an Estimator-compatible input function that reads
-from the local files.
-
-### Build the `Dataset`
-
-We start by building a `tf.data.TextLineDataset` object to
-read the file one line at a time. Then, we call the
-`tf.data.Dataset.skip` method to skip over the first line of the file, which contains a header, not an example:
-
-``` python
-ds = tf.data.TextLineDataset(train_path).skip(1)
-```
-
-### Build a csv line parser
-
-We will start by building a function to parse a single line.
-
-The following `iris_data.parse_line` function accomplishes this task using the
-`tf.decode_csv` function, and some simple python code:
-
-We must parse each of the lines in the dataset in order to generate the
-necessary `(features, label)` pairs. The following `_parse_line` function
-calls `tf.decode_csv` to parse a single line into its features
-and the label. Since Estimators require that features be represented as a
-dictionary, we rely on Python's built-in `dict` and `zip` functions to build
-that dictionary. The feature names are the keys of that dictionary.
-We then call the dictionary's `pop` method to remove the label field from
-the features dictionary:
-
-``` python
-# Metadata describing the text columns
-COLUMNS = ['SepalLength', 'SepalWidth',
- 'PetalLength', 'PetalWidth',
- 'label']
-FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
-def _parse_line(line):
- # Decode the line into its fields
- fields = tf.decode_csv(line, FIELD_DEFAULTS)
-
- # Pack the result into a dictionary
- features = dict(zip(COLUMNS,fields))
-
- # Separate the label from the features
- label = features.pop('label')
-
- return features, label
-```
-
-### Parse the lines
-
-Datasets have many methods for manipulating the data while it is being piped
-to a model. The most heavily-used method is `tf.data.Dataset.map`, which
-applies a transformation to each element of the `Dataset`.
-
-The `map` method takes a `map_func` argument that describes how each item in the
-`Dataset` should be transformed.
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/datasets/map.png">
-</div>
-<div style="text-align: center">
-The `tf.data.Dataset.map` method applies the `map_func` to
-transform each item in the <code>Dataset</code>.
-</div>
-
-So to parse the lines as they are streamed out of the csv file, we pass our
-`_parse_line` function to the `map` method:
-
-``` python
-ds = ds.map(_parse_line)
-print(ds)
-```
-``` None
-<MapDataset
-shapes: (
- {SepalLength: (), PetalWidth: (), ...},
- ()),
-types: (
- {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
- tf.int32)>
-```
-
-Now instead of simple scalar strings, the dataset contains `(features, label)`
-pairs.
-
-the remainder of the `iris_data.csv_input_fn` function is identical
-to `iris_data.train_input_fn` which was covered in the in the
-[Basic input](#basic_input) section.
-
-### Try it out
-
-This function can be used as a replacement for
-`iris_data.train_input_fn`. It can be used to feed an estimator as follows:
-
-``` python
-train_path, test_path = iris_data.maybe_download()
-
-# All the inputs are numeric
-feature_columns = [
- tf.feature_column.numeric_column(name)
- for name in iris_data.CSV_COLUMN_NAMES[:-1]]
-
-# Build the estimator
-est = tf.estimator.LinearClassifier(feature_columns,
- n_classes=3)
-# Train the estimator
-batch_size = 100
-est.train(
- steps=1000,
- input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))
-```
-
-Estimators expect an `input_fn` to take no arguments. To work around this
-restriction, we use `lambda` to capture the arguments and provide the expected
-interface.
-
-## Summary
-
-The `tf.data` module provides a collection of classes and functions for easily
-reading data from a variety of sources. Furthermore, `tf.data` has simple
-powerful methods for applying a wide variety of standard and custom
-transformations.
-
-Now you have the basic idea of how to efficiently load data into an
-Estimator. Consider the following documents next:
-
-
-* @{$custom_estimators}, which demonstrates how to build your own
- custom `Estimator` model.
-* The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates
- how to experiment directly with `tf.data.Datasets` using TensorFlow's low
- level APIs.
-* @{$guide/datasets} which goes into great detail about additional
- functionality of `Datasets`.
-
diff --git a/tensorflow/docs_src/guide/debugger.md b/tensorflow/docs_src/guide/debugger.md
deleted file mode 100644
index 4c4a04a88a..0000000000
--- a/tensorflow/docs_src/guide/debugger.md
+++ /dev/null
@@ -1,814 +0,0 @@
-# TensorFlow Debugger
-
-<!-- [comment]: TODO(barryr): Links to and from sections on "Graphs" & "Monitoring Learning". -->
-
-[TOC]
-
-`tfdbg` is a specialized debugger for TensorFlow. It lets you view the internal
-structure and states of running TensorFlow graphs during training and inference,
-which is difficult to debug with general-purpose debuggers such as Python's `pdb`
-due to TensorFlow's computation-graph paradigm.
-
-This guide focuses on the command-line interface (CLI) of `tfdbg`. For guide on
-how to use the graphical user interface (GUI) of tfdbg, i.e., the
-**TensorBoard Debugger Plugin**, please visit
-[its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
-
-Note: The TensorFlow debugger uses a
-[curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text
-user interface. On Mac OS X, the `ncurses` library is required and can be
-installed with `brew install ncurses`. On Windows, curses isn't as
-well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based
-interface can be used with tfdbg by installing `pyreadline` with `pip`. If you
-use Anaconda3, you can install it with a command such as
-`"C:\Program Files\Anaconda3\Scripts\pip.exe" install pyreadline`. Unofficial
-Windows curses packages can be downloaded
-[here](https://www.lfd.uci.edu/~gohlke/pythonlibs/#curses), then subsequently
-installed using `pip install <your_version>.whl`, however curses on Windows may
-not work as reliably as curses on Linux or Mac.
-
-This tutorial demonstrates how to use the **tfdbg** CLI to debug the appearance
-of [`nan`s](https://en.wikipedia.org/wiki/NaN)
-and [`inf`s](https://en.wikipedia.org/wiki/Infinity), a frequently-encountered
-type of bug in TensorFlow model development.
-The following example is for users who use the low-level
-[`Session`](https://www.tensorflow.org/api_docs/python/tf/Session) API of
-TensorFlow. Later sections of this document describe how to use **tfdbg**
-with higher-level APIs of TensorFlow, including `tf.estimator`,
-`tf.keras` / `keras` and `tf.contrib.slim`.
-To *observe* such an issue, run the following command without the debugger (the
-source code can be found
-[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/debug/examples/debug_mnist.py)):
-
-```none
-python -m tensorflow.python.debug.examples.debug_mnist
-```
-
-This code trains a simple neural network for MNIST digit image recognition.
-Notice that the accuracy increases slightly after the first training step, but
-then gets stuck at a low (near-chance) level:
-
-```none
-Accuracy at step 0: 0.1113
-Accuracy at step 1: 0.3183
-Accuracy at step 2: 0.098
-Accuracy at step 3: 0.098
-Accuracy at step 4: 0.098
-```
-
-Wondering what might have gone wrong, you suspect that certain nodes in the
-training graph generated bad numeric values such as `inf`s and `nan`s, because
-this is a common cause of this type of training failure.
-Let's use tfdbg to debug this issue and pinpoint the exact graph node where this
-numeric problem first surfaced.
-
-## Wrapping TensorFlow Sessions with tfdbg
-
-To add support for tfdbg in our example, all that is needed is to add the
-following lines of code and wrap the Session object with a debugger wrapper.
-This code is already added in
-[debug_mnist.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/debug/examples/debug_mnist.py),
-so you can activate tfdbg CLI with the `--debug` flag at the command line.
-
-```python
-# Let your BUILD target depend on "//tensorflow/python/debug:debug_py"
-# (You don't need to worry about the BUILD dependency if you are using a pip
-# install of open-source TensorFlow.)
-from tensorflow.python import debug as tf_debug
-
-sess = tf_debug.LocalCLIDebugWrapperSession(sess)
-```
-
-This wrapper has the same interface as Session, so enabling debugging requires
-no other changes to the code. The wrapper provides additional features,
-including:
-
-* Bringing up a CLI before and after `Session.run()` calls, to let you
-control the execution and inspect the graph's internal state.
-* Allowing you to register special `filters` for tensor values, to facilitate
-the diagnosis of issues.
-
-In this example, we have already registered a tensor filter called
-`tfdbg.has_inf_or_nan`,
-which simply determines if there are any `nan` or `inf` values in any
-intermediate tensors (tensors that are neither inputs or outputs of the
-`Session.run()` call, but are in the path leading from the inputs to the
-outputs). This filter is for `nan`s and `inf`s is a common enough use case that
-we ship it with the
-@{$python/tfdbg#Classes_for_debug_dump_data_and_directories$`debug_data`}
-module.
-
-Note: You can also write your own custom filters. See `tfdbg.DebugDumpDir.find`
-for additional information.
-
-## Debugging Model Training with tfdbg
-
-Let's try training the model again, but with the `--debug` flag added this time:
-
-```none
-python -m tensorflow.python.debug.examples.debug_mnist --debug
-```
-
-The debug wrapper session will prompt you when it is about to execute the first
-`Session.run()` call, with information regarding the fetched tensor and feed
-dictionaries displayed on the screen.
-
-![tfdbg run-start UI](https://www.tensorflow.org/images/tfdbg_screenshot_run_start.png)
-
-This is what we refer to as the *run-start CLI*. It lists the feeds and fetches
-to the current `Session.run` call, before executing anything.
-
-If the screen size is too small to display the content of the message in its
-entirety, you can resize it.
-
-Use the **PageUp** / **PageDown** / **Home** / **End** keys to navigate the
-screen output. On most keyboards lacking those keys **Fn + Up** /
-**Fn + Down** / **Fn + Right** / **Fn + Left** will work.
-
-Enter the `run` command (or just `r`) at the command prompt:
-
-```
-tfdbg> run
-```
-
-The `run` command causes tfdbg to execute until the end of the next
-`Session.run()` call, which calculates the model's accuracy using a test data
-set. tfdbg augments the runtime Graph to dump all intermediate tensors.
-After the run ends, tfdbg displays all the dumped tensors values in the
-*run-end CLI*. For example:
-
-![tfdbg run-end UI: accuracy](https://www.tensorflow.org/images/tfdbg_screenshot_run_end_accuracy.png)
-
-This list of tensors can also be obtained by running the command `lt` after you
-executed `run`.
-
-### tfdbg CLI Frequently-Used Commands
-
-Try the following commands at the `tfdbg>` prompt (referencing the code at
-`tensorflow/python/debug/examples/debug_mnist.py`):
-
-| Command | Syntax or Option | Explanation | Example |
-|:-------------------|:---------------- |:------------ |:------------------------- |
-| **`lt`** | | **List dumped tensors.** | `lt` |
-| | `-n <name_pattern>` | List dumped tensors with names matching given regular-expression pattern. | `lt -n Softmax.*` |
-| | `-t <op_pattern>` | List dumped tensors with op types matching given regular-expression pattern. | `lt -t MatMul` |
-| | `-f <filter_name>` | List only the tensors that pass a registered tensor filter. | `lt -f has_inf_or_nan` |
-| | `-f <filter_name> -fenn <regex>` | List only the tensors that pass a registered tensor filter, excluding nodes with names matching the regular expression. | `lt -f has_inf_or_nan` `-fenn .*Sqrt.*` |
-| | `-s <sort_key>` | Sort the output by given `sort_key`, whose possible values are `timestamp` (default), `dump_size`, `op_type` and `tensor_name`. | `lt -s dump_size` |
-| | `-r` | Sort in reverse order. | `lt -r -s dump_size` |
-| **`pt`** | | **Print value of a dumped tensor.** | |
-| | `pt <tensor>` | Print tensor value. | `pt hidden/Relu:0` |
-| | `pt <tensor>[slicing]` | Print a subarray of tensor, using [numpy](http://www.numpy.org/)-style array slicing. | `pt hidden/Relu:0[0:50,:]` |
-| | `-a` | Print the entirety of a large tensor, without using ellipses. (May take a long time for large tensors.) | `pt -a hidden/Relu:0[0:50,:]` |
-| | `-r <range>` | Highlight elements falling into specified numerical range. Multiple ranges can be used in conjunction. | `pt hidden/Relu:0 -a -r [[-inf,-1],[1,inf]]` |
-| | `-n <number>` | Print dump corresponding to specified 0-based dump number. Required for tensors with multiple dumps. | `pt -n 0 hidden/Relu:0` |
-| | `-s` | Include a summary of the numeric values of the tensor (applicable only to non-empty tensors with Boolean and numeric types such as `int*` and `float*`.) | `pt -s hidden/Relu:0[0:50,:]` |
-| | `-w` | Write the value of the tensor (possibly sliced) to a Numpy file using [`numpy.save()`](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.save.html) | `pt -s hidden/Relu:0 -w /tmp/relu.npy` |
-| **`@[coordinates]`** | | Navigate to specified element in `pt` output. | `@[10,0]` or `@10,0` |
-| **`/regex`** | | [less](https://linux.die.net/man/1/less)-style search for given regular expression. | `/inf` |
-| **`/`** | | Scroll to the next line with matches to the searched regex (if any). | `/` |
-| **`pf`** | | **Print a value in the feed_dict to `Session.run`.** | |
-| | `pf <feed_tensor_name>` | Print the value of the feed. Also note that the `pf` command has the `-a`, `-r` and `-s` flags (not listed below), which have the same syntax and semantics as the identically-named flags of `pt`. | `pf input_xs:0` |
-| **eval** | | **Evaluate arbitrary Python and numpy expression.** | |
-| | `eval <expression>` | Evaluate a Python / numpy expression, with numpy available as `np` and debug tensor names enclosed in backticks. | ``eval "np.matmul((`output/Identity:0` / `Softmax:0`).T, `Softmax:0`)"`` |
-| | `-a` | Print a large-sized evaluation result in its entirety, i.e., without using ellipses. | ``eval -a 'np.sum(`Softmax:0`, axis=1)'`` |
-| | `-w` | Write the result of the evaluation to a Numpy file using [`numpy.save()`](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.save.html) | ``eval -a 'np.sum(`Softmax:0`, axis=1)' -w /tmp/softmax_sum.npy`` |
-| **`ni`** | | **Display node information.** | |
-| | `-a` | Include node attributes in the output. | `ni -a hidden/Relu` |
-| | `-d` | List the debug dumps available from the node. | `ni -d hidden/Relu` |
-| | `-t` | Display the Python stack trace of the node's creation. | `ni -t hidden/Relu` |
-| **`li`** | | **List inputs to node** | |
-| | `-r` | List the inputs to node, recursively (the input tree.) | `li -r hidden/Relu:0` |
-| | `-d <max_depth>` | Limit recursion depth under the `-r` mode. | `li -r -d 3 hidden/Relu:0` |
-| | `-c` | Include control inputs. | `li -c -r hidden/Relu:0` |
-| | `-t` | Show op types of input nodes. | `li -t -r hidden/Relu:0` |
-| **`lo`** | | **List output recipients of node** | |
-| | `-r` | List the output recipients of node, recursively (the output tree.) | `lo -r hidden/Relu:0` |
-| | `-d <max_depth>` | Limit recursion depth under the `-r` mode. | `lo -r -d 3 hidden/Relu:0` |
-| | `-c` | Include recipients via control edges. | `lo -c -r hidden/Relu:0` |
-| | `-t` | Show op types of recipient nodes. | `lo -t -r hidden/Relu:0` |
-| **`ls`** | | **List Python source files involved in node creation.** | |
-| | `-p <path_pattern>` | Limit output to source files matching given regular-expression path pattern. | `ls -p .*debug_mnist.*` |
-| | `-n` | Limit output to node names matching given regular-expression pattern. | `ls -n Softmax.*` |
-| **`ps`** | | **Print Python source file.** | |
-| | `ps <file_path>` | Print given Python source file source.py, with the lines annotated with the nodes created at each of them (if any). | `ps /path/to/source.py` |
-| | `-t` | Perform annotation with respect to Tensors, instead of the default, nodes. | `ps -t /path/to/source.py` |
-| | `-b <line_number>` | Annotate source.py beginning at given line. | `ps -b 30 /path/to/source.py` |
-| | `-m <max_elements>` | Limit the number of elements in the annotation for each line. | `ps -m 100 /path/to/source.py` |
-| **`run`** | | **Proceed to the next Session.run()** | `run` |
-| | `-n` | Execute through the next `Session.run` without debugging, and drop to CLI right before the run after that. | `run -n` |
-| | `-t <T>` | Execute `Session.run` `T - 1` times without debugging, followed by a run with debugging. Then drop to CLI right after the debugged run. | `run -t 10` |
-| | `-f <filter_name>` | Continue executing `Session.run` until any intermediate tensor triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan` |
-| | `-f <filter_name> -fenn <regex>` | Continue executing `Session.run` until any intermediate tensor whose node names doesn't match the regular expression triggers the specified Tensor filter (causes the filter to return `True`). | `run -f has_inf_or_nan -fenn .*Sqrt.*` |
-| | `--node_name_filter <pattern>` | Execute the next `Session.run`, watching only nodes with names matching the given regular-expression pattern. | `run --node_name_filter Softmax.*` |
-| | `--op_type_filter <pattern>` | Execute the next `Session.run`, watching only nodes with op types matching the given regular-expression pattern. | `run --op_type_filter Variable.*` |
-| | `--tensor_dtype_filter <pattern>` | Execute the next `Session.run`, dumping only Tensors with data types (`dtype`s) matching the given regular-expression pattern. | `run --tensor_dtype_filter int.*` |
-| | `-p` | Execute the next `Session.run` call in profiling mode. | `run -p` |
-| **`ri`** | | **Display information about the run the current run, including fetches and feeds.** | `ri` |
-| **`config`** | | **Set or show persistent TFDBG UI configuration.** | |
-| | `set` | Set the value of a config item: {`graph_recursion_depth`, `mouse_mode`}. | `config set graph_recursion_depth 3` |
-| | `show` | Show current persistent UI configuration. | `config show` |
-| **`version`** | | **Print the version of TensorFlow and its key dependencies.** | `version` |
-| **`help`** | | **Print general help information** | `help` |
-| | `help <command>` | Print help for given command. | `help lt` |
-
-Note that each time you enter a command, a new screen output
-will appear. This is somewhat analogous to web pages in a browser. You can
-navigate between these screens by clicking the `<--` and
-`-->` text arrows near the top-left corner of the CLI.
-
-### Other Features of the tfdbg CLI
-
-In addition to the commands listed above, the tfdbg CLI provides the following
-additional features:
-
-* To navigate through previous tfdbg commands, type in a few characters
- followed by the Up or Down arrow keys. tfdbg will show you the history of
- commands that started with those characters.
-* To navigate through the history of screen outputs, do either of the
- following:
- * Use the `prev` and `next` commands.
- * Click underlined `<--` and `-->` links near the top left corner of the
- screen.
-* Tab completion of commands and some command arguments.
-* To redirect the screen output to a file instead of the screen, end the
- command with bash-style redirection. For example, the following command
- redirects the output of the pt command to the `/tmp/xent_value_slices.txt`
- file:
-
- ```none
- tfdbg> pt cross_entropy/Log:0[:, 0:10] > /tmp/xent_value_slices.txt
- ```
-
-### Finding `nan`s and `inf`s
-
-In this first `Session.run()` call, there happen to be no problematic numerical
-values. You can move on to the next run by using the command `run` or its
-shorthand `r`.
-
-> TIP: If you enter `run` or `r` repeatedly, you will be able to move through
-> the `Session.run()` calls in a sequential manner.
->
-> You can also use the `-t` flag to move ahead a number of `Session.run()` calls
-> at a time, for example:
->
-> ```
-> tfdbg> run -t 10
-> ```
-
-Instead of entering `run` repeatedly and manually searching for `nan`s and
-`inf`s in the run-end UI after every `Session.run()` call (for example, by using
-the `pt` command shown in the table above) , you can use the following
-command to let the debugger repeatedly execute `Session.run()` calls without
-stopping at the run-start or run-end prompt, until the first `nan` or `inf`
-value shows up in the graph. This is analogous to *conditional breakpoints* in
-some procedural-language debuggers:
-
-```none
-tfdbg> run -f has_inf_or_nan
-```
-
-> NOTE: The preceding command works properly because a tensor filter called
-> `has_inf_or_nan` has been registered for you when the wrapped session is
-> created. This filter detects `nan`s and `inf`s (as explained previously).
-> If you have registered any other filters, you can
-> use "run -f" to have tfdbg run until any tensor triggers that filter (cause
-> the filter to return True).
->
-> ``` python
-> def my_filter_callable(datum, tensor):
-> # A filter that detects zero-valued scalars.
-> return len(tensor.shape) == 0 and tensor == 0.0
->
-> sess.add_tensor_filter('my_filter', my_filter_callable)
-> ```
->
-> Then at the tfdbg run-start prompt run until your filter is triggered:
->
-> ```
-> tfdbg> run -f my_filter
-> ```
-
-See [this API document](https://www.tensorflow.org/api_docs/python/tfdbg/DebugDumpDir#find)
-for more information on the expected signature and return value of the predicate
-`Callable` used with `add_tensor_filter()`.
-
-![tfdbg run-end UI: infs and nans](https://www.tensorflow.org/images/tfdbg_screenshot_run_end_inf_nan.png)
-
-As the screen display indicates on the first line, the `has_inf_or_nan` filter is first triggered
-during the fourth `Session.run()` call: an
-[Adam optimizer](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer)
-forward-backward training pass on the graph. In this run, 36 (out of the total
-95) intermediate tensors contain `nan` or `inf` values. These tensors are listed
-in chronological order, with their timestamps displayed on the left. At the top
-of the list, you can see the first tensor in which the bad numerical values
-first surfaced: `cross_entropy/Log:0`.
-
-To view the value of the tensor, click the underlined tensor name
-`cross_entropy/Log:0` or enter the equivalent command:
-
-```none
-tfdbg> pt cross_entropy/Log:0
-```
-
-Scroll down a little and you will notice some scattered `inf` values. If the
-instances of `inf` and `nan` are difficult to spot by eye, you can use the
-following command to perform a regex search and highlight the output:
-
-```none
-tfdbg> /inf
-```
-
-Or, alternatively:
-
-```none
-tfdbg> /(inf|nan)
-```
-
-You can also use the `-s` or `--numeric_summary` command to get a quick summary
-of the types of numeric values in the tensor:
-
-``` none
-tfdbg> pt -s cross_entropy/Log:0
-```
-
-From the summary, you can see that several of the 1000 elements of the
-`cross_entropy/Log:0` tensor are `-inf`s (negative infinities).
-
-Why did these infinities appear? To further debug, display more information
-about the node `cross_entropy/Log` by clicking the underlined `node_info` menu
-item on the top or entering the equivalent node_info (`ni`) command:
-
-```none
-tfdbg> ni cross_entropy/Log
-```
-
-![tfdbg run-end UI: infs and nans](https://www.tensorflow.org/images/tfdbg_screenshot_run_end_node_info.png)
-
-You can see that this node has the op type `Log`
-and that its input is the node `Softmax`. Run the following command to
-take a closer look at the input tensor:
-
-```none
-tfdbg> pt Softmax:0
-```
-
-Examine the values in the input tensor, searching for zeros:
-
-```none
-tfdbg> /0\.000
-```
-
-Indeed, there are zeros. Now it is clear that the origin of the bad numerical
-values is the node `cross_entropy/Log` taking logs of zeros. To find out the
-culprit line in the Python source code, use the `-t` flag of the `ni` command
-to show the traceback of the node's construction:
-
-```none
-tfdbg> ni -t cross_entropy/Log
-```
-
-If you click "node_info" at the top of the screen, tfdbg automatically shows the
-traceback of the node's construction.
-
-From the traceback, you can see that the op is constructed at the following
-line:
-[`debug_mnist.py`](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_mnist.py):
-
-```python
-diff = y_ * tf.log(y)
-```
-
-**tfdbg** has a feature that makes it easy to trace Tensors and ops back to
-lines in Python source files. It can annotate lines of a Python file with
-the ops or Tensors created by them. To use this feature,
-simply click the underlined line numbers in the stack trace output of the
-`ni -t <op_name>` commands, or use the `ps` (or `print_source`) command such as:
-`ps /path/to/source.py`. For example, the following screenshot shows the output
-of a `ps` command.
-
-![tfdbg run-end UI: annotated Python source file](https://www.tensorflow.org/images/tfdbg_screenshot_run_end_annotated_source.png)
-
-### Fixing the problem
-
-To fix the problem, edit `debug_mnist.py`, changing the original line:
-
-```python
-diff = -(y_ * tf.log(y))
-```
-
-to the built-in, numerically-stable implementation of softmax cross-entropy:
-
-```python
-diff = tf.losses.softmax_cross_entropy(labels=y_, logits=logits)
-```
-
-Rerun with the `--debug` flag as follows:
-
-```none
-python -m tensorflow.python.debug.examples.debug_mnist --debug
-```
-
-At the `tfdbg>` prompt, enter the following command:
-
-```none
-run -f has_inf_or_nan`
-```
-
-Confirm that no tensors are flagged as containing `nan` or `inf` values, and
-accuracy now continues to rise rather than getting stuck. Success!
-
-## Debugging TensorFlow Estimators
-
-This section explains how to debug TensorFlow programs that use the `Estimator`
-APIs. Part of the convenience provided by these APIs is that
-they manage `Session`s internally. This makes the `LocalCLIDebugWrapperSession`
-described in the preceding sections inapplicable. Fortunately, you can still
-debug them by using special `hook`s provided by `tfdbg`.
-
-`tfdbg` can debug the
-`tf.estimator.Estimator.train`,
-`tf.estimator.Estimator.evaluate` and
-`tf.estimator.Estimator.predict`
-methods of tf-learn `Estimator`s. To debug `Estimator.train()`,
-create a `LocalCLIDebugHook` and supply it in the `hooks` argument. For example:
-
-```python
-# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
-# (You don't need to worry about the BUILD dependency if you are using a pip
-# install of open-source TensorFlow.)
-from tensorflow.python import debug as tf_debug
-
-# Create a LocalCLIDebugHook and use it as a monitor when calling fit().
-hooks = [tf_debug.LocalCLIDebugHook()]
-
-# To debug `train`:
-classifier.train(input_fn,
- steps=1000,
- hooks=hooks)
-```
-
-Similarly, to debug `Estimator.evaluate()` and `Estimator.predict()`, assign
-hooks to the `hooks` parameter, as in the following example:
-
-```python
-# To debug `evaluate`:
-accuracy_score = classifier.evaluate(eval_input_fn,
- hooks=hooks)["accuracy"]
-
-# To debug `predict`:
-predict_results = classifier.predict(predict_input_fn, hooks=hooks)
-```
-
-[debug_tflearn_iris.py](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py),
-contains a full example of how to use the tfdbg with `Estimator`s.
-To run this example, do:
-
-```none
-python -m tensorflow.python.debug.examples.debug_tflearn_iris --debug
-```
-
-The `LocalCLIDebugHook` also allows you to configure a `watch_fn` that can be
-used to flexibly specify what `Tensor`s to watch on different `Session.run()`
-calls, as a function of the `fetches` and `feed_dict` and other states. See
-`tfdbg.DumpingDebugWrapperSession.__init__`
-for more details.
-
-## Debugging Keras Models with TFDBG
-
-To use TFDBG with
-[tf.keras](https://www.tensorflow.org/api_docs/python/tf/keras),
-let the Keras backend use a TFDBG-wrapped Session object. For example, to use
-the CLI wrapper:
-
-``` python
-import tensorflow as tf
-from tensorflow.python import debug as tf_debug
-
-tf.keras.backend.set_session(tf_debug.LocalCLIDebugWrapperSession(tf.Session()))
-
-# Define your keras model, called "model".
-
-# Calls to `fit()`, 'evaluate()` and `predict()` methods will break into the
-# TFDBG CLI.
-model.fit(...)
-model.evaluate(...)
-model.predict(...)
-```
-
-With minor modification, the preceding code example also works for the
-[non-TensorFlow version of Keras](https://keras.io/) running against a
-TensorFlow backend. You just need to replace `tf.keras.backend` with
-`keras.backend`.
-
-## Debugging tf-slim with TFDBG
-
-TFDBG supports debugging of training and evaluation with
-[tf-slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim).
-As detailed below, training and evaluation require slightly different debugging
-workflows.
-
-### Debugging training in tf-slim
-To debug the training process, provide `LocalCLIDebugWrapperSession` to the
-`session_wrapper` argument of `slim.learning.train()`. For example:
-
-``` python
-import tensorflow as tf
-from tensorflow.python import debug as tf_debug
-
-# ... Code that creates the graph and the train_op ...
-tf.contrib.slim.learning.train(
- train_op,
- logdir,
- number_of_steps=10,
- session_wrapper=tf_debug.LocalCLIDebugWrapperSession)
-```
-
-### Debugging evaluation in tf-slim
-To debug the evaluation process, provide `LocalCLIDebugHook` to the
-`hooks` argument of `slim.evaluation.evaluate_once()`. For example:
-
-``` python
-import tensorflow as tf
-from tensorflow.python import debug as tf_debug
-
-# ... Code that creates the graph and the eval and final ops ...
-tf.contrib.slim.evaluation.evaluate_once(
- '',
- checkpoint_path,
- logdir,
- eval_op=my_eval_op,
- final_op=my_value_op,
- hooks=[tf_debug.LocalCLIDebugHook()])
-```
-
-## Offline Debugging of Remotely-Running Sessions
-
-Often, your model is running on a remote machine or a process that you don't
-have terminal access to. To perform model debugging in such cases, you can use
-the `offline_analyzer` binary of `tfdbg` (described below). It operates on
-dumped data directories. This can be done to both the lower-level `Session` API
-and the higher-level `Estimator` API.
-
-### Debugging Remote tf.Sessions
-
-If you interact directly with the `tf.Session` API in `python`, you can
-configure the `RunOptions` proto that you call your `Session.run()` method
-with, by using the method `tfdbg.watch_graph`.
-This will cause the intermediate tensors and runtime graphs to be dumped to a
-shared storage location of your choice when the `Session.run()` call occurs
-(at the cost of slower performance). For example:
-
-```python
-from tensorflow.python import debug as tf_debug
-
-# ... Code where your session and graph are set up...
-
-run_options = tf.RunOptions()
-tf_debug.watch_graph(
- run_options,
- session.graph,
- debug_urls=["file:///shared/storage/location/tfdbg_dumps_1"])
-# Be sure to specify different directories for different run() calls.
-
-session.run(fetches, feed_dict=feeds, options=run_options)
-```
-
-Later, in an environment that you have terminal access to (for example, a local
-computer that can access the shared storage location specified in the code
-above), you can load and inspect the data in the dump directory on the shared
-storage by using the `offline_analyzer` binary of `tfdbg`. For example:
-
-```none
-python -m tensorflow.python.debug.cli.offline_analyzer \
- --dump_dir=/shared/storage/location/tfdbg_dumps_1
-```
-
-The `Session` wrapper `DumpingDebugWrapperSession` offers an easier and more
-flexible way to generate file-system dumps that can be analyzed offline.
-To use it, simply wrap your session in a `tf_debug.DumpingDebugWrapperSession`.
-For example:
-
-```python
-# Let your BUILD target depend on "//tensorflow/python/debug:debug_py
-# (You don't need to worry about the BUILD dependency if you are using a pip
-# install of open-source TensorFlow.)
-from tensorflow.python import debug as tf_debug
-
-sess = tf_debug.DumpingDebugWrapperSession(
- sess, "/shared/storage/location/tfdbg_dumps_1/", watch_fn=my_watch_fn)
-```
-
-The `watch_fn` argument accepts a `Callable` that allows you to configure what
-`tensor`s to watch on different `Session.run()` calls, as a function of the
-`fetches` and `feed_dict` to the `run()` call and other states.
-
-### C++ and other languages
-
-If your model code is written in C++ or other languages, you can also
-modify the `debug_options` field of `RunOptions` to generate debug dumps that
-can be inspected offline. See
-[the proto definition](https://www.tensorflow.org/code/tensorflow/core/protobuf/debug.proto)
-for more details.
-
-### Debugging Remotely-Running Estimators
-
-If your remote TensorFlow server runs `Estimator`s,
-you can use the non-interactive `DumpingDebugHook`. For example:
-
-```python
-# Let your BUILD target depend on "//tensorflow/python/debug:debug_py
-# (You don't need to worry about the BUILD dependency if you are using a pip
-# install of open-source TensorFlow.)
-from tensorflow.python import debug as tf_debug
-
-hooks = [tf_debug.DumpingDebugHook("/shared/storage/location/tfdbg_dumps_1")]
-```
-
-Then this `hook` can be used in the same way as the `LocalCLIDebugHook` examples
-described earlier in this document.
-As the training, evaluation or prediction happens with `Estimator`,
-tfdbg creates directories having the following name pattern:
-`/shared/storage/location/tfdbg_dumps_1/run_<epoch_timestamp_microsec>_<uuid>`.
-Each directory corresponds to a `Session.run()` call that underlies
-the `fit()` or `evaluate()` call. You can load these directories and inspect
-them in a command-line interface in an offline manner using the
-`offline_analyzer` offered by tfdbg. For example:
-
-```bash
-python -m tensorflow.python.debug.cli.offline_analyzer \
- --dump_dir="/shared/storage/location/tfdbg_dumps_1/run_<epoch_timestamp_microsec>_<uuid>"
-```
-
-## Frequently Asked Questions
-
-**Q**: _Do the timestamps on the left side of the `lt` output reflect actual
- performance in a non-debugging session?_
-
-**A**: No. The debugger inserts additional special-purpose debug nodes to the
- graph to record the values of intermediate tensors. These nodes
- slow down the graph execution. If you are interested in profiling your
- model, check out
-
- 1. The profiling mode of tfdbg: `tfdbg> run -p`.
- 2. [tfprof](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profiler)
- and other profiling tools for TensorFlow.
-
-**Q**: _How do I link tfdbg against my `Session` in Bazel? Why do I see an
- error such as "ImportError: cannot import name debug"?_
-
-**A**: In your BUILD rule, declare dependencies:
- `"//tensorflow:tensorflow_py"` and `"//tensorflow/python/debug:debug_py"`.
- The first is the dependency that you include to use TensorFlow even
- without debugger support; the second enables the debugger.
- Then, In your Python file, add:
-
-```python
-from tensorflow.python import debug as tf_debug
-
-# Then wrap your TensorFlow Session with the local-CLI wrapper.
-sess = tf_debug.LocalCLIDebugWrapperSession(sess)
-```
-
-**Q**: _Does tfdbg help debug runtime errors such as shape mismatches?_
-
-**A**: Yes. tfdbg intercepts errors generated by ops during runtime and presents
- the errors with some debug instructions to the user in the CLI.
- See examples:
-
-```none
-# Debugging shape mismatch during matrix multiplication.
-python -m tensorflow.python.debug.examples.debug_errors \
- --error shape_mismatch --debug
-
-# Debugging uninitialized variable.
-python -m tensorflow.python.debug.examples.debug_errors \
- --error uninitialized_variable --debug
-```
-
-**Q**: _How can I let my tfdbg-wrapped Sessions or Hooks run the debug mode
-only from the main thread?_
-
-**A**:
-This is a common use case, in which the `Session` object is used from multiple
-threads concurrently. Typically, the child threads take care of background tasks
-such as running enqueue operations. Often, you want to debug only the main
-thread (or less frequently, only one of the child threads). You can use the
-`thread_name_filter` keyword argument of `LocalCLIDebugWrapperSession` to
-achieve this type of thread-selective debugging. For example, to debug from the
-main thread only, construct a wrapped `Session` as follows:
-
-```python
-sess = tf_debug.LocalCLIDebugWrapperSession(sess, thread_name_filter="MainThread$")
-```
-
-The above example relies on the fact that main threads in Python have the
-default name `MainThread`.
-
-**Q**: _The model I am debugging is very large. The data dumped by tfdbg
-fills up the free space of my disk. What can I do?_
-
-**A**:
-You might encounter this problem in any of the following situations:
-
-* models with many intermediate tensors
-* very large intermediate tensors
-* many `tf.while_loop` iterations
-
-There are three possible workarounds or solutions:
-
-* The constructors of `LocalCLIDebugWrapperSession` and `LocalCLIDebugHook`
- provide a keyword argument, `dump_root`, to specify the path
- to which tfdbg dumps the debug data. You can use it to let tfdbg dump the
- debug data on a disk with larger free space. For example:
-
-```python
-# For LocalCLIDebugWrapperSession
-sess = tf_debug.LocalCLIDebugWrapperSession(dump_root="/with/lots/of/space")
-
-# For LocalCLIDebugHook
-hooks = [tf_debug.LocalCLIDebugHook(dump_root="/with/lots/of/space")]
-```
- Make sure that the directory pointed to by dump_root is empty or nonexistent.
- `tfdbg` cleans up the dump directories before exiting.
-
-* Reduce the batch size used during the runs.
-* Use the filtering options of tfdbg's `run` command to watch only specific
- nodes in the graph. For example:
-
- ```
- tfdbg> run --node_name_filter .*hidden.*
- tfdbg> run --op_type_filter Variable.*
- tfdbg> run --tensor_dtype_filter int.*
- ```
-
- The first command above watches only nodes whose name match the
- regular-expression pattern `.*hidden.*`. The second command watches only
- operations whose name match the pattern `Variable.*`. The third one watches
- only the tensors whose dtype match the pattern `int.*` (e.g., `int32`).
-
-
-**Q**: _Why can't I select text in the tfdbg CLI?_
-
-**A**: This is because the tfdbg CLI enables mouse events in the terminal by
- default. This [mouse-mask](https://linux.die.net/man/3/mousemask) mode
- overrides default terminal interactions, including text selection. You
- can re-enable text selection by using the command `mouse off` or
- `m off`.
-
-**Q**: _Why does the tfdbg CLI show no dumped tensors when I debug code like the following?_
-
-``` python
-a = tf.ones([10], name="a")
-b = tf.add(a, a, name="b")
-sess = tf.Session()
-sess = tf_debug.LocalCLIDebugWrapperSession(sess)
-sess.run(b)
-```
-
-**A**: The reason why you see no data dumped is because every node in the
- executed TensorFlow graph is constant-folded by the TensorFlow runtime.
- In this example, `a` is a constant tensor; therefore, the fetched
- tensor `b` is effectively also a constant tensor. TensorFlow's graph
- optimization folds the graph that contains `a` and `b` into a single
- node to speed up future runs of the graph, which is why `tfdbg` does
- not generate any intermediate tensor dumps. However, if `a` were a
- `tf.Variable`, as in the following example:
-
-``` python
-import numpy as np
-
-a = tf.Variable(np.ones(10), name="a")
-b = tf.add(a, a, name="b")
-sess = tf.Session()
-sess.run(tf.global_variables_initializer())
-sess = tf_debug.LocalCLIDebugWrapperSession(sess)
-sess.run(b)
-```
-
-the constant-folding would not occur and `tfdbg` should show the intermediate
-tensor dumps.
-
-
-**Q**: I am debugging a model that generates unwanted infinities or NaNs. But
- there are some nodes in my model that are known to generate infinities
- or NaNs in their output tensors even under completely normal conditions.
- How can I skip those nodes during my `run -f has_inf_or_nan` actions?
-
-**A**: Use the `--filter_exclude_node_names` (`-fenn` for short) flag. For
- example, if you known you have a node with name matching the regular
- expression `.*Sqrt.*` that generates infinities or NaNs regardless
- of whether the model is behaving correctly, you can exclude the nodes
- from the infinity/NaN-finding runs with the command
- `run -f has_inf_or_nan -fenn .*Sqrt.*`.
-
-
-**Q**: Is there a GUI for tfdbg?
-
-**A**: Yes, the **TensorBoard Debugger Plugin** is the GUI of tfdbg.
- It offers features such as inspection of the computation graph,
- real-time visualization of tensor values, continuation to tensor
- and conditional breakpoints, and tying tensors to their
- graph-construction source code, all in the browser environment.
- To get started, please visit
- [its README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md).
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
deleted file mode 100644
index e47a8b599c..0000000000
--- a/tensorflow/docs_src/guide/eager.md
+++ /dev/null
@@ -1,854 +0,0 @@
-# Eager Execution
-
-TensorFlow's eager execution is an imperative programming environment that
-evaluates operations immediately, without building graphs: operations return
-concrete values instead of constructing a computational graph to run later. This
-makes it easy to get started with TensorFlow and debug models, and it
-reduces boilerplate as well. To follow along with this guide, run the code
-samples below in an interactive `python` interpreter.
-
-Eager execution is a flexible machine learning platform for research and
-experimentation, providing:
-
-* *An intuitive interface*—Structure your code naturally and use Python data
- structures. Quickly iterate on small models and small data.
-* *Easier debugging*—Call ops directly to inspect running models and test
- changes. Use standard Python debugging tools for immediate error reporting.
-* *Natural control flow*—Use Python control flow instead of graph control
- flow, simplifying the specification of dynamic models.
-
-Eager execution supports most TensorFlow operations and GPU acceleration. For a
-collection of examples running in eager execution, see:
-[tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples).
-
-Note: Some models may experience increased overhead with eager execution
-enabled. Performance improvements are ongoing, but please
-[file a bug](https://github.com/tensorflow/tensorflow/issues) if you find a
-problem and share your benchmarks.
-
-## Setup and basic usage
-
-Upgrade to the latest version of TensorFlow:
-
-```
-$ pip install --upgrade tensorflow
-```
-
-To start eager execution, add `tf.enable_eager_execution()` to the beginning of
-the program or console session. Do not add this operation to other modules that
-the program calls.
-
-```py
-from __future__ import absolute_import, division, print_function
-
-import tensorflow as tf
-
-tf.enable_eager_execution()
-```
-
-Now you can run TensorFlow operations and the results will return immediately:
-
-```py
-tf.executing_eagerly() # => True
-
-x = [[2.]]
-m = tf.matmul(x, x)
-print("hello, {}".format(m)) # => "hello, [[4.]]"
-```
-
-Enabling eager execution changes how TensorFlow operations behave—now they
-immediately evaluate and return their values to Python. `tf.Tensor` objects
-reference concrete values instead of symbolic handles to nodes in a computational
-graph. Since there isn't a computational graph to build and run later in a
-session, it's easy to inspect results using `print()` or a debugger. Evaluating,
-printing, and checking tensor values does not break the flow for computing
-gradients.
-
-Eager execution works nicely with [NumPy](http://www.numpy.org/). NumPy
-operations accept `tf.Tensor` arguments. TensorFlow
-[math operations](https://www.tensorflow.org/api_guides/python/math_ops) convert
-Python objects and NumPy arrays to `tf.Tensor` objects. The
-`tf.Tensor.numpy` method returns the object's value as a NumPy `ndarray`.
-
-```py
-a = tf.constant([[1, 2],
- [3, 4]])
-print(a)
-# => tf.Tensor([[1 2]
-# [3 4]], shape=(2, 2), dtype=int32)
-
-# Broadcasting support
-b = tf.add(a, 1)
-print(b)
-# => tf.Tensor([[2 3]
-# [4 5]], shape=(2, 2), dtype=int32)
-
-# Operator overloading is supported
-print(a * b)
-# => tf.Tensor([[ 2 6]
-# [12 20]], shape=(2, 2), dtype=int32)
-
-# Use NumPy values
-import numpy as np
-
-c = np.multiply(a, b)
-print(c)
-# => [[ 2 6]
-# [12 20]]
-
-# Obtain numpy value from a tensor:
-print(a.numpy())
-# => [[1 2]
-# [3 4]]
-```
-
-The `tf.contrib.eager` module contains symbols available to both eager and graph execution
-environments and is useful for writing code to [work with graphs](#work_with_graphs):
-
-```py
-tfe = tf.contrib.eager
-```
-
-## Dynamic control flow
-
-A major benefit of eager execution is that all the functionality of the host
-language is available while your model is executing. So, for example,
-it is easy to write [fizzbuzz](https://en.wikipedia.org/wiki/Fizz_buzz):
-
-```py
-def fizzbuzz(max_num):
- counter = tf.constant(0)
- max_num = tf.convert_to_tensor(max_num)
- for num in range(max_num.numpy()):
- num = tf.constant(num)
- if int(num % 3) == 0 and int(num % 5) == 0:
- print('FizzBuzz')
- elif int(num % 3) == 0:
- print('Fizz')
- elif int(num % 5) == 0:
- print('Buzz')
- else:
- print(num)
- counter += 1
- return counter
-```
-
-This has conditionals that depend on tensor values and it prints these values
-at runtime.
-
-## Build a model
-
-Many machine learning models are represented by composing layers. When
-using TensorFlow with eager execution you can either write your own layers or
-use a layer provided in the `tf.keras.layers` package.
-
-While you can use any Python object to represent a layer,
-TensorFlow has `tf.keras.layers.Layer` as a convenient base class. Inherit from
-it to implement your own layer:
-
-```py
-class MySimpleLayer(tf.keras.layers.Layer):
- def __init__(self, output_units):
- super(MySimpleLayer, self).__init__()
- self.output_units = output_units
-
- def build(self, input_shape):
- # The build method gets called the first time your layer is used.
- # Creating variables on build() allows you to make their shape depend
- # on the input shape and hence removes the need for the user to specify
- # full shapes. It is possible to create variables during __init__() if
- # you already know their full shapes.
- self.kernel = self.add_variable(
- "kernel", [input_shape[-1], self.output_units])
-
- def call(self, input):
- # Override call() instead of __call__ so we can perform some bookkeeping.
- return tf.matmul(input, self.kernel)
-```
-
-Use `tf.keras.layers.Dense` layer instead of `MySimpleLayer` above as it has
-a superset of its functionality (it can also add a bias).
-
-When composing layers into models you can use `tf.keras.Sequential` to represent
-models which are a linear stack of layers. It is easy to use for basic models:
-
-```py
-model = tf.keras.Sequential([
- tf.keras.layers.Dense(10, input_shape=(784,)), # must declare input shape
- tf.keras.layers.Dense(10)
-])
-```
-
-Alternatively, organize models in classes by inheriting from `tf.keras.Model`.
-This is a container for layers that is a layer itself, allowing `tf.keras.Model`
-objects to contain other `tf.keras.Model` objects.
-
-```py
-class MNISTModel(tf.keras.Model):
- def __init__(self):
- super(MNISTModel, self).__init__()
- self.dense1 = tf.keras.layers.Dense(units=10)
- self.dense2 = tf.keras.layers.Dense(units=10)
-
- def call(self, input):
- """Run the model."""
- result = self.dense1(input)
- result = self.dense2(result)
- result = self.dense2(result) # reuse variables from dense2 layer
- return result
-
-model = MNISTModel()
-```
-
-It's not required to set an input shape for the `tf.keras.Model` class since
-the parameters are set the first time input is passed to the layer.
-
-`tf.keras.layers` classes create and contain their own model variables that
-are tied to the lifetime of their layer objects. To share layer variables, share
-their objects.
-
-
-## Eager training
-
-### Computing gradients
-
-[Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
-is useful for implementing machine learning algorithms such as
-[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training
-neural networks. During eager execution, use `tf.GradientTape` to trace
-operations for computing gradients later.
-
-`tf.GradientTape` is an opt-in feature to provide maximal performance when
-not tracing. Since different operations can occur during each call, all
-forward-pass operations get recorded to a "tape". To compute the gradient, play
-the tape backwards and then discard. A particular `tf.GradientTape` can only
-compute one gradient; subsequent calls throw a runtime error.
-
-```py
-w = tf.Variable([[1.0]])
-with tf.GradientTape() as tape:
- loss = w * w
-
-grad = tape.gradient(loss, w)
-print(grad) # => tf.Tensor([[ 2.]], shape=(1, 1), dtype=float32)
-```
-
-Here's an example of `tf.GradientTape` that records forward-pass operations
-to train a simple model:
-
-```py
-# A toy dataset of points around 3 * x + 2
-NUM_EXAMPLES = 1000
-training_inputs = tf.random_normal([NUM_EXAMPLES])
-noise = tf.random_normal([NUM_EXAMPLES])
-training_outputs = training_inputs * 3 + 2 + noise
-
-def prediction(input, weight, bias):
- return input * weight + bias
-
-# A loss function using mean-squared error
-def loss(weights, biases):
- error = prediction(training_inputs, weights, biases) - training_outputs
- return tf.reduce_mean(tf.square(error))
-
-# Return the derivative of loss with respect to weight and bias
-def grad(weights, biases):
- with tf.GradientTape() as tape:
- loss_value = loss(weights, biases)
- return tape.gradient(loss_value, [weights, biases])
-
-train_steps = 200
-learning_rate = 0.01
-# Start with arbitrary values for W and B on the same batch of data
-W = tf.Variable(5.)
-B = tf.Variable(10.)
-
-print("Initial loss: {:.3f}".format(loss(W, B)))
-
-for i in range(train_steps):
- dW, dB = grad(W, B)
- W.assign_sub(dW * learning_rate)
- B.assign_sub(dB * learning_rate)
- if i % 20 == 0:
- print("Loss at step {:03d}: {:.3f}".format(i, loss(W, B)))
-
-print("Final loss: {:.3f}".format(loss(W, B)))
-print("W = {}, B = {}".format(W.numpy(), B.numpy()))
-```
-
-Output (exact numbers may vary):
-
-```
-Initial loss: 71.204
-Loss at step 000: 68.333
-Loss at step 020: 30.222
-Loss at step 040: 13.691
-Loss at step 060: 6.508
-Loss at step 080: 3.382
-Loss at step 100: 2.018
-Loss at step 120: 1.422
-Loss at step 140: 1.161
-Loss at step 160: 1.046
-Loss at step 180: 0.996
-Final loss: 0.974
-W = 3.01582956314, B = 2.1191945076
-```
-
-Replay the `tf.GradientTape` to compute the gradients and apply them in a
-training loop. This is demonstrated in an excerpt from the
-[mnist_eager.py](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_eager.py)
-example:
-
-```py
-dataset = tf.data.Dataset.from_tensor_slices((data.train.images,
- data.train.labels))
-...
-for (batch, (images, labels)) in enumerate(dataset):
- ...
- with tf.GradientTape() as tape:
- logits = model(images, training=True)
- loss_value = loss(logits, labels)
- ...
- grads = tape.gradient(loss_value, model.variables)
- optimizer.apply_gradients(zip(grads, model.variables),
- global_step=tf.train.get_or_create_global_step())
-```
-
-
-The following example creates a multi-layer model that classifies the standard
-MNIST handwritten digits. It demonstrates the optimizer and layer APIs to build
-trainable graphs in an eager execution environment.
-
-### Train a model
-
-Even without training, call the model and inspect the output in eager execution:
-
-```py
-# Create a tensor representing a blank image
-batch = tf.zeros([1, 1, 784])
-print(batch.shape) # => (1, 1, 784)
-
-result = model(batch)
-# => tf.Tensor([[[ 0. 0., ..., 0.]]], shape=(1, 1, 10), dtype=float32)
-```
-
-This example uses the
-[dataset.py module](https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py)
-from the
-[TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist);
-download this file to your local directory. Run the following to download the
-MNIST data files to your working directory and prepare a `tf.data.Dataset`
-for training:
-
-```py
-import dataset # download dataset.py file
-dataset_train = dataset.train('./datasets').shuffle(60000).repeat(4).batch(32)
-```
-
-To train a model, define a loss function to optimize and then calculate
-gradients. Use an optimizer to update the variables:
-
-```py
-def loss(model, x, y):
- prediction = model(x)
- return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=prediction)
-
-def grad(model, inputs, targets):
- with tf.GradientTape() as tape:
- loss_value = loss(model, inputs, targets)
- return tape.gradient(loss_value, model.variables)
-
-optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
-
-x, y = iter(dataset_train).next()
-print("Initial loss: {:.3f}".format(loss(model, x, y)))
-
-# Training loop
-for (i, (x, y)) in enumerate(dataset_train):
- # Calculate derivatives of the input function with respect to its parameters.
- grads = grad(model, x, y)
- # Apply the gradient to the model
- optimizer.apply_gradients(zip(grads, model.variables),
- global_step=tf.train.get_or_create_global_step())
- if i % 200 == 0:
- print("Loss at step {:04d}: {:.3f}".format(i, loss(model, x, y)))
-
-print("Final loss: {:.3f}".format(loss(model, x, y)))
-```
-
-Output (exact numbers may vary):
-
-```
-Initial loss: 2.674
-Loss at step 0000: 2.593
-Loss at step 0200: 2.143
-Loss at step 0400: 2.009
-Loss at step 0600: 2.103
-Loss at step 0800: 1.621
-Loss at step 1000: 1.695
-...
-Loss at step 6600: 0.602
-Loss at step 6800: 0.557
-Loss at step 7000: 0.499
-Loss at step 7200: 0.744
-Loss at step 7400: 0.681
-Final loss: 0.670
-```
-
-And for faster training, move the computation to a GPU:
-
-```py
-with tf.device("/gpu:0"):
- for (i, (x, y)) in enumerate(dataset_train):
- # minimize() is equivalent to the grad() and apply_gradients() calls.
- optimizer.minimize(lambda: loss(model, x, y),
- global_step=tf.train.get_or_create_global_step())
-```
-
-### Variables and optimizers
-
-`tf.Variable` objects store mutable `tf.Tensor` values accessed during
-training to make automatic differentiation easier. The parameters of a model can
-be encapsulated in classes as variables.
-
-Better encapsulate model parameters by using `tf.Variable` with
-`tf.GradientTape`. For example, the automatic differentiation example above
-can be rewritten:
-
-```py
-class Model(tf.keras.Model):
- def __init__(self):
- super(Model, self).__init__()
- self.W = tf.Variable(5., name='weight')
- self.B = tf.Variable(10., name='bias')
- def call(self, inputs):
- return inputs * self.W + self.B
-
-# A toy dataset of points around 3 * x + 2
-NUM_EXAMPLES = 2000
-training_inputs = tf.random_normal([NUM_EXAMPLES])
-noise = tf.random_normal([NUM_EXAMPLES])
-training_outputs = training_inputs * 3 + 2 + noise
-
-# The loss function to be optimized
-def loss(model, inputs, targets):
- error = model(inputs) - targets
- return tf.reduce_mean(tf.square(error))
-
-def grad(model, inputs, targets):
- with tf.GradientTape() as tape:
- loss_value = loss(model, inputs, targets)
- return tape.gradient(loss_value, [model.W, model.B])
-
-# Define:
-# 1. A model.
-# 2. Derivatives of a loss function with respect to model parameters.
-# 3. A strategy for updating the variables based on the derivatives.
-model = Model()
-optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
-
-print("Initial loss: {:.3f}".format(loss(model, training_inputs, training_outputs)))
-
-# Training loop
-for i in range(300):
- grads = grad(model, training_inputs, training_outputs)
- optimizer.apply_gradients(zip(grads, [model.W, model.B]),
- global_step=tf.train.get_or_create_global_step())
- if i % 20 == 0:
- print("Loss at step {:03d}: {:.3f}".format(i, loss(model, training_inputs, training_outputs)))
-
-print("Final loss: {:.3f}".format(loss(model, training_inputs, training_outputs)))
-print("W = {}, B = {}".format(model.W.numpy(), model.B.numpy()))
-```
-
-Output (exact numbers may vary):
-
-```
-Initial loss: 69.066
-Loss at step 000: 66.368
-Loss at step 020: 30.107
-Loss at step 040: 13.959
-Loss at step 060: 6.769
-Loss at step 080: 3.567
-Loss at step 100: 2.141
-Loss at step 120: 1.506
-Loss at step 140: 1.223
-Loss at step 160: 1.097
-Loss at step 180: 1.041
-Loss at step 200: 1.016
-Loss at step 220: 1.005
-Loss at step 240: 1.000
-Loss at step 260: 0.998
-Loss at step 280: 0.997
-Final loss: 0.996
-W = 2.99431324005, B = 2.02129220963
-```
-
-## Use objects for state during eager execution
-
-With graph execution, program state (such as the variables) is stored in global
-collections and their lifetime is managed by the `tf.Session` object. In
-contrast, during eager execution the lifetime of state objects is determined by
-the lifetime of their corresponding Python object.
-
-### Variables are objects
-
-During eager execution, variables persist until the last reference to the object
-is removed, and is then deleted.
-
-```py
-with tf.device("gpu:0"):
- v = tf.Variable(tf.random_normal([1000, 1000]))
- v = None # v no longer takes up GPU memory
-```
-
-### Object-based saving
-
-`tf.train.Checkpoint` can save and restore `tf.Variable`s to and from
-checkpoints:
-
-```py
-x = tf.Variable(10.)
-
-checkpoint = tf.train.Checkpoint(x=x) # save as "x"
-
-x.assign(2.) # Assign a new value to the variables and save.
-save_path = checkpoint.save('./ckpt/')
-
-x.assign(11.) # Change the variable after saving.
-
-# Restore values from the checkpoint
-checkpoint.restore(save_path)
-
-print(x) # => 2.0
-```
-
-To save and load models, `tf.train.Checkpoint` stores the internal state of objects,
-without requiring hidden variables. To record the state of a `model`,
-an `optimizer`, and a global step, pass them to a `tf.train.Checkpoint`:
-
-```py
-model = MyModel()
-optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
-checkpoint_dir = ‘/path/to/model_dir’
-checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
-root = tf.train.Checkpoint(optimizer=optimizer,
- model=model,
- optimizer_step=tf.train.get_or_create_global_step())
-
-root.save(file_prefix=checkpoint_prefix)
-# or
-root.restore(tf.train.latest_checkpoint(checkpoint_dir))
-```
-
-### Object-oriented metrics
-
-`tfe.metrics` are stored as objects. Update a metric by passing the new data to
-the callable, and retrieve the result using the `tfe.metrics.result` method,
-for example:
-
-```py
-m = tfe.metrics.Mean("loss")
-m(0)
-m(5)
-m.result() # => 2.5
-m([8, 9])
-m.result() # => 5.5
-```
-
-#### Summaries and TensorBoard
-
-@{$summaries_and_tensorboard$TensorBoard} is a visualization tool for
-understanding, debugging and optimizing the model training process. It uses
-summary events that are written while executing the program.
-
-`tf.contrib.summary` is compatible with both eager and graph execution
-environments. Summary operations, such as `tf.contrib.summary.scalar`, are
-inserted during model construction. For example, to record summaries once every
-100 global steps:
-
-```py
-global_step = tf.train.get_or_create_global_step()
-writer = tf.contrib.summary.create_file_writer(logdir)
-writer.set_as_default()
-
-for _ in range(iterations):
- global_step.assign_add(1)
- # Must include a record_summaries method
- with tf.contrib.summary.record_summaries_every_n_global_steps(100):
- # your model code goes here
- tf.contrib.summary.scalar('loss', loss)
- ...
-```
-
-## Advanced automatic differentiation topics
-
-### Dynamic models
-
-`tf.GradientTape` can also be used in dynamic models. This example for a
-[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
-algorithm looks like normal NumPy code, except there are gradients and is
-differentiable, despite the complex control flow:
-
-```py
-def line_search_step(fn, init_x, rate=1.0):
- with tf.GradientTape() as tape:
- # Variables are automatically recorded, but manually watch a tensor
- tape.watch(init_x)
- value = fn(init_x)
- grad = tape.gradient(value, init_x)
- grad_norm = tf.reduce_sum(grad * grad)
- init_value = value
- while value > init_value - rate * grad_norm:
- x = init_x - rate * grad
- value = fn(x)
- rate /= 2.0
- return x, value
-```
-
-### Additional functions to compute gradients
-
-`tf.GradientTape` is a powerful interface for computing gradients, but there
-is another [Autograd](https://github.com/HIPS/autograd)-style API available for
-automatic differentiation. These functions are useful if writing math code with
-only tensors and gradient functions, and without `tf.Variables`:
-
-* `tfe.gradients_function` —Returns a function that computes the derivatives
- of its input function parameter with respect to its arguments. The input
- function parameter must return a scalar value. When the returned function is
- invoked, it returns a list of `tf.Tensor` objects: one element for each
- argument of the input function. Since anything of interest must be passed as a
- function parameter, this becomes unwieldy if there's a dependency on many
- trainable parameters.
-* `tfe.value_and_gradients_function` —Similar to
- `tfe.gradients_function`, but when the returned function is invoked, it
- returns the value from the input function in addition to the list of
- derivatives of the input function with respect to its arguments.
-
-In the following example, `tfe.gradients_function` takes the `square`
-function as an argument and returns a function that computes the partial
-derivatives of `square` with respect to its inputs. To calculate the derivative
-of `square` at `3`, `grad(3.0)` returns `6`.
-
-```py
-def square(x):
- return tf.multiply(x, x)
-
-grad = tfe.gradients_function(square)
-
-square(3.) # => 9.0
-grad(3.) # => [6.0]
-
-# The second-order derivative of square:
-gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
-gradgrad(3.) # => [2.0]
-
-# The third-order derivative is None:
-gradgradgrad = tfe.gradients_function(lambda x: gradgrad(x)[0])
-gradgradgrad(3.) # => [None]
-
-
-# With flow control:
-def abs(x):
- return x if x > 0. else -x
-
-grad = tfe.gradients_function(abs)
-
-grad(3.) # => [1.0]
-grad(-3.) # => [-1.0]
-```
-
-### Custom gradients
-
-Custom gradients are an easy way to override gradients in eager and graph
-execution. Within the forward function, define the gradient with respect to the
-inputs, outputs, or intermediate results. For example, here's an easy way to clip
-the norm of the gradients in the backward pass:
-
-```py
-@tf.custom_gradient
-def clip_gradient_by_norm(x, norm):
- y = tf.identity(x)
- def grad_fn(dresult):
- return [tf.clip_by_norm(dresult, norm), None]
- return y, grad_fn
-```
-
-Custom gradients are commonly used to provide a numerically stable gradient for a
-sequence of operations:
-
-```py
-def log1pexp(x):
- return tf.log(1 + tf.exp(x))
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# The gradient computation works fine at x = 0.
-grad_log1pexp(0.) # => [0.5]
-
-# However, x = 100 fails because of numerical instability.
-grad_log1pexp(100.) # => [nan]
-```
-
-Here, the `log1pexp` function can be analytically simplified with a custom
-gradient. The implementation below reuses the value for `tf.exp(x)` that is
-computed during the forward pass—making it more efficient by eliminating
-redundant calculations:
-
-```py
-@tf.custom_gradient
-def log1pexp(x):
- e = tf.exp(x)
- def grad(dy):
- return dy * (1 - 1 / (1 + e))
- return tf.log(1 + e), grad
-
-grad_log1pexp = tfe.gradients_function(log1pexp)
-
-# As before, the gradient computation works fine at x = 0.
-grad_log1pexp(0.) # => [0.5]
-
-# And the gradient computation also works at x = 100.
-grad_log1pexp(100.) # => [1.0]
-```
-
-## Performance
-
-Computation is automatically offloaded to GPUs during eager execution. If you
-want control over where a computation runs you can enclose it in a
-`tf.device('/gpu:0')` block (or the CPU equivalent):
-
-```py
-import time
-
-def measure(x, steps):
- # TensorFlow initializes a GPU the first time it's used, exclude from timing.
- tf.matmul(x, x)
- start = time.time()
- for i in range(steps):
- x = tf.matmul(x, x)
- # tf.matmul can return before completing the matrix multiplication
- # (e.g., can return after enqueing the operation on a CUDA stream).
- # The x.numpy() call below will ensure that all enqueued operations
- # have completed (and will also copy the result to host memory,
- # so we're including a little more than just the matmul operation
- # time).
- _ = x.numpy()
- end = time.time()
- return end - start
-
-shape = (1000, 1000)
-steps = 200
-print("Time to multiply a {} matrix by itself {} times:".format(shape, steps))
-
-# Run on CPU:
-with tf.device("/cpu:0"):
- print("CPU: {} secs".format(measure(tf.random_normal(shape), steps)))
-
-# Run on GPU, if available:
-if tfe.num_gpus() > 0:
- with tf.device("/gpu:0"):
- print("GPU: {} secs".format(measure(tf.random_normal(shape), steps)))
-else:
- print("GPU: not found")
-```
-
-Output (exact numbers depend on hardware):
-
-```
-Time to multiply a (1000, 1000) matrix by itself 200 times:
-CPU: 1.46628093719 secs
-GPU: 0.0593810081482 secs
-```
-
-A `tf.Tensor` object can be copied to a different device to execute its
-operations:
-
-```py
-x = tf.random_normal([10, 10])
-
-x_gpu0 = x.gpu()
-x_cpu = x.cpu()
-
-_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU
-_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0
-
-if tfe.num_gpus() > 1:
- x_gpu1 = x.gpu(1)
- _ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1
-```
-
-### Benchmarks
-
-For compute-heavy models, such as
-[ResNet50](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/resnet50)
-training on a GPU, eager execution performance is comparable to graph execution.
-But this gap grows larger for models with less computation and there is work to
-be done for optimizing hot code paths for models with lots of small operations.
-
-
-## Work with graphs
-
-While eager execution makes development and debugging more interactive,
-TensorFlow graph execution has advantages for distributed training, performance
-optimizations, and production deployment. However, writing graph code can feel
-different than writing regular Python code and more difficult to debug.
-
-For building and training graph-constructed models, the Python program first
-builds a graph representing the computation, then invokes `Session.run` to send
-the graph for execution on the C++-based runtime. This provides:
-
-* Automatic differentiation using static autodiff.
-* Simple deployment to a platform independent server.
-* Graph-based optimizations (common subexpression elimination, constant-folding, etc.).
-* Compilation and kernel fusion.
-* Automatic distribution and replication (placing nodes on the distributed system).
-
-Deploying code written for eager execution is more difficult: either generate a
-graph from the model, or run the Python runtime and code directly on the server.
-
-### Write compatible code
-
-The same code written for eager execution will also build a graph during graph
-execution. Do this by simply running the same code in a new Python session where
-eager execution is not enabled.
-
-Most TensorFlow operations work during eager execution, but there are some things
-to keep in mind:
-
-* Use `tf.data` for input processing instead of queues. It's faster and easier.
-* Use object-oriented layer APIs—like `tf.keras.layers` and
- `tf.keras.Model`—since they have explicit storage for variables.
-* Most model code works the same during eager and graph execution, but there are
- exceptions. (For example, dynamic models using Python control flow to change the
- computation based on inputs.)
-* Once eager execution is enabled with `tf.enable_eager_execution`, it
- cannot be turned off. Start a new Python session to return to graph execution.
-
-It's best to write code for both eager execution *and* graph execution. This
-gives you eager's interactive experimentation and debuggability with the
-distributed performance benefits of graph execution.
-
-Write, debug, and iterate in eager execution, then import the model graph for
-production deployment. Use `tf.train.Checkpoint` to save and restore model
-variables, this allows movement between eager and graph execution environments.
-See the examples in:
-[tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples).
-
-### Use eager execution in a graph environment
-
-Selectively enable eager execution in a TensorFlow graph environment using
-`tfe.py_func`. This is used when `tf.enable_eager_execution()` has *not*
-been called.
-
-```py
-def my_py_func(x):
- x = tf.matmul(x, x) # You can use tf ops
- print(x) # but it's eager!
- return x
-
-with tf.Session() as sess:
- x = tf.placeholder(dtype=tf.float32)
- # Call eager function in graph!
- pf = tfe.py_func(my_py_func, [x], tf.float32)
- sess.run(pf, feed_dict={x: [[2.0]]}) # [[4.0]]
-```
diff --git a/tensorflow/docs_src/guide/embedding.md b/tensorflow/docs_src/guide/embedding.md
deleted file mode 100644
index 8a98367dfb..0000000000
--- a/tensorflow/docs_src/guide/embedding.md
+++ /dev/null
@@ -1,262 +0,0 @@
-# Embeddings
-
-This document introduces the concept of embeddings, gives a simple example of
-how to train an embedding in TensorFlow, and explains how to view embeddings
-with the TensorBoard Embedding Projector
-([live example](http://projector.tensorflow.org)). The first two parts target
-newcomers to machine learning or TensorFlow, and the Embedding Projector how-to
-is for users at all levels.
-
-An alternative tutorial on these concepts is available in the
-[Embeddings section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/embeddings/video-lecture).
-
-[TOC]
-
-An **embedding** is a mapping from discrete objects, such as words, to vectors
-of real numbers. For example, a 300-dimensional embedding for English words
-could include:
-
-```
-blue: (0.01359, 0.00075997, 0.24608, ..., -0.2524, 1.0048, 0.06259)
-blues: (0.01396, 0.11887, -0.48963, ..., 0.033483, -0.10007, 0.1158)
-orange: (-0.24776, -0.12359, 0.20986, ..., 0.079717, 0.23865, -0.014213)
-oranges: (-0.35609, 0.21854, 0.080944, ..., -0.35413, 0.38511, -0.070976)
-```
-
-The individual dimensions in these vectors typically have no inherent meaning.
-Instead, it's the overall patterns of location and distance between vectors
-that machine learning takes advantage of.
-
-Embeddings are important for input to machine learning. Classifiers, and neural
-networks more generally, work on vectors of real numbers. They train best on
-dense vectors, where all values contribute to define an object. However, many
-important inputs to machine learning, such as words of text, do not have a
-natural vector representation. Embedding functions are the standard and
-effective way to transform such discrete input objects into useful
-continuous vectors.
-
-Embeddings are also valuable as outputs of machine learning. Because embeddings
-map objects to vectors, applications can use similarity in vector space (for
-instance, Euclidean distance or the angle between vectors) as a robust and
-flexible measure of object similarity. One common use is to find nearest
-neighbors. Using the same word embeddings as above, for instance, here are the
-three nearest neighbors for each word and the corresponding angles:
-
-```
-blue: (red, 47.6°), (yellow, 51.9°), (purple, 52.4°)
-blues: (jazz, 53.3°), (folk, 59.1°), (bluegrass, 60.6°)
-orange: (yellow, 53.5°), (colored, 58.0°), (bright, 59.9°)
-oranges: (apples, 45.3°), (lemons, 48.3°), (mangoes, 50.4°)
-```
-
-This would tell an application that apples and oranges are in some way more
-similar (45.3° apart) than lemons and oranges (48.3° apart).
-
-## Embeddings in TensorFlow
-
-To create word embeddings in TensorFlow, we first split the text into words
-and then assign an integer to every word in the vocabulary. Let us assume that
-this has already been done, and that `word_ids` is a vector of these integers.
-For example, the sentence “I have a cat.” could be split into
-`[“I”, “have”, “a”, “cat”, “.”]` and then the corresponding `word_ids` tensor
-would have shape `[5]` and consist of 5 integers. To map these word ids
-to vectors, we need to create the embedding variable and use the
-`tf.nn.embedding_lookup` function as follows:
-
-```
-word_embeddings = tf.get_variable(“word_embeddings”,
- [vocabulary_size, embedding_size])
-embedded_word_ids = tf.nn.embedding_lookup(word_embeddings, word_ids)
-```
-
-After this, the tensor `embedded_word_ids` will have shape `[5, embedding_size]`
-in our example and contain the embeddings (dense vectors) for each of the 5
-words. At the end of training, `word_embeddings` will contain the embeddings
-for all words in the vocabulary.
-
-Embeddings can be trained in many network types, and with various loss
-functions and data sets. For example, one could use a recurrent neural network
-to predict the next word from the previous one given a large corpus of
-sentences, or one could train two networks to do multi-lingual translation.
-These methods are described in the @{$word2vec$Vector Representations of Words}
-tutorial.
-
-## Visualizing Embeddings
-
-TensorBoard includes the **Embedding Projector**, a tool that lets you
-interactively visualize embeddings. This tool can read embeddings from your
-model and render them in two or three dimensions.
-
-The Embedding Projector has three panels:
-
-- *Data panel* on the top left, where you can choose the run, the embedding
- variable and data columns to color and label points by.
-- *Projections panel* on the bottom left, where you can choose the type of
- projection.
-- *Inspector panel* on the right side, where you can search for particular
- points and see a list of nearest neighbors.
-
-### Projections
-The Embedding Projector provides three ways to reduce the dimensionality of a
-data set.
-
-- *[t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding)*:
- a nonlinear nondeterministic algorithm (T-distributed stochastic neighbor
- embedding) that tries to preserve local neighborhoods in the data, often at
- the expense of distorting global structure. You can choose whether to compute
- two- or three-dimensional projections.
-
-- *[PCA](https://en.wikipedia.org/wiki/Principal_component_analysis)*:
- a linear deterministic algorithm (principal component analysis) that tries to
- capture as much of the data variability in as few dimensions as possible. PCA
- tends to highlight large-scale structure in the data, but can distort local
- neighborhoods. The Embedding Projector computes the top 10 principal
- components, from which you can choose two or three to view.
-
-- *Custom*: a linear projection onto horizontal and vertical axes that you
- specify using labels in the data. You define the horizontal axis, for
- instance, by giving text patterns for "Left" and "Right". The Embedding
- Projector finds all points whose label matches the "Left" pattern and
- computes the centroid of that set; similarly for "Right". The line passing
- through these two centroids defines the horizontal axis. The vertical axis is
- likewise computed from the centroids for points matching the "Up" and "Down"
- text patterns.
-
-Further useful articles are
-[How to Use t-SNE Effectively](https://distill.pub/2016/misread-tsne/) and
-[Principal Component Analysis Explained Visually](http://setosa.io/ev/principal-component-analysis/).
-
-### Exploration
-
-You can explore visually by zooming, rotating, and panning using natural
-click-and-drag gestures. Hovering your mouse over a point will show any
-[metadata](#metadata) for that point. You can also inspect nearest-neighbor
-subsets. Clicking on a point causes the right pane to list the nearest
-neighbors, along with distances to the current point. The nearest-neighbor
-points are also highlighted in the projection.
-
-It is sometimes useful to restrict the view to a subset of points and perform
-projections only on those points. To do so, you can select points in multiple
-ways:
-
-- After clicking on a point, its nearest neighbors are also selected.
-- After a search, the points matching the query are selected.
-- Enabling selection, clicking on a point and dragging defines a selection
- sphere.
-
-Then click the "Isolate *nnn* points" button at the top of the Inspector pane
-on the right hand side. The following image shows 101 points selected and ready
-for the user to click "Isolate 101 points":
-
-![Selection of nearest neighbors](https://www.tensorflow.org/images/embedding-nearest-points.png "Selection of nearest neighbors")
-
-*Selection of the nearest neighbors of “important” in a word embedding dataset.*
-
-Advanced tip: filtering with custom projection can be powerful. Below, we
-filtered the 100 nearest neighbors of “politics” and projected them onto the
-“worst” - “best” vector as an x axis. The y axis is random. As a result, one
-finds on the right side “ideas”, “science”, “perspective”, “journalism” but on
-the left “crisis”, “violence” and “conflict”.
-
-<table width="100%;">
- <tr>
- <td style="width: 30%;">
- <img src="https://www.tensorflow.org/images/embedding-custom-controls.png" alt="Custom controls panel" title="Custom controls panel" />
- </td>
- <td style="width: 70%;">
- <img src="https://www.tensorflow.org/images/embedding-custom-projection.png" alt="Custom projection" title="Custom projection" />
- </td>
- </tr>
- <tr>
- <td style="width: 30%;">
- Custom projection controls.
- </td>
- <td style="width: 70%;">
- Custom projection of neighbors of "politics" onto "best" - "worst" vector.
- </td>
- </tr>
-</table>
-
-To share your findings, you can use the bookmark panel in the bottom right
-corner and save the current state (including computed coordinates of any
-projection) as a small file. The Projector can then be pointed to a set of one
-or more of these files, producing the panel below. Other users can then walk
-through a sequence of bookmarks.
-
-<img src="https://www.tensorflow.org/images/embedding-bookmark.png" alt="Bookmark panel" style="width:300px;">
-
-### Metadata
-
-If you are working with an embedding, you'll probably want to attach
-labels/images to the data points. You can do this by generating a metadata file
-containing the labels for each point and clicking "Load data" in the data panel
-of the Embedding Projector.
-
-The metadata can be either labels or images, which are
-stored in a separate file. For labels, the format should
-be a [TSV file](https://en.wikipedia.org/wiki/Tab-separated_values)
-(tab characters shown in red) whose first line contains column headers
-(shown in bold) and subsequent lines contain the metadata values. For example:
-
-<code>
-<b>Word<span style="color:#800;">\t</span>Frequency</b><br/>
- Airplane<span style="color:#800;">\t</span>345<br/>
- Car<span style="color:#800;">\t</span>241<br/>
- ...
-</code>
-
-The order of lines in the metadata file is assumed to match the order of
-vectors in the embedding variable, except for the header. Consequently, the
-(i+1)-th line in the metadata file corresponds to the i-th row of the embedding
-variable. If the TSV metadata file has only a single column, then we don’t
-expect a header row, and assume each row is the label of the embedding. We
-include this exception because it matches the commonly-used "vocab file"
-format.
-
-To use images as metadata, you must produce a single
-[sprite image](https://www.google.com/webhp#q=what+is+a+sprite+image),
-consisting of small thumbnails, one for each vector in the embedding. The
-sprite should store thumbnails in row-first order: the first data point placed
-in the top left and the last data point in the bottom right, though the last
-row doesn't have to be filled, as shown below.
-
-<table style="border: none;">
-<tr style="background-color: transparent;">
- <td style="border: 1px solid black">0</td>
- <td style="border: 1px solid black">1</td>
- <td style="border: 1px solid black">2</td>
-</tr>
-<tr style="background-color: transparent;">
- <td style="border: 1px solid black">3</td>
- <td style="border: 1px solid black">4</td>
- <td style="border: 1px solid black">5</td>
-</tr>
-<tr style="background-color: transparent;">
- <td style="border: 1px solid black">6</td>
- <td style="border: 1px solid black">7</td>
- <td style="border: 1px solid black"></td>
-</tr>
-</table>
-
-Follow [this link](https://www.tensorflow.org/images/embedding-mnist.mp4)
-to see a fun example of thumbnail images in the Embedding Projector.
-
-
-## Mini-FAQ
-
-**Is "embedding" an action or a thing?**
-Both. People talk about embedding words in a vector space (action) and about
-producing word embeddings (things). Common to both is the notion of embedding
-as a mapping from discrete objects to vectors. Creating or applying that
-mapping is an action, but the mapping itself is a thing.
-
-**Are embeddings high-dimensional or low-dimensional?**
-It depends. A 300-dimensional vector space of words and phrases, for instance,
-is often called low-dimensional (and dense) when compared to the millions of
-words and phrases it can contain. But mathematically it is high-dimensional,
-displaying many properties that are dramatically different from what our human
-intuition has learned about 2- and 3-dimensional spaces.
-
-**Is an embedding the same as an embedding layer?**
-No. An *embedding layer* is a part of neural network, but an *embedding* is a more
-general concept.
diff --git a/tensorflow/docs_src/guide/estimators.md b/tensorflow/docs_src/guide/estimators.md
deleted file mode 100644
index 7b54e3de29..0000000000
--- a/tensorflow/docs_src/guide/estimators.md
+++ /dev/null
@@ -1,196 +0,0 @@
-# Estimators
-
-This document introduces `tf.estimator`--a high-level TensorFlow
-API that greatly simplifies machine learning programming. Estimators encapsulate
-the following actions:
-
-* training
-* evaluation
-* prediction
-* export for serving
-
-You may either use the pre-made Estimators we provide or write your
-own custom Estimators. All Estimators--whether pre-made or custom--are
-classes based on the `tf.estimator.Estimator` class.
-
-For a quick example try [Estimator tutorials]](../tutorials/estimators/linear).
-To see each sub-topic in depth, see the [Estimator guides](premade_estimators).
-
-Note: TensorFlow also includes a deprecated `Estimator` class at
-`tf.contrib.learn.Estimator`, which you should not use.
-
-
-## Advantages of Estimators
-
-Estimators provide the following benefits:
-
-* You can run Estimator-based models on a local host or on a
- distributed multi-server environment without changing your model.
- Furthermore, you can run Estimator-based models on CPUs, GPUs,
- or TPUs without recoding your model.
-* Estimators simplify sharing implementations between model developers.
-* You can develop a state of the art model with high-level intuitive code.
- In short, it is generally much easier to create models with Estimators
- than with the low-level TensorFlow APIs.
-* Estimators are themselves built on `tf.keras.layers`, which
- simplifies customization.
-* Estimators build the graph for you.
-* Estimators provide a safe distributed training loop that controls how and
- when to:
- * build the graph
- * initialize variables
- * load data
- * handle exceptions
- * create checkpoint files and recover from failures
- * save summaries for TensorBoard
-
-When writing an application with Estimators, you must separate the data input
-pipeline from the model. This separation simplifies experiments with
-different data sets.
-
-
-## Pre-made Estimators
-
-Pre-made Estimators enable you to work at a much higher conceptual level
-than the base TensorFlow APIs. You no longer have to worry about creating
-the computational graph or sessions since Estimators handle all
-the "plumbing" for you. That is, pre-made Estimators create and manage
-`tf.Graph` and `tf.Session` objects for you. Furthermore,
-pre-made Estimators let you experiment with different model architectures by
-making only minimal code changes. `tf.estimator.DNNClassifier`,
-for example, is a pre-made Estimator class that trains classification models
-based on dense, feed-forward neural networks.
-
-
-### Structure of a pre-made Estimators program
-
-A TensorFlow program relying on a pre-made Estimator typically consists
-of the following four steps:
-
-1. **Write one or more dataset importing functions.** For example, you might
- create one function to import the training set and another function to
- import the test set. Each dataset importing function must return two
- objects:
-
- * a dictionary in which the keys are feature names and the
- values are Tensors (or SparseTensors) containing the corresponding
- feature data
- * a Tensor containing one or more labels
-
- For example, the following code illustrates the basic skeleton for
- an input function:
-
- def input_fn(dataset):
- ... # manipulate dataset, extracting the feature dict and the label
- return feature_dict, label
-
- (See @{$guide/datasets} for full details.)
-
-2. **Define the feature columns.** Each `tf.feature_column`
- identifies a feature name, its type, and any input pre-processing.
- For example, the following snippet creates three feature
- columns that hold integer or floating-point data. The first two
- feature columns simply identify the feature's name and type. The
- third feature column also specifies a lambda the program will invoke
- to scale the raw data:
-
- # Define three numeric feature columns.
- population = tf.feature_column.numeric_column('population')
- crime_rate = tf.feature_column.numeric_column('crime_rate')
- median_education = tf.feature_column.numeric_column('median_education',
- normalizer_fn=lambda x: x - global_education_mean)
-
-3. **Instantiate the relevant pre-made Estimator.** For example, here's
- a sample instantiation of a pre-made Estimator named `LinearClassifier`:
-
- # Instantiate an estimator, passing the feature columns.
- estimator = tf.estimator.LinearClassifier(
- feature_columns=[population, crime_rate, median_education],
- )
-
-4. **Call a training, evaluation, or inference method.**
- For example, all Estimators provide a `train` method, which trains a model.
-
- # my_training_set is the function created in Step 1
- estimator.train(input_fn=my_training_set, steps=2000)
-
-
-### Benefits of pre-made Estimators
-
-Pre-made Estimators encode best practices, providing the following benefits:
-
-* Best practices for determining where different parts of the computational
- graph should run, implementing strategies on a single machine or on a
- cluster.
-* Best practices for event (summary) writing and universally useful
- summaries.
-
-If you don't use pre-made Estimators, you must implement the preceding
-features yourself.
-
-
-## Custom Estimators
-
-The heart of every Estimator--whether pre-made or custom--is its
-**model function**, which is a method that builds graphs for training,
-evaluation, and prediction. When you are using a pre-made Estimator,
-someone else has already implemented the model function. When relying
-on a custom Estimator, you must write the model function yourself. A
-@{$custom_estimators$companion document}
-explains how to write the model function.
-
-
-## Recommended workflow
-
-We recommend the following workflow:
-
-1. Assuming a suitable pre-made Estimator exists, use it to build your
- first model and use its results to establish a baseline.
-2. Build and test your overall pipeline, including the integrity and
- reliability of your data with this pre-made Estimator.
-3. If suitable alternative pre-made Estimators are available, run
- experiments to determine which pre-made Estimator produces the
- best results.
-4. Possibly, further improve your model by building your own custom Estimator.
-
-
-## Creating Estimators from Keras models
-
-You can convert existing Keras models to Estimators. Doing so enables your Keras
-model to access Estimator's strengths, such as distributed training. Call
-`tf.keras.estimator.model_to_estimator` as in the
-following sample:
-
-```python
-# Instantiate a Keras inception v3 model.
-keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
-# Compile model with the optimizer, loss, and metrics you'd like to train with.
-keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
- loss='categorical_crossentropy',
- metric='accuracy')
-# Create an Estimator from the compiled Keras model. Note the initial model
-# state of the keras model is preserved in the created Estimator.
-est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)
-
-# Treat the derived Estimator as you would with any other Estimator.
-# First, recover the input name(s) of Keras model, so we can use them as the
-# feature column name(s) of the Estimator input function:
-keras_inception_v3.input_names # print out: ['input_1']
-# Once we have the input name(s), we can create the input function, for example,
-# for input(s) in the format of numpy ndarray:
-train_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={"input_1": train_data},
- y=train_labels,
- num_epochs=1,
- shuffle=False)
-# To train, we call Estimator's train function:
-est_inception_v3.train(input_fn=train_input_fn, steps=2000)
-```
-Note that the names of feature columns and labels of a keras estimator come from
-the corresponding compiled keras model. For example, the input key names for
-`train_input_fn` above can be obtained from `keras_inception_v3.input_names`,
-and similarly, the predicted output names can be obtained from
-`keras_inception_v3.output_names`.
-
-For more details, please refer to the documentation for
-`tf.keras.estimator.model_to_estimator`.
diff --git a/tensorflow/docs_src/guide/faq.md b/tensorflow/docs_src/guide/faq.md
deleted file mode 100644
index 8370097560..0000000000
--- a/tensorflow/docs_src/guide/faq.md
+++ /dev/null
@@ -1,296 +0,0 @@
-# Frequently Asked Questions
-
-This document provides answers to some of the frequently asked questions about
-TensorFlow. If you have a question that is not covered here, you might find an
-answer on one of the TensorFlow @{$about$community resources}.
-
-[TOC]
-
-## Features and Compatibility
-
-#### Can I run distributed training on multiple computers?
-
-Yes! TensorFlow gained
-@{$distributed$support for distributed computation} in
-version 0.8. TensorFlow now supports multiple devices (CPUs and GPUs) in one or
-more computers.
-
-#### Does TensorFlow work with Python 3?
-
-As of the 0.6.0 release timeframe (Early December 2015), we do support Python
-3.3+.
-
-## Building a TensorFlow graph
-
-See also the
-@{$python/framework$API documentation on building graphs}.
-
-#### Why does `c = tf.matmul(a, b)` not execute the matrix multiplication immediately?
-
-In the TensorFlow Python API, `a`, `b`, and `c` are
-`tf.Tensor` objects. A `Tensor` object is
-a symbolic handle to the result of an operation, but does not actually hold the
-values of the operation's output. Instead, TensorFlow encourages users to build
-up complicated expressions (such as entire neural networks and its gradients) as
-a dataflow graph. You then offload the computation of the entire dataflow graph
-(or a subgraph of it) to a TensorFlow
-`tf.Session`, which is able to execute the
-whole computation much more efficiently than executing the operations
-one-by-one.
-
-#### How are devices named?
-
-The supported device names are `"/device:CPU:0"` (or `"/cpu:0"`) for the CPU
-device, and `"/device:GPU:i"` (or `"/gpu:i"`) for the *i*th GPU device.
-
-#### How do I place operations on a particular device?
-
-To place a group of operations on a device, create them within a
-`tf.device` context. See
-the how-to documentation on
-@{$using_gpu$using GPUs with TensorFlow} for details of how
-TensorFlow assigns operations to devices, and the
-@{$deep_cnn$CIFAR-10 tutorial} for an example model that
-uses multiple GPUs.
-
-
-## Running a TensorFlow computation
-
-See also the
-@{$python/client$API documentation on running graphs}.
-
-#### What's the deal with feeding and placeholders?
-
-Feeding is a mechanism in the TensorFlow Session API that allows you to
-substitute different values for one or more tensors at run time. The `feed_dict`
-argument to `tf.Session.run` is a
-dictionary that maps `tf.Tensor` objects to
-numpy arrays (and some other types), which will be used as the values of those
-tensors in the execution of a step.
-
-#### What is the difference between `Session.run()` and `Tensor.eval()`?
-
-If `t` is a `tf.Tensor` object,
-`tf.Tensor.eval` is shorthand for
-`tf.Session.run`, where `sess` is the
-current `tf.get_default_session`. The
-two following snippets of code are equivalent:
-
-```python
-# Using `Session.run()`.
-sess = tf.Session()
-c = tf.constant(5.0)
-print(sess.run(c))
-
-# Using `Tensor.eval()`.
-c = tf.constant(5.0)
-with tf.Session():
- print(c.eval())
-```
-
-In the second example, the session acts as a
-[context manager](https://docs.python.org/2.7/reference/compound_stmts.html#with),
-which has the effect of installing it as the default session for the lifetime of
-the `with` block. The context manager approach can lead to more concise code for
-simple use cases (like unit tests); if your code deals with multiple graphs and
-sessions, it may be more straightforward to make explicit calls to
-`Session.run()`.
-
-#### Do Sessions have a lifetime? What about intermediate tensors?
-
-Sessions can own resources, such as
-`tf.Variable`,
-`tf.QueueBase`, and
-`tf.ReaderBase`. These resources can sometimes use
-a significant amount of memory, and can be released when the session is closed by calling
-`tf.Session.close`.
-
-The intermediate tensors that are created as part of a call to
-@{$python/client$`Session.run()`} will be freed at or before the
-end of the call.
-
-#### Does the runtime parallelize parts of graph execution?
-
-The TensorFlow runtime parallelizes graph execution across many different
-dimensions:
-
-* The individual ops have parallel implementations, using multiple cores in a
- CPU, or multiple threads in a GPU.
-* Independent nodes in a TensorFlow graph can run in parallel on multiple
- devices, which makes it possible to speed up
- @{$deep_cnn$CIFAR-10 training using multiple GPUs}.
-* The Session API allows multiple concurrent steps (i.e. calls to
- `tf.Session.run` in parallel). This
- enables the runtime to get higher throughput, if a single step does not use
- all of the resources in your computer.
-
-#### Which client languages are supported in TensorFlow?
-
-TensorFlow is designed to support multiple client languages.
-Currently, the best-supported client language is [Python](../api_docs/python/index.md). Experimental interfaces for
-executing and constructing graphs are also available for
-[C++](../api_docs/cc/index.md), [Java](../api_docs/java/reference/org/tensorflow/package-summary.html) and [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go).
-
-TensorFlow also has a
-[C-based client API](https://www.tensorflow.org/code/tensorflow/c/c_api.h)
-to help build support for more client languages. We invite contributions of new
-language bindings.
-
-Bindings for various other languages (such as [C#](https://github.com/migueldeicaza/TensorFlowSharp), [Julia](https://github.com/malmaud/TensorFlow.jl), [Ruby](https://github.com/somaticio/tensorflow.rb) and [Scala](https://github.com/eaplatanios/tensorflow_scala)) created and supported by the open source community build on top of the C API supported by the TensorFlow maintainers.
-
-#### Does TensorFlow make use of all the devices (GPUs and CPUs) available on my machine?
-
-TensorFlow supports multiple GPUs and CPUs. See the how-to documentation on
-@{$using_gpu$using GPUs with TensorFlow} for details of how
-TensorFlow assigns operations to devices, and the
-@{$deep_cnn$CIFAR-10 tutorial} for an example model that
-uses multiple GPUs.
-
-Note that TensorFlow only uses GPU devices with a compute capability greater
-than 3.5.
-
-#### Why does `Session.run()` hang when using a reader or a queue?
-
-The `tf.ReaderBase` and
-`tf.QueueBase` classes provide special operations that
-can *block* until input (or free space in a bounded queue) becomes
-available. These operations allow you to build sophisticated
-@{$reading_data$input pipelines}, at the cost of making the
-TensorFlow computation somewhat more complicated. See the how-to documentation
-for
-@{$reading_data#creating_threads_to_prefetch_using_queuerunner_objects$using `QueueRunner` objects to drive queues and readers}
-for more information on how to use them.
-
-## Variables
-
-See also the how-to documentation on @{$variables$variables} and
-@{$python/state_ops$the API documentation for variables}.
-
-#### What is the lifetime of a variable?
-
-A variable is created when you first run the
-`tf.Variable.initializer`
-operation for that variable in a session. It is destroyed when that
-`tf.Session.close`.
-
-#### How do variables behave when they are concurrently accessed?
-
-Variables allow concurrent read and write operations. The value read from a
-variable may change if it is concurrently updated. By default, concurrent
-assignment operations to a variable are allowed to run with no mutual exclusion.
-To acquire a lock when assigning to a variable, pass `use_locking=True` to
-`tf.Variable.assign`.
-
-## Tensor shapes
-
-See also the
-`tf.TensorShape`.
-
-#### How can I determine the shape of a tensor in Python?
-
-In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true)
-shape. The static shape can be read using the
-`tf.Tensor.get_shape`
-method: this shape is inferred from the operations that were used to create the
-tensor, and may be partially complete (the static-shape may contain `None`). If
-the static shape is not fully defined, the dynamic shape of a `tf.Tensor`, `t`
-can be determined using `tf.shape(t)`.
-
-#### What is the difference between `x.set_shape()` and `x = tf.reshape(x)`?
-
-The `tf.Tensor.set_shape` method updates
-the static shape of a `Tensor` object, and it is typically used to provide
-additional shape information when this cannot be inferred directly. It does not
-change the dynamic shape of the tensor.
-
-The `tf.reshape` operation creates
-a new tensor with a different dynamic shape.
-
-#### How do I build a graph that works with variable batch sizes?
-
-It is often useful to build a graph that works with variable batch sizes
-so that the same code can be used for (mini-)batch training, and
-single-instance inference. The resulting graph can be
-`tf.Graph.as_graph_def`
-and
-`tf.import_graph_def`.
-
-When building a variable-size graph, the most important thing to remember is not
-to encode the batch size as a Python constant, but instead to use a symbolic
-`Tensor` to represent it. The following tips may be useful:
-
-* Use [`batch_size = tf.shape(input)[0]`](../api_docs/python/array_ops.md#shape)
- to extract the batch dimension from a `Tensor` called `input`, and store it in
- a `Tensor` called `batch_size`.
-
-* Use `tf.reduce_mean` instead
- of `tf.reduce_sum(...) / batch_size`.
-
-
-## TensorBoard
-
-#### How can I visualize a TensorFlow graph?
-
-See the @{$graph_viz$graph visualization tutorial}.
-
-#### What is the simplest way to send data to TensorBoard?
-
-Add summary ops to your TensorFlow graph, and write
-these summaries to a log directory. Then, start TensorBoard using
-
- python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
-
-For more details, see the
-@{$summaries_and_tensorboard$Summaries and TensorBoard tutorial}.
-
-#### Every time I launch TensorBoard, I get a network security popup!
-
-You can change TensorBoard to serve on localhost rather than '0.0.0.0' by
-the flag --host=localhost. This should quiet any security warnings.
-
-## Extending TensorFlow
-
-See the how-to documentation for
-@{$adding_an_op$adding a new operation to TensorFlow}.
-
-#### My data is in a custom format. How do I read it using TensorFlow?
-
-There are three main options for dealing with data in a custom format.
-
-The easiest option is to write parsing code in Python that transforms the data
-into a numpy array. Then, use `tf.data.Dataset.from_tensor_slices` to
-create an input pipeline from the in-memory data.
-
-If your data doesn't fit in memory, try doing the parsing in the Dataset
-pipeline. Start with an appropriate file reader, like
-`tf.data.TextLineDataset`. Then convert the dataset by mapping
-`tf.data.Dataset.map` appropriate operations over it.
-Prefer predefined TensorFlow operations such as `tf.decode_raw`,
-`tf.decode_csv`, `tf.parse_example`, or `tf.image.decode_png`.
-
-If your data is not easily parsable with the built-in TensorFlow operations,
-consider converting it, offline, to a format that is easily parsable, such
-as `tf.python_io.TFRecordWriter` format.
-
-The most efficient method to customize the parsing behavior is to
-@{$adding_an_op$add a new op written in C++} that parses your
-data format. The @{$new_data_formats$guide to handling new data formats} has
-more information about the steps for doing this.
-
-
-## Miscellaneous
-
-#### What is TensorFlow's coding style convention?
-
-The TensorFlow Python API adheres to the
-[PEP8](https://www.python.org/dev/peps/pep-0008/) conventions.<sup>*</sup> In
-particular, we use `CamelCase` names for classes, and `snake_case` names for
-functions, methods, and properties. We also adhere to the
-[Google Python style guide](https://google.github.io/styleguide/pyguide.html).
-
-The TensorFlow C++ code base adheres to the
-[Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
-
-(<sup>*</sup> With one exception: we use 2-space indentation instead of 4-space
-indentation.)
-
diff --git a/tensorflow/docs_src/guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md
deleted file mode 100644
index b189c4334e..0000000000
--- a/tensorflow/docs_src/guide/feature_columns.md
+++ /dev/null
@@ -1,572 +0,0 @@
-# Feature Columns
-
-This document details feature columns. Think of **feature columns** as the
-intermediaries between raw data and Estimators. Feature columns are very rich,
-enabling you to transform a diverse range of raw data into formats that
-Estimators can use, allowing easy experimentation.
-
-In @{$premade_estimators$Premade Estimators}, we used the premade
-Estimator, `tf.estimator.DNNClassifier` to train a model to
-predict different types of Iris flowers from four input features. That example
-created only numerical feature columns (of type
-`tf.feature_column.numeric_column`). Although numerical feature columns model
-the lengths of petals and sepals effectively, real world data sets contain all
-kinds of features, many of which are non-numerical.
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/feature_cloud.jpg">
-</div>
-<div style="text-align: center">
-Some real-world features (such as, longitude) are numerical, but many are not.
-</div>
-
-## Input to a Deep Neural Network
-
-What kind of data can a deep neural network operate on? The answer
-is, of course, numbers (for example, `tf.float32`). After all, every neuron in
-a neural network performs multiplication and addition operations on weights and
-input data. Real-life input data, however, often contains non-numerical
-(categorical) data. For example, consider a `product_class` feature that can
-contain the following three non-numerical values:
-
-* `kitchenware`
-* `electronics`
-* `sports`
-
-ML models generally represent categorical values as simple vectors in which a
-1 represents the presence of a value and a 0 represents the absence of a value.
-For example, when `product_class` is set to `sports`, an ML model would usually
-represent `product_class` as `[0, 0, 1]`, meaning:
-
-* `0`: `kitchenware` is absent
-* `0`: `electronics` is absent
-* `1`: `sports` is present
-
-So, although raw data can be numerical or categorical, an ML model represents
-all features as numbers.
-
-## Feature Columns
-
-As the following figure suggests, you specify the input to a model through the
-`feature_columns` argument of an Estimator (`DNNClassifier` for Iris).
-Feature Columns bridge input data (as returned by `input_fn`) with your model.
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/inputs_to_model_bridge.jpg">
-</div>
-<div style="text-align: center">
-Feature columns bridge raw data with the data your model needs.
-</div>
-
-To create feature columns, call functions from the
-`tf.feature_column` module. This document explains nine of the functions in
-that module. As the following figure shows, all nine functions return either a
-Categorical-Column or a Dense-Column object, except `bucketized_column`, which
-inherits from both classes:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/some_constructors.jpg">
-</div>
-<div style="text-align: center">
-Feature column methods fall into two main categories and one hybrid category.
-</div>
-
-Let's look at these functions in more detail.
-
-### Numeric column
-
-The Iris classifier calls the `tf.feature_column.numeric_column` function for
-all input features:
-
- * `SepalLength`
- * `SepalWidth`
- * `PetalLength`
- * `PetalWidth`
-
-Although `tf.numeric_column` provides optional arguments, calling
-`tf.numeric_column` without any arguments, as follows, is a fine way to specify
-a numerical value with the default data type (`tf.float32`) as input to your
-model:
-
-```python
-# Defaults to a tf.float32 scalar.
-numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength")
-```
-
-To specify a non-default numerical data type, use the `dtype` argument. For
-example:
-
-``` python
-# Represent a tf.float64 scalar.
-numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength",
- dtype=tf.float64)
-```
-
-By default, a numeric column creates a single value (scalar). Use the shape
-argument to specify another shape. For example:
-
-<!--TODO(markdaoust) link to full example-->
-```python
-# Represent a 10-element vector in which each cell contains a tf.float32.
-vector_feature_column = tf.feature_column.numeric_column(key="Bowling",
- shape=10)
-
-# Represent a 10x5 matrix in which each cell contains a tf.float32.
-matrix_feature_column = tf.feature_column.numeric_column(key="MyMatrix",
- shape=[10,5])
-```
-### Bucketized column
-
-Often, you don't want to feed a number directly into the model, but instead
-split its value into different categories based on numerical ranges. To do so,
-create a `tf.feature_column.bucketized_column`. For
-example, consider raw data that represents the year a house was built. Instead
-of representing that year as a scalar numeric column, we could split the year
-into the following four buckets:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/bucketized_column.jpg">
-</div>
-<div style="text-align: center">
-Dividing year data into four buckets.
-</div>
-
-The model will represent the buckets as follows:
-
-|Date Range |Represented as... |
-|:----------|:-----------------|
-|< 1960 | [1, 0, 0, 0] |
-|>= 1960 but < 1980 | [0, 1, 0, 0] |
-|>= 1980 but < 2000 | [0, 0, 1, 0] |
-|>= 2000 | [0, 0, 0, 1] |
-
-Why would you want to split a number—a perfectly valid input to your
-model—into a categorical value? Well, notice that the categorization splits a
-single input number into a four-element vector. Therefore, the model now can
-learn _four individual weights_ rather than just one; four weights creates a
-richer model than one weight. More importantly, bucketizing enables the model
-to clearly distinguish between different year categories since only one of the
-elements is set (1) and the other three elements are cleared (0). For example,
-when we just use a single number (a year) as input, a linear model can only
-learn a linear relationship. So, bucketing provides the model with additional
-flexibility that the model can use to learn.
-
-The following code demonstrates how to create a bucketized feature:
-
-<!--TODO(markdaoust) link to full example - housing price grid?-->
-```python
-# First, convert the raw input to a numeric column.
-numeric_feature_column = tf.feature_column.numeric_column("Year")
-
-# Then, bucketize the numeric column on the years 1960, 1980, and 2000.
-bucketized_feature_column = tf.feature_column.bucketized_column(
- source_column = numeric_feature_column,
- boundaries = [1960, 1980, 2000])
-```
-Note that specifying a _three_-element boundaries vector creates a
-_four_-element bucketized vector.
-
-
-### Categorical identity column
-
-**Categorical identity columns** can be seen as a special case of bucketized
-columns. In traditional bucketized columns, each bucket represents a range of
-values (for example, from 1960 to 1979). In a categorical identity column, each
-bucket represents a single, unique integer. For example, let's say you want to
-represent the integer range `[0, 4)`. That is, you want to represent the
-integers 0, 1, 2, or 3. In this case, the categorical identity mapping looks
-like this:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg">
-</div>
-<div style="text-align: center">
-A categorical identity column mapping. Note that this is a one-hot
-encoding, not a binary numerical encoding.
-</div>
-
-As with bucketized columns, a model can learn a separate weight for each class
-in a categorical identity column. For example, instead of using a string to
-represent the `product_class`, let's represent each class with a unique integer
-value. That is:
-
-* `0="kitchenware"`
-* `1="electronics"`
-* `2="sport"`
-
-Call `tf.feature_column.categorical_column_with_identity` to implement a
-categorical identity column. For example:
-
-``` python
-# Create categorical output for an integer feature named "my_feature_b",
-# The values of my_feature_b must be >= 0 and < num_buckets
-identity_feature_column = tf.feature_column.categorical_column_with_identity(
- key='my_feature_b',
- num_buckets=4) # Values [0, 4)
-
-# In order for the preceding call to work, the input_fn() must return
-# a dictionary containing 'my_feature_b' as a key. Furthermore, the values
-# assigned to 'my_feature_b' must belong to the set [0, 4).
-def input_fn():
- ...
- return ({ 'my_feature_a':[7, 9, 5, 2], 'my_feature_b':[3, 1, 2, 2] },
- [Label_values])
-```
-
-### Categorical vocabulary column
-
-We cannot input strings directly to a model. Instead, we must first map strings
-to numeric or categorical values. Categorical vocabulary columns provide a good
-way to represent strings as a one-hot vector. For example:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/categorical_column_with_vocabulary.jpg">
-</div>
-<div style="text-align: center">
-Mapping string values to vocabulary columns.
-</div>
-
-As you can see, categorical vocabulary columns are kind of an enum version of
-categorical identity columns. TensorFlow provides two different functions to
-create categorical vocabulary columns:
-
-* `tf.feature_column.categorical_column_with_vocabulary_list`
-* `tf.feature_column.categorical_column_with_vocabulary_file`
-
-`categorical_column_with_vocabulary_list` maps each string to an integer based
-on an explicit vocabulary list. For example:
-
-```python
-# Given input "feature_name_from_input_fn" which is a string,
-# create a categorical feature by mapping the input to one of
-# the elements in the vocabulary list.
-vocabulary_feature_column =
- tf.feature_column.categorical_column_with_vocabulary_list(
- key=feature_name_from_input_fn,
- vocabulary_list=["kitchenware", "electronics", "sports"])
-```
-
-The preceding function is pretty straightforward, but it has a significant
-drawback. Namely, there's way too much typing when the vocabulary list is long.
-For these cases, call
-`tf.feature_column.categorical_column_with_vocabulary_file` instead, which lets
-you place the vocabulary words in a separate file. For example:
-
-```python
-
-# Given input "feature_name_from_input_fn" which is a string,
-# create a categorical feature to our model by mapping the input to one of
-# the elements in the vocabulary file
-vocabulary_feature_column =
- tf.feature_column.categorical_column_with_vocabulary_file(
- key=feature_name_from_input_fn,
- vocabulary_file="product_class.txt",
- vocabulary_size=3)
-```
-
-`product_class.txt` should contain one line for each vocabulary element. In our
-case:
-
-```None
-kitchenware
-electronics
-sports
-```
-
-### Hashed Column
-
-So far, we've worked with a naively small number of categories. For example,
-our product_class example has only 3 categories. Often though, the number of
-categories can be so big that it's not possible to have individual categories
-for each vocabulary word or integer because that would consume too much memory.
-For these cases, we can instead turn the question around and ask, "How many
-categories am I willing to have for my input?" In fact, the
-`tf.feature_column.categorical_column_with_hash_bucket` function enables you
-to specify the number of categories. For this type of feature column the model
-calculates a hash value of the input, then puts it into one of
-the `hash_bucket_size` categories using the modulo operator, as in the following
-pseudocode:
-
-```python
-# pseudocode
-feature_id = hash(raw_feature) % hash_bucket_size
-```
-
-The code to create the `feature_column` might look something like this:
-
-``` python
-hashed_feature_column =
- tf.feature_column.categorical_column_with_hash_bucket(
- key = "some_feature",
- hash_bucket_size = 100) # The number of categories
-```
-At this point, you might rightfully think: "This is crazy!" After all, we are
-forcing the different input values to a smaller set of categories. This means
-that two probably unrelated inputs will be mapped to the same
-category, and consequently mean the same thing to the neural network. The
-following figure illustrates this dilemma, showing that kitchenware and sports
-both get assigned to category (hash bucket) 12:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/hashed_column.jpg">
-</div>
-<div style="text-align: center">
-Representing data with hash buckets.
-</div>
-
-As with many counterintuitive phenomena in machine learning, it turns out that
-hashing often works well in practice. That's because hash categories provide
-the model with some separation. The model can use additional features to further
-separate kitchenware from sports.
-
-### Crossed column
-
-Combining features into a single feature, better known as
-[feature crosses](https://developers.google.com/machine-learning/glossary/#feature_cross),
-enables the model to learn separate weights for each combination of
-features.
-
-More concretely, suppose we want our model to calculate real estate prices in
-Atlanta, GA. Real-estate prices within this city vary greatly depending on
-location. Representing latitude and longitude as separate features isn't very
-useful in identifying real-estate location dependencies; however, crossing
-latitude and longitude into a single feature can pinpoint locations. Suppose we
-represent Atlanta as a grid of 100x100 rectangular sections, identifying each
-of the 10,000 sections by a feature cross of latitude and longitude. This
-feature cross enables the model to train on pricing conditions related to each
-individual section, which is a much stronger signal than latitude and longitude
-alone.
-
-The following figure shows our plan, with the latitude & longitude values for
-the corners of the city in red text:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/Atlanta.jpg">
-</div>
-<div style="text-align: center">
-Map of Atlanta. Imagine this map divided into 10,000 sections of
-equal size.
-</div>
-
-For the solution, we used a combination of the `bucketized_column` we looked at
-earlier, with the `tf.feature_column.crossed_column` function.
-
-<!--TODO(markdaoust) link to full example-->
-
-``` python
-def make_dataset(latitude, longitude, labels):
- assert latitude.shape == longitude.shape == labels.shape
-
- features = {'latitude': latitude.flatten(),
- 'longitude': longitude.flatten()}
- labels=labels.flatten()
-
- return tf.data.Dataset.from_tensor_slices((features, labels))
-
-
-# Bucketize the latitude and longitude using the `edges`
-latitude_bucket_fc = tf.feature_column.bucketized_column(
- tf.feature_column.numeric_column('latitude'),
- list(atlanta.latitude.edges))
-
-longitude_bucket_fc = tf.feature_column.bucketized_column(
- tf.feature_column.numeric_column('longitude'),
- list(atlanta.longitude.edges))
-
-# Cross the bucketized columns, using 5000 hash bins.
-crossed_lat_lon_fc = tf.feature_column.crossed_column(
- [latitude_bucket_fc, longitude_bucket_fc], 5000)
-
-fc = [
- latitude_bucket_fc,
- longitude_bucket_fc,
- crossed_lat_lon_fc]
-
-# Build and train the Estimator.
-est = tf.estimator.LinearRegressor(fc, ...)
-```
-
-You may create a feature cross from either of the following:
-
-* Feature names; that is, names from the `dict` returned from `input_fn`.
-* Any categorical column, except `categorical_column_with_hash_bucket`
- (since `crossed_column` hashes the input).
-
-When the feature columns `latitude_bucket_fc` and `longitude_bucket_fc` are
-crossed, TensorFlow will create `(latitude_fc, longitude_fc)` pairs for each
-example. This would produce a full grid of possibilities as follows:
-
-``` None
- (0,0), (0,1)... (0,99)
- (1,0), (1,1)... (1,99)
- ... ... ...
-(99,0), (99,1)...(99, 99)
-```
-
-Except that a full grid would only be tractable for inputs with limited
-vocabularies. Instead of building this, potentially huge, table of inputs,
-the `crossed_column` only builds the number requested by the `hash_bucket_size`
-argument. The feature column assigns an example to a index by running a hash
-function on the tuple of inputs, followed by a modulo operation with
-`hash_bucket_size`.
-
-As discussed earlier, performing the
-hash and modulo function limits the number of categories, but can cause category
-collisions; that is, multiple (latitude, longitude) feature crosses will end
-up in the same hash bucket. In practice though, performing feature crosses
-still adds significant value to the learning capability of your models.
-
-Somewhat counterintuitively, when creating feature crosses, you typically still
-should include the original (uncrossed) features in your model (as in the
-preceding code snippet). The independent latitude and longitude features help the
-model distinguish between examples where a hash collision has occurred in the
-crossed feature.
-
-## Indicator and embedding columns
-
-Indicator columns and embedding columns never work on features directly, but
-instead take categorical columns as input.
-
-When using an indicator column, we're telling TensorFlow to do exactly what
-we've seen in our categorical product_class example. That is, an
-**indicator column** treats each category as an element in a one-hot vector,
-where the matching category has value 1 and the rest have 0s:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg">
-</div>
-<div style="text-align: center">
-Representing data in indicator columns.
-</div>
-
-Here's how you create an indicator column by calling
-`tf.feature_column.indicator_column`:
-
-``` python
-categorical_column = ... # Create any type of categorical column.
-
-# Represent the categorical column as an indicator column.
-indicator_column = tf.feature_column.indicator_column(categorical_column)
-```
-
-Now, suppose instead of having just three possible classes, we have a million.
-Or maybe a billion. For a number of reasons, as the number of categories grow
-large, it becomes infeasible to train a neural network using indicator columns.
-
-We can use an embedding column to overcome this limitation. Instead of
-representing the data as a one-hot vector of many dimensions, an
-**embedding column** represents that data as a lower-dimensional, ordinary
-vector in which each cell can contain any number, not just 0 or 1. By
-permitting a richer palette of numbers for every cell, an embedding column
-contains far fewer cells than an indicator column.
-
-Let's look at an example comparing indicator and embedding columns. Suppose our
-input examples consist of different words from a limited palette of only 81
-words. Further suppose that the data set provides the following input
-words in 4 separate examples:
-
-* `"dog"`
-* `"spoon"`
-* `"scissors"`
-* `"guitar"`
-
-In that case, the following figure illustrates the processing path for
-embedding columns or indicator columns.
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/feature_columns/embedding_vs_indicator.jpg">
-</div>
-<div style="text-align: center">
-An embedding column stores categorical data in a lower-dimensional
-vector than an indicator column. (We just placed random numbers into the
-embedding vectors; training determines the actual numbers.)
-</div>
-
-When an example is processed, one of the `categorical_column_with...` functions
-maps the example string to a numerical categorical value. For example, a
-function maps "spoon" to `[32]`. (The 32 comes from our imagination—the actual
-values depend on the mapping function.) You may then represent these numerical
-categorical values in either of the following two ways:
-
-* As an indicator column. A function converts each numeric categorical value
- into an 81-element vector (because our palette consists of 81 words), placing
- a 1 in the index of the categorical value (0, 32, 79, 80) and a 0 in all the
- other positions.
-
-* As an embedding column. A function uses the numerical categorical values
- `(0, 32, 79, 80)` as indices to a lookup table. Each slot in that lookup table
- contains a 3-element vector.
-
-How do the values in the embeddings vectors magically get assigned? Actually,
-the assignments happen during training. That is, the model learns the best way
-to map your input numeric categorical values to the embeddings vector value in
-order to solve your problem. Embedding columns increase your model's
-capabilities, since an embeddings vector learns new relationships between
-categories from the training data.
-
-Why is the embedding vector size 3 in our example? Well, the following "formula"
-provides a general rule of thumb about the number of embedding dimensions:
-
-```python
-embedding_dimensions = number_of_categories**0.25
-```
-
-That is, the embedding vector dimension should be the 4th root of the number of
-categories. Since our vocabulary size in this example is 81, the recommended
-number of dimensions is 3:
-
-``` python
-3 = 81**0.25
-```
-Note that this is just a general guideline; you can set the number of embedding
-dimensions as you please.
-
-Call `tf.feature_column.embedding_column` to create an `embedding_column` as
-suggested by the following snippet:
-
-``` python
-categorical_column = ... # Create any categorical column
-
-# Represent the categorical column as an embedding column.
-# This means creating an embedding vector lookup table with one element for each category.
-embedding_column = tf.feature_column.embedding_column(
- categorical_column=categorical_column,
- dimension=embedding_dimensions)
-```
-
-@{$guide/embedding$Embeddings} is a significant topic within machine
-learning. This information was just to get you started using them as feature
-columns.
-
-## Passing feature columns to Estimators
-
-As the following list indicates, not all Estimators permit all types of
-`feature_columns` argument(s):
-
-* `tf.estimator.LinearClassifier` and
- `tf.estimator.LinearRegressor`: Accept all types of
- feature column.
-* `tf.estimator.DNNClassifier` and
- `tf.estimator.DNNRegressor`: Only accept dense columns. Other
- column types must be wrapped in either an `indicator_column` or
- `embedding_column`.
-* `tf.estimator.DNNLinearCombinedClassifier` and
- `tf.estimator.DNNLinearCombinedRegressor`:
- * The `linear_feature_columns` argument accepts any feature column type.
- * The `dnn_feature_columns` argument only accepts dense columns.
-
-## Other Sources
-
-For more examples on feature columns, view the following:
-
-* The @{$low_level_intro#feature_columns$Low Level Introduction} demonstrates how
- experiment directly with `feature_columns` using TensorFlow's low level APIs.
-* The [Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
- solves a binary classification problem using `feature_columns` on a variety of
- input data types.
-
-To learn more about embeddings, see the following:
-
-* [Deep Learning, NLP, and representations](http://colah.github.io/posts/2014-07-NLP-RNNs-Representations/)
- (Chris Olah's blog)
-* The TensorFlow [Embedding Projector](http://projector.tensorflow.org)
diff --git a/tensorflow/docs_src/guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md
deleted file mode 100644
index 97b0e2d4de..0000000000
--- a/tensorflow/docs_src/guide/graph_viz.md
+++ /dev/null
@@ -1,317 +0,0 @@
-# TensorBoard: Graph Visualization
-
-TensorFlow computation graphs are powerful but complicated. The graph visualization can help you understand and debug them. Here's an example of the visualization at work.
-
-![Visualization of a TensorFlow graph](https://www.tensorflow.org/images/graph_vis_animation.gif "Visualization of a TensorFlow graph")
-*Visualization of a TensorFlow graph.*
-
-To see your own graph, run TensorBoard pointing it to the log directory of the job, click on the graph tab on the top pane and select the appropriate run using the menu at the upper left corner. For in depth information on how to run TensorBoard and make sure you are logging all the necessary information, see @{$summaries_and_tensorboard$TensorBoard: Visualizing Learning}.
-
-## Name scoping and nodes
-
-Typical TensorFlow graphs can have many thousands of nodes--far too many to see
-easily all at once, or even to lay out using standard graph tools. To simplify,
-variable names can be scoped and the visualization uses this information to
-define a hierarchy on the nodes in the graph. By default, only the top of this
-hierarchy is shown. Here is an example that defines three operations under the
-`hidden` name scope using
-`tf.name_scope`:
-
-```python
-import tensorflow as tf
-
-with tf.name_scope('hidden') as scope:
- a = tf.constant(5, name='alpha')
- W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0), name='weights')
- b = tf.Variable(tf.zeros([1]), name='biases')
-```
-
-This results in the following three op names:
-
-* `hidden/alpha`
-* `hidden/weights`
-* `hidden/biases`
-
-By default, the visualization will collapse all three into a node labeled `hidden`.
-The extra detail isn't lost. You can double-click, or click
-on the orange `+` sign in the top right to expand the node, and then you'll see
-three subnodes for `alpha`, `weights` and `biases`.
-
-Here's a real-life example of a more complicated node in its initial and
-expanded states.
-
-<table width="100%;">
- <tr>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/pool1_collapsed.png" alt="Unexpanded name scope" title="Unexpanded name scope" />
- </td>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/pool1_expanded.png" alt="Expanded name scope" title="Expanded name scope" />
- </td>
- </tr>
- <tr>
- <td style="width: 50%;">
- Initial view of top-level name scope <code>pool_1</code>. Clicking on the orange <code>+</code> button on the top right or double-clicking on the node itself will expand it.
- </td>
- <td style="width: 50%;">
- Expanded view of <code>pool_1</code> name scope. Clicking on the orange <code>-</code> button on the top right or double-clicking on the node itself will collapse the name scope.
- </td>
- </tr>
-</table>
-
-Grouping nodes by name scopes is critical to making a legible graph. If you're
-building a model, name scopes give you control over the resulting visualization.
-**The better your name scopes, the better your visualization.**
-
-The figure above illustrates a second aspect of the visualization. TensorFlow
-graphs have two kinds of connections: data dependencies and control
-dependencies. Data dependencies show the flow of tensors between two ops and
-are shown as solid arrows, while control dependencies use dotted lines. In the
-expanded view (right side of the figure above) all the connections are data
-dependencies with the exception of the dotted line connecting `CheckNumerics`
-and `control_dependency`.
-
-There's a second trick to simplifying the layout. Most TensorFlow graphs have a
-few nodes with many connections to other nodes. For example, many nodes might
-have a control dependency on an initialization step. Drawing all edges between
-the `init` node and its dependencies would create a very cluttered view.
-
-To reduce clutter, the visualization separates out all high-degree nodes to an
-*auxiliary* area on the right and doesn't draw lines to represent their edges.
-Instead of lines, we draw small *node icons* to indicate the connections.
-Separating out the auxiliary nodes typically doesn't remove critical
-information since these nodes are usually related to bookkeeping functions.
-See [Interaction](#interaction) for how to move nodes between the main graph
-and the auxiliary area.
-
-<table width="100%;">
- <tr>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/conv_1.png" alt="conv_1 is part of the main graph" title="conv_1 is part of the main graph" />
- </td>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/save.png" alt="save is extracted as auxiliary node" title="save is extracted as auxiliary node" />
- </td>
- </tr>
- <tr>
- <td style="width: 50%;">
- Node <code>conv_1</code> is connected to <code>save</code>. Note the little <code>save</code> node icon on its right.
- </td>
- <td style="width: 50%;">
- <code>save</code> has a high degree, and will appear as an auxiliary node. The connection with <code>conv_1</code> is shown as a node icon on its left. To further reduce clutter, since <code>save</code> has a lot of connections, we show the first 5 and abbreviate the others as <code>... 12 more</code>.
- </td>
- </tr>
-</table>
-
-One last structural simplification is *series collapsing*. Sequential
-motifs--that is, nodes whose names differ by a number at the end and have
-isomorphic structures--are collapsed into a single *stack* of nodes, as shown
-below. For networks with long sequences, this greatly simplifies the view. As
-with hierarchical nodes, double-clicking expands the series. See
-[Interaction](#interaction) for how to disable/enable series collapsing for a
-specific set of nodes.
-
-<table width="100%;">
- <tr>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/series.png" alt="Sequence of nodes" title="Sequence of nodes" />
- </td>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/series_expanded.png" alt="Expanded sequence of nodes" title="Expanded sequence of nodes" />
- </td>
- </tr>
- <tr>
- <td style="width: 50%;">
- A collapsed view of a node sequence.
- </td>
- <td style="width: 50%;">
- A small piece of the expanded view, after double-click.
- </td>
- </tr>
-</table>
-
-Finally, as one last aid to legibility, the visualization uses special icons
-for constants and summary nodes. To summarize, here's a table of node symbols:
-
-Symbol | Meaning
---- | ---
-![Name scope](https://www.tensorflow.org/images/namespace_node.png "Name scope") | *High-level* node representing a name scope. Double-click to expand a high-level node.
-![Sequence of unconnected nodes](https://www.tensorflow.org/images/horizontal_stack.png "Sequence of unconnected nodes") | Sequence of numbered nodes that are not connected to each other.
-![Sequence of connected nodes](https://www.tensorflow.org/images/vertical_stack.png "Sequence of connected nodes") | Sequence of numbered nodes that are connected to each other.
-![Operation node](https://www.tensorflow.org/images/op_node.png "Operation node") | An individual operation node.
-![Constant node](https://www.tensorflow.org/images/constant.png "Constant node") | A constant.
-![Summary node](https://www.tensorflow.org/images/summary.png "Summary node") | A summary node.
-![Data flow edge](https://www.tensorflow.org/images/dataflow_edge.png "Data flow edge") | Edge showing the data flow between operations.
-![Control dependency edge](https://www.tensorflow.org/images/control_edge.png "Control dependency edge") | Edge showing the control dependency between operations.
-![Reference edge](https://www.tensorflow.org/images/reference_edge.png "Reference edge") | A reference edge showing that the outgoing operation node can mutate the incoming tensor.
-
-## Interaction {#interaction}
-
-Navigate the graph by panning and zooming. Click and drag to pan, and use a
-scroll gesture to zoom. Double-click on a node, or click on its `+` button, to
-expand a name scope that represents a group of operations. To easily keep
-track of the current viewpoint when zooming and panning, there is a minimap in
-the bottom right corner.
-
-To close an open node, double-click it again or click its `-` button. You can
-also click once to select a node. It will turn a darker color, and details
-about it and the nodes it connects to will appear in the info card at upper
-right corner of the visualization.
-
-<table width="100%;">
- <tr>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/infocard.png" alt="Info card of a name scope" title="Info card of a name scope" />
- </td>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/infocard_op.png" alt="Info card of operation node" title="Info card of operation node" />
- </td>
- </tr>
- <tr>
- <td style="width: 50%;">
- Info card showing detailed information for the <code>conv2</code> name scope. The inputs and outputs are combined from the inputs and outputs of the operation nodes inside the name scope. For name scopes no attributes are shown.
- </td>
- <td style="width: 50%;">
- Info card showing detailed information for the <code>DecodeRaw</code> operation node. In addition to inputs and outputs, the card shows the device and the attributes associated with the current operation.
- </td>
- </tr>
-</table>
-
-TensorBoard provides several ways to change the visual layout of the graph. This
-doesn't change the graph's computational semantics, but it can bring some
-clarity to the network's structure. By right clicking on a node or pressing
-buttons on the bottom of that node's info card, you can make the following
-changes to its layout:
-
-* Nodes can be moved between the main graph and the auxiliary area.
-* A series of nodes can be ungrouped so that the nodes in the series do not
-appear grouped together. Ungrouped series can likewise be regrouped.
-
-Selection can also be helpful in understanding high-degree nodes. Select any
-high-degree node, and the corresponding node icons for its other connections
-will be selected as well. This makes it easy, for example, to see which nodes
-are being saved--and which aren't.
-
-Clicking on a node name in the info card will select it. If necessary, the
-viewpoint will automatically pan so that the node is visible.
-
-Finally, you can choose two color schemes for your graph, using the color menu
-above the legend. The default *Structure View* shows structure: when two
-high-level nodes have the same structure, they appear in the same color of the
-rainbow. Uniquely structured nodes are gray. There's a second view, which shows
-what device the different operations run on. Name scopes are colored
-proportionally to the fraction of devices for the operations inside them.
-
-The images below give an illustration for a piece of a real-life graph.
-
-<table width="100%;">
- <tr>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/colorby_structure.png" alt="Color by structure" title="Color by structure" />
- </td>
- <td style="width: 50%;">
- <img src="https://www.tensorflow.org/images/colorby_device.png" alt="Color by device" title="Color by device" />
- </td>
- </tr>
- <tr>
- <td style="width: 50%;">
- Structure view: The gray nodes have unique structure. The orange <code>conv1</code> and <code>conv2</code> nodes have the same structure, and analogously for nodes with other colors.
- </td>
- <td style="width: 50%;">
- Device view: Name scopes are colored proportionally to the fraction of devices of the operation nodes inside them. Here, purple means GPU and the green is CPU.
- </td>
- </tr>
-</table>
-
-## Tensor shape information
-
-When the serialized `GraphDef` includes tensor shapes, the graph visualizer
-labels edges with tensor dimensions, and edge thickness reflects total tensor
-size. To include tensor shapes in the `GraphDef` pass the actual graph object
-(as in `sess.graph`) to the `FileWriter` when serializing the graph.
-The images below show the CIFAR-10 model with tensor shape information:
-<table width="100%;">
- <tr>
- <td style="width: 100%;">
- <img src="https://www.tensorflow.org/images/tensor_shapes.png" alt="CIFAR-10 model with tensor shape information" title="CIFAR-10 model with tensor shape information" />
- </td>
- </tr>
- <tr>
- <td style="width: 100%;">
- CIFAR-10 model with tensor shape information.
- </td>
- </tr>
-</table>
-
-## Runtime statistics
-
-Often it is useful to collect runtime metadata for a run, such as total memory
-usage, total compute time, and tensor shapes for nodes. The code example below
-is a snippet from the train and test section of a modification of the
-[Estimators MNIST tutorial](../tutorials/estimators/cnn.md), in which we have
-recorded summaries and
-runtime statistics. See the
-@{$summaries_and_tensorboard#serializing-the-data$Summaries Tutorial}
-for details on how to record summaries.
-Full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
-
-```python
- # Train the model, and also write summaries.
- # Every 10th step, measure test-set accuracy, and write test summaries
- # All other steps, run train_step on training data, & add training summaries
-
- def feed_dict(train):
- """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
- if train or FLAGS.fake_data:
- xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data)
- k = FLAGS.dropout
- else:
- xs, ys = mnist.test.images, mnist.test.labels
- k = 1.0
- return {x: xs, y_: ys, keep_prob: k}
-
- for i in range(FLAGS.max_steps):
- if i % 10 == 0: # Record summaries and test-set accuracy
- summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
- test_writer.add_summary(summary, i)
- print('Accuracy at step %s: %s' % (i, acc))
- else: # Record train set summaries, and train
- if i % 100 == 99: # Record execution stats
- run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
- run_metadata = tf.RunMetadata()
- summary, _ = sess.run([merged, train_step],
- feed_dict=feed_dict(True),
- options=run_options,
- run_metadata=run_metadata)
- train_writer.add_run_metadata(run_metadata, 'step%d' % i)
- train_writer.add_summary(summary, i)
- print('Adding run metadata for', i)
- else: # Record a summary
- summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
- train_writer.add_summary(summary, i)
-```
-
-This code will emit runtime statistics for every 100th step starting at step99.
-
-When you launch tensorboard and go to the Graph tab, you will now see options
-under "Session runs" which correspond to the steps where run metadata was added.
-Selecting one of these runs will show you the snapshot of the network at that
-step, fading out unused nodes. In the controls on the left hand side, you will
-be able to color the nodes by total memory or total compute time. Additionally,
-clicking on a node will display the exact total memory, compute time, and
-tensor output sizes.
-
-
-<table width="100%;">
- <tr style="height: 380px">
- <td>
- <img src="https://www.tensorflow.org/images/colorby_compute_time.png" alt="Color by compute time" title="Color by compute time"/>
- </td>
- <td>
- <img src="https://www.tensorflow.org/images/run_metadata_graph.png" alt="Run metadata graph" title="Run metadata graph" />
- </td>
- <td>
- <img src="https://www.tensorflow.org/images/run_metadata_infocard.png" alt="Run metadata info card" title="Run metadata info card" />
- </td>
- </tr>
-</table>
diff --git a/tensorflow/docs_src/guide/graphs.md b/tensorflow/docs_src/guide/graphs.md
deleted file mode 100644
index 2bb44fbb32..0000000000
--- a/tensorflow/docs_src/guide/graphs.md
+++ /dev/null
@@ -1,558 +0,0 @@
-# Graphs and Sessions
-
-TensorFlow uses a **dataflow graph** to represent your computation in terms of
-the dependencies between individual operations. This leads to a low-level
-programming model in which you first define the dataflow graph, then create a
-TensorFlow **session** to run parts of the graph across a set of local and
-remote devices.
-
-This guide will be most useful if you intend to use the low-level programming
-model directly. Higher-level APIs such as `tf.estimator.Estimator` and Keras
-hide the details of graphs and sessions from the end user, but this guide may
-also be useful if you want to understand how these APIs are implemented.
-
-## Why dataflow graphs?
-
-![](../images/tensors_flowing.gif)
-
-[Dataflow](https://en.wikipedia.org/wiki/Dataflow_programming) is a common
-programming model for parallel computing. In a dataflow graph, the nodes
-represent units of computation, and the edges represent the data consumed or
-produced by a computation. For example, in a TensorFlow graph, the `tf.matmul`
-operation would correspond to a single node with two incoming edges (the
-matrices to be multiplied) and one outgoing edge (the result of the
-multiplication).
-
-<!-- TODO(barryr): Add a diagram to illustrate the `tf.matmul` graph. -->
-
-Dataflow has several advantages that TensorFlow leverages when executing your
-programs:
-
-* **Parallelism.** By using explicit edges to represent dependencies between
- operations, it is easy for the system to identify operations that can execute
- in parallel.
-
-* **Distributed execution.** By using explicit edges to represent the values
- that flow between operations, it is possible for TensorFlow to partition your
- program across multiple devices (CPUs, GPUs, and TPUs) attached to different
- machines. TensorFlow inserts the necessary communication and coordination
- between devices.
-
-* **Compilation.** TensorFlow's @{$performance/xla$XLA compiler} can
- use the information in your dataflow graph to generate faster code, for
- example, by fusing together adjacent operations.
-
-* **Portability.** The dataflow graph is a language-independent representation
- of the code in your model. You can build a dataflow graph in Python, store it
- in a @{$saved_model$SavedModel}, and restore it in a C++ program for
- low-latency inference.
-
-
-## What is a `tf.Graph`?
-
-A `tf.Graph` contains two relevant kinds of information:
-
-* **Graph structure.** The nodes and edges of the graph, indicating how
- individual operations are composed together, but not prescribing how they
- should be used. The graph structure is like assembly code: inspecting it can
- convey some useful information, but it does not contain all of the useful
- context that source code conveys.
-
-* **Graph collections.** TensorFlow provides a general mechanism for storing
- collections of metadata in a `tf.Graph`. The `tf.add_to_collection` function
- enables you to associate a list of objects with a key (where `tf.GraphKeys`
- defines some of the standard keys), and `tf.get_collection` enables you to
- look up all objects associated with a key. Many parts of the TensorFlow
- library use this facility: for example, when you create a `tf.Variable`, it
- is added by default to collections representing "global variables" and
- "trainable variables". When you later come to create a `tf.train.Saver` or
- `tf.train.Optimizer`, the variables in these collections are used as the
- default arguments.
-
-
-## Building a `tf.Graph`
-
-Most TensorFlow programs start with a dataflow graph construction phase. In this
-phase, you invoke TensorFlow API functions that construct new `tf.Operation`
-(node) and `tf.Tensor` (edge) objects and add them to a `tf.Graph`
-instance. TensorFlow provides a **default graph** that is an implicit argument
-to all API functions in the same context. For example:
-
-* Calling `tf.constant(42.0)` creates a single `tf.Operation` that produces the
- value `42.0`, adds it to the default graph, and returns a `tf.Tensor` that
- represents the value of the constant.
-
-* Calling `tf.matmul(x, y)` creates a single `tf.Operation` that multiplies
- the values of `tf.Tensor` objects `x` and `y`, adds it to the default graph,
- and returns a `tf.Tensor` that represents the result of the multiplication.
-
-* Executing `v = tf.Variable(0)` adds to the graph a `tf.Operation` that will
- store a writeable tensor value that persists between `tf.Session.run` calls.
- The `tf.Variable` object wraps this operation, and can be used [like a
- tensor](#tensor-like_objects), which will read the current value of the
- stored value. The `tf.Variable` object also has methods such as
- `tf.Variable.assign` and `tf.Variable.assign_add` that
- create `tf.Operation` objects that, when executed, update the stored value.
- (See @{$guide/variables} for more information about variables.)
-
-* Calling `tf.train.Optimizer.minimize` will add operations and tensors to the
- default graph that calculates gradients, and return a `tf.Operation` that,
- when run, will apply those gradients to a set of variables.
-
-Most programs rely solely on the default graph. However,
-see [Dealing with multiple graphs](#programming_with_multiple_graphs) for more
-advanced use cases. High-level APIs such as the `tf.estimator.Estimator` API
-manage the default graph on your behalf, and--for example--may create different
-graphs for training and evaluation.
-
-Note: Calling most functions in the TensorFlow API merely adds operations
-and tensors to the default graph, but **does not** perform the actual
-computation. Instead, you compose these functions until you have a `tf.Tensor`
-or `tf.Operation` that represents the overall computation--such as performing
-one step of gradient descent--and then pass that object to a `tf.Session` to
-perform the computation. See the section "Executing a graph in a `tf.Session`"
-for more details.
-
-## Naming operations
-
-A `tf.Graph` object defines a **namespace** for the `tf.Operation` objects it
-contains. TensorFlow automatically chooses a unique name for each operation in
-your graph, but giving operations descriptive names can make your program easier
-to read and debug. The TensorFlow API provides two ways to override the name of
-an operation:
-
-* Each API function that creates a new `tf.Operation` or returns a new
- `tf.Tensor` accepts an optional `name` argument. For example,
- `tf.constant(42.0, name="answer")` creates a new `tf.Operation` named
- `"answer"` and returns a `tf.Tensor` named `"answer:0"`. If the default graph
- already contains an operation named `"answer"`, then TensorFlow would append
- `"_1"`, `"_2"`, and so on to the name, in order to make it unique.
-
-* The `tf.name_scope` function makes it possible to add a **name scope** prefix
- to all operations created in a particular context. The current name scope
- prefix is a `"/"`-delimited list of the names of all active `tf.name_scope`
- context managers. If a name scope has already been used in the current
- context, TensorFlow appends `"_1"`, `"_2"`, and so on. For example:
-
- ```python
- c_0 = tf.constant(0, name="c") # => operation named "c"
-
- # Already-used names will be "uniquified".
- c_1 = tf.constant(2, name="c") # => operation named "c_1"
-
- # Name scopes add a prefix to all operations created in the same context.
- with tf.name_scope("outer"):
- c_2 = tf.constant(2, name="c") # => operation named "outer/c"
-
- # Name scopes nest like paths in a hierarchical file system.
- with tf.name_scope("inner"):
- c_3 = tf.constant(3, name="c") # => operation named "outer/inner/c"
-
- # Exiting a name scope context will return to the previous prefix.
- c_4 = tf.constant(4, name="c") # => operation named "outer/c_1"
-
- # Already-used name scopes will be "uniquified".
- with tf.name_scope("inner"):
- c_5 = tf.constant(5, name="c") # => operation named "outer/inner_1/c"
- ```
-
-The graph visualizer uses name scopes to group operations and reduce the visual
-complexity of a graph. See [Visualizing your graph](#visualizing-your-graph) for
-more information.
-
-Note that `tf.Tensor` objects are implicitly named after the `tf.Operation`
-that produces the tensor as output. A tensor name has the form `"<OP_NAME>:<i>"`
-where:
-
-* `"<OP_NAME>"` is the name of the operation that produces it.
-* `"<i>"` is an integer representing the index of that tensor among the
- operation's outputs.
-
-## Placing operations on different devices
-
-If you want your TensorFlow program to use multiple different devices, the
-`tf.device` function provides a convenient way to request that all operations
-created in a particular context are placed on the same device (or type of
-device).
-
-A **device specification** has the following form:
-
-```
-/job:<JOB_NAME>/task:<TASK_INDEX>/device:<DEVICE_TYPE>:<DEVICE_INDEX>
-```
-
-where:
-
-* `<JOB_NAME>` is an alpha-numeric string that does not start with a number.
-* `<DEVICE_TYPE>` is a registered device type (such as `GPU` or `CPU`).
-* `<TASK_INDEX>` is a non-negative integer representing the index of the task
- in the job named `<JOB_NAME>`. See `tf.train.ClusterSpec` for an explanation
- of jobs and tasks.
-* `<DEVICE_INDEX>` is a non-negative integer representing the index of the
- device, for example, to distinguish between different GPU devices used in the
- same process.
-
-You do not need to specify every part of a device specification. For example,
-if you are running in a single-machine configuration with a single GPU, you
-might use `tf.device` to pin some operations to the CPU and GPU:
-
-```python
-# Operations created outside either context will run on the "best possible"
-# device. For example, if you have a GPU and a CPU available, and the operation
-# has a GPU implementation, TensorFlow will choose the GPU.
-weights = tf.random_normal(...)
-
-with tf.device("/device:CPU:0"):
- # Operations created in this context will be pinned to the CPU.
- img = tf.decode_jpeg(tf.read_file("img.jpg"))
-
-with tf.device("/device:GPU:0"):
- # Operations created in this context will be pinned to the GPU.
- result = tf.matmul(weights, img)
-```
-If you are deploying TensorFlow in a @{$distributed$typical distributed configuration},
-you might specify the job name and task ID to place variables on
-a task in the parameter server job (`"/job:ps"`), and the other operations on
-task in the worker job (`"/job:worker"`):
-
-```python
-with tf.device("/job:ps/task:0"):
- weights_1 = tf.Variable(tf.truncated_normal([784, 100]))
- biases_1 = tf.Variable(tf.zeroes([100]))
-
-with tf.device("/job:ps/task:1"):
- weights_2 = tf.Variable(tf.truncated_normal([100, 10]))
- biases_2 = tf.Variable(tf.zeroes([10]))
-
-with tf.device("/job:worker"):
- layer_1 = tf.matmul(train_batch, weights_1) + biases_1
- layer_2 = tf.matmul(train_batch, weights_2) + biases_2
-```
-
-`tf.device` gives you a lot of flexibility to choose placements for individual
-operations or broad regions of a TensorFlow graph. In many cases, there are
-simple heuristics that work well. For example, the
-`tf.train.replica_device_setter` API can be used with `tf.device` to place
-operations for **data-parallel distributed training**. For example, the
-following code fragment shows how `tf.train.replica_device_setter` applies
-different placement policies to `tf.Variable` objects and other operations:
-
-```python
-with tf.device(tf.train.replica_device_setter(ps_tasks=3)):
- # tf.Variable objects are, by default, placed on tasks in "/job:ps" in a
- # round-robin fashion.
- w_0 = tf.Variable(...) # placed on "/job:ps/task:0"
- b_0 = tf.Variable(...) # placed on "/job:ps/task:1"
- w_1 = tf.Variable(...) # placed on "/job:ps/task:2"
- b_1 = tf.Variable(...) # placed on "/job:ps/task:0"
-
- input_data = tf.placeholder(tf.float32) # placed on "/job:worker"
- layer_0 = tf.matmul(input_data, w_0) + b_0 # placed on "/job:worker"
- layer_1 = tf.matmul(layer_0, w_1) + b_1 # placed on "/job:worker"
-```
-
-## Tensor-like objects
-
-Many TensorFlow operations take one or more `tf.Tensor` objects as arguments.
-For example, `tf.matmul` takes two `tf.Tensor` objects, and `tf.add_n` takes
-a list of `n` `tf.Tensor` objects. For convenience, these functions will accept
-a **tensor-like object** in place of a `tf.Tensor`, and implicitly convert it
-to a `tf.Tensor` using the `tf.convert_to_tensor` method. Tensor-like objects
-include elements of the following types:
-
-* `tf.Tensor`
-* `tf.Variable`
-* [`numpy.ndarray`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.html)
-* `list` (and lists of tensor-like objects)
-* Scalar Python types: `bool`, `float`, `int`, `str`
-
-You can register additional tensor-like types using
-`tf.register_tensor_conversion_function`.
-
-Note: By default, TensorFlow will create a new `tf.Tensor` each time you use
-the same tensor-like object. If the tensor-like object is large (e.g. a
-`numpy.ndarray` containing a set of training examples) and you use it multiple
-times, you may run out of memory. To avoid this, manually call
-`tf.convert_to_tensor` on the tensor-like object once and use the returned
-`tf.Tensor` instead.
-
-## Executing a graph in a `tf.Session`
-
-TensorFlow uses the `tf.Session` class to represent a connection between the
-client program---typically a Python program, although a similar interface is
-available in other languages---and the C++ runtime. A `tf.Session` object
-provides access to devices in the local machine, and remote devices using the
-distributed TensorFlow runtime. It also caches information about your
-`tf.Graph` so that you can efficiently run the same computation multiple times.
-
-### Creating a `tf.Session`
-
-If you are using the low-level TensorFlow API, you can create a `tf.Session`
-for the current default graph as follows:
-
-```python
-# Create a default in-process session.
-with tf.Session() as sess:
- # ...
-
-# Create a remote session.
-with tf.Session("grpc://example.org:2222"):
- # ...
-```
-
-Since a `tf.Session` owns physical resources (such as GPUs and
-network connections), it is typically used as a context manager (in a `with`
-block) that automatically closes the session when you exit the block. It is
-also possible to create a session without using a `with` block, but you should
-explicitly call `tf.Session.close` when you are finished with it to free the
-resources.
-
-Note: Higher-level APIs such as `tf.train.MonitoredTrainingSession` or
-`tf.estimator.Estimator` will create and manage a `tf.Session` for you. These
-APIs accept optional `target` and `config` arguments (either directly, or as
-part of a `tf.estimator.RunConfig` object), with the same meaning as
-described below.
-
-`tf.Session.__init__` accepts three optional arguments:
-
-* **`target`.** If this argument is left empty (the default), the session will
- only use devices in the local machine. However, you may also specify a
- `grpc://` URL to specify the address of a TensorFlow server, which gives the
- session access to all devices on machines that this server controls. See
- `tf.train.Server` for details of how to create a TensorFlow
- server. For example, in the common **between-graph replication**
- configuration, the `tf.Session` connects to a `tf.train.Server` in the same
- process as the client. The [distributed TensorFlow](../deploy/distributed.md)
- deployment guide describes other common scenarios.
-
-* **`graph`.** By default, a new `tf.Session` will be bound to---and only able
- to run operations in---the current default graph. If you are using multiple
- graphs in your program (see [Programming with multiple
- graphs](#programming_with_multiple_graphs) for more details), you can specify
- an explicit `tf.Graph` when you construct the session.
-
-* **`config`.** This argument allows you to specify a `tf.ConfigProto` that
- controls the behavior of the session. For example, some of the configuration
- options include:
-
- * `allow_soft_placement`. Set this to `True` to enable a "soft" device
- placement algorithm, which ignores `tf.device` annotations that attempt
- to place CPU-only operations on a GPU device, and places them on the CPU
- instead.
-
- * `cluster_def`. When using distributed TensorFlow, this option allows you
- to specify what machines to use in the computation, and provide a mapping
- between job names, task indices, and network addresses. See
- `tf.train.ClusterSpec.as_cluster_def` for details.
-
- * `graph_options.optimizer_options`. Provides control over the optimizations
- that TensorFlow performs on your graph before executing it.
-
- * `gpu_options.allow_growth`. Set this to `True` to change the GPU memory
- allocator so that it gradually increases the amount of memory allocated,
- rather than allocating most of the memory at startup.
-
-
-### Using `tf.Session.run` to execute operations
-
-The `tf.Session.run` method is the main mechanism for running a `tf.Operation`
-or evaluating a `tf.Tensor`. You can pass one or more `tf.Operation` or
-`tf.Tensor` objects to `tf.Session.run`, and TensorFlow will execute the
-operations that are needed to compute the result.
-
-`tf.Session.run` requires you to specify a list of **fetches**, which determine
-the return values, and may be a `tf.Operation`, a `tf.Tensor`, or
-a [tensor-like type](#tensor-like_objects) such as `tf.Variable`. These fetches
-determine what **subgraph** of the overall `tf.Graph` must be executed to
-produce the result: this is the subgraph that contains all operations named in
-the fetch list, plus all operations whose outputs are used to compute the value
-of the fetches. For example, the following code fragment shows how different
-arguments to `tf.Session.run` cause different subgraphs to be executed:
-
-```python
-x = tf.constant([[37.0, -23.0], [1.0, 4.0]])
-w = tf.Variable(tf.random_uniform([2, 2]))
-y = tf.matmul(x, w)
-output = tf.nn.softmax(y)
-init_op = w.initializer
-
-with tf.Session() as sess:
- # Run the initializer on `w`.
- sess.run(init_op)
-
- # Evaluate `output`. `sess.run(output)` will return a NumPy array containing
- # the result of the computation.
- print(sess.run(output))
-
- # Evaluate `y` and `output`. Note that `y` will only be computed once, and its
- # result used both to return `y_val` and as an input to the `tf.nn.softmax()`
- # op. Both `y_val` and `output_val` will be NumPy arrays.
- y_val, output_val = sess.run([y, output])
-```
-
-`tf.Session.run` also optionally takes a dictionary of **feeds**, which is a
-mapping from `tf.Tensor` objects (typically `tf.placeholder` tensors) to
-values (typically Python scalars, lists, or NumPy arrays) that will be
-substituted for those tensors in the execution. For example:
-
-```python
-# Define a placeholder that expects a vector of three floating-point values,
-# and a computation that depends on it.
-x = tf.placeholder(tf.float32, shape=[3])
-y = tf.square(x)
-
-with tf.Session() as sess:
- # Feeding a value changes the result that is returned when you evaluate `y`.
- print(sess.run(y, {x: [1.0, 2.0, 3.0]})) # => "[1.0, 4.0, 9.0]"
- print(sess.run(y, {x: [0.0, 0.0, 5.0]})) # => "[0.0, 0.0, 25.0]"
-
- # Raises `tf.errors.InvalidArgumentError`, because you must feed a value for
- # a `tf.placeholder()` when evaluating a tensor that depends on it.
- sess.run(y)
-
- # Raises `ValueError`, because the shape of `37.0` does not match the shape
- # of placeholder `x`.
- sess.run(y, {x: 37.0})
-```
-
-`tf.Session.run` also accepts an optional `options` argument that enables you
-to specify options about the call, and an optional `run_metadata` argument that
-enables you to collect metadata about the execution. For example, you can use
-these options together to collect tracing information about the execution:
-
-```
-y = tf.matmul([[37.0, -23.0], [1.0, 4.0]], tf.random_uniform([2, 2]))
-
-with tf.Session() as sess:
- # Define options for the `sess.run()` call.
- options = tf.RunOptions()
- options.output_partition_graphs = True
- options.trace_level = tf.RunOptions.FULL_TRACE
-
- # Define a container for the returned metadata.
- metadata = tf.RunMetadata()
-
- sess.run(y, options=options, run_metadata=metadata)
-
- # Print the subgraphs that executed on each device.
- print(metadata.partition_graphs)
-
- # Print the timings of each operation that executed.
- print(metadata.step_stats)
-```
-
-
-## Visualizing your graph
-
-TensorFlow includes tools that can help you to understand the code in a graph.
-The **graph visualizer** is a component of TensorBoard that renders the
-structure of your graph visually in a browser. The easiest way to create a
-visualization is to pass a `tf.Graph` when creating the
-`tf.summary.FileWriter`:
-
-```python
-# Build your graph.
-x = tf.constant([[37.0, -23.0], [1.0, 4.0]])
-w = tf.Variable(tf.random_uniform([2, 2]))
-y = tf.matmul(x, w)
-# ...
-loss = ...
-train_op = tf.train.AdagradOptimizer(0.01).minimize(loss)
-
-with tf.Session() as sess:
- # `sess.graph` provides access to the graph used in a `tf.Session`.
- writer = tf.summary.FileWriter("/tmp/log/...", sess.graph)
-
- # Perform your computation...
- for i in range(1000):
- sess.run(train_op)
- # ...
-
- writer.close()
-```
-
-Note: If you are using a `tf.estimator.Estimator`, the graph (and any
-summaries) will be logged automatically to the `model_dir` that you specified
-when creating the estimator.
-
-You can then open the log in `tensorboard`, navigate to the "Graph" tab, and
-see a high-level visualization of your graph's structure. Note that a typical
-TensorFlow graph---especially training graphs with automatically computed
-gradients---has too many nodes to visualize at once. The graph visualizer makes
-use of name scopes to group related operations into "super" nodes. You can
-click on the orange "+" button on any of these super nodes to expand the
-subgraph inside.
-
-![](../images/mnist_deep.png)
-
-For more information about visualizing your TensorFlow application with
-TensorBoard, see the [TensorBoard guide](./summaries_and_tensorboard.md).
-
-## Programming with multiple graphs
-
-Note: When training a model, a common way of organizing your code is to use one
-graph for training your model, and a separate graph for evaluating or performing
-inference with a trained model. In many cases, the inference graph will be
-different from the training graph: for example, techniques like dropout and
-batch normalization use different operations in each case. Furthermore, by
-default utilities like `tf.train.Saver` use the names of `tf.Variable` objects
-(which have names based on an underlying `tf.Operation`) to identify each
-variable in a saved checkpoint. When programming this way, you can either use
-completely separate Python processes to build and execute the graphs, or you can
-use multiple graphs in the same process. This section describes how to use
-multiple graphs in the same process.
-
-As noted above, TensorFlow provides a "default graph" that is implicitly passed
-to all API functions in the same context. For many applications, a single graph
-is sufficient. However, TensorFlow also provides methods for manipulating
-the default graph, which can be useful in more advanced use cases. For example:
-
-* A `tf.Graph` defines the namespace for `tf.Operation` objects: each
- operation in a single graph must have a unique name. TensorFlow will
- "uniquify" the names of operations by appending `"_1"`, `"_2"`, and so on to
- their names if the requested name is already taken. Using multiple explicitly
- created graphs gives you more control over what name is given to each
- operation.
-
-* The default graph stores information about every `tf.Operation` and
- `tf.Tensor` that was ever added to it. If your program creates a large number
- of unconnected subgraphs, it may be more efficient to use a different
- `tf.Graph` to build each subgraph, so that unrelated state can be garbage
- collected.
-
-You can install a different `tf.Graph` as the default graph, using the
-`tf.Graph.as_default` context manager:
-
-```python
-g_1 = tf.Graph()
-with g_1.as_default():
- # Operations created in this scope will be added to `g_1`.
- c = tf.constant("Node in g_1")
-
- # Sessions created in this scope will run operations from `g_1`.
- sess_1 = tf.Session()
-
-g_2 = tf.Graph()
-with g_2.as_default():
- # Operations created in this scope will be added to `g_2`.
- d = tf.constant("Node in g_2")
-
-# Alternatively, you can pass a graph when constructing a `tf.Session`:
-# `sess_2` will run operations from `g_2`.
-sess_2 = tf.Session(graph=g_2)
-
-assert c.graph is g_1
-assert sess_1.graph is g_1
-
-assert d.graph is g_2
-assert sess_2.graph is g_2
-```
-
-To inspect the current default graph, call `tf.get_default_graph`, which
-returns a `tf.Graph` object:
-
-```python
-# Print all of the operations in the default graph.
-g = tf.get_default_graph()
-print(g.get_operations())
-```
diff --git a/tensorflow/docs_src/guide/index.md b/tensorflow/docs_src/guide/index.md
deleted file mode 100644
index 1c920e7d70..0000000000
--- a/tensorflow/docs_src/guide/index.md
+++ /dev/null
@@ -1,82 +0,0 @@
-# TensorFlow Guide
-
-The documents in this unit dive into the details of how TensorFlow
-works. The units are as follows:
-
-## High Level APIs
-
- * @{$guide/keras}, TensorFlow's high-level API for building and
- training deep learning models.
- * @{$guide/eager}, an API for writing TensorFlow code
- imperatively, like you would use Numpy.
- * @{$guide/datasets}, easy input pipelines to bring your data into
- your TensorFlow program.
- * @{$guide/estimators}, a high-level API that provides
- fully-packaged models ready for large-scale training and production.
-
-## Estimators
-
-* @{$premade_estimators}, the basics of premade Estimators.
-* @{$checkpoints}, save training progress and resume where you left off.
-* @{$feature_columns}, handle a variety of input data types without changes to the model.
-* @{$datasets_for_estimators}, use `tf.data` to input data.
-* @{$custom_estimators}, write your own Estimator.
-
-## Accelerators
-
- * @{$using_gpu} explains how TensorFlow assigns operations to
- devices and how you can change the arrangement manually.
- * @{$using_tpu} explains how to modify `Estimator` programs to run on a TPU.
-
-## Low Level APIs
-
- * @{$guide/low_level_intro}, which introduces the
- basics of how you can use TensorFlow outside of the high Level APIs.
- * @{$guide/tensors}, which explains how to create,
- manipulate, and access Tensors--the fundamental object in TensorFlow.
- * @{$guide/variables}, which details how
- to represent shared, persistent state in your program.
- * @{$guide/graphs}, which explains:
- * dataflow graphs, which are TensorFlow's representation of computations
- as dependencies between operations.
- * sessions, which are TensorFlow's mechanism for running dataflow graphs
- across one or more local or remote devices.
- If you are programming with the low-level TensorFlow API, this unit
- is essential. If you are programming with a high-level TensorFlow API
- such as Estimators or Keras, the high-level API creates and manages
- graphs and sessions for you, but understanding graphs and sessions
- can still be helpful.
- * @{$guide/saved_model}, which
- explains how to save and restore variables and models.
-
-## ML Concepts
-
- * @{$guide/embedding}, which introduces the concept
- of embeddings, provides a simple example of training an embedding in
- TensorFlow, and explains how to view embeddings with the TensorBoard
- Embedding Projector.
-
-## Debugging
-
- * @{$guide/debugger}, which
- explains how to use the TensorFlow debugger (tfdbg).
-
-## TensorBoard
-
-TensorBoard is a utility to visualize different aspects of machine learning.
-The following guides explain how to use TensorBoard:
-
- * @{$guide/summaries_and_tensorboard},
- which introduces TensorBoard.
- * @{$guide/graph_viz}, which
- explains how to visualize the computational graph.
- * @{$guide/tensorboard_histograms} which demonstrates the how to
- use TensorBoard's histogram dashboard.
-
-
-## Misc
-
- * @{$guide/version_compat},
- which explains backward compatibility guarantees and non-guarantees.
- * @{$guide/faq}, which contains frequently asked
- questions about TensorFlow.
diff --git a/tensorflow/docs_src/guide/keras.md b/tensorflow/docs_src/guide/keras.md
deleted file mode 100644
index 2330fa03c7..0000000000
--- a/tensorflow/docs_src/guide/keras.md
+++ /dev/null
@@ -1,623 +0,0 @@
-# Keras
-
-Keras is a high-level API to build and train deep learning models. It's used for
-fast prototyping, advanced research, and production, with three key advantages:
-
-- *User friendly*<br>
- Keras has a simple, consistent interface optimized for common use cases. It
- provides clear and actionable feedback for user errors.
-- *Modular and composable*<br>
- Keras models are made by connecting configurable building blocks together,
- with few restrictions.
-- *Easy to extend*<br> Write custom building blocks to express new ideas for
- research. Create new layers, loss functions, and develop state-of-the-art
- models.
-
-## Import tf.keras
-
-`tf.keras` is TensorFlow's implementation of the
-[Keras API specification](https://keras.io){:.external}. This is a high-level
-API to build and train models that includes first-class support for
-TensorFlow-specific functionality, such as [eager execution](#eager_execution),
-`tf.data` pipelines, and [Estimators](./estimators.md).
-`tf.keras` makes TensorFlow easier to use without sacrificing flexibility and
-performance.
-
-To get started, import `tf.keras` as part of your TensorFlow program setup:
-
-```python
-import tensorflow as tf
-from tensorflow import keras
-```
-
-`tf.keras` can run any Keras-compatible code, but keep in mind:
-
-* The `tf.keras` version in the latest TensorFlow release might not be the same
- as the latest `keras` version from PyPI. Check `tf.keras.__version__`.
-* When [saving a model's weights](#weights_only), `tf.keras` defaults to the
- [checkpoint format](./checkpoints.md). Pass `save_format='h5'` to
- use HDF5.
-
-## Build a simple model
-
-### Sequential model
-
-In Keras, you assemble *layers* to build *models*. A model is (usually) a graph
-of layers. The most common type of model is a stack of layers: the
-`tf.keras.Sequential` model.
-
-To build a simple, fully-connected network (i.e. multi-layer perceptron):
-
-```python
-model = keras.Sequential()
-# Adds a densely-connected layer with 64 units to the model:
-model.add(keras.layers.Dense(64, activation='relu'))
-# Add another:
-model.add(keras.layers.Dense(64, activation='relu'))
-# Add a softmax layer with 10 output units:
-model.add(keras.layers.Dense(10, activation='softmax'))
-```
-
-### Configure the layers
-
-There are many `tf.keras.layers` available with some common constructor
-parameters:
-
-* `activation`: Set the activation function for the layer. This parameter is
- specified by the name of a built-in function or as a callable object. By
- default, no activation is applied.
-* `kernel_initializer` and `bias_initializer`: The initialization schemes
- that create the layer's weights (kernel and bias). This parameter is a name or
- a callable object. This defaults to the `"Glorot uniform"` initializer.
-* `kernel_regularizer` and `bias_regularizer`: The regularization schemes
- that apply the layer's weights (kernel and bias), such as L1 or L2
- regularization. By default, no regularization is applied.
-
-The following instantiates `tf.keras.layers.Dense` layers using constructor
-arguments:
-
-```python
-# Create a sigmoid layer:
-layers.Dense(64, activation='sigmoid')
-# Or:
-layers.Dense(64, activation=tf.sigmoid)
-
-# A linear layer with L1 regularization of factor 0.01 applied to the kernel matrix:
-layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
-# A linear layer with L2 regularization of factor 0.01 applied to the bias vector:
-layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01))
-
-# A linear layer with a kernel initialized to a random orthogonal matrix:
-layers.Dense(64, kernel_initializer='orthogonal')
-# A linear layer with a bias vector initialized to 2.0s:
-layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))
-```
-
-## Train and evaluate
-
-### Set up training
-
-After the model is constructed, configure its learning process by calling the
-`compile` method:
-
-```python
-model.compile(optimizer=tf.train.AdamOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-```
-
-`tf.keras.Model.compile` takes three important arguments:
-
-* `optimizer`: This object specifies the training procedure. Pass it optimizer
- instances from the `tf.train` module, such as
- [`AdamOptimizer`](/api_docs/python/tf/train/AdamOptimizer),
- [`RMSPropOptimizer`](/api_docs/python/tf/train/RMSPropOptimizer), or
- [`GradientDescentOptimizer`](/api_docs/python/tf/train/GradientDescentOptimizer).
-* `loss`: The function to minimize during optimization. Common choices include
- mean square error (`mse`), `categorical_crossentropy`, and
- `binary_crossentropy`. Loss functions are specified by name or by
- passing a callable object from the `tf.keras.losses` module.
-* `metrics`: Used to monitor training. These are string names or callables from
- the `tf.keras.metrics` module.
-
-The following shows a few examples of configuring a model for training:
-
-```python
-# Configure a model for mean-squared error regression.
-model.compile(optimizer=tf.train.AdamOptimizer(0.01),
- loss='mse', # mean squared error
- metrics=['mae']) # mean absolute error
-
-# Configure a model for categorical classification.
-model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
- loss=keras.losses.categorical_crossentropy,
- metrics=[keras.metrics.categorical_accuracy])
-```
-
-### Input NumPy data
-
-For small datasets, use in-memory [NumPy](https://www.numpy.org/){:.external}
-arrays to train and evaluate a model. The model is "fit" to the training data
-using the `fit` method:
-
-```python
-import numpy as np
-
-data = np.random.random((1000, 32))
-labels = np.random.random((1000, 10))
-
-model.fit(data, labels, epochs=10, batch_size=32)
-```
-
-`tf.keras.Model.fit` takes three important arguments:
-
-* `epochs`: Training is structured into *epochs*. An epoch is one iteration over
- the entire input data (this is done in smaller batches).
-* `batch_size`: When passed NumPy data, the model slices the data into smaller
- batches and iterates over these batches during training. This integer
- specifies the size of each batch. Be aware that the last batch may be smaller
- if the total number of samples is not divisible by the batch size.
-* `validation_data`: When prototyping a model, you want to easily monitor its
- performance on some validation data. Passing this argument—a tuple of inputs
- and labels—allows the model to display the loss and metrics in inference mode
- for the passed data, at the end of each epoch.
-
-Here's an example using `validation_data`:
-
-```python
-import numpy as np
-
-data = np.random.random((1000, 32))
-labels = np.random.random((1000, 10))
-
-val_data = np.random.random((100, 32))
-val_labels = np.random.random((100, 10))
-
-model.fit(data, labels, epochs=10, batch_size=32,
- validation_data=(val_data, val_labels))
-```
-
-### Input tf.data datasets
-
-Use the [Datasets API](./datasets.md) to scale to large datasets
-or multi-device training. Pass a `tf.data.Dataset` instance to the `fit`
-method:
-
-```python
-# Instantiates a toy dataset instance:
-dataset = tf.data.Dataset.from_tensor_slices((data, labels))
-dataset = dataset.batch(32)
-dataset = dataset.repeat()
-
-# Don't forget to specify `steps_per_epoch` when calling `fit` on a dataset.
-model.fit(dataset, epochs=10, steps_per_epoch=30)
-```
-
-Here, the `fit` method uses the `steps_per_epoch` argument—this is the number of
-training steps the model runs before it moves to the next epoch. Since the
-`Dataset` yields batches of data, this snippet does not require a `batch_size`.
-
-Datasets can also be used for validation:
-
-```python
-dataset = tf.data.Dataset.from_tensor_slices((data, labels))
-dataset = dataset.batch(32).repeat()
-
-val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
-val_dataset = val_dataset.batch(32).repeat()
-
-model.fit(dataset, epochs=10, steps_per_epoch=30,
- validation_data=val_dataset,
- validation_steps=3)
-```
-
-### Evaluate and predict
-
-The `tf.keras.Model.evaluate` and `tf.keras.Model.predict` methods can use NumPy
-data and a `tf.data.Dataset`.
-
-To *evaluate* the inference-mode loss and metrics for the data provided:
-
-```python
-model.evaluate(x, y, batch_size=32)
-
-model.evaluate(dataset, steps=30)
-```
-
-And to *predict* the output of the last layer in inference for the data provided,
-as a NumPy array:
-
-```
-model.predict(x, batch_size=32)
-
-model.predict(dataset, steps=30)
-```
-
-
-## Build advanced models
-
-### Functional API
-
-The `tf.keras.Sequential` model is a simple stack of layers that cannot
-represent arbitrary models. Use the
-[Keras functional API](https://keras.io/getting-started/functional-api-guide/){:.external}
-to build complex model topologies such as:
-
-* Multi-input models,
-* Multi-output models,
-* Models with shared layers (the same layer called several times),
-* Models with non-sequential data flows (e.g. residual connections).
-
-Building a model with the functional API works like this:
-
-1. A layer instance is callable and returns a tensor.
-2. Input tensors and output tensors are used to define a `tf.keras.Model`
- instance.
-3. This model is trained just like the `Sequential` model.
-
-The following example uses the functional API to build a simple, fully-connected
-network:
-
-```python
-inputs = keras.Input(shape=(32,)) # Returns a placeholder tensor
-
-# A layer instance is callable on a tensor, and returns a tensor.
-x = keras.layers.Dense(64, activation='relu')(inputs)
-x = keras.layers.Dense(64, activation='relu')(x)
-predictions = keras.layers.Dense(10, activation='softmax')(x)
-
-# Instantiate the model given inputs and outputs.
-model = keras.Model(inputs=inputs, outputs=predictions)
-
-# The compile step specifies the training configuration.
-model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
-# Trains for 5 epochs
-model.fit(data, labels, batch_size=32, epochs=5)
-```
-
-### Model subclassing
-
-Build a fully-customizable model by subclassing `tf.keras.Model` and defining
-your own forward pass. Create layers in the `__init__` method and set them as
-attributes of the class instance. Define the forward pass in the `call` method.
-
-Model subclassing is particularly useful when
-[eager execution](./eager.md) is enabled since the forward pass
-can be written imperatively.
-
-Key Point: Use the right API for the job. While model subclassing offers
-flexibility, it comes at a cost of greater complexity and more opportunities for
-user errors. If possible, prefer the functional API.
-
-The following example shows a subclassed `tf.keras.Model` using a custom forward
-pass:
-
-```python
-class MyModel(keras.Model):
-
- def __init__(self, num_classes=10):
- super(MyModel, self).__init__(name='my_model')
- self.num_classes = num_classes
- # Define your layers here.
- self.dense_1 = keras.layers.Dense(32, activation='relu')
- self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')
-
- def call(self, inputs):
- # Define your forward pass here,
- # using layers you previously defined (in `__init__`).
- x = self.dense_1(inputs)
- return self.dense_2(x)
-
- def compute_output_shape(self, input_shape):
- # You need to override this function if you want to use the subclassed model
- # as part of a functional-style model.
- # Otherwise, this method is optional.
- shape = tf.TensorShape(input_shape).as_list()
- shape[-1] = self.num_classes
- return tf.TensorShape(shape)
-
-
-# Instantiates the subclassed model.
-model = MyModel(num_classes=10)
-
-# The compile step specifies the training configuration.
-model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
-# Trains for 5 epochs.
-model.fit(data, labels, batch_size=32, epochs=5)
-```
-
-
-### Custom layers
-
-Create a custom layer by subclassing `tf.keras.layers.Layer` and implementing
-the following methods:
-
-* `build`: Create the weights of the layer. Add weights with the `add_weight`
- method.
-* `call`: Define the forward pass.
-* `compute_output_shape`: Specify how to compute the output shape of the layer
- given the input shape.
-* Optionally, a layer can be serialized by implementing the `get_config` method
- and the `from_config` class method.
-
-Here's an example of a custom layer that implements a `matmul` of an input with
-a kernel matrix:
-
-```python
-class MyLayer(keras.layers.Layer):
-
- def __init__(self, output_dim, **kwargs):
- self.output_dim = output_dim
- super(MyLayer, self).__init__(**kwargs)
-
- def build(self, input_shape):
- shape = tf.TensorShape((input_shape[1], self.output_dim))
- # Create a trainable weight variable for this layer.
- self.kernel = self.add_weight(name='kernel',
- shape=shape,
- initializer='uniform',
- trainable=True)
- # Be sure to call this at the end
- super(MyLayer, self).build(input_shape)
-
- def call(self, inputs):
- return tf.matmul(inputs, self.kernel)
-
- def compute_output_shape(self, input_shape):
- shape = tf.TensorShape(input_shape).as_list()
- shape[-1] = self.output_dim
- return tf.TensorShape(shape)
-
- def get_config(self):
- base_config = super(MyLayer, self).get_config()
- base_config['output_dim'] = self.output_dim
-
- @classmethod
- def from_config(cls, config):
- return cls(**config)
-
-
-# Create a model using the custom layer
-model = keras.Sequential([MyLayer(10),
- keras.layers.Activation('softmax')])
-
-# The compile step specifies the training configuration
-model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
-# Trains for 5 epochs.
-model.fit(data, targets, batch_size=32, epochs=5)
-```
-
-
-## Callbacks
-
-A callback is an object passed to a model to customize and extend its behavior
-during training. You can write your own custom callback, or use the built-in
-`tf.keras.callbacks` that include:
-
-* `tf.keras.callbacks.ModelCheckpoint`: Save checkpoints of your model at
- regular intervals.
-* `tf.keras.callbacks.LearningRateScheduler`: Dynamically change the learning
- rate.
-* `tf.keras.callbacks.EarlyStopping`: Interrupt training when validation
- performance has stopped improving.
-* `tf.keras.callbacks.TensorBoard`: Monitor the model's behavior using
- [TensorBoard](./summaries_and_tensorboard.md).
-
-To use a `tf.keras.callbacks.Callback`, pass it to the model's `fit` method:
-
-```python
-callbacks = [
- # Interrupt training if `val_loss` stops improving for over 2 epochs
- keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
- # Write TensorBoard logs to `./logs` directory
- keras.callbacks.TensorBoard(log_dir='./logs')
-]
-model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks,
- validation_data=(val_data, val_targets))
-```
-
-
-## Save and restore
-
-### Weights only
-
-Save and load the weights of a model using `tf.keras.Model.save_weights`:
-
-```python
-# Save weights to a TensorFlow Checkpoint file
-model.save_weights('./my_model')
-
-# Restore the model's state,
-# this requires a model with the same architecture.
-model.load_weights('my_model')
-```
-
-By default, this saves the model's weights in the
-[TensorFlow checkpoint](./checkpoints.md) file format. Weights can
-also be saved to the Keras HDF5 format (the default for the multi-backend
-implementation of Keras):
-
-```python
-# Save weights to a HDF5 file
-model.save_weights('my_model.h5', save_format='h5')
-
-# Restore the model's state
-model.load_weights('my_model.h5')
-```
-
-
-### Configuration only
-
-A model's configuration can be saved—this serializes the model architecture
-without any weights. A saved configuration can recreate and initialize the same
-model, even without the code that defined the original model. Keras supports
-JSON and YAML serialization formats:
-
-```python
-# Serialize a model to JSON format
-json_string = model.to_json()
-
-# Recreate the model (freshly initialized)
-fresh_model = keras.models.model_from_json(json_string)
-
-# Serializes a model to YAML format
-yaml_string = model.to_yaml()
-
-# Recreate the model
-fresh_model = keras.models.model_from_yaml(yaml_string)
-```
-
-Caution: Subclassed models are not serializable because their architecture is
-defined by the Python code in the body of the `call` method.
-
-
-### Entire model
-
-The entire model can be saved to a file that contains the weight values, the
-model's configuration, and even the optimizer's configuration. This allows you
-to checkpoint a model and resume training later—from the exact same
-state—without access to the original code.
-
-```python
-# Create a trivial model
-model = keras.Sequential([
- keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
- keras.layers.Dense(10, activation='softmax')
-])
-model.compile(optimizer='rmsprop',
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-model.fit(data, targets, batch_size=32, epochs=5)
-
-
-# Save entire model to a HDF5 file
-model.save('my_model.h5')
-
-# Recreate the exact same model, including weights and optimizer.
-model = keras.models.load_model('my_model.h5')
-```
-
-
-## Eager execution
-
-[Eager execution](./eager.md) is an imperative programming
-environment that evaluates operations immediately. This is not required for
-Keras, but is supported by `tf.keras` and useful for inspecting your program and
-debugging.
-
-All of the `tf.keras` model-building APIs are compatible with eager execution.
-And while the `Sequential` and functional APIs can be used, eager execution
-especially benefits *model subclassing* and building *custom layers*—the APIs
-that require you to write the forward pass as code (instead of the APIs that
-create models by assembling existing layers).
-
-See the [eager execution guide](./eager.md#build_a_model) for
-examples of using Keras models with custom training loops and `tf.GradientTape`.
-
-
-## Distribution
-
-### Estimators
-
-The [Estimators](./estimators.md) API is used for training models
-for distributed environments. This targets industry use cases such as
-distributed training on large datasets that can export a model for production.
-
-A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the
-model to an `tf.estimator.Estimator` object with
-`tf.keras.estimator.model_to_estimator`. See
-[Creating Estimators from Keras models](./estimators.md#creating_estimators_from_keras_models).
-
-```python
-model = keras.Sequential([layers.Dense(10,activation='softmax'),
- layers.Dense(10,activation='softmax')])
-
-model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
- loss='categorical_crossentropy',
- metrics=['accuracy'])
-
-estimator = keras.estimator.model_to_estimator(model)
-```
-
-Note: Enable [eager execution](./eager.md) for debugging
-[Estimator input functions](./premade_estimators.md#create_input_functions)
-and inspecting data.
-
-### Multiple GPUs
-
-`tf.keras` models can run on multiple GPUs using
-`tf.contrib.distribute.DistributionStrategy`. This API provides distributed
-training on multiple GPUs with almost no changes to existing code.
-
-Currently, `tf.contrib.distribute.MirroredStrategy` is the only supported
-distribution strategy. `MirroredStrategy` does in-graph replication with
-synchronous training using all-reduce on a single machine. To use
-`DistributionStrategy` with Keras, convert the `tf.keras.Model` to a
-`tf.estimator.Estimator` with `tf.keras.estimator.model_to_estimator`, then
-train the estimator
-
-The following example distributes a `tf.keras.Model` across multiple GPUs on a
-single machine.
-
-First, define a simple model:
-
-```python
-model = keras.Sequential()
-model.add(keras.layers.Dense(16, activation='relu', input_shape=(10,)))
-model.add(keras.layers.Dense(1, activation='sigmoid'))
-
-optimizer = tf.train.GradientDescentOptimizer(0.2)
-
-model.compile(loss='binary_crossentropy', optimizer=optimizer)
-model.summary()
-```
-
-Define an *input pipeline*. The `input_fn` returns a `tf.data.Dataset` object
-used to distribute the data across multiple devices—with each device processing
-a slice of the input batch.
-
-```python
-def input_fn():
- x = np.random.random((1024, 10))
- y = np.random.randint(2, size=(1024, 1))
- x = tf.cast(x, tf.float32)
- dataset = tf.data.Dataset.from_tensor_slices((x, y))
- dataset = dataset.repeat(10)
- dataset = dataset.batch(32)
- return dataset
-```
-
-Next, create a `tf.estimator.RunConfig` and set the `train_distribute` argument
-to the `tf.contrib.distribute.MirroredStrategy` instance. When creating
-`MirroredStrategy`, you can specify a list of devices or set the `num_gpus`
-argument. The default uses all available GPUs, like the following:
-
-```python
-strategy = tf.contrib.distribute.MirroredStrategy()
-config = tf.estimator.RunConfig(train_distribute=strategy)
-```
-
-Convert the Keras model to a `tf.estimator.Estimator` instance:
-
-```python
-keras_estimator = keras.estimator.model_to_estimator(
- keras_model=model,
- config=config,
- model_dir='/tmp/model_dir')
-```
-
-Finally, train the `Estimator` instance by providing the `input_fn` and `steps`
-arguments:
-
-```python
-keras_estimator.train(input_fn=input_fn, steps=10)
-```
diff --git a/tensorflow/docs_src/guide/leftnav_files b/tensorflow/docs_src/guide/leftnav_files
deleted file mode 100644
index 8e227e0c8f..0000000000
--- a/tensorflow/docs_src/guide/leftnav_files
+++ /dev/null
@@ -1,41 +0,0 @@
-index.md
-
-### High Level APIs
-keras.md
-eager.md
-datasets.md
-estimators.md: Introduction to Estimators
-
-### Estimators
-premade_estimators.md
-checkpoints.md
-feature_columns.md
-datasets_for_estimators.md
-custom_estimators.md
-
-### Accelerators
-using_gpu.md
-using_tpu.md
-
-### Low Level APIs
-low_level_intro.md
-tensors.md
-variables.md
-graphs.md
-saved_model.md
-autograph.md : Control flow
-
-### ML Concepts
-embedding.md
-
-### Debugging
-debugger.md
-
-### TensorBoard
-summaries_and_tensorboard.md: Visualizing Learning
-graph_viz.md: Graphs
-tensorboard_histograms.md: Histograms
-
-### Misc
-version_compat.md
-faq.md
diff --git a/tensorflow/docs_src/guide/low_level_intro.md b/tensorflow/docs_src/guide/low_level_intro.md
deleted file mode 100644
index dc6cb9ee0d..0000000000
--- a/tensorflow/docs_src/guide/low_level_intro.md
+++ /dev/null
@@ -1,604 +0,0 @@
-# Introduction
-
-This guide gets you started programming in the low-level TensorFlow APIs
-(TensorFlow Core), showing you how to:
-
- * Manage your own TensorFlow program (a `tf.Graph`) and TensorFlow
- runtime (a `tf.Session`), instead of relying on Estimators to manage them.
- * Run TensorFlow operations, using a `tf.Session`.
- * Use high level components ([datasets](#datasets), [layers](#layers), and
- [feature_columns](#feature_columns)) in this low level environment.
- * Build your own training loop, instead of using the one
- @{$premade_estimators$provided by Estimators}.
-
-We recommend using the higher level APIs to build models when possible.
-Knowing TensorFlow Core is valuable for the following reasons:
-
- * Experimentation and debugging are both more straight forward
- when you can use low level TensorFlow operations directly.
- * It gives you a mental model of how things work internally when
- using the higher level APIs.
-
-## Setup
-
-Before using this guide, @{$install$install TensorFlow}.
-
-To get the most out of this guide, you should know the following:
-
-* How to program in Python.
-* At least a little bit about arrays.
-* Ideally, something about machine learning.
-
-Feel free to launch `python` and follow along with this walkthrough.
-Run the following lines to set up your Python environment:
-
-```python
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-```
-
-## Tensor Values
-
-The central unit of data in TensorFlow is the **tensor**. A tensor consists of a
-set of primitive values shaped into an array of any number of dimensions. A
-tensor's **rank** is its number of dimensions, while its **shape** is a tuple
-of integers specifying the array's length along each dimension. Here are some
-examples of tensor values:
-
-```python
-3. # a rank 0 tensor; a scalar with shape [],
-[1., 2., 3.] # a rank 1 tensor; a vector with shape [3]
-[[1., 2., 3.], [4., 5., 6.]] # a rank 2 tensor; a matrix with shape [2, 3]
-[[[1., 2., 3.]], [[7., 8., 9.]]] # a rank 3 tensor with shape [2, 1, 3]
-```
-
-TensorFlow uses numpy arrays to represent tensor **values**.
-
-## TensorFlow Core Walkthrough
-
-You might think of TensorFlow Core programs as consisting of two discrete
-sections:
-
-1. Building the computational graph (a `tf.Graph`).
-2. Running the computational graph (using a `tf.Session`).
-
-### Graph
-
-A **computational graph** is a series of TensorFlow operations arranged into a
-graph. The graph is composed of two types of objects.
-
- * `tf.Operation` (or "ops"): The nodes of the graph.
- Operations describe calculations that consume and produce tensors.
- * `tf.Tensor`: The edges in the graph. These represent the values
- that will flow through the graph. Most TensorFlow functions return
- `tf.Tensors`.
-
-Important: `tf.Tensors` do not have values, they are just handles to elements
-in the computation graph.
-
-Let's build a simple computational graph. The most basic operation is a
-constant. The Python function that builds the operation takes a tensor value as
-input. The resulting operation takes no inputs. When run, it outputs the
-value that was passed to the constructor. We can create two floating point
-constants `a` and `b` as follows:
-
-```python
-a = tf.constant(3.0, dtype=tf.float32)
-b = tf.constant(4.0) # also tf.float32 implicitly
-total = a + b
-print(a)
-print(b)
-print(total)
-```
-
-The print statements produce:
-
-```
-Tensor("Const:0", shape=(), dtype=float32)
-Tensor("Const_1:0", shape=(), dtype=float32)
-Tensor("add:0", shape=(), dtype=float32)
-```
-
-Notice that printing the tensors does not output the values `3.0`, `4.0`, and
-`7.0` as you might expect. The above statements only build the computation
-graph. These `tf.Tensor` objects just represent the results of the operations
-that will be run.
-
-Each operation in a graph is given a unique name. This name is independent of
-the names the objects are assigned to in Python. Tensors are named after the
-operation that produces them followed by an output index, as in
-`"add:0"` above.
-
-### TensorBoard
-
-TensorFlow provides a utility called TensorBoard. One of TensorBoard's many
-capabilities is visualizing a computation graph. You can easily do this with
-a few simple commands.
-
-First you save the computation graph to a TensorBoard summary file as
-follows:
-
-```
-writer = tf.summary.FileWriter('.')
-writer.add_graph(tf.get_default_graph())
-```
-
-This will produce an `event` file in the current directory with a name in the
-following format:
-
-```
-events.out.tfevents.{timestamp}.{hostname}
-```
-
-Now, in a new terminal, launch TensorBoard with the following shell command:
-
-```bsh
-tensorboard --logdir .
-```
-
-Then open TensorBoard's [graphs page](http://localhost:6006/#graphs) in your
-browser, and you should see a graph similar to the following:
-
-![TensorBoard screenshot](https://www.tensorflow.org/images/getting_started_add.png)
-
-For more about TensorBoard's graph visualization tools see @{$graph_viz}.
-
-### Session
-
-To evaluate tensors, instantiate a `tf.Session` object, informally known as a
-**session**. A session encapsulates the state of the TensorFlow runtime, and
-runs TensorFlow operations. If a `tf.Graph` is like a `.py` file, a `tf.Session`
-is like the `python` executable.
-
-The following code creates a `tf.Session` object and then invokes its `run`
-method to evaluate the `total` tensor we created above:
-
-```python
-sess = tf.Session()
-print(sess.run(total))
-```
-
-When you request the output of a node with `Session.run` TensorFlow backtracks
-through the graph and runs all the nodes that provide input to the requested
-output node. So this prints the expected value of 7.0:
-
-```
-7.0
-```
-
-You can pass multiple tensors to `tf.Session.run`. The `run` method
-transparently handles any combination of tuples or dictionaries, as in the
-following example:
-
-```python
-print(sess.run({'ab':(a, b), 'total':total}))
-```
-
-which returns the results in a structure of the same layout:
-
-``` None
-{'total': 7.0, 'ab': (3.0, 4.0)}
-```
-
-During a call to `tf.Session.run` any `tf.Tensor` only has a single value.
-For example, the following code calls `tf.random_uniform` to produce a
-`tf.Tensor` that generates a random 3-element vector (with values in `[0,1)`):
-
-```python
-vec = tf.random_uniform(shape=(3,))
-out1 = vec + 1
-out2 = vec + 2
-print(sess.run(vec))
-print(sess.run(vec))
-print(sess.run((out1, out2)))
-```
-
-The result shows a different random value on each call to `run`, but
-a consistent value during a single `run` (`out1` and `out2` receive the same
-random input):
-
-```
-[ 0.52917576 0.64076328 0.68353939]
-[ 0.66192627 0.89126778 0.06254101]
-(
- array([ 1.88408756, 1.87149239, 1.84057522], dtype=float32),
- array([ 2.88408756, 2.87149239, 2.84057522], dtype=float32)
-)
-```
-
-Some TensorFlow functions return `tf.Operations` instead of `tf.Tensors`.
-The result of calling `run` on an Operation is `None`. You run an operation
-to cause a side-effect, not to retrieve a value. Examples of this include the
-[initialization](#Initializing Layers), and [training](#Training) ops
-demonstrated later.
-
-### Feeding
-
-As it stands, this graph is not especially interesting because it always
-produces a constant result. A graph can be parameterized to accept external
-inputs, known as **placeholders**. A **placeholder** is a promise to provide a
-value later, like a function argument.
-
-```python
-x = tf.placeholder(tf.float32)
-y = tf.placeholder(tf.float32)
-z = x + y
-```
-
-The preceding three lines are a bit like a function in which we
-define two input parameters (`x` and `y`) and then an operation on them. We can
-evaluate this graph with multiple inputs by using the `feed_dict` argument of
-the `tf.Session.run` method to feed concrete values to the placeholders:
-
-```python
-print(sess.run(z, feed_dict={x: 3, y: 4.5}))
-print(sess.run(z, feed_dict={x: [1, 3], y: [2, 4]}))
-```
-This results in the following output:
-
-```
-7.5
-[ 3. 7.]
-```
-
-Also note that the `feed_dict` argument can be used to overwrite any tensor in
-the graph. The only difference between placeholders and other `tf.Tensors` is
-that placeholders throw an error if no value is fed to them.
-
-## Datasets
-
-Placeholders work for simple experiments, but `tf.data` are the
-preferred method of streaming data into a model.
-
-To get a runnable `tf.Tensor` from a Dataset you must first convert it to a
-`tf.data.Iterator`, and then call the Iterator's
-`tf.data.Iterator.get_next` method.
-
-The simplest way to create an Iterator is with the
-`tf.data.Dataset.make_one_shot_iterator` method.
-For example, in the following code the `next_item` tensor will return a row from
-the `my_data` array on each `run` call:
-
-``` python
-my_data = [
- [0, 1,],
- [2, 3,],
- [4, 5,],
- [6, 7,],
-]
-slices = tf.data.Dataset.from_tensor_slices(my_data)
-next_item = slices.make_one_shot_iterator().get_next()
-```
-
-Reaching the end of the data stream causes `Dataset` to throw an
-`tf.errors.OutOfRangeError`. For example, the following code
-reads the `next_item` until there is no more data to read:
-
-``` python
-while True:
- try:
- print(sess.run(next_item))
- except tf.errors.OutOfRangeError:
- break
-```
-
-If the `Dataset` depends on stateful operations you may need to
-initialize the iterator before using it, as shown below:
-
-``` python
-r = tf.random_normal([10,3])
-dataset = tf.data.Dataset.from_tensor_slices(r)
-iterator = dataset.make_initializable_iterator()
-next_row = iterator.get_next()
-
-sess.run(iterator.initializer)
-while True:
- try:
- print(sess.run(next_row))
- except tf.errors.OutOfRangeError:
- break
-```
-
-For more details on Datasets and Iterators see: @{$guide/datasets}.
-
-## Layers
-
-A trainable model must modify the values in the graph to get new outputs with
-the same input. `tf.layers` are the preferred way to add trainable
-parameters to a graph.
-
-Layers package together both the variables and the operations that act
-on them. For example a
-[densely-connected layer](https://developers.google.com/machine-learning/glossary/#fully_connected_layer)
-performs a weighted sum across all inputs
-for each output and applies an optional
-[activation function](https://developers.google.com/machine-learning/glossary/#activation_function).
-The connection weights and biases are managed by the layer object.
-
-### Creating Layers
-
-The following code creates a `tf.layers.Dense` layer that takes a
-batch of input vectors, and produces a single output value for each. To apply a
-layer to an input, call the layer as if it were a function. For example:
-
-```python
-x = tf.placeholder(tf.float32, shape=[None, 3])
-linear_model = tf.layers.Dense(units=1)
-y = linear_model(x)
-```
-
-The layer inspects its input to determine sizes for its internal variables. So
-here we must set the shape of the `x` placeholder so that the layer can
-build a weight matrix of the correct size.
-
-Now that we have defined the calculation of the output, `y`, there is one more
-detail we need to take care of before we run the calculation.
-
-### Initializing Layers
-
-The layer contains variables that must be **initialized** before they can be
-used. While it is possible to initialize variables individually, you can easily
-initialize all the variables in a TensorFlow graph as follows:
-
-```python
-init = tf.global_variables_initializer()
-sess.run(init)
-```
-
-Important: Calling `tf.global_variables_initializer` only
-creates and returns a handle to a TensorFlow operation. That op
-will initialize all the global variables when we run it with `tf.Session.run`.
-
-Also note that this `global_variables_initializer` only initializes variables
-that existed in the graph when the initializer was created. So the initializer
-should be one of the last things added during graph construction.
-
-### Executing Layers
-
-Now that the layer is initialized, we can evaluate the `linear_model`'s output
-tensor as we would any other tensor. For example, the following code:
-
-```python
-print(sess.run(y, {x: [[1, 2, 3],[4, 5, 6]]}))
-```
-
-will generate a two-element output vector such as the following:
-
-```
-[[-3.41378999]
- [-9.14999008]]
-```
-
-### Layer Function shortcuts
-
-For each layer class (like `tf.layers.Dense`) TensorFlow also supplies a
-shortcut function (like `tf.layers.dense`). The only difference is that the
-shortcut function versions create and run the layer in a single call. For
-example, the following code is equivalent to the earlier version:
-
-```python
-x = tf.placeholder(tf.float32, shape=[None, 3])
-y = tf.layers.dense(x, units=1)
-
-init = tf.global_variables_initializer()
-sess.run(init)
-
-print(sess.run(y, {x: [[1, 2, 3], [4, 5, 6]]}))
-```
-
-While convenient, this approach allows no access to the `tf.layers.Layer`
-object. This makes introspection and debugging more difficult,
-and layer reuse impossible.
-
-## Feature columns
-
-The easiest way to experiment with feature columns is using the
-`tf.feature_column.input_layer` function. This function only accepts
-@{$feature_columns$dense columns} as inputs, so to view the result
-of a categorical column you must wrap it in an
-`tf.feature_column.indicator_column`. For example:
-
-``` python
-features = {
- 'sales' : [[5], [10], [8], [9]],
- 'department': ['sports', 'sports', 'gardening', 'gardening']}
-
-department_column = tf.feature_column.categorical_column_with_vocabulary_list(
- 'department', ['sports', 'gardening'])
-department_column = tf.feature_column.indicator_column(department_column)
-
-columns = [
- tf.feature_column.numeric_column('sales'),
- department_column
-]
-
-inputs = tf.feature_column.input_layer(features, columns)
-```
-
-Running the `inputs` tensor will parse the `features` into a batch of vectors.
-
-Feature columns can have internal state, like layers, so they often need to be
-initialized. Categorical columns use `tf.contrib.lookup`
-internally and these require a separate initialization op,
-`tf.tables_initializer`.
-
-``` python
-var_init = tf.global_variables_initializer()
-table_init = tf.tables_initializer()
-sess = tf.Session()
-sess.run((var_init, table_init))
-```
-
-Once the internal state has been initialized you can run `inputs` like any
-other `tf.Tensor`:
-
-```python
-print(sess.run(inputs))
-```
-
-This shows how the feature columns have packed the input vectors, with the
-one-hot "department" as the first two indices and "sales" as the third.
-
-```None
-[[ 1. 0. 5.]
- [ 1. 0. 10.]
- [ 0. 1. 8.]
- [ 0. 1. 9.]]
-```
-
-## Training
-
-Now that you're familiar with the basics of core TensorFlow, let's train a
-small regression model manually.
-
-### Define the data
-
-First let's define some inputs, `x`, and the expected output for each input,
-`y_true`:
-
-```python
-x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
-y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32)
-```
-
-### Define the model
-
-Next, build a simple linear model, with 1 output:
-
-``` python
-linear_model = tf.layers.Dense(units=1)
-
-y_pred = linear_model(x)
-```
-
-You can evaluate the predictions as follows:
-
-``` python
-sess = tf.Session()
-init = tf.global_variables_initializer()
-sess.run(init)
-
-print(sess.run(y_pred))
-```
-
-The model hasn't yet been trained, so the four "predicted" values aren't very
-good. Here's what we got; your own output will almost certainly differ:
-
-``` None
-[[ 0.02631879]
- [ 0.05263758]
- [ 0.07895637]
- [ 0.10527515]]
-```
-
-### Loss
-
-To optimize a model, you first need to define the loss. We'll use the mean
-square error, a standard loss for regression problems.
-
-While you could do this manually with lower level math operations,
-the `tf.losses` module provides a set of common loss functions. You can use it
-to calculate the mean square error as follows:
-
-``` python
-loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
-
-print(sess.run(loss))
-```
-This will produce a loss value, something like:
-
-``` None
-2.23962
-```
-
-### Training
-
-TensorFlow provides
-[**optimizers**](https://developers.google.com/machine-learning/glossary/#optimizer)
-implementing standard optimization algorithms. These are implemented as
-sub-classes of `tf.train.Optimizer`. They incrementally change each
-variable in order to minimize the loss. The simplest optimization algorithm is
-[**gradient descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent),
-implemented by `tf.train.GradientDescentOptimizer`. It modifies each
-variable according to the magnitude of the derivative of loss with respect to
-that variable. For example:
-
-```python
-optimizer = tf.train.GradientDescentOptimizer(0.01)
-train = optimizer.minimize(loss)
-```
-
-This code builds all the graph components necessary for the optimization, and
-returns a training operation. When run, the training op will update variables
-in the graph. You might run it as follows:
-
-```python
-for i in range(100):
- _, loss_value = sess.run((train, loss))
- print(loss_value)
-```
-
-Since `train` is an op, not a tensor, it doesn't return a value when run.
-To see the progression of the loss during training, we run the loss tensor at
-the same time, producing output like the following:
-
-``` None
-1.35659
-1.00412
-0.759167
-0.588829
-0.470264
-0.387626
-0.329918
-0.289511
-0.261112
-0.241046
-...
-```
-
-### Complete program
-
-```python
-x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
-y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32)
-
-linear_model = tf.layers.Dense(units=1)
-
-y_pred = linear_model(x)
-loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
-
-optimizer = tf.train.GradientDescentOptimizer(0.01)
-train = optimizer.minimize(loss)
-
-init = tf.global_variables_initializer()
-
-sess = tf.Session()
-sess.run(init)
-for i in range(100):
- _, loss_value = sess.run((train, loss))
- print(loss_value)
-
-print(sess.run(y_pred))
-```
-
-## Next steps
-
-To learn more about building models with TensorFlow consider the following:
-
-* @{$custom_estimators$Custom Estimators}, to learn how to build
- customized models with TensorFlow. Your knowledge of TensorFlow Core will
- help you understand and debug your own models.
-
-If you want to learn more about the inner workings of TensorFlow consider the
-following documents, which go into more depth on many of the topics discussed
-here:
-
-* @{$graphs}
-* @{$tensors}
-* @{$variables}
-
-
diff --git a/tensorflow/docs_src/guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md
deleted file mode 100644
index dc38f0c1d3..0000000000
--- a/tensorflow/docs_src/guide/premade_estimators.md
+++ /dev/null
@@ -1,430 +0,0 @@
-# Premade Estimators
-
-This document introduces the TensorFlow programming environment and shows you
-how to solve the Iris classification problem in TensorFlow.
-
-## Prerequisites
-
-Prior to using the sample code in this document, you'll need to do the
-following:
-
-* @{$install$Install TensorFlow}.
-* If you installed TensorFlow with virtualenv or Anaconda, activate your
- TensorFlow environment.
-* Install or upgrade pandas by issuing the following command:
-
- pip install pandas
-
-## Getting the sample code
-
-Take the following steps to get the sample code we'll be going through:
-
-1. Clone the TensorFlow Models repository from GitHub by entering the following
- command:
-
- git clone https://github.com/tensorflow/models
-
-1. Change directory within that branch to the location containing the examples
- used in this document:
-
- cd models/samples/core/get_started/
-
-The program described in this document is
-[`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
-This program uses
-[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py)
-to fetch its training data.
-
-### Running the program
-
-You run TensorFlow programs as you would run any Python program. For example:
-
-``` bsh
-python premade_estimator.py
-```
-
-The program should output training logs followed by some predictions against
-the test set. For example, the first line in the following output shows that
-the model thinks there is a 99.6% chance that the first example in the test
-set is a Setosa. Since the test set expected Setosa, this appears to be
-a good prediction.
-
-``` None
-...
-Prediction is "Setosa" (99.6%), expected "Setosa"
-
-Prediction is "Versicolor" (99.8%), expected "Versicolor"
-
-Prediction is "Virginica" (97.9%), expected "Virginica"
-```
-
-If the program generates errors instead of answers, ask yourself the following
-questions:
-
-* Did you install TensorFlow properly?
-* Are you using the correct version of TensorFlow?
-* Did you activate the environment you installed TensorFlow in? (This is
- only relevant in certain installation mechanisms.)
-
-## The programming stack
-
-Before getting into the details of the program itself, let's investigate the
-programming environment. As the following illustration shows, TensorFlow
-provides a programming stack consisting of multiple API layers:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/tensorflow_programming_environment.png">
-</div>
-
-We strongly recommend writing TensorFlow programs with the following APIs:
-
-* @{$guide/estimators$Estimators}, which represent a complete model.
- The Estimator API provides methods to train the model, to judge the model's
- accuracy, and to generate predictions.
-* @{$guide/datasets_for_estimators}, which build a data input
- pipeline. The Dataset API has methods to load and manipulate data, and feed
- it into your model. The Dataset API meshes well with the Estimators API.
-
-## Classifying irises: an overview
-
-The sample program in this document builds and tests a model that
-classifies Iris flowers into three different species based on the size of their
-[sepals](https://en.wikipedia.org/wiki/Sepal) and
-[petals](https://en.wikipedia.org/wiki/Petal).
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%"
- alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
- src="../images/iris_three_species.jpg">
-</div>
-
-**From left to right,
-[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
-[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),
-[*Iris versicolor*](https://commons.wikimedia.org/w/index.php?curid=248095) (by
-[Dlanglois](https://commons.wikimedia.org/wiki/User:Dlanglois), CC BY-SA 3.0),
-and [*Iris virginica*](https://www.flickr.com/photos/33397993@N05/3352169862)
-(by [Frank Mayfield](https://www.flickr.com/photos/33397993@N05), CC BY-SA
-2.0).**
-
-### The data set
-
-The Iris data set contains four features and one
-[label](https://developers.google.com/machine-learning/glossary/#label).
-The four features identify the following botanical characteristics of
-individual Iris flowers:
-
-* sepal length
-* sepal width
-* petal length
-* petal width
-
-Our model will represent these features as `float32` numerical data.
-
-The label identifies the Iris species, which must be one of the following:
-
-* Iris setosa (0)
-* Iris versicolor (1)
-* Iris virginica (2)
-
-Our model will represent the label as `int32` categorical data.
-
-The following table shows three examples in the data set:
-
-|sepal length | sepal width | petal length | petal width| species (label) |
-|------------:|------------:|-------------:|-----------:|:---------------:|
-| 5.1 | 3.3 | 1.7 | 0.5 | 0 (Setosa) |
-| 5.0 | 2.3 | 3.3 | 1.0 | 1 (versicolor)|
-| 6.4 | 2.8 | 5.6 | 2.2 | 2 (virginica) |
-
-### The algorithm
-
-The program trains a Deep Neural Network classifier model having the following
-topology:
-
-* 2 hidden layers.
-* Each hidden layer contains 10 nodes.
-
-The following figure illustrates the features, hidden layers, and predictions
-(not all of the nodes in the hidden layers are shown):
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%"
- alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs"
- src="../images/custom_estimators/full_network.png">
-</div>
-
-### Inference
-
-Running the trained model on an unlabeled example yields three predictions,
-namely, the likelihood that this flower is the given Iris species. The sum of
-those output predictions will be 1.0. For example, the prediction on an
-unlabeled example might be something like the following:
-
-* 0.03 for Iris Setosa
-* 0.95 for Iris Versicolor
-* 0.02 for Iris Virginica
-
-The preceding prediction indicates a 95% probability that the given unlabeled
-example is an Iris Versicolor.
-
-## Overview of programming with Estimators
-
-An Estimator is TensorFlow's high-level representation of a complete model. It
-handles the details of initialization, logging, saving and restoring, and many
-other features so you can concentrate on your model. For more details see
-@{$guide/estimators}.
-
-An Estimator is any class derived from `tf.estimator.Estimator`. TensorFlow
-provides a collection of
-`tf.estimator`
-(for example, `LinearRegressor`) to implement common ML algorithms. Beyond
-those, you may write your own
-@{$custom_estimators$custom Estimators}.
-We recommend using pre-made Estimators when just getting started.
-
-To write a TensorFlow program based on pre-made Estimators, you must perform the
-following tasks:
-
-* Create one or more input functions.
-* Define the model's feature columns.
-* Instantiate an Estimator, specifying the feature columns and various
- hyperparameters.
-* Call one or more methods on the Estimator object, passing the appropriate
- input function as the source of the data.
-
-Let's see how those tasks are implemented for Iris classification.
-
-## Create input functions
-
-You must create input functions to supply data for training,
-evaluating, and prediction.
-
-An **input function** is a function that returns a `tf.data.Dataset` object
-which outputs the following two-element tuple:
-
-* [`features`](https://developers.google.com/machine-learning/glossary/#feature) - A Python dictionary in which:
- * Each key is the name of a feature.
- * Each value is an array containing all of that feature's values.
-* `label` - An array containing the values of the
- [label](https://developers.google.com/machine-learning/glossary/#label) for
- every example.
-
-Just to demonstrate the format of the input function, here's a simple
-implementation:
-
-```python
-def input_evaluation_set():
- features = {'SepalLength': np.array([6.4, 5.0]),
- 'SepalWidth': np.array([2.8, 2.3]),
- 'PetalLength': np.array([5.6, 3.3]),
- 'PetalWidth': np.array([2.2, 1.0])}
- labels = np.array([2, 1])
- return features, labels
-```
-
-Your input function may generate the `features` dictionary and `label` list any
-way you like. However, we recommend using TensorFlow's Dataset API, which can
-parse all sorts of data. At a high level, the Dataset API consists of the
-following classes:
-
-<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%"
- alt="A diagram showing subclasses of the Dataset class"
- src="../images/dataset_classes.png">
-</div>
-
-Where the individual members are:
-
-* `Dataset` - Base class containing methods to create and transform
- datasets. Also allows you to initialize a dataset from data in memory, or from
- a Python generator.
-* `TextLineDataset` - Reads lines from text files.
-* `TFRecordDataset` - Reads records from TFRecord files.
-* `FixedLengthRecordDataset` - Reads fixed size records from binary files.
-* `Iterator` - Provides a way to access one data set element at a time.
-
-The Dataset API can handle a lot of common cases for you. For example,
-using the Dataset API, you can easily read in records from a large collection
-of files in parallel and join them into a single stream.
-
-To keep things simple in this example we are going to load the data with
-[pandas](https://pandas.pydata.org/), and build our input pipeline from this
-in-memory data.
-
-Here is the input function used for training in this program, which is available
-in [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py):
-
-``` python
-def train_input_fn(features, labels, batch_size):
- """An input function for training"""
- # Convert the inputs to a Dataset.
- dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
-
- # Shuffle, repeat, and batch the examples.
- return dataset.shuffle(1000).repeat().batch(batch_size)
-```
-
-## Define the feature columns
-
-A [**feature column**](https://developers.google.com/machine-learning/glossary/#feature_columns)
-is an object describing how the model should use raw input data from the
-features dictionary. When you build an Estimator model, you pass it a list of
-feature columns that describes each of the features you want the model to use.
-The `tf.feature_column` module provides many options for representing data
-to the model.
-
-For Iris, the 4 raw features are numeric values, so we'll build a list of
-feature columns to tell the Estimator model to represent each of the four
-features as 32-bit floating-point values. Therefore, the code to create the
-feature column is:
-
-```python
-# Feature columns describe how to use the input.
-my_feature_columns = []
-for key in train_x.keys():
- my_feature_columns.append(tf.feature_column.numeric_column(key=key))
-```
-
-Feature columns can be far more sophisticated than those we're showing here. We
-detail feature columns @{$feature_columns$later on} in our Getting
-Started guide.
-
-Now that we have the description of how we want the model to represent the raw
-features, we can build the estimator.
-
-
-## Instantiate an estimator
-
-The Iris problem is a classic classification problem. Fortunately, TensorFlow
-provides several pre-made classifier Estimators, including:
-
-* `tf.estimator.DNNClassifier` for deep models that perform multi-class
- classification.
-* `tf.estimator.DNNLinearCombinedClassifier` for wide & deep models.
-* `tf.estimator.LinearClassifier` for classifiers based on linear models.
-
-For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice.
-Here's how we instantiated this Estimator:
-
-```python
-# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
-classifier = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- # Two hidden layers of 10 nodes each.
- hidden_units=[10, 10],
- # The model must choose between 3 classes.
- n_classes=3)
-```
-
-## Train, Evaluate, and Predict
-
-Now that we have an Estimator object, we can call methods to do the following:
-
-* Train the model.
-* Evaluate the trained model.
-* Use the trained model to make predictions.
-
-### Train the model
-
-Train the model by calling the Estimator's `train` method as follows:
-
-```python
-# Train the Model.
-classifier.train(
- input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
- steps=args.train_steps)
-```
-
-Here we wrap up our `input_fn` call in a
-[`lambda`](https://docs.python.org/3/tutorial/controlflow.html)
-to capture the arguments while providing an input function that takes no
-arguments, as expected by the Estimator. The `steps` argument tells the method
-to stop training after a number of training steps.
-
-### Evaluate the trained model
-
-Now that the model has been trained, we can get some statistics on its
-performance. The following code block evaluates the accuracy of the trained
-model on the test data:
-
-```python
-# Evaluate the model.
-eval_result = classifier.evaluate(
- input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
-
-print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
-```
-
-Unlike our call to the `train` method, we did not pass the `steps`
-argument to evaluate. Our `eval_input_fn` only yields a single
-[epoch](https://developers.google.com/machine-learning/glossary/#epoch) of data.
-
-Running this code yields the following output (or something similar):
-
-```none
-Test set accuracy: 0.967
-```
-
-### Making predictions (inferring) from the trained model
-
-We now have a trained model that produces good evaluation results.
-We can now use the trained model to predict the species of an Iris flower
-based on some unlabeled measurements. As with training and evaluation, we make
-predictions using a single function call:
-
-```python
-# Generate predictions from the model
-expected = ['Setosa', 'Versicolor', 'Virginica']
-predict_x = {
- 'SepalLength': [5.1, 5.9, 6.9],
- 'SepalWidth': [3.3, 3.0, 3.1],
- 'PetalLength': [1.7, 4.2, 5.4],
- 'PetalWidth': [0.5, 1.5, 2.1],
-}
-
-predictions = classifier.predict(
- input_fn=lambda:iris_data.eval_input_fn(predict_x,
- batch_size=args.batch_size))
-```
-
-The `predict` method returns a Python iterable, yielding a dictionary of
-prediction results for each example. The following code prints a few
-predictions and their probabilities:
-
-
-``` python
-template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
-
-for pred_dict, expec in zip(predictions, expected):
- class_id = pred_dict['class_ids'][0]
- probability = pred_dict['probabilities'][class_id]
-
- print(template.format(iris_data.SPECIES[class_id],
- 100 * probability, expec))
-```
-
-Running the preceding code yields the following output:
-
-``` None
-...
-Prediction is "Setosa" (99.6%), expected "Setosa"
-
-Prediction is "Versicolor" (99.8%), expected "Versicolor"
-
-Prediction is "Virginica" (97.9%), expected "Virginica"
-```
-
-
-## Summary
-
-Pre-made Estimators are an effective way to quickly create standard models.
-
-Now that you've gotten started writing TensorFlow programs, consider the
-following material:
-
-* @{$checkpoints$Checkpoints} to learn how to save and restore models.
-* @{$guide/datasets_for_estimators} to learn more about importing
- data into your model.
-* @{$custom_estimators$Creating Custom Estimators} to learn how to
- write your own Estimator, customized for a particular problem.
diff --git a/tensorflow/docs_src/guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md
deleted file mode 100644
index c260da7966..0000000000
--- a/tensorflow/docs_src/guide/saved_model.md
+++ /dev/null
@@ -1,999 +0,0 @@
-# Save and Restore
-
-The `tf.train.Saver` class provides methods to save and restore models. The
-`tf.saved_model.simple_save` function is an easy way to build a
-`tf.saved_model` suitable for serving. [Estimators](./estimators)
-automatically save and restore variables in the `model_dir`.
-
-## Save and restore variables
-
-TensorFlow @{$variables} are the best way to represent shared, persistent state
-manipulated by your program. The `tf.train.Saver` constructor adds `save` and
-`restore` ops to the graph for all, or a specified list, of the variables in the
-graph. The `Saver` object provides methods to run these ops, specifying paths
-for the checkpoint files to write to or read from.
-
-`Saver` restores all variables already defined in your model. If you're
-loading a model without knowing how to build its graph (for example, if you're
-writing a generic program to load models), then read the
-[Overview of saving and restoring models](#models) section
-later in this document.
-
-TensorFlow saves variables in binary *checkpoint files* that map variable
-names to tensor values.
-
-Caution: TensorFlow model files are code. Be careful with untrusted code.
-See [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)
-for details.
-
-### Save variables
-
-Create a `Saver` with `tf.train.Saver()` to manage all variables in the
-model. For example, the following snippet demonstrates how to call the
-`tf.train.Saver.save` method to save variables to checkpoint files:
-
-```python
-# Create some variables.
-v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
-v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
-
-inc_v1 = v1.assign(v1+1)
-dec_v2 = v2.assign(v2-1)
-
-# Add an op to initialize the variables.
-init_op = tf.global_variables_initializer()
-
-# Add ops to save and restore all the variables.
-saver = tf.train.Saver()
-
-# Later, launch the model, initialize the variables, do some work, and save the
-# variables to disk.
-with tf.Session() as sess:
- sess.run(init_op)
- # Do some work with the model.
- inc_v1.op.run()
- dec_v2.op.run()
- # Save the variables to disk.
- save_path = saver.save(sess, "/tmp/model.ckpt")
- print("Model saved in path: %s" % save_path)
-```
-
-### Restore variables
-
-The `tf.train.Saver` object not only saves variables to checkpoint files, it
-also restores variables. Note that when you restore variables you do not have
-to initialize them beforehand. For example, the following snippet demonstrates
-how to call the `tf.train.Saver.restore` method to restore variables from the
-checkpoint files:
-
-```python
-tf.reset_default_graph()
-
-# Create some variables.
-v1 = tf.get_variable("v1", shape=[3])
-v2 = tf.get_variable("v2", shape=[5])
-
-# Add ops to save and restore all the variables.
-saver = tf.train.Saver()
-
-# Later, launch the model, use the saver to restore variables from disk, and
-# do some work with the model.
-with tf.Session() as sess:
- # Restore variables from disk.
- saver.restore(sess, "/tmp/model.ckpt")
- print("Model restored.")
- # Check the values of the variables
- print("v1 : %s" % v1.eval())
- print("v2 : %s" % v2.eval())
-```
-
-Note: There is not a physical file called `/tmp/model.ckpt`. It is the *prefix* of
-filenames created for the checkpoint. Users only interact with the prefix
-instead of physical checkpoint files.
-
-### Choose variables to save and restore
-
-If you do not pass any arguments to `tf.train.Saver()`, the saver handles all
-variables in the graph. Each variable is saved under the name that was passed
-when the variable was created.
-
-It is sometimes useful to explicitly specify names for variables in the
-checkpoint files. For example, you may have trained a model with a variable
-named `"weights"` whose value you want to restore into a variable named
-`"params"`.
-
-It is also sometimes useful to only save or restore a subset of the variables
-used by a model. For example, you may have trained a neural net with five
-layers, and you now want to train a new model with six layers that reuses the
-existing weights of the five trained layers. You can use the saver to restore
-the weights of just the first five layers.
-
-You can easily specify the names and variables to save or load by passing to the
-`tf.train.Saver()` constructor either of the following:
-
-* A list of variables (which will be stored under their own names).
-* A Python dictionary in which keys are the names to use and the values are the
-variables to manage.
-
-Continuing from the save/restore examples shown earlier:
-
-```python
-tf.reset_default_graph()
-# Create some variables.
-v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
-v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
-
-# Add ops to save and restore only `v2` using the name "v2"
-saver = tf.train.Saver({"v2": v2})
-
-# Use the saver object normally after that.
-with tf.Session() as sess:
- # Initialize v1 since the saver will not.
- v1.initializer.run()
- saver.restore(sess, "/tmp/model.ckpt")
-
- print("v1 : %s" % v1.eval())
- print("v2 : %s" % v2.eval())
-```
-
-Notes:
-
-* You can create as many `Saver` objects as you want if you need to save and
- restore different subsets of the model variables. The same variable can be
- listed in multiple saver objects; its value is only changed when the
- `Saver.restore()` method is run.
-
-* If you only restore a subset of the model variables at the start of a
- session, you have to run an initialize op for the other variables. See
- `tf.variables_initializer` for more information.
-
-* To inspect the variables in a checkpoint, you can use the
- [`inspect_checkpoint`](https://www.tensorflow.org/code/tensorflow/python/tools/inspect_checkpoint.py)
- library, particularly the `print_tensors_in_checkpoint_file` function.
-
-* By default, `Saver` uses the value of the `tf.Variable.name` property
- for each variable. However, when you create a `Saver` object, you may
- optionally choose names for the variables in the checkpoint files.
-
-
-### Inspect variables in a checkpoint
-
-We can quickly inspect variables in a checkpoint with the
-[`inspect_checkpoint`](https://www.tensorflow.org/code/tensorflow/python/tools/inspect_checkpoint.py) library.
-
-Continuing from the save/restore examples shown earlier:
-
-```python
-# import the inspect_checkpoint library
-from tensorflow.python.tools import inspect_checkpoint as chkp
-
-# print all tensors in checkpoint file
-chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
-
-# tensor_name: v1
-# [ 1. 1. 1.]
-# tensor_name: v2
-# [-1. -1. -1. -1. -1.]
-
-# print only tensor v1 in checkpoint file
-chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
-
-# tensor_name: v1
-# [ 1. 1. 1.]
-
-# print only tensor v2 in checkpoint file
-chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
-
-# tensor_name: v2
-# [-1. -1. -1. -1. -1.]
-```
-
-
-<a name="models"></a>
-## Save and restore models
-
-Use `SavedModel` to save and load your model—variables, the graph, and the
-graph's metadata. This is a language-neutral, recoverable, hermetic
-serialization format that enables higher-level systems and tools to produce,
-consume, and transform TensorFlow models. TensorFlow provides several ways to
-interact with `SavedModel`, including the `tf.saved_model` APIs,
-`tf.estimator.Estimator`, and a command-line interface.
-
-
-## Build and load a SavedModel
-
-### Simple save
-
-The easiest way to create a `SavedModel` is to use the `tf.saved_model.simple_save`
-function:
-
-```python
-simple_save(session,
- export_dir,
- inputs={"x": x, "y": y},
- outputs={"z": z})
-```
-
-This configures the `SavedModel` so it can be loaded by
-[TensorFlow serving](/serving/serving_basic) and supports the
-[Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
-To access the classify, regress, or multi-inference APIs, use the manual
-`SavedModel` builder APIs or an `tf.estimator.Estimator`.
-
-### Manually build a SavedModel
-
-If your use case isn't covered by `tf.saved_model.simple_save`, use the manual
-`tf.saved_model.builder` to create a `SavedModel`.
-
-The `tf.saved_model.builder.SavedModelBuilder` class provides functionality to
-save multiple `MetaGraphDef`s. A **MetaGraph** is a dataflow graph, plus
-its associated variables, assets, and signatures. A **`MetaGraphDef`**
-is the protocol buffer representation of a MetaGraph. A **signature** is
-the set of inputs to and outputs from a graph.
-
-If assets need to be saved and written or copied to disk, they can be provided
-when the first `MetaGraphDef` is added. If multiple `MetaGraphDef`s are
-associated with an asset of the same name, only the first version is retained.
-
-Each `MetaGraphDef` added to the SavedModel must be annotated with
-user-specified tags. The tags provide a means to identify the specific
-`MetaGraphDef` to load and restore, along with the shared set of variables
-and assets. These tags
-typically annotate a `MetaGraphDef` with its functionality (for example,
-serving or training), and optionally with hardware-specific aspects (for
-example, GPU).
-
-For example, the following code suggests a typical way to use
-`SavedModelBuilder` to build a SavedModel:
-
-```python
-export_dir = ...
-...
-builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
-with tf.Session(graph=tf.Graph()) as sess:
- ...
- builder.add_meta_graph_and_variables(sess,
- [tag_constants.TRAINING],
- signature_def_map=foo_signatures,
- assets_collection=foo_assets,
- strip_default_attrs=True)
-...
-# Add a second MetaGraphDef for inference.
-with tf.Session(graph=tf.Graph()) as sess:
- ...
- builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True)
-...
-builder.save()
-```
-
-<a name="forward_compatibility"></a>
-#### Forward compatibility via `strip_default_attrs=True`
-
-Following the guidance below gives you forward compatibility only if the set of
-Ops has not changed.
-
-The `tf.saved_model.builder.SavedModelBuilder` class allows
-users to control whether default-valued attributes must be stripped from the
-@{$extend/tool_developers#nodes$`NodeDefs`}
-while adding a meta graph to the SavedModel bundle. Both
-`tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables`
-and `tf.saved_model.builder.SavedModelBuilder.add_meta_graph`
-methods accept a Boolean flag `strip_default_attrs` that controls this behavior.
-
-If `strip_default_attrs` is `False`, the exported `tf.MetaGraphDef` will have
-the default valued attributes in all its `tf.NodeDef` instances.
-This can break forward compatibility with a sequence of events such as the
-following:
-
-* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a
- default (`bool`) at version 101.
-* A model producer such as a "trainer binary" picks up this change (version 101)
- to the `OpDef` and re-exports an existing model that uses Op `Foo`.
-* A model consumer (such as [Tensorflow Serving](/serving)) running an older
- binary (version 100) doesn't have attribute `T` for Op `Foo`, but tries to
- import this model. The model consumer doesn't recognize attribute `T` in a
- `NodeDef` that uses Op `Foo` and therefore fails to load the model.
-* By setting `strip_default_attrs` to True, the model producers can strip away
- any default valued attributes in the `NodeDefs`. This helps ensure that newly
- added attributes with defaults don't cause older model consumers to fail
- loading models regenerated with newer training binaries.
-
-See [compatibility guidance](./version_compat.md)
-for more information.
-
-### Loading a SavedModel in Python
-
-The Python version of the SavedModel
-`tf.saved_model.loader`
-provides load and restore capability for a SavedModel. The `load` operation
-requires the following information:
-
-* The session in which to restore the graph definition and variables.
-* The tags used to identify the MetaGraphDef to load.
-* The location (directory) of the SavedModel.
-
-Upon a load, the subset of variables, assets, and signatures supplied as part of
-the specific MetaGraphDef will be restored into the supplied session.
-
-
-```python
-export_dir = ...
-...
-with tf.Session(graph=tf.Graph()) as sess:
- tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
- ...
-```
-
-
-### Load a SavedModel in C++
-
-The C++ version of the SavedModel
-[loader](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h)
-provides an API to load a SavedModel from a path, while allowing
-`SessionOptions` and `RunOptions`.
-You have to specify the tags associated with the graph to be loaded.
-The loaded version of SavedModel is referred to as `SavedModelBundle`
-and contains the MetaGraphDef and the session within which it is loaded.
-
-```c++
-const string export_dir = ...
-SavedModelBundle bundle;
-...
-LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
- &bundle);
-```
-
-### Load and serve a SavedModel in TensorFlow serving
-
-You can easily load and serve a SavedModel with the TensorFlow Serving Model
-Server binary. See [instructions](https://www.tensorflow.org/serving/setup#installing_using_apt-get)
-on how to install the server, or build it if you wish.
-
-Once you have the Model Server, run it with:
-```
-tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path
-```
-Set the port and model_name flags to values of your choosing. The
-model_base_path flag expects to be to a base directory, with each version of
-your model residing in a numerically named subdirectory. If you only have a
-single version of your model, simply place it in a subdirectory like so:
-* Place the model in /tmp/model/0001
-* Set model_base_path to /tmp/model
-
-Store different versions of your model in numerically named subdirectories of a
-common base directory. For example, suppose the base directory is `/tmp/model`.
-If you have only one version of your model, store it in `/tmp/model/0001`. If
-you have two versions of your model, store the second version in
-`/tmp/model/0002`, and so on. Set the `--model-base_path` flag to the base
-directory (`/tmp/model`, in this example). TensorFlow Model Server will serve
-the model in the highest numbered subdirectory of that base directory.
-
-### Standard constants
-
-SavedModel offers the flexibility to build and load TensorFlow graphs for a
-variety of use-cases. For the most common use-cases, SavedModel's APIs
-provide a set of constants in Python and C++ that are easy to
-reuse and share across tools consistently.
-
-#### Standard MetaGraphDef tags
-
-You may use sets of tags to uniquely identify a `MetaGraphDef` saved in a
-SavedModel. A subset of commonly used tags is specified in:
-
-* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/tag_constants.py)
-* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h)
-
-
-#### Standard SignatureDef constants
-
-A [**SignatureDef**](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto)
-is a protocol buffer that defines the signature of a computation
-supported by a graph.
-Commonly used input keys, output keys, and method names are
-defined in:
-
-* [Python](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/signature_constants.py)
-* [C++](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/signature_constants.h)
-
-## Using SavedModel with Estimators
-
-After training an `Estimator` model, you may want to create a service
-from that model that takes requests and returns a result. You can run such a
-service locally on your machine or deploy it in the cloud.
-
-To prepare a trained Estimator for serving, you must export it in the standard
-SavedModel format. This section explains how to:
-
-* Specify the output nodes and the corresponding
- [APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto)
- that can be served (Classify, Regress, or Predict).
-* Export your model to the SavedModel format.
-* Serve the model from a local server and request predictions.
-
-
-### Prepare serving inputs
-
-During training, an @{$premade_estimators#input_fn$`input_fn()`} ingests data
-and prepares it for use by the model. At serving time, similarly, a
-`serving_input_receiver_fn()` accepts inference requests and prepares them for
-the model. This function has the following purposes:
-
-* To add placeholders to the graph that the serving system will feed
- with inference requests.
-* To add any additional ops needed to convert data from the input format
- into the feature `Tensor`s expected by the model.
-
-The function returns a `tf.estimator.export.ServingInputReceiver` object,
-which packages the placeholders and the resulting feature `Tensor`s together.
-
-A typical pattern is that inference requests arrive in the form of serialized
-`tf.Example`s, so the `serving_input_receiver_fn()` creates a single string
-placeholder to receive them. The `serving_input_receiver_fn()` is then also
-responsible for parsing the `tf.Example`s by adding a `tf.parse_example` op to
-the graph.
-
-When writing such a `serving_input_receiver_fn()`, you must pass a parsing
-specification to `tf.parse_example` to tell the parser what feature names to
-expect and how to map them to `Tensor`s. A parsing specification takes the
-form of a dict from feature names to `tf.FixedLenFeature`, `tf.VarLenFeature`,
-and `tf.SparseFeature`. Note this parsing specification should not include
-any label or weight columns, since those will not be available at serving
-time&mdash;in contrast to a parsing specification used in the `input_fn()` at
-training time.
-
-In combination, then:
-
-```py
-feature_spec = {'foo': tf.FixedLenFeature(...),
- 'bar': tf.VarLenFeature(...)}
-
-def serving_input_receiver_fn():
- """An input receiver that expects a serialized tf.Example."""
- serialized_tf_example = tf.placeholder(dtype=tf.string,
- shape=[default_batch_size],
- name='input_example_tensor')
- receiver_tensors = {'examples': serialized_tf_example}
- features = tf.parse_example(serialized_tf_example, feature_spec)
- return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
-```
-
-The `tf.estimator.export.build_parsing_serving_input_receiver_fn` utility
-function provides that input receiver for the common case.
-
-> Note: when training a model to be served using the Predict API with a local
-> server, the parsing step is not needed because the model will receive raw
-> feature data.
-
-Even if you require no parsing or other input processing&mdash;that is, if the
-serving system will feed feature `Tensor`s directly&mdash;you must still provide
-a `serving_input_receiver_fn()` that creates placeholders for the feature
-`Tensor`s and passes them through. The
-`tf.estimator.export.build_raw_serving_input_receiver_fn` utility provides for
-this.
-
-If these utilities do not meet your needs, you are free to write your own
-`serving_input_receiver_fn()`. One case where this may be needed is if your
-training `input_fn()` incorporates some preprocessing logic that must be
-recapitulated at serving time. To reduce the risk of training-serving skew, we
-recommend encapsulating such processing in a function which is then called
-from both `input_fn()` and `serving_input_receiver_fn()`.
-
-Note that the `serving_input_receiver_fn()` also determines the *input*
-portion of the signature. That is, when writing a
-`serving_input_receiver_fn()`, you must tell the parser what signatures
-to expect and how to map them to your model's expected inputs.
-By contrast, the *output* portion of the signature is determined by the model.
-
-<a name="specify_outputs"></a>
-### Specify the outputs of a custom model
-
-When writing a custom `model_fn`, you must populate the `export_outputs` element
-of the `tf.estimator.EstimatorSpec` return value. This is a dict of
-`{name: output}` describing the output signatures to be exported and used during
-serving.
-
-In the usual case of making a single prediction, this dict contains
-one element, and the `name` is immaterial. In a multi-headed model, each head
-is represented by an entry in this dict. In this case the `name` is a string
-of your choice that can be used to request a specific head at serving time.
-
-Each `output` value must be an `ExportOutput` object such as
-`tf.estimator.export.ClassificationOutput`,
-`tf.estimator.export.RegressionOutput`, or
-`tf.estimator.export.PredictOutput`.
-
-These output types map straightforwardly to the
-[TensorFlow Serving APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto),
-and so determine which request types will be honored.
-
-Note: In the multi-headed case, a `SignatureDef` will be generated for each
-element of the `export_outputs` dict returned from the model_fn, named using
-the same keys. These `SignatureDef`s differ only in their outputs, as
-provided by the corresponding `ExportOutput` entry. The inputs are always
-those provided by the `serving_input_receiver_fn`.
-An inference request may specify the head by name. One head must be named
-using [`signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`](https://www.tensorflow.org/code/tensorflow/python/saved_model/signature_constants.py)
-indicating which `SignatureDef` will be served when an inference request
-does not specify one.
-
-<a name="perform_export"></a>
-### Perform the export
-
-To export your trained Estimator, call
-`tf.estimator.Estimator.export_savedmodel` with the export base path and
-the `serving_input_receiver_fn`.
-
-```py
-estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn,
- strip_default_attrs=True)
-```
-
-This method builds a new graph by first calling the
-`serving_input_receiver_fn()` to obtain feature `Tensor`s, and then calling
-this `Estimator`'s `model_fn()` to generate the model graph based on those
-features. It starts a fresh `Session`, and, by default, restores the most recent
-checkpoint into it. (A different checkpoint may be passed, if needed.)
-Finally it creates a time-stamped export directory below the given
-`export_dir_base` (i.e., `export_dir_base/<timestamp>`), and writes a
-SavedModel into it containing a single `MetaGraphDef` saved from this
-Session.
-
-> Note: It is your responsibility to garbage-collect old exports.
-> Otherwise, successive exports will accumulate under `export_dir_base`.
-
-### Serve the exported model locally
-
-For local deployment, you can serve your model using
-[TensorFlow Serving](https://github.com/tensorflow/serving), an open-source project that loads a
-SavedModel and exposes it as a [gRPC](https://www.grpc.io/) service.
-
-First, [install TensorFlow Serving](https://github.com/tensorflow/serving).
-
-Then build and run the local model server, substituting `$export_dir_base` with
-the path to the SavedModel you exported above:
-
-```sh
-bazel build //tensorflow_serving/model_servers:tensorflow_model_server
-bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_base_path=$export_dir_base
-```
-
-Now you have a server listening for inference requests via gRPC on port 9000!
-
-
-### Request predictions from a local server
-
-The server responds to gRPC requests according to the
-[PredictionService](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto#L15)
-gRPC API service definition. (The nested protocol buffers are defined in
-various [neighboring files](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis)).
-
-From the API service definition, the gRPC framework generates client libraries
-in various languages providing remote access to the API. In a project using the
-Bazel build tool, these libraries are built automatically and provided via
-dependencies like these (using Python for example):
-
-```build
- deps = [
- "//tensorflow_serving/apis:classification_proto_py_pb2",
- "//tensorflow_serving/apis:regression_proto_py_pb2",
- "//tensorflow_serving/apis:predict_proto_py_pb2",
- "//tensorflow_serving/apis:prediction_service_proto_py_pb2"
- ]
-```
-
-Python client code can then import the libraries thus:
-
-```py
-from tensorflow_serving.apis import classification_pb2
-from tensorflow_serving.apis import regression_pb2
-from tensorflow_serving.apis import predict_pb2
-from tensorflow_serving.apis import prediction_service_pb2
-```
-
-> Note: `prediction_service_pb2` defines the service as a whole and so
-> is always required. However a typical client will need only one of
-> `classification_pb2`, `regression_pb2`, and `predict_pb2`, depending on the
-> type of requests being made.
-
-Sending a gRPC request is then accomplished by assembling a protocol buffer
-containing the request data and passing it to the service stub. Note how the
-request protocol buffer is created empty and then populated via the
-[generated protocol buffer API](https://developers.google.com/protocol-buffers/docs/reference/python-generated).
-
-```py
-from grpc.beta import implementations
-
-channel = implementations.insecure_channel(host, int(port))
-stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
-
-request = classification_pb2.ClassificationRequest()
-example = request.input.example_list.examples.add()
-example.features.feature['x'].float_list.value.extend(image[0].astype(float))
-
-result = stub.Classify(request, 10.0) # 10 secs timeout
-```
-
-The returned result in this example is a `ClassificationResponse` protocol
-buffer.
-
-This is a skeletal example; please see the @{$deploy$Tensorflow Serving}
-documentation and [examples](https://github.com/tensorflow/serving/tree/master/tensorflow_serving/example)
-for more details.
-
-> Note: `ClassificationRequest` and `RegressionRequest` contain a
-> `tensorflow.serving.Input` protocol buffer, which in turn contains a list of
-> `tensorflow.Example` protocol buffers. `PredictRequest`, by contrast,
-> contains a mapping from feature names to values encoded via `TensorProto`.
-> Correspondingly: When using the `Classify` and `Regress` APIs, TensorFlow
-> Serving feeds serialized `tf.Example`s to the graph, so your
-> `serving_input_receiver_fn()` should include a `tf.parse_example()` Op.
-> When using the generic `Predict` API, however, TensorFlow Serving feeds raw
-> feature data to the graph, so a pass through `serving_input_receiver_fn()`
-> should be used.
-
-
-<!-- TODO(soergel): give examples of making requests against this server, using
-the different Tensorflow Serving APIs, selecting the signature by key, etc. -->
-
-<!-- TODO(soergel): document ExportStrategy here once Experiment moves
-from contrib to core. -->
-
-
-
-
-## CLI to inspect and execute SavedModel
-
-You can use the SavedModel Command Line Interface (CLI) to inspect and
-execute a SavedModel.
-For example, you can use the CLI to inspect the model's `SignatureDef`s.
-The CLI enables you to quickly confirm that the input
-@{$tensors$Tensor dtype and shape} match the model. Moreover, if you
-want to test your model, you can use the CLI to do a sanity check by
-passing in sample inputs in various formats (for example, Python
-expressions) and then fetching the output.
-
-
-### Install the SavedModel CLI
-
-Broadly speaking, you can install TensorFlow in either of the following
-two ways:
-
-* By installing a pre-built TensorFlow binary.
-* By building TensorFlow from source code.
-
-If you installed TensorFlow through a pre-built TensorFlow binary,
-then the SavedModel CLI is already installed on your system
-at pathname `bin\saved_model_cli`.
-
-If you built TensorFlow from source code, you must run the following
-additional command to build `saved_model_cli`:
-
-```
-$ bazel build tensorflow/python/tools:saved_model_cli
-```
-
-### Overview of commands
-
-The SavedModel CLI supports the following two commands on a
-`MetaGraphDef` in a SavedModel:
-
-* `show`, which shows a computation on a `MetaGraphDef` in a SavedModel.
-* `run`, which runs a computation on a `MetaGraphDef`.
-
-
-### `show` command
-
-A SavedModel contains one or more `MetaGraphDef`s, identified by their tag-sets.
-To serve a model, you
-might wonder what kind of `SignatureDef`s are in each model, and what are their
-inputs and outputs. The `show` command let you examine the contents of the
-SavedModel in hierarchical order. Here's the syntax:
-
-```
-usage: saved_model_cli show [-h] --dir DIR [--all]
-[--tag_set TAG_SET] [--signature_def SIGNATURE_DEF_KEY]
-```
-
-For example, the following command shows all available
-MetaGraphDef tag-sets in the SavedModel:
-
-```
-$ saved_model_cli show --dir /tmp/saved_model_dir
-The given SavedModel contains the following tag-sets:
-serve
-serve, gpu
-```
-
-The following command shows all available `SignatureDef` keys in
-a `MetaGraphDef`:
-
-```
-$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve
-The given SavedModel `MetaGraphDef` contains `SignatureDefs` with the
-following keys:
-SignatureDef key: "classify_x2_to_y3"
-SignatureDef key: "classify_x_to_y"
-SignatureDef key: "regress_x2_to_y3"
-SignatureDef key: "regress_x_to_y"
-SignatureDef key: "regress_x_to_y2"
-SignatureDef key: "serving_default"
-```
-
-If a `MetaGraphDef` has *multiple* tags in the tag-set, you must specify
-all tags, each tag separated by a comma. For example:
-
-```none
-$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu
-```
-
-To show all inputs and outputs TensorInfo for a specific `SignatureDef`, pass in
-the `SignatureDef` key to `signature_def` option. This is very useful when you
-want to know the tensor key value, dtype and shape of the input tensors for
-executing the computation graph later. For example:
-
-```
-$ saved_model_cli show --dir \
-/tmp/saved_model_dir --tag_set serve --signature_def serving_default
-The given SavedModel SignatureDef contains the following input(s):
- inputs['x'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: x:0
-The given SavedModel SignatureDef contains the following output(s):
- outputs['y'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: y:0
-Method name is: tensorflow/serving/predict
-```
-
-To show all available information in the SavedModel, use the `--all` option.
-For example:
-
-```none
-$ saved_model_cli show --dir /tmp/saved_model_dir --all
-MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
-
-signature_def['classify_x2_to_y3']:
- The given SavedModel SignatureDef contains the following input(s):
- inputs['inputs'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: x2:0
- The given SavedModel SignatureDef contains the following output(s):
- outputs['scores'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: y3:0
- Method name is: tensorflow/serving/classify
-
-...
-
-signature_def['serving_default']:
- The given SavedModel SignatureDef contains the following input(s):
- inputs['x'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: x:0
- The given SavedModel SignatureDef contains the following output(s):
- outputs['y'] tensor_info:
- dtype: DT_FLOAT
- shape: (-1, 1)
- name: y:0
- Method name is: tensorflow/serving/predict
-```
-
-
-### `run` command
-
-Invoke the `run` command to run a graph computation, passing
-inputs and then displaying (and optionally saving) the outputs.
-Here's the syntax:
-
-```
-usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
- SIGNATURE_DEF_KEY [--inputs INPUTS]
- [--input_exprs INPUT_EXPRS]
- [--input_examples INPUT_EXAMPLES] [--outdir OUTDIR]
- [--overwrite] [--tf_debug]
-```
-
-The `run` command provides the following three ways to pass inputs to the model:
-
-* `--inputs` option enables you to pass numpy ndarray in files.
-* `--input_exprs` option enables you to pass Python expressions.
-* `--input_examples` option enables you to pass `tf.train.Example`.
-
-
-#### `--inputs`
-
-To pass input data in files, specify the `--inputs` option, which takes the
-following general format:
-
-```bsh
---inputs <INPUTS>
-```
-
-where *INPUTS* is either of the following formats:
-
-* `<input_key>=<filename>`
-* `<input_key>=<filename>[<variable_name>]`
-
-You may pass multiple *INPUTS*. If you do pass multiple inputs, use a semicolon
-to separate each of the *INPUTS*.
-
-`saved_model_cli` uses `numpy.load` to load the *filename*.
-The *filename* may be in any of the following formats:
-
-* `.npy`
-* `.npz`
-* pickle format
-
-A `.npy` file always contains a numpy ndarray. Therefore, when loading from
-a `.npy` file, the content will be directly assigned to the specified input
-tensor. If you specify a *variable_name* with that `.npy` file, the
-*variable_name* will be ignored and a warning will be issued.
-
-When loading from a `.npz` (zip) file, you may optionally specify a
-*variable_name* to identify the variable within the zip file to load for
-the input tensor key. If you don't specify a *variable_name*, the SavedModel
-CLI will check that only one file is included in the zip file and load it
-for the specified input tensor key.
-
-When loading from a pickle file, if no `variable_name` is specified in the
-square brackets, whatever that is inside the pickle file will be passed to the
-specified input tensor key. Otherwise, the SavedModel CLI will assume a
-dictionary is stored in the pickle file and the value corresponding to
-the *variable_name* will be used.
-
-
-#### `--input_exprs`
-
-To pass inputs through Python expressions, specify the `--input_exprs` option.
-This can be useful for when you don't have data
-files lying around, but still want to sanity check the model with some simple
-inputs that match the dtype and shape of the model's `SignatureDef`s.
-For example:
-
-```bsh
-`<input_key>=[[1],[2],[3]]`
-```
-
-In addition to Python expressions, you may also pass numpy functions. For
-example:
-
-```bsh
-`<input_key>=np.ones((32,32,3))`
-```
-
-(Note that the `numpy` module is already available to you as `np`.)
-
-
-#### `--input_examples`
-
-To pass `tf.train.Example` as inputs, specify the `--input_examples` option.
-For each input key, it takes a list of dictionary, where each dictionary is an
-instance of `tf.train.Example`. The dictionary keys are the features and the
-values are the value lists for each feature.
-For example:
-
-```bsh
-`<input_key>=[{"age":[22,24],"education":["BS","MS"]}]`
-```
-
-#### Save output
-
-By default, the SavedModel CLI writes output to stdout. If a directory is
-passed to `--outdir` option, the outputs will be saved as npy files named after
-output tensor keys under the given directory.
-
-Use `--overwrite` to overwrite existing output files.
-
-
-#### TensorFlow debugger (tfdbg) integration
-
-If `--tf_debug` option is set, the SavedModel CLI will use the
-TensorFlow Debugger (tfdbg) to watch the intermediate Tensors and runtime
-graphs or subgraphs while running the SavedModel.
-
-
-#### Full examples of `run`
-
-Given:
-
-* Your model simply adds `x1` and `x2` to get output `y`.
-* All tensors in the model have shape `(-1, 1)`.
-* You have two `npy` files:
- * `/tmp/my_data1.npy`, which contains a numpy ndarray `[[1], [2], [3]]`.
- * `/tmp/my_data2.npy`, which contains another numpy
- ndarray `[[0.5], [0.5], [0.5]]`.
-
-To run these two `npy` files through the model to get output `y`, issue
-the following command:
-
-```
-$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
---signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npy;x2=/tmp/my_data2.npy \
---outdir /tmp/out
-Result for output key y:
-[[ 1.5]
- [ 2.5]
- [ 3.5]]
-```
-
-Let's change the preceding example slightly. This time, instead of two
-`.npy` files, you now have an `.npz` file and a pickle file. Furthermore,
-you want to overwrite any existing output file. Here's the command:
-
-```
-$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
---signature_def x1_x2_to_y \
---inputs x1=/tmp/my_data1.npz[x];x2=/tmp/my_data2.pkl --outdir /tmp/out \
---overwrite
-Result for output key y:
-[[ 1.5]
- [ 2.5]
- [ 3.5]]
-```
-
-You may specify python expression instead of an input file. For example,
-the following command replaces input `x2` with a Python expression:
-
-```
-$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
---signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npz[x] \
---input_exprs 'x2=np.ones((3,1))'
-Result for output key y:
-[[ 2]
- [ 3]
- [ 4]]
-```
-
-To run the model with the TensorFlow Debugger on, issue the
-following command:
-
-```
-$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
---signature_def serving_default --inputs x=/tmp/data.npz[x] --tf_debug
-```
-
-
-<a name="structure"></a>
-## Structure of a SavedModel directory
-
-When you save a model in SavedModel format, TensorFlow creates
-a SavedModel directory consisting of the following subdirectories
-and files:
-
-```bsh
-assets/
-assets.extra/
-variables/
- variables.data-?????-of-?????
- variables.index
-saved_model.pb|saved_model.pbtxt
-```
-
-where:
-
-* `assets` is a subfolder containing auxiliary (external) files,
- such as vocabularies. Assets are copied to the SavedModel location
- and can be read when loading a specific `MetaGraphDef`.
-* `assets.extra` is a subfolder where higher-level libraries and users can
- add their own assets that co-exist with the model, but are not loaded by
- the graph. This subfolder is not managed by the SavedModel libraries.
-* `variables` is a subfolder that includes output from
- `tf.train.Saver`.
-* `saved_model.pb` or `saved_model.pbtxt` is the SavedModel protocol buffer.
- It includes the graph definitions as `MetaGraphDef` protocol buffers.
-
-A single SavedModel can represent multiple graphs. In this case, all the
-graphs in the SavedModel share a *single* set of checkpoints (variables)
-and assets. For example, the following diagram shows one SavedModel
-containing three `MetaGraphDef`s, all three of which share the same set
-of checkpoints and assets:
-
-![SavedModel represents checkpoints, assets, and one or more MetaGraphDefs](../images/SavedModel.svg)
-
-Each graph is associated with a specific set of tags, which enables
-identification during a load or restore operation.
diff --git a/tensorflow/docs_src/guide/summaries_and_tensorboard.md b/tensorflow/docs_src/guide/summaries_and_tensorboard.md
deleted file mode 100644
index 6177c3393b..0000000000
--- a/tensorflow/docs_src/guide/summaries_and_tensorboard.md
+++ /dev/null
@@ -1,225 +0,0 @@
-# TensorBoard: Visualizing Learning
-
-The computations you'll use TensorFlow for - like training a massive
-deep neural network - can be complex and confusing. To make it easier to
-understand, debug, and optimize TensorFlow programs, we've included a suite of
-visualization tools called TensorBoard. You can use TensorBoard to visualize
-your TensorFlow graph, plot quantitative metrics about the execution of your
-graph, and show additional data like images that pass through it. When
-TensorBoard is fully configured, it looks like this:
-
-![MNIST TensorBoard](https://www.tensorflow.org/images/mnist_tensorboard.png "MNIST TensorBoard")
-
-<div class="video-wrapper">
- <iframe class="devsite-embedded-youtube-video" data-video-id="eBbEDRsCmv4"
- data-autohide="1" data-showinfo="0" frameborder="0" allowfullscreen>
- </iframe>
-</div>
-
-This 30-minute tutorial is intended to get you started with simple TensorBoard
-usage. It assumes a basic understanding of TensorFlow.
-
-There are other resources available as well! The [TensorBoard GitHub](https://github.com/tensorflow/tensorboard)
-has a lot more information on using individual dashboards within TensorBoard
-including tips & tricks and debugging information.
-
-## Setup
-
-[Install TensorFlow](https://www.tensorflow.org/install/). Installing TensorFlow
-via pip should also automatically install TensorBoard.
-
-## Serializing the data
-
-TensorBoard operates by reading TensorFlow events files, which contain summary
-data that you can generate when running TensorFlow. Here's the general
-lifecycle for summary data within TensorBoard.
-
-First, create the TensorFlow graph that you'd like to collect summary
-data from, and decide which nodes you would like to annotate with
-@{$python/summary$summary operations}.
-
-For example, suppose you are training a convolutional neural network for
-recognizing MNIST digits. You'd like to record how the learning rate
-varies over time, and how the objective function is changing. Collect these by
-attaching `tf.summary.scalar` ops
-to the nodes that output the learning rate and loss respectively. Then, give
-each `scalar_summary` a meaningful `tag`, like `'learning rate'` or `'loss
-function'`.
-
-Perhaps you'd also like to visualize the distributions of activations coming
-off a particular layer, or the distribution of gradients or weights. Collect
-this data by attaching
-`tf.summary.histogram` ops to
-the gradient outputs and to the variable that holds your weights, respectively.
-
-For details on all of the summary operations available, check out the docs on
-@{$python/summary$summary operations}.
-
-Operations in TensorFlow don't do anything until you run them, or an op that
-depends on their output. And the summary nodes that we've just created are
-peripheral to your graph: none of the ops you are currently running depend on
-them. So, to generate summaries, we need to run all of these summary nodes.
-Managing them by hand would be tedious, so use
-`tf.summary.merge_all`
-to combine them into a single op that generates all the summary data.
-
-Then, you can just run the merged summary op, which will generate a serialized
-`Summary` protobuf object with all of your summary data at a given step.
-Finally, to write this summary data to disk, pass the summary protobuf to a
-`tf.summary.FileWriter`.
-
-The `FileWriter` takes a logdir in its constructor - this logdir is quite
-important, it's the directory where all of the events will be written out.
-Also, the `FileWriter` can optionally take a `Graph` in its constructor.
-If it receives a `Graph` object, then TensorBoard will visualize your graph
-along with tensor shape information. This will give you a much better sense of
-what flows through the graph: see
-@{$graph_viz#tensor-shape-information$Tensor shape information}.
-
-Now that you've modified your graph and have a `FileWriter`, you're ready to
-start running your network! If you want, you could run the merged summary op
-every single step, and record a ton of training data. That's likely to be more
-data than you need, though. Instead, consider running the merged summary op
-every `n` steps.
-
-The code example below is a modification of the
-[simple MNIST tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py),
-in which we have added some summary ops, and run them every ten steps. If you
-run this and then launch `tensorboard --logdir=/tmp/tensorflow/mnist`, you'll be able
-to visualize statistics, such as how the weights or accuracy varied during
-training. The code below is an excerpt; full source is
-[here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
-
-```python
-def variable_summaries(var):
- """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
- with tf.name_scope('summaries'):
- mean = tf.reduce_mean(var)
- tf.summary.scalar('mean', mean)
- with tf.name_scope('stddev'):
- stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
- tf.summary.scalar('stddev', stddev)
- tf.summary.scalar('max', tf.reduce_max(var))
- tf.summary.scalar('min', tf.reduce_min(var))
- tf.summary.histogram('histogram', var)
-
-def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
- """Reusable code for making a simple neural net layer.
-
- It does a matrix multiply, bias add, and then uses relu to nonlinearize.
- It also sets up name scoping so that the resultant graph is easy to read,
- and adds a number of summary ops.
- """
- # Adding a name scope ensures logical grouping of the layers in the graph.
- with tf.name_scope(layer_name):
- # This Variable will hold the state of the weights for the layer
- with tf.name_scope('weights'):
- weights = weight_variable([input_dim, output_dim])
- variable_summaries(weights)
- with tf.name_scope('biases'):
- biases = bias_variable([output_dim])
- variable_summaries(biases)
- with tf.name_scope('Wx_plus_b'):
- preactivate = tf.matmul(input_tensor, weights) + biases
- tf.summary.histogram('pre_activations', preactivate)
- activations = act(preactivate, name='activation')
- tf.summary.histogram('activations', activations)
- return activations
-
-hidden1 = nn_layer(x, 784, 500, 'layer1')
-
-with tf.name_scope('dropout'):
- keep_prob = tf.placeholder(tf.float32)
- tf.summary.scalar('dropout_keep_probability', keep_prob)
- dropped = tf.nn.dropout(hidden1, keep_prob)
-
-# Do not apply softmax activation yet, see below.
-y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity)
-
-with tf.name_scope('cross_entropy'):
- # The raw formulation of cross-entropy,
- #
- # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)),
- # reduction_indices=[1]))
- #
- # can be numerically unstable.
- #
- # So here we use tf.losses.sparse_softmax_cross_entropy on the
- # raw logit outputs of the nn_layer above.
- with tf.name_scope('total'):
- cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
-tf.summary.scalar('cross_entropy', cross_entropy)
-
-with tf.name_scope('train'):
- train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
- cross_entropy)
-
-with tf.name_scope('accuracy'):
- with tf.name_scope('correct_prediction'):
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
- with tf.name_scope('accuracy'):
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
-tf.summary.scalar('accuracy', accuracy)
-
-# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
-merged = tf.summary.merge_all()
-train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
- sess.graph)
-test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')
-tf.global_variables_initializer().run()
-```
-
-After we've initialized the `FileWriters`, we have to add summaries to the
-`FileWriters` as we train and test the model.
-
-```python
-# Train the model, and also write summaries.
-# Every 10th step, measure test-set accuracy, and write test summaries
-# All other steps, run train_step on training data, & add training summaries
-
-def feed_dict(train):
- """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
- if train or FLAGS.fake_data:
- xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data)
- k = FLAGS.dropout
- else:
- xs, ys = mnist.test.images, mnist.test.labels
- k = 1.0
- return {x: xs, y_: ys, keep_prob: k}
-
-for i in range(FLAGS.max_steps):
- if i % 10 == 0: # Record summaries and test-set accuracy
- summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
- test_writer.add_summary(summary, i)
- print('Accuracy at step %s: %s' % (i, acc))
- else: # Record train set summaries, and train
- summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
- train_writer.add_summary(summary, i)
-```
-
-You're now all set to visualize this data using TensorBoard.
-
-
-## Launching TensorBoard
-
-To run TensorBoard, use the following command (alternatively `python -m
-tensorboard.main`)
-
-```bash
-tensorboard --logdir=path/to/log-directory
-```
-
-where `logdir` points to the directory where the `FileWriter` serialized its
-data. If this `logdir` directory contains subdirectories which contain
-serialized data from separate runs, then TensorBoard will visualize the data
-from all of those runs. Once TensorBoard is running, navigate your web browser
-to `localhost:6006` to view the TensorBoard.
-
-When looking at TensorBoard, you will see the navigation tabs in the top right
-corner. Each tab represents a set of serialized data that can be visualized.
-
-For in depth information on how to use the *graph* tab to visualize your graph,
-see @{$graph_viz$TensorBoard: Graph Visualization}.
-
-For more usage information on TensorBoard in general, see the
-[TensorBoard GitHub](https://github.com/tensorflow/tensorboard).
diff --git a/tensorflow/docs_src/guide/tensorboard_histograms.md b/tensorflow/docs_src/guide/tensorboard_histograms.md
deleted file mode 100644
index af8f2cadd1..0000000000
--- a/tensorflow/docs_src/guide/tensorboard_histograms.md
+++ /dev/null
@@ -1,245 +0,0 @@
-# TensorBoard Histogram Dashboard
-
-The TensorBoard Histogram Dashboard displays how the distribution of some
-`Tensor` in your TensorFlow graph has changed over time. It does this by showing
-many histograms visualizations of your tensor at different points in time.
-
-## A Basic Example
-
-Let's start with a simple case: a normally-distributed variable, where the mean
-shifts over time.
-TensorFlow has an op
-[`tf.random_normal`](https://www.tensorflow.org/api_docs/python/tf/random_normal)
-which is perfect for this purpose. As is usually the case with TensorBoard, we
-will ingest data using a summary op; in this case,
-['tf.summary.histogram'](https://www.tensorflow.org/api_docs/python/tf/summary/histogram).
-For a primer on how summaries work, please see the
-[TensorBoard guide](./summaries_and_tensorboard.md).
-
-Here is a code snippet that will generate some histogram summaries containing
-normally distributed data, where the mean of the distribution increases over
-time.
-
-```python
-import tensorflow as tf
-
-k = tf.placeholder(tf.float32)
-
-# Make a normal distribution, with a shifting mean
-mean_moving_normal = tf.random_normal(shape=[1000], mean=(5*k), stddev=1)
-# Record that distribution into a histogram summary
-tf.summary.histogram("normal/moving_mean", mean_moving_normal)
-
-# Setup a session and summary writer
-sess = tf.Session()
-writer = tf.summary.FileWriter("/tmp/histogram_example")
-
-summaries = tf.summary.merge_all()
-
-# Setup a loop and write the summaries to disk
-N = 400
-for step in range(N):
- k_val = step/float(N)
- summ = sess.run(summaries, feed_dict={k: k_val})
- writer.add_summary(summ, global_step=step)
-```
-
-Once that code runs, we can load the data into TensorBoard via the command line:
-
-
-```sh
-tensorboard --logdir=/tmp/histogram_example
-```
-
-Once TensorBoard is running, load it in Chrome or Firefox and navigate to the
-Histogram Dashboard. Then we can see a histogram visualization for our normally
-distributed data.
-
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/1_moving_mean.png)
-
-`tf.summary.histogram` takes an arbitrarily sized and shaped Tensor, and
-compresses it into a histogram data structure consisting of many bins with
-widths and counts. For example, let's say we want to organize the numbers
-`[0.5, 1.1, 1.3, 2.2, 2.9, 2.99]` into bins. We could make three bins:
-* a bin
-containing everything from 0 to 1 (it would contain one element, 0.5),
-* a bin
-containing everything from 1-2 (it would contain two elements, 1.1 and 1.3),
-* a bin containing everything from 2-3 (it would contain three elements: 2.2,
-2.9 and 2.99).
-
-TensorFlow uses a similar approach to create bins, but unlike in our example, it
-doesn't create integer bins. For large, sparse datasets, that might result in
-many thousands of bins.
-Instead, [the bins are exponentially distributed, with many bins close to 0 and
-comparatively few bins for very large numbers.](https://github.com/tensorflow/tensorflow/blob/c8b59c046895fa5b6d79f73e0b5817330fcfbfc1/tensorflow/core/lib/histogram/histogram.cc#L28)
-However, visualizing exponentially-distributed bins is tricky; if height is used
-to encode count, then wider bins take more space, even if they have the same
-number of elements. Conversely, encoding count in the area makes height
-comparisons impossible. Instead, the histograms [resample the data](https://github.com/tensorflow/tensorflow/blob/17c47804b86e340203d451125a721310033710f1/tensorflow/tensorboard/components/tf_backend/backend.ts#L400)
-into uniform bins. This can lead to unfortunate artifacts in some cases.
-
-Each slice in the histogram visualizer displays a single histogram.
-The slices are organized by step;
-older slices (e.g. step 0) are further "back" and darker, while newer slices
-(e.g. step 400) are close to the foreground, and lighter in color.
-The y-axis on the right shows the step number.
-
-You can mouse over the histogram to see tooltips with some more detailed
-information. For example, in the following image we can see that the histogram
-at timestep 176 has a bin centered at 2.25 with 177 elements in that bin.
-
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/2_moving_mean_tooltip.png)
-
-Also, you may note that the histogram slices are not always evenly spaced in
-step count or time. This is because TensorBoard uses
-[reservoir sampling](https://en.wikipedia.org/wiki/Reservoir_sampling) to keep a
-subset of all the histograms, to save on memory. Reservoir sampling guarantees
-that every sample has an equal likelihood of being included, but because it is
-a randomized algorithm, the samples chosen don't occur at even steps.
-
-## Overlay Mode
-
-There is a control on the left of the dashboard that allows you to toggle the
-histogram mode from "offset" to "overlay":
-
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/3_overlay_offset.png)
-
-In "offset" mode, the visualization rotates 45 degrees, so that the individual
-histogram slices are no longer spread out in time, but instead are all plotted
-on the same y-axis.
-
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/4_overlay.png)
-Now, each slice is a separate line on the chart, and the y-axis shows the item
-count within each bucket. Darker lines are older, earlier steps, and lighter
-lines are more recent, later steps. Once again, you can mouse over the chart to
-see some additional information.
-
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/5_overlay_tooltips.png)
-
-In general, the overlay visualization is useful if you want to directly compare
-the counts of different histograms.
-
-## Multimodal Distributions
-
-The Histogram Dashboard is great for visualizing multimodal
-distributions. Let's construct a simple bimodal distribution by concatenating
-the outputs from two different normal distributions. The code will look like
-this:
-
-```python
-import tensorflow as tf
-
-k = tf.placeholder(tf.float32)
-
-# Make a normal distribution, with a shifting mean
-mean_moving_normal = tf.random_normal(shape=[1000], mean=(5*k), stddev=1)
-# Record that distribution into a histogram summary
-tf.summary.histogram("normal/moving_mean", mean_moving_normal)
-
-# Make a normal distribution with shrinking variance
-variance_shrinking_normal = tf.random_normal(shape=[1000], mean=0, stddev=1-(k))
-# Record that distribution too
-tf.summary.histogram("normal/shrinking_variance", variance_shrinking_normal)
-
-# Let's combine both of those distributions into one dataset
-normal_combined = tf.concat([mean_moving_normal, variance_shrinking_normal], 0)
-# We add another histogram summary to record the combined distribution
-tf.summary.histogram("normal/bimodal", normal_combined)
-
-summaries = tf.summary.merge_all()
-
-# Setup a session and summary writer
-sess = tf.Session()
-writer = tf.summary.FileWriter("/tmp/histogram_example")
-
-# Setup a loop and write the summaries to disk
-N = 400
-for step in range(N):
- k_val = step/float(N)
- summ = sess.run(summaries, feed_dict={k: k_val})
- writer.add_summary(summ, global_step=step)
-```
-
-You already remember our "moving mean" normal distribution from the example
-above. Now we also have a "shrinking variance" distribution. Side-by-side, they
-look like this:
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/6_two_distributions.png)
-
-When we concatenate them, we get a chart that clearly reveals the divergent,
-bimodal structure:
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/7_bimodal.png)
-
-## Some more distributions
-
-Just for fun, let's generate and visualize a few more distributions, and then
-combine them all into one chart. Here's the code we'll use:
-
-```python
-import tensorflow as tf
-
-k = tf.placeholder(tf.float32)
-
-# Make a normal distribution, with a shifting mean
-mean_moving_normal = tf.random_normal(shape=[1000], mean=(5*k), stddev=1)
-# Record that distribution into a histogram summary
-tf.summary.histogram("normal/moving_mean", mean_moving_normal)
-
-# Make a normal distribution with shrinking variance
-variance_shrinking_normal = tf.random_normal(shape=[1000], mean=0, stddev=1-(k))
-# Record that distribution too
-tf.summary.histogram("normal/shrinking_variance", variance_shrinking_normal)
-
-# Let's combine both of those distributions into one dataset
-normal_combined = tf.concat([mean_moving_normal, variance_shrinking_normal], 0)
-# We add another histogram summary to record the combined distribution
-tf.summary.histogram("normal/bimodal", normal_combined)
-
-# Add a gamma distribution
-gamma = tf.random_gamma(shape=[1000], alpha=k)
-tf.summary.histogram("gamma", gamma)
-
-# And a poisson distribution
-poisson = tf.random_poisson(shape=[1000], lam=k)
-tf.summary.histogram("poisson", poisson)
-
-# And a uniform distribution
-uniform = tf.random_uniform(shape=[1000], maxval=k*10)
-tf.summary.histogram("uniform", uniform)
-
-# Finally, combine everything together!
-all_distributions = [mean_moving_normal, variance_shrinking_normal,
- gamma, poisson, uniform]
-all_combined = tf.concat(all_distributions, 0)
-tf.summary.histogram("all_combined", all_combined)
-
-summaries = tf.summary.merge_all()
-
-# Setup a session and summary writer
-sess = tf.Session()
-writer = tf.summary.FileWriter("/tmp/histogram_example")
-
-# Setup a loop and write the summaries to disk
-N = 400
-for step in range(N):
- k_val = step/float(N)
- summ = sess.run(summaries, feed_dict={k: k_val})
- writer.add_summary(summ, global_step=step)
-```
-### Gamma Distribution
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/8_gamma.png)
-
-### Uniform Distribution
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/9_uniform.png)
-
-### Poisson Distribution
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/10_poisson.png)
-The poisson distribution is defined over the integers. So, all of the values
-being generated are perfect integers. The histogram compression moves the data
-into floating-point bins, causing the visualization to show little
-bumps over the integer values rather than perfect spikes.
-
-### All Together Now
-Finally, we can concatenate all of the data into one funny-looking curve.
-![](https://www.tensorflow.org/images/tensorboard/histogram_dashboard/11_all_combined.png)
-
diff --git a/tensorflow/docs_src/guide/tensors.md b/tensorflow/docs_src/guide/tensors.md
deleted file mode 100644
index 6b5a110a1c..0000000000
--- a/tensorflow/docs_src/guide/tensors.md
+++ /dev/null
@@ -1,330 +0,0 @@
-# Tensors
-
-TensorFlow, as the name indicates, is a framework to define and run computations
-involving tensors. A **tensor** is a generalization of vectors and matrices to
-potentially higher dimensions. Internally, TensorFlow represents tensors as
-n-dimensional arrays of base datatypes.
-
-When writing a TensorFlow program, the main object you manipulate and pass
-around is the `tf.Tensor`. A `tf.Tensor` object represents a partially defined
-computation that will eventually produce a value. TensorFlow programs work by
-first building a graph of `tf.Tensor` objects, detailing how each tensor is
-computed based on the other available tensors and then by running parts of this
-graph to achieve the desired results.
-
-A `tf.Tensor` has the following properties:
-
- * a data type (`float32`, `int32`, or `string`, for example)
- * a shape
-
-
-Each element in the Tensor has the same data type, and the data type is always
-known. The shape (that is, the number of dimensions it has and the size of each
-dimension) might be only partially known. Most operations produce tensors of
-fully-known shapes if the shapes of their inputs are also fully known, but in
-some cases it's only possible to find the shape of a tensor at graph execution
-time.
-
-Some types of tensors are special, and these will be covered in other
-units of the TensorFlow guide. The main ones are:
-
- * `tf.Variable`
- * `tf.constant`
- * `tf.placeholder`
- * `tf.SparseTensor`
-
-With the exception of `tf.Variable`, the value of a tensor is immutable, which
-means that in the context of a single execution tensors only have a single
-value. However, evaluating the same tensor twice can return different values;
-for example that tensor can be the result of reading data from disk, or
-generating a random number.
-
-## Rank
-
-The **rank** of a `tf.Tensor` object is its number of dimensions. Synonyms for
-rank include **order** or **degree** or **n-dimension**.
-Note that rank in TensorFlow is not the same as matrix rank in mathematics.
-As the following table shows, each rank in TensorFlow corresponds to a
-different mathematical entity:
-
-Rank | Math entity
---- | ---
-0 | Scalar (magnitude only)
-1 | Vector (magnitude and direction)
-2 | Matrix (table of numbers)
-3 | 3-Tensor (cube of numbers)
-n | n-Tensor (you get the idea)
-
-
-### Rank 0
-
-The following snippet demonstrates creating a few rank 0 variables:
-
-```python
-mammal = tf.Variable("Elephant", tf.string)
-ignition = tf.Variable(451, tf.int16)
-floating = tf.Variable(3.14159265359, tf.float64)
-its_complicated = tf.Variable(12.3 - 4.85j, tf.complex64)
-```
-
-Note: A string is treated as a single item in TensorFlow, not as a sequence of
-characters. It is possible to have scalar strings, vectors of strings, etc.
-
-### Rank 1
-
-To create a rank 1 `tf.Tensor` object, you can pass a list of items as the
-initial value. For example:
-
-```python
-mystr = tf.Variable(["Hello"], tf.string)
-cool_numbers = tf.Variable([3.14159, 2.71828], tf.float32)
-first_primes = tf.Variable([2, 3, 5, 7, 11], tf.int32)
-its_very_complicated = tf.Variable([12.3 - 4.85j, 7.5 - 6.23j], tf.complex64)
-```
-
-
-### Higher ranks
-
-A rank 2 `tf.Tensor` object consists of at least one row and at least
-one column:
-
-```python
-mymat = tf.Variable([[7],[11]], tf.int16)
-myxor = tf.Variable([[False, True],[True, False]], tf.bool)
-linear_squares = tf.Variable([[4], [9], [16], [25]], tf.int32)
-squarish_squares = tf.Variable([ [4, 9], [16, 25] ], tf.int32)
-rank_of_squares = tf.rank(squarish_squares)
-mymatC = tf.Variable([[7],[11]], tf.int32)
-```
-
-Higher-rank Tensors, similarly, consist of an n-dimensional array. For example,
-during image processing, many tensors of rank 4 are used, with dimensions
-corresponding to example-in-batch, image width, image height, and color channel.
-
-``` python
-my_image = tf.zeros([10, 299, 299, 3]) # batch x height x width x color
-```
-
-### Getting a `tf.Tensor` object's rank
-
-To determine the rank of a `tf.Tensor` object, call the `tf.rank` method.
-For example, the following method programmatically determines the rank
-of the `tf.Tensor` defined in the previous section:
-
-```python
-r = tf.rank(my_image)
-# After the graph runs, r will hold the value 4.
-```
-
-### Referring to `tf.Tensor` slices
-
-Since a `tf.Tensor` is an n-dimensional array of cells, to access a single cell
-in a `tf.Tensor` you need to specify n indices.
-
-For a rank 0 tensor (a scalar), no indices are necessary, since it is already a
-single number.
-
-For a rank 1 tensor (a vector), passing a single index allows you to access a
-number:
-
-```python
-my_scalar = my_vector[2]
-```
-
-Note that the index passed inside the `[]` can itself be a scalar `tf.Tensor`, if
-you want to dynamically choose an element from the vector.
-
-For tensors of rank 2 or higher, the situation is more interesting. For a
-`tf.Tensor` of rank 2, passing two numbers returns a scalar, as expected:
-
-
-```python
-my_scalar = my_matrix[1, 2]
-```
-
-
-Passing a single number, however, returns a subvector of a matrix, as follows:
-
-
-```python
-my_row_vector = my_matrix[2]
-my_column_vector = my_matrix[:, 3]
-```
-
-The `:` notation is python slicing syntax for "leave this dimension alone". This
-is useful in higher-rank Tensors, as it allows you to access its subvectors,
-submatrices, and even other subtensors.
-
-
-## Shape
-
-The **shape** of a tensor is the number of elements in each dimension.
-TensorFlow automatically infers shapes during graph construction. These inferred
-shapes might have known or unknown rank. If the rank is known, the sizes of each
-dimension might be known or unknown.
-
-The TensorFlow documentation uses three notational conventions to describe
-tensor dimensionality: rank, shape, and dimension number. The following table
-shows how these relate to one another:
-
-Rank | Shape | Dimension number | Example
---- | --- | --- | ---
-0 | [] | 0-D | A 0-D tensor. A scalar.
-1 | [D0] | 1-D | A 1-D tensor with shape [5].
-2 | [D0, D1] | 2-D | A 2-D tensor with shape [3, 4].
-3 | [D0, D1, D2] | 3-D | A 3-D tensor with shape [1, 4, 3].
-n | [D0, D1, ... Dn-1] | n-D | A tensor with shape [D0, D1, ... Dn-1].
-
-Shapes can be represented via Python lists / tuples of ints, or with the
-`tf.TensorShape`.
-
-### Getting a `tf.Tensor` object's shape
-
-There are two ways of accessing the shape of a `tf.Tensor`. While building the
-graph, it is often useful to ask what is already known about a tensor's
-shape. This can be done by reading the `shape` property of a `tf.Tensor` object.
-This method returns a `TensorShape` object, which is a convenient way of
-representing partially-specified shapes (since, when building the graph, not all
-shapes will be fully known).
-
-It is also possible to get a `tf.Tensor` that will represent the fully-defined
-shape of another `tf.Tensor` at runtime. This is done by calling the `tf.shape`
-operation. This way, you can build a graph that manipulates the shapes of
-tensors by building other tensors that depend on the dynamic shape of the input
-`tf.Tensor`.
-
-For example, here is how to make a vector of zeros with the same size as the
-number of columns in a given matrix:
-
-``` python
-zeros = tf.zeros(my_matrix.shape[1])
-```
-
-### Changing the shape of a `tf.Tensor`
-
-The **number of elements** of a tensor is the product of the sizes of all its
-shapes. The number of elements of a scalar is always `1`. Since there are often
-many different shapes that have the same number of elements, it's often
-convenient to be able to change the shape of a `tf.Tensor`, keeping its elements
-fixed. This can be done with `tf.reshape`.
-
-The following examples demonstrate how to reshape tensors:
-
-```python
-rank_three_tensor = tf.ones([3, 4, 5])
-matrix = tf.reshape(rank_three_tensor, [6, 10]) # Reshape existing content into
- # a 6x10 matrix
-matrixB = tf.reshape(matrix, [3, -1]) # Reshape existing content into a 3x20
- # matrix. -1 tells reshape to calculate
- # the size of this dimension.
-matrixAlt = tf.reshape(matrixB, [4, 3, -1]) # Reshape existing content into a
- #4x3x5 tensor
-
-# Note that the number of elements of the reshaped Tensors has to match the
-# original number of elements. Therefore, the following example generates an
-# error because no possible value for the last dimension will match the number
-# of elements.
-yet_another = tf.reshape(matrixAlt, [13, 2, -1]) # ERROR!
-```
-
-## Data types
-
-In addition to dimensionality, Tensors have a data type. Refer to the
-`tf.DType` page for a complete list of the data types.
-
-It is not possible to have a `tf.Tensor` with more than one data type. It is
-possible, however, to serialize arbitrary data structures as `string`s and store
-those in `tf.Tensor`s.
-
-It is possible to cast `tf.Tensor`s from one datatype to another using
-`tf.cast`:
-
-``` python
-# Cast a constant integer tensor into floating point.
-float_tensor = tf.cast(tf.constant([1, 2, 3]), dtype=tf.float32)
-```
-
-To inspect a `tf.Tensor`'s data type use the `Tensor.dtype` property.
-
-When creating a `tf.Tensor` from a python object you may optionally specify the
-datatype. If you don't, TensorFlow chooses a datatype that can represent your
-data. TensorFlow converts Python integers to `tf.int32` and python floating
-point numbers to `tf.float32`. Otherwise TensorFlow uses the same rules numpy
-uses when converting to arrays.
-
-## Evaluating Tensors
-
-Once the computation graph has been built, you can run the computation that
-produces a particular `tf.Tensor` and fetch the value assigned to it. This is
-often useful for debugging as well as being required for much of TensorFlow to
-work.
-
-The simplest way to evaluate a Tensor is using the `Tensor.eval` method. For
-example:
-
-```python
-constant = tf.constant([1, 2, 3])
-tensor = constant * constant
-print(tensor.eval())
-```
-
-The `eval` method only works when a default `tf.Session` is active (see
-Graphs and Sessions for more information).
-
-`Tensor.eval` returns a numpy array with the same contents as the tensor.
-
-Sometimes it is not possible to evaluate a `tf.Tensor` with no context because
-its value might depend on dynamic information that is not available. For
-example, tensors that depend on `placeholder`s can't be evaluated without
-providing a value for the `placeholder`.
-
-``` python
-p = tf.placeholder(tf.float32)
-t = p + 1.0
-t.eval() # This will fail, since the placeholder did not get a value.
-t.eval(feed_dict={p:2.0}) # This will succeed because we're feeding a value
- # to the placeholder.
-```
-
-Note that it is possible to feed any `tf.Tensor`, not just placeholders.
-
-Other model constructs might make evaluating a `tf.Tensor`
-complicated. TensorFlow can't directly evaluate `tf.Tensor`s defined inside
-functions or inside control flow constructs. If a `tf.Tensor` depends on a value
-from a queue, evaluating the `tf.Tensor` will only work once something has been
-enqueued; otherwise, evaluating it will hang. When working with queues, remember
-to call `tf.train.start_queue_runners` before evaluating any `tf.Tensor`s.
-
-## Printing Tensors
-
-For debugging purposes you might want to print the value of a `tf.Tensor`. While
- @{$debugger$tfdbg} provides advanced debugging support, TensorFlow also has an
- operation to directly print the value of a `tf.Tensor`.
-
-Note that you rarely want to use the following pattern when printing a
-`tf.Tensor`:
-
-``` python
-t = <<some tensorflow operation>>
-print(t) # This will print the symbolic tensor when the graph is being built.
- # This tensor does not have a value in this context.
-```
-
-This code prints the `tf.Tensor` object (which represents deferred computation)
-and not its value. Instead, TensorFlow provides the `tf.Print` operation, which
-returns its first tensor argument unchanged while printing the set of
-`tf.Tensor`s it is passed as the second argument.
-
-To correctly use `tf.Print` its return value must be used. See the example below
-
-``` python
-t = <<some tensorflow operation>>
-tf.Print(t, [t]) # This does nothing
-t = tf.Print(t, [t]) # Here we are using the value returned by tf.Print
-result = t + 1 # Now when result is evaluated the value of `t` will be printed.
-```
-
-When you evaluate `result` you will evaluate everything `result` depends
-upon. Since `result` depends upon `t`, and evaluating `t` has the side effect of
-printing its input (the old value of `t`), `t` gets printed.
-
diff --git a/tensorflow/docs_src/guide/using_gpu.md b/tensorflow/docs_src/guide/using_gpu.md
deleted file mode 100644
index c0218fd12e..0000000000
--- a/tensorflow/docs_src/guide/using_gpu.md
+++ /dev/null
@@ -1,215 +0,0 @@
-# Using GPUs
-
-## Supported devices
-
-On a typical system, there are multiple computing devices. In TensorFlow, the
-supported device types are `CPU` and `GPU`. They are represented as `strings`.
-For example:
-
-* `"/cpu:0"`: The CPU of your machine.
-* `"/device:GPU:0"`: The GPU of your machine, if you have one.
-* `"/device:GPU:1"`: The second GPU of your machine, etc.
-
-If a TensorFlow operation has both CPU and GPU implementations, the GPU devices
-will be given priority when the operation is assigned to a device. For example,
-`matmul` has both CPU and GPU kernels. On a system with devices `cpu:0` and
-`gpu:0`, `gpu:0` will be selected to run `matmul`.
-
-## Logging Device placement
-
-To find out which devices your operations and tensors are assigned to, create
-the session with `log_device_placement` configuration option set to `True`.
-
-```python
-# Creates a graph.
-a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
-b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
-c = tf.matmul(a, b)
-# Creates a session with log_device_placement set to True.
-sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
-# Runs the op.
-print(sess.run(c))
-```
-
-You should see the following output:
-
-```
-Device mapping:
-/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K40c, pci bus
-id: 0000:05:00.0
-b: /job:localhost/replica:0/task:0/device:GPU:0
-a: /job:localhost/replica:0/task:0/device:GPU:0
-MatMul: /job:localhost/replica:0/task:0/device:GPU:0
-[[ 22. 28.]
- [ 49. 64.]]
-
-```
-
-## Manual device placement
-
-If you would like a particular operation to run on a device of your choice
-instead of what's automatically selected for you, you can use `with tf.device`
-to create a device context such that all the operations within that context will
-have the same device assignment.
-
-```python
-# Creates a graph.
-with tf.device('/cpu:0'):
- a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
- b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
-c = tf.matmul(a, b)
-# Creates a session with log_device_placement set to True.
-sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
-# Runs the op.
-print(sess.run(c))
-```
-
-You will see that now `a` and `b` are assigned to `cpu:0`. Since a device was
-not explicitly specified for the `MatMul` operation, the TensorFlow runtime will
-choose one based on the operation and available devices (`gpu:0` in this
-example) and automatically copy tensors between devices if required.
-
-```
-Device mapping:
-/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K40c, pci bus
-id: 0000:05:00.0
-b: /job:localhost/replica:0/task:0/cpu:0
-a: /job:localhost/replica:0/task:0/cpu:0
-MatMul: /job:localhost/replica:0/task:0/device:GPU:0
-[[ 22. 28.]
- [ 49. 64.]]
-```
-
-## Allowing GPU memory growth
-
-By default, TensorFlow maps nearly all of the GPU memory of all GPUs (subject to
-[`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars))
-visible to the process. This is done to more efficiently use the relatively
-precious GPU memory resources on the devices by reducing [memory
-fragmentation](https://en.wikipedia.org/wiki/Fragmentation_\(computing\)).
-
-In some cases it is desirable for the process to only allocate a subset of the
-available memory, or to only grow the memory usage as is needed by the process.
-TensorFlow provides two Config options on the Session to control this.
-
-The first is the `allow_growth` option, which attempts to allocate only as much
-GPU memory based on runtime allocations: it starts out allocating very little
-memory, and as Sessions get run and more GPU memory is needed, we extend the GPU
-memory region needed by the TensorFlow process. Note that we do not release
-memory, since that can lead to even worse memory fragmentation. To turn this
-option on, set the option in the ConfigProto by:
-
-```python
-config = tf.ConfigProto()
-config.gpu_options.allow_growth = True
-session = tf.Session(config=config, ...)
-```
-
-The second method is the `per_process_gpu_memory_fraction` option, which
-determines the fraction of the overall amount of memory that each visible GPU
-should be allocated. For example, you can tell TensorFlow to only allocate 40%
-of the total memory of each GPU by:
-
-```python
-config = tf.ConfigProto()
-config.gpu_options.per_process_gpu_memory_fraction = 0.4
-session = tf.Session(config=config, ...)
-```
-
-This is useful if you want to truly bound the amount of GPU memory available to
-the TensorFlow process.
-
-## Using a single GPU on a multi-GPU system
-
-If you have more than one GPU in your system, the GPU with the lowest ID will be
-selected by default. If you would like to run on a different GPU, you will need
-to specify the preference explicitly:
-
-```python
-# Creates a graph.
-with tf.device('/device:GPU:2'):
- a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
- b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
- c = tf.matmul(a, b)
-# Creates a session with log_device_placement set to True.
-sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
-# Runs the op.
-print(sess.run(c))
-```
-
-If the device you have specified does not exist, you will get
-`InvalidArgumentError`:
-
-```
-InvalidArgumentError: Invalid argument: Cannot assign a device to node 'b':
-Could not satisfy explicit device specification '/device:GPU:2'
- [[{{node b}} = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [3,2]
- values: 1 2 3...>, _device="/device:GPU:2"]()]]
-```
-
-If you would like TensorFlow to automatically choose an existing and supported
-device to run the operations in case the specified one doesn't exist, you can
-set `allow_soft_placement` to `True` in the configuration option when creating
-the session.
-
-```python
-# Creates a graph.
-with tf.device('/device:GPU:2'):
- a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
- b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
- c = tf.matmul(a, b)
-# Creates a session with allow_soft_placement and log_device_placement set
-# to True.
-sess = tf.Session(config=tf.ConfigProto(
- allow_soft_placement=True, log_device_placement=True))
-# Runs the op.
-print(sess.run(c))
-```
-
-## Using multiple GPUs
-
-If you would like to run TensorFlow on multiple GPUs, you can construct your
-model in a multi-tower fashion where each tower is assigned to a different GPU.
-For example:
-
-``` python
-# Creates a graph.
-c = []
-for d in ['/device:GPU:2', '/device:GPU:3']:
- with tf.device(d):
- a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
- b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2])
- c.append(tf.matmul(a, b))
-with tf.device('/cpu:0'):
- sum = tf.add_n(c)
-# Creates a session with log_device_placement set to True.
-sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
-# Runs the op.
-print(sess.run(sum))
-```
-
-You will see the following output.
-
-```
-Device mapping:
-/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla K20m, pci bus
-id: 0000:02:00.0
-/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: Tesla K20m, pci bus
-id: 0000:03:00.0
-/job:localhost/replica:0/task:0/device:GPU:2 -> device: 2, name: Tesla K20m, pci bus
-id: 0000:83:00.0
-/job:localhost/replica:0/task:0/device:GPU:3 -> device: 3, name: Tesla K20m, pci bus
-id: 0000:84:00.0
-Const_3: /job:localhost/replica:0/task:0/device:GPU:3
-Const_2: /job:localhost/replica:0/task:0/device:GPU:3
-MatMul_1: /job:localhost/replica:0/task:0/device:GPU:3
-Const_1: /job:localhost/replica:0/task:0/device:GPU:2
-Const: /job:localhost/replica:0/task:0/device:GPU:2
-MatMul: /job:localhost/replica:0/task:0/device:GPU:2
-AddN: /job:localhost/replica:0/task:0/cpu:0
-[[ 44. 56.]
- [ 98. 128.]]
-```
-
-The @{$deep_cnn$cifar10 tutorial} is a good example
-demonstrating how to do training with multiple GPUs.
diff --git a/tensorflow/docs_src/guide/using_tpu.md b/tensorflow/docs_src/guide/using_tpu.md
deleted file mode 100644
index 90a663b75e..0000000000
--- a/tensorflow/docs_src/guide/using_tpu.md
+++ /dev/null
@@ -1,395 +0,0 @@
-# Using TPUs
-
-This document walks through the principal TensorFlow APIs necessary to make
-effective use of a [Cloud TPU](https://cloud.google.com/tpu/), and highlights
-the differences between regular TensorFlow usage, and usage on a TPU.
-
-This doc is aimed at users who:
-
-* Are familiar with TensorFlow's `Estimator` and `Dataset` APIs
-* Have maybe [tried out a Cloud TPU](https://cloud.google.com/tpu/docs/quickstart)
- using an existing model.
-* Have, perhaps, skimmed the code of an example TPU model
- [[1]](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_tpu.py)
- [[2]](https://github.com/tensorflow/tpu/tree/master/models).
-* Are interested in porting an existing `Estimator` model to
- run on Cloud TPUs
-
-## TPUEstimator
-
-`tf.estimator.Estimator` are TensorFlow's model-level abstraction.
-Standard `Estimators` can drive models on CPU and GPUs. You must use
-`tf.contrib.tpu.TPUEstimator` to drive a model on TPUs.
-
-Refer to TensorFlow's Getting Started section for an introduction to the basics
-of using a @{$premade_estimators$pre-made `Estimator`}, and
-@{$custom_estimators$custom `Estimator`s}.
-
-The `TPUEstimator` class differs somewhat from the `Estimator` class.
-
-The simplest way to maintain a model that can be run both on CPU/GPU or on a
-Cloud TPU is to define the model's inference phase (from inputs to predictions)
-outside of the `model_fn`. Then maintain separate implementations of the
-`Estimator` setup and `model_fn`, both wrapping this inference step. For an
-example of this pattern compare the `mnist.py` and `mnist_tpu.py` implementation in
-[tensorflow/models](https://github.com/tensorflow/models/tree/master/official/mnist).
-
-### Running a `TPUEstimator` locally
-
-To create a standard `Estimator` you call the constructor, and pass it a
-`model_fn`, for example:
-
-```
-my_estimator = tf.estimator.Estimator(
- model_fn=my_model_fn)
-```
-
-The changes required to use a `tf.contrib.tpu.TPUEstimator` on your local
-machine are relatively minor. The constructor requires two additional arguments.
-You should set the `use_tpu` argument to `False`, and pass a
-`tf.contrib.tpu.RunConfig` as the `config` argument, as shown below:
-
-``` python
-my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
- model_fn=my_model_fn,
- config=tf.contrib.tpu.RunConfig()
- use_tpu=False)
-```
-
-Just this simple change will allow you to run a `TPUEstimator` locally.
-The majority of example TPU models can be run in this local mode,
-by setting the command line flags as follows:
-
-
-```
-$> python mnist_tpu.py --use_tpu=false --master=''
-```
-
-Note: This `use_tpu=False` argument is useful for trying out the `TPUEstimator`
-API. It is not meant to be a complete TPU compatibility test. Successfully
-running a model locally in a `TPUEstimator` does not guarantee that it will
-work on a TPU.
-
-
-### Building a `tpu.RunConfig`
-
-While the default `RunConfig` is sufficient for local training, these settings
-cannot be ignored in real usage.
-
-A more typical setup for a `RunConfig`, that can be switched to use a Cloud
-TPU, might be as follows:
-
-``` python
-import tempfile
-import subprocess
-
-class FLAGS(object):
- use_tpu=False
- tpu_name=None
- # Use a local temporary path for the `model_dir`
- model_dir = tempfile.mkdtemp()
- # Number of training steps to run on the Cloud TPU before returning control.
- iterations = 50
- # A single Cloud TPU has 8 shards.
- num_shards = 8
-
-if FLAGS.use_tpu:
- my_project_name = subprocess.check_output([
- 'gcloud','config','get-value','project'])
- my_zone = subprocess.check_output([
- 'gcloud','config','get-value','compute/zone'])
- cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
- tpu_names=[FLAGS.tpu_name],
- zone=my_zone,
- project=my_project)
- master = tpu_cluster_resolver.get_master()
-else:
- master = ''
-
-my_tpu_run_config = tf.contrib.tpu.RunConfig(
- master=master,
- evaluation_master=master,
- model_dir=FLAGS.model_dir,
- session_config=tf.ConfigProto(
- allow_soft_placement=True, log_device_placement=True),
- tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations,
- FLAGS.num_shards),
-)
-```
-
-Then you must pass the `tf.contrib.tpu.RunConfig` to the constructor:
-
-``` python
-my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
- model_fn=my_model_fn,
- config = my_tpu_run_config,
- use_tpu=FLAGS.use_tpu)
-```
-
-Typically the `FLAGS` would be set by command line arguments. To switch from
-training locally to training on a cloud TPU you would need to:
-
-* Set `FLAGS.use_tpu` to `True`
-* Set `FLAGS.tpu_name` so the `tf.contrib.cluster_resolver.TPUClusterResolver` can find it
-* Set `FLAGS.model_dir` to a Google Cloud Storage bucket url (`gs://`).
-
-
-## Optimizer
-
-When training on a cloud TPU you **must** wrap the optimizer in a
-`tf.contrib.tpu.CrossShardOptimizer`, which uses an `allreduce` to aggregate
-gradients and broadcast the result to each shard (each TPU core).
-
-The `CrossShardOptimizer` is not compatible with local training. So, to have
-the same code run both locally and on a Cloud TPU, add lines like the following:
-
-``` python
-optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
-if FLAGS.use_tpu:
- optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
-```
-
-If you prefer to avoid a global `FLAGS` variable in your model code, one
-approach is to set the optimizer as one of the `Estimator`'s params,
-as follows:
-
-``` python
-my_tpu_estimator = tf.contrib.tpu.TPUEstimator(
- model_fn=my_model_fn,
- config = my_tpu_run_config,
- use_tpu=FLAGS.use_tpu,
- params={'optimizer':optimizer})
-```
-
-## Model Function
-
-This section details the changes you must make to the model function
-(`model_fn()`) to make it `TPUEstimator` compatible.
-
-### Static shapes
-
-During regular usage TensorFlow attempts to determine the shapes of each
-`tf.Tensor` during graph construction. During execution any unknown shape
-dimensions are determined dynamically,
-see @{$guide/tensors#shape$Tensor Shapes} for more details.
-
-To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}.
-XLA uses a similar system for determining shapes at compile time. XLA requires
-that all tensor dimensions be statically defined at compile time. All shapes
-must evaluate to a constant, and not depend on external data, or stateful
-operations like variables or a random number generator.
-
-
-### Summaries
-
-Remove any use of `tf.summary` from your model.
-
-@{$summaries_and_tensorboard$TensorBoard summaries} are a great way see inside
-your model. A minimal set of basic summaries are automatically recorded by the
-`TPUEstimator`, to `event` files in the `model_dir`. Custom summaries, however,
-are currently unsupported when training on a Cloud TPU. So while the
-`TPUEstimator` will still run locally with summaries, it will fail if used on a
-TPU.
-
-### Metrics
-
-Build your evaluation metrics dictionary in a stand-alone `metric_fn`.
-
-<!-- TODO(markdaoust) link to guide/metrics when it exists -->
-
-Evaluation metrics are an essential part of training a model. These are fully
-supported on Cloud TPUs, but with a slightly different syntax.
-
-A standard `tf.metrics` returns two tensors. The first returns the running
-average of the metric value, while the second updates the running average and
-returns the value for this batch:
-
-```
-running_average, current_batch = tf.metrics.accuracy(labels, predictions)
-```
-
-In a standard `Estimator` you create a dictionary of these pairs, and return it
-as part of the `EstimatorSpec`.
-
-```python
-my_metrics = {'accuracy': tf.metrics.accuracy(labels, predictions)}
-
-return tf.estimator.EstimatorSpec(
- ...
- eval_metric_ops=my_metrics
-)
-```
-
-In a `TPUEstimator` you instead pass a function (which returns a metrics
-dictionary) and a list of argument tensors, as shown below:
-
-```python
-def my_metric_fn(labels, predictions):
- return {'accuracy': tf.metrics.accuracy(labels, predictions)}
-
-return tf.contrib.tpu.TPUEstimatorSpec(
- ...
- eval_metrics=(my_metric_fn, [labels, predictions])
-)
-```
-
-### Use `TPUEstimatorSpec`
-
-`TPUEstimatorSpec` do not support hooks, and require function wrappers for
-some fields.
-
-An `Estimator`'s `model_fn` must return an `EstimatorSpec`. An `EstimatorSpec`
-is a simple structure of named fields containing all the `tf.Tensors` of the
-model that the `Estimator` may need to interact with.
-
-`TPUEstimators` use a `tf.contrib.tpu.TPUEstimatorSpec`. There are a few
-differences between it and a standard `tf.estimator.EstimatorSpec`:
-
-
-* The `eval_metric_ops` must be wrapped into a `metrics_fn`, this field is
- renamed `eval_metrics` ([see above](#metrics)).
-* The `tf.train.SessionRunHook` are unsupported, so these fields are
- omitted.
-* The `tf.train.Scaffold`, if used, must also be wrapped in a
- function. This field is renamed to `scaffold_fn`.
-
-`Scaffold` and `Hooks` are for advanced usage, and can typically be omitted.
-
-## Input functions
-
-Input functions work mainly unchanged as they run on the host computer, not the
-Cloud TPU itself. This section explains the two necessary adjustments.
-
-### Params argument
-
-<!-- TODO(markdaoust) link to input_fn doc when it exists -->
-
-The `input_fn` for a standard `Estimator` _can_ include a
-`params` argument; the `input_fn` for a `TPUEstimator` *must* include a
-`params` argument. This is necessary to allow the estimator to set the batch
-size for each replica of the input stream. So the minimum signature for an
-`input_fn` for a `TPUEstimator` is:
-
-```
-def my_input_fn(params):
- pass
-```
-
-Where `params['batch-size']` will contain the batch size.
-
-### Static shapes and batch size
-
-The input pipeline generated by your `input_fn` is run on CPU. So it is mostly
-free from the strict static shape requirements imposed by the XLA/TPU environment.
-The one requirement is that the batches of data fed from your input pipeline to
-the TPU have a static shape, as determined by the standard TensorFlow shape
-inference algorithm. Intermediate tensors are free to have a dynamic shapes.
-If shape inference has failed, but the shape is known it is possible to
-impose the correct shape using `tf.set_shape()`.
-
-In the example below the shape
-inference algorithm fails, but it is correctly using `set_shape`:
-
-```
->>> x = tf.zeros(tf.constant([1,2,3])+1)
->>> x.shape
-
-TensorShape([Dimension(None), Dimension(None), Dimension(None)])
-
->>> x.set_shape([2,3,4])
-```
-
-In many cases the batch size is the only unknown dimension.
-
-A typical input pipeline, using `tf.data`, will usually produce batches of a
-fixed size. The last batch of a finite `Dataset`, however, is typically smaller,
-containing just the remaining elements. Since a `Dataset` does not know its own
-length or finiteness, the standard `tf.data.Dataset.batch` method
-cannot determine if all batches will have a fixed size batch on its own:
-
-```
->>> params = {'batch_size':32}
->>> ds = tf.data.Dataset.from_tensors([0, 1, 2])
->>> ds = ds.repeat().batch(params['batch-size'])
->>> ds
-
-<BatchDataset shapes: (?, 3), types: tf.int32>
-```
-
-The most straightforward fix is to
-`tf.data.Dataset.apply` `tf.contrib.data.batch_and_drop_remainder`
-as follows:
-
-```
->>> params = {'batch_size':32}
->>> ds = tf.data.Dataset.from_tensors([0, 1, 2])
->>> ds = ds.repeat().apply(
-... tf.contrib.data.batch_and_drop_remainder(params['batch-size']))
->>> ds
-
- <_RestructuredDataset shapes: (32, 3), types: tf.int32>
-```
-
-The one downside to this approach is that, as the name implies, this batching
-method throws out any fractional batch at the end of the dataset. This is fine
-for an infinitely repeating dataset being used for training, but could be a
-problem if you want to train for an exact number of epochs.
-
-To do an exact 1-epoch of _evaluation_ you can work around this by manually
-padding the length of the batches, and setting the padding entries to have zero
-weight when creating your `tf.metrics`.
-
-## Datasets
-
-Efficient use of the `tf.data.Dataset` API is critical when using a Cloud
-TPU, as it is impossible to use the Cloud TPU's unless you can feed it data
-quickly enough. See @{$datasets_performance} for details on dataset performance.
-
-For all but the simplest experimentation (using
-`tf.data.Dataset.from_tensor_slices` or other in-graph data) you will need to
-store all data files read by the `TPUEstimator`'s `Dataset` in Google Cloud
-Storage Buckets.
-
-<!--TODO(markdaoust): link to the `TFRecord` doc when it exists.-->
-
-For most use-cases, we recommend converting your data into `TFRecord`
-format and using a `tf.data.TFRecordDataset` to read it. This, however, is not
-a hard requirement and you can use other dataset readers
-(`FixedLengthRecordDataset` or `TextLineDataset`) if you prefer.
-
-Small datasets can be loaded entirely into memory using
-`tf.data.Dataset.cache`.
-
-Regardless of the data format used, it is strongly recommended that you
-@{$performance_guide#use_large_files$use large files}, on the order of
-100MB. This is especially important in this networked setting as the overhead
-of opening a file is significantly higher.
-
-It is also important, regardless of the type of reader used, to enable buffering
-using the `buffer_size` argument to the constructor. This argument is specified
-in bytes. A minimum of a few MB (`buffer_size=8*1024*1024`) is recommended so
-that data is available when needed.
-
-The TPU-demos repo includes
-[a script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
-for downloading the imagenet dataset and converting it to an appropriate format.
-This together with the imagenet
-[models](https://github.com/tensorflow/tpu/tree/master/models)
-included in the repo demonstrate all of these best-practices.
-
-
-## What Next
-
-For details on how to actually set up and run a Cloud TPU see:
-
- * [Google Cloud TPU Documentation](https://cloud.google.com/tpu/docs/)
-
-This document is by no means exhaustive. The best source of more detail on how
-to make a Cloud TPU compatible model are the example models published in:
-
- * The [TPU Demos Repository.](https://github.com/tensorflow/tpu)
-
-For more information about tuning TensorFlow code for performance see:
-
- * The @{$performance$Performance Section.}
-
diff --git a/tensorflow/docs_src/guide/variables.md b/tensorflow/docs_src/guide/variables.md
deleted file mode 100644
index 5d5d73394c..0000000000
--- a/tensorflow/docs_src/guide/variables.md
+++ /dev/null
@@ -1,319 +0,0 @@
-# Variables
-
-A TensorFlow **variable** is the best way to represent shared, persistent state
-manipulated by your program.
-
-Variables are manipulated via the `tf.Variable` class. A `tf.Variable`
-represents a tensor whose value can be changed by running ops on it. Unlike
-`tf.Tensor` objects, a `tf.Variable` exists outside the context of a single
-`session.run` call.
-
-Internally, a `tf.Variable` stores a persistent tensor. Specific ops allow you
-to read and modify the values of this tensor. These modifications are visible
-across multiple `tf.Session`s, so multiple workers can see the same values for a
-`tf.Variable`.
-
-## Creating a Variable
-
-The best way to create a variable is to call the `tf.get_variable`
-function. This function requires you to specify the Variable's name. This name
-will be used by other replicas to access the same variable, as well as to name
-this variable's value when checkpointing and exporting models. `tf.get_variable`
-also allows you to reuse a previously created variable of the same name, making it
-easy to define models which reuse layers.
-
-To create a variable with `tf.get_variable`, simply provide the name and shape
-
-``` python
-my_variable = tf.get_variable("my_variable", [1, 2, 3])
-```
-
-This creates a variable named "my_variable" which is a three-dimensional tensor
-with shape `[1, 2, 3]`. This variable will, by default, have the `dtype`
-`tf.float32` and its initial value will be randomized via
-`tf.glorot_uniform_initializer`.
-
-You may optionally specify the `dtype` and initializer to `tf.get_variable`. For
-example:
-
-``` python
-my_int_variable = tf.get_variable("my_int_variable", [1, 2, 3], dtype=tf.int32,
- initializer=tf.zeros_initializer)
-```
-
-TensorFlow provides many convenient initializers. Alternatively, you may
-initialize a `tf.Variable` to have the value of a `tf.Tensor`. For example:
-
-``` python
-other_variable = tf.get_variable("other_variable", dtype=tf.int32,
- initializer=tf.constant([23, 42]))
-```
-
-Note that when the initializer is a `tf.Tensor` you should not specify the
-variable's shape, as the shape of the initializer tensor will be used.
-
-
-<a name="collections"></a>
-### Variable collections
-
-Because disconnected parts of a TensorFlow program might want to create
-variables, it is sometimes useful to have a single way to access all of
-them. For this reason TensorFlow provides **collections**, which are named lists
-of tensors or other objects, such as `tf.Variable` instances.
-
-By default every `tf.Variable` gets placed in the following two collections:
-
- * `tf.GraphKeys.GLOBAL_VARIABLES` --- variables that can be shared across
- multiple devices,
- * `tf.GraphKeys.TRAINABLE_VARIABLES` --- variables for which TensorFlow will
- calculate gradients.
-
-If you don't want a variable to be trainable, add it to the
-`tf.GraphKeys.LOCAL_VARIABLES` collection instead. For example, the following
-snippet demonstrates how to add a variable named `my_local` to this collection:
-
-``` python
-my_local = tf.get_variable("my_local", shape=(),
-collections=[tf.GraphKeys.LOCAL_VARIABLES])
-```
-
-Alternatively, you can specify `trainable=False` as an argument to
-`tf.get_variable`:
-
-``` python
-my_non_trainable = tf.get_variable("my_non_trainable",
- shape=(),
- trainable=False)
-```
-
-
-You can also use your own collections. Any string is a valid collection name,
-and there is no need to explicitly create a collection. To add a variable (or
-any other object) to a collection after creating the variable, call
-`tf.add_to_collection`. For example, the following code adds an existing
-variable named `my_local` to a collection named `my_collection_name`:
-
-``` python
-tf.add_to_collection("my_collection_name", my_local)
-```
-
-And to retrieve a list of all the variables (or other objects) you've placed in
-a collection you can use:
-
-``` python
-tf.get_collection("my_collection_name")
-```
-
-### Device placement
-
-Just like any other TensorFlow operation, you can place variables on particular
-devices. For example, the following snippet creates a variable named `v` and
-places it on the second GPU device:
-
-``` python
-with tf.device("/device:GPU:1"):
- v = tf.get_variable("v", [1])
-```
-
-It is particularly important for variables to be in the correct device in
-distributed settings. Accidentally putting variables on workers instead of
-parameter servers, for example, can severely slow down training or, in the worst
-case, let each worker blithely forge ahead with its own independent copy of each
-variable. For this reason we provide `tf.train.replica_device_setter`, which
-can automatically place variables in parameter servers. For example:
-
-``` python
-cluster_spec = {
- "ps": ["ps0:2222", "ps1:2222"],
- "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
-with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
- v = tf.get_variable("v", shape=[20, 20]) # this variable is placed
- # in the parameter server
- # by the replica_device_setter
-```
-
-## Initializing variables
-
-Before you can use a variable, it must be initialized. If you are programming in
-the low-level TensorFlow API (that is, you are explicitly creating your own
-graphs and sessions), you must explicitly initialize the variables. Most
-high-level frameworks such as `tf.contrib.slim`, `tf.estimator.Estimator` and
-`Keras` automatically initialize variables for you before training a model.
-
-Explicit initialization is otherwise useful because it allows you not to rerun
-potentially expensive initializers when reloading a model from a checkpoint as
-well as allowing determinism when randomly-initialized variables are shared in a
-distributed setting.
-
-To initialize all trainable variables in one go, before training starts, call
-`tf.global_variables_initializer()`. This function returns a single operation
-responsible for initializing all variables in the
-`tf.GraphKeys.GLOBAL_VARIABLES` collection. Running this operation initializes
-all variables. For example:
-
-``` python
-session.run(tf.global_variables_initializer())
-# Now all variables are initialized.
-```
-
-If you do need to initialize variables yourself, you can run the variable's
-initializer operation. For example:
-
-``` python
-session.run(my_variable.initializer)
-```
-
-
-You can also ask which variables have still not been initialized. For example,
-the following code prints the names of all variables which have not yet been
-initialized:
-
-``` python
-print(session.run(tf.report_uninitialized_variables()))
-```
-
-
-Note that by default `tf.global_variables_initializer` does not specify the
-order in which variables are initialized. Therefore, if the initial value of a
-variable depends on another variable's value, it's likely that you'll get an
-error. Any time you use the value of a variable in a context in which not all
-variables are initialized (say, if you use a variable's value while initializing
-another variable), it is best to use `variable.initialized_value()` instead of
-`variable`:
-
-``` python
-v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
-w = tf.get_variable("w", initializer=v.initialized_value() + 1)
-```
-
-## Using variables
-
-To use the value of a `tf.Variable` in a TensorFlow graph, simply treat it like
-a normal `tf.Tensor`:
-
-``` python
-v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
-w = v + 1 # w is a tf.Tensor which is computed based on the value of v.
- # Any time a variable is used in an expression it gets automatically
- # converted to a tf.Tensor representing its value.
-```
-
-To assign a value to a variable, use the methods `assign`, `assign_add`, and
-friends in the `tf.Variable` class. For example, here is how you can call these
-methods:
-
-``` python
-v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
-assignment = v.assign_add(1)
-tf.global_variables_initializer().run()
-sess.run(assignment) # or assignment.op.run(), or assignment.eval()
-```
-
-Most TensorFlow optimizers have specialized ops that efficiently update the
-values of variables according to some gradient descent-like algorithm. See
-`tf.train.Optimizer` for an explanation of how to use optimizers.
-
-Because variables are mutable it's sometimes useful to know what version of a
-variable's value is being used at any point in time. To force a re-read of the
-value of a variable after something has happened, you can use
-`tf.Variable.read_value`. For example:
-
-``` python
-v = tf.get_variable("v", shape=(), initializer=tf.zeros_initializer())
-assignment = v.assign_add(1)
-with tf.control_dependencies([assignment]):
- w = v.read_value() # w is guaranteed to reflect v's value after the
- # assign_add operation.
-```
-
-
-## Sharing variables
-
-TensorFlow supports two ways of sharing variables:
-
- * Explicitly passing `tf.Variable` objects around.
- * Implicitly wrapping `tf.Variable` objects within `tf.variable_scope` objects.
-
-While code which explicitly passes variables around is very clear, it is
-sometimes convenient to write TensorFlow functions that implicitly use
-variables in their implementations. Most of the functional layers from
-`tf.layers` use this approach, as well as all `tf.metrics`, and a few other
-library utilities.
-
-Variable scopes allow you to control variable reuse when calling functions which
-implicitly create and use variables. They also allow you to name your variables
-in a hierarchical and understandable way.
-
-For example, let's say we write a function to create a convolutional / relu
-layer:
-
-```python
-def conv_relu(input, kernel_shape, bias_shape):
- # Create variable named "weights".
- weights = tf.get_variable("weights", kernel_shape,
- initializer=tf.random_normal_initializer())
- # Create variable named "biases".
- biases = tf.get_variable("biases", bias_shape,
- initializer=tf.constant_initializer(0.0))
- conv = tf.nn.conv2d(input, weights,
- strides=[1, 1, 1, 1], padding='SAME')
- return tf.nn.relu(conv + biases)
-```
-
-This function uses short names `weights` and `biases`, which is good for
-clarity. In a real model, however, we want many such convolutional layers, and
-calling this function repeatedly would not work:
-
-``` python
-input1 = tf.random_normal([1,10,10,32])
-input2 = tf.random_normal([1,20,20,32])
-x = conv_relu(input1, kernel_shape=[5, 5, 32, 32], bias_shape=[32])
-x = conv_relu(x, kernel_shape=[5, 5, 32, 32], bias_shape = [32]) # This fails.
-```
-
-Since the desired behavior is unclear (create new variables or reuse the
-existing ones?) TensorFlow will fail. Calling `conv_relu` in different scopes,
-however, clarifies that we want to create new variables:
-
-```python
-def my_image_filter(input_images):
- with tf.variable_scope("conv1"):
- # Variables created here will be named "conv1/weights", "conv1/biases".
- relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
- with tf.variable_scope("conv2"):
- # Variables created here will be named "conv2/weights", "conv2/biases".
- return conv_relu(relu1, [5, 5, 32, 32], [32])
-```
-
-If you do want the variables to be shared, you have two options. First, you can
-create a scope with the same name using `reuse=True`:
-
-``` python
-with tf.variable_scope("model"):
- output1 = my_image_filter(input1)
-with tf.variable_scope("model", reuse=True):
- output2 = my_image_filter(input2)
-
-```
-
-You can also call `scope.reuse_variables()` to trigger a reuse:
-
-``` python
-with tf.variable_scope("model") as scope:
- output1 = my_image_filter(input1)
- scope.reuse_variables()
- output2 = my_image_filter(input2)
-
-```
-
-Since depending on exact string names of scopes can feel dangerous, it's also
-possible to initialize a variable scope based on another one:
-
-``` python
-with tf.variable_scope("model") as scope:
- output1 = my_image_filter(input1)
-with tf.variable_scope(scope, reuse=True):
- output2 = my_image_filter(input2)
-
-```
-
diff --git a/tensorflow/docs_src/guide/version_compat.md b/tensorflow/docs_src/guide/version_compat.md
deleted file mode 100644
index 29ac066e6f..0000000000
--- a/tensorflow/docs_src/guide/version_compat.md
+++ /dev/null
@@ -1,322 +0,0 @@
-# TensorFlow Version Compatibility
-
-This document is for users who need backwards compatibility across different
-versions of TensorFlow (either for code or data), and for developers who want
-to modify TensorFlow while preserving compatibility.
-
-## Semantic Versioning 2.0
-
-TensorFlow follows Semantic Versioning 2.0 ([semver](http://semver.org)) for its
-public API. Each release version of TensorFlow has the form `MAJOR.MINOR.PATCH`.
-For example, TensorFlow version 1.2.3 has `MAJOR` version 1, `MINOR` version 2,
-and `PATCH` version 3. Changes to each number have the following meaning:
-
-* **MAJOR**: Potentially backwards incompatible changes. Code and data that
- worked with a previous major release will not necessarily work with the new
- release. However, in some cases existing TensorFlow graphs and checkpoints
- may be migratable to the newer release; see
- [Compatibility of graphs and checkpoints](#compatibility_of_graphs_and_checkpoints)
- for details on data compatibility.
-
-* **MINOR**: Backwards compatible features, speed improvements, etc. Code and
- data that worked with a previous minor release *and* which depends only on the
- public API will continue to work unchanged. For details on what is and is
- not the public API, see [What is covered](#what_is_covered).
-
-* **PATCH**: Backwards compatible bug fixes.
-
-For example, release 1.0.0 introduced backwards *incompatible* changes from
-release 0.12.1. However, release 1.1.1 was backwards *compatible* with release
-1.0.0.
-
-## What is covered
-
-Only the public APIs of TensorFlow are backwards compatible across minor and
-patch versions. The public APIs consist of
-
-* All the documented [Python](../api_docs/python) functions and classes in the
- `tensorflow` module and its submodules, except for
- * functions and classes in `tf.contrib`
- * functions and classes whose names start with `_` (as these are private)
- Note that the code in the `examples/` and `tools/` directories is not
- reachable through the `tensorflow` Python module and is thus not covered by
- the compatibility guarantee.
-
- If a symbol is available through the `tensorflow` Python module or its
- submodules, but is not documented, then it is **not** considered part of the
- public API.
-
-* The [C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h).
-
-* The following protocol buffer files:
- * [`attr_value`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto)
- * [`config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto)
- * [`event`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/event.proto)
- * [`graph`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto)
- * [`op_def`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto)
- * [`reader_base`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/reader_base.proto)
- * [`summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto)
- * [`tensor`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto)
- * [`tensor_shape`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto)
- * [`types`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto)
-
-<a name="not_covered"></a>
-## What is *not* covered
-
-Some API functions are explicitly marked as "experimental" and can change in
-backward incompatible ways between minor releases. These include:
-
-* **Experimental APIs**: The `tf.contrib` module and its submodules in Python
- and any functions in the C API or fields in protocol buffers that are
- explicitly commented as being experimental. In particular, any field in a
- protocol buffer which is called "experimental" and all its fields and
- submessages can change at any time.
-
-* **Other languages**: TensorFlow APIs in languages other than Python and C,
- such as:
-
- - @{$cc/guide$C++} (exposed through header files in
- [`tensorflow/cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/cc)).
- - [Java](../api_docs/java/reference/org/tensorflow/package-summary),
- - [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go)
- - [JavaScript](https://js.tensorflow.org)
-
-* **Details of composite ops:** Many public functions in Python expand to
- several primitive ops in the graph, and these details will be part of any
- graphs saved to disk as `GraphDef`s. These details may change for
- minor releases. In particular, regressions tests that check for exact
- matching between graphs are likely to break across minor releases, even
- though the behavior of the graph should be unchanged and existing
- checkpoints will still work.
-
-* **Floating point numerical details:** The specific floating point values
- computed by ops may change at any time. Users should rely only on
- approximate accuracy and numerical stability, not on the specific bits
- computed. Changes to numerical formulas in minor and patch releases should
- result in comparable or improved accuracy, with the caveat that in machine
- learning improved accuracy of specific formulas may result in decreased
- accuracy for the overall system.
-
-* **Random numbers:** The specific random numbers computed by the
- @{$python/constant_op#Random_Tensors$random ops} may change at any time.
- Users should rely only on approximately correct distributions and
- statistical strength, not the specific bits computed. However, we will make
- changes to random bits rarely (or perhaps never) for patch releases. We
- will, of course, document all such changes.
-
-* **Version skew in distributed Tensorflow:** Running two different versions
- of TensorFlow in a single cluster is unsupported. There are no guarantees
- about backwards compatibility of the wire protocol.
-
-* **Bugs:** We reserve the right to make backwards incompatible behavior
- (though not API) changes if the current implementation is clearly broken,
- that is, if it contradicts the documentation or if a well-known and
- well-defined intended behavior is not properly implemented due to a bug.
- For example, if an optimizer claims to implement a well-known optimization
- algorithm but does not match that algorithm due to a bug, then we will fix
- the optimizer. Our fix may break code relying on the wrong behavior for
- convergence. We will note such changes in the release notes.
-
-* **Error messages:** We reserve the right to change the text of error
- messages. In addition, the type of an error may change unless the type is
- specified in the documentation. For example, a function documented to
- raise an `InvalidArgument` exception will continue to
- raise `InvalidArgument`, but the human-readable message contents can change.
-
-## Compatibility of graphs and checkpoints
-
-You'll sometimes need to preserve graphs and checkpoints.
-Graphs describe the data flow of ops to be run during training and
-inference, and checkpoints contain the saved tensor values of variables in a
-graph.
-
-Many TensorFlow users save graphs and trained models to disk for
-later evaluation or additional training, but end up running their saved graphs
-or models on a later release. In compliance with semver, any graph or checkpoint
-written out with one version of TensorFlow can be loaded and evaluated with a
-later version of TensorFlow with the same major release. However, we will
-endeavor to preserve backwards compatibility even across major releases when
-possible, so that the serialized files are usable over long periods of time.
-
-
-Graphs are serialized via the `GraphDef` protocol buffer. To facilitate (rare)
-backwards incompatible changes to graphs, each `GraphDef` has a version number
-separate from the TensorFlow version. For example, `GraphDef` version 17
-deprecated the `inv` op in favor of `reciprocal`. The semantics are:
-
-* Each version of TensorFlow supports an interval of `GraphDef` versions. This
- interval will be constant across patch releases, and will only grow across
- minor releases. Dropping support for a `GraphDef` version will only occur
- for a major release of TensorFlow.
-
-* Newly created graphs are assigned the latest `GraphDef` version number.
-
-* If a given version of TensorFlow supports the `GraphDef` version of a graph,
- it will load and evaluate with the same behavior as the TensorFlow version
- used to generate it (except for floating point numerical details and random
- numbers), regardless of the major version of TensorFlow. In particular, all
- checkpoint files will be compatible.
-
-* If the `GraphDef` *upper* bound is increased to X in a (minor) release, there
- will be at least six months before the *lower* bound is increased to X. For
- example (we're using hypothetical version numbers here):
- * TensorFlow 1.2 might support `GraphDef` versions 4 to 7.
- * TensorFlow 1.3 could add `GraphDef` version 8 and support versions 4 to 8.
- * At least six months later, TensorFlow 2.0.0 could drop support for
- versions 4 to 7, leaving version 8 only.
-
-Finally, when support for a `GraphDef` version is dropped, we will attempt to
-provide tools for automatically converting graphs to a newer supported
-`GraphDef` version.
-
-## Graph and checkpoint compatibility when extending TensorFlow
-
-This section is relevant only when making incompatible changes to the `GraphDef`
-format, such as when adding ops, removing ops, or changing the functionality
-of existing ops. The previous section should suffice for most users.
-
-### Backward and partial forward compatibility
-
-Our versioning scheme has three requirements:
-
-* **Backward compatibility** to support loading graphs and checkpoints
- created with older versions of TensorFlow.
-* **Forward compatibility** to support scenarios where the producer of a
- graph or checkpoint is upgraded to a newer version of TensorFlow before
- the consumer.
-* Enable evolving TensorFlow in incompatible ways. For example, removing ops,
- adding attributes, and removing attributes.
-
-Note that while the `GraphDef` version mechanism is separate from the TensorFlow
-version, backwards incompatible changes to the `GraphDef` format are still
-restricted by Semantic Versioning. This means functionality can only be removed
-or changed between `MAJOR` versions of TensorFlow (such as `1.7` to `2.0`).
-Additionally, forward compatibility is enforced within Patch releases (`1.x.1`
-to `1.x.2` for example).
-
-To achieve backward and forward compatibility and to know when to enforce changes
-in formats, graphs and checkpoints have metadata that describes when they
-were produced. The sections below detail the TensorFlow implementation and
-guidelines for evolving `GraphDef` versions.
-
-### Independent data version schemes
-
-There are different data versions for graphs and checkpoints. The two data
-formats evolve at different rates from each other and also at different rates
-from TensorFlow. Both versioning systems are defined in
-[`core/public/version.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/version.h).
-Whenever a new version is added, a note is added to the header detailing what
-changed and the date.
-
-### Data, producers, and consumers
-
-We distinguish between the following kinds of data version information:
-* **producers**: binaries that produce data. Producers have a version
- (`producer`) and a minimum consumer version that they are compatible with
- (`min_consumer`).
-* **consumers**: binaries that consume data. Consumers have a version
- (`consumer`) and a minimum producer version that they are compatible with
- (`min_producer`).
-
-Each piece of versioned data has a [`VersionDef
-versions`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/versions.proto)
-field which records the `producer` that made the data, the `min_consumer`
-that it is compatible with, and a list of `bad_consumers` versions that are
-disallowed.
-
-By default, when a producer makes some data, the data inherits the producer's
-`producer` and `min_consumer` versions. `bad_consumers` can be set if specific
-consumer versions are known to contain bugs and must be avoided. A consumer can
-accept a piece of data if the following are all true:
-
-* `consumer` >= data's `min_consumer`
-* data's `producer` >= consumer's `min_producer`
-* `consumer` not in data's `bad_consumers`
-
-Since both producers and consumers come from the same TensorFlow code base,
-[`core/public/version.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/version.h)
-contains a main data version which is treated as either `producer` or
-`consumer` depending on context and both `min_consumer` and `min_producer`
-(needed by producers and consumers, respectively). Specifically,
-
-* For `GraphDef` versions, we have `TF_GRAPH_DEF_VERSION`,
- `TF_GRAPH_DEF_VERSION_MIN_CONSUMER`, and
- `TF_GRAPH_DEF_VERSION_MIN_PRODUCER`.
-* For checkpoint versions, we have `TF_CHECKPOINT_VERSION`,
- `TF_CHECKPOINT_VERSION_MIN_CONSUMER`, and
- `TF_CHECKPOINT_VERSION_MIN_PRODUCER`.
-
-### Add a new attribute with default to an existing op
-
-Following the guidance below gives you forward compatibility only if the set of
-ops has not changed:
-
-1. If forward compatibility is desired, set `strip_default_attrs` to `True`
- while exporting the model using either the
- `tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables`
- and `tf.saved_model.builder.SavedModelBuilder.add_meta_graph`
- methods of the `SavedModelBuilder` class, or
- `tf.estimator.Estimator.export_savedmodel`
-2. This strips off the default valued attributes at the time of
- producing/exporting the models. This makes sure that the exported
- `tf.MetaGraphDef` does not contain the new op-attribute when the default
- value is used.
-3. Having this control could allow out-of-date consumers (for example, serving
- binaries that lag behind training binaries) to continue loading the models
- and prevent interruptions in model serving.
-
-### Evolving GraphDef versions
-
-This section explains how to use this versioning mechanism to make different
-types of changes to the `GraphDef` format.
-
-#### Add an op
-
-Add the new op to both consumers and producers at the same time, and do not
-change any `GraphDef` versions. This type of change is automatically
-backward compatible, and does not impact forward compatibility plan since
-existing producer scripts will not suddenly use the new functionality.
-
-#### Add an op and switch existing Python wrappers to use it
-
-1. Implement new consumer functionality and increment the `GraphDef` version.
-2. If it is possible to make the wrappers use the new functionality only in
- cases that did not work before, the wrappers can be updated now.
-3. Change Python wrappers to use the new functionality. Do not increment
- `min_consumer`, since models that do not use this op should not break.
-
-#### Remove or restrict an op's functionality
-
-1. Fix all producer scripts (not TensorFlow itself) to not use the banned op or
- functionality.
-2. Increment the `GraphDef` version and implement new consumer functionality
- that bans the removed op or functionality for GraphDefs at the new version
- and above. If possible, make TensorFlow stop producing `GraphDefs` with the
- banned functionality. To do so, add the
- [`REGISTER_OP(...).Deprecated(deprecated_at_version,
- message)`](https://github.com/tensorflow/tensorflow/blob/b289bc7a50fc0254970c60aaeba01c33de61a728/tensorflow/core/ops/array_ops.cc#L1009).
-3. Wait for a major release for backward compatibility purposes.
-4. Increase `min_producer` to the GraphDef version from (2) and remove the
- functionality entirely.
-
-#### Change an op's functionality
-
-1. Add a new similar op named `SomethingV2` or similar and go through the
- process of adding it and switching existing Python wrappers to use it.
- To ensure forward compatibility use the checks suggested in
- [compat.py](https://www.tensorflow.org/code/tensorflow/python/compat/compat.py)
- when changing the Python wrappers.
-2. Remove the old op (Can only take place with a major version change due to
- backward compatibility).
-3. Increase `min_consumer` to rule out consumers with the old op, add back the
- old op as an alias for `SomethingV2`, and go through the process to switch
- existing Python wrappers to use it.
-4. Go through the process to remove `SomethingV2`.
-
-#### Ban a single unsafe consumer version
-
-1. Bump the `GraphDef` version and add the bad version to `bad_consumers` for
- all new GraphDefs. If possible, add to `bad_consumers` only for GraphDefs
- which contain a certain op or similar.
-2. If existing consumers have the bad version, push them out as soon as
- possible.
diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md
deleted file mode 100644
index 55481cc400..0000000000
--- a/tensorflow/docs_src/install/index.md
+++ /dev/null
@@ -1,39 +0,0 @@
-# Install TensorFlow
-
-Note: Run the [TensorFlow tutorials](../tutorials) in a pre-configured
-[Colab notebook environment](https://colab.research.google.com/notebooks/welcome.ipynb){: .external},
-without installation.
-
-TensorFlow is built and tested on the following 64-bit operating systems:
-
- * macOS 10.12.6 (Sierra) or later.
- * Ubuntu 16.04 or later
- * Windows 7 or later.
- * Raspbian 9.0 or later.
-
-While TensorFlow may work on other systems, we only support—and fix issues in—the
-systems listed above.
-
-The following guides explain how to install a version of TensorFlow
-that enables you to write applications in Python:
-
- * @{$install_linux$Install TensorFlow on Ubuntu}
- * @{$install_mac$Install TensorFlow on macOS}
- * @{$install_windows$Install TensorFlow on Windows}
- * @{$install_raspbian$Install TensorFlow on a Raspberry Pi}
- * @{$install_sources$Install TensorFlow from source code}
-
-Many aspects of the Python TensorFlow API changed from version 0.n to 1.0.
-The following guide explains how to migrate older TensorFlow applications
-to Version 1.0:
-
- * @{$migration$Transition to TensorFlow 1.0}
-
-The following guides explain how to install TensorFlow libraries for use in
-other programming languages. These APIs are aimed at deploying TensorFlow
-models in applications and are not as extensive as the Python APIs.
-
- * @{$install_java$Install TensorFlow for Java}
- * @{$install_c$Install TensorFlow for C}
- * @{$install_go$Install TensorFlow for Go}
-
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
deleted file mode 100644
index 4a63f11fca..0000000000
--- a/tensorflow/docs_src/install/install_c.md
+++ /dev/null
@@ -1,118 +0,0 @@
-# Install TensorFlow for C
-
-TensorFlow provides a C API defined in
-[`c_api.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h),
-which is suitable for
-[building bindings for other languages](https://www.tensorflow.org/extend/language_bindings).
-The API leans towards simplicity and uniformity rather than convenience.
-
-
-## Supported Platforms
-
-This guide explains how to install TensorFlow for C. Although these
-instructions might also work on other variants, we have only tested
-(and we only support) these instructions on machines meeting the
-following requirements:
-
- * Linux, 64-bit, x86
- * macOS X, Version 10.12.6 (Sierra) or higher
-
-
-## Installation
-
-Take the following steps to install the TensorFlow for C library and
-enable TensorFlow for C:
-
- 1. Decide whether you will run TensorFlow for C on CPU(s) only or
- with the help of GPU(s). To help you decide, read the section
- entitled "Determine which TensorFlow to install" in one of the
- following guides:
-
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
-
- 2. Download and extract the TensorFlow C library into `/usr/local/lib` by
- invoking the following shell commands:
-
- TF_TYPE="cpu" # Change to "gpu" for GPU support
- OS="linux" # Change to "darwin" for macOS
- TARGET_DIRECTORY="/usr/local"
- curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.10.0.tar.gz" |
- sudo tar -C $TARGET_DIRECTORY -xz
-
- The `tar` command extracts the TensorFlow C library into the `lib`
- subdirectory of `TARGET_DIRECTORY`. For example, specifying `/usr/local`
- as `TARGET_DIRECTORY` causes `tar` to extract the TensorFlow C library
- into `/usr/local/lib`.
-
- If you'd prefer to extract the library into a different directory,
- adjust `TARGET_DIRECTORY` accordingly.
-
- 3. In Step 2, if you specified a system directory (for example, `/usr/local`)
- as the `TARGET_DIRECTORY`, then run `ldconfig` to configure the linker.
- For example:
-
- <pre><b>sudo ldconfig</b></pre>
-
- If you assigned a `TARGET_DIRECTORY` other than a system
- directory (for example, `~/mydir`), then you must append the extraction
- directory (for example, `~/mydir/lib`) to two environment variables.
- For example:
-
- <pre> <b>export LIBRARY_PATH=$LIBRARY_PATH:~/mydir/lib</b> # For both Linux and macOS X
- <b>export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/mydir/lib</b> # For Linux only
- <b>export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:~/mydir/lib</b> # For macOS X only</pre>
-
-
-
-## Validate your installation
-
-After installing TensorFlow for C, enter the following code into a file named
-`hello_tf.c`:
-
-```c
-#include <stdio.h>
-#include <tensorflow/c/c_api.h>
-
-int main() {
- printf("Hello from TensorFlow C library version %s\n", TF_Version());
- return 0;
-}
-```
-
-### Build and Run
-
-Build `hello_tf.c` by invoking the following command:
-
-
-<pre><b>gcc hello_tf.c</b></pre>
-
-
-Running the resulting executable should output the following message:
-
-
-<pre><b>a.out</b>
-Hello from TensorFlow C library version <i>number</i></pre>
-
-
-### Troubleshooting
-
-If building the program fails, the most likely culprit is that `gcc` cannot
-find the TensorFlow C library. One way to fix this problem is to specify
-the `-I` and `-L` options to `gcc`. For example, if the `TARGET_LIBRARY`
-was `/usr/local`, you would invoke `gcc` as follows:
-
-<pre><b>gcc -I/usr/local/include -L/usr/local/lib hello_tf.c -ltensorflow</b></pre>
-
-If executing `a.out` fails, ask yourself the following questions:
-
- * Did the program build without error?
- * Have you assigned the correct directory to the environment variables
- noted in Step 3 of [Installation](#installation)?
- * Did you export those environment variables?
-
-If you are still seeing build or execution error messages, search (or post to)
-[StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow) for
-possible solutions.
-
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
deleted file mode 100644
index f0f8436777..0000000000
--- a/tensorflow/docs_src/install/install_go.md
+++ /dev/null
@@ -1,142 +0,0 @@
-# Install TensorFlow for Go
-
-TensorFlow provides APIs for use in Go programs. These APIs are particularly
-well-suited to loading models created in Python and executing them within
-a Go application. This guide explains how to install and set up the
-[TensorFlow Go package](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go).
-
-Warning: The TensorFlow Go API is *not* covered by the TensorFlow
-[API stability guarantees](../guide/version_compat.md).
-
-
-## Supported Platforms
-
-This guide explains how to install TensorFlow for Go. Although these
-instructions might also work on other variants, we have only tested
-(and we only support) these instructions on machines meeting the
-following requirements:
-
- * Linux, 64-bit, x86
- * macOS X, 10.12.6 (Sierra) or higher
-
-
-## Installation
-
-TensorFlow for Go depends on the TensorFlow C library. Take the following
-steps to install this library and enable TensorFlow for Go:
-
- 1. Decide whether you will run TensorFlow for Go on CPU(s) only or with
- the help of GPU(s). To help you decide, read the section entitled
- "Determine which TensorFlow to install" in one of the following guides:
-
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
-
- 2. Download and extract the TensorFlow C library into `/usr/local/lib` by
- invoking the following shell commands:
-
- TF_TYPE="cpu" # Change to "gpu" for GPU support
- TARGET_DIRECTORY='/usr/local'
- curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.10.0.tar.gz" |
- sudo tar -C $TARGET_DIRECTORY -xz
-
- The `tar` command extracts the TensorFlow C library into the `lib`
- subdirectory of `TARGET_DIRECTORY`. For example, specifying `/usr/local`
- as `TARGET_DIRECTORY` causes `tar` to extract the TensorFlow C library
- into `/usr/local/lib`.
-
- If you'd prefer to extract the library into a different directory,
- adjust `TARGET_DIRECTORY` accordingly.
-
- 3. In Step 2, if you specified a system directory (for example, `/usr/local`)
- as the `TARGET_DIRECTORY`, then run `ldconfig` to configure the linker.
- For example:
-
- <pre><b>sudo ldconfig</b></pre>
-
- If you assigned a `TARGET_DIRECTORY` other than a system
- directory (for example, `~/mydir`), then you must append the extraction
- directory (for example, `~/mydir/lib`) to two environment variables
- as follows:
-
- <pre> <b>export LIBRARY_PATH=$LIBRARY_PATH:~/mydir/lib</b> # For both Linux and macOS X
- <b>export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/mydir/lib</b> # For Linux only
- <b>export DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:~/mydir/lib</b> # For macOS X only</pre>
-
- 4. Now that the TensorFlow C library is installed, invoke `go get` as follows
- to download the appropriate packages and their dependencies:
-
- <pre><b>go get github.com/tensorflow/tensorflow/tensorflow/go</b></pre>
-
- 5. Invoke `go test` as follows to validate the TensorFlow for Go
- installation:
-
- <pre><b>go test github.com/tensorflow/tensorflow/tensorflow/go</b></pre>
-
-If `go get` or `go test` generate error messages, search (or post to)
-[StackOverflow](http://www.stackoverflow.com/questions/tagged/tensorflow)
-for possible solutions.
-
-
-## Hello World
-
-After installing TensorFlow for Go, enter the following code into a
-file named `hello_tf.go`:
-
-```go
-package main
-
-import (
- tf "github.com/tensorflow/tensorflow/tensorflow/go"
- "github.com/tensorflow/tensorflow/tensorflow/go/op"
- "fmt"
-)
-
-func main() {
- // Construct a graph with an operation that produces a string constant.
- s := op.NewScope()
- c := op.Const(s, "Hello from TensorFlow version " + tf.Version())
- graph, err := s.Finalize()
- if err != nil {
- panic(err)
- }
-
- // Execute the graph in a session.
- sess, err := tf.NewSession(graph, nil)
- if err != nil {
- panic(err)
- }
- output, err := sess.Run(nil, []tf.Output{c}, nil)
- if err != nil {
- panic(err)
- }
- fmt.Println(output[0].Value())
-}
-```
-
-For a more advanced example of TensorFlow in Go, look at the
-[example in the API documentation](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package),
-which uses a pre-trained TensorFlow model to label contents of an image.
-
-
-### Running
-
-Run `hello_tf.go` by invoking the following command:
-
-<pre><b>go run hello_tf.go</b>
-Hello from TensorFlow version <i>number</i></pre>
-
-The program might also generate multiple warning messages of the
-following form, which you can ignore:
-
-<pre>W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library
-wasn't compiled to use *Type* instructions, but these are available on your
-machine and could speed up CPU computations.</pre>
-
-
-## Building from source code
-
-TensorFlow is open-source. You may build TensorFlow for Go from the
-TensorFlow source code by following the instructions in a
-[separate document](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/README.md).
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
deleted file mode 100644
index c131a2ea76..0000000000
--- a/tensorflow/docs_src/install/install_java.md
+++ /dev/null
@@ -1,268 +0,0 @@
-# Install TensorFlow for Java
-
-TensorFlow provides APIs for use in Java programs. These APIs are particularly
-well-suited to loading models created in Python and executing them within a
-Java application. This guide explains how to install
-[TensorFlow for Java](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)
-and use it in a Java application.
-
-Warning: The TensorFlow Java API is *not* covered by the TensorFlow
-[API stability guarantees](../guide/version_semantics.md).
-
-
-## Supported Platforms
-
-This guide explains how to install TensorFlow for Java. Although these
-instructions might also work on other variants, we have only tested
-(and we only support) these instructions on machines meeting the
-following requirements:
-
- * Ubuntu 16.04 or higher; 64-bit, x86
- * macOS 10.12.6 (Sierra) or higher
- * Windows 7 or higher; 64-bit, x86
-
-The installation instructions for Android are in a separate
-[Android TensorFlow Support page](https://www.tensorflow.org/code/tensorflow/contrib/android).
-After installation, please see this
-[complete example](https://www.tensorflow.org/code/tensorflow/examples/android)
-of TensorFlow on Android.
-
-## Using TensorFlow with a Maven project
-
-If your project uses [Apache Maven](https://maven.apache.org), then add the
-following to the project's `pom.xml` to use the TensorFlow Java APIs:
-
-```xml
-<dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- <version>1.10.0</version>
-</dependency>
-```
-
-That's all.
-
-### Example
-
-As an example, these steps will create a Maven project that uses TensorFlow:
-
- 1. Create the project's `pom.xml`:
-
-
- <project>
- <modelVersion>4.0.0</modelVersion>
- <groupId>org.myorg</groupId>
- <artifactId>hellotf</artifactId>
- <version>1.0-SNAPSHOT</version>
- <properties>
- <exec.mainClass>HelloTF</exec.mainClass>
- <!-- The sample code requires at least JDK 1.7. -->
- <!-- The maven compiler plugin defaults to a lower version -->
- <maven.compiler.source>1.7</maven.compiler.source>
- <maven.compiler.target>1.7</maven.compiler.target>
- </properties>
- <dependencies>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- <version>1.10.0</version>
- </dependency>
- </dependencies>
- </project>
-
-
- 2. Create the source file (`src/main/java/HelloTF.java`):
-
-
- import org.tensorflow.Graph;
- import org.tensorflow.Session;
- import org.tensorflow.Tensor;
- import org.tensorflow.TensorFlow;
-
- public class HelloTF {
- public static void main(String[] args) throws Exception {
- try (Graph g = new Graph()) {
- final String value = "Hello from " + TensorFlow.version();
-
- // Construct the computation graph with a single operation, a constant
- // named "MyConst" with a value "value".
- try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
- // The Java API doesn't yet include convenience functions for adding operations.
- g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
- }
-
- // Execute the "MyConst" operation in a Session.
- try (Session s = new Session(g);
- // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
- Tensor output = s.runner().fetch("MyConst").run().get(0)) {
- System.out.println(new String(output.bytesValue(), "UTF-8"));
- }
- }
- }
- }
-
-
- 3. Compile and execute:
-
- <pre> # Use -q to hide logging from the mvn tool
- <b>mvn -q compile exec:java</b></pre>
-
-
-The preceding command should output <tt>Hello from <i>version</i></tt>. If it
-does, you've successfully set up TensorFlow for Java and are ready to use it in
-Maven projects. If not, check
-[Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow)
-for possible solutions. You can skip reading the rest of this document.
-
-### GPU support
-
-If your Linux system has an NVIDIA® GPU and your TensorFlow Java program
-requires GPU acceleration, then add the following to the project's `pom.xml`
-instead:
-
-```xml
-<dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>libtensorflow</artifactId>
- <version>1.10.0</version>
-</dependency>
-<dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>libtensorflow_jni_gpu</artifactId>
- <version>1.10.0</version>
-</dependency>
-```
-
-GPU acceleration is available via Maven only for Linux and only if your system
-meets the
-@{$install_linux#determine_which_tensorflow_to_install$requirements for GPU}.
-
-## Using TensorFlow with JDK
-
-This section describes how to use TensorFlow using the `java` and `javac`
-commands from a JDK installation. If your project uses Apache Maven, then
-refer to the simpler instructions above instead.
-
-### Install on Linux or macOS
-
-Take the following steps to install TensorFlow for Java on Linux or macOS:
-
- 1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0.jar),
- which is the TensorFlow Java Archive (JAR).
-
- 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
- the help of GPU(s). To help you decide, read the section entitled
- "Determine which TensorFlow to install" in one of the following guides:
-
- * @{$install_linux#determine_which_tensorflow_to_install$Installing TensorFlow on Linux}
- * @{$install_mac#determine_which_tensorflow_to_install$Installing TensorFlow on macOS}
-
- 3. Download and extract the appropriate Java Native Interface (JNI)
- file for your operating system and processor support by running the
- following shell commands:
-
-
- TF_TYPE="cpu" # Default processor is CPU. If you want GPU, set to "gpu"
- OS=$(uname -s | tr '[:upper:]' '[:lower:]')
- mkdir -p ./jni
- curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.10.0.tar.gz" |
- tar -xz -C ./jni
-
-### Install on Windows
-
-Take the following steps to install TensorFlow for Java on Windows:
-
- 1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0.jar),
- which is the TensorFlow Java Archive (JAR).
- 2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.10.0.zip).
- 3. Extract this .zip file.
-
-__Note__: The native library (`tensorflow_jni.dll`) requires `msvcp140.dll` at runtime, which is included in the [Visual C++ 2015 Redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=48145) package.
-
-### Validate the installation
-
-After installing TensorFlow for Java, validate your installation by entering
-the following code into a file named `HelloTF.java`:
-
-```java
-import org.tensorflow.Graph;
-import org.tensorflow.Session;
-import org.tensorflow.Tensor;
-import org.tensorflow.TensorFlow;
-
-public class HelloTF {
- public static void main(String[] args) throws Exception {
- try (Graph g = new Graph()) {
- final String value = "Hello from " + TensorFlow.version();
-
- // Construct the computation graph with a single operation, a constant
- // named "MyConst" with a value "value".
- try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
- // The Java API doesn't yet include convenience functions for adding operations.
- g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
- }
-
- // Execute the "MyConst" operation in a Session.
- try (Session s = new Session(g);
- // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
- Tensor output = s.runner().fetch("MyConst").run().get(0)) {
- System.out.println(new String(output.bytesValue(), "UTF-8"));
- }
- }
- }
-}
-```
-
-And use the instructions below to compile and run `HelloTF.java`.
-
-
-### Compiling
-
-When compiling a Java program that uses TensorFlow, the downloaded `.jar`
-must be part of your `classpath`. For example, you can include the
-downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
-as follows:
-
-<pre><b>javac -cp libtensorflow-1.10.0.jar HelloTF.java</b></pre>
-
-
-### Running
-
-To execute a Java program that depends on TensorFlow, ensure that the following
-two files are available to the JVM:
-
- * the downloaded `.jar` file
- * the extracted JNI library
-
-For example, the following command line executes the `HelloTF` program on Linux
-and macOS X:
-
-<pre><b>java -cp libtensorflow-1.10.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
-
-And the following command line executes the `HelloTF` program on Windows:
-
-<pre><b>java -cp libtensorflow-1.10.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
-
-If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
-installed TensorFlow for Java and are ready to use the API. If the program
-outputs something else, check
-[Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow) for
-possible solutions.
-
-
-### Advanced Example
-
-For a more sophisticated example, see
-[LabelImage.java](https://www.tensorflow.org/code/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java),
-which recognizes objects in an image.
-
-
-## Building from source code
-
-TensorFlow is open-source. You may build TensorFlow for Java from the
-TensorFlow source code by following the instructions in a
-[separate document](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/README.md).
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
deleted file mode 100644
index 0febdee99f..0000000000
--- a/tensorflow/docs_src/install/install_linux.md
+++ /dev/null
@@ -1,714 +0,0 @@
-# Install TensorFlow on Ubuntu
-
-This guide explains how to install TensorFlow on Ubuntu Linux. While these
-instructions may work on other Linux variants, they are tested and supported
-with the following system requirements:
-
-* 64-bit desktops or laptops
-* Ubuntu 16.04 or higher
-
-## Choose which TensorFlow to install
-
-The following TensorFlow variants are available for installation:
-
-* __TensorFlow with CPU support only__. If your system does not have a
- NVIDIA®&nbsp;GPU, you must install this version. This version of TensorFlow
- is usually easier to install, so even if you have an NVIDIA GPU, we
- recommend installing this version first.
-* __TensorFlow with GPU support__. TensorFlow programs usually run much faster
- on a GPU instead of a CPU. If you run performance-critical applications and
- your system has an NVIDIA®&nbsp;GPU that meets the prerequisites, you should
- install this version. See [TensorFlow GPU support](#NVIDIARequirements) for
- details.
-
-## How to install TensorFlow
-
-There are a few options to install TensorFlow on your machine:
-
-* [Use pip in a virtual environment](#InstallingVirtualenv) *(recommended)*
-* [Use pip in your system environment](#InstallingNativePip)
-* [Configure a Docker container](#InstallingDocker)
-* [Use pip in Anaconda](#InstallingAnaconda)
-* [Install TensorFlow from source](/install/install_sources)
-
-<a name="InstallingVirtualenv"></a>
-
-### Use `pip` in a virtual environment
-
-Key Point: Using a virtual environment is the recommended install method.
-
-The [Virtualenv](https://virtualenv.pypa.io/en/stable/) tool creates virtual
-Python environments that are isolated from other Python development on the same
-machine. In this scenario, you install TensorFlow and its dependencies within a
-virtual environment that is available when *activated*. Virtualenv provides a
-reliable way to install and run TensorFlow while avoiding conflicts with the
-rest of the system.
-
-##### 1. Install Python, `pip`, and `virtualenv`.
-
-On Ubuntu, Python is automatically installed and `pip` is *usually* installed.
-Confirm the `python` and `pip` versions:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">python -V # or: python3 -V</code>
- <code class="devsite-terminal">pip -V # or: pip3 -V</code>
-</pre>
-
-To install these packages on Ubuntu:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo apt-get install python-pip python-dev python-virtualenv # for Python 2.7</code>
- <code class="devsite-terminal">sudo apt-get install python3-pip python3-dev python-virtualenv # for Python 3.n</code>
-</pre>
-
-We *recommend* using `pip` version 8.1 or higher. If using a release before
-version 8.1, upgrade `pip`:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install --upgrade pip</code>
-</pre>
-
-If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is
-installed, use `easy_install` to install `pip`:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">easy_install -U pip</code>
-</pre>
-
-##### 2. Create a directory for the virtual environment and choose a Python interpreter.
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">mkdir ~/tensorflow # somewhere to work out of</code>
- <code class="devsite-terminal">cd ~/tensorflow</code>
- <code># Choose one of the following Python environments for the ./venv directory:</code>
- <code class="devsite-terminal">virtualenv --system-site-packages <var>venv</var> # Use python default (Python 2.7)</code>
- <code class="devsite-terminal">virtualenv --system-site-packages -p python3 <var>venv</var> # Use Python 3.n</code>
-</pre>
-
-##### 3. Activate the Virtualenv environment.
-
-Use one of these shell-specific commands to activate the virtual environment:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">source ~/tensorflow/<var>venv</var>/bin/activate # bash, sh, ksh, or zsh</code>
- <code class="devsite-terminal">source ~/tensorflow/<var>venv</var>/bin/activate.csh # csh or tcsh</code>
- <code class="devsite-terminal">. ~/tensorflow/<var>venv</var>/bin/activate.fish # fish</code>
-</pre>
-
-When the Virtualenv is activated, the shell prompt displays as `(venv) $`.
-
-##### 4. Upgrade `pip` in the virtual environment.
-
-Within the active virtual environment, upgrade `pip`:
-
-<pre class="prettyprint lang-bsh">
-(venv)$ pip install --upgrade pip
-</pre>
-
-You can install other Python packages within the virtual environment without
-affecting packages outside the `virtualenv`.
-
-##### 5. Install TensorFlow in the virtual environment.
-
-Choose one of the available TensorFlow packages for installation:
-
-* `tensorflow` —Current release for CPU
-* `tensorflow-gpu` —Current release with GPU support
-* `tf-nightly` —Nightly build for CPU
-* `tf-nightly-gpu` —Nightly build with GPU support
-
-Within an active Virtualenv environment, use `pip` to install the package:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install --upgrade tensorflow</code>
-</pre>
-
-Use `pip list` to show the packages installed in the virtual environment.
-[Validate the install](#ValidateYourInstallation) and test the version:
-
-<pre class="prettyprint lang-bsh">
-(venv)$ python -c "import tensorflow as tf; print(tf.__version__)"
-</pre>
-
-Success: TensorFlow is now installed.
-
-Use the `deactivate` command to stop the Python virtual environment.
-
-#### Problems
-
-If the above steps failed, try installing the TensorFlow binary using the remote
-URL of the `pip` package:
-
-<pre class="prettyprint lang-bsh">
-(venv)$ pip install --upgrade <var>remote-pkg-URL</var> # Python 2.7
-(venv)$ pip3 install --upgrade <var>remote-pkg-URL</var> # Python 3.n
-</pre>
-
-The <var>remote-pkg-URL</var> depends on the operating system, Python version,
-and GPU support. See [here](#the_url_of_the_tensorflow_python_package) for the
-URL naming scheme and location.
-
-See [Common Installation Problems](#common_installation_problems) if you
-encounter problems.
-
-#### Uninstall TensorFlow
-
-To uninstall TensorFlow, remove the Virtualenv directory you created in step 2:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">deactivate # stop the virtualenv</code>
- <code class="devsite-terminal">rm -r ~/tensorflow/<var>venv</var></code>
-</pre>
-
-<a name="InstallingNativePip"></a>
-
-### Use `pip` in your system environment
-
-Use `pip` to install the TensorFlow package directly on your system without
-using a container or virtual environment for isolation. This method is
-recommended for system administrators that want a TensorFlow installation that
-is available to everyone on a multi-user system.
-
-Since a system install is not isolated, it could interfere with other
-Python-based installations. But if you understand `pip` and your Python
-environment, a system `pip` install is straightforward.
-
-See the
-[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
-for a list of packages that TensorFlow installs.
-
-##### 1. Install Python, `pip`, and `virtualenv`.
-
-On Ubuntu, Python is automatically installed and `pip` is *usually* installed.
-Confirm the `python` and `pip` versions:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">python -V # or: python3 -V</code>
- <code class="devsite-terminal">pip -V # or: pip3 -V</code>
-</pre>
-
-To install these packages on Ubuntu:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">sudo apt-get install python-pip python-dev # for Python 2.7</code>
- <code class="devsite-terminal">sudo apt-get install python3-pip python3-dev # for Python 3.n</code>
-</pre>
-
-We *recommend* using `pip` version 8.1 or higher. If using a release before
-version 8.1, upgrade `pip`:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install --upgrade pip</code>
-</pre>
-
-If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is
-installed, use `easy_install` to install `pip`:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">easy_install -U pip</code>
-</pre>
-
-##### 2. Install TensorFlow on system.
-
-Choose one of the available TensorFlow packages for installation:
-
-* `tensorflow` —Current release for CPU
-* `tensorflow-gpu` —Current release with GPU support
-* `tf-nightly` —Nightly build for CPU
-* `tf-nightly-gpu` —Nightly build with GPU support
-
-And use `pip` to install the package for Python 2 or 3:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install --upgrade --user tensorflow # Python 2.7</code>
- <code class="devsite-terminal">pip3 install --upgrade --user tensorflow # Python 3.n</code>
-</pre>
-
-Use `pip list` to show the packages installed on the system.
-[Validate the install](#ValidateYourInstallation) and test the version:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">python -c "import tensorflow as tf; print(tf.__version__)"</code>
-</pre>
-
-Success: TensorFlow is now installed.
-
-#### Problems
-
-If the above steps failed, try installing the TensorFlow binary using the remote
-URL of the `pip` package:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip install --user --upgrade <var>remote-pkg-URL</var> # Python 2.7</code>
- <code class="devsite-terminal">pip3 install --user --upgrade <var>remote-pkg-URL</var> # Python 3.n</code>
-</pre>
-
-The <var>remote-pkg-URL</var> depends on the operating system, Python version,
-and GPU support. See [here](#the_url_of_the_tensorflow_python_package) for the
-URL naming scheme and location.
-
-See [Common Installation Problems](#common_installation_problems) if you
-encounter problems.
-
-#### Uninstall TensorFlow
-
-To uninstall TensorFlow on your system, use one of following commands:
-
-<pre class="prettyprint lang-bsh">
- <code class="devsite-terminal">pip uninstall tensorflow # for Python 2.7</code>
- <code class="devsite-terminal">pip3 uninstall tensorflow # for Python 3.n</code>
-</pre>
-
-<a name="InstallingDocker"></a>
-
-### Configure a Docker container
-
-Docker completely isolates the TensorFlow installation from pre-existing
-packages on your machine. The Docker container contains TensorFlow and all its
-dependencies. Note that the Docker image can be quite large (hundreds of MBs).
-You might choose the Docker installation if you are incorporating TensorFlow
-into a larger application architecture that already uses Docker.
-
-Take the following steps to install TensorFlow through Docker:
-
-1. Install Docker on your machine as described in the
- [Docker documentation](http://docs.docker.com/engine/installation/).
-2. Optionally, create a Linux group called <code>docker</code> to allow
- launching containers without sudo as described in the
- [Docker documentation](https://docs.docker.com/engine/installation/linux/linux-postinstall/).
- (If you don't do this step, you'll have to use sudo each time you invoke
- Docker.)
-3. To install a version of TensorFlow that supports GPUs, you must first
- install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), which is
- stored in github.
-4. Launch a Docker container that contains one of the
- [TensorFlow binary images](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
-
-The remainder of this section explains how to launch a Docker container.
-
-#### CPU-only
-
-To launch a Docker container with CPU-only support (that is, without GPU
-support), enter a command of the following format:
-
-<pre>
-$ docker run -it <i>-p hostPort:containerPort TensorFlowCPUImage</i>
-</pre>
-
-where:
-
-* <tt><i>-p hostPort:containerPort</i></tt> is optional. If you plan to run
- TensorFlow programs from the shell, omit this option. If you plan to run
- TensorFlow programs as Jupyter notebooks, set both <tt><i>hostPort</i></tt>
- and <tt><i>containerPort</i></tt> to <tt>8888</tt>. If you'd like to run
- TensorBoard inside the container, add a second `-p` flag, setting both
- <i>hostPort</i> and <i>containerPort</i> to 6006.
-* <tt><i>TensorFlowCPUImage</i></tt> is required. It identifies the Docker
- container. Specify one of the following values:
-
- * <tt>tensorflow/tensorflow</tt>, which is the TensorFlow CPU binary
- image.
- * <tt>tensorflow/tensorflow:latest-devel</tt>, which is the latest
- TensorFlow CPU Binary image plus source code.
- * <tt>tensorflow/tensorflow:<i>version</i></tt>, which is the specified
- version (for example, 1.1.0rc1) of TensorFlow CPU binary image.
- * <tt>tensorflow/tensorflow:<i>version</i>-devel</tt>, which is the
- specified version (for example, 1.1.0rc1) of the TensorFlow GPU binary
- image plus source code.
-
- TensorFlow images are available at
- [dockerhub](https://hub.docker.com/r/tensorflow/tensorflow/).
-
-For example, the following command launches the latest TensorFlow CPU binary
-image in a Docker container from which you can run TensorFlow programs in a
-shell:
-
-<pre>
-$ <b>docker run -it tensorflow/tensorflow bash</b>
-</pre>
-
-The following command also launches the latest TensorFlow CPU binary image in a
-Docker container. However, in this Docker container, you can run TensorFlow
-programs in a Jupyter notebook:
-
-<pre>
-$ <b>docker run -it -p 8888:8888 tensorflow/tensorflow</b>
-</pre>
-
-Docker will download the TensorFlow binary image the first time you launch it.
-
-#### GPU support
-
-To launch a Docker container with NVidia GPU support, enter a command of the
-following format (this
-[does not require any local CUDA installation](https://github.com/nvidia/nvidia-docker/wiki/CUDA#requirements)):
-
-<pre>
-$ <b>nvidia-docker run -it</b> <i>-p hostPort:containerPort TensorFlowGPUImage</i>
-</pre>
-
-where:
-
-* <tt><i>-p hostPort:containerPort</i></tt> is optional. If you plan to run
- TensorFlow programs from the shell, omit this option. If you plan to run
- TensorFlow programs as Jupyter notebooks, set both <tt><i>hostPort</i></tt>
- and <code><em>containerPort</em></code> to `8888`.
-* <i>TensorFlowGPUImage</i> specifies the Docker container. You must specify
- one of the following values:
- * <tt>tensorflow/tensorflow:latest-gpu</tt>, which is the latest
- TensorFlow GPU binary image.
- * <tt>tensorflow/tensorflow:latest-devel-gpu</tt>, which is the latest
- TensorFlow GPU Binary image plus source code.
- * <tt>tensorflow/tensorflow:<i>version</i>-gpu</tt>, which is the
- specified version (for example, 0.12.1) of the TensorFlow GPU binary
- image.
- * <tt>tensorflow/tensorflow:<i>version</i>-devel-gpu</tt>, which is the
- specified version (for example, 0.12.1) of the TensorFlow GPU binary
- image plus source code.
-
-We recommend installing one of the `latest` versions. For example, the following
-command launches the latest TensorFlow GPU binary image in a Docker container
-from which you can run TensorFlow programs in a shell:
-
-<pre>
-$ <b>nvidia-docker run -it tensorflow/tensorflow:latest-gpu bash</b>
-</pre>
-
-The following command also launches the latest TensorFlow GPU binary image in a
-Docker container. In this Docker container, you can run TensorFlow programs in a
-Jupyter notebook:
-
-<pre>
-$ <b>nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:latest-gpu</b>
-</pre>
-
-The following command installs an older TensorFlow version (0.12.1):
-
-<pre>
-$ <b>nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:0.12.1-gpu</b>
-</pre>
-
-Docker will download the TensorFlow binary image the first time you launch it.
-For more details see the
-[TensorFlow docker readme](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/docker).
-
-#### Next Steps
-
-You should now [validate your installation](#ValidateYourInstallation).
-
-<a name="InstallingAnaconda"></a>
-
-### Use `pip` in Anaconda
-
-Anaconda provides the `conda` utility to create a virtual environment. However,
-within Anaconda, we recommend installing TensorFlow using the `pip install`
-command and *not* with the `conda install` command.
-
-Caution: `conda` is a community supported package this is not officially
-maintained by the TensorFlow team. Use this package at your own risk since it is
-not tested on new TensorFlow releases.
-
-Take the following steps to install TensorFlow in an Anaconda environment:
-
-1. Follow the instructions on the
- [Anaconda download site](https://www.continuum.io/downloads) to download and
- install Anaconda.
-
-2. Create a conda environment named <tt>tensorflow</tt> to run a version of
- Python by invoking the following command:
-
- <pre>$ <b>conda create -n tensorflow pip python=2.7 # or python=3.3, etc.</b></pre>
-
-3. Activate the conda environment by issuing the following command:
-
- <pre>$ <b>source activate tensorflow</b>
- (tensorflow)$ # Your prompt should change </pre>
-
-4. Issue a command of the following format to install TensorFlow inside your
- conda environment:
-
- <pre>(tensorflow)$ <b>pip install --ignore-installed --upgrade</b> <i>tfBinaryURL</i></pre>
-
- where <code><em>tfBinaryURL</em></code> is the
- [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package).
- For example, the following command installs the CPU-only version of
- TensorFlow for Python 3.4:
-
- <pre>
- (tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp34-cp34m-linux_x86_64.whl</b></pre>
-
-<a name="ValidateYourInstallation"></a>
-
-## Validate your installation
-
-To validate your TensorFlow installation, do the following:
-
-1. Ensure that your environment is prepared to run TensorFlow programs.
-2. Run a short TensorFlow program.
-
-### Prepare your environment
-
-If you installed on native pip, Virtualenv, or Anaconda, then do the following:
-
-1. Start a terminal.
-2. If you installed with Virtualenv or Anaconda, activate your container.
-3. If you installed TensorFlow source code, navigate to any directory *except*
- one containing TensorFlow source code.
-
-If you installed through Docker, start a Docker container from which you can run
-bash. For example:
-
-<pre>
-$ <b>docker run -it tensorflow/tensorflow bash</b>
-</pre>
-
-### Run a short TensorFlow program
-
-Invoke python from your shell as follows:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
-# Python
-import tensorflow as tf
-hello = tf.constant('Hello, TensorFlow!')
-sess = tf.Session()
-print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin writing
-TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-If the system outputs an error message instead of a greeting, see
-[Common installation problems](#common_installation_problems).
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-<a name="NVIDIARequirements"></a>
-
-## TensorFlow GPU support
-
-Note: Due to the number of libraries required, using [Docker](#InstallingDocker)
-is recommended over installing directly on the host system.
-
-The following NVIDIA® <i>hardware</i> must be installed on your system:
-
-* GPU card with CUDA Compute Capability 3.5 or higher. See
- [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of
- supported GPU cards.
-
-The following NVIDIA® <i>software</i> must be installed on your system:
-
-* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher.
-* [CUDA Toolkit 9.0](http://nvidia.com/cuda).
-* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 7.0). Version 7.1 is
- recommended.
-* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but
- you also need to append its path to the `LD_LIBRARY_PATH` environment
- variable: `export
- LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64`
-* *OPTIONAL*: [NCCL 2.2](https://developer.nvidia.com/nccl) to use TensorFlow
- with multiple GPUs.
-* *OPTIONAL*:
- [TensorRT](http://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html)
- which can improve latency and throughput for inference for some models.
-
-To use a GPU with CUDA Compute Capability 3.0, or different versions of the
-preceding NVIDIA libraries see
-@{$install_sources$installing TensorFlow from Sources}. If using Ubuntu 16.04
-and possibly other Debian based linux distros, `apt-get` can be used with the
-NVIDIA repository to simplify installation.
-
-```bash
-# Adds NVIDIA package repository.
-sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub
-wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_9.1.85-1_amd64.deb
-wget http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb
-sudo dpkg -i cuda-repo-ubuntu1604_9.1.85-1_amd64.deb
-sudo dpkg -i nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb
-sudo apt-get update
-# Includes optional NCCL 2.x.
-sudo apt-get install cuda9.0 cuda-cublas-9-0 cuda-cufft-9-0 cuda-curand-9-0 \
- cuda-cusolver-9-0 cuda-cusparse-9-0 libcudnn7=7.1.4.18-1+cuda9.0 \
- libnccl2=2.2.13-1+cuda9.0 cuda-command-line-tools-9-0
-# Optionally install TensorRT runtime, must be done after above cuda install.
-sudo apt-get update
-sudo apt-get install libnvinfer4=4.1.2-1+cuda9.0
-```
-
-## Common installation problems
-
-We are relying on Stack Overflow to document TensorFlow installation problems
-and their remedies. The following table contains links to Stack Overflow answers
-for some common installation problems. If you encounter an error message or
-other installation problem not listed in the following table, search for it on
-Stack Overflow. If Stack Overflow doesn't show the error message, ask a new
-question about it on Stack Overflow and specify the `tensorflow` tag.
-
-<table>
-<tr> <th>Link to GitHub or Stack&nbsp;Overflow</th> <th>Error Message</th> </tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/36159194">36159194</a></td>
- <td><pre>ImportError: libcudart.so.<i>Version</i>: cannot open shared object file:
- No such file or directory</pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/41991101">41991101</a></td>
- <td><pre>ImportError: libcudnn.<i>Version</i>: cannot open shared object file:
- No such file or directory</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/36371137">36371137</a> and
- <a href="#Protobuf31">here</a></td>
- <td><pre>libprotobuf ERROR google/protobuf/src/google/protobuf/io/coded_stream.cc:207] A
- protocol message was rejected because it was too big (more than 67108864 bytes).
- To increase the limit (or to disable these warnings), see
- CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.</pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/35252888">35252888</a></td>
- <td><pre>Error importing tensorflow. Unless you are using bazel, you should
- not try to import tensorflow from its source directory; please exit the
- tensorflow source tree, and relaunch your python interpreter from
- there.</pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/33623453">33623453</a></td>
- <td><pre>IOError: [Errno 2] No such file or directory:
- '/tmp/pip-o6Tpui-build/setup.py'</tt></pre>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42006320">42006320</a></td>
- <td><pre>ImportError: Traceback (most recent call last):
- File ".../tensorflow/core/framework/graph_pb2.py", line 6, in <module>
- from google.protobuf import descriptor as _descriptor
- ImportError: cannot import name 'descriptor'</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/questions/35190574">35190574</a> </td>
- <td><pre>SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify
- failed</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42009190">42009190</a></td>
- <td><pre>
- Installing collected packages: setuptools, protobuf, wheel, numpy, tensorflow
- Found existing installation: setuptools 1.1.6
- Uninstalling setuptools-1.1.6:
- Exception:
- ...
- [Errno 1] Operation not permitted:
- '/tmp/pip-a1DXRT-uninstall/.../lib/python/_markerlib' </pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/questions/36933958">36933958</a></td>
- <td><pre>
- ...
- Installing collected packages: setuptools, protobuf, wheel, numpy, tensorflow
- Found existing installation: setuptools 1.1.6
- Uninstalling setuptools-1.1.6:
- Exception:
- ...
- [Errno 1] Operation not permitted:
- '/tmp/pip-a1DXRT-uninstall/System/Library/Frameworks/Python.framework/
- Versions/2.7/Extras/lib/python/_markerlib'</pre>
- </td>
-</tr>
-
-</table>
-
-<a name="TF_PYTHON_URL"></a>
-
-## The URL of the TensorFlow Python package
-
-A few installation mechanisms require the URL of the TensorFlow Python package.
-The value you specify depends on three factors:
-
-* operating system
-* Python version
-* CPU only vs. GPU support
-
-This section documents the relevant values for Linux installations.
-
-### Python 2.7
-
-CPU only:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp27-none-linux_x86_64.whl
-</pre>
-
-GPU support:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp27-none-linux_x86_64.whl
-</pre>
-
-Note that GPU support requires the NVIDIA hardware and software described in
-[NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements).
-
-### Python 3.4
-
-CPU only:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp34-cp34m-linux_x86_64.whl
-</pre>
-
-GPU support:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp34-cp34m-linux_x86_64.whl
-</pre>
-
-Note that GPU support requires the NVIDIA hardware and software described in
-[NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements).
-
-### Python 3.5
-
-CPU only:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl
-</pre>
-
-GPU support:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp35-cp35m-linux_x86_64.whl
-</pre>
-
-Note that GPU support requires the NVIDIA hardware and software described in
-[NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements).
-
-### Python 3.6
-
-CPU only:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl
-</pre>
-
-GPU support:
-
-<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp36-cp36m-linux_x86_64.whl
-</pre>
-
-Note that GPU support requires the NVIDIA hardware and software described in
-[NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements).
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
deleted file mode 100644
index c4d63cc107..0000000000
--- a/tensorflow/docs_src/install/install_mac.md
+++ /dev/null
@@ -1,529 +0,0 @@
-# Install TensorFlow on macOS
-
-This guide explains how to install TensorFlow on macOS. Although these
-instructions might also work on other macOS variants, we have only
-tested (and we only support) these instructions on machines meeting the
-following requirements:
-
- * macOS 10.12.6 (Sierra) or higher
-
-Note: There are known, accuracy-affecting numerical issues before macOS 10.12.6
-(Sierra) that are described in
-[GitHub#15933](https://github.com/tensorflow/tensorflow/issues/15933#issuecomment-366331383).
-
-Note: As of version 1.2, TensorFlow no longer provides GPU support on macOS.
-
-## Determine how to install TensorFlow
-
-You must pick the mechanism by which you install TensorFlow. The supported choices are as follows:
-
- * Virtualenv
- * "native" pip
- * Docker
- * installing from sources, which is documented in
- [a separate guide](https://www.tensorflow.org/install/install_sources).
-
-**We recommend the Virtualenv installation.**
-[Virtualenv](https://virtualenv.pypa.io/en/stable)
-is a virtual Python environment isolated from other Python development,
-incapable of interfering with or being affected by other Python programs
-on the same machine. During the Virtualenv installation process,
-you will install not only TensorFlow but also all the packages that
-TensorFlow requires. (This is actually pretty easy.)
-To start working with TensorFlow, you simply need to "activate" the
-virtual environment. All in all, Virtualenv provides a safe and
-reliable mechanism for installing and running TensorFlow.
-
-Native pip installs TensorFlow directly on your system without going through
-any container or virtual environment system. Since a native pip installation
-is not walled-off, the pip installation might interfere with or be influenced
-by other Python-based installations on your system. Furthermore, you might need
-to disable System Integrity Protection (SIP) in order to install through native
-pip. However, if you understand SIP, pip, and your Python environment, a
-native pip installation is relatively easy to perform.
-
-[Docker](http://docker.com) completely isolates the TensorFlow installation
-from pre-existing packages on your machine. The Docker container contains
-TensorFlow and all its dependencies. Note that the Docker image can be quite
-large (hundreds of MBs). You might choose the Docker installation if you are
-incorporating TensorFlow into a larger application architecture that
-already uses Docker.
-
-In Anaconda, you may use conda to create a virtual environment.
-However, within Anaconda, we recommend installing TensorFlow with the
-`pip install` command, not with the `conda install` command.
-
-**NOTE:** The conda package is community supported, not officially supported.
-That is, the TensorFlow team neither tests nor maintains the conda package.
-Use that package at your own risk.
-
-## Installing with Virtualenv
-
-Take the following steps to install TensorFlow with Virtualenv:
-
- 1. Start a terminal (a shell). You'll perform all subsequent steps
- in this shell.
-
- 2. Install pip and Virtualenv by issuing the following commands:
-
- <pre> $ <b>sudo easy_install pip</b>
- $ <b>pip install --upgrade virtualenv</b> </pre>
-
- 3. Create a Virtualenv environment by issuing a command of one
- of the following formats:
-
- <pre> $ <b>virtualenv --system-site-packages</b> <i>targetDirectory</i> # for Python 2.7
- $ <b>virtualenv --system-site-packages -p python3</b> <i>targetDirectory</i> # for Python 3.n
- </pre>
-
- where <i>targetDirectory</i> identifies the top of the Virtualenv tree.
- Our instructions assume that <i>targetDirectory</i>
- is `~/tensorflow`, but you may choose any directory.
-
- 4. Activate the Virtualenv environment by issuing one of the
- following commands:
-
- <pre>$ <b>cd <i>targetDirectory</i></b>
- $ <b>source ./bin/activate</b> # If using bash, sh, ksh, or zsh
- $ <b>source ./bin/activate.csh</b> # If using csh or tcsh </pre>
-
- The preceding `source` command should change your prompt to the following:
-
- <pre> (<i>targetDirectory</i>)$ </pre>
-
- 5. Ensure pip ≥8.1 is installed:
-
- <pre> (<i>targetDirectory</i>)$ <b>easy_install -U pip</b></pre>
-
- 6. Issue one of the following commands to install TensorFlow and all the
- packages that TensorFlow requires into the active Virtualenv environment:
-
- <pre> (<i>targetDirectory</i>)$ <b>pip install --upgrade tensorflow</b> # for Python 2.7
- (<i>targetDirectory</i>)$ <b>pip3 install --upgrade tensorflow</b> # for Python 3.n
-
- 7. Optional. If Step 6 failed (typically because you invoked a pip version
- lower than 8.1), install TensorFlow in the active
- Virtualenv environment by issuing a command of the following format:
-
- <pre> $ <b>pip install --upgrade</b> <i>tfBinaryURL</i> # Python 2.7
- $ <b>pip3 install --upgrade</b> <i>tfBinaryURL</i> # Python 3.n </pre>
-
- where <i>tfBinaryURL</i> identifies the URL
- of the TensorFlow Python package. The appropriate value of
- <i>tfBinaryURL</i> depends on the operating system and
- Python version. Find the appropriate value for
- <i>tfBinaryURL</i> for your system
- [here](#the_url_of_the_tensorflow_python_package).
- For example, if you are installing TensorFlow for macOS,
- Python 2.7, the command to install
- TensorFlow in the active Virtualenv is as follows:
-
- <pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl</b></pre>
-
-If you encounter installation problems, see
-[Common Installation Problems](#common-installation-problems).
-
-
-### Next Steps
-
-After installing TensorFlow,
-[validate your installation](#ValidateYourInstallation)
-to confirm that the installation worked properly.
-
-Note that you must activate the Virtualenv environment each time you
-use TensorFlow in a new shell. If the Virtualenv environment is not
-currently active (that is, the prompt is not `(<i>targetDirectory</i>)`, invoke
-one of the following commands:
-
-<pre>$ <b>cd <i>targetDirectory</i></b>
-$ <b>source ./bin/activate</b> # If using bash, sh, ksh, or zsh
-$ <b>source ./bin/activate.csh</b> # If using csh or tcsh </pre>
-
-
-Your prompt will transform to the following to indicate that your
-tensorflow environment is active:
-
-<pre> (<i>targetDirectory</i>)$ </pre>
-
-When the Virtualenv environment is active, you may run
-TensorFlow programs from this shell.
-
-When you are done using TensorFlow, you may deactivate the
-environment by issuing the following command:
-
-<pre> (<i>targetDirectory</i>)$ <b>deactivate</b> </pre>
-
-The prompt will revert back to your default prompt (as defined by `PS1`).
-
-
-### Uninstalling TensorFlow
-
-If you want to uninstall TensorFlow, simply remove the tree you created. For example:
-
-<pre> $ <b>rm -r ~/tensorflow</b> </pre>
-
-
-## Installing with native pip
-
-We have uploaded the TensorFlow binaries to PyPI.
-Therefore, you can install TensorFlow through pip.
-
-The
-[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
-lists the packages that pip will install or upgrade.
-
-
-### Prerequisite: Python
-
-In order to install TensorFlow, your system must contain one of the following Python versions:
-
- * Python 2.7
- * Python 3.3+
-
-If your system does not already have one of the preceding Python versions,
-[install](https://wiki.python.org/moin/BeginnersGuide/Download) it now.
-
-When installing Python, you might need to disable
-System Integrity Protection (SIP) to permit any entity other than
-Mac App Store to install software.
-
-
-### Prerequisite: pip
-
-[Pip](https://en.wikipedia.org/wiki/Pip_(package_manager)) installs
-and manages software packages written in Python. If you intend to install
-with native pip, then one of the following flavors of pip must be
-installed on your system:
-
- * `pip`, for Python 2.7
- * `pip3`, for Python 3.n.
-
-`pip` or `pip3` was probably installed on your system when you
-installed Python. To determine whether pip or pip3 is actually
-installed on your system, issue one of the following commands:
-
-<pre>$ <b>pip -V</b> # for Python 2.7
-$ <b>pip3 -V</b> # for Python 3.n </pre>
-
-We strongly recommend pip or pip3 version 8.1 or higher in order
-to install TensorFlow. If pip or pip3 8.1 or later is not
-installed, issue the following commands to install or upgrade:
-
-<pre>$ <b>sudo easy_install --upgrade pip</b>
-$ <b>sudo easy_install --upgrade six</b> </pre>
-
-
-### Install TensorFlow
-
-Assuming the prerequisite software is installed on your Mac,
-take the following steps:
-
- 1. Install TensorFlow by invoking **one** of the following commands:
-
- <pre> $ <b>pip install tensorflow</b> # Python 2.7; CPU support
- $ <b>pip3 install tensorflow</b> # Python 3.n; CPU support
-
- If the preceding command runs to completion, you should now
- [validate your installation](#ValidateYourInstallation).
-
- 2. (Optional.) If Step 1 failed, install the latest version of TensorFlow
- by issuing a command of the following format:
-
- <pre> $ <b>sudo pip install --upgrade</b> <i>tfBinaryURL</i> # Python 2.7
- $ <b>sudo pip3 install --upgrade</b> <i>tfBinaryURL</i> # Python 3.n </pre>
-
- where <i>tfBinaryURL</i> identifies the URL of the TensorFlow Python
- package. The appropriate value of <i>tfBinaryURL</i> depends on the
- operating system and Python version. Find the appropriate
- value for <i>tfBinaryURL</i>
- [here](#the_url_of_the_tensorflow_python_package). For example, if
- you are installing TensorFlow for macOS and Python 2.7
- issue the following command:
-
- <pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl</b> </pre>
-
- If the preceding command fails, see
- [installation problems](#common-installation-problems).
-
-
-
-### Next Steps
-
-After installing TensorFlow,
-[validate your installation](#ValidateYourInstallation)
-to confirm that the installation worked properly.
-
-
-### Uninstalling TensorFlow
-
-To uninstall TensorFlow, issue one of following commands:
-
-<pre>$ <b>pip uninstall tensorflow</b>
-$ <b>pip3 uninstall tensorflow</b> </pre>
-
-
-## Installing with Docker
-
-Follow these steps to install TensorFlow through Docker.
-
- 1. Install Docker on your machine as described in the
- [Docker documentation](https://docs.docker.com/engine/installation/#/on-macos-and-windows).
-
- 2. Launch a Docker container that contains one of the TensorFlow
- binary images.
-
-The remainder of this section explains how to launch a Docker container.
-
-To launch a Docker container that holds the TensorFlow binary image,
-enter a command of the following format:
-
-<pre> $ <b>docker run -it <i>-p hostPort:containerPort</i> TensorFlowImage</b> </pre>
-
-where:
-
- * <i>-p hostPort:containerPort</i> is optional. If you'd like to run
- TensorFlow programs from the shell, omit this option. If you'd like
- to run TensorFlow programs from Jupyter notebook, set both
- <i>hostPort</i> and <i>containerPort</i> to <code>8888</code>.
- If you'd like to run TensorBoard inside the container, add
- a second `-p` flag, setting both <i>hostPort</i> and <i>containerPort</i>
- to 6006.
- * <i>TensorFlowImage</i> is required. It identifies the Docker container.
- You must specify one of the following values:
- * <code>tensorflow/tensorflow</code>: TensorFlow binary image.
- * <code>tensorflow/tensorflow:latest-devel</code>: TensorFlow
- Binary image plus source code.
-
-The TensorFlow images are available at
-[dockerhub](https://hub.docker.com/r/tensorflow/tensorflow/).
-
-For example, the following command launches a TensorFlow CPU binary image
-in a Docker container from which you can run TensorFlow programs in a shell:
-
-<pre>$ <b>docker run -it tensorflow/tensorflow bash</b></pre>
-
-The following command also launches a TensorFlow CPU binary image in a
-Docker container. However, in this Docker container, you can run
-TensorFlow programs in a Jupyter notebook:
-
-<pre>$ <b>docker run -it -p 8888:8888 tensorflow/tensorflow</b></pre>
-
-Docker will download the TensorFlow binary image the first time you launch it.
-
-
-### Next Steps
-
-You should now
-[validate your installation](#ValidateYourInstallation).
-
-
-## Installing with Anaconda
-
-**The Anaconda installation is community supported, not officially supported.**
-
-Take the following steps to install TensorFlow in an Anaconda environment:
-
- 1. Follow the instructions on the
- [Anaconda download site](https://www.continuum.io/downloads)
- to download and install Anaconda.
-
- 2. Create a conda environment named `tensorflow`
- by invoking the following command:
-
- <pre>$ <b>conda create -n tensorflow pip python=2.7 # or python=3.3, etc.</b></pre>
-
- 3. Activate the conda environment by issuing the following command:
-
- <pre>$ <b>source activate tensorflow</b>
- (<i>targetDirectory</i>)$ # Your prompt should change</pre>
-
- 4. Issue a command of the following format to install
- TensorFlow inside your conda environment:
-
- <pre>(<i>targetDirectory</i>)<b>$ pip install --ignore-installed --upgrade</b> <i>TF_PYTHON_URL</i></pre>
-
- where <i>TF_PYTHON_URL</i> is the
- [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package).
- For example, the following command installs the CPU-only version of
- TensorFlow for Python 2.7:
-
- <pre> (<i>targetDirectory</i>)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py2-none-any.whl</b></pre>
-
-
-<a name="ValidateYourInstallation"></a>
-## Validate your installation
-
-To validate your TensorFlow installation, do the following:
-
- 1. Ensure that your environment is prepared to run TensorFlow programs.
- 2. Run a short TensorFlow program.
-
-
-### Prepare your environment
-
-If you installed on native pip, Virtualenv, or Anaconda, then
-do the following:
-
- 1. Start a terminal.
- 2. If you installed with Virtualenv or Anaconda, activate your container.
- 3. If you installed TensorFlow source code, navigate to any
- directory *except* one containing TensorFlow source code.
-
-If you installed through Docker, start a Docker container that runs bash.
-For example:
-
-<pre>$ <b>docker run -it tensorflow/tensorflow bash</b></pre>
-
-
-
-### Run a short TensorFlow program
-
-Invoke python from your shell as follows:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
-# Python
-import tensorflow as tf
-hello = tf.constant('Hello, TensorFlow!')
-sess = tf.Session()
-print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin
-writing TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-If the system outputs an error message instead of a greeting, see
-[Common installation problems](#common_installation_problems).
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-## Common installation problems
-
-We are relying on Stack Overflow to document TensorFlow installation problems
-and their remedies. The following table contains links to Stack Overflow
-answers for some common installation problems.
-If you encounter an error message or other
-installation problem not listed in the following table, search for it
-on Stack Overflow. If Stack Overflow doesn't show the error message,
-ask a new question about it on Stack Overflow and specify
-the `tensorflow` tag.
-
-<table>
-<tr> <th>Stack Overflow Link</th> <th>Error Message</th> </tr>
-
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42006320">42006320</a></td>
- <td><pre>ImportError: Traceback (most recent call last):
-File ".../tensorflow/core/framework/graph_pb2.py", line 6, in <module>
-from google.protobuf import descriptor as _descriptor
-ImportError: cannot import name 'descriptor'</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/33623453">33623453</a></td>
- <td><pre>IOError: [Errno 2] No such file or directory:
- '/tmp/pip-o6Tpui-build/setup.py'</tt></pre>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/questions/35190574">35190574</a> </td>
- <td><pre>SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify
- failed</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42009190">42009190</a></td>
- <td><pre>
- Installing collected packages: setuptools, protobuf, wheel, numpy, tensorflow
- Found existing installation: setuptools 1.1.6
- Uninstalling setuptools-1.1.6:
- Exception:
- ...
- [Errno 1] Operation not permitted:
- '/tmp/pip-a1DXRT-uninstall/.../lib/python/_markerlib' </pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/33622019">33622019</a></td>
- <td><pre>ImportError: No module named copyreg</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/37810228">37810228</a></td>
- <td>During a <tt>pip install</tt> operation, the system returns:
- <pre>OSError: [Errno 1] Operation not permitted</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/33622842">33622842</a></td>
- <td>An <tt>import tensorflow</tt> statement triggers an error such as the
- following:<pre>Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- File "/usr/local/lib/python2.7/site-packages/tensorflow/__init__.py",
- line 4, in <module>
- from tensorflow.python import *
- ...
- File "/usr/local/lib/python2.7/site-packages/tensorflow/core/framework/tensor_shape_pb2.py",
- line 22, in <module>
- serialized_pb=_b('\n,tensorflow/core/framework/tensor_shape.proto\x12\ntensorflow\"d\n\x10TensorShapeProto\x12-\n\x03\x64im\x18\x02
- \x03(\x0b\x32
- .tensorflow.TensorShapeProto.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01
- \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3')
- TypeError: __init__() got an unexpected keyword argument 'syntax'</pre>
- </td>
-</tr>
-
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42075397">42075397</a></td>
- <td>A <tt>pip install</tt> command triggers the following error:
-<pre>...<lots of warnings and errors>
-You have not agreed to the Xcode license agreements, please run
-'xcodebuild -license' (for user-level acceptance) or
-'sudo xcodebuild -license' (for system-wide acceptance) from within a
-Terminal window to review and agree to the Xcode license agreements.
-...<more stack trace output>
- File "numpy/core/setup.py", line 653, in get_mathlib_info
-
- raise RuntimeError("Broken toolchain: cannot link a simple C program")
-
-RuntimeError: Broken toolchain: cannot link a simple C program</pre>
-</td>
-
-
-</table>
-
-
-
-
-<a name="TF_PYTHON_URL"></a>
-## The URL of the TensorFlow Python package
-
-A few installation mechanisms require the URL of the TensorFlow Python package.
-The value you specify depends on your Python version.
-
-### Python 2.7
-
-
-<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py2-none-any.whl
-</pre>
-
-
-### Python 3.4, 3.5, or 3.6
-
-
-<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl
-</pre>
diff --git a/tensorflow/docs_src/install/install_raspbian.md b/tensorflow/docs_src/install/install_raspbian.md
deleted file mode 100644
index cf6b6b4f79..0000000000
--- a/tensorflow/docs_src/install/install_raspbian.md
+++ /dev/null
@@ -1,313 +0,0 @@
-# Install TensorFlow on Raspbian
-
-This guide explains how to install TensorFlow on a Raspberry Pi running
-Raspbian. Although these instructions might also work on other Pi variants, we
-have only tested (and we only support) these instructions on machines meeting
-the following requirements:
-
-* Raspberry Pi devices running Raspbian 9.0 or higher
-
-## Determine how to install TensorFlow
-
-You must pick the mechanism by which you install TensorFlow. The supported
-choices are as follows:
-
-* "Native" pip.
-* Cross-compiling from sources.
-
-**We recommend pip installation.**
-
-## Installing with native pip
-
-We have uploaded the TensorFlow binaries to piwheels.org. Therefore, you can
-install TensorFlow through pip.
-
-The [REQUIRED_PACKAGES section of
-setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
-lists the packages that pip will install or upgrade.
-
-### Prerequisite: Python
-
-In order to install TensorFlow, your system must contain one of the following
-Python versions:
-
-* Python 2.7
-* Python 3.4+
-
-If your system does not already have one of the preceding Python versions,
-[install](https://wiki.python.org/moin/BeginnersGuide/Download) it now. It
-should already be included when Raspbian was installed though, so no extra steps
-should be needed.
-
-### Prerequisite: pip
-
-[Pip](https://en.wikipedia.org/wiki/Pip_\(package_manager\)) installs and
-manages software packages written in Python. If you intend to install with
-native pip, then one of the following flavors of pip must be installed on your
-system:
-
-* `pip3`, for Python 3.n (preferred).
-* `pip`, for Python 2.7.
-
-`pip` or `pip3` was probably installed on your system when you installed Python.
-To determine whether pip or pip3 is actually installed on your system, issue one
-of the following commands:
-
-<pre>$ <b>pip3 -V</b> # for Python 3.n
-$ <b>pip -V</b> # for Python 2.7</pre>
-
-If it gives the error "Command not found", then the package has not been
-installed yet. To install if for the first time, run:
-
-<pre>$ sudo apt-get install python3-pip # for Python 3.n
-$ sudo apt-get install python-pip # for Python 2.7</pre>
-
-You can find more help on installing and upgrading pip in
-[the Raspberry Pi documentation](https://www.raspberrypi.org/documentation/linux/software/python.md).
-
-### Prerequisite: Atlas
-
-[Atlas](http://math-atlas.sourceforge.net/) is a linear algebra library that
-numpy depends on, and so needs to be installed before TensorFlow. To add it to
-your system, run the following command:
-
-<pre>$ sudo apt install libatlas-base-dev</pre>
-
-### Install TensorFlow
-
-Assuming the prerequisite software is installed on your Pi, install TensorFlow
-by invoking **one** of the following commands:
-
-<pre>$ <b>pip3 install tensorflow</b> # Python 3.n
-$ <b>pip install tensorflow</b> # Python 2.7</pre>
-
-This can take some time on certain platforms like the Pi Zero, where some Python
-packages like scipy that TensorFlow depends on need to be compiled before the
-installation can complete. The Python 3 version will typically be faster to
-install because piwheels.org has pre-built versions of the dependencies
-available, so this is our recommended option.
-
-### Next Steps
-
-After installing TensorFlow, [validate your
-installation](#ValidateYourInstallation) to confirm that the installation worked
-properly.
-
-### Uninstalling TensorFlow
-
-To uninstall TensorFlow, issue one of following commands:
-
-<pre>$ <b>pip uninstall tensorflow</b>
-$ <b>pip3 uninstall tensorflow</b> </pre>
-
-## Cross-compiling from sources
-
-Cross-compilation means building on a different machine than than you'll be
-deploying on. Since Raspberry Pi's only have limited RAM and comparatively slow
-processors, and TensorFlow has a large amount of source code to compile, it's
-easier to use a MacOS or Linux desktop or laptop to handle the build process.
-Because it can take over 24 hours to build on a Pi, and requires external swap
-space to cope with the memory shortage, we recommend using cross-compilation if
-you do need to compile TensorFlow from source. To make the dependency management
-process easier, we also recommend using Docker to help simplify building.
-
-Note that we provide well-tested, pre-built TensorFlow binaries for Raspbian
-systems. So, don't build a TensorFlow binary yourself unless you are very
-comfortable building complex packages from source and dealing with the
-inevitable aftermath should things not go exactly as documented
-
-### Prerequisite: Docker
-
-Install Docker on your machine as described in the [Docker
-documentation](https://docs.docker.com/engine/installation/#/on-macos-and-windows).
-
-### Clone the TensorFlow repository
-
-Start the process of building TensorFlow by cloning a TensorFlow repository.
-
-To clone **the latest** TensorFlow repository, issue the following command:
-
-<pre>$ <b>git clone https://github.com/tensorflow/tensorflow</b> </pre>
-
-The preceding <code>git clone</code> command creates a subdirectory named
-`tensorflow`. After cloning, you may optionally build a **specific branch**
-(such as a release branch) by invoking the following commands:
-
-<pre>
-$ <b>cd tensorflow</b>
-$ <b>git checkout</b> <i>Branch</i> # where <i>Branch</i> is the desired branch
-</pre>
-
-For example, to work with the `r1.0` release instead of the master release,
-issue the following command:
-
-<pre>$ <b>git checkout r1.0</b></pre>
-
-### Build from source
-
-To compile TensorFlow and produce a binary pip can install, do the following:
-
-1. Start a terminal.
-2. Navigate to the directory containing the tensorflow source code.
-3. Run a command to cross-compile the library, for example:
-
-<pre>$ CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.4" \
-tensorflow/tools/ci_build/ci_build.sh PI-PYTHON3 tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
- </pre>
-
-This will build a pip .whl file for Python 3.4, with Arm v7 instructions that
-will only work on the Pi models 2 or 3. These NEON instructions are required for
-the fastest operation on those devices, but you can build a library that will
-run across all Pi devices by passing `PI_ONE` at the end of the command line.
-You can also target Python 2.7 by omitting the initial docker parameters. Here's
-an example of building for Python 2.7 and Raspberry Pi model Zero or One
-devices:
-
-<pre>$ tensorflow/tools/ci_build/ci_build.sh PI tensorflow/tools/ci_build/pi/build_raspberry_pi.sh PI_ONE</pre>
-
-This will take some time to complete, typically twenty or thirty minutes, and
-should produce a .whl file in an output-artifacts sub-folder inside your source
-tree at the end. This wheel file can be installed through pip or pip3 (depending
-on your Python version) by copying it to a Raspberry Pi and running a terminal
-command like this (with the name of your actual file substituted):
-
-<pre>$ pip3 install tensorflow-1.9.0-cp34-none-linux_armv7l.whl</pre>
-
-### Troubleshooting the build
-
-The build script uses Docker internally to create a Linux virtual machine to
-handle the compilation. If you do have problems running the script, first check
-that you're able to run Docker tests like `docker run hello-world` on your
-system.
-
-If you're building from the latest development branch, try syncing to an older
-version that's known to work, for example release 1.9, with a command like this:
-
-<pre>$ <b>git checkout r1.0</b></pre>
-
-<a name="ValidateYourInstallation"></a>
-
-## Validate your installation
-
-To validate your TensorFlow installation, do the following:
-
-1. Ensure that your environment is prepared to run TensorFlow programs.
-2. Run a short TensorFlow program.
-
-### Prepare your environment
-
-If you installed on native pip, Virtualenv, or Anaconda, then do the following:
-
-1. Start a terminal.
-2. If you installed TensorFlow source code, navigate to any directory *except*
- one containing TensorFlow source code.
-
-### Run a short TensorFlow program
-
-Invoke python from your shell as follows:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
-# Python
-import tensorflow as tf
-hello = tf.constant('Hello, TensorFlow!')
-sess = tf.Session()
-print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin writing
-TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-If you're running with Python 3.5, you may see a warning when you first import
-TensorFlow. This is not an error, and TensorFlow should continue to run with no
-problems, despite the log message.
-
-If the system outputs an error message instead of a greeting, see [Common
-installation problems](#common_installation_problems).
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-## Common installation problems
-
-We are relying on Stack Overflow to document TensorFlow installation problems
-and their remedies. The following table contains links to Stack Overflow answers
-for some common installation problems. If you encounter an error message or
-other installation problem not listed in the following table, search for it on
-Stack Overflow. If Stack Overflow doesn't show the error message, ask a new
-question about it on Stack Overflow and specify the `tensorflow` tag.
-
-<table>
-<tr> <th>Stack Overflow Link</th> <th>Error Message</th> </tr>
-
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42006320">42006320</a></td>
- <td><pre>ImportError: Traceback (most recent call last):
-File ".../tensorflow/core/framework/graph_pb2.py", line 6, in <module>
-from google.protobuf import descriptor as _descriptor
-ImportError: cannot import name 'descriptor'</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/33623453">33623453</a></td>
- <td><pre>IOError: [Errno 2] No such file or directory:
- '/tmp/pip-o6Tpui-build/setup.py'</tt></pre>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/questions/35190574">35190574</a> </td>
- <td><pre>SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify
- failed</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42009190">42009190</a></td>
- <td><pre>
- Installing collected packages: setuptools, protobuf, wheel, numpy, tensorflow
- Found existing installation: setuptools 1.1.6
- Uninstalling setuptools-1.1.6:
- Exception:
- ...
- [Errno 1] Operation not permitted:
- '/tmp/pip-a1DXRT-uninstall/.../lib/python/_markerlib' </pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/33622019">33622019</a></td>
- <td><pre>ImportError: No module named copyreg</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/37810228">37810228</a></td>
- <td>During a <tt>pip install</tt> operation, the system returns:
- <pre>OSError: [Errno 1] Operation not permitted</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/33622842">33622842</a></td>
- <td>An <tt>import tensorflow</tt> statement triggers an error such as the
- following:<pre>Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- File "/usr/local/lib/python2.7/site-packages/tensorflow/__init__.py",
- line 4, in <module>
- from tensorflow.python import *
- ...
- File "/usr/local/lib/python2.7/site-packages/tensorflow/core/framework/tensor_shape_pb2.py",
- line 22, in <module>
- serialized_pb=_b('\n,tensorflow/core/framework/tensor_shape.proto\x12\ntensorflow\"d\n\x10TensorShapeProto\x12-\n\x03\x64im\x18\x02
- \x03(\x0b\x32
- .tensorflow.TensorShapeProto.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01
- \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3')
- TypeError: __init__() got an unexpected keyword argument 'syntax'</pre>
- </td>
-</tr>
-
-
-</table>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
deleted file mode 100644
index dfd9fbce4b..0000000000
--- a/tensorflow/docs_src/install/install_sources.md
+++ /dev/null
@@ -1,577 +0,0 @@
-# Install TensorFlow from Sources
-
-This guide explains how to build TensorFlow sources into a TensorFlow binary and
-how to install that TensorFlow binary. Note that we provide well-tested,
-pre-built TensorFlow binaries for Ubuntu, macOS, and Windows systems. In
-addition, there are pre-built TensorFlow
-[docker images](https://hub.docker.com/r/tensorflow/tensorflow/). So, don't
-build a TensorFlow binary yourself unless you are very comfortable building
-complex packages from source and dealing with the inevitable aftermath should
-things not go exactly as documented.
-
-If the last paragraph didn't scare you off, welcome. This guide explains how to
-build TensorFlow on 64-bit desktops and laptops running either of the following
-operating systems:
-
-* Ubuntu
-* macOS X
-
-Note: Some users have successfully built and installed TensorFlow from sources
-on non-supported systems. Please remember that we do not fix issues stemming
-from these attempts.
-
-We **do not support** building TensorFlow on Windows. That said, if you'd like
-to try to build TensorFlow on Windows anyway, use either of the following:
-
-* [Bazel on Windows](https://bazel.build/versions/master/docs/windows.html)
-* [TensorFlow CMake build](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/cmake)
-
-Note: Starting from 1.6 release, our prebuilt binaries will use AVX
-instructions. Older CPUs may not be able to execute these binaries.
-
-## Determine which TensorFlow to install
-
-You must choose one of the following types of TensorFlow to build and install:
-
-* **TensorFlow with CPU support only**. If your system does not have a NVIDIA®
- GPU, build and install this version. Note that this version of TensorFlow is
- typically easier to build and install, so even if you have an NVIDIA GPU, we
- recommend building and installing this version first.
-* **TensorFlow with GPU support**. TensorFlow programs typically run
- significantly faster on a GPU than on a CPU. Therefore, if your system has a
- NVIDIA GPU and you need to run performance-critical applications, you should
- ultimately build and install this version. Beyond the NVIDIA GPU itself,
- your system must also fulfill the NVIDIA software requirements described in
- one of the following documents:
-
- * @ {$install_linux#NVIDIARequirements$Installing TensorFlow on Ubuntu}
- * @ {$install_mac#NVIDIARequirements$Installing TensorFlow on macOS}
-
-## Clone the TensorFlow repository
-
-Start the process of building TensorFlow by cloning a TensorFlow repository.
-
-To clone **the latest** TensorFlow repository, issue the following command:
-
-<pre>$ <b>git clone https://github.com/tensorflow/tensorflow</b> </pre>
-
-The preceding <code>git clone</code> command creates a subdirectory named
-`tensorflow`. After cloning, you may optionally build a **specific branch**
-(such as a release branch) by invoking the following commands:
-
-<pre>
-$ <b>cd tensorflow</b>
-$ <b>git checkout</b> <i>Branch</i> # where <i>Branch</i> is the desired branch
-</pre>
-
-For example, to work with the `r1.0` release instead of the master release,
-issue the following command:
-
-<pre>$ <b>git checkout r1.0</b></pre>
-
-Next, you must prepare your environment for [Linux](#PrepareLinux) or
-[macOS](#PrepareMac)
-
-<a name="PrepareLinux"></a>
-
-## Prepare environment for Linux
-
-Before building TensorFlow on Linux, install the following build tools on your
-system:
-
-* bazel
-* TensorFlow Python dependencies
-* optionally, NVIDIA packages to support TensorFlow for GPU.
-
-### Install Bazel
-
-If bazel is not installed on your system, install it now by following
-[these directions](https://bazel.build/versions/master/docs/install.html).
-
-### Install TensorFlow Python dependencies
-
-To install TensorFlow, you must install the following packages:
-
-* `numpy`, which is a numerical processing package that TensorFlow requires.
-* `dev`, which enables adding extensions to Python.
-* `pip`, which enables you to install and manage certain Python packages.
-* `wheel`, which enables you to manage Python compressed packages in the wheel
- (.whl) format.
-
-To install these packages for Python 2.7, issue the following command:
-
-<pre>
-$ <b>sudo apt-get install python-numpy python-dev python-pip python-wheel</b>
-</pre>
-
-To install these packages for Python 3.n, issue the following command:
-
-<pre>
-$ <b>sudo apt-get install python3-numpy python3-dev python3-pip python3-wheel</b>
-</pre>
-
-### Optional: install TensorFlow for GPU prerequisites
-
-If you are building TensorFlow without GPU support, skip this section.
-
-The following NVIDIA® <i>hardware</i> must be installed on your system:
-
-* GPU card with CUDA Compute Capability 3.5 or higher. See
- [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of
- supported GPU cards.
-
-The following NVIDIA® <i>software</i> must be installed on your system:
-
-* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher.
-* [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0.
-* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend
- version 7.1.x.
-* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but
- you also need to append its path to the `LD_LIBRARY_PATH` environment
- variable: `export
- LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64`
-* *OPTIONAL*: [NCCL 2.2](https://developer.nvidia.com/nccl) to use TensorFlow
- with multiple GPUs.
-* *OPTIONAL*:
- [TensorRT](http://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html)
- which can improve latency and throughput for inference for some models.
-
-While it is possible to install the NVIDIA libraries via `apt-get` from the
-NVIDIA repository, the libraries and headers are installed in locations that
-make it difficult to configure and debug build issues. Downloading and
-installing the libraries manually or using docker
-([latest-devel-gpu](https://hub.docker.com/r/tensorflow/tensorflow/tags/)) is
-recommended.
-
-### Next
-
-After preparing the environment, you must now
-[configure the installation](#ConfigureInstallation).
-
-<a name="PrepareMac"></a>
-
-## Prepare environment for macOS
-
-Before building TensorFlow, you must install the following on your system:
-
-* bazel
-* TensorFlow Python dependencies.
-* optionally, NVIDIA packages to support TensorFlow for GPU.
-
-### Install bazel
-
-If bazel is not installed on your system, install it now by following
-[these directions](https://bazel.build/versions/master/docs/install.html#mac-os-x).
-
-### Install python dependencies
-
-To build TensorFlow, you must install the following packages:
-
-* six
-* mock
-* numpy, which is a numerical processing package that TensorFlow requires.
-* wheel, which enables you to manage Python compressed packages in the wheel
- (.whl) format.
-
-You may install the python dependencies using pip. If you don't have pip on your
-machine, we recommend using homebrew to install Python and pip as
-[documented here](http://docs.python-guide.org/en/latest/starting/install/osx/).
-If you follow these instructions, you will not need to disable SIP.
-
-After installing pip, invoke the following commands:
-
-<pre> $ <b>sudo pip install six numpy wheel mock h5py</b>
- $ <b>sudo pip install keras_applications==1.0.4 --no-deps</b>
- $ <b>sudo pip install keras_preprocessing==1.0.2 --no-deps</b>
-</pre>
-
-Note: These are just the minimum requirements to _build_ tensorflow. Installing
-the pip package will download additional packages required to _run_ it. If you
-plan on executing tasks directly with `bazel` , without the pip installation,
-you may need to install additional python packages. For example, you should `pip
-install mock enum34` before running TensorFlow's tests with bazel.
-
-<a name="ConfigureInstallation"></a>
-
-## Configure the installation
-
-The root of the source tree contains a bash script named <code>configure</code>.
-This script asks you to identify the pathname of all relevant TensorFlow
-dependencies and specify other build configuration options such as compiler
-flags. You must run this script *prior* to creating the pip package and
-installing TensorFlow.
-
-If you wish to build TensorFlow with GPU, `configure` will ask you to specify
-the version numbers of CUDA and cuDNN. If several versions of CUDA or cuDNN are
-installed on your system, explicitly select the desired version instead of
-relying on the default.
-
-One of the questions that `configure` will ask is as follows:
-
-<pre>
-Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]
-</pre>
-
-This question refers to a later phase in which you'll use bazel to
-[build the pip package](#build-the-pip-package) or the
-[C/Java libraries](#BuildCorJava). We recommend accepting the default
-(`-march=native`), which will optimize the generated code for your local
-machine's CPU type. However, if you are building TensorFlow on one CPU type but
-will run TensorFlow on a different CPU type, then consider specifying a more
-specific optimization flag as described in
-[the gcc documentation](https://gcc.gnu.org/onlinedocs/gcc-4.5.3/gcc/i386-and-x86_002d64-Options.html).
-
-Here is an example execution of the `configure` script. Note that your own input
-will likely differ from our sample input:
-
-<pre>
-$ <b>cd tensorflow</b> # cd to the top-level directory created
-$ <b>./configure</b>
-You have bazel 0.15.0 installed.
-Please specify the location of python. [Default is /usr/bin/python]: <b>/usr/bin/python2.7</b>
-
-
-Found possible Python library paths:
- /usr/local/lib/python2.7/dist-packages
- /usr/lib/python2.7/dist-packages
-Please input the desired Python library path to use. Default is [/usr/lib/python2.7/dist-packages]
-
-Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]:
-jemalloc as malloc support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]:
-Google Cloud Platform support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with Hadoop File System support? [Y/n]:
-Hadoop File System support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]:
-Amazon AWS Platform support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]:
-Apache Kafka Platform support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with XLA JIT support? [y/N]:
-No XLA JIT support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with GDR support? [y/N]:
-No GDR support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with VERBS support? [y/N]:
-No VERBS support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]:
-No OpenCL SYCL support will be enabled for TensorFlow.
-
-Do you wish to build TensorFlow with CUDA support? [y/N]: <b>Y</b>
-CUDA support will be enabled for TensorFlow.
-
-Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 9.0]: <b>9.0</b>
-
-
-Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
-
-
-Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: <b>7.0</b>
-
-
-Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
-
-
-Do you wish to build TensorFlow with TensorRT support? [y/N]:
-No TensorRT support will be enabled for TensorFlow.
-
-Please specify the NCCL version you want to use. If NCLL 2.2 is not installed, then you can use version 1.3 that can be fetched automatically but it may have worse performance with multiple GPUs. [Default is 2.2]: 1.3
-
-
-Please specify a list of comma-separated Cuda compute capabilities you want to build with.
-You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
-Please note that each additional compute capability significantly increases your
-build time and binary size. [Default is: 3.5,7.0] <b>6.1</b>
-
-
-Do you want to use clang as CUDA compiler? [y/N]:
-nvcc will be used as CUDA compiler.
-
-Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]:
-
-
-Do you wish to build TensorFlow with MPI support? [y/N]:
-No MPI support will be enabled for TensorFlow.
-
-Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
-
-
-Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]:
-Not configuring the WORKSPACE for Android builds.
-
-Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
- --config=mkl # Build with MKL support.
- --config=monolithic # Config for mostly static monolithic build.
-Configuration finished
-</pre>
-
-If you told `configure` to build for GPU support, then `configure` will create a
-canonical set of symbolic links to the CUDA libraries on your system. Therefore,
-every time you change the CUDA library paths, you must rerun the `configure`
-script before re-invoking the <code>bazel build</code> command.
-
-Note the following:
-
-* Although it is possible to build both CUDA and non-CUDA configs under the
- same source tree, we recommend running `bazel clean` when switching between
- these two configurations in the same source tree.
-* If you don't run the `configure` script *before* running the `bazel build`
- command, the `bazel build` command will fail.
-
-## Build the pip package
-
-Note: If you're only interested in building the libraries for the TensorFlow C
-or Java APIs, see [Build the C or Java libraries](#BuildCorJava), you do not
-need to build the pip package in that case.
-
-### CPU-only support
-
-To build a pip package for TensorFlow with CPU-only support:
-
-<pre>
-$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
-</pre>
-
-To build a pip package for TensorFlow with CPU-only support for the Intel®
-MKL-DNN:
-
-<pre>
-$ bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
-</pre>
-
-### GPU support
-
-To build a pip package for TensorFlow with GPU support:
-
-<pre>
-$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
-</pre>
-
-**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow
-website are built with gcc 4, which uses the older ABI. To make your build
-compatible with the older ABI, you need to add
-`--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI
-compatibility allows custom ops built against the TensorFlow pip package to
-continue to work against your built package.
-
-<b>Tip:</b> By default, building TensorFlow from sources consumes a lot of RAM.
-If RAM is an issue on your system, you may limit RAM usage by specifying
-<code>--local_resources 2048,.5,1.0</code> while invoking `bazel`.
-
-The <code>bazel build</code> command builds a script named `build_pip_package`.
-Running this script as follows will build a `.whl` file within the
-`/tmp/tensorflow_pkg` directory:
-
-<pre>
-$ <b>bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg</b>
-</pre>
-
-## Install the pip package
-
-Invoke `pip install` to install that pip package. The filename of the `.whl`
-file depends on your platform. For example, the following command will install
-the pip package
-
-for TensorFlow 1.10.0 on Linux:
-
-<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.10.0-py2-none-any.whl</b>
-</pre>
-
-## Validate your installation
-
-Validate your TensorFlow installation by doing the following:
-
-Start a terminal.
-
-Change directory (`cd`) to any directory on your system other than the
-`tensorflow` subdirectory from which you invoked the `configure` command.
-
-Invoke python:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
-# Python
-import tensorflow as tf
-hello = tf.constant('Hello, TensorFlow!')
-sess = tf.Session()
-print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin writing
-TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-If the system outputs an error message instead of a greeting, see
-[Common installation problems](#common_installation_problems).
-
-## Common build and installation problems
-
-The build and installation problems you encounter typically depend on the
-operating system. See the "Common installation problems" section of one of the
-following guides:
-
-* @
- {$install_linux#common_installation_problems$Installing TensorFlow on Linux}
-* @
- {$install_mac#common_installation_problems$Installing TensorFlow on Mac OS}
-* @
- {$install_windows#common_installation_problems$Installing TensorFlow on Windows}
-
-Beyond the errors documented in those two guides, the following table notes
-additional errors specific to building TensorFlow. Note that we are relying on
-Stack Overflow as the repository for build and installation problems. If you
-encounter an error message not listed in the preceding two guides or in the
-following table, search for it on Stack Overflow. If Stack Overflow doesn't show
-the error message, ask a new question on Stack Overflow and specify the
-`tensorflow` tag.
-
-<table>
-<tr> <th>Stack Overflow Link</th> <th>Error Message</th> </tr>
-
-<tr>
- <td><a
- href="https://stackoverflow.com/questions/41293077/how-to-compile-tensorflow-with-sse4-2-and-avx-instructions">41293077</a></td>
- <td><pre>W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow
- library wasn't compiled to use SSE4.1 instructions, but these are available on
- your machine and could speed up CPU computations.</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42013316">42013316</a></td>
- <td><pre>ImportError: libcudart.so.8.0: cannot open shared object file:
- No such file or directory</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42013316">42013316</a></td>
- <td><pre>ImportError: libcudnn.5: cannot open shared object file:
- No such file or directory</pre></td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/35953210">35953210</a></td>
- <td>Invoking `python` or `ipython` generates the following error:
- <pre>ImportError: cannot import name pywrap_tensorflow</pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/questions/45276830">45276830</a></td>
- <td><pre>external/local_config_cc/BUILD:50:5: in apple_cc_toolchain rule
- @local_config_cc//:cc-compiler-darwin_x86_64: Xcode version must be specified
- to use an Apple CROSSTOOL.</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/47080760">47080760</a></td>
- <td><pre>undefined reference to `cublasGemmEx@libcublas.so.9.0'</pre></td>
-</tr>
-
-</table>
-
-## Tested source configurations
-
-**Linux**
-<table>
-<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.15.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.10.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.15.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.10.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.8.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.10.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.7.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.6.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.8.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.5.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.8.0</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.5.4</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.4.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.5.4</td><td>6</td><td>8</td></tr>
-<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>6</td><td>8</td></tr>
-<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>5.1</td><td>8</td></tr>
-<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
-<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
-</table>
-
-**Mac**
-<table>
-<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.15.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.5.4</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
-<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
-</table>
-
-**Windows**
-<table>
-<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.10.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.8.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.7.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.6.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.5.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
-<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.4.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>6</td><td>8</td></tr>
-<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>6</td><td>8</td></tr>
-<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
-<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
-<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
-</table>
-
-<a name="BuildCorJava"></a>
-
-## Build the C or Java libraries
-
-The instructions above are tailored to building the TensorFlow Python packages.
-
-If you're interested in building the libraries for the TensorFlow C API, do the
-following:
-
-1. Follow the steps up to [Configure the installation](#ConfigureInstallation)
-2. Build the C libraries following instructions in the
- [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md).
-
-If you're interested inv building the libraries for the TensorFlow Java API, do
-the following:
-
-1. Follow the steps up to [Configure the installation](#ConfigureInstallation)
-2. Build the Java library following instructions in the
- [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md).
diff --git a/tensorflow/docs_src/install/install_sources_windows.md b/tensorflow/docs_src/install/install_sources_windows.md
deleted file mode 100644
index a1da122317..0000000000
--- a/tensorflow/docs_src/install/install_sources_windows.md
+++ /dev/null
@@ -1,320 +0,0 @@
-# Install TensorFlow from Sources on Windows
-
-This guide explains how to build TensorFlow sources into a TensorFlow binary and
-how to install that TensorFlow binary on Windows.
-
-## Determine which TensorFlow to install
-
-You must choose one of the following types of TensorFlow to build and install:
-
-* **TensorFlow with CPU support only**. If your system does not have a NVIDIA®
- GPU, build and install this version. Note that this version of TensorFlow is
- typically easier to build and install, so even if you have an NVIDIA GPU, we
- recommend building and installing this version first.
-* **TensorFlow with GPU support**. TensorFlow programs typically run
- significantly faster on a GPU than on a CPU. Therefore, if your system has a
- NVIDIA GPU and you need to run performance-critical applications, you should
- ultimately build and install this version. Beyond the NVIDIA GPU itself,
- your system must also fulfill the NVIDIA software requirements described in
- the following document:
-
- * [Installing TensorFlow on Windows](install_windows.md#NVIDIARequirements)
-
-## Prepare environment for Windows
-
-Before building TensorFlow on Windows, install the following build tools on your
-system:
-
-* [MSYS2](#InstallMSYS2)
-* [Visual C++ build tools](#InstallVCBuildTools)
-* [Bazel for Windows](#InstallBazel)
-* [TensorFlow Python dependencies](#InstallPython)
-* [optionally, NVIDIA packages to support TensorFlow for GPU](#InstallCUDA)
-
-<a name="InstallMSYS2"></a>
-
-### Install MSYS2
-
-Bash bin tools are used in TensorFlow Bazel build, you can install them through [MSYS2](https://www.msys2.org/).
-
-Assume you installed MSYS2 at `C:\msys64`, add `C:\msys64\usr\bin` to your `%PATH%` environment variable.
-
-To install necessary bash bin tools, issue the following command under `cmd.exe`:
-
-<pre>
-C:\> <b>pacman -S git patch unzip</b>
-</pre>
-
-<a name="InstallVCBuildTools"></a>
-
-### Install Visual C++ Build Tools 2015
-
-To build TensorFlow, you need to install Visual C++ build tools 2015. It is a part of Visual Studio 2015.
-But you can install it separately by the following way:
-
- * Open the [official downloand page](https://visualstudio.microsoft.com/vs/older-downloads/).
- * Go to <b>Redistributables and Build Tools</b> section.
- * Find <b>Microsoft Build Tools 2015 Update 3</b> and click download.
- * Run the installer.
-
-It's possible to build TensorFlow with newer version of Visual C++ build tools,
-but we only test against Visual Studio 2015 Update 3.
-
-<a name="InstallBazel"></a>
-
-### Install Bazel
-
-If bazel is not installed on your system, install it now by following
-[these instructions](https://docs.bazel.build/versions/master/install-windows.html).
-It is recommended to use a Bazel version >= `0.15.0`.
-
-Add the directory where you installed Bazel to your `%PATH%` environment variable.
-
-<a name="InstallPython"></a>
-
-### Install TensorFlow Python dependencies
-
-If you don't have Python 3.5 or Python 3.6 installed, install it now:
-
- * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/)
- * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/)
-
-To build and install TensorFlow, you must install the following python packages:
-
-* `six`, which provides simple utilities for wrapping over differences between
- Python 2 and Python 3.
-* `numpy`, which is a numerical processing package that TensorFlow requires.
-* `wheel`, which enables you to manage Python compressed packages in the wheel
- (.whl) format.
-* `keras_applications`, the applications module of the Keras deep learning library.
-* `keras_preprocessing`, the data preprocessing and data augmentation module
- of the Keras deep learning library.
-
-Assume you already have `pip3` in `%PATH%`, issue the following command:
-
-<pre>
-C:\> <b>pip3 install six numpy wheel</b>
-C:\> <b>pip3 install keras_applications==1.0.4 --no-deps</b>
-C:\> <b>pip3 install keras_preprocessing==1.0.2 --no-deps</b>
-</pre>
-
-<a name="InstallCUDA"></a>
-
-### Optional: install TensorFlow for GPU prerequisites
-
-If you are building TensorFlow without GPU support, skip this section.
-
-The following NVIDIA® _hardware_ must be installed on your system:
-
-* GPU card with CUDA Compute Capability 3.5 or higher. See
- [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of
- supported GPU cards.
-
-The following NVIDIA® _software_ must be installed on your system:
-
-* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher.
-* [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0.
-* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend
- version 7.1.x.
-* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but
- you also need to append its path to `%PATH%` environment
- variable.
-
-Assume you have CUDA Toolkit installed at `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0`
-and cuDNN at `C:\tools\cuda`, issue the following commands.
-
-<pre>
-C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\bin;%PATH%
-C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\extras\CUPTI\libx64;%PATH%
-C:\> SET PATH=C:\tools\cuda\bin;%PATH%
-</pre>
-
-## Clone the TensorFlow repository
-
-Now you need to clone **the latest** TensorFlow repository,
-thanks to MSYS2 we already have `git` avaiable, issue the following command:
-
-<pre>C:\> <b>git clone https://github.com/tensorflow/tensorflow.git</b> </pre>
-
-The preceding <code>git clone</code> command creates a subdirectory named
-`tensorflow`. After cloning, you may optionally build a **specific branch**
-(such as a release branch) by invoking the following commands:
-
-<pre>
-C:\> <b>cd tensorflow</b>
-C:\> <b>git checkout</b> <i>Branch</i> # where <i>Branch</i> is the desired branch
-</pre>
-
-For example, to work with the `r1.11` release instead of the master release,
-issue the following command:
-
-<pre>C:\> <b>git checkout r1.11</b></pre>
-
-Next, you must now configure the installation.
-
-## Configure the installation
-
-The root of the source tree contains a python script named <code>configure.py</code>.
-This script asks you to identify the pathname of all relevant TensorFlow
-dependencies and specify other build configuration options such as compiler
-flags. You must run this script *prior* to creating the pip package and
-installing TensorFlow.
-
-If you wish to build TensorFlow with GPU, `configure.py` will ask you to specify
-the version numbers of CUDA and cuDNN. If several versions of CUDA or cuDNN are
-installed on your system, explicitly select the desired version instead of
-relying on the default.
-
-One of the questions that `configure.py` will ask is as follows:
-
-<pre>
-Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]:
-</pre>
-
-Here is an example execution of the `configure.py` script. Note that your own input
-will likely differ from our sample input:
-
-<pre>
-C:\> <b>cd tensorflow</b> # cd to the top-level directory created
-C:\tensorflow> <b>python ./configure.py</b>
-Starting local Bazel server and connecting to it...
-................
-You have bazel 0.15.0 installed.
-Please specify the location of python. [Default is C:\python36\python.exe]:
-
-Found possible Python library paths:
- C:\python36\lib\site-packages
-Please input the desired Python library path to use. Default is [C:\python36\lib\site-packages]
-
-Do you wish to build TensorFlow with CUDA support? [y/N]: <b>Y</b>
-CUDA support will be enabled for TensorFlow.
-
-Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 9.0]:
-
-Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]:
-
-Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: <b>7.0</b>
-
-Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]: <b>C:\tools\cuda</b>
-
-Please specify a list of comma-separated Cuda compute capabilities you want to build with.
-You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
-Please note that each additional compute capability significantly increases your build time and binary size. [Default is: 3.5,7.0]: <b>3.7</b>
-
-Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]:
-
-Would you like to override eigen strong inline for some C++ compilation to reduce the compilation time? [Y/n]:
-Eigen strong inline overridden.
-
-Configuration finished
-</pre>
-
-## Build the pip package
-
-### CPU-only support
-
-To build a pip package for TensorFlow with CPU-only support:
-
-<pre>
-C:\tensorflow> <b>bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package</b>
-</pre>
-
-### GPU support
-
-To build a pip package for TensorFlow with GPU support:
-
-<pre>
-C:\tensorflow> <b>bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package</b>
-</pre>
-
-**NOTE :** When building with GPU support, you might want to add `--copt=-nvcc_options=disable-warnings`
-to suppress nvcc warning messages.
-
-The `bazel build` command builds a binary named `build_pip_package`
-(an executable binary to launch bash and run a bash script to create the pip package).
-Running this binary as follows will build a `.whl` file within the `C:/tmp/tensorflow_pkg` directory:
-
-<pre>
-C:\tensorflow> <b>bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg</b>
-</pre>
-
-## Install the pip package
-
-Invoke `pip3 install` to install that pip package. The filename of the `.whl`
-file depends on the TensorFlow version and your platform. For example, the
-following command will install the pip package for TensorFlow 1.11.0rc0:
-
-<pre>
-C:\tensorflow> <b>pip3 install C:/tmp/tensorflow_pkg/tensorflow-1.11.0rc0-cp36-cp36m-win_amd64.whl</b>
-</pre>
-
-## Validate your installation
-
-Validate your TensorFlow installation by doing the following:
-
-Start a terminal.
-
-Change directory (`cd`) to any directory on your system other than the
-`tensorflow` subdirectory from which you invoked the `configure` command.
-
-Invoke python:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
-# Python
-import tensorflow as tf
-hello = tf.constant('Hello, TensorFlow!')
-sess = tf.Session()
-print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin writing
-TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-## Build under MSYS shell
-The above instruction assumes you are building under the Windows native command line (`cmd.exe`), but you can also
-build TensorFlow from MSYS shell. There are a few things to notice:
-
-* Disable the path conversion heuristic in MSYS. MSYS automatically converts arguments that look
- like a Unix path to Windows path when running a program, this will confuse Bazel.
- (eg. A Bazel label `//foo/bar:bin` is considered a Unix absolute path, only because it starts with a slash)
-
- ```sh
-$ export MSYS_NO_PATHCONV=1
-$ export MSYS2_ARG_CONV_EXCL="*"
-```
-
-* Add the directory where you install Bazel in `$PATH`. Assume you have Bazel
- installed at `C:\tools\bazel.exe`, issue the following command:
-
- ```sh
-# `:` is used as path separator, so we have to convert the path to Unix style.
-$ export PATH="/c/tools:$PATH"
-```
-
-* Add the directory where you install Python in `$PATH`. Assume you have
- Python installed at `C:\Python36\python.exe`, issue the following command:
-
- ```sh
-$ export PATH="/c/Python36:$PATH"
-```
-
-* If you have Python in `$PATH`, you can run configure script just by
- `./configure`, a shell script will help you invoke python.
-
-* (For GPU build only) Add Cuda and cuDNN bin directories in `$PATH` in the following way:
-
- ```sh
-$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH"
-$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH"
-$ export PATH="/c/tools/cuda/bin:$PATH"
-```
-
-The rest steps should be the same as building under `cmd.exe`.
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
deleted file mode 100644
index 0bb0e5aeb9..0000000000
--- a/tensorflow/docs_src/install/install_windows.md
+++ /dev/null
@@ -1,227 +0,0 @@
-# Install TensorFlow on Windows
-
-This guide explains how to install TensorFlow on Windows. Although these
-instructions might also work on other Windows variants, we have only
-tested (and we only support) these instructions on machines meeting the
-following requirements:
-
- * 64-bit, x86 desktops or laptops
- * Windows 7 or later
-
-
-## Determine which TensorFlow to install
-
-You must choose one of the following types of TensorFlow to install:
-
- * **TensorFlow with CPU support only**. If your system does not have a
- NVIDIA® GPU, you must install this version. Note that this version of
- TensorFlow is typically much easier to install (typically,
- in 5 or 10 minutes), so even if you have an NVIDIA GPU, we recommend
- installing this version first. Prebuilt binaries will use AVX instructions.
- * **TensorFlow with GPU support**. TensorFlow programs typically run
- significantly faster on a GPU than on a CPU. Therefore, if your
- system has a NVIDIA® GPU meeting the prerequisites shown below
- and you need to run performance-critical applications, you should
- ultimately install this version.
-
-<a name="NVIDIARequirements"></a>
-
-### Requirements to run TensorFlow with GPU support
-
-If you are installing TensorFlow with GPU support using one of the mechanisms
-described in this guide, then the following NVIDIA software must be
-installed on your system:
-
- * CUDA® Toolkit 9.0. For details, see
- [NVIDIA's
- documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/)
- Ensure that you append the relevant Cuda pathnames to the `%PATH%`
- environment variable as described in the NVIDIA documentation.
- * The NVIDIA drivers associated with CUDA Toolkit 9.0.
- * cuDNN v7.0. For details, see
- [NVIDIA's documentation](https://developer.nvidia.com/cudnn).
- Note that cuDNN is typically installed in a different location from the
- other CUDA DLLs. Ensure that you add the directory where you installed
- the cuDNN DLL to your `%PATH%` environment variable.
- * GPU card with CUDA Compute Capability 3.0 or higher for building
- from source and 3.5 or higher for our binaries. See
- [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a
- list of supported GPU cards.
-
-If you have a different version of one of the preceding packages, please
-change to the specified versions. In particular, the cuDNN version
-must match exactly: TensorFlow will not load if it cannot find `cuDNN64_7.dll`.
-To use a different version of cuDNN, you must build from source.
-
-## Determine how to install TensorFlow
-
-You must pick the mechanism by which you install TensorFlow. The
-supported choices are as follows:
-
- * "native" pip
- * Anaconda
-
-Native pip installs TensorFlow directly on your system without going
-through a virtual environment. Since a native pip installation is not
-walled-off in a separate container, the pip installation might interfere
-with other Python-based installations on your system. However, if you
-understand pip and your Python environment, a native pip installation
-often entails only a single command! Furthermore, if you install with
-native pip, users can run TensorFlow programs from any directory on
-the system.
-
-In Anaconda, you may use conda to create a virtual environment.
-However, within Anaconda, we recommend installing TensorFlow with the
-`pip install` command, not with the `conda install` command.
-
-**NOTE:** The conda package is community supported, not officially supported.
-That is, the TensorFlow team neither tests nor maintains this conda package.
-Use that package at your own risk.
-
-
-## Installing with native pip
-
-If one of the following versions of Python is not installed on your machine,
-install it now:
-
- * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/)
- * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/)
-
-TensorFlow supports Python 3.5.x and 3.6.x on Windows.
-Note that Python 3 comes with the pip3 package manager, which is the
-program you'll use to install TensorFlow.
-
-To install TensorFlow, start a terminal. Then issue the appropriate
-<tt>pip3 install</tt> command in that terminal. To install the CPU-only
-version of TensorFlow, enter the following command:
-
-<pre>C:\> <b>pip3 install --upgrade tensorflow</b></pre>
-
-To install the GPU version of TensorFlow, enter the following command:
-
-<pre>C:\> <b>pip3 install --upgrade tensorflow-gpu</b></pre>
-
-## Installing with Anaconda
-
-**The Anaconda installation is community supported, not officially supported.**
-
-Take the following steps to install TensorFlow in an Anaconda environment:
-
- 1. Follow the instructions on the
- [Anaconda download site](https://www.continuum.io/downloads)
- to download and install Anaconda.
-
- 2. Create a conda environment named <tt>tensorflow</tt>
- by invoking the following command:
-
- <pre>C:\> <b>conda create -n tensorflow pip python=3.5</b> </pre>
-
- 3. Activate the conda environment by issuing the following command:
-
- <pre>C:\> <b>activate tensorflow</b>
- (tensorflow)C:\> # Your prompt should change </pre>
-
- 4. Issue the appropriate command to install TensorFlow inside your conda
- environment. To install the CPU-only version of TensorFlow, enter the
- following command:
-
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade tensorflow</b> </pre>
-
- To install the GPU version of TensorFlow, enter the following command
- (on a single line):
-
- <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade tensorflow-gpu</b> </pre>
-
-## Validate your installation
-
-Start a terminal.
-
-If you installed through Anaconda, activate your Anaconda environment.
-
-Invoke python from your shell as follows:
-
-<pre>$ <b>python</b></pre>
-
-Enter the following short program inside the python interactive shell:
-
-```python
->>> import tensorflow as tf
->>> hello = tf.constant('Hello, TensorFlow!')
->>> sess = tf.Session()
->>> print(sess.run(hello))
-```
-
-If the system outputs the following, then you are ready to begin writing
-TensorFlow programs:
-
-<pre>Hello, TensorFlow!</pre>
-
-If the system outputs an error message instead of a greeting, see [Common
-installation problems](#common_installation_problems).
-
-To learn more, see the [TensorFlow tutorials](../tutorials/).
-
-## Common installation problems
-
-We are relying on Stack Overflow to document TensorFlow installation problems
-and their remedies. The following table contains links to Stack Overflow
-answers for some common installation problems.
-If you encounter an error message or other
-installation problem not listed in the following table, search for it
-on Stack Overflow. If Stack Overflow doesn't show the error message,
-ask a new question about it on Stack Overflow and specify
-the `tensorflow` tag.
-
-<table>
-<tr> <th>Stack Overflow Link</th> <th>Error Message</th> </tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/41007279">41007279</a></td>
- <td>
- <pre>[...\stream_executor\dso_loader.cc] Couldn't open CUDA library nvcuda.dll</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/41007279">41007279</a></td>
- <td>
- <pre>[...\stream_executor\cuda\cuda_dnn.cc] Unable to load cuDNN DSO</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="http://stackoverflow.com/q/42006320">42006320</a></td>
- <td><pre>ImportError: Traceback (most recent call last):
-File "...\tensorflow\core\framework\graph_pb2.py", line 6, in <module>
-from google.protobuf import descriptor as _descriptor
-ImportError: cannot import name 'descriptor'</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/42011070">42011070</a></td>
- <td><pre>No module named "pywrap_tensorflow"</pre></td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/42217532">42217532</a></td>
- <td>
- <pre>OpKernel ('op: "BestSplits" device_type: "CPU"') for unknown op: BestSplits</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/43134753">43134753</a></td>
- <td>
- <pre>The TensorFlow library wasn't compiled to use SSE instructions</pre>
- </td>
-</tr>
-
-<tr>
- <td><a href="https://stackoverflow.com/q/38896424">38896424</a></td>
- <td>
- <pre>Could not find a version that satisfies the requirement tensorflow</pre>
- </td>
-</tr>
-
-</table>
diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files
deleted file mode 100644
index 59292f7121..0000000000
--- a/tensorflow/docs_src/install/leftnav_files
+++ /dev/null
@@ -1,18 +0,0 @@
-index.md
-
-### Python
-install_linux.md: Ubuntu
-install_mac.md: MacOS
-install_windows.md: Windows
-install_raspbian.md: Raspbian
-install_sources.md: From source
-install_sources_windows.md: From source on Windows
->>>
-migration.md
-
-### Other Languages
-install_java.md: Java
-install_go.md: Go
-install_c.md: C
-
-
diff --git a/tensorflow/docs_src/install/migration.md b/tensorflow/docs_src/install/migration.md
deleted file mode 100644
index 19315ace2d..0000000000
--- a/tensorflow/docs_src/install/migration.md
+++ /dev/null
@@ -1,336 +0,0 @@
-# Transition to TensorFlow 1.0
-
-
-The APIs in TensorFlow 1.0 have changed in ways that are not all backwards
-compatible. That is, TensorFlow programs that worked on TensorFlow 0.n won't
-necessarily work on TensorFlow 1.0. We have made this API changes to ensure an
-internally-consistent API, and do not plan to make backwards-breaking changes
-throughout the 1.N lifecycle.
-
-This guide walks you through the major changes in the API and how to
-automatically upgrade your programs for TensorFlow 1.0. This guide not
-only steps you through the changes but also explains why we've made them.
-
-## How to upgrade
-
-If you would like to automatically port your code to 1.0, you can try our
-`tf_upgrade.py` script. While this script handles many cases, manual changes
-are sometimes necessary.
- Get this script from our
-[GitHub tree](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/compatibility).
-
-To convert a single 0.n TensorFlow source file to 1.0, enter a
-command of the following format:
-
-<pre>
-$ <b>python tf_upgrade.py --infile</b> <i>InputFile</i> <b>--outfile</b> <i>OutputFile</i>
-</pre>
-
-For example, the following command converts a 0.n TensorFlow
-program named `test.py` to a 1.0 TensorFlow program named `test_1.0.py`:
-
-<pre>
-$ <b>python tf_upgrade.py --infile test.py --outfile test_1.0.py</b>
-</pre>
-
-The `tf_upgrade.py` script also generates a file named `report.txt`, which
-details all the changes it performed and makes additional suggestions about
-changes you might need to make manually.
-
-To upgrade a whole directory of 0.n TensorFlow programs to 1.0,
-enter a command having the following format:
-
-<pre>
-$ <b>python tf_upgrade.py --intree</b> <i>InputDir</i> <b>--outtree</b> <i>OutputDir</i>
-</pre>
-
-For example, the following command converts all the 0.n TensorFlow programs
-in the `/home/user/cool` directory, creating their 1.0 equivalents in
-the `/home/user/cool_1.0` directory:
-
-<pre>
-$ <b>python tf_upgrade.py --intree /home/user/cool --outtree /home/user/cool_1.0</b>
-</pre>
-
-### Limitations
-
-There are a few things to watch out for. Specifically:
-
- * You must manually fix any instances of `tf.reverse()`.
- The `tf_upgrade.py` script will warn you about `tf.reverse()` in
- stdout and in the `report.txt` file.
- * On reordered arguments, `tf_upgrade.py` tries to minimally reformat
- your code, so it cannot automatically change the actual argument order.
- Instead, `tf_upgrade.py` makes your function invocations order-independent
- by introducing keyword arguments.
- * Constructions like `tf.get_variable_scope().reuse_variables()`
- will likely not work. We recommend deleting those lines and replacing
- them with lines such as the following:
-
- <pre class="prettyprint">
- with tf.variable_scope(tf.get_variable_scope(), reuse=True):
- ...
- </pre>
-
- * Analogously to `tf.pack` and `tf.unpack`, we're renamed
- `TensorArray.pack` and `TensorArray.unpack` to
- `TensorArray.stack` and `TensorArray.unstack`. However, `TensorArray.pack`
- and `TensorArray.unpack` cannot be detected lexically since they are
- indirectly related to the `tf` namespace e.g.
- `foo = tf.TensorArray(); foo.unpack()`
-
-## Upgrading your code manually
-
-Instead of running `tf_upgrade.py`, you may manually upgrade your code.
-The remainder of this document provides a comprehensive list of
-all backward incompatible changes made in TensorFlow 1.0.
-
-
-### Variables
-
-Variable functions have been made more consistent and less confusing.
-
-* `tf.VARIABLES`
- * should be renamed to `tf.GLOBAL_VARIABLES`
-* `tf.all_variables`
- * should be renamed to `tf.global_variables`
-* `tf.initialize_all_variables`
- * should be renamed to `tf.global_variables_initializer`
-* `tf.initialize_local_variables`
- * should be renamed to `tf.local_variables_initializer`
-* `tf.initialize_variables`
- * should be renamed to `tf.variables_initializer`
-
-### Summary functions
-
-Summary functions have been consolidated under the `tf.summary` namespace.
-
-* `tf.audio_summary`
- * should be renamed to `tf.summary.audio`
-* `tf.contrib.deprecated.histogram_summary`
- * should be renamed to `tf.summary.histogram`
-* `tf.contrib.deprecated.scalar_summary`
- * should be renamed to `tf.summary.scalar`
-* `tf.histogram_summary`
- * should be renamed to `tf.summary.histogram`
-* `tf.image_summary`
- * should be renamed to `tf.summary.image`
-* `tf.merge_all_summaries`
- * should be renamed to `tf.summary.merge_all`
-* `tf.merge_summary`
- * should be renamed to `tf.summary.merge`
-* `tf.scalar_summary`
- * should be renamed to `tf.summary.scalar`
-* `tf.train.SummaryWriter`
- * should be renamed to `tf.summary.FileWriter`
-
-### Numeric differences
-
-
-Integer division and `tf.floordiv` now uses flooring semantics. This is to
-make the results of `np.divide` and `np.mod` consistent with `tf.divide` and
-`tf.mod`, respectively. In addition we have changed the rounding algorithm
-used by `tf.round` to match NumPy.
-
-
-* `tf.div`
-
- * The semantics of `tf.divide` division have been changed to match Python
-semantics completely. That is, `/` in Python 3 and future division mode in
-Python 2 will produce floating point numbers always, `//` will produce floored
-division. However, even `tf.div` will produce floored integer division.
-To force C-style truncation semantics, you must use `tf.truncatediv`.
-
- * Consider changing your code to use `tf.divide`, which follows Python semantics for promotion.
-
-* `tf.mod`
-
- * The semantics of `tf.mod` have been changed to match Python semantics. In
-particular, flooring semantics are used for integers. If you wish to have
-C-style truncation mod (remainders), you can use `tf.truncatemod`
-
-
-The old and new behavior of division can be summarized with this table:
-
-| Expr | TF 0.11 (py2) | TF 0.11 (py3) | TF 1.0 (py2) | TF 1.0 (py3) |
-|---------------------|---------------|---------------|--------------|--------------|
-| tf.div(3,4) | 0 | 0 | 0 | 0 |
-| tf.div(-3,4) | 0 | 0 | -1 | -1 |
-| tf.mod(-3,4) | -3 | -3 | 1 | 1 |
-| -3/4 | 0 | -0.75 | -1 | -0.75 |
-| -3/4tf.divide(-3,4) | N/A | N/A | -0.75 | -1 |
-
-The old and new behavior of rounding can be summarized with this table:
-
-| Input | Python | NumPy | C++ round() | TensorFlow 0.11(floor(x+.5)) | TensorFlow 1.0 |
-|-------|--------|-------|-------------|------------------------------|----------------|
-| -3.5 | -4 | -4 | -4 | -3 | -4 |
-| -2.5 | -2 | -2 | -3 | -2 | -2 |
-| -1.5 | -2 | -2 | -2 | -1 | -2 |
-| -0.5 | 0 | 0 | -1 | 0 | 0 |
-| 0.5 | 0 | 0 | 1 | 1 | 0 |
-| 1.5 | 2 | 2 | 2 | 2 | 2 |
-| 2.5 | 2 | 2 | 3 | 3 | 2 |
-| 3.5 | 4 | 4 | 4 | 4 | 4 |
-
-
-
-### NumPy matching names
-
-
-Many functions have been renamed to match NumPy. This was done to make the
-transition between NumPy and TensorFlow as easy as possible. There are still
-numerous cases where functions do not match, so this is far from a hard and
-fast rule, but we have removed several commonly noticed inconsistencies.
-
-* `tf.inv`
- * should be renamed to `tf.reciprocal`
- * This was done to avoid confusion with NumPy's matrix inverse `np.inv`
-* `tf.list_diff`
- * should be renamed to `tf.setdiff1d`
-* `tf.listdiff`
- * should be renamed to `tf.setdiff1d`
-* `tf.mul`
- * should be renamed to `tf.multiply`
-* `tf.neg`
- * should be renamed to `tf.negative`
-* `tf.select`
- * should be renamed to `tf.where`
- * `tf.where` now takes 3 arguments or 1 argument, just like `np.where`
-* `tf.sub`
- * should be renamed to `tf.subtract`
-
-### NumPy matching arguments
-
-Arguments for certain TensorFlow 1.0 methods now match arguments in certain
-NumPy methods. To achieve this, TensorFlow 1.0 has changed keyword arguments
-and reordered some arguments. Notably, TensorFlow 1.0 now uses `axis` rather
-than `dimension`. TensorFlow 1.0 aims to keep the tensor argument first on
-operations that modify Tensors. (see the `tf.concat` change).
-
-
-* `tf.argmax`
- * keyword argument `dimension` should be renamed to `axis`
-* `tf.argmin`
- * keyword argument `dimension` should be renamed to `axis`
-* `tf.concat`
- * keyword argument `concat_dim` should be renamed to `axis`
- * arguments have been reordered to `tf.concat(values, axis, name='concat')`.
-* `tf.count_nonzero`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.expand_dims`
- * keyword argument `dim` should be renamed to `axis`
-* `tf.reduce_all`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_any`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_join`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_logsumexp`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_max`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_mean`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_min`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_prod`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reduce_sum`
- * keyword argument `reduction_indices` should be renamed to `axis`
-* `tf.reverse`
- * `tf.reverse` used to take a 1D `bool` tensor to control which dimensions were reversed. Now we use a Tensor of axis indices.
- * For example `tf.reverse(a, [True, False, True])` now must be `tf.reverse(a, [0, 2])`
-* `tf.reverse_sequence`
- * keyword argument `batch_dim` should be renamed to `batch_axis`
- * keyword argument `seq_dim` should be renamed to `seq_axis`
-* `tf.sparse_concat`
- * keyword argument `concat_dim` should be renamed to `axis`
-* `tf.sparse_reduce_sum`
- * keyword argument `reduction_axes` should be renamed to `axis`
-* `tf.sparse_reduce_sum_sparse`
- * keyword argument `reduction_axes` should be renamed to `axis`
-* `tf.sparse_split`
- * keyword argument `split_dim` should be renamed to `axis`
- * arguments have been reordered to `tf.sparse_split(keyword_required=KeywordRequired(), sp_input=None, num_split=None, axis=None, name=None, split_dim=None)`.
-* `tf.split`
- * keyword argument `split_dim` should be renamed to `axis`
- * keyword argument `num_split` should be renamed to `num_or_size_splits`
- * arguments have been reordered to `tf.split(value, num_or_size_splits, axis=0, num=None, name='split')`.
-* `tf.squeeze`
- * keyword argument `squeeze_dims` should be renamed to `axis`
-* `tf.svd`
- * arguments have been reordered to `tf.svd(tensor, full_matrices=False, compute_uv=True, name=None)`.
-
-### Simplified math variants
-
-Batched versions of math operations have been removed. Now the functionality is
-contained in the non-batched versions. Similarly,`tf.complex_abs` has had its
-functionality moved to `tf.abs`
-
-* `tf.batch_band_part`
- * should be renamed to `tf.band_part`
-* `tf.batch_cholesky`
- * should be renamed to `tf.cholesky`
-* `tf.batch_cholesky_solve`
- * should be renamed to `tf.cholesky_solve`
-* `tf.batch_fft`
- * should be renamed to `tf.fft`
-* `tf.batch_fft3d`
- * should be renamed to `tf.fft3d`
-* `tf.batch_ifft`
- * should be renamed to `tf.ifft`
-* `tf.batch_ifft2d`
- * should be renamed to `tf.ifft2d`
-* `tf.batch_ifft3d`
- * should be renamed to `tf.ifft3d`
-* `tf.batch_matmul`
- * should be renamed to `tf.matmul`
-* `tf.batch_matrix_determinant`
- * should be renamed to `tf.matrix_determinant`
-* `tf.batch_matrix_diag`
- * should be renamed to `tf.matrix_diag`
-* `tf.batch_matrix_inverse`
- * should be renamed to `tf.matrix_inverse`
-* `tf.batch_matrix_solve`
- * should be renamed to `tf.matrix_solve`
-* `tf.batch_matrix_solve_ls`
- * should be renamed to `tf.matrix_solve_ls`
-* `tf.batch_matrix_transpose`
- * should be renamed to `tf.matrix_transpose`
-* `tf.batch_matrix_triangular_solve`
- * should be renamed to `tf.matrix_triangular_solve`
-* `tf.batch_self_adjoint_eig`
- * should be renamed to `tf.self_adjoint_eig`
-* `tf.batch_self_adjoint_eigvals`
- * should be renamed to `tf.self_adjoint_eigvals`
-* `tf.batch_set_diag`
- * should be renamed to `tf.set_diag`
-* `tf.batch_svd`
- * should be renamed to `tf.svd`
-* `tf.complex_abs`
- * should be renamed to `tf.abs`
-
-### Misc Changes
-
-Several other changes have been made, including the following:
-
-* `tf.image.per_image_whitening`
- * should be renamed to `tf.image.per_image_standardization`
-* `tf.nn.sigmoid_cross_entropy_with_logits`
- * arguments have been reordered to `tf.nn.sigmoid_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None)`.
-* `tf.nn.softmax_cross_entropy_with_logits`
- * arguments have been reordered to `tf.nn.softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, dim=-1, name=None)`.
-* `tf.nn.sparse_softmax_cross_entropy_with_logits`
- * arguments have been reordered to `tf.nn.sparse_softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None)`.
-* `tf.ones_initializer`
- * should be changed to a function call i.e. `tf.ones_initializer()`
-* `tf.pack`
- * should be renamed to `tf.stack`
-* `tf.round`
- * The semantics of `tf.round` now match Banker's rounding.
-* `tf.unpack`
- * should be renamed to `tf.unstack`
-* `tf.zeros_initializer`
- * should be changed to a function call i.e. `tf.zeros_initializer()`
-
diff --git a/tensorflow/docs_src/mobile/README.md b/tensorflow/docs_src/mobile/README.md
deleted file mode 100644
index ecf4267265..0000000000
--- a/tensorflow/docs_src/mobile/README.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# TF Lite subsite
-
-This subsite directory lives in [tensorflow/contrib/lite/g3doc](../../contrib/lite/g3doc/).
diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md
deleted file mode 100644
index a5fa551dd4..0000000000
--- a/tensorflow/docs_src/performance/benchmarks.md
+++ /dev/null
@@ -1,412 +0,0 @@
-# Benchmarks
-
-## Overview
-
-A selection of image classification models were tested across multiple platforms
-to create a point of reference for the TensorFlow community. The
-[Methodology](#methodology) section details how the tests were executed and has
-links to the scripts used.
-
-## Results for image classification models
-
-InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)), ResNet-50
-([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), ResNet-152
-([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), VGG16
-([arXiv:1409.1556](https://arxiv.org/abs/1409.1556)), and
-[AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
-were tested using the [ImageNet](http://www.image-net.org/) data set. Tests were
-run on Google Compute Engine, Amazon Elastic Compute Cloud (Amazon EC2), and an
-NVIDIA® DGX-1™. Most of the tests were run with both synthetic and real data.
-Testing with synthetic data was done by using a `tf.Variable` set to the same
-shape as the data expected by each model for ImageNet. We believe it is
-important to include real data measurements when benchmarking a platform. This
-load tests both the underlying hardware and the framework at preparing data for
-actual training. We start with synthetic data to remove disk I/O as a variable
-and to set a baseline. Real data is then used to verify that the TensorFlow
-input pipeline and the underlying disk I/O are saturating the compute units.
-
-### Training with NVIDIA® DGX-1™ (NVIDIA® Tesla® P100)
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:80%" src="../images/perf_summary_p100_single_server.png">
-</div>
-
-Details and additional results are in the [Details for NVIDIA® DGX-1™ (NVIDIA®
-Tesla® P100)](#details_for_nvidia_dgx-1tm_nvidia_tesla_p100) section.
-
-### Training with NVIDIA® Tesla® K80
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:80%" src="../images/perf_summary_k80_single_server.png">
-</div>
-
-Details and additional results are in the [Details for Google Compute Engine
-(NVIDIA® Tesla® K80)](#details_for_google_compute_engine_nvidia_tesla_k80) and
-[Details for Amazon EC2 (NVIDIA® Tesla®
-K80)](#details_for_amazon_ec2_nvidia_tesla_k80) sections.
-
-### Distributed training with NVIDIA® Tesla® K80
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:80%" src="../images/perf_summary_k80_aws_distributed.png">
-</div>
-
-Details and additional results are in the [Details for Amazon EC2 Distributed
-(NVIDIA® Tesla® K80)](#details_for_amazon_ec2_distributed_nvidia_tesla_k80)
-section.
-
-### Compare synthetic with real data training
-
-**NVIDIA® Tesla® P100**
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="../images/perf_summary_p100_data_compare_inceptionv3.png">
- <img style="width:35%" src="../images/perf_summary_p100_data_compare_resnet50.png">
-</div>
-
-**NVIDIA® Tesla® K80**
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="../images/perf_summary_k80_data_compare_inceptionv3.png">
- <img style="width:35%" src="../images/perf_summary_k80_data_compare_resnet50.png">
-</div>
-
-## Details for NVIDIA® DGX-1™ (NVIDIA® Tesla® P100)
-
-### Environment
-
-* **Instance type**: NVIDIA® DGX-1™
-* **GPU:** 8x NVIDIA® Tesla® P100
-* **OS:** Ubuntu 16.04 LTS with tests run via Docker
-* **CUDA / cuDNN:** 8.0 / 5.1
-* **TensorFlow GitHub hash:** b1e174e
-* **Benchmark GitHub hash:** 9165a70
-* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
- //tensorflow/tools/pip_package:build_pip_package`
-* **Disk:** Local SSD
-* **DataSet:** ImageNet
-* **Test Date:** May 2017
-
-Batch size and optimizer used for each model are listed in the table below. In
-addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
-ResNet-152, and VGG16 were tested with a batch size of 32. Those results are in
-the *other results* section.
-
-Options | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
------------------- | ----------- | --------- | ---------- | ------- | -----
-Batch size per GPU | 64 | 64 | 64 | 512 | 64
-Optimizer | sgd | sgd | sgd | sgd | sgd
-
-Configuration used for each model.
-
-Model | variable_update | local_parameter_device
------------ | ---------------------- | ----------------------
-InceptionV3 | parameter_server | cpu
-ResNet50 | parameter_server | cpu
-ResNet152 | parameter_server | cpu
-AlexNet | replicated (with NCCL) | n/a
-VGG16 | replicated (with NCCL) | n/a
-
-### Results
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:80%" src="../images/perf_summary_p100_single_server.png">
-</div>
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="../images/perf_dgx1_synth_p100_single_server_scaling.png">
- <img style="width:35%" src="../images/perf_dgx1_real_p100_single_server_scaling.png">
-</div>
-
-**Training synthetic data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 142 | 219 | 91.8 | 2987 | 154
-2 | 284 | 422 | 181 | 5658 | 295
-4 | 569 | 852 | 356 | 10509 | 584
-8 | 1131 | 1734 | 716 | 17822 | 1081
-
-**Training real data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 142 | 218 | 91.4 | 2890 | 154
-2 | 278 | 425 | 179 | 4448 | 284
-4 | 551 | 853 | 359 | 7105 | 534
-8 | 1079 | 1630 | 708 | N/A | 898
-
-Training AlexNet with real data on 8 GPUs was excluded from the graph and table
-above due to it maxing out the input pipeline.
-
-### Other Results
-
-The results below are all with a batch size of 32.
-
-**Training synthetic data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
----- | ----------- | --------- | ---------- | -----
-1 | 128 | 195 | 82.7 | 144
-2 | 259 | 368 | 160 | 281
-4 | 520 | 768 | 317 | 549
-8 | 995 | 1485 | 632 | 820
-
-**Training real data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
----- | ----------- | --------- | ---------- | -----
-1 | 130 | 193 | 82.4 | 144
-2 | 257 | 369 | 159 | 253
-4 | 507 | 760 | 317 | 457
-8 | 966 | 1410 | 609 | 690
-
-## Details for Google Compute Engine (NVIDIA® Tesla® K80)
-
-### Environment
-
-* **Instance type**: n1-standard-32-k80x8
-* **GPU:** 8x NVIDIA® Tesla® K80
-* **OS:** Ubuntu 16.04 LTS
-* **CUDA / cuDNN:** 8.0 / 5.1
-* **TensorFlow GitHub hash:** b1e174e
-* **Benchmark GitHub hash:** 9165a70
-* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
- //tensorflow/tools/pip_package:build_pip_package`
-* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
-* **DataSet:** ImageNet
-* **Test Date:** May 2017
-
-Batch size and optimizer used for each model are listed in the table below. In
-addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
-tested with a batch size of 32. Those results are in the *other results*
-section.
-
-Options | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
------------------- | ----------- | --------- | ---------- | ------- | -----
-Batch size per GPU | 64 | 64 | 32 | 512 | 32
-Optimizer | sgd | sgd | sgd | sgd | sgd
-
-The configuration used for each model was `variable_update` equal to
-`parameter_server` and `local_parameter_device` equal to `cpu`.
-
-### Results
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="../images/perf_gce_synth_k80_single_server_scaling.png">
- <img style="width:35%" src="../images/perf_gce_real_k80_single_server_scaling.png">
-</div>
-
-**Training synthetic data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 30.5 | 51.9 | 20.0 | 656 | 35.4
-2 | 57.8 | 99.0 | 38.2 | 1209 | 64.8
-4 | 116 | 195 | 75.8 | 2328 | 120
-8 | 227 | 387 | 148 | 4640 | 234
-
-**Training real data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 30.6 | 51.2 | 20.0 | 639 | 34.2
-2 | 58.4 | 98.8 | 38.3 | 1136 | 62.9
-4 | 115 | 194 | 75.4 | 2067 | 118
-8 | 225 | 381 | 148 | 4056 | 230
-
-### Other Results
-
-**Training synthetic data**
-
-GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
----- | --------------------------- | -------------------------
-1 | 29.3 | 49.5
-2 | 55.0 | 95.4
-4 | 109 | 183
-8 | 216 | 362
-
-**Training real data**
-
-GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
----- | --------------------------- | -------------------------
-1 | 29.5 | 49.3
-2 | 55.4 | 95.3
-4 | 110 | 186
-8 | 216 | 359
-
-## Details for Amazon EC2 (NVIDIA® Tesla® K80)
-
-### Environment
-
-* **Instance type**: p2.8xlarge
-* **GPU:** 8x NVIDIA® Tesla® K80
-* **OS:** Ubuntu 16.04 LTS
-* **CUDA / cuDNN:** 8.0 / 5.1
-* **TensorFlow GitHub hash:** b1e174e
-* **Benchmark GitHub hash:** 9165a70
-* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
- //tensorflow/tools/pip_package:build_pip_package`
-* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
- MiB/sec)
-* **DataSet:** ImageNet
-* **Test Date:** May 2017
-
-Batch size and optimizer used for each model are listed in the table below. In
-addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
-tested with a batch size of 32. Those results are in the *other results*
-section.
-
-Options | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
------------------- | ----------- | --------- | ---------- | ------- | -----
-Batch size per GPU | 64 | 64 | 32 | 512 | 32
-Optimizer | sgd | sgd | sgd | sgd | sgd
-
-Configuration used for each model.
-
-Model | variable_update | local_parameter_device
------------ | ------------------------- | ----------------------
-InceptionV3 | parameter_server | cpu
-ResNet-50 | replicated (without NCCL) | gpu
-ResNet-152 | replicated (without NCCL) | gpu
-AlexNet | parameter_server | gpu
-VGG16 | parameter_server | gpu
-
-### Results
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="../images/perf_aws_synth_k80_single_server_scaling.png">
- <img style="width:35%" src="../images/perf_aws_real_k80_single_server_scaling.png">
-</div>
-
-**Training synthetic data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 30.8 | 51.5 | 19.7 | 684 | 36.3
-2 | 58.7 | 98.0 | 37.6 | 1244 | 69.4
-4 | 117 | 195 | 74.9 | 2479 | 141
-8 | 230 | 384 | 149 | 4853 | 260
-
-**Training real data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152 | AlexNet | VGG16
----- | ----------- | --------- | ---------- | ------- | -----
-1 | 30.5 | 51.3 | 19.7 | 674 | 36.3
-2 | 59.0 | 94.9 | 38.2 | 1227 | 67.5
-4 | 118 | 188 | 75.2 | 2201 | 136
-8 | 228 | 373 | 149 | N/A | 242
-
-Training AlexNet with real data on 8 GPUs was excluded from the graph and table
-above due to our EFS setup not providing enough throughput.
-
-### Other Results
-
-**Training synthetic data**
-
-GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
----- | --------------------------- | -------------------------
-1 | 29.9 | 49.0
-2 | 57.5 | 94.1
-4 | 114 | 184
-8 | 216 | 355
-
-**Training real data**
-
-GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
----- | --------------------------- | -------------------------
-1 | 30.0 | 49.1
-2 | 57.5 | 95.1
-4 | 113 | 185
-8 | 212 | 353
-
-## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
-
-### Environment
-
-* **Instance type**: p2.8xlarge
-* **GPU:** 8x NVIDIA® Tesla® K80
-* **OS:** Ubuntu 16.04 LTS
-* **CUDA / cuDNN:** 8.0 / 5.1
-* **TensorFlow GitHub hash:** b1e174e
-* **Benchmark GitHub hash:** 9165a70
-* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
- //tensorflow/tools/pip_package:build_pip_package`
-* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
-* **DataSet:** ImageNet
-* **Test Date:** May 2017
-
-The batch size and optimizer used for the tests are listed in the table. In
-addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
-tested with a batch size of 32. Those results are in the *other results*
-section.
-
-Options | InceptionV3 | ResNet-50 | ResNet-152
------------------- | ----------- | --------- | ----------
-Batch size per GPU | 64 | 64 | 32
-Optimizer | sgd | sgd | sgd
-
-Configuration used for each model.
-
-Model | variable_update | local_parameter_device | cross_replica_sync
------------ | ---------------------- | ---------------------- | ------------------
-InceptionV3 | distributed_replicated | n/a | True
-ResNet-50 | distributed_replicated | n/a | True
-ResNet-152 | distributed_replicated | n/a | True
-
-To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
-ran parameter servers. Equal numbers of parameter servers and worker servers were
-used with the following exceptions:
-
-* InceptionV3: 8 instances / 6 parameter servers
-* ResNet-50: (batch size 32) 8 instances / 4 parameter servers
-* ResNet-152: 8 instances / 4 parameter servers
-
-### Results
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:80%" src="../images/perf_summary_k80_aws_distributed.png">
-</div>
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:70%" src="../images/perf_aws_synth_k80_distributed_scaling.png">
-</div>
-
-**Training synthetic data**
-
-GPUs | InceptionV3 | ResNet-50 | ResNet-152
----- | ----------- | --------- | ----------
-1 | 29.7 | 52.4 | 19.4
-8 | 229 | 378 | 146
-16 | 459 | 751 | 291
-32 | 902 | 1388 | 565
-64 | 1783 | 2744 | 981
-
-### Other Results
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:50%" src="../images/perf_aws_synth_k80_multi_server_batch32.png">
-</div>
-
-**Training synthetic data**
-
-GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
----- | --------------------------- | -------------------------
-1 | 29.2 | 48.4
-8 | 219 | 333
-16 | 427 | 667
-32 | 820 | 1180
-64 | 1608 | 2315
-
-## Methodology
-
-This
-[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
-was run on the various platforms to generate the above results.
-
-In order to create results that are as repeatable as possible, each test was run
-5 times and then the times were averaged together. GPUs are run in their default
-state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
-Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
-For each test, 10 warmup steps are done and then the next 100 steps are
-averaged.
diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md
deleted file mode 100644
index 5d9e4ba392..0000000000
--- a/tensorflow/docs_src/performance/datasets_performance.md
+++ /dev/null
@@ -1,331 +0,0 @@
-# Input Pipeline Performance Guide
-
-GPUs and TPUs can radically reduce the time required to execute a single
-training step. Achieving peak performance requires an efficient input pipeline
-that delivers data for the next step before the current step has finished. The
-`tf.data` API helps to build flexible and efficient input pipelines. This
-document explains the `tf.data` API's features and best practices for building
-high performance TensorFlow input pipelines across a variety of models and
-accelerators.
-
-This guide does the following:
-
-* Illustrates that TensorFlow input pipelines are essentially an
- [ETL](https://en.wikipedia.org/wiki/Extract,_transform,_load) process.
-* Describes common performance optimizations in the context of the `tf.data`
- API.
-* Discusses the performance implications of the order in which you apply
- transformations.
-* Summarizes the best practices for designing performant TensorFlow input
- pipelines.
-
-
-## Input Pipeline Structure
-
-A typical TensorFlow training input pipeline can be framed as an ETL process:
-
-1. **Extract**: Read data from persistent storage -- either local (e.g. HDD or
- SSD) or remote (e.g. [GCS](https://cloud.google.com/storage/) or
- [HDFS](https://en.wikipedia.org/wiki/Apache_Hadoop#Hadoop_distributed_file_system)).
-2. **Transform**: Use CPU cores to parse and perform preprocessing operations
- on the data such as image decompression, data augmentation transformations
- (such as random crop, flips, and color distortions), shuffling, and batching.
-3. **Load**: Load the transformed data onto the accelerator device(s) (for
- example, GPU(s) or TPU(s)) that execute the machine learning model.
-
-This pattern effectively utilizes the CPU, while reserving the accelerator for
-the heavy lifting of training your model. In addition, viewing input pipelines
-as an ETL process provides structure that facilitates the application of
-performance optimizations.
-
-When using the `tf.estimator.Estimator` API, the first two phases (Extract and
-Transform) are captured in the `input_fn` passed to
-`tf.estimator.Estimator.train`. In code, this might look like the following
-(naive, sequential) implementation:
-
-```
-def parse_fn(example):
- "Parse TFExample records and perform simple data augmentation."
- example_fmt = {
- "image": tf.FixedLengthFeature((), tf.string, ""),
- "label": tf.FixedLengthFeature((), tf.int64, -1)
- }
- parsed = tf.parse_single_example(example, example_fmt)
- image = tf.image.decode_image(parsed["image"])
- image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
- return image, parsed["label"]
-
-def input_fn():
- files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
- dataset = files.interleave(tf.data.TFRecordDataset)
- dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
- dataset = dataset.map(map_func=parse_fn)
- dataset = dataset.batch(batch_size=FLAGS.batch_size)
- return dataset
-```
-
-The next section builds on this input pipeline, adding performance
-optimizations.
-
-## Optimizing Performance
-
-As new computing devices (such as GPUs and TPUs) make it possible to train
-neural networks at an increasingly fast rate, the CPU processing is prone to
-becoming the bottleneck. The `tf.data` API provides users with building blocks
-to design input pipelines that effectively utilize the CPU, optimizing each step
-of the ETL process.
-
-### Pipelining
-
-To perform a training step, you must first extract and transform the training
-data and then feed it to a model running on an accelerator. However, in a naive
-synchronous implementation, while the CPU is preparing the data, the accelerator
-is sitting idle. Conversely, while the accelerator is training the model, the
-CPU is sitting idle. The training step time is thus the sum of both CPU
-pre-processing time and the accelerator training time.
-
-**Pipelining** overlaps the preprocessing and model execution of a training
-step. While the accelerator is performing training step `N`, the CPU is
-preparing the data for step `N+1`. Doing so reduces the step time to the maximum
-(as opposed to the sum) of the training and the time it takes to extract and
-transform the data.
-
-Without pipelining, the CPU and the GPU/TPU sit idle much of the time:
-
-![without pipelining](/images/datasets_without_pipelining.png)
-
-With pipelining, idle time diminishes significantly:
-
-![with pipelining](/images/datasets_with_pipelining.png)
-
-The `tf.data` API provides a software pipelining mechanism through the
-`tf.data.Dataset.prefetch` transformation, which can be used to decouple the
-time data is produced from the time it is consumed. In particular, the
-transformation uses a background thread and an internal buffer to prefetch
-elements from the input dataset ahead of the time they are requested. Thus, to
-achieve the pipelining effect illustrated above, you can add `prefetch(1)` as
-the final transformation to your dataset pipeline (or `prefetch(n)` if a single
-training step consumes n elements).
-
-To apply this change to our running example, change:
-
-```
-dataset = dataset.batch(batch_size=FLAGS.batch_size)
-return dataset
-```
-
-to:
-
-
-```
-dataset = dataset.batch(batch_size=FLAGS.batch_size)
-dataset = dataset.prefetch(buffer_size=FLAGS.prefetch_buffer_size)
-return dataset
-```
-
-Note that the prefetch transformation will yield benefits any time there is an
-opportunity to overlap the work of a "producer" with the work of a "consumer."
-The preceding recommendation is simply the most common application.
-
-### Parallelize Data Transformation
-
-When preparing a batch, input elements may need to be pre-processed. To this
-end, the `tf.data` API offers the `tf.data.Dataset.map` transformation, which
-applies a user-defined function (for example, `parse_fn` from the running
-example) to each element of the input dataset. Because input elements are
-independent of one another, the pre-processing can be parallelized across
-multiple CPU cores. To make this possible, the `map` transformation provides the
-`num_parallel_calls` argument to specify the level of parallelism. For example,
-the following diagram illustrates the effect of setting `num_parallel_calls=2`
-to the `map` transformation:
-
-![parallel map](/images/datasets_parallel_map.png)
-
-Choosing the best value for the `num_parallel_calls` argument depends on your
-hardware, characteristics of your training data (such as its size and shape),
-the cost of your map function, and what other processing is happening on the
-CPU at the same time; a simple heuristic is to use the number of available CPU
-cores. For instance, if the machine executing the example above had 4 cores, it
-would have been more efficient to set `num_parallel_calls=4`. On the other hand,
-setting `num_parallel_calls` to a value much greater than the number of
-available CPUs can lead to inefficient scheduling, resulting in a slowdown.
-
-To apply this change to our running example, change:
-
-```
-dataset = dataset.map(map_func=parse_fn)
-```
-
-to:
-
-```
-dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)
-```
-
-Furthermore, if your batch size is in the hundreds or thousands, your pipeline
-will likely additionally benefit from parallelizing the batch creation. To this
-end, the `tf.data` API provides the `tf.contrib.data.map_and_batch`
-transformation, which effectively "fuses" the map and batch transformations.
-
-To apply this change to our running example, change:
-
-```
-dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_calls)
-dataset = dataset.batch(batch_size=FLAGS.batch_size)
-```
-
-to:
-
-```
-dataset = dataset.apply(tf.contrib.data.map_and_batch(
- map_func=parse_fn, batch_size=FLAGS.batch_size))
-```
-
-### Parallelize Data Extraction
-
-In a real-world setting, the input data may be stored remotely (for example,
-GCS or HDFS), either because the input data would not fit locally or because the
-training is distributed and it would not make sense to replicate the input data
-on every machine. A dataset pipeline that works well when reading data locally
-might become bottlenecked on I/O when reading data remotely because of the
-following differences between local and remote storage:
-
-
-* **Time-to-first-byte:** Reading the first byte of a file from remote storage
- can take orders of magnitude longer than from local storage.
-* **Read throughput:** While remote storage typically offers large aggregate
- bandwidth, reading a single file might only be able to utilize a small
- fraction of this bandwidth.
-
-In addition, once the raw bytes are read into memory, it may also be necessary
-to deserialize or decrypt the data
-(e.g. [protobuf](https://developers.google.com/protocol-buffers/)), which adds
-additional overhead. This overhead is present irrespective of whether the data
-is stored locally or remotely, but can be worse in the remote case if data is
-not prefetched effectively.
-
-To mitigate the impact of the various data extraction overheads, the `tf.data`
-API offers the `tf.contrib.data.parallel_interleave` transformation. Use this
-transformation to parallelize the execution of and interleave the contents of
-other datasets (such as data file readers). The
-number of datasets to overlap can be specified by the `cycle_length` argument.
-
-The following diagram illustrates the effect of supplying `cycle_length=2` to
-the `parallel_interleave` transformation:
-
-![parallel io](/images/datasets_parallel_io.png)
-
-To apply this change to our running example, change:
-
-```
-dataset = files.interleave(tf.data.TFRecordDataset)
-```
-
-to:
-
-```
-dataset = files.apply(tf.contrib.data.parallel_interleave(
- tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers))
-```
-
-
-The throughput of remote storage systems can vary over time due to load or
-network events. To account for this variance, the `parallel_interleave`
-transformation can optionally use prefetching. (See
-`tf.contrib.data.parallel_interleave` for details).
-
-By default, the `parallel_interleave` transformation provides a deterministic
-ordering of elements to aid reproducibility. As an alternative to prefetching
-(which may be ineffective in some cases), the `parallel_interleave`
-transformation also provides an option that can boost performance at the expense
-of ordering guarantees. In particular, if the `sloppy` argument is set to true,
-the transformation may depart from its otherwise deterministic ordering, by
-temporarily skipping over files whose elements are not available when the next
-element is requested.
-
-## Performance Considerations
-
-The `tf.data` API is designed around composable transformations to provide its
-users with flexibility. Although many of these transformations are commutative,
-the ordering of certain transformations has performance implications.
-
-### Map and Batch
-
-Invoking the user-defined function passed into the `map` transformation has
-overhead related to scheduling and executing the user-defined function.
-Normally, this overhead is small compared to the amount of computation performed
-by the function. However, if `map` does little work, this overhead can dominate
-the total cost. In such cases, we recommend vectorizing the user-defined
-function (that is, have it operate over a batch of inputs at once) and apply the
-`batch` transformation _before_ the `map` transformation.
-
-### Map and Cache
-
-The `tf.data.Dataset.cache` transformation can cache a dataset, either in
-memory or on local storage. If the user-defined function passed into the `map`
-transformation is expensive, apply the cache transformation after the map
-transformation as long as the resulting dataset can still fit into memory or
-local storage. If the user-defined function increases the space required to
-store the dataset beyond the cache capacity, consider pre-processing your data
-before your training job to reduce resource usage.
-
-### Map and Interleave / Prefetch / Shuffle
-
-A number of transformations, including `interleave`, `prefetch`, and `shuffle`,
-maintain an internal buffer of elements. If the user-defined function passed
-into the `map` transformation changes the size of the elements, then the
-ordering of the map transformation and the transformations that buffer elements
-affects the memory usage. In general, we recommend choosing the order that
-results in lower memory footprint, unless different ordering is desirable for
-performance (for example, to enable fusing of the map and batch transformations).
-
-### Repeat and Shuffle
-
-The `tf.data.Dataset.repeat` transformation repeats the input data a finite (or
-infinite) number of times; each repetition of the data is typically referred to
-as an _epoch_. The `tf.data.Dataset.shuffle` transformation randomizes the
-order of the dataset's examples.
-
-If the `repeat` transformation is applied before the `shuffle` transformation,
-then the epoch boundaries are blurred. That is, certain elements can be repeated
-before other elements appear even once. On the other hand, if the `shuffle`
-transformation is applied before the repeat transformation, then performance
-might slow down at the beginning of each epoch related to initialization of the
-internal state of the `shuffle` transformation. In other words, the former
-(`repeat` before `shuffle`) provides better performance, while the latter
-(`shuffle` before `repeat`) provides stronger ordering guarantees.
-
-When possible, we recommend using the fused
-`tf.contrib.data.shuffle_and_repeat` transformation, which combines the best of
-both worlds (good performance and strong ordering guarantees). Otherwise, we
-recommend shuffling before repeating.
-
-## Summary of Best Practices
-
-Here is a summary of the best practices for designing input pipelines:
-
-* Use the `prefetch` transformation to overlap the work of a producer and
- consumer. In particular, we recommend adding prefetch(n) (where n is the
- number of elements / batches consumed by a training step) to the end of your
- input pipeline to overlap the transformations performed on the CPU with the
- training done on the accelerator.
-* Parallelize the `map` transformation by setting the `num_parallel_calls`
- argument. We recommend using the number of available CPU cores for its value.
-* If you are combining pre-processed elements into a batch using the `batch`
- transformation, we recommend using the fused `map_and_batch` transformation;
- especially if you are using large batch sizes.
-* If you are working with data stored remotely and / or requiring
- deserialization, we recommend using the `parallel_interleave`
- transformation to overlap the reading (and deserialization) of data from
- different files.
-* Vectorize cheap user-defined functions passed in to the `map` transformation
- to amortize the overhead associated with scheduling and executing the
- function.
-* If your data can fit into memory, use the `cache` transformation to cache it
- in memory during the first epoch, so that subsequent epochs can avoid the
- overhead associated with reading, parsing, and transforming it.
-* If your pre-processing increases the size of your data, we recommend
- applying the `interleave`, `prefetch`, and `shuffle` first (if possible) to
- reduce memory usage.
-* We recommend applying the `shuffle` transformation _before_ the `repeat`
- transformation, ideally using the fused `shuffle_and_repeat` transformation.
diff --git a/tensorflow/docs_src/performance/index.md b/tensorflow/docs_src/performance/index.md
deleted file mode 100644
index 131d28fa3e..0000000000
--- a/tensorflow/docs_src/performance/index.md
+++ /dev/null
@@ -1,52 +0,0 @@
-# Performance
-
-Performance is an important consideration when training machine learning
-models. Performance speeds up and scales research while
-also providing end users with near instant predictions. This section provides
-details on the high level APIs to use along with best practices to build
-and train high performance models, and quantize models for the least latency
-and highest throughput for inference.
-
- * @{$performance_guide$Performance Guide} contains a collection of best
- practices for optimizing your TensorFlow code.
-
- * @{$datasets_performance$Data input pipeline guide} describes the tf.data
- API for building efficient data input pipelines for TensorFlow.
-
- * @{$performance/benchmarks$Benchmarks} contains a collection of
- benchmark results for a variety of hardware configurations.
-
- * For improving inference efficiency on mobile and
- embedded hardware, see
- @{$quantization$How to Quantize Neural Networks with TensorFlow}, which
- explains how to use quantization to reduce model size, both in storage
- and at runtime.
-
- * For optimizing inference on GPUs, refer to [NVIDIA TensorRT™
- integration with TensorFlow.](
- https://medium.com/tensorflow/speed-up-tensorflow-inference-on-gpus-with-tensorrt-13b49f3db3fa)
-
-
-XLA (Accelerated Linear Algebra) is an experimental compiler for linear
-algebra that optimizes TensorFlow computations. The following guides explore
-XLA:
-
- * @{$xla$XLA Overview}, which introduces XLA.
- * @{$broadcasting$Broadcasting Semantics}, which describes XLA's
- broadcasting semantics.
- * @{$developing_new_backend$Developing a new back end for XLA}, which
- explains how to re-target TensorFlow in order to optimize the performance
- of the computational graph for particular hardware.
- * @{$jit$Using JIT Compilation}, which describes the XLA JIT compiler that
- compiles and runs parts of TensorFlow graphs via XLA in order to optimize
- performance.
- * @{$operation_semantics$Operation Semantics}, which is a reference manual
- describing the semantics of operations in the `ComputationBuilder`
- interface.
- * @{$shapes$Shapes and Layout}, which details the `Shape` protocol buffer.
- * @{$tfcompile$Using AOT compilation}, which explains `tfcompile`, a
- standalone tool that compiles TensorFlow graphs into executable code in
- order to optimize performance.
-
-
-
diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files
deleted file mode 100644
index 12e0dbd48a..0000000000
--- a/tensorflow/docs_src/performance/leftnav_files
+++ /dev/null
@@ -1,14 +0,0 @@
-index.md
-performance_guide.md
-datasets_performance.md
-benchmarks.md
-quantization.md
-
-### XLA
-xla/index.md
-xla/broadcasting.md
-xla/developing_new_backend.md
-xla/jit.md
-xla/operation_semantics.md
-xla/shapes.md
-xla/tfcompile.md
diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md
deleted file mode 100644
index df70309568..0000000000
--- a/tensorflow/docs_src/performance/performance_guide.md
+++ /dev/null
@@ -1,733 +0,0 @@
-# Performance Guide
-
-This guide contains a collection of best practices for optimizing TensorFlow
-code. The guide is divided into a few sections:
-
-* [General best practices](#general_best_practices) covers topics that are
- common across a variety of model types and hardware.
-* [Optimizing for GPU](#optimizing_for_gpu) details tips specifically relevant
- to GPUs.
-* [Optimizing for CPU](#optimizing_for_cpu) details CPU specific information.
-
-## General best practices
-
-The sections below cover best practices that are relevant to a variety of
-hardware and models. The best practices section is broken down into the
-following sections:
-
-* [Input pipeline optimizations](#input-pipeline-optimization)
-* [Data formats](#data-formats)
-* [Common fused Ops](#common-fused-ops)
-* [RNN Performance](#rnn-performance)
-* [Building and installing from source](#building-and-installing-from-source)
-
-### Input pipeline optimization
-
-Typical models retrieve data from disk and preprocess it before sending the data
-through the network. For example, models that process JPEG images will follow
-this flow: load image from disk, decode JPEG into a tensor, crop and pad,
-possibly flip and distort, and then batch. This flow is referred to as the input
-pipeline. As GPUs and other hardware accelerators get faster, preprocessing of
-data can be a bottleneck.
-
-Determining if the input pipeline is the bottleneck can be complicated. One of
-the most straightforward methods is to reduce the model to a single operation
-(trivial model) after the input pipeline and measure the examples per second. If
-the difference in examples per second for the full model and the trivial model
-is minimal then the input pipeline is likely a bottleneck. Below are some other
-approaches to identifying issues:
-
-* Check if a GPU is underutilized by running `nvidia-smi -l 2`. If GPU
- utilization is not approaching 80-100%, then the input pipeline may be the
- bottleneck.
-* Generate a timeline and look for large blocks of white space (waiting). An
- example of generating a timeline exists as part of the @{$jit$XLA JIT}
- tutorial.
-* Check CPU usage. It is possible to have an optimized input pipeline and lack
- the CPU cycles to process the pipeline.
-* Estimate the throughput needed and verify the disk used is capable of that
- level of throughput. Some cloud solutions have network attached disks that
- start as low as 50 MB/sec, which is slower than spinning disks (150 MB/sec),
- SATA SSDs (500 MB/sec), and PCIe SSDs (2,000+ MB/sec).
-
-#### Preprocessing on the CPU
-
-Placing input pipeline operations on the CPU can significantly improve
-performance. Utilizing the CPU for the input pipeline frees the GPU to focus on
-training. To ensure preprocessing is on the CPU, wrap the preprocessing
-operations as shown below:
-
-```python
-with tf.device('/cpu:0'):
- # function to get and process images or data.
- distorted_inputs = load_and_distort_images()
-```
-
-If using `tf.estimator.Estimator` the input function is automatically placed on
-the CPU.
-
-#### Using the tf.data API
-
-The @{$datasets$tf.data API} is replacing `queue_runner` as the recommended API
-for building input pipelines. This
-[ResNet example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/cifar10_main.py)
-([arXiv:1512.03385](https://arxiv.org/abs/1512.03385))
-training CIFAR-10 illustrates the use of the `tf.data` API along with
-`tf.estimator.Estimator`.
-
-The `tf.data` API utilizes C++ multi-threading and has a much lower overhead
-than the Python-based `queue_runner` that is limited by Python's multi-threading
-performance. A detailed performance guide for the `tf.data` API can be found
-@{$datasets_performance$here}.
-
-While feeding data using a `feed_dict` offers a high level of flexibility, in
-general `feed_dict` does not provide a scalable solution. If only a single GPU
-is used, the difference between the `tf.data` API and `feed_dict` performance
-may be negligible. Our recommendation is to avoid using `feed_dict` for all but
-trivial examples. In particular, avoid using `feed_dict` with large inputs:
-
-```python
-# feed_dict often results in suboptimal performance when using large inputs.
-sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
-```
-
-#### Fused decode and crop
-
-If inputs are JPEG images that also require cropping, use fused
-`tf.image.decode_and_crop_jpeg` to speed up preprocessing.
-`tf.image.decode_and_crop_jpeg` only decodes the part of
-the image within the crop window. This significantly speeds up the process if
-the crop window is much smaller than the full image. For imagenet data, this
-approach could speed up the input pipeline by up to 30%.
-
-Example Usage:
-
-```python
-def _image_preprocess_fn(image_buffer):
- # image_buffer 1-D string Tensor representing the raw JPEG image buffer.
-
- # Extract image shape from raw JPEG image buffer.
- image_shape = tf.image.extract_jpeg_shape(image_buffer)
-
- # Get a crop window with distorted bounding box.
- sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
- image_shape, ...)
- bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
-
- # Decode and crop image.
- offset_y, offset_x, _ = tf.unstack(bbox_begin)
- target_height, target_width, _ = tf.unstack(bbox_size)
- crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
- cropped_image = tf.image.decode_and_crop_jpeg(image, crop_window)
-```
-
-`tf.image.decode_and_crop_jpeg` is available on all platforms. There is no speed
-up on Windows due to the use of `libjpeg` vs. `libjpeg-turbo` on other
-platforms.
-
-#### Use large files
-
-Reading large numbers of small files significantly impacts I/O performance.
-One approach to get maximum I/O throughput is to preprocess input data into
-larger (~100MB) `TFRecord` files. For smaller data sets (200MB-1GB), the best
-approach is often to load the entire data set into memory. The document
-[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/research/slim#downloading-and-converting-to-tfrecord-format)
-includes information and scripts for creating `TFRecords` and this
-[script](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py)
-converts the CIFAR-10 data set into `TFRecords`.
-
-### Data formats
-
-Data formats refers to the structure of the Tensor passed to a given Op. The
-discussion below is specifically about 4D Tensors representing images. In
-TensorFlow the parts of the 4D tensor are often referred to by the following
-letters:
-
-* N refers to the number of images in a batch.
-* H refers to the number of pixels in the vertical (height) dimension.
-* W refers to the number of pixels in the horizontal (width) dimension.
-* C refers to the channels. For example, 1 for black and white or grayscale
- and 3 for RGB.
-
-Within TensorFlow there are two naming conventions representing the two most
-common data formats:
-
-* `NCHW` or `channels_first`
-* `NHWC` or `channels_last`
-
-`NHWC` is the TensorFlow default and `NCHW` is the optimal format to use when
-training on NVIDIA GPUs using [cuDNN](https://developer.nvidia.com/cudnn).
-
-The best practice is to build models that work with both data formats. This
-simplifies training on GPUs and then running inference on CPUs. If TensorFlow is
-compiled with the [Intel MKL](#tensorflow_with_intel_mkl-dnn) optimizations,
-many operations, especially those related to CNN based models, will be optimized
-and support `NCHW`. If not using the MKL, some operations are not supported on
-CPU when using `NCHW`.
-
-The brief history of these two formats is that TensorFlow started by using
-`NHWC` because it was a little faster on CPUs. In the long term, we are working
-on tools to auto rewrite graphs to make switching between the formats
-transparent and take advantages of micro optimizations where a GPU Op may be
-faster using `NHWC` than the normally most efficient `NCHW`.
-
-### Common fused Ops
-
-Fused Ops combine multiple operations into a single kernel for improved
-performance. There are many fused Ops within TensorFlow and @{$xla$XLA} will
-create fused Ops when possible to automatically improve performance. Collected
-below are select fused Ops that can greatly improve performance and may be
-overlooked.
-
-#### Fused batch norm
-
-Fused batch norm combines the multiple operations needed to do batch
-normalization into a single kernel. Batch norm is an expensive process that for
-some models makes up a large percentage of the operation time. Using fused batch
-norm can result in a 12%-30% speedup.
-
-There are two commonly used batch norms and both support fusing. The core
-`tf.layers.batch_normalization` added fused starting in TensorFlow 1.3.
-
-```python
-bn = tf.layers.batch_normalization(
- input_layer, fused=True, data_format='NCHW')
-```
-
-The contrib `tf.contrib.layers.batch_norm` method has had fused as an option
-since before TensorFlow 1.0.
-
-```python
-bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW')
-```
-
-### RNN Performance
-
-There are many ways to specify an RNN computation in TensorFlow and they have
-trade-offs with respect to model flexibility and performance. The
-`tf.nn.rnn_cell.BasicLSTMCell` should be considered a reference implementation
-and used only as a last resort when no other options will work.
-
-When using one of the cells, rather than the fully fused RNN layers, you have a
-choice of whether to use `tf.nn.static_rnn` or `tf.nn.dynamic_rnn`. There
-shouldn't generally be a performance difference at runtime, but large unroll
-amounts can increase the graph size of the `tf.nn.static_rnn` and cause long
-compile times. An additional advantage of `tf.nn.dynamic_rnn` is that it can
-optionally swap memory from the GPU to the CPU to enable training of very long
-sequences. Depending on the model and hardware configuration, this can come at
-a performance cost. It is also possible to run multiple iterations of
-`tf.nn.dynamic_rnn` and the underlying `tf.while_loop` construct in parallel,
-although this is rarely useful with RNN models as they are inherently
-sequential.
-
-On NVIDIA GPUs, the use of `tf.contrib.cudnn_rnn` should always be preferred
-unless you want layer normalization, which it doesn't support. It is often at
-least an order of magnitude faster than `tf.contrib.rnn.BasicLSTMCell` and
-`tf.contrib.rnn.LSTMBlockCell` and uses 3-4x less memory than
-`tf.contrib.rnn.BasicLSTMCell`.
-
-If you need to run one step of the RNN at a time, as might be the case in
-reinforcement learning with a recurrent policy, then you should use the
-`tf.contrib.rnn.LSTMBlockCell` with your own environment interaction loop
-inside a `tf.while_loop` construct. Running one step of the RNN at a time and
-returning to Python is possible, but it will be slower.
-
-On CPUs, mobile devices, and if `tf.contrib.cudnn_rnn` is not available on
-your GPU, the fastest and most memory efficient option is
-`tf.contrib.rnn.LSTMBlockFusedCell`.
-
-For all of the less common cell types like `tf.contrib.rnn.NASCell`,
-`tf.contrib.rnn.PhasedLSTMCell`, `tf.contrib.rnn.UGRNNCell`,
-`tf.contrib.rnn.GLSTMCell`, `tf.contrib.rnn.Conv1DLSTMCell`,
-`tf.contrib.rnn.Conv2DLSTMCell`, `tf.contrib.rnn.LayerNormBasicLSTMCell`,
-etc., one should be aware that they are implemented in the graph like
-`tf.contrib.rnn.BasicLSTMCell` and as such will suffer from the same poor
-performance and high memory usage. One should consider whether or not those
-trade-offs are worth it before using these cells. For example, while layer
-normalization can speed up convergence, because cuDNN is 20x faster the fastest
-wall clock time to convergence is usually obtained without it.
-
-
-### Building and installing from source
-
-The default TensorFlow binaries target the broadest range of hardware to make
-TensorFlow accessible to everyone. If using CPUs for training or inference, it
-is recommended to compile TensorFlow with all of the optimizations available for
-the CPU in use. Speedups for training and inference on CPU are documented below
-in [Comparing compiler optimizations](#comparing-compiler-optimizations).
-
-To install the most optimized version of TensorFlow,
-@{$install_sources$build and install} from source. If there is a need to build
-TensorFlow on a platform that has different hardware than the target, then
-cross-compile with the highest optimizations for the target platform. The
-following command is an example of using `bazel` to compile for a specific
-platform:
-
-```python
-# This command optimizes for Intel’s Broadwell processor
-bazel build -c opt --copt=-march="broadwell" --config=cuda //tensorflow/tools/pip_package:build_pip_package
-
-```
-
-#### Environment, build, and install tips
-
-* `./configure` asks which compute capability to include in the build. This
- does not impact overall performance but does impact initial startup. After
- running TensorFlow once, the compiled kernels are cached by CUDA. If using
- a docker container, the data is not cached and the penalty is paid each time
- TensorFlow starts. The best practice is to include the
- [compute capabilities](http://developer.nvidia.com/cuda-gpus)
- of the GPUs that will be used, e.g. P100: 6.0, Titan X (Pascal): 6.1, Titan
- X (Maxwell): 5.2, and K80: 3.7.
-* Use a version of gcc that supports all of the optimizations of the target
- CPU. The recommended minimum gcc version is 4.8.3. On OS X, upgrade to the
- latest Xcode version and use the version of clang that comes with Xcode.
-* Install the latest stable CUDA platform and cuDNN libraries supported by
- TensorFlow.
-
-## Optimizing for GPU
-
-This section contains GPU-specific tips that are not covered in the
-[General best practices](#general-best-practices). Obtaining optimal performance
-on multi-GPUs is a challenge. A common approach is to use data parallelism.
-Scaling through the use of data parallelism involves making multiple copies of
-the model, which are referred to as "towers", and then placing one tower on each
-of the GPUs. Each tower operates on a different mini-batch of data and then
-updates variables, also known as parameters, that need to be shared between
-each of the towers. How each tower gets the updated variables and how the
-gradients are applied has an impact on the performance, scaling, and convergence
-of the model. The rest of this section provides an overview of variable
-placement and the towering of a model on multiple GPUs.
-@{$performance_models$High-Performance Models} gets into more details regarding
-more complex methods that can be used to share and update variables between
-towers.
-
-The best approach to handling variable updates depends on the model, hardware,
-and even how the hardware has been configured. An example of this, is that two
-systems can be built with NVIDIA Tesla P100s but one may be using PCIe and the
-other [NVLink](http://www.nvidia.com/object/nvlink.html). In that scenario, the
-optimal solution for each system may be different. For real world examples, read
-the @{$performance/benchmarks$benchmark} page which details the settings that
-were optimal for a variety of platforms. Below is a summary of what was learned
-from benchmarking various platforms and configurations:
-
-* **Tesla K80**: If the GPUs are on the same PCI Express root complex and are
- able to use [NVIDIA GPUDirect](https://developer.nvidia.com/gpudirect) Peer
- to Peer, then placing the variables equally across the GPUs used for
- training is the best approach. If the GPUs cannot use GPUDirect, then
- placing the variables on the CPU is the best option.
-
-* **Titan X (Maxwell and Pascal), M40, P100, and similar**: For models like
- ResNet and InceptionV3, placing variables on the CPU is the optimal setting,
- but for models with a lot of variables like AlexNet and VGG, using GPUs with
- `NCCL` is better.
-
-A common approach to managing where variables are placed, is to create a method
-to determine where each Op is to be placed and use that method in place of a
-specific device name when calling `with tf.device():`. Consider a scenario where
-a model is being trained on 2 GPUs and the variables are to be placed on the
-CPU. There would be a loop for creating and placing the "towers" on each of the
-2 GPUs. A custom device placement method would be created that watches for Ops
-of type `Variable`, `VariableV2`, and `VarHandleOp` and indicates that they are
-to be placed on the CPU. All other Ops would be placed on the target GPU.
-The building of the graph would proceed as follows:
-
-* On the first loop a "tower" of the model would be created for `gpu:0`.
- During the placement of the Ops, the custom device placement method would
- indicate that variables are to be placed on `cpu:0` and all other Ops on
- `gpu:0`.
-
-* On the second loop, `reuse` is set to `True` to indicate that variables are
- to be reused and then the "tower" is created on `gpu:1`. During the
- placement of the Ops associated with the "tower", the variables that were
- placed on `cpu:0` are reused and all other Ops are created and placed on
- `gpu:1`.
-
-The final result is all of the variables are placed on the CPU with each GPU
-having a copy of all of the computational Ops associated with the model.
-
-The code snippet below illustrates two different approaches for variable
-placement: one is placing variables on the CPU; the other is placing variables
-equally across the GPUs.
-
-```python
-
-class GpuParamServerDeviceSetter(object):
- """Used with tf.device() to place variables on the least loaded GPU.
-
- A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
- 'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
- placed on the least loaded gpu. All other Ops, which will be the computation
- Ops, will be placed on the worker_device.
- """
-
- def __init__(self, worker_device, ps_devices):
- """Initializer for GpuParamServerDeviceSetter.
- Args:
- worker_device: the device to use for computation Ops.
- ps_devices: a list of devices to use for Variable Ops. Each variable is
- assigned to the least loaded device.
- """
- self.ps_devices = ps_devices
- self.worker_device = worker_device
- self.ps_sizes = [0] * len(self.ps_devices)
-
- def __call__(self, op):
- if op.device:
- return op.device
- if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
- return self.worker_device
-
- # Gets the least loaded ps_device
- device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
- device_name = self.ps_devices[device_index]
- var_size = op.outputs[0].get_shape().num_elements()
- self.ps_sizes[device_index] += var_size
-
- return device_name
-
-def _create_device_setter(is_cpu_ps, worker, num_gpus):
- """Create device setter object."""
- if is_cpu_ps:
- # tf.train.replica_device_setter supports placing variables on the CPU, all
- # on one GPU, or on ps_servers defined in a cluster_spec.
- return tf.train.replica_device_setter(
- worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
- else:
- gpus = ['/gpu:%d' % i for i in range(num_gpus)]
- return ParamServerDeviceSetter(worker, gpus)
-
-# The method below is a modified snippet from the full example.
-def _resnet_model_fn():
- # When set to False, variables are placed on the least loaded GPU. If set
- # to True, the variables will be placed on the CPU.
- is_cpu_ps = False
-
- # Loops over the number of GPUs and creates a copy ("tower") of the model on
- # each GPU.
- for i in range(num_gpus):
- worker = '/gpu:%d' % i
- # Creates a device setter used to determine where Ops are to be placed.
- device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus)
- # Creates variables on the first loop. On subsequent loops reuse is set
- # to True, which results in the "towers" sharing variables.
- with tf.variable_scope('resnet', reuse=bool(i != 0)):
- with tf.name_scope('tower_%d' % i) as name_scope:
- # tf.device calls the device_setter for each Op that is created.
- # device_setter returns the device the Op is to be placed on.
- with tf.device(device_setter):
- # Creates the "tower".
- _tower_fn(is_training, weight_decay, tower_features[i],
- tower_labels[i], tower_losses, tower_gradvars,
- tower_preds, False)
-
-```
-
-In the near future the above code will be for illustration purposes only as
-there will be easy to use high level methods to support a wide range of popular
-approaches. This
-[example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator)
-will continue to get updated as the API expands and evolves to address multi-GPU
-scenarios.
-
-## Optimizing for CPU
-
-CPUs, which includes Intel® Xeon Phi™, achieve optimal performance when
-TensorFlow is @{$install_sources$built from source} with all of the instructions
-supported by the target CPU.
-
-Beyond using the latest instruction sets, Intel® has added support for the
-Intel® Math Kernel Library for Deep Neural Networks (Intel® MKL-DNN) to
-TensorFlow. While the name is not completely accurate, these optimizations are
-often simply referred to as 'MKL' or 'TensorFlow with MKL'. [TensorFlow
-with Intel® MKL-DNN](#tensorflow_with_intel_mkl_dnn) contains details on the
-MKL optimizations.
-
-The two configurations listed below are used to optimize CPU performance by
-adjusting the thread pools.
-
-* `intra_op_parallelism_threads`: Nodes that can use multiple threads to
- parallelize their execution will schedule the individual pieces into this
- pool.
-* `inter_op_parallelism_threads`: All ready nodes are scheduled in this pool.
-
-These configurations are set via the `tf.ConfigProto` and passed to `tf.Session`
-in the `config` attribute as shown in the snippet below. For both configuration
-options, if they are unset or set to 0, will default to the number of logical
-CPU cores. Testing has shown that the default is effective for systems ranging
-from one CPU with 4 cores to multiple CPUs with 70+ combined logical cores.
-A common alternative optimization is to set the number of threads in both pools
-equal to the number of physical cores rather than logical cores.
-
-```python
-
- config = tf.ConfigProto()
- config.intra_op_parallelism_threads = 44
- config.inter_op_parallelism_threads = 44
- tf.Session(config=config)
-
-```
-
-The [Comparing compiler optimizations](#comparing-compiler-optimizations)
-section contains the results of tests that used different compiler
-optimizations.
-
-### TensorFlow with Intel® MKL DNN
-
-Intel® has added optimizations to TensorFlow for Intel® Xeon® and Intel® Xeon
-Phi™ through the use of the Intel® Math Kernel Library for Deep Neural Networks
-(Intel® MKL-DNN) optimized primitives. The optimizations also provide speedups
-for the consumer line of processors, e.g. i5 and i7 Intel processors. The Intel
-published paper
-[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
-contains additional details on the implementation.
-
-> Note: MKL was added as of TensorFlow 1.2 and currently only works on Linux. It
-> also does not work when also using `--config=cuda`.
-
-In addition to providing significant performance improvements for training CNN
-based models, compiling with the MKL creates a binary that is optimized for AVX
-and AVX2. The result is a single binary that is optimized and compatible with
-most modern (post-2011) processors.
-
-TensorFlow can be compiled with the MKL optimizations using the following
-commands that depending on the version of the TensorFlow source used.
-
-For TensorFlow source versions after 1.3.0:
-
-```bash
-./configure
-# Pick the desired options
-bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
-
-```
-
-For TensorFlow versions 1.2.0 through 1.3.0:
-
-```bash
-./configure
-Do you wish to build TensorFlow with MKL support? [y/N] Y
-Do you wish to download MKL LIB from the web? [Y/n] Y
-# Select the defaults for the rest of the options.
-
-bazel build --config=mkl --copt="-DEIGEN_USE_VML" -c opt //tensorflow/tools/pip_package:build_pip_package
-
-```
-
-#### Tuning MKL for the best performance
-
-This section details the different configurations and environment variables that
-can be used to tune the MKL to get optimal performance. Before tweaking various
-environment variables make sure the model is using the `NCHW` (`channels_first`)
-[data format](#data-formats). The MKL is optimized for `NCHW` and Intel is
-working to get near performance parity when using `NHWC`.
-
-MKL uses the following environment variables to tune performance:
-
-* KMP_BLOCKTIME - Sets the time, in milliseconds, that a thread should wait,
- after completing the execution of a parallel region, before sleeping.
-* KMP_AFFINITY - Enables the run-time library to bind threads to physical
- processing units.
-* KMP_SETTINGS - Enables (true) or disables (false) the printing of OpenMP*
- run-time library environment variables during program execution.
-* OMP_NUM_THREADS - Specifies the number of threads to use.
-
-More details on the KMP variables are on
-[Intel's](https://software.intel.com/en-us/node/522775) site and the OMP
-variables on
-[gnu.org](https://gcc.gnu.org/onlinedocs/libgomp/Environment-Variables.html)
-
-While there can be substantial gains from adjusting the environment variables,
-which is discussed below, the simplified advice is to set the
-`inter_op_parallelism_threads` equal to the number of physical CPUs and to set
-the following environment variables:
-
-* KMP_BLOCKTIME=0
-* KMP_AFFINITY=granularity=fine,verbose,compact,1,0
-
-Example setting MKL variables with command-line arguments:
-
-```bash
-KMP_BLOCKTIME=0 KMP_AFFINITY=granularity=fine,verbose,compact,1,0 \
-KMP_SETTINGS=1 python your_python_script.py
-```
-
-Example setting MKL variables with python `os.environ`:
-
-```python
-os.environ["KMP_BLOCKTIME"] = str(FLAGS.kmp_blocktime)
-os.environ["KMP_SETTINGS"] = str(FLAGS.kmp_settings)
-os.environ["KMP_AFFINITY"]= FLAGS.kmp_affinity
-if FLAGS.num_intra_threads > 0:
- os.environ["OMP_NUM_THREADS"]= str(FLAGS.num_intra_threads)
-
-```
-
-There are models and hardware platforms that benefit from different settings.
-Each variable that impacts performance is discussed below.
-
-* **KMP_BLOCKTIME**: The MKL default is 200ms, which was not optimal in our
- testing. 0 (0ms) was a good default for CNN based models that were tested.
- The best performance for AlexNex was achieved at 30ms and both GoogleNet and
- VGG11 performed best set at 1ms.
-
-* **KMP_AFFINITY**: The recommended setting is
- `granularity=fine,verbose,compact,1,0`.
-
-* **OMP_NUM_THREADS**: This defaults to the number of physical cores.
- Adjusting this parameter beyond matching the number of cores can have an
- impact when using Intel® Xeon Phi™ (Knights Landing) for some models. See
- [TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
- for optimal settings.
-
-* **intra_op_parallelism_threads**: Setting this equal to the number of
- physical cores is recommended. Setting the value to 0, which is the default,
- results in the value being set to the number of logical cores - this is an
- alternate option to try for some architectures. This value and `OMP_NUM_THREADS`
- should be equal.
-
-* **inter_op_parallelism_threads**: Setting this equal to the number of
- sockets is recommended. Setting the value to 0, which is the default,
- results in the value being set to the number of logical cores.
-
-### Comparing compiler optimizations
-
-Collected below are performance results running training and inference on
-different types of CPUs on different platforms with various compiler
-optimizations. The models used were ResNet-50
-([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)) and
-InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)).
-
-For each test, when the MKL optimization was used the environment variable
-KMP_BLOCKTIME was set to 0 (0ms) and KMP_AFFINITY to
-`granularity=fine,verbose,compact,1,0`.
-
-#### Inference InceptionV3
-
-**Environment**
-
-* Instance Type: AWS EC2 m4.xlarge
-* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
-* Dataset: ImageNet
-* TensorFlow Version: 1.2.0 RC2
-* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
-
-**Batch Size: 1**
-
-Command executed for the MKL test:
-
-```bash
-python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
---kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
---batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
---data_dir=<path to ImageNet TFRecords>
-```
-
-| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
-: : : (step time) : : :
-| ------------ | ----------- | ------------ | ------------- | ------------- |
-| AVX2 | NHWC | 7.0 (142ms) | 4 | 0 |
-| MKL | NCHW | 6.6 (152ms) | 4 | 1 |
-| AVX | NHWC | 5.0 (202ms) | 4 | 0 |
-| SSE3 | NHWC | 2.8 (361ms) | 4 | 0 |
-
-**Batch Size: 32**
-
-Command executed for the MKL test:
-
-```bash
-python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
---kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
---batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
---data_dir=<path to ImageNet TFRecords>
-```
-
-| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
-: : : (step time) : : :
-| ------------ | ----------- | ------------- | ------------- | ------------- |
-| MKL | NCHW | 10.3 | 4 | 1 |
-: : : (3,104ms) : : :
-| AVX2 | NHWC | 7.5 (4,255ms) | 4 | 0 |
-| AVX | NHWC | 5.1 (6,275ms) | 4 | 0 |
-| SSE3 | NHWC | 2.8 (11,428ms)| 4 | 0 |
-
-#### Inference ResNet-50
-
-**Environment**
-
-* Instance Type: AWS EC2 m4.xlarge
-* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
-* Dataset: ImageNet
-* TensorFlow Version: 1.2.0 RC2
-* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
-
-**Batch Size: 1**
-
-Command executed for the MKL test:
-
-```bash
-python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
---kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
---batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
---data_dir=<path to ImageNet TFRecords>
-```
-
-| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
-: : : (step time) : : :
-| ------------ | ----------- | ------------ | ------------- | ------------- |
-| AVX2 | NHWC | 8.8 (113ms) | 4 | 0 |
-| MKL | NCHW | 8.5 (120ms) | 4 | 1 |
-| AVX | NHWC | 6.4 (157ms) | 4 | 0 |
-| SSE3 | NHWC | 3.7 (270ms) | 4 | 0 |
-
-**Batch Size: 32**
-
-Command executed for the MKL test:
-
-```bash
-python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
---kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
---batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
---data_dir=<path to ImageNet TFRecords>
-```
-
-| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
-: : : (step time) : : :
-| ------------ | ----------- | ------------- | ------------- | ------------- |
-| MKL | NCHW | 12.4 | 4 | 1 |
-: : : (2,590ms) : : :
-| AVX2 | NHWC | 10.4 (3,079ms)| 4 | 0 |
-| AVX | NHWC | 7.3 (4,4416ms)| 4 | 0 |
-| SSE3 | NHWC | 4.0 (8,054ms) | 4 | 0 |
-
-#### Training InceptionV3
-
-**Environment**
-
-* Instance Type: Dedicated AWS EC2 r4.16xlarge (Broadwell)
-* CPU: Intel Xeon E5-2686 v4 (Broadwell) Processors
-* Dataset: ImageNet
-* TensorFlow Version: 1.2.0 RC2
-* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
-
-Command executed for MKL test:
-
-```bash
-python tf_cnn_benchmarks.py --device=cpu --mkl=True --kmp_blocktime=0 \
---nodistortions --model=resnet50 --data_format=NCHW --batch_size=32 \
---num_inter_threads=2 --num_intra_threads=36 \
---data_dir=<path to ImageNet TFRecords>
-```
-
-Optimization | Data Format | Images/Sec | Intra threads | Inter Threads
------------- | ----------- | ---------- | ------------- | -------------
-MKL | NCHW | 20.8 | 36 | 2
-AVX2 | NHWC | 6.2 | 36 | 0
-AVX | NHWC | 5.7 | 36 | 0
-SSE3 | NHWC | 4.3 | 36 | 0
-
-ResNet and [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
-were also run on this configuration but in an ad hoc manner. There were not
-enough runs executed to publish a coherent table of results. The incomplete
-results strongly indicated the final result would be similar to the table above
-with MKL providing significant 3x+ gains over AVX2.
diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md
deleted file mode 100644
index 66bf684d5b..0000000000
--- a/tensorflow/docs_src/performance/performance_models.md
+++ /dev/null
@@ -1,422 +0,0 @@
-# High-Performance Models
-
-This document and accompanying
-[scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
-detail how to build highly scalable models that target a variety of system types
-and network topologies. The techniques in this document utilize some low-level
-TensorFlow Python primitives. In the future, many of these techniques will be
-incorporated into high-level APIs.
-
-## Input Pipeline
-
-The @{$performance_guide$Performance Guide} explains how to identify possible
-input pipeline issues and best practices. We found that using `tf.FIFOQueue`
-and `tf.train.queue_runner` could not saturate multiple current generation GPUs
-when using large inputs and processing with higher samples per second, such
-as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf).
-This is due to the use of Python threads as its underlying implementation. The
-overhead of Python threads is too large.
-
-Another approach, which we have implemented in the
-[scripts](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks),
-is to build an input pipeline using the native parallelism in TensorFlow. Our
-implementation is made up of 3 stages:
-
-* I/O reads: Choose and read image files from disk.
-* Image Processing: Decode image records into images, preprocess, and organize
- into mini-batches.
-* CPU-to-GPU Data Transfer: Transfer images from CPU to GPU.
-
-The dominant part of each stage is executed in parallel with the other stages
-using `data_flow_ops.StagingArea`. `StagingArea` is a queue-like operator
-similar to `tf.FIFOQueue`. The difference is that `StagingArea` does not
-guarantee FIFO ordering, but offers simpler functionality and can be executed
-on both CPU and GPU in parallel with other stages. Breaking the input pipeline
-into 3 stages that operate independently in parallel is scalable and takes full
-advantage of large multi-core environments. The rest of this section details
-the stages followed by details about using `data_flow_ops.StagingArea`.
-
-### Parallelize I/O Reads
-
-`data_flow_ops.RecordInput` is used to parallelize reading from disk. Given a
-list of input files representing TFRecords, `RecordInput` continuously reads
-records using background threads. The records are placed into its own large
-internal pool and when it has loaded at least half of its capacity, it produces
-output tensors.
-
-This op has its own internal threads that are dominated by I/O time that consume
-minimal CPU, which allows it to run smoothly in parallel with the rest of the
-model.
-
-### Parallelize Image Processing
-
-After images are read from `RecordInput` they are passed as tensors to the image
-processing pipeline. To make the image processing pipeline easier to explain,
-assume that the input pipeline is targeting 8 GPUs with a batch size of 256 (32
-per GPU).
-
-256 records are read and processed individually in parallel. This starts with
-256 independent `RecordInput` read ops in the graph. Each read op is followed by
-an identical set of ops for image preprocessing that are considered independent
-and executed in parallel. The image preprocessing ops include operations such as
-image decoding, distortion, and resizing.
-
-Once the images are through preprocessing, they are concatenated together into 8
-tensors each with a batch-size of 32. Rather than using `tf.concat` for this
-purpose, which is implemented as a single op that waits for all the inputs to be
-ready before concatenating them together, `tf.parallel_stack` is used.
-`tf.parallel_stack` allocates an uninitialized tensor as an output, and each
-input tensor is written to its designated portion of the output tensor as soon
-as the input is available.
-
-When all the input tensors are finished, the output tensor is passed along in
-the graph. This effectively hides all the memory latency with the long tail of
-producing all the input tensors.
-
-### Parallelize CPU-to-GPU Data Transfer
-
-Continuing with the assumption that the target is 8 GPUs with a batch size of
-256 (32 per GPU). Once the input images are processed and concatenated together
-by the CPU, we have 8 tensors each with a batch-size of 32.
-
-TensorFlow enables tensors from one device to be used on any other device
-directly. TensorFlow inserts implicit copies to make the tensors available on
-any devices where they are used. The runtime schedules the copy between devices
-to run before the tensors are actually used. However, if the copy cannot finish
-in time, the computation that needs those tensors will stall and result in
-decreased performance.
-
-In this implementation, `data_flow_ops.StagingArea` is used to explicitly
-schedule the copy in parallel. The end result is that when computation starts on
-the GPU, all the tensors are already available.
-
-### Software Pipelining
-
-With all the stages capable of being driven by different processors,
-`data_flow_ops.StagingArea` is used between them so they run in parallel.
-`StagingArea` is a queue-like operator similar to `tf.FIFOQueue` that offers
-simpler functionalities that can be executed on both CPU and GPU.
-
-Before the model starts running all the stages, the input pipeline stages are
-warmed up to prime the staging buffers in between with one set of data.
-During each run step, one set of data is read from the staging buffers at
-the beginning of each stage, and one set is pushed at the end.
-
-For example: if there are three stages: A, B and C. There are two staging areas
-in between: S1 and S2. During the warm up, we run:
-
-```
-Warm up:
-Step 1: A0
-Step 2: A1 B0
-
-Actual execution:
-Step 3: A2 B1 C0
-Step 4: A3 B2 C1
-Step 5: A4 B3 C2
-```
-
-After the warm up, S1 and S2 each have one set of data in them. For each step of
-the actual execution, one set of data is consumed from each staging area, and
-one set is added to each.
-
-Benefits of using this scheme:
-
-* All stages are non-blocking, since the staging areas always have one set of
- data after the warm up.
-* Each stage can run in parallel since they can all start immediately.
-* The staging buffers have a fixed memory overhead. They will have at most one
- extra set of data.
-* Only a single`session.run()` call is needed to run all stages of the step,
- which makes profiling and debugging much easier.
-
-## Best Practices in Building High-Performance Models
-
-Collected below are a couple of additional best practices that can improve
-performance and increase the flexibility of models.
-
-### Build the model with both NHWC and NCHW
-
-Most TensorFlow operations used by a CNN support both NHWC and NCHW data format.
-On GPU, NCHW is faster. But on CPU, NHWC is sometimes faster.
-
-Building a model to support both data formats keeps the model flexible and
-capable of operating optimally regardless of platform. Most TensorFlow
-operations used by a CNN support both NHWC and NCHW data formats. The benchmark
-script was written to support both NCHW and NHWC. NCHW should always be used
-when training with GPUs. NHWC is sometimes faster on CPU. A flexible model can
-be trained on GPUs using NCHW with inference done on CPU using NHWC with the
-weights obtained from training.
-
-### Use Fused Batch-Normalization
-
-The default batch-normalization in TensorFlow is implemented as composite
-operations. This is very general, but often leads to suboptimal performance. An
-alternative is to use fused batch-normalization which often has much better
-performance on GPU. Below is an example of using `tf.contrib.layers.batch_norm`
-to implement fused batch-normalization.
-
-```python
-bn = tf.contrib.layers.batch_norm(
- input_layer, fused=True, data_format='NCHW'
- scope=scope)
-```
-
-## Variable Distribution and Gradient Aggregation
-
-During training, training variable values are updated using aggregated gradients
-and deltas. In the benchmark script, we demonstrate that with the flexible and
-general-purpose TensorFlow primitives, a diverse range of high-performance
-distribution and aggregation schemes can be built.
-
-Three examples of variable distribution and aggregation were included in the
-script:
-
-* `parameter_server` where each replica of the training model reads the
- variables from a parameter server and updates the variable independently.
- When each model needs the variables, they are copied over through the
- standard implicit copies added by the TensorFlow runtime. The example
- [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
- illustrates using this method for local training, distributed synchronous
- training, and distributed asynchronous training.
-* `replicated` places an identical copy of each training variable on each
- GPU. The forward and backward computation can start immediately as the
- variable data is immediately available. Gradients are accumulated across all
- GPUs, and the aggregated total is applied to each GPU's copy of the
- variables to keep them in sync.
-* `distributed_replicated` places an identical copy of the training parameters
- on each GPU along with a master copy on the parameter servers. The forward
- and backward computation can start immediately as the variable data is
- immediately available. Gradients are accumulated across all GPUs on each
- server and then the per-server aggregated gradients are applied to the
- master copy. After all workers do this, each worker updates its copy of the
- variable from the master copy.
-
-Below are additional details about each approach.
-
-### Parameter Server Variables
-
-The most common way trainable variables are managed in TensorFlow models is
-parameter server mode.
-
-In a distributed system, each worker process runs the same model, and parameter
-server processes own the master copies of the variables. When a worker needs a
-variable from a parameter server, it refers to it directly. The TensorFlow
-runtime adds implicit copies to the graph to make the variable value available
-on the computation device that needs it. When a gradient is computed on a
-worker, it is sent to the parameter server that owns the particular variable,
-and the corresponding optimizer is used to update the variable.
-
-There are some techniques to improve throughput:
-
-* The variables are spread among parameter servers based on their size, for
- load balancing.
-* When each worker has multiple GPUs, gradients are accumulated across the
- GPUs and a single aggregated gradient is sent to the parameter server. This
- reduces the network bandwidth and the amount of work done by the parameter
- servers.
-
-For coordinating between workers, a very common mode is async updates, where
-each worker updates the master copy of the variables without synchronizing with
-other workers. In our model, we demonstrate that it is fairly easy to introduce
-synchronization across workers so updates for all workers are finished in one
-step before the next step can start.
-
-The parameter server method can also be used for local training, In this case,
-instead of spreading the master copies of variables across parameters servers,
-they are either on the CPU or spread across the available GPUs.
-
-Due to the simple nature of this setup, this architecture has gained a lot of
-popularity within the community.
-
-This mode can be used in the script by passing
-`--variable_update=parameter_server`.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" alt="parameter_server mode in distributed training"
- src="../images/perf_parameter_server_mode_doc.png">
-</div>
-
-### Replicated Variables
-
-In this design, each GPU on the server has its own copy of each variable. The
-values are kept in sync across GPUs by applying the fully aggregated gradient to
-each GPU's copy of the variable.
-
-The variables and data are available at the start of training, so the forward
-pass of training can start immediately. Gradients are aggregated across the
-devices and the fully aggregated gradient is then applied to each local copy.
-
-Gradient aggregation across the server can be done in different ways:
-
-* Using standard TensorFlow operations to accumulate the total on a single
- device (CPU or GPU) and then copy it back to all GPUs.
-* Using NVIDIA® NCCL, described below in the NCCL section.
-
-This mode can be used in the script by passing `--variable_update=replicated`.
-
-### Replicated Variables in Distributed Training
-
-The replicated method for variables can be extended to distributed training. One
-way to do this like the replicated mode: aggregate the gradients fully across
-the cluster and apply them to each local copy of the variable. This may be shown
-in a future version of this scripts; the scripts do present a different
-variation, described here.
-
-In this mode, in addition to each GPU's copy of the variables, a master copy is
-stored on the parameter servers. As with the replicated mode, training can start
-immediately using the local copies of the variables.
-
-As the gradients of the weights become available, they are sent back to the
-parameter servers and all local copies are updated:
-
-1. All the gradients from the GPU on the same worker are aggregated together.
-2. Aggregated gradients from each worker are sent to the parameter server that
- owns the variable, where the specified optimizer is used to update the
- master copy of the variable.
-3. Each worker updates its local copy of the variable from the master. In the
- example model, this is done with a cross-replica barrier that waits for all
- the workers to finish updating the variables, and fetches the new variable
- only after the barrier has been released by all replicas. Once the copy
- finishes for all variables, this marks the end of a training step, and a new
- step can start.
-
-Although this sounds similar to the standard use of parameter servers, the
-performance is often better in many cases. This is largely due to the fact the
-computation can happen without any delay, and much of the copy latency of early
-gradients can be hidden by later computation layers.
-
-This mode can be used in the script by passing
-`--variable_update=distributed_replicated`.
-
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" alt="distributed_replicated mode"
- src="../images/perf_distributed_replicated_mode_doc.png">
-</div>
-
-#### NCCL
-
-In order to broadcast variables and aggregate gradients across different GPUs
-within the same host machine, we can use the default TensorFlow implicit copy
-mechanism.
-
-However, we can instead use the optional NCCL (`tf.contrib.nccl`) support. NCCL
-is an NVIDIA® library that can efficiently broadcast and aggregate data across
-different GPUs. It schedules a cooperating kernel on each GPU that knows how to
-best utilize the underlying hardware topology; this kernel uses a single SM of
-the GPU.
-
-In our experiment, we demonstrate that although NCCL often leads to much faster
-data aggregation by itself, it doesn't necessarily lead to faster training. Our
-hypothesis is that the implicit copies are essentially free since they go to the
-copy engine on GPU, as long as its latency can be hidden by the main computation
-itself. Although NCCL can transfer data faster, it takes one SM away, and adds
-more pressure to the underlying L2 cache. Our results show that for 8-GPUs, NCCL
-often leads to better performance. However, for fewer GPUs, the implicit copies
-often perform better.
-
-#### Staged Variables
-
-We further introduce a staged-variable mode where we use staging areas for both
-the variable reads, and their updates. Similar to software pipelining of the
-input pipeline, this can hide the data copy latency. If the computation time
-takes longer than the copy and aggregation, the copy itself becomes essentially
-free.
-
-The downside is that all the weights read are from the previous training step.
-So it is a different algorithm from SGD. But it is possible to improve its
-convergence by adjusting learning rate and other hyperparameters.
-
-## Executing the script
-
-This section lists the core command line arguments and a few basic examples for
-executing the main script
-([tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)).
-
-> Note: `tf_cnn_benchmarks.py` uses the config `force_gpu_compatible`,
-> which was introduced after TensorFlow 1.1. Until TensorFlow 1.2 is released
-> building from source is advised.
-
-#### Base command line arguments
-
-* **`model`**: Model to use, e.g. `resnet50`, `inception3`, `vgg16`, and
- `alexnet`.
-* **`num_gpus`**: Number of GPUs to use.
-* **`data_dir`**: Path to data to process. If not set, synthetic data is used.
- To use ImageNet data use these
- [instructions](https://github.com/tensorflow/models/tree/master/research/inception#getting-started)
- as a starting point.
-* **`batch_size`**: Batch size for each GPU.
-* **`variable_update`**: The method for managing variables: `parameter_server`
- ,`replicated`, `distributed_replicated`, `independent`
-* **`local_parameter_device`**: Device to use as parameter server: `cpu` or
- `gpu`.
-
-#### Single instance examples
-
-```bash
-# VGG16 training ImageNet with 8 GPUs using arguments that optimize for
-# Google Compute Engine.
-python tf_cnn_benchmarks.py --local_parameter_device=cpu --num_gpus=8 \
---batch_size=32 --model=vgg16 --data_dir=/home/ubuntu/imagenet/train \
---variable_update=parameter_server --nodistortions
-
-# VGG16 training synthetic ImageNet data with 8 GPUs using arguments that
-# optimize for the NVIDIA DGX-1.
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=vgg16 --variable_update=replicated --use_nccl=True
-
-# VGG16 training ImageNet data with 8 GPUs using arguments that optimize for
-# Amazon EC2.
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=vgg16 --variable_update=parameter_server
-
-# ResNet-50 training ImageNet data with 8 GPUs using arguments that optimize for
-# Amazon EC2.
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=resnet50 --variable_update=replicated --use_nccl=False
-
-```
-
-#### Distributed command line arguments
-
-* **`ps_hosts`**: Comma separated list of hosts to use as parameter servers
- in the format of ```<host>:port```, e.g. ```10.0.0.2:50000```.
-* **`worker_hosts`**: Comma separated list of hosts to use as workers in the
- format of ```<host>:port```, e.g. ```10.0.0.2:50001```.
-* **`task_index`**: Index of the host in the list of `ps_hosts` or
- `worker_hosts` being started.
-* **`job_name`**: Type of job, e.g `ps` or `worker`
-
-#### Distributed examples
-
-Below is an example of training ResNet-50 on 2 hosts: host_0 (10.0.0.1) and
-host_1 (10.0.0.2). The example uses synthetic data. To use real data pass the
-`--data_dir` argument.
-
-```bash
-# Run the following commands on host_0 (10.0.0.1):
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
---job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
---worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
-
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
---job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
---worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=0
-
-
-# Run the following commands on host_1 (10.0.0.2):
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
---job_name=worker --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
---worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
-
-python tf_cnn_benchmarks.py --local_parameter_device=gpu --num_gpus=8 \
---batch_size=64 --model=resnet50 --variable_update=distributed_replicated \
---job_name=ps --ps_hosts=10.0.0.1:50000,10.0.0.2:50000 \
---worker_hosts=10.0.0.1:50001,10.0.0.2:50001 --task_index=1
-
-```
diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md
deleted file mode 100644
index 4499f5715c..0000000000
--- a/tensorflow/docs_src/performance/quantization.md
+++ /dev/null
@@ -1,253 +0,0 @@
-# Fixed Point Quantization
-
-Quantization techniques store and calculate numbers in more compact formats.
-[TensorFlow Lite](/mobile/tflite/) adds quantization that uses an 8-bit fixed
-point representation.
-
-Since a challenge for modern neural networks is optimizing for high accuracy, the
-priority has been improving accuracy and speed during training. Using floating
-point arithmetic is an easy way to preserve accuracy and GPUs are designed to
-accelerate these calculations.
-
-However, as more machine learning models are deployed to mobile devices,
-inference efficiency has become a critical issue. Where the computational demand
-for *training* grows with the amount of models trained on different
-architectures, the computational demand for *inference* grows in proportion to
-the amount of users.
-
-## Quantization benefits
-
-
-Using 8-bit calculations help your models run faster and use less power. This is
-especially important for mobile devices and embedded applications that can't run
-floating point code efficiently, for example, Internet of Things (IoT) and
-robotics devices. There are additional opportunities to extend this support to
-more backends and research lower precision networks.
-
-### Smaller file sizes {: .hide-from-toc}
-
-Neural network models require a lot of space on disk. For example, the original
-AlexNet requires over 200 MB for the float format—almost all of that for the
-model's millions of weights. Because the weights are slightly different
-floating point numbers, simple compression formats perform poorly (like zip).
-
-Weights fall in large layers of numerical values. For each layer, weights tend to
-be normally distributed within a range. Quantization can shrink file sizes by
-storing the minimum and maximum weight for each layer, then compress each
-weight's float value to an 8-bit integer representing the closest real number in
-a linear set of 256 within the range.
-
-### Faster inference {: .hide-from-toc}
-
-Since calculations are run entirely on 8-bit inputs and outputs, quantization
-reduces the computational resources needed for inference calculations. This is
-more involved, requiring changes to all floating point calculations, but results
-in a large speed-up for inference time.
-
-### Memory efficiency {: .hide-from-toc}
-
-Since fetching 8-bit values only requires 25% of the memory bandwidth of floats,
-more efficient caches avoid bottlenecks for RAM access. In many cases, the power
-consumption for running a neural network is dominated by memory access. The
-savings from using fixed-point 8-bit weights and activations are significant.
-
-Typically, SIMD operations are available that run more operations per clock
-cycle. In some cases, a DSP chip is available that accelerates 8-bit calculations
-resulting in a massive speedup.
-
-## Fixed point quantization techniques
-
-The goal is to use the same precision for weights and activations during both
-training and inference. But an important difference is that training consists of
-a forward pass and a backward pass, while inference only uses a forward pass.
-When we train the model with quantization in the loop, we ensure that the forward
-pass matches precision for both training and inference.
-
-To minimize the loss in accuracy for fully fixed point models (weights and
-activations), train the model with quantization in the loop. This simulates
-quantization in the forward pass of a model so weights tend towards values that
-perform better during quantized inference. The backward pass uses quantized
-weights and activations and models quantization as a straight through estimator.
-(See Bengio et al., [2013](https://arxiv.org/abs/1308.3432))
-
-Additionally, the minimum and maximum values for activations are determined
-during training. This allows a model trained with quantization in the loop to be
-converted to a fixed point inference model with little effort, eliminating the
-need for a separate calibration step.
-
-## Quantization training with TensorFlow
-
-TensorFlow can train models with quantization in the loop. Because training
-requires small gradient adjustments, floating point values are still used. To
-keep models as floating point while adding the quantization error in the training
-loop, @{$array_ops#Fake_quantization$fake quantization} nodes simulate the
-effect of quantization in the forward and backward passes.
-
-Since it's difficult to add these fake quantization operations to all the
-required locations in the model, there's a function available that rewrites the
-training graph. To create a fake quantized training graph:
-
-```
-# Build forward pass of model.
-loss = tf.losses.get_total_loss()
-
-# Call the training rewrite which rewrites the graph in-place with
-# FakeQuantization nodes and folds batchnorm for training. It is
-# often needed to fine tune a floating point model for quantization
-# with this training tool. When training from scratch, quant_delay
-# can be used to activate quantization after training to converge
-# with the float graph, effectively fine-tuning the model.
-tf.contrib.quantize.create_training_graph(quant_delay=2000000)
-
-# Call backward pass optimizer as usual.
-optimizer = tf.train.GradientDescentOptimizer(learning_rate)
-optimizer.minimize(loss)
-```
-
-The rewritten *eval graph* is non-trivially different from the *training graph*
-since the quantization ops affect the batch normalization step. Because of this,
-we've added a separate rewrite for the *eval graph*:
-
-```
-# Build eval model
-logits = tf.nn.softmax_cross_entropy_with_logits_v2(...)
-
-# Call the eval rewrite which rewrites the graph in-place with
-# FakeQuantization nodes and fold batchnorm for eval.
-tf.contrib.quantize.create_eval_graph()
-
-# Save the checkpoint and eval graph proto to disk for freezing
-# and providing to TFLite.
-with open(eval_graph_file, ‘w’) as f:
- f.write(str(g.as_graph_def()))
-saver = tf.train.Saver()
-saver.save(sess, checkpoint_name)
-```
-
-Methods to rewrite the training and eval graphs are an active area of research
-and experimentation. Although rewrites and quantized training might not work or
-improve performance for all models, we are working to generalize these
-techniques.
-
-## Generating fully quantized models
-
-The previously demonstrated after-rewrite eval graph only *simulates*
-quantization. To generate real fixed point computations from a trained
-quantization model, convert it to a fixed point kernel. Tensorflow Lite supports
-this conversion from the graph resulting from `create_eval_graph`.
-
-First, create a frozen graph that will be the input for the TensorFlow Lite
-toolchain:
-
-```
-bazel build tensorflow/python/tools:freeze_graph && \
- bazel-bin/tensorflow/python/tools/freeze_graph \
- --input_graph=eval_graph_def.pb \
- --input_checkpoint=checkpoint \
- --output_graph=frozen_eval_graph.pb --output_node_names=outputs
-```
-
-Provide this to the TensorFlow Lite Optimizing Converter (TOCO) to get a fully
-quantized TensorFLow Lite model:
-
-```
-bazel build tensorflow/contrib/lite/toco:toco && \
- ./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
- --input_file=frozen_eval_graph.pb \
- --output_file=tflite_model.tflite \
- --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
- --inference_type=QUANTIZED_UINT8 \
- --input_shape="1,224, 224,3" \
- --input_array=input \
- --output_array=outputs \
- --std_value=127.5 --mean_value=127.5
-```
-
-See the documentation for `tf.contrib.quantize` and
-[TensorFlow Lite](/mobile/tflite/).
-
-## Quantized accuracy
-
-Fixed point [MobileNet](https://arxiv.org/abs/1704.0486) models are released with
-8-bit weights and activations. Using the rewriters, these models achieve the
-Top-1 accuracies listed in Table 1. For comparison, the floating point accuracies
-are listed for the same models. The code used to generate these models
-[is available](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)
-along with links to all of the pretrained mobilenet_v1 models.
-
-<figure>
- <table>
- <tr>
- <th>Image Size</th>
- <th>Depth</th>
- <th>Top-1 Accuracy:<br>Floating point</th>
- <th>Top-1 Accuracy:<br>Fixed point: 8 bit weights and activations</th>
- </tr>
- <tr><td>128</td><td>0.25</td><td>0.415</td><td>0.399</td></tr>
- <tr><td>128</td><td>0.5</td><td>0.563</td><td>0.549</td></tr>
- <tr><td>128</td><td>0.75</td><td>0.621</td><td>0.598</td></tr>
- <tr><td>128</td><td>1</td><td>0.652</td><td>0.64</td></tr>
- <tr><td>160</td><td>0.25</td><td>0.455</td><td>0.435</td></tr>
- <tr><td>160</td><td>0.5</td><td>0.591</td><td>0.577</td></tr>
- <tr><td>160</td><td>0.75</td><td>0.653</td><td>0.639</td></tr>
- <tr><td>160</td><td>1</td><td>0.68</td><td>0.673</td></tr>
- <tr><td>192</td><td>0.25</td><td>0.477</td><td>0.458</td></tr>
- <tr><td>192</td><td>0.5</td><td>0.617</td><td>0.604</td></tr>
- <tr><td>192</td><td>0.75</td><td>0.672</td><td>0.662</td></tr>
- <tr><td>192</td><td>1</td><td>0.7</td><td>0.69</td></tr>
- <tr><td>224</td><td>0.25</td><td>0.498</td><td>0.482</td></tr>
- <tr><td>224</td><td>0.5</td><td>0.633</td><td>0.622</td></tr>
- <tr><td>224</td><td>0.75</td><td>0.684</td><td>0.679</td></tr>
- <tr><td>224</td><td>1</td><td>0.709</td><td>0.697</td></tr>
- </table>
- <figcaption>
- <b>Table 1</b>: MobileNet Top-1 accuracy on Imagenet Validation dataset.
- </figcaption>
-</figure>
-
-## Representation for quantized tensors
-
-TensorFlow approaches the conversion of floating-point arrays of numbers into
-8-bit representations as a compression problem. Since the weights and activation
-tensors in trained neural network models tend to have values that are distributed
-across comparatively small ranges (for example, -15 to +15 for weights or -500 to
-1000 for image model activations). And since neural nets tend to be robust
-handling noise, the error introduced by quantizing to a small set of values
-maintains the precision of the overall results within an acceptable threshold. A
-chosen representation must perform fast calculations, especially the large matrix
-multiplications that comprise the bulk of the computations while running a model.
-
-This is represented with two floats that store the overall minimum and maximum
-values corresponding to the lowest and highest quantized value. Each entry in the
-quantized array represents a float value in that range, distributed linearly
-between the minimum and maximum. For example, with a minimum of -10.0 and maximum
-of 30.0f, and an 8-bit array, the quantized values represent the following:
-
-<figure>
- <table>
- <tr><th>Quantized</th><th>Float</th></tr>
- <tr><td>0</td><td>-10.0</td></tr>
- <tr><td>128</td><td>10.0</td></tr>
- <tr><td>255</td><td>30.0</td></tr>
- </table>
- <figcaption>
- <b>Table 2</b>: Example quantized value range
- </figcaption>
-</figure>
-
-The advantages of this representation format are:
-
-* It efficiently represents an arbitrary magnitude of ranges.
-* The values don't have to be symmetrical.
-* The format represents both signed and unsigned values.
-* The linear spread makes multiplications straightforward.
-
-Alternative techniques use lower bit depths by non-linearly distributing the
-float values across the representation, but currently are more expensive in terms
-of computation time. (See Han et al.,
-[2016](https://arxiv.org/abs/1510.00149).)
-
-The advantage of having a clear definition of the quantized format is that it's
-always possible to convert back and forth from fixed-point to floating-point for
-operations that aren't quantization-ready, or to inspect the tensors for
-debugging.
diff --git a/tensorflow/docs_src/performance/xla/broadcasting.md b/tensorflow/docs_src/performance/xla/broadcasting.md
deleted file mode 100644
index 7018ded53f..0000000000
--- a/tensorflow/docs_src/performance/xla/broadcasting.md
+++ /dev/null
@@ -1,204 +0,0 @@
-# Broadcasting semantics
-
-This document describes how the broadcasting semantics in XLA work.
-
-## What is broadcasting?
-
-Broadcasting is the process of making arrays with different shapes have
-compatible shapes for arithmetic operations. The terminology is borrowed from
-Numpy
-[(broadcasting)](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
-
-Broadcasting may be required for operations between multi-dimensional arrays of
-different ranks, or between multi-dimensional arrays with different but
-compatible shapes. Consider the addition `X+v` where `X` is a matrix (an array
-of rank 2) and `v` is a vector (an array of rank 1). To perform element-wise
-addition, XLA needs to "broadcast" the vector `v` to the same rank as the
-matrix `X`, by replicating `v` a certain number of times. The vector's length
-has to match at least one of the dimensions of the matrix.
-
-For example:
-
- |1 2 3| + |7 8 9|
- |4 5 6|
-
-The matrix's dimensions are (2,3), the vector's are (3). The vector is broadcast
-by replicating it over rows to get:
-
- |1 2 3| + |7 8 9| = |8 10 12|
- |4 5 6| |7 8 9| |11 13 15|
-
-In Numpy, this is called [broadcasting]
-(http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
-
-## Principles
-
-The XLA language is as strict and explicit as possible, avoiding implicit and
-"magical" features. Such features may make some computations slightly easier to
-define, at the cost of more assumptions baked into user code that will be
-difficult to change in the long term. If necessary, implicit and magical
-features can be added in client-level wrappers.
-
-In regards to broadcasting, explicit broadcasting specifications on operations
-between arrays of different ranks is required. This is different from Numpy,
-which infers the specification when possible.
-
-## Broadcasting a lower-rank array onto a higher-rank array
-
-*Scalars* can always be broadcast over arrays without an explicit specification
-of broadcasting dimensions. An element-wise binary operation between a scalar
-and an array means applying the operation with the scalar for each element in
-the array. For example, adding a scalar to a matrix means producing a matrix
-each element of which is a sum of the scalar with the corresponding input
-matrix's element.
-
- |1 2 3| + 7 = |8 9 10|
- |4 5 6| |11 12 13|
-
-Most broadcasting needs can be captured by using a tuple of dimensions on a
-binary operation. When the inputs to the operation have different ranks, this
-broadcasting tuple specifies which dimension(s) in the **higher-rank** array to
-match with the **lower-rank** array.
-
-Consider the previous example, instead of adding a scalar to a (2,3) matrix, add
-a vector of dimension (3) to a matrix of dimensions (2,3). *Without specifying
-broadcasting, this operation is invalid.* To correctly request matrix-vector
-addition, specify the broadcasting dimension to be (1), meaning the vector's
-dimension is matched to dimension 1 of the matrix. In 2D, if dimension 0 is
-considered as rows and dimension 1 as columns, this means that each element of
-the vector becomes a column of a size matching the number of rows in the matrix:
-
- |7 8 9| ==> |7 8 9|
- |7 8 9|
-
-As a more complex example, consider adding a 3-element vector (dimension (3)) to
-a 3x3 matrix (dimensions (3,3)). There are two ways broadcasting can happen for
-this example:
-
-(1) A broadcasting dimension of 1 can be used. Each vector element becomes a
-column and the vector is duplicated for each row in the matrix.
-
- |7 8 9| ==> |7 8 9|
- |7 8 9|
- |7 8 9|
-
-(2) A broadcasting dimension of 0 can be used. Each vector element becomes a row
-and the vector is duplicated for each column in the matrix.
-
- |7| ==> |7 7 7|
- |8| |8 8 8|
- |9| |9 9 9|
-
-> Note: when adding a 2x3 matrix to a 3-element vector, a broadcasting dimension
-> of 0 is invalid.
-
-The broadcasting dimensions can be a tuple that describes how a smaller rank
-shape is broadcast into a larger rank shape. For example, given a 2x3x4 cuboid
-and a 3x4 matrix, a broadcasting tuple (1,2) means matching the matrix to
-dimensions 1 and 2 of the cuboid.
-
-This type of broadcast is used in the binary ops in `XlaBuilder`, if the
-`broadcast_dimensions` argument is given. For example, see
-[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.cc).
-In the XLA source code, this type of broadcasting is sometimes called "InDim"
-broadcasting.
-
-### Formal definition
-
-The broadcasting attribute allows matching a lower-rank array to a higher-rank
-array, by specifying which dimensions of the higher-rank array to match. For
-example, for an array with dimensions MxNxPxQ, a vector with dimension T can be
-matched as follows:
-
- MxNxPxQ
-
- dim 3: T
- dim 2: T
- dim 1: T
- dim 0: T
-
-In each case, T has to be equal to the matching dimension of the higher-rank
-array. The vector's values are then broadcast from the matched dimension to all
-the other dimensions.
-
-To match a TxV matrix onto the MxNxPxQ array, a pair of broadcasting dimensions
-are used:
-
- MxNxPxQ
- dim 2,3: T V
- dim 1,2: T V
- dim 0,3: T V
- etc...
-
-The order of dimensions in the broadcasting tuple has to be the order in which
-the lower-rank array's dimensions are expected to match the higher-rank array's
-dimensions. The first element in the tuple says which dimension in the
-higher-rank array has to match dimension 0 in the lower-rank array. The second
-element for dimension 1, and so on. The order of broadcast dimensions has to be
-strictly increasing. For example, in the previous example it is illegal to match
-V to N and T to P; it is also illegal to match V to both P and N.
-
-## Broadcasting similar-rank arrays with degenerate dimensions
-
-A related broadcasting problem is broadcasting two arrays that have the same
-rank but different dimension sizes. Similarly to Numpy's rules, this is only
-possible when the arrays are *compatible*. Two arrays are compatible when all
-their dimensions are compatible. Two dimensions are compatible if:
-
-* They are equal, or
-* One of them is 1 (a "degenerate" dimension)
-
-When two compatible arrays are encountered, the result shape has the maximum
-among the two inputs at every dimension index.
-
-Examples:
-
-1. (2,1) and (2,3) broadcast to (2,3).
-2. (1,2,5) and (7,2,5) broadcast to (7,2,5)
-3. (7,2,5) and (7,1,5) broadcast to (7,2,5)
-4. (7,2,5) and (7,2,6) are incompatible and cannot be broadcast.
-
-A special case arises, and is also supported, where each of the input arrays has
-a degenerate dimension at a different index. In this case, the result is an
-"outer operation": (2,1) and (1,3) broadcast to (2,3). For more examples,
-consult the [Numpy documentation on
-broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
-
-## Broadcast composition
-
-Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting
-using degenerate dimensions can both be performed in the same binary operation.
-For example, a vector of size 4 and an matrix of size 1x2 can be added together
-using broadcast dimensions value of (0):
-
- |1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
-
-First the vector is broadcast up to rank 2 (matrix) using the broadcast
-dimensions. The single value (0) in the broadcast dimensions indicates that
-dimension zero of the vector matches to dimension zero of the matrix. This
-produces an matrix of size 4xM where the value M is chosen to match the
-corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is
-produced:
-
- |1 1| + [5 6]
- |2 2|
- |3 3|
- |4 4|
-
-Then "degenerate dimension broadcasting" broadcasts dimension zero of the 1x2
-matrix to match the corresponding dimension size of the right hand side:
-
- |1 1| + |5 6| |6 7|
- |2 2| + |5 6| = |7 8|
- |3 3| + |5 6| |8 9|
- |4 4| + |5 6| |9 10|
-
-A more complicated example is a matrix of size 1x2 added to an array of size
-4x3x1 using broadcast dimensions of (1, 2). First the 1x2 matrix is broadcast up
-to rank 3 using the broadcast dimensions to produces an intermediate Mx1x2 array
-where the dimension size M is determined by the size of the larger operand (the
-4x3x1 array) producing a 4x1x2 intermediate array. The M is at dimension 0
-(left-most dimension) because the dimensions 1 and 2 are mapped to the
-dimensions of the original 1x2 matrix as the broadcast dimension are (1, 2).
-This intermediate array can be added to the 4x3x1 matrix using broadcasting of
-degenerate dimensions to produce a 4x3x2 array result.
diff --git a/tensorflow/docs_src/performance/xla/developing_new_backend.md b/tensorflow/docs_src/performance/xla/developing_new_backend.md
deleted file mode 100644
index 840f6983c2..0000000000
--- a/tensorflow/docs_src/performance/xla/developing_new_backend.md
+++ /dev/null
@@ -1,77 +0,0 @@
-# Developing a new backend for XLA
-
-This preliminary guide is for early adopters that want to easily retarget
-TensorFlow to their hardware in an efficient manner. The guide is not
-step-by-step and assumes knowledge of [LLVM](http://llvm.org),
-[Bazel](https://bazel.build/), and TensorFlow.
-
-XLA provides an abstract interface that a new architecture or accelerator can
-implement to create a backend to run TensorFlow graphs. Retargeting XLA should
-be significantly simpler and scalable than implementing every existing
-TensorFlow Op for new hardware.
-
-Most implementations will fall into one of the following scenarios:
-
-1. Existing CPU architecture not yet officially supported by XLA, with or
- without an existing [LLVM](http://llvm.org) backend.
-2. Non-CPU-like hardware with an existing LLVM backend.
-3. Non-CPU-like hardware without an existing LLVM backend.
-
-> Note: An LLVM backend can mean either one of the officially released LLVM
-> backends or a custom LLVM backend developed in-house.
-
-## Scenario 1: Existing CPU architecture not yet officially supported by XLA
-
-In this scenario, start by looking at the existing [XLA CPU backend]
-(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/cpu/).
-XLA makes it easy to retarget TensorFlow to different CPUs by using LLVM, since
-the main difference between XLA backends for CPUs is the code generated by LLVM.
-Google tests XLA for x64 and ARM64 architectures.
-
-If the hardware vendor has an LLVM backend for their hardware, it is simple to
-link the backend with the LLVM built with XLA. In JIT mode, the XLA CPU backend
-emits code for the host CPU. For ahead-of-time compilation,
-[`xla::AotCompilationOptions`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h)
-can provide an LLVM triple to configure the target architecture.
-
-If there is no existing LLVM backend but another kind of code generator exists,
-it should be possible to reuse most of the existing CPU backend.
-
-## Scenario 2: Non-CPU-like hardware with an existing LLVM backend
-
-It is possible to model a new
-[`xla::Compiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h)
-implementation on the existing [`xla::CPUCompiler`]
-(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc)
-and [`xla::GPUCompiler`]
-(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc)
-classes, since these already emit LLVM IR. Depending on the nature of the
-hardware, it is possible that many of the LLVM IR generation aspects will have
-to be changed, but a lot of code can be shared with the existing backends.
-
-A good example to follow is the [GPU backend]
-(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/)
-of XLA. The GPU backend targets a non-CPU-like ISA, and therefore some aspects
-of its code generation are unique to the GPU domain. Other kinds of hardware,
-e.g. DSPs like Hexagon (which has an upstream LLVM backend), can reuse parts of
-the LLVM IR emission logic, but other parts will be unique.
-
-## Scenario 3: Non-CPU-like hardware without an existing LLVM backend
-
-If it is not possible to utilize LLVM, then the best option is to implement a
-new backend for XLA for the desired hardware. This option requires the most
-effort. The classes that need to be implemented are as follows:
-
-* [`StreamExecutor`](https://www.tensorflow.org/code/tensorflow/stream_executor/stream_executor.h):
- For many devices not all methods of `StreamExecutor` are needed. See
- existing `StreamExecutor` implementations for details.
-* [`xla::Compiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h):
- This class encapsulates the compilation of an HLO computation into an
- `xla::Executable`.
-* [`xla::Executable`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/executable.h):
- This class is used to launch a compiled computation on the platform.
-* [`xla::TransferManager`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/transfer_manager.h):
- This class enables backends to provide platform-specific mechanisms for
- constructing XLA literal data from given device memory handles. In other
- words, it helps encapsulate the transfer of data from the host to the device
- and back.
diff --git a/tensorflow/docs_src/performance/xla/index.md b/tensorflow/docs_src/performance/xla/index.md
deleted file mode 100644
index 8f5de83ea6..0000000000
--- a/tensorflow/docs_src/performance/xla/index.md
+++ /dev/null
@@ -1,98 +0,0 @@
-# XLA Overview
-
-<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:50%" src="/images/xlalogo.png">
-</div>
-
-> Note: XLA is experimental and considered alpha. Most use cases will not
-> see improvements in performance (speed or decreased memory usage). We have
-> released XLA early so the Open Source Community can contribute to its
-> development, as well as create a path for integration with hardware
-> accelerators.
-
-XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear
-algebra that optimizes TensorFlow computations. The results are improvements in
-speed, memory usage, and portability on server and mobile platforms. Initially,
-most users will not see large benefits from XLA, but are welcome to experiment
-by using XLA via @{$jit$just-in-time (JIT) compilation} or @{$tfcompile$ahead-of-time (AOT) compilation}. Developers targeting new hardware accelerators are
-especially encouraged to try out XLA.
-
-The XLA framework is experimental and in active development. In particular,
-while it is unlikely that the semantics of existing operations will change, it
-is expected that more operations will be added to cover important use cases. The
-team welcomes feedback from the community about missing functionality and
-community contributions via GitHub.
-
-## Why did we build XLA?
-
-We had several objectives for XLA to work with TensorFlow:
-
-* *Improve execution speed.* Compile subgraphs to reduce the execution time of
- short-lived Ops to eliminate overhead from the TensorFlow runtime, fuse
- pipelined operations to reduce memory overhead, and specialize to known
- tensor shapes to allow for more aggressive constant propagation.
-
-* *Improve memory usage.* Analyze and schedule memory usage, in principle
- eliminating many intermediate storage buffers.
-
-* *Reduce reliance on custom Ops.* Remove the need for many custom Ops by
- improving the performance of automatically fused low-level Ops to match the
- performance of custom Ops that were fused by hand.
-
-* *Reduce mobile footprint.* Eliminate the TensorFlow runtime by ahead-of-time
- compiling the subgraph and emitting an object/header file pair that can be
- linked directly into another application. The results can reduce the
- footprint for mobile inference by several orders of magnitude.
-
-* *Improve portability.* Make it relatively easy to write a new backend for
- novel hardware, at which point a large fraction of TensorFlow programs will
- run unmodified on that hardware. This is in contrast with the approach of
- specializing individual monolithic Ops for new hardware, which requires
- TensorFlow programs to be rewritten to make use of those Ops.
-
-## How does XLA work?
-
-The input language to XLA is called "HLO IR", or just HLO (High Level
-Optimizer). The semantics of HLO are described on the
-@{$operation_semantics$Operation Semantics} page. It
-is most convenient to think of HLO as a [compiler
-IR](https://en.wikipedia.org/wiki/Intermediate_representation).
-
-XLA takes graphs ("computations") defined in HLO and compiles them into machine
-instructions for various architectures. XLA is modular in the sense that it is
-easy to slot in an alternative backend to @{$developing_new_backend$target some novel HW architecture}. The CPU backend for x64 and ARM64 as
-well as the NVIDIA GPU backend are in the TensorFlow source tree.
-
-The following diagram shows the compilation process in XLA:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img src="https://www.tensorflow.org/images/how-does-xla-work.png">
-</div>
-
-XLA comes with several optimizations and analysis passes that are
-target-independent, such as
-[CSE](https://en.wikipedia.org/wiki/Common_subexpression_elimination),
-target-independent operation fusion, and buffer analysis for allocating runtime
-memory for the computation.
-
-After the target-independent step, XLA sends the HLO computation to a backend.
-The backend can perform further HLO-level optimizations, this time with target
-specific information and needs in mind. For example, the XLA GPU backend may
-perform operation fusion beneficial specifically for the GPU programming model
-and determine how to partition the computation into streams. At this stage,
-backends may also pattern-match certain operations or combinations thereof to
-optimized library calls.
-
-The next step is target-specific code generation. The CPU and GPU backends
-included with XLA use [LLVM](http://llvm.org) for low-level IR, optimization,
-and code-generation. These backends emit the LLVM IR necessary to represent the
-XLA HLO computation in an efficient manner, and then invoke LLVM to emit native
-code from this LLVM IR.
-
-The GPU backend currently supports NVIDIA GPUs via the LLVM NVPTX backend; the
-CPU backend supports multiple CPU ISAs.
-
-## Supported Platforms
-
-XLA currently supports @{$jit$JIT compilation} on x86-64 and NVIDIA GPUs; and
-@{$tfcompile$AOT compilation} for x86-64 and ARM.
diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md
deleted file mode 100644
index 7202ef47f7..0000000000
--- a/tensorflow/docs_src/performance/xla/jit.md
+++ /dev/null
@@ -1,169 +0,0 @@
-# Using JIT Compilation
-
-> Note: TensorFlow must be compiled from source to include XLA.
-
-## Why use just-in-time (JIT) compilation?
-
-The TensorFlow/XLA JIT compiler compiles and runs parts of TensorFlow graphs via
-XLA. The benefit of this over the standard TensorFlow implementation is that XLA
-can fuse multiple operators (kernel fusion) into a small number of compiled
-kernels. Fusing operators can reduce memory bandwidth requirements and improve
-performance compared to executing operators one-at-a-time, as the TensorFlow
-executor does.
-
-## Running TensorFlow graphs via XLA
-
-There are two ways to run TensorFlow computations via XLA, either by
-JIT-compiling operators placed on a CPU or GPU device, or by placing operators
-on the `XLA_CPU` or `XLA_GPU` TensorFlow devices. Placing operators directly on
-a TensorFlow XLA device forces the operator to run on that device and is mainly
-used for testing.
-
-> Note: The XLA CPU backend supports intra-op parallelism (i.e. it can shard a
-> single operation across multiple cores) but it does not support inter-op
-> parallelism (i.e. it cannot execute independent operations concurrently across
-> multiple cores). The XLA GPU backend is competitive with the standard
-> TensorFlow implementation, sometimes faster, sometimes slower.
-
-### Turning on JIT compilation
-
-JIT compilation can be turned on at the session level or manually for select
-operations. Both of these approaches are zero-copy --- data does not need to be
-copied when passing data between a compiled XLA kernel and a TensorFlow operator
-placed on the same device.
-
-#### Session
-
-Turning on JIT compilation at the session level will result in all possible
-operators being greedily compiled into XLA computations. Each XLA computation
-will be compiled into one or more kernels for the underlying device.
-
-Subject to a few constraints, if there are two adjacent operators in the graph
-that both have XLA implementations, then they will be compiled into a single XLA
-computation.
-
-JIT compilation is turned on at the session level by setting the
-`global_jit_level` config to `tf.OptimizerOptions.ON_1` and passing the config
-during session initialization.
-
-```python
-# Config to turn on JIT compilation
-config = tf.ConfigProto()
-config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
-
-sess = tf.Session(config=config)
-```
-
-> Note: Turning on JIT at the session level will not result in operations being
-> compiled for the CPU. JIT compilation for CPU operations must be done via
-> the manual method documented below.
-
-#### Manual
-
-JIT compilation can also be turned on manually for one or more operators. This
-is done by tagging the operators to compile with the attribute
-`_XlaCompile=true`. The simplest way to do this is via the
-`tf.contrib.compiler.jit.experimental_jit_scope()` scope defined in
-[`tensorflow/contrib/compiler/jit.py`](https://www.tensorflow.org/code/tensorflow/contrib/compiler/jit.py).
-Example usage:
-
-```python
- jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
-
- x = tf.placeholder(np.float32)
- with jit_scope():
- y = tf.add(x, x) # The "add" will be compiled with XLA.
-```
-
-The `_XlaCompile` attribute is currently supported on a best-effort basis. If an
-operator cannot be compiled, TensorFlow will silently fall back to the normal
-implementation.
-
-### Placing operators on XLA devices
-
-Another way to run computations via XLA is to place an operator on a specific
-XLA device. This method is normally only used for testing. Valid targets are
-`XLA_CPU` or `XLA_GPU`.
-
-```python
-with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
- output = tf.add(input1, input2)
-```
-
-Unlike JIT compilation on the standard CPU and GPU devices, these devices make a
-copy of data when it is transferred on and off the device. The extra copy makes
-it expensive to mix XLA and TensorFlow operators in the same graph.
-
-## Tutorial
-
-This tutorial covers training a simple version of MNIST softmax with JIT turned
-on. Currently JIT at the session level, which is what is used for the tutorial,
-only supports GPU.
-
-Before starting the tutorial verify that the LD_LIBRARY environment variable or
-ldconfig contains `$CUDA_ROOT/extras/CUPTI/lib64`, which contains libraries for
-the CUDA Profiling Tools Interface [(CUPTI)](http://docs.nvidia.com/cuda/cupti/index.html).
-TensorFlow uses CUPTI to pull tracing information from the GPU.
-
-### Step #1: Prepare sample script
-
-Download or move
-[mnist_softmax_xla.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py)
-into a folder outside of the TensorFlow source tree.
-
-### Step #2: Run without XLA
-
-Execute the python script to train the model without XLA.
-
-```shell
-python mnist_softmax_xla.py --xla=''
-```
-
-Using the Chrome Trace Event Profiler (browse to chrome://tracing),
-open the timeline file created when the script finishes: `timeline.ctf.json`.
-The rendered timeline should look similar to the picture below with multiple
-green boxes labeled `MatMul`, possibly across multiple CPUs.
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/jit_timeline_gpu.png">
-</div>
-
-### Step #3 Run with XLA
-
-Execute the python script to train the model with XLA and turn on a debugging
-feature of XLA via an environmental variable that outputs the XLA graph.
-
-```shell
-TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py
-```
-
-Open the timeline file created (`timeline.ctf.json`). The rendered timeline
-should look similar to the picture below with one long bar labeled `XlaLaunch`.
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/jit_timeline_gpu_xla.png">
-</div>
-
-To understand what is happening in `XlaLaunch`, look at the console output for
-statements similar to the following:
-
-```shell
-computation cluster_0[_XlaCompiledKernel=true,_XlaNumConstantArgs=1].v82 [CPU:
-pipeline start, before inline]: /tmp/hlo_graph_0.dot
-
-```
-
-The console statements point to the location of `hlo_graph_xx.dot` files that
-contain information about the graph created by XLA. The process that XLA takes
-to fuse Ops is visible by starting at `hlo_graph_0.dot` and viewing each diagram
-in succession.
-
-To Render the .dot file into a png, install
-[GraphViz](https://www.graphviz.org/download/) and run:
-
-```shell
-dot -Tpng hlo_graph_80.dot -o hlo_graph_80.png
-```
-
-The result will look like the following:
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/jit_gpu_xla_graph.png">
-</div>
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
deleted file mode 100644
index 8c9d26fcbb..0000000000
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ /dev/null
@@ -1,2433 +0,0 @@
-# Operation Semantics
-
-The following describes the semantics of operations defined in the
-[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-interface. Typically, these operations map one-to-one to operations defined in
-the RPC interface in
-[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
-
-A note on nomenclature: the generalized data type XLA deals with is an
-N-dimensional array holding elements of some uniform type (such as 32-bit
-float). Throughout the documentation, *array* is used to denote an
-arbitrary-dimensional array. For convenience, special cases have more specific
-and familiar names; for example a *vector* is a 1-dimensional array and a
-*matrix* is a 2-dimensional array.
-
-## AllToAll
-
-See also
-[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Alltoall is a collective operation that sends data from all cores to all cores.
-It has two phases:
-
-1. the scatter phase. On each core, the operand is split into `split_count`
- number of blocks along the `split_dimensions`, and the blocks are scattered
- to all cores, e.g., the ith block is send to the ith core.
-2. the gather phase. Each core concatenates the received blocks along the
- `concat_dimension`.
-
-The participating cores can be configured by:
-
-- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty,
- all replicas belong to one group in the order of 0 - (n-1). Alltoall will be
- applied within subgroups in the specified order. For example, replica
- groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica
- 1, 2, 3, and in the gather phase, the received blocks will be concatenated
- in the order of 1, 2, 3; another Alltoall will be applied within replica 4,
- 5, 0, and the concatenation order is 4, 5, 0.
-
-Prerequisites:
-
-- The dimension size of the operand on the split_dimension is divisible by
- split_count.
-- The operand's shape is not tuple.
-
-<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
-replica_groups)` </b>
-
-
-| Arguments | Type | Semantics |
-| ------------------ | --------------------- | ------------------------------- |
-| `operand` | `XlaOp` | n dimensional input array |
-| `split_dimension` | `int64` | A value in the interval `[0, |
-: : : n)` that names the dimension :
-: : : along which the operand is :
-: : : split :
-| `concat_dimension` | `int64` | a value in the interval `[0, |
-: : : n)` that names the dimension :
-: : : along which the split blocks :
-: : : are concatenated :
-| `split_count` | `int64` | the number of cores that |
-: : : participate this operation. If :
-: : : `replica_groups` is empty, this :
-: : : should be the number of :
-: : : replicas; otherwise, this :
-: : : should be equal to the number :
-: : : of replicas in each group. :
-| `replica_groups` | `ReplicaGroup` vector | each group contains a list of |
-: : : replica id. :
-
-Below shows an example of Alltoall.
-
-```
-XlaBuilder b("alltoall");
-auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
-AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/xla/ops_alltoall.png">
-</div>
-
-In this example, there are 4 cores participating the Alltoall. On each core, the
-operand is split into 4 parts along dimension 0, so each part has shape
-f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
-the received parts along dimension 1, in the order or core 0-4. So the output on
-each core has shape f32[16,4].
-
-## BatchNormGrad
-
-See also
-[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Calculates gradients of batch norm.
-
-<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `XlaOp` | n dimensional array to be |
-: : : normalized (x) :
-| `scale` | `XlaOp` | 1 dimensional array |
-: : : (\\(\gamma\\)) :
-| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
-| `variance` | `XlaOp` | 1 dimensional array |
-: : : (\\(\sigma^2\\)) :
-| `grad_output` | `XlaOp` | Gradients passed to |
-: : : `BatchNormTraining` :
-: : : (\\( \nabla y\\)) :
-| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) |
-| `feature_index` | `int64` | Index to feature dimension in |
-: : : `operand` :
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the gradients with
-respect to `operand`, `offset` and `scale` across all the other dimensions. The
-`feature_index` must be a valid index for the feature dimension in `operand`.
-
-The three gradients are defined by the following formulas (assuming a
-4-dimensional tensor as `operand` and with feature dimension index \\(l\\),
-batch size `m` and spatial sizes `w` and `h`):
-
-\\[ \begin{split} c_l&=
-\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
-\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
-\\\\
-\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
-\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
-\right)
-\\\\
-\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
-\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
-\\\\\
-\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
-\end{split} \\]
-
-The inputs `mean` and `variance` represent moments value
-across batch and spatial dimensions.
-
-The output type is a tuple of three handles:
-
-| Outputs | Type | Semantics |
-| ------------- | ----------------------- | --------------------------------- |
-| `grad_operand` | `XlaOp` | gradient with respect to input |
-: : : `operand` (\\( \nabla x\\)) :
-| `grad_scale` | `XlaOp` | gradient with respect to input |
-: : : `scale` (\\( \nabla \gamma\\)) :
-| `grad_offset` | `XlaOp` | gradient with respect to input |
-: : : `offset`(\\( \nabla \beta\\)) :
-
-## BatchNormInference
-
-See also
-[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Normalizes an array across batch and spatial dimensions.
-
-<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | ---------------------------------------
-`operand` | `XlaOp` | n dimensional array to be normalized
-`scale` | `XlaOp` | 1 dimensional array
-`offset` | `XlaOp` | 1 dimensional array
-`mean` | `XlaOp` | 1 dimensional array
-`variance` | `XlaOp` | 1 dimensional array
-`epsilon` | `float` | Epsilon value
-`feature_index` | `int64` | Index to feature dimension in `operand`
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the mean and variance
-across all the other dimensions and uses the mean and variance to normalize each
-element in `operand`. The `feature_index` must be a valid index for the feature
-dimension in `operand`.
-
-`BatchNormInference` is equivalent to calling `BatchNormTraining` without
-computing `mean` and `variance` for each batch. It uses the input `mean` and
-`variance` instead as estimated values. The purpose of this op is to reduce
-latency in inference, hence the name `BatchNormInference`.
-
-The output is an n-dimensional, normalized array with the same shape as input
-`operand`.
-
-## BatchNormTraining
-
-See also
-[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Normalizes an array across batch and spatial dimensions.
-
-<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | ----------------------------------------
-`operand` | `XlaOp` | n dimensional array to be normalized (x)
-`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\))
-`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\))
-`epsilon` | `float` | Epsilon value (\\(\epsilon\\))
-`feature_index` | `int64` | Index to feature dimension in `operand`
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the mean and variance
-across all the other dimensions and uses the mean and variance to normalize each
-element in `operand`. The `feature_index` must be a valid index for the feature
-dimension in `operand`.
-
-The algorithm goes as follows for each batch in `operand` \\(x\\) that
-contains `m` elements with `w` and `h` as the size of spatial dimensions
-(assuming `operand` is an 4 dimensional array):
-
-- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension:
-\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\)
-
-- Calculates batch variance \\(\sigma^2_l\\):
-\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\)
-
-- Normalizes, scales and shifts:
-\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\)
-
-The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
-
-The output type is a tuple of three `XlaOp`s:
-
-| Outputs | Type | Semantics |
-| ------------ | ----------------------- | -------------------------------------|
-| `output` | `XlaOp` | n dimensional array with the same |
-: : : shape as input `operand` (y) :
-| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
-| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) |
-
-The `batch_mean` and `batch_var` are moments calculated across the batch and
-spatial dimensions using the formulas above.
-
-## BitcastConvertType
-
-See also
-[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
-operation from a data shape to a target shape. The dimensions must match, and
-the conversion is an element-wise one; e.g. `s32` elements become `f32` elements
-via bitcast routine. Bitcast is implemented as a low-level cast, so machines
-with different floating-point representations will give different results.
-
-<b> `BitcastConvertType(operand, new_element_type)` </b>
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`operand` | `XlaOp` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
-
-The dimensions of the operand and the target shape must match. The bit-width of
-the source and destination element types must be equal. The source
-and destination element types must not be tuples.
-
-## Broadcast
-
-See also
-[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Adds dimensions to an array by duplicating the data in the array.
-
-<b> `Broadcast(operand, broadcast_sizes)` </b>
-
-Arguments | Type | Semantics
------------------ | ------------------- | -------------------------------
-`operand` | `XlaOp` | The array to duplicate
-`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
-
-The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
-values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
-the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`.
-
-The new dimensions index into copies of the operand, i.e.
-
-```
-output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
-```
-
-For example, if `operand` is a scalar `f32` with value `2.0f`, and
-`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
-`f32[2, 3]` and all the values in the result will be `2.0f`.
-
-## Call
-
-See also
-[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Invokes a computation with the given arguments.
-
-<b> `Call(computation, args...)` </b>
-
-| Arguments | Type | Semantics |
-| ------------- | ---------------------- | ----------------------------------- |
-| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., |
-: : : T_N -> S` with N parameters of :
-: : : arbitrary type :
-| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type |
-
-The arity and types of the `args` must match the parameters of the
-`computation`. It is allowed to have no `args`.
-
-## Clamp
-
-See also
-[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Clamps an operand to within the range between a minimum and maximum value.
-
-<b> `Clamp(min, operand, max)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | ---------------
-`min` | `XlaOp` | array of type T
-`operand` | `XlaOp` | array of type T
-`max` | `XlaOp` | array of type T
-
-Given an operand and minimum and maximum values, returns the operand if it is in
-the range between the minimum and maximum, else returns the minimum value if the
-operand is below this range or the maximum value if the operand is above this
-range. That is, `clamp(a, x, b) = min(max(a, x), b)`.
-
-All three arrays must be the same shape. Alternatively, as a restricted form of
-[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`.
-
-Example with scalar `min` and `max`:
-
-```
-let operand: s32[3] = {-1, 5, 9};
-let min: s32 = 0;
-let max: s32 = 6;
-==>
-Clamp(min, operand, max) = s32[3]{0, 5, 6};
-```
-
-## Collapse
-
-See also
-[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and the `tf.reshape` operation.
-
-Collapses dimensions of an array into one dimension.
-
-<b> `Collapse(operand, dimensions)` </b>
-
-Arguments | Type | Semantics
------------- | -------------- | -----------------------------------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
-
-Collapse replaces the given subset of the operand's dimensions by a single
-dimension. The input arguments are an arbitrary array of type T and a
-compile-time-constant vector of dimension indices. The dimension indices must be
-an in-order (low to high dimension numbers), consecutive subset of T's
-dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but
-{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the
-same position in the dimension sequence as those they replace, with the new
-dimension size equal to the product of original dimension sizes. The lowest
-dimension number in `dimensions` is the slowest varying dimension (most major)
-in the loop nest which collapses these dimension, and the highest dimension
-number is fastest varying (most minor). See the `tf.reshape` operator
-if more general collapse ordering is needed.
-
-For example, let v be an array of 24 elements:
-
-```
-let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
- {{20, 21, 22}, {25, 26, 27}},
- {{30, 31, 32}, {35, 36, 37}},
- {{40, 41, 42}, {45, 46, 47}}};
-
-// Collapse to a single dimension, leaving one dimension.
-let v012 = Collapse(v, {0,1,2});
-then v012 == f32[24] {10, 11, 12, 15, 16, 17,
- 20, 21, 22, 25, 26, 27,
- 30, 31, 32, 35, 36, 37,
- 40, 41, 42, 45, 46, 47};
-
-// Collapse the two lower dimensions, leaving two dimensions.
-let v01 = Collapse(v, {0,1});
-then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
- {20, 21, 22, 25, 26, 27},
- {30, 31, 32, 35, 36, 37},
- {40, 41, 42, 45, 46, 47}};
-
-// Collapse the two higher dimensions, leaving two dimensions.
-let v12 = Collapse(v, {1,2});
-then v12 == f32[8x3] {{10, 11, 12},
- {15, 16, 17},
- {20, 21, 22},
- {25, 26, 27},
- {30, 31, 32},
- {35, 36, 37},
- {40, 41, 42},
- {45, 46, 47}};
-
-```
-
-## Concatenate
-
-See also
-[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Concatenate composes an array from multiple array operands. The array is of the
-same rank as each of the input array operands (which must be of the same rank as
-each other) and contains the arguments in the order that they were specified.
-
-<b> `Concatenate(operands..., dimension)` </b>
-
-| Arguments | Type | Semantics |
-| ----------- | --------------------- | -------------------------------------- |
-| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions |
-: : : [L0, L1, ...]. Requires N >= 1. :
-| `dimension` | `int64` | A value in the interval `[0, N)` that |
-: : : names the dimension to be concatenated :
-: : : between the `operands`. :
-
-With the exception of `dimension` all dimensions must be the same. This is
-because XLA does not support "ragged" arrays. Also note that rank-0 values
-cannot be concatenated (as it's impossible to name the dimension along which the
-concatenation occurs).
-
-1-dimensional example:
-
-```
-Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
->>> {2, 3, 4, 5, 6, 7}
-```
-
-2-dimensional example:
-
-```
-let a = {
- {1, 2},
- {3, 4},
- {5, 6},
-};
-let b = {
- {7, 8},
-};
-Concat({a, b}, 0)
->>> {
- {1, 2},
- {3, 4},
- {5, 6},
- {7, 8},
-}
-```
-
-Diagram:
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_concatenate.png">
-</div>
-
-## Conditional
-
-See also
-[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Conditional(pred, true_operand, true_computation, false_operand,
-false_computation)` </b>
-
-Arguments | Type | Semantics
-------------------- | ---------------- | ---------------------------------
-`pred` | `XlaOp` | Scalar of type `PRED`
-`true_operand` | `XlaOp` | Argument of type `T_0`
-`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S`
-`false_operand` | `XlaOp` | Argument of type `T_1`
-`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S`
-
-Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
-is `false`, and returns the result.
-
-The `true_computation` must take in a single argument of type `T_0` and will be
-invoked with `true_operand` which must be of the same type. The
-`false_computation` must take in a single argument of type `T_1` and will be
-invoked with `false_operand` which must be of the same type. The type of the
-returned value of `true_computation` and `false_computation` must be the same.
-
-Note that only one of `true_computation` and `false_computation` will be
-executed depending on the value of `pred`.
-
-## Conv (convolution)
-
-See also
-[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
-either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
-the output has the same shape as the input when not taking striding into
-account. VALID padding simply means no padding.
-
-## ConvWithGeneralPadding (convolution)
-
-See also
-[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Computes a convolution of the kind used in neural networks. Here, a convolution
-can be thought of as a n-dimensional window moving across a n-dimensional base
-area and a computation is performed for each possible position of the window.
-
-| Arguments | Type | Semantics |
-| --------------------- | -------------------- | ----------------------------- |
-| `lhs` | `XlaOp` | rank n+2 array of inputs |
-| `rhs` | `XlaOp` | rank n+2 array of kernel |
-: : : weights :
-| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
-| `padding` | `ArraySlice< | n-d array of (low, high) |
-: : pair<int64, int64>>` : padding :
-| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
-| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
-| `feature_group_count` | int64 | the number of feature groups |
-
-Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
-array describing the base area. This is called the input, even though of course
-the rhs is also an input. In a neural network, these are the input activations.
-The n+2 dimensions are, in this order:
-
-* `batch`: Each coordinate in this dimension represents an independent input
- for which convolution is carried out.
-* `z/depth/features`: Each (y,x) position in the base area has a vector
- associated to it, which goes into this dimension.
-* `spatial_dims`: Describes the `n` spatial dimensions that define the base
- area that the window moves across.
-
-The `rhs` argument is a rank n+2 array describing the convolutional
-filter/kernel/window. The dimensions are, in this order:
-
-* `output-z`: The `z` dimension of the output.
-* `input-z`: The size of this dimension times `feature_group_count` should
- equal the size of the `z` dimension in lhs.
-* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
- window that moves across the base area.
-
-The `window_strides` argument specifies the stride of the convolutional window
-in the spatial dimensions. For example, if the stride in the first spatial
-dimension is 3, then the window can only be placed at coordinates where the
-first spatial index is divisible by 3.
-
-The `padding` argument specifies the amount of zero padding to be applied to the
-base area. The amount of padding can be negative -- the absolute value of
-negative padding indicates the number of elements to remove from the specified
-dimension before doing the convolution. `padding[0]` specifies the padding for
-dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each
-pair has the low padding as the first element and the high padding as the second
-element. The low padding is applied in the direction of lower indices while the
-high padding is applied in the direction of higher indices. For example, if
-`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and
-by 3 zeroes on the right in the second spatial dimension. Using padding is
-equivalent to inserting those same zero values into the input (`lhs`) before
-doing the convolution.
-
-The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to
-be applied to the lhs and rhs, respectively, in each spatial dimension. If the
-dilation factor in a spatial dimension is d, then d-1 holes are implicitly
-placed between each of the entries in that dimension, increasing the size of the
-array. The holes are filled with a no-op value, which for convolution means
-zeroes.
-
-Dilation of the rhs is also called atrous convolution. For more details, see
-`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
-convolution. For more details, see `tf.nn.conv2d_transpose`.
-
-The `feature_group_count` argument (default value 1) can be used for grouped
-convolutions. `feature_group_count` needs to be a divisor of both the input and
-the output feature dimension. If `feature_group_count` is greater than 1, it
-means that conceptually the input and output feature dimension and the `rhs`
-output feature dimension are split evenly into `feature_group_count` many
-groups, each group consisting of a consecutive subsequence of features. The
-input feature dimension of `rhs` needs to be equal to the `lhs` input feature
-dimension divided by `feature_group_count` (so it already has the size of a
-group of input features). The i-th groups are used together to compute
-`feature_group_count` many separate convolutions. The results of these
-convolutions are concatenated together in the output feature dimension.
-
-For depthwise convolution the `feature_group_count` argument would be set to the
-input feature dimension, and the filter would be reshaped from
-`[filter_height, filter_width, in_channels, channel_multiplier]` to
-`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
-details, see `tf.nn.depthwise_conv2d`.
-
-The output shape has these dimensions, in this order:
-
-* `batch`: Same size as `batch` on the input (`lhs`).
-* `z`: Same size as `output-z` on the kernel (`rhs`).
-* `spatial_dims`: One value for each valid placement of the convolutional
- window.
-
-The valid placements of the convolutional window are determined by the strides
-and the size of the base area after padding.
-
-To describe what a convolution does, consider a 2d convolution, and pick some
-fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a
-position of a corner of the window within the base area (e.g. the upper left
-corner, depending on how you interpret the spatial dimensions). We now have a 2d
-window, taken from the base area, where each 2d point is associated to a 1d
-vector, so we get a 3d box. From the convolutional kernel, since we fixed the
-output coordinate `z`, we also have a 3d box. The two boxes have the same
-dimensions, so we can take the sum of the element-wise products between the two
-boxes (similar to a dot product). That is the output value.
-
-Note that if `output-z` is e.g., 5, then each position of the window produces 5
-values in the output into the `z` dimension of the output. These values differ
-in what part of the convolutional kernel is used - there is a separate 3d box of
-values used for each `output-z` coordinate. So you could think of it as 5
-separate convolutions with a different filter for each of them.
-
-Here is pseudo-code for a 2d convolution with padding and striding:
-
-```
-for (b, oz, oy, ox) { // output coordinates
- value = 0;
- for (iz, ky, kx) { // kernel coordinates and input z
- iy = oy*stride_y + ky - pad_low_y;
- ix = ox*stride_x + kx - pad_low_x;
- if ((iy, ix) inside the base area considered without padding) {
- value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
- }
- }
- output(b, oz, oy, ox) = value;
-}
-```
-
-## ConvertElementType
-
-See also
-[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Similar to an element-wise `static_cast` in C++, performs an element-wise
-conversion operation from a data shape to a target shape. The dimensions must
-match, and the conversion is an element-wise one; e.g. `s32` elements become
-`f32` elements via an `s32`-to-`f32` conversion routine.
-
-<b> `ConvertElementType(operand, new_element_type)` </b>
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`operand` | `XlaOp` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
-
-The dimensions of the operand and the target shape must match. The source and
-destination element types must not be tuples.
-
-A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float
-conversion routine such as round-to-nearest-even.
-
-> Note: The precise float-to-int and visa-versa conversions are currently
-> unspecified, but may become additional arguments to the convert operation in
-> the future. Not all possible conversions have been implemented for all
->targets.
-
-```
-let a: s32[3] = {0, 1, 2};
-let b: f32[3] = convert(a, f32);
-then b == f32[3]{0.0, 1.0, 2.0}
-```
-
-## CrossReplicaSum
-
-See also
-[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Computes a sum across replicas.
-
-<b> `CrossReplicaSum(operand)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | -----------------------------
-`operand` | `XlaOp` | Array to sum across replicas.
-| `replica_group_ids` | `int64` vector | Group ID for each replica. |
-
-The output shape is the same as the input shape. For example, if there are two
-replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)`
-respectively on the two replicas, then the output value from this op will be
-`(4.0, 7.75)` on both replicas.
-
-`replica_group_ids` identifies the group ID of each replica. The group ID must
-either be empty (all replicas belong to a single group), or contain the same
-number of elements as the number of replicas. For example, if
-`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are
-four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of
-each subgroup *must* be identical, so, for example, using:
-`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid.
-
-Computing the result of CrossReplicaSum requires having one input from each
-replica, so if one replica executes a CrossReplicaSum node more times than
-another, then the former replica will wait forever. Since the replicas are all
-running the same program, there are not a lot of ways for that to happen, but it
-is possible when a while loop's condition depends on data from infeed and the
-data that is infed causes the while loop to iterate more times on one replica
-than another.
-
-## CustomCall
-
-See also
-[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Call a user-provided function within a computation.
-
-<b> `CustomCall(target_name, args..., shape)` </b>
-
-| Arguments | Type | Semantics |
-| ------------- | ---------------------- | --------------------------------- |
-| `target_name` | `string` | Name of the function. A call |
-: : : instruction will be emitted which :
-: : : targets this symbol name. :
-| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, |
-: : : which will be passed to the :
-: : : function. :
-| `shape` | `Shape` | Output shape of the function |
-
-The function signature is the same, regardless of the arity or type of args:
-
-```
-extern "C" void target_name(void* out, void** in);
-```
-
-For example, if CustomCall is used as follows:
-
-```
-let x = f32[2] {1,2};
-let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};
-
-CustomCall("myfunc", {x, y}, f32[3x3])
-```
-
-Here is an example of an implementation of `myfunc`:
-
-```
-extern "C" void myfunc(void* out, void** in) {
- float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
- float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
- EXPECT_EQ(1, x[0]);
- EXPECT_EQ(2, x[1]);
- EXPECT_EQ(10, y[0][0]);
- EXPECT_EQ(20, y[0][1]);
- EXPECT_EQ(30, y[0][2]);
- EXPECT_EQ(40, y[1][0]);
- EXPECT_EQ(50, y[1][1]);
- EXPECT_EQ(60, y[1][2]);
- float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
- z[0][0] = x[1] + y[1][0];
- // ...
-}
-```
-
-The user-provided function must not have side-effects and its execution must be
-idempotent.
-
-> Note: The opaque nature of the user-provided function restricts optimization
-> opportunities for the compiler. Try to express your computation in terms of
-> native XLA ops whenever possible; only use CustomCall as a last resort.
-
-## Dot
-
-See also
-[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Dot(lhs, rhs)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | ---------------
-`lhs` | `XlaOp` | array of type T
-`rhs` | `XlaOp` | array of type T
-
-The exact semantics of this operation depend on the ranks of the operands:
-
-| Input | Output | Semantics |
-| ----------------------- | --------------------- | ----------------------- |
-| vector [n] `dot` vector | scalar | vector dot product |
-: [n] : : :
-| matrix [m x k] `dot` | vector [m] | matrix-vector |
-: vector [k] : : multiplication :
-| matrix [m x k] `dot` | matrix [m x n] | matrix-matrix |
-: matrix [k x n] : : multiplication :
-
-The operation performs sum of products over the last dimension of `lhs` and the
-one-before-last dimension of `rhs`. These are the "contracted" dimensions. The
-contracted dimensions of `lhs` and `rhs` must be of the same size. In practice,
-it can be used to perform dot products between vectors, vector/matrix
-multiplications or matrix/matrix multiplications.
-
-## DotGeneral
-
-See also
-[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
-
-Arguments | Type | Semantics
-------------------- | --------------------- | ---------------
-`lhs` | `XlaOp` | array of type T
-`rhs` | `XlaOp` | array of type T
-`dimension_numbers` | `DotDimensionNumbers` | array of type T
-
-As Dot, but allows contracting and batch dimension numbers to be specified for
-both the 'lhs' and 'rhs'.
-
-| DotDimensionNumbers Fields | Type | Semantics
-| --------- | ----------------------- | ---------------
-| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
-| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
-| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
-| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
-
-DotGeneral performs the sum of products over contracting dimensions specified
-in 'dimension_numbers'.
-
-Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
-to be the same, but must be listed in the same order in both
-'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
-There must be exactly one contracting dimension on both 'lhs' and 'rhs'.
-
-Example with contracting dimension numbers:
-
-```
-lhs = { {1.0, 2.0, 3.0},
- {4.0, 5.0, 6.0} }
-
-rhs = { {1.0, 1.0, 1.0},
- {2.0, 2.0, 2.0} }
-
-DotDimensionNumbers dnums;
-dnums.add_lhs_contracting_dimensions(1);
-dnums.add_rhs_contracting_dimensions(1);
-
-DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
- {15.0, 30.0} }
-```
-
-Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
-dimension number, must be listed in the same order in both arrays, must
-have the same dimension sizes, and must be ordered before contracting and
-non-contracting/non-batch dimension numbers.
-
-Example with batch dimension numbers (batch size 2, 2x2 matrices):
-
-```
-lhs = { { {1.0, 2.0},
- {3.0, 4.0} },
- { {5.0, 6.0},
- {7.0, 8.0} } }
-
-rhs = { { {1.0, 0.0},
- {0.0, 1.0} },
- { {1.0, 0.0},
- {0.0, 1.0} } }
-
-DotDimensionNumbers dnums;
-dnums.add_lhs_contracting_dimensions(2);
-dnums.add_rhs_contracting_dimensions(1);
-dnums.add_lhs_batch_dimensions(0);
-dnums.add_rhs_batch_dimensions(0);
-
-DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
- {3.0, 4.0} },
- { {5.0, 6.0},
- {7.0, 8.0} } }
-```
-
-| Input | Output | Semantics |
-| ----------------------------------- | ----------------- | ---------------- |
-| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
-| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
-
-It follows that the resulting dimension number starts with the batch dimension,
-then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
-non-contracting/non-batch dimension.
-
-## DynamicSlice
-
-See also
-[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-DynamicSlice extracts a sub-array from the input array at dynamic
-`start_indices`. The size of the slice in each dimension is passed in
-`size_indices`, which specify the end point of exclusive slice intervals in each
-dimension: [start, start + size). The shape of `start_indices` must be rank ==
-1, with dimension size equal to the rank of `operand`.
-
-<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------------------- | ----------------------------------- |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `start_indices` | `XlaOp` | Rank 1 array of N integers |
-: : : containing the starting indices of :
-: : : the slice for each dimension. Value :
-: : : must be greater than or equal to :
-: : : zero. :
-| `size_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : slice size for each dimension. Each :
-: : : value must be strictly greater than :
-: : : zero, and start + size must be less :
-: : : than or equal to the size of the :
-: : : dimension to avoid wrapping modulo :
-: : : dimension size. :
-
-The effective slice indices are computed by applying the following
-transformation for each index `i` in `[1, N)` before performing the slice:
-
-```
-start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
-```
-
-This ensures that the extracted slice is always in-bounds with respect to the
-operand array. If the slice is in-bounds before the transformation is applied,
-the transformation has no effect.
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-let s = {2}
-
-DynamicSlice(a, s, {2}) produces:
- {2.0, 3.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-let s = {2, 1}
-
-DynamicSlice(b, s, {2, 2}) produces:
- { { 7.0, 8.0},
- {10.0, 11.0} }
-```
-## DynamicUpdateSlice
-
-See also
-[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-DynamicUpdateSlice generates a result which is the value of the input array
-`operand`, with a slice `update` overwritten at `start_indices`.
-The shape of `update` determines the shape of the sub-array of the result which
-is updated.
-The shape of `start_indices` must be rank == 1, with dimension size equal to
-the rank of `operand`.
-
-<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------- | ------------------------------------------------ |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `update` | `XlaOp` | N dimensional array of type T containing the |
-: : : slice update. Each dimension of update shape :
-: : : must be strictly greater than zero, and start + :
-: : : update must be less than or equal to the operand :
-: : : size for each dimension to avoid generating :
-: : : out-of-bounds update indices. :
-| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the |
-: : : starting indices of the slice for each :
-: : : dimension. Value must be greater than or equal :
-: : : to zero. :
-
-The effective slice indices are computed by applying the following
-transformation for each index `i` in `[1, N)` before performing the slice:
-
-```
-start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
-```
-
-This ensures that the updated slice is always in-bounds with respect to the
-operand array. If the slice is in-bounds before the transformation is applied,
-the transformation has no effect.
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-let u = {5.0, 6.0}
-let s = {2}
-
-DynamicUpdateSlice(a, u, s) produces:
- {0.0, 1.0, 5.0, 6.0, 4.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-let u =
- { {12.0, 13.0},
- {14.0, 15.0},
- {16.0, 17.0} }
-
-let s = {1, 1}
-
-DynamicUpdateSlice(b, u, s) produces:
- { {0.0, 1.0, 2.0},
- {3.0, 12.0, 13.0},
- {6.0, 14.0, 15.0},
- {9.0, 16.0, 17.0} }
-```
-
-## Element-wise binary arithmetic operations
-
-See also
-[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A set of element-wise binary arithmetic operations is supported.
-
-<b> `Op(lhs, rhs)` </b>
-
-Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
-(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
-(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
-
-Arguments | Type | Semantics
---------- | ------- | ----------------------------------------
-`lhs` | `XlaOp` | left-hand-side operand: array of type T
-`rhs` | `XlaOp` | right-hand-side operand: array of type T
-
-The arguments' shapes have to be either similar or compatible. See the
-@{$broadcasting$broadcasting} documentation about what it means for shapes to
-be compatible. The result of an operation has a shape which is the result of
-broadcasting the two input arrays. In this variant, operations between arrays of
-different ranks are *not* supported, unless one of the operands is a scalar.
-
-When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
-absolute value of the result is always less than the divisor's absolute value.
-
-An alternative variant with different-rank broadcasting support exists for these
-operations:
-
-<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
-
-Where `Op` is the same as above. This variant of the operation should be used
-for arithmetic operations between arrays of different ranks (such as adding a
-matrix to a vector).
-
-The additional `broadcast_dimensions` operand is a slice of integers used to
-expand the rank of the lower-rank operand up to the rank of the higher-rank
-operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to
-the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
-shape are filled with dimensions of size one. Degenerate-dimension broadcasting
-then broadcasts the shapes along these degenerate dimensions to equalize the
-shapes of both operands. The semantics are described in detail on the
-@{$broadcasting$broadcasting page}.
-
-## Element-wise comparison operations
-
-See also
-[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A set of standard element-wise binary comparison operations is supported. Note
-that standard IEEE 754 floating-point comparison semantics apply when comparing
-floating-point types.
-
-<b> `Op(lhs, rhs)` </b>
-
-Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
-(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
-(less-than).
-
-Arguments | Type | Semantics
---------- | ------- | ----------------------------------------
-`lhs` | `XlaOp` | left-hand-side operand: array of type T
-`rhs` | `XlaOp` | right-hand-side operand: array of type T
-
-The arguments' shapes have to be either similar or compatible. See the
-@{$broadcasting$broadcasting} documentation about what it means for shapes to
-be compatible. The result of an operation has a shape which is the result of
-broadcasting the two input arrays with the element type `PRED`. In this variant,
-operations between arrays of different ranks are *not* supported, unless one of
-the operands is a scalar.
-
-An alternative variant with different-rank broadcasting support exists for these
-operations:
-
-<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
-
-Where `Op` is the same as above. This variant of the operation should be used
-for comparison operations between arrays of different ranks (such as adding a
-matrix to a vector).
-
-The additional `broadcast_dimensions` operand is a slice of integers specifying
-the dimensions to use for broadcasting the operands. The semantics are described
-in detail on the @{$broadcasting$broadcasting page}.
-
-## Element-wise unary functions
-
-XlaBuilder supports these element-wise unary functions:
-
-<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
-
-<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
-
-<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
-
-<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
-
-<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
-
-<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
-i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
-of `PRED` values with the same shape as the input, where each element is `true`
-if and only if the corresponding input element is finite.
-
-<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
-
-<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
-
-<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
-
-<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
-
-$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
-
-using the comparison operator of the element type of `operand`.
-
-<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
-
-
-Arguments | Type | Semantics
---------- | ------- | ---------------------------
-`operand` | `XlaOp` | The operand to the function
-
-The function is applied to each element in the `operand` array, resulting in an
-array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
-
-## Gather
-
-The XLA gather operation stitches together several slices (each slice at a
-potentially different runtime offset) of an input tensor into an output tensor.
-
-### General Semantics
-
-See also
-[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-For a more intuitive description, see the "Informal Description" section below.
-
-<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
-
-|Arguments | Type | Semantics |
-|----------------- | ----------------------- | --------------------------------|
-|`operand` | `XlaOp` | The tensor we’re gathering |
-: : : from. :
-|`gather_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices we're :
-: : : stitching together into the :
-: : : output tensor. :
-|`index_vector_dim` | `int64` | The dimension in |
-: : : `gather_indices` that contains :
-: : : the starting indices. :
-|`output_window_dims` | `ArraySlice<int64>` | The set of dimensions in the |
-: : : output shape that are _window :
-: : : dimensions_ (defined below). :
-: : : Not all window dimensions may :
-: : : be present in the output shape. :
-|`elided_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ |
-: : : that are not present in the output shape. :
-: : : `window_bounds[i]` must be `1` for all `i` :
-: : : in `elided_window_dims`. :
-|`window_bounds` | `ArraySlice<int64>` | `window_bounds[i]` is the bounds |
-: : : for window dimension `i`. This includes :
-: : : both the window dimensions that are :
-: : : explicitly part of the output shape (via :
-: : : `output_window_dims`) and the window :
-: : : dimensions that are elided (via :
-: : : `elided_window_dims`). :
-|`gather_dims_to_operand_dims` | `ArraySlice<int64>` | A dimension map (the |
-: : : array is interpreted as mapping `i` to :
-: : : `gather_dims_to_operand_dims[i]`) from :
-: : : the gather indices in `gather_indices` to :
-: : : the operand index space. It has to be :
-: : : one-to-one and total. :
-
-For every index `Out` in the output tensor, we compute two things (more
-precisely described later):
-
- - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`,
- which gives us a starting index of a slice, _operand slice_, in the operand
- tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions
- in `gather_indices` except `index_vector_dim`.
-
- - A _window index_ that has the same rank as the operand. This index is
- composed of the values in `Out` at dimensions `output_window_dims`, embedded
- with zeroes according to `elided_window_dims`.
-
-The _window index_ is the relative index of the element in _operand slice_ that
-should be present in the output at index `Out`.
-
-The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank`
-- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type
-`ArraySlice<int64>` as the set of dimensions in the output shape but not in
-`output_window_dims`, in ascending order. E.g. if the output tensor has rank
-`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`,
-`3`}
-
-If `index_vector_dim` is equal to `gather_indices.rank` we implicitly
-consider `gather_indices` to have a trailing `1` dimension (i.e. if
-`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then
-we implicitly consider the shape of `gather_indices` to be `[6,7,1]`).
-
-The bounds for the output tensor along dimension `i` is computed as follows:
-
- 1. If `i` is present in `output_gather_dims` (i.e. is equal to
- `output_gather_dims[k]` for some `k`) then we pick the corresponding
- dimension bounds out of `gather_indices.shape`, skipping
- `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k`
- < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`]
- otherwise).
- 2. If `i` is present in `output_window_dims` (i.e. equal to
- `output_window_dims`[`k`] for some `k`) then we pick the corresponding
- bound out of `window_bounds` after accounting for `elided_window_dims`
- (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds`
- is `window_bounds` with the bounds at indices `elided_window_dims`
- removed).
-
-The operand index `In` corresponding to an output index `Out` is computed as
-follows:
-
- 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice
- out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)]
- where Combine(A, b) inserts b at position `index_vector_dim` into A.
- Note that this is well defined even if `G` is empty -- if `G` is empty then
- `S` = `gather_indices`.
- 2. Create an index, `S`<sub>`in`</sub>, into `operand` using `S` by
- scattering `S` using the `gather_dims_to_operand_dims` map
- (`S`<sub>`in`</sub> is the starting indices for _operand slice_ mentioned
- above). More precisely:
- 1. `S`<sub>`in`</sub>[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` <
- `gather_dims_to_operand_dims.size`.
- 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at the output window dimensions in `Out` according to
- the `elided_window_dims` set (`W`<sub>`in`</sub> is the _window index_
- mentioned above). More precisely:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if
- `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is
- defined below).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `In` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
- addition.
-
-`window_dims_to_operand_dims` is the monotonic function with domain [`0`,
-`output_window_dims.size`) and range [`0`, `operand.rank`) \
-`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`,
-`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then
-`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
-
-### Informal Description and Examples
-
-`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the
-examples that follow. More interesting values for `index_vector_dim`
-does not change the operation fundamentally, but makes the visual representation
-more cumbersome.
-
-To get an intuition on how all of the above fits together, let's look at an
-example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The
-position of a slice into the `[16,11]` tensor can be represented as an index
-vector of shape `S64[2]`, so the set of 5 positions can be represented as a
-`S64[5,2]` tensor.
-
-The behavior of the gather operation can then be depicted as an index
-transformation that takes [`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>], an index in
-the output shape, and maps it to an element in the input tensor in the following
-way:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_0.svg">
-</div>
-
-We first select an (`X`,`Y`) vector from the gather indices tensor using `G`.
-The element in the output tensor at index
-[`G`,`W`<sub>`0`</sub>,`W`<sub>`1`</sub>] is then the element in the input
-tensor at index [`X`+`W`<sub>`0`</sub>,`Y`+`W`<sub>`1`</sub>].
-
-`window_bounds` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
-W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
-
-This gather operation acts as a batch dynamic slice with `G` as the batch
-dimension.
-
-The gather indices may be multidimensional. For instance, a more general
-version of the example above using a "gather indices" tensor of shape `[4,5,2]`
-would translate indices like this:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_1.svg">
-</div>
-
-Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
-`G`<sub>`1`</sub> as the batch dimensions. The window bounds are still `[8,6]`.
-
-The gather operation in XLA generalizes the informal semantics outlined above in
-the following ways:
-
- 1. We can configure which dimensions in the output shape are the window
- dimensions (dimensions containing `W`<sub>`0`</sub>, `W`<sub>`1`</sub> in
- the last example). The output gather dimensions (dimensions containing
- `G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
- the output dimensions that are not window dimensions.
-
- 2. The number of output window dimensions explicitly present in the output
- shape may be smaller than the input rank. These "missing" dimensions, which
- are listed explicitly as `elided_window_dims`, must have a window bound of
- `1`. Since they have a window bound of `1` the only valid index for them is
- `0` and eliding them does not introduce ambiguity.
-
- 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last
- example) may have fewer elements than the input tensor rank, and an explicit
- mapping dictates how the index should be expanded to have the same rank as
- the input.
-
-As a final example, we use (2) and (3) to implement `tf.gather_nd`:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_2.svg">
-</div>
-
-`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
-from the gather indices tensor as usual, except the starting index has only one
-element, `X`. Similarly, there is only one output window index with the value
-`W`<sub>`0`</sub>. However, before being used as indices into the input tensor,
-these are expanded in accordance to "Gather Index Mapping"
-(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping"
-(`window_dims_to_operand_dims` in the formal description) into
-[`0`,`W`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up to
-[`X`,`W`<sub>`0`</sub>]. In other words, the output index
-[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`W`<sub>`0`</sub>] maps to the input index
-[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
-the semantics for `tf.gather_nd`.
-
-`window_bounds` for this case is `[1,11]`. Intuitively this means that every
-index `X` in the gather indices tensor picks an entire row and the result is the
-concatenation of all these rows.
-
-## GetTupleElement
-
-See also
-[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Indexes into a tuple with a compile-time-constant value.
-
-The value must be a compile-time-constant so that shape inference can determine
-the type of the resulting value.
-
-This is analogous to `std::get<int N>(t)` in C++. Conceptually:
-
-```
-let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-let s: s32 = 5;
-let t: (f32[10], s32) = tuple(v, s);
-let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
-```
-
-See also `tf.tuple`.
-
-## Infeed
-
-See also
-[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Infeed(shape)` </b>
-
-| Argument | Type | Semantics |
-| -------- | ------- | ----------------------------------------------------- |
-| `shape` | `Shape` | Shape of the data read from the Infeed interface. The |
-: : : layout field of the shape must be set to match the :
-: : : layout of the data sent to the device; otherwise its :
-: : : behavior is undefined. :
-
-Reads a single data item from the implicit Infeed streaming interface of the
-device, interpreting the data as the given shape and its layout, and returns a
-`XlaOp` of the data. Multiple Infeed operations are allowed in a
-computation, but there must be a total order among the Infeed operations. For
-example, two Infeeds in the code below have a total order since there is a
-dependency between the while loops.
-
-```
-result1 = while (condition, init = init_value) {
- Infeed(shape)
-}
-
-result2 = while (condition, init = result1) {
- Infeed(shape)
-}
-```
-
-Nested tuple shapes are not supported. For an empty tuple shape, the Infeed
-operation is effectively a no-op and proceeds without reading any data from the
-Infeed of the device.
-
-> Note: We plan to allow multiple Infeed operations without a total order, in
-> which case the compiler will provide information about how the Infeed
-> operations are serialized in the compiled program.
-
-## Iota
-
-<b> `Iota()` </b>
-
-Builds a constant literal on device rather than a potentially large host
-transfer. Creates a rank 1 tensor of values starting at zero and incrementing
-by one.
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`type` | `PrimitiveType` | type U
-`size` | `int64` | The number of elements in the tensor.
-
-## Map
-
-See also
-[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Map(operands..., computation)` </b>
-
-| Arguments | Type | Semantics |
-| ----------------- | ---------------------- | ------------------------------ |
-| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
-| `computation` | `XlaComputation` | computation of type `T_0, T_1, |
-: : : ..., T_{N + M -1} -> S` with N :
-: : : parameters of type T and M of :
-: : : arbitrary type :
-| `dimensions` | `int64` array | array of map dimensions |
-
-Applies a scalar function over the given `operands` arrays, producing an array
-of the same dimensions where each element is the result of the mapped function
-applied to the corresponding elements in the input arrays.
-
-The mapped function is an arbitrary computation with the restriction that it has
-N inputs of scalar type `T` and a single output with type `S`. The output has
-the same dimensions as the operands except that the element type T is replaced
-with S.
-
-For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <-
-computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the
-input arrays to produce the output array.
-
-## Pad
-
-See also
-[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Pad(operand, padding_value, padding_config)` </b>
-
-| Arguments | Type | Semantics |
-| ---------------- | --------------- | --------------------------------------- |
-| `operand` | `XlaOp` | array of type `T` |
-| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added |
-: : : padding :
-| `padding_config` | `PaddingConfig` | padding amount on both edges (low, |
-: : : high) and between the elements of each :
-: : : dimension :
-
-Expands the given `operand` array by padding around the array as well as between
-the elements of the array with the given `padding_value`. `padding_config`
-specifies the amount of edge padding and the interior padding for each
-dimension.
-
-`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains
-three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and
-`interior_padding`. `edge_padding_low` and `edge_padding_high` specify the
-amount of padding added at the low-end (next to index 0) and the high-end (next
-to the highest index) of each dimension respectively. The amount of edge padding
-can be negative -- the absolute value of negative padding indicates the number
-of elements to remove from the specified dimension. `interior_padding` specifies
-the amount of padding added between any two elements in each dimension. Interior
-padding occurs logically before edge padding, so in the case of negative edge
-padding elements are removed from the interior-padded operand. This operation is
-a no-op if the edge padding pairs are all (0, 0) and the interior padding values
-are all 0. The figure below shows examples of different `edge_padding` and
-`interior_padding` values for a two-dimensional array.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
-</div>
-
-## Recv
-
-See also
-[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Recv(shape, channel_handle)` </b>
-
-| Arguments | Type | Semantics |
-| ---------------- | --------------- | ------------------------------------ |
-| `shape` | `Shape` | shape of the data to receive |
-| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
-
-Receives data of the given shape from a `Send` instruction in another
-computation that shares the same channel handle. Returns a
-XlaOp for the received data.
-
-The client API of `Recv` operation represents synchronous communication.
-However, the instruction is internally decomposed into 2 HLO instructions
-(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
-[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
-
-<b>`Recv(const Shape& shape, int64 channel_id)`</b>
-
-Allocates resources required to receive data from a `Send` instruction with the
-same channel_id. Returns a context for the allocated resources, which is used
-by a following `RecvDone` instruction to wait for the completion of the data
-transfer. The context is a tuple of {receive buffer (shape), request identifier
-(U32)} and it can only be used by a `RecvDone` instruction.
-
-<b> `RecvDone(HloInstruction context)` </b>
-
-Given a context created by a `Recv` instruction, waits for the data transfer to
-complete and returns the received data.
-
-## Reduce
-
-See also
-[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Applies a reduction function to one or more arrays in parallel.
-
-<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
-
-Arguments | Type | Semantics
-------------- | --------------------- | ---------------------------------------
-`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
-`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
-`computation` | `XlaComputation` | computation of type
- : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
-`dimensions` | `int64` array | unordered array of dimensions to reduce
-
-Where:
-* N is required to be greater or equal to 1.
-* All input arrays must have the same dimensions.
-* If `N = 1`, `Collate(T)` is `T`.
-* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
-
-The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
-`T_i`, the dimensions of which are described below.
-
-This operation reduces one or more dimensions of each input array into scalars.
-The rank of each returned array is `rank(operand) - len(dimensions)`.
-`init_value` is the initial value used for every reduction and may be inserted
-anywhere during computation by the back-end. In most cases, `init_value` is an
-identity of the reduction function (for example, 0 for addition). The applied
-`computation` is always passed the `init_value` on the left-hand side.
-
-The evaluation order of the reduction function is arbitrary and may be
-non-deterministic. Therefore, the reduction function should not be overly
-sensitive to reassociation.
-
-Some reduction functions like addition are not strictly associative for floats.
-However, if the range of the data is limited, floating-point addition is close
-enough to being associative for most practical uses. It is possible to conceive
-of some completely non-associative reductions, however, and these will produce
-incorrect or unpredictable results in XLA reductions.
-
-As an example, when reducing across one dimension in a single 1D array with
-values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
-then that could be computed as
-
-`f(10, f(11, f(12, f(init_value, 13)))`
-
-but there are also many other possibilities, e.g.
-
-`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
-
-The following is a rough pseudo-code example of how reduction could be
-implemented, using summation as the reduction computation with an initial value
-of 0.
-
-```python
-result_shape <- remove all dims in dimensions from operand_shape
-
-# Iterate over all elements in result_shape. The number of r's here is equal
-# to the rank of the result
-for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
- # Initialize this result element
- result[r0, r1...] <- 0
-
- # Iterate over all the reduction dimensions
- for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
- # Increment the result element with the value of the operand's element.
- # The index of the operand's element is constructed from all ri's and di's
- # in the right order (by construction ri's and di's together index over the
- # whole operand shape).
- result[r0, r1...] += operand[ri... di]
-```
-
-Here's an example of reducing a 2D array (matrix). The shape has rank 2,
-dimension 0 of size 2 and dimension 1 of size 3:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_2d_matrix.png">
-</div>
-
-Results of reducing dimensions 0 or 1 with an "add" function:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_2d_matrix.png">
-</div>
-
-Note that both reduction results are 1D arrays. The diagram shows one as column
-and another as row just for visual convenience.
-
-For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of
-size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the
-values 1 to 6 are replicated across dimension 0.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_3d_matrix.png">
-</div>
-
-Similarly to the 2D example, we can reduce just one dimension. If we reduce
-dimension 0, for example, we get a rank-2 array where all values across
-dimension 0 were folded into a scalar:
-
-```text
-| 4 8 12 |
-| 16 20 24 |
-```
-
-If we reduce dimension 2, we also get a rank-2 array where all values across
-dimension 2 were folded into a scalar:
-
-```text
-| 6 15 |
-| 6 15 |
-| 6 15 |
-| 6 15 |
-```
-
-Note that the relative order between the remaining dimensions in the input is
-preserved in the output, but some dimensions may get assigned new numbers (since
-the rank changes).
-
-We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
-the 1D array `| 20 28 36 |`.
-
-Reducing the 3D array over all its dimensions produces the scalar `84`.
-
-When `N > 1`, reduce function application is slightly more complex, as it is
-applied simultaneously to all inputs. For example, consider the following
-reduction function, which can be used to compute the max and the argmax of a
-a 1-D tensor in parallel:
-
-```
-f: (Float, Int, Float, Int) -> Float, Int
-f(max, argmax, value, index):
- if value >= argmax:
- return (value, index)
- else:
- return (max, argmax)
-```
-
-For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
-`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
-input dimension is equivalent to the following recursive application:
-```
-f_0 = f(I_V, I_K, V_0, K_0)
-f_1 = f(f_0.first, f_0.second, V_1, K_1)
-...
-f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
-```
-
-Applying this reduction to an array of values, and an array of sequential
-indices (i.e. iota), will co-iterate over the arrays, and return a tuple
-containing the maximal value and the matching index.
-
-## ReducePrecision
-
-See also
-[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Models the effect of converting floating-point values to a lower-precision
-format (such as IEEE-FP16) and back to the original format. The number of
-exponent and mantissa bits in the lower-precision format can be specified
-arbitrarily, although all bit sizes may not be supported on all hardware
-implementations.
-
-<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | -------------------------------------------------
-`operand` | `XlaOp` | array of floating-point type `T`.
-`exponent_bits` | `int32` | number of exponent bits in lower-precision format
-`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
-
-The result is an array of type `T`. The input values are rounded to the nearest
-value representable with the given number of mantissa bits (using "ties to even"
-semantics), and any values that exceed the range specified by the number of
-exponent bits are clamped to positive or negative infinity. `NaN` values are
-retained, although they may be converted to canonical `NaN` values.
-
-The lower-precision format must have at least one exponent bit (in order to
-distinguish a zero value from an infinity, since both have a zero mantissa), and
-must have a non-negative number of mantissa bits. The number of exponent or
-mantissa bits may exceed the corresponding value for type `T`; the corresponding
-portion of the conversion is then simply a no-op.
-
-## ReduceWindow
-
-See also
-[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Applies a reduction function to all elements in each window of the input
-multi-dimensional array, producing an output multi-dimensional array with the
-same number of elements as the number of valid positions of the window. A
-pooling layer can be expressed as a `ReduceWindow`. Similar to
-[`Reduce`](#reduce), the applied `computation` is always passed the `init_value`
-on the left-hand side.
-
-<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
-window_strides, padding)` </b>
-
-| Arguments | Type | Semantics |
-| ------------------- | ------------------- | -------------------------------- |
-| `operand` | `XlaOp` | N dimensional array containing |
-: : : elements of type T. This is the :
-: : : base area on which the window is :
-: : : placed. :
-| `init_value` | `XlaOp` | Starting value for the |
-: : : reduction. See [Reduce](#reduce) :
-: : : for details. :
-| `computation` | `XlaComputation` | Reduction function of type `T, T |
-: : : -> T`, to apply to all elements :
-: : : in each window :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
-
-Below code and figure shows an example of using `ReduceWindow`. Input is a
-matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
-[2x3].
-
-```
-// Create a computation for the reduction (maximum).
-XlaComputation max;
-{
- XlaBuilder builder(client_, "max");
- auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
- auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
- builder.Max(y, x);
- max = builder.Build().ConsumeValueOrDie();
-}
-
-// Create a ReduceWindow computation with the max reduction computation.
-XlaBuilder builder(client_, "reduce_window_2x3");
-auto shape = ShapeUtil::MakeShape(F32, {4, 6});
-auto input = builder.Parameter(0, shape, "input");
-builder.ReduceWindow(
- input, *max,
- /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
- /*window_dimensions=*/{2, 3},
- /*window_stride_dimensions=*/{2, 3},
- Padding::kValid);
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_window.png">
-</div>
-
-Stride of 1 in a dimension specifies that the position of a window in the
-dimension is 1 element away from its adjacent window. In order to specify that
-no windows overlap with each other, window_stride_dimensions should be equal to
-window_dimensions. The figure below illustrates the use of two different stride
-values. Padding is applied to each dimension of the input and the calculations
-are the same as though the input came in with the dimensions it has after
-padding.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:75%" src="https://www.tensorflow.org/images/ops_reduce_window_stride.png">
-</div>
-
-The evaluation order of the reduction function is arbitrary and may be
-non-deterministic. Therefore, the reduction function should not be overly
-sensitive to reassociation. See the discussion about associativity in the
-context of [`Reduce`](#reduce) for more details.
-
-## Reshape
-
-See also
-[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and the [`Collapse`](#collapse) operation.
-
-Reshapes the dimensions of an array into a new configuration.
-
-<b> `Reshape(operand, new_sizes)` </b>
-<b> `Reshape(operand, dimensions, new_sizes)` </b>
-
-Arguments | Type | Semantics
------------- | -------------- | ---------------------------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `int64` vector | order in which dimensions are collapsed
-`new_sizes` | `int64` vector | vector of sizes of new dimensions
-
-Conceptually, reshape first flattens an array into a one-dimensional vector of
-data values, and then refines this vector into a new shape. The input arguments
-are an arbitrary array of type T, a compile-time-constant vector of dimension
-indices, and a compile-time-constant vector of dimension sizes for the result.
-The values in the `dimension` vector, if given, must be a permutation of all of
-T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of
-the dimensions in `dimensions` is from slowest-varying dimension (most major) to
-fastest-varying dimension (most minor) in the loop nest which collapses the
-input array into a single dimension. The `new_sizes` vector determines the size
-of the output array. The value at index 0 in `new_sizes` is the size of
-dimension 0, the value at index 1 is the size of dimension 1, and so on. The
-product of the `new_size` dimensions must equal the product of the operand's
-dimension sizes. When refining the collapsed array into the multidimensional
-array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from
-slowest varying (most major) and to fastest varying (most minor).
-
-For example, let v be an array of 24 elements:
-
-```
-let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
- {{20, 21, 22}, {25, 26, 27}},
- {{30, 31, 32}, {35, 36, 37}},
- {{40, 41, 42}, {45, 46, 47}}};
-
-In-order collapse:
-let v012_24 = Reshape(v, {0,1,2}, {24});
-then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
- 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
-
-let v012_83 = Reshape(v, {0,1,2}, {8,3});
-then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
- {20, 21, 22}, {25, 26, 27},
- {30, 31, 32}, {35, 36, 37},
- {40, 41, 42}, {45, 46, 47}};
-
-Out-of-order collapse:
-let v021_24 = Reshape(v, {1,2,0}, {24});
-then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
- 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
-
-let v021_83 = Reshape(v, {1,2,0}, {8,3});
-then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
- {31, 41, 12}, {22, 32, 42},
- {15, 25, 35}, {45, 16, 26},
- {36, 46, 17}, {27, 37, 47}};
-
-
-let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
-then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
- {11, 21}, {31, 41},
- {12, 22}, {32, 42}},
- {{15, 25}, {35, 45},
- {16, 26}, {36, 46},
- {17, 27}, {37, 47}}};
-```
-
-As a special case, reshape can transform a single-element array to a scalar and
-vice versa. For example,
-
-```
-Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
-Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
-```
-
-## Rev (reverse)
-
-See also
-[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b>`Rev(operand, dimensions)`</b>
-
-Arguments | Type | Semantics
------------- | ------------------- | ---------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `ArraySlice<int64>` | dimensions to reverse
-
-Reverses the order of elements in the `operand` array along the specified
-`dimensions`, generating an output array of the same shape. Each element of the
-operand array at a multidimensional index is stored into the output array at a
-transformed index. The multidimensional index is transformed by reversing the
-index in each dimension to be reversed (i.e., if a dimension of size N is one of
-the reversing dimensions, its index i is transformed into N - 1 - i).
-
-One use for the `Rev` operation is to reverse the convolution weight array along
-the two window dimensions during the gradient computation in neural networks.
-
-## RngNormal
-
-See also
-[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output of a given shape with random numbers generated following
-the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
-$$\sigma$$, and output shape have to have a floating point elemental type. The
-parameters furthermore have to be scalar valued.
-
-<b>`RngNormal(mu, sigma, shape)`</b>
-
-| Arguments | Type | Semantics |
-| --------- | ------- | --------------------------------------------------- |
-| `mu` | `XlaOp` | Scalar of type T specifying mean of generated |
-: : : numbers :
-| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of |
-: : : generated numbers :
-| `shape` | `Shape` | Output shape of type T |
-
-## RngUniform
-
-See also
-[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output of a given shape with random numbers generated following
-the uniform distribution over the interval $$[a,b)$$. The parameters and output
-element type have to be a boolean type, an integral type or a floating point
-types, and the types have to be consistent. The CPU and GPU backends currently
-only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
-parameters need to be scalar valued. If $$b <= a$$ the result is
-implementation-defined.
-
-<b>`RngUniform(a, b, shape)`</b>
-
-| Arguments | Type | Semantics |
-| --------- | ----------------------- | --------------------------------- |
-| `a` | `XlaOp` | Scalar of type T specifying lower |
-: : : limit of interval :
-| `b` | `XlaOp` | Scalar of type T specifying upper |
-: : : limit of interval :
-| `shape` | `Shape` | Output shape of type T |
-
-## Scatter
-
-The XLA scatter operation generates a result which is the value of the input
-tensor `operand`, with several slices (at indices specified by
-`scatter_indices`) updated with the values in `updates` using
-`update_computation`.
-
-See also
-[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
-
-|Arguments | Type | Semantics |
-|------------------|------------------------|----------------------------------|
-|`operand` | `XlaOp` | Tensor to be scattered into. |
-|`scatter_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices that must :
-: : : be scattered to. :
-|`updates` | `XlaOp` | Tensor containing the values that|
-: : : must be used for scattering. :
-|`update_computation`| `XlaComputation` | Computation to be used for |
-: : : combining the existing values in :
-: : : the input tensor and the updates :
-: : : during scatter. This computation :
-: : : should be of type `T, T -> T`. :
-|`index_vector_dim`| `int64` | The dimension in |
-: : : `scatter_indices` that contains :
-: : : the starting indices. :
-|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in |
-: : : `updates` shape that are _window :
-: : : dimensions_. :
-|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ |
-: : : that must be inserted into :
-: : : `updates` shape. :
-|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from |
-: : : the scatter indices to the :
-: : : operand index space. This array :
-: : : is interpreted as mapping `i` to :
-: : : `scatter_dims_to_operand_dims[i]`:
-: : : . It has to be one-to-one and :
-: : : total. :
-
-If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
-`scatter_indices` to have a trailing `1` dimension.
-
-We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
-dimensions in `updates` shape that are not in `update_window_dims`, in ascending
-order.
-
-The arguments of scatter should follow these constraints:
-
- - `updates` tensor must be of rank `update_window_dims.size +
- scatter_indices.rank - 1`.
-
- - Bounds of dimension `i` in `updates` must conform to the following:
- - If `i` is present in `update_window_dims` (i.e. equal to
- `update_window_dims`[`k`] for some `k`), then the bound of dimension
- `i` in `updates` must not exceed the corresponding bound of `operand`
- after accounting for the `inserted_window_dims` (i.e.
- `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
- the bounds of `operand` with the bounds at indices
- `inserted_window_dims` removed).
- - If `i` is present in `update_scatter_dims` (i.e. equal to
- `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
- `i` in `updates` must be equal to the corresponding bound of
- `scatter_indices`, skipping `index_vector_dim` (i.e.
- `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
- `scatter_indices.shape.dims`[`k+1`] otherwise).
-
- - `update_window_dims` must be in ascending order, not have any repeating
- dimension numbers, and be in the range `[0, updates.rank)`.
-
- - `inserted_window_dims` must be in ascending order, not have any
- repeating dimension numbers, and be in the range `[0, operand.rank)`.
-
- - `scatter_dims_to_operand_dims.size` must be equal to
- `scatter_indices`[`index_vector_dim`], and its values must be in the range
- `[0, operand.rank)`.
-
-For a given index `U` in the `updates` tensor, the corresponding index `I` in
-the `operand` tensor into which this update has to be applied is computed as
-follows:
-
- 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
- an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] =
- `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
- positions `index_vector_dim` into A.
- 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
- `S` using the `scatter_dims_to_operand_dims` map. More formally:
- 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
- `k` < `scatter_dims_to_operand_dims.size`.
- 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at `update_window_dims` in `U` according to `inserted_window_dims`.
- More formally:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if
- `k` < `update_window_dims.size`, where `window_dims_to_operand_dims`
- is the monotonic function with domain [`0`, `update_window_dims.size`)
- and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For
- example, if `update_window_dims.size` is `4`, `operand.rank` is `6`,
- and `inserted_window_dims` is {`0`, `2`} then
- `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`,
- `3`→`5`}).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
- addition.
-
-In summary, the scatter operation can be defined as follows.
-
- - Initialize `output` with `operand`, i.e. for all indices `O` in the
- `operand` tensor:\
- `output`[`O`] = `operand`[`O`]
- - For every index `U` in the `updates` tensor and the corresponding index `O`
- in the `operand` tensor:\
- `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
-
-The order in which updates are applied is non-deterministic. So, when multiple
-indices in `updates` refer to the same index in `operand`, the corresponding
-value in `output` will be non-deterministic.
-
-Note that the first parameter that is passed into the `update_computation` will
-always be the current value from the `output` tensor and the second parameter
-will always be the value from the `updates` tensor. This is important
-specifically for cases when the `update_computation` is _not commutative_.
-
-Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
-the scatter op updates the elements in the input that are extracted by the
-corresponding gather op.
-
-For a detailed informal description and examples, refer to the
-"Informal Description" section under `Gather`.
-
-## Select
-
-See also
-[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output array from elements of two input arrays, based on the
-values of a predicate array.
-
-<b> `Select(pred, on_true, on_false)` </b>
-
-Arguments | Type | Semantics
----------- | ------- | ------------------
-`pred` | `XlaOp` | array of type PRED
-`on_true` | `XlaOp` | array of type T
-`on_false` | `XlaOp` | array of type T
-
-The arrays `on_true` and `on_false` must have the same shape. This is also the
-shape of the output array. The array `pred` must have the same dimensionality as
-`on_true` and `on_false`, with the `PRED` element type.
-
-For each element `P` of `pred`, the corresponding element of the output array is
-taken from `on_true` if the value of `P` is `true`, and from `on_false` if the
-value of `P` is `false`. As a restricted form of [broadcasting]
-(broadcasting.md), `pred` can be a scalar of type `PRED`. In this case, the
-output array is taken wholly from `on_true` if `pred` is `true`, and from
-`on_false` if `pred` is `false`.
-
-Example with non-scalar `pred`:
-
-```
-let pred: PRED[4] = {true, false, false, true};
-let v1: s32[4] = {1, 2, 3, 4};
-let v2: s32[4] = {100, 200, 300, 400};
-==>
-Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
-```
-
-Example with scalar `pred`:
-
-```
-let pred: PRED = true;
-let v1: s32[4] = {1, 2, 3, 4};
-let v2: s32[4] = {100, 200, 300, 400};
-==>
-Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
-```
-
-Selections between tuples are supported. Tuples are considered to be scalar
-types for this purpose. If `on_true` and `on_false` are tuples (which must have
-the same shape!) then `pred` has to be a scalar of type `PRED`.
-
-## SelectAndScatter
-
-See also
-[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-This operation can be considered as a composite operation that first computes
-`ReduceWindow` on the `operand` array to select an element from each window, and
-then scatters the `source` array to the indices of the selected elements to
-construct an output array with the same shape as the operand array. The binary
-`select` function is used to select an element from each window by applying it
-across each window, and it is called with the property that the first
-parameter's index vector is lexicographically less than the second parameter's
-index vector. The `select` function returns `true` if the first parameter is
-selected and returns `false` if the second parameter is selected, and the
-function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are
-`true`, then `select(a, c)` is also `true`) so that the selected element does
-not depend on the order of the elements traversed for a given window.
-
-The function `scatter` is applied at each selected index in the output array. It
-takes two scalar parameters:
-
-1. Current value at the selected index in the output array
-2. The scatter value from `source` that applies to the selected index
-
-It combines the two parameters and returns a scalar value that's used to update
-the value at the selected index in the output array. Initially, all indices of
-the output array are set to `init_value`.
-
-The output array has the same shape as the `operand` array and the `source`
-array must have the same shape as the result of applying a `ReduceWindow`
-operation on the `operand` array. `SelectAndScatter` can be used to
-backpropagate the gradient values for a pooling layer in a neural network.
-
-<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
-padding, source, init_value, scatter)`</b>
-
-| Arguments | Type | Semantics |
-| ------------------- | ------------------- | -------------------------------- |
-| `operand` | `XlaOp` | array of type T over which the |
-: : : windows slide :
-| `select` | `XlaComputation` | binary computation of type `T, T |
-: : : -> PRED`, to apply to all :
-: : : elements in each window; returns :
-: : : `true` if the first parameter is :
-: : : selected and returns `false` if :
-: : : the second parameter is selected :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
-| `source` | `XlaOp` | array of type T with the values |
-: : : to scatter :
-| `init_value` | `XlaOp` | scalar value of type T for the |
-: : : initial value of the output :
-: : : array :
-| `scatter` | `XlaComputation` | binary computation of type `T, T |
-: : : -> T`, to apply each scatter :
-: : : source element with its :
-: : : destination element :
-
-The figure below shows examples of using `SelectAndScatter`, with the `select`
-function computing the maximal value among its parameters. Note that when the
-windows overlap, as in the figure (2) below, an index of the `operand` array may
-be selected multiple times by different windows. In the figure, the element of
-value 9 is selected by both of the top windows (blue and red) and the binary
-addition `scatter` function produces the output element of value 8 (2 + 6).
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%"
- src="https://www.tensorflow.org/images/ops_scatter_to_selected_window_element.png">
-</div>
-
-The evaluation order of the `scatter` function is arbitrary and may be
-non-deterministic. Therefore, the `scatter` function should not be overly
-sensitive to reassociation. See the discussion about associativity in the
-context of [`Reduce`](#reduce) for more details.
-
-## Send
-
-See also
-[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Send(operand, channel_handle)` </b>
-
-Arguments | Type | Semantics
----------------- | --------------- | -----------------------------------------
-`operand` | `XlaOp` | data to send (array of type T)
-`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
-
-Sends the given operand data to a `Recv` instruction in another computation
-that shares the same channel handle. Does not return any data.
-
-Similar to the `Recv` operation, the client API of `Send` operation represents
-synchronous communication, and is internally decomposed into 2 HLO instructions
-(`Send` and `SendDone`) to enable asynchronous data transfers. See also
-[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
-
-<b>`Send(HloInstruction operand, int64 channel_id)`</b>
-
-Initiates an asynchronous transfer of the operand to the resources allocated by
-the `Recv` instruction with the same channel id. Returns a context, which is
-used by a following `SendDone` instruction to wait for the completion of the
-data transfer. The context is a tuple of {operand (shape), request identifier
-(U32)} and it can only be used by a `SendDone` instruction.
-
-<b> `SendDone(HloInstruction context)` </b>
-
-Given a context created by a `Send` instruction, waits for the data transfer to
-complete. The instruction does not return any data.
-
-<b> Scheduling of channel instructions </b>
-
-The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
-`Send`, `SendDone`) is as below.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:70%" src="../../images/send_recv_order.png">
-</div>
-
-* `Recv` happens before `Send`
-* `Send` happens before `RecvDone`
-* `Recv` happens before `RecvDone`
-* `Send` happens before `SendDone`
-
-When the backend compilers generate a linear schedule for each computation that
-communicates via channel instructions, there must not be cycles across the
-computations. For example, below schedules lead to deadlocks.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/send_recv_schedule.png">
-</div>
-
-## Slice
-
-See also
-[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Slicing extracts a sub-array from the input array. The sub-array is of the same
-rank as the input and contains the values inside a bounding box within the input
-array where the dimensions and indices of the bounding box are given as
-arguments to the slice operation.
-
-<b> `Slice(operand, start_indices, limit_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------------------- | ------------------------------------ |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `start_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : starting indices of the slice for :
-: : : each dimension. Values must be :
-: : : greater than or equal to zero. :
-| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : ending indices (exclusive) for the :
-: : : slice for each dimension. Each value :
-: : : must be strictly greater than the :
-: : : respective `start_indices` value for :
-: : : the dimension and less than or equal :
-: : : to the size of the dimension. :
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-Slice(a, {2}, {4}) produces:
- {2.0, 3.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-
-Slice(b, {2, 1}, {4, 3}) produces:
- { { 7.0, 8.0},
- {10.0, 11.0} }
-```
-
-## Sort
-
-See also
-[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-There are two versions of the Sort instruction: a single-operand and a
-two-operand version.
-
-<b>`Sort(operand)`</b>
-
-Arguments | Type | Semantics
------------ | ------- | --------------------
-`operand` | `XlaOp` | The operand to sort.
-`dimension` | `int64` | The dimension along which to sort.
-
-Sorts the elements in the operand in ascending order along the provided
-dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0
-will sort each column independently, and a `dimension` value of 1 will sort each
-row independently. If the operand's elements have floating point type, and the
-operand contains NaN elements, the order of elements in the output is
-implementation-defined.
-
-<b>`Sort(key, value)`</b>
-
-Sorts both the key and the value operands. The keys are sorted as in the
-single-operand version. The values are sorted according to the order of their
-corresponding keys. For example, if the inputs are `keys = [3, 1]` and
-`values = [42, 50]`, then the output of the sort is the tuple
-`{[1, 3], [50, 42]}`.
-
-The sort is not guaranteed to be stable, that is, if the keys array contains
-duplicates, the order of their corresponding values may not be preserved.
-
-Arguments | Type | Semantics
------------ | ------- | -------------------
-`keys` | `XlaOp` | The sort keys.
-`values` | `XlaOp` | The values to sort.
-`dimension` | `int64` | The dimension along which to sort.
-
-The `keys` and `values` must have the same dimensions, but may have different
-element types.
-
-## Transpose
-
-See also the `tf.reshape` operation.
-
-<b>`Transpose(operand)`</b>
-
-Arguments | Type | Semantics
-------------- | ------------------- | ------------------------------
-`operand` | `XlaOp` | The operand to transpose.
-`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
-
-
-Permutes the operand dimensions with the given permutation, so
-`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`.
-
-This is the same as Reshape(operand, permutation,
- Permute(permutation, operand.shape.dimensions)).
-
-## Tuple
-
-See also
-[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A tuple containing a variable number of data handles, each of which has its own
-shape.
-
-This is analogous to `std::tuple` in C++. Conceptually:
-
-```
-let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-let s: s32 = 5;
-let t: (f32[10], s32) = tuple(v, s);
-```
-
-Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
-(#gettupleelement) operation.
-
-## While
-
-See also
-[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `While(condition, body, init)` </b>
-
-| Arguments | Type | Semantics |
-| ----------- | ---------------- | ---------------------------------------- |
-| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
-: : : defines the termination condition of the :
-: : : loop. :
-| `body` | `XlaComputation` | XlaComputation of type `T -> T` which |
-: : : defines the body of the loop. :
-| `init` | `T` | Initial value for the parameter of |
-: : : `condition` and `body`. :
-
-Sequentially executes the `body` until the `condition` fails. This is similar to
-a typical while loop in many other languages except for the differences and
-restrictions listed below.
-
-* A `While` node returns a value of type `T`, which is the result from the
- last execution of the `body`.
-* The shape of the type `T` is statically determined and must be the same
- across all iterations.
-
-The T parameters of the computations are initialized with the `init` value in
-the first iteration and are automatically updated to the new result from `body`
-in each subsequent iteration.
-
-One main use case of the `While` node is to implement the repeated execution of
-training in neural networks. Simplified pseudocode is shown below with a graph
-that represents the computation. The code can be found in
-[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc).
-The type `T` in this example is a `Tuple` consisting of an `int32` for the
-iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the
-loop keeps adding a constant vector to the accumulator.
-
-```
-// Pseudocode for the computation.
-init = {0, zero_vector[10]} // Tuple of int32 and float[10].
-result = init;
-while (result(0) < 1000) {
- iteration = result(0) + 1;
- new_vector = result(1) + constant_vector[10];
- result = {iteration, new_vector};
-}
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_while.png">
-</div>
diff --git a/tensorflow/docs_src/performance/xla/shapes.md b/tensorflow/docs_src/performance/xla/shapes.md
deleted file mode 100644
index 39e74ff307..0000000000
--- a/tensorflow/docs_src/performance/xla/shapes.md
+++ /dev/null
@@ -1,150 +0,0 @@
-# Shapes and Layout
-
-The XLA `Shape` proto
-([xla_data.proto](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto))
-describes the rank, size, and data type of an N-dimensional array (*array* in
-short).
-
-## Terminology, Notation, and Conventions
-
-* The rank of an array is equal to the number of dimensions. The *true rank*
- of an array is the number of dimensions which have a size greater than 1.
-
-* Dimensions are numbered from `0` up to `N-1` for an `N` dimensional array.
- The dimension numbers are arbitrary labels for convenience. The order of
- these dimension numbers does not imply a particular minor/major ordering in
- the layout of the shape. The layout is determined by the `Layout` proto.
-
-* By convention, dimensions are listed in increasing order of dimension
- number. For example, for a 3-dimensional array of size `[A x B x C]`,
- dimension 0 has size `A`, dimension 1 has size `B` and dimension 2 has size
- `C`.
-
- Some utilities in XLA also support negative indexing, similarly to Python;
- dimension -1 is the last dimension (equivalent to `N-1` for an `N`
- dimensional array). For example, for the 3-dimensional array described
- above, dimension -1 has size `C`, dimension -2 has size `B` and so on.
-
-* Two, three, and four dimensional arrays often have specific letters
- associated with dimensions. For example, for a 2D array:
-
- * dimension 0: `y`
- * dimension 1: `x`
-
- For a 3D array:
-
- * dimension 0: `z`
- * dimension 1: `y`
- * dimension 2: `x`
-
- For a 4D array:
-
- * dimension 0: `p`
- * dimension 1: `z`
- * dimension 2: `y`
- * dimension 3: `x`
-
-* Functions in the XLA API which take dimensions do so in increasing order of
- dimension number. This matches the ordering used when passing dimensions as
- an `initializer_list`; e.g.
-
- `ShapeUtil::MakeShape(F32, {A, B, C, D})`
-
- Will create a shape whose dimension size array consists of the sequence
- `[A, B, C, D]`.
-
-## Layout
-
-The `Layout` proto describes how an array is represented in memory. The `Layout`
-proto includes the following fields:
-
-```
-message Layout {
- repeated int64 minor_to_major = 1;
- repeated int64 padded_dimensions = 2;
- optional PaddingValue padding_value = 3;
-}
-```
-
-### Minor-to-major dimension ordering
-
-The only required field is `minor_to_major`. This field describes the
-minor-to-major ordering of the dimensions within a shape. Values in
-`minor_to_major` are an ordering of the dimensions of the array (`0` to `N-1`
-for an `N` dimensional array) with the first value being the most-minor
-dimension up to the last value which is the most-major dimension. The most-minor
-dimension is the dimension which changes most rapidly when stepping through the
-elements of the array laid out in linear memory.
-
-For example, consider the following 2D array of size `[2 x 3]`:
-
-```
-a b c
-d e f
-```
-
-Here dimension `0` is size 2, and dimension `1` is size 3. If the
-`minor_to_major` field in the layout is `[0, 1]` then dimension `0` is the
-most-minor dimension and dimension `1` is the most-major dimension. This
-corresponds to the following layout in linear memory:
-
-```
-a d b e c f
-```
-
-This minor-to-major dimension order of `0` up to `N-1` is akin to *column-major*
-(at rank 2). Assuming a monotonic ordering of dimensions, another name we may
-use to refer to this layout in the code is simply "dim 0 is minor".
-
-On the other hand, if the `minor_to_major` field in the layout is `[1, 0]` then
-the layout in linear memory is:
-
-```
-a b c d e f
-```
-
-A minor-to-major dimension order of `N-1` down to `0` for an `N` dimensional
-array is akin to *row-major* (at rank 2). Assuming a monotonic ordering of
-dimensions, another name we may use to refer to this layout in the code is
-simply "dim 0 is major".
-
-#### Default minor-to-major ordering
-
-The default layout for newly created Shapes is "dimension order is
-major-to-minor" (akin to row-major at rank 2).
-
-### Padding
-
-Padding is defined in the optional `padded_dimensions` and `padding_value`
-fields. The field `padded_dimensions` describes the sizes (widths) to which each
-dimension is padded. If present, the number of elements in `padded_dimensions`
-must equal the rank of the shape.
-
-For example, given the `[2 x 3]` array defined above, if `padded_dimension` is
-`[3, 5]` then dimension 0 is padded to a width of 3 and dimension 1 is padded to
-a width of 5. The layout in linear memory (assuming a padding value of 0 and
-column-major layout) is:
-
-```
-a d 0 b e 0 c f 0 0 0 0 0 0 0
-```
-
-This is equivalent to the layout of the following array with the same
-minor-to-major dimension order:
-
-```
-a b c 0 0
-d e f 0 0
-0 0 0 0 0
-```
-
-### Indexing into arrays
-
-The class `IndexUtil` in
-[index_util.h](https://www.tensorflow.org/code/tensorflow/compiler/xla/index_util.h)
-provides utilities for converting between multidimensional indices and linear
-indices given a shape and layout. Multidimensional indices include a `int64`
-index for each dimension. Linear indices are a single `int64` value which
-indexes into the buffer holding the array. See `shape_util.h` and
-`layout_util.h` in the same directory for utilities that simplify creation and
-manipulation of shapes and layouts.
diff --git a/tensorflow/docs_src/performance/xla/tfcompile.md b/tensorflow/docs_src/performance/xla/tfcompile.md
deleted file mode 100644
index e4b803164f..0000000000
--- a/tensorflow/docs_src/performance/xla/tfcompile.md
+++ /dev/null
@@ -1,281 +0,0 @@
-# Using AOT compilation
-
-## What is tfcompile?
-
-`tfcompile` is a standalone tool that ahead-of-time (AOT) compiles TensorFlow
-graphs into executable code. It can reduce total binary size, and also avoid
-some runtime overheads. A typical use-case of `tfcompile` is to compile an
-inference graph into executable code for mobile devices.
-
-The TensorFlow graph is normally executed by the TensorFlow runtime. This incurs
-some runtime overhead for execution of each node in the graph. This also leads
-to a larger total binary size, since the code for the TensorFlow runtime needs
-to be available, in addition to the graph itself. The executable code produced
-by `tfcompile` does not use the TensorFlow runtime, and only has dependencies on
-kernels that are actually used in the computation.
-
-The compiler is built on top of the XLA framework. The code bridging TensorFlow
-to the XLA framework resides under
-[tensorflow/compiler](https://www.tensorflow.org/code/tensorflow/compiler/),
-which also includes support for @{$jit$just-in-time (JIT) compilation} of
-TensorFlow graphs.
-
-## What does tfcompile do?
-
-`tfcompile` takes a subgraph, identified by the TensorFlow concepts of
-feeds and fetches, and generates a function that implements that subgraph.
-The `feeds` are the input arguments for the function, and the `fetches` are the
-output arguments for the function. All inputs must be fully specified by the
-feeds; the resulting pruned subgraph cannot contain Placeholder or Variable
-nodes. It is common to specify all Placeholders and Variables as feeds, which
-ensures the resulting subgraph no longer contains these nodes. The generated
-function is packaged as a `cc_library`, with a header file exporting the
-function signature, and an object file containing the implementation. The user
-writes code to invoke the generated function as appropriate.
-
-## Using tfcompile
-
-This section details high level steps for generating an executable binary with
-`tfcompile` from a TensorFlow subgraph. The steps are:
-
-* Step 1: Configure the subgraph to compile
-* Step 2: Use the `tf_library` build macro to compile the subgraph
-* Step 3: Write code to invoke the subgraph
-* Step 4: Create the final binary
-
-### Step 1: Configure the subgraph to compile
-
-Identify the feeds and fetches that correspond to the input and output
-arguments for the generated function. Then configure the `feeds` and `fetches`
-in a [`tensorflow.tf2xla.Config`](https://www.tensorflow.org/code/tensorflow/compiler/tf2xla/tf2xla.proto)
-proto.
-
-```textproto
-# Each feed is a positional input argument for the generated function. The order
-# of each entry matches the order of each input argument. Here “x_hold” and “y_hold”
-# refer to the names of placeholder nodes defined in the graph.
-feed {
- id { node_name: "x_hold" }
- shape {
- dim { size: 2 }
- dim { size: 3 }
- }
-}
-feed {
- id { node_name: "y_hold" }
- shape {
- dim { size: 3 }
- dim { size: 2 }
- }
-}
-
-# Each fetch is a positional output argument for the generated function. The order
-# of each entry matches the order of each output argument. Here “x_y_prod”
-# refers to the name of a matmul node defined in the graph.
-fetch {
- id { node_name: "x_y_prod" }
-}
-```
-
-### Step 2: Use tf_library build macro to compile the subgraph
-
-This step converts the graph into a `cc_library` using the `tf_library` build
-macro. The `cc_library` consists of an object file containing the code generated
-from the graph, along with a header file that gives access to the generated
-code. `tf_library` utilizes `tfcompile` to compile the TensorFlow graph into
-executable code.
-
-```build
-load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
-
-# Use the tf_library macro to compile your graph into executable code.
-tf_library(
- # name is used to generate the following underlying build rules:
- # <name> : cc_library packaging the generated header and object files
- # <name>_test : cc_test containing a simple test and benchmark
- # <name>_benchmark : cc_binary containing a stand-alone benchmark with minimal deps;
- # can be run on a mobile device
- name = "test_graph_tfmatmul",
- # cpp_class specifies the name of the generated C++ class, with namespaces allowed.
- # The class will be generated in the given namespace(s), or if no namespaces are
- # given, within the global namespace.
- cpp_class = "foo::bar::MatMulComp",
- # graph is the input GraphDef proto, by default expected in binary format. To
- # use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be
- # created from this input graph, with feeds as inputs and fetches as outputs.
- # No Placeholder or Variable ops may exist in this subgraph.
- graph = "test_graph_tfmatmul.pb",
- # config is the input Config proto, by default expected in binary format. To
- # use the text format instead, use the ‘.pbtxt’ suffix. This is where the
- # feeds and fetches were specified above, in the previous step.
- config = "test_graph_tfmatmul.config.pbtxt",
-)
-```
-
-> To generate the GraphDef proto (test_graph_tfmatmul.pb) for this example, run
-> [make_test_graphs.py]("https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/make_test_graphs.py")
-> and specify the output location with the --out_dir flag.
-
-Typical graphs contain @{$python/state_ops$`Variables`}
-representing the weights that are learned via training, but `tfcompile` cannot
-compile a subgraph that contain `Variables`. The
-[freeze_graph.py](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py)
-tool converts variables into constants, using values stored in a checkpoint
-file. As a convenience, the `tf_library` macro supports the `freeze_checkpoint`
-argument, which runs the tool. For more examples see
-[tensorflow/compiler/aot/tests/BUILD](https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/BUILD).
-
-> Constants that show up in the compiled subgraph are compiled directly into the
-> generated code. To pass the constants into the generated function, rather than
-> having them compiled-in, simply pass them in as feeds.
-
-For details on the `tf_library` build macro, see
-[tfcompile.bzl](https://www.tensorflow.org/code/tensorflow/compiler/aot/tfcompile.bzl).
-
-For details on the underlying `tfcompile` tool, see
-[tfcompile_main.cc](https://www.tensorflow.org/code/tensorflow/compiler/aot/tfcompile_main.cc).
-
-### Step 3: Write code to invoke the subgraph
-
-This step uses the header file (`test_graph_tfmatmul.h`) generated by the
-`tf_library` build macro in the previous step to invoke the generated code. The
-header file is located in the `bazel-genfiles` directory corresponding to the
-build package, and is named based on the name attribute set on the `tf_library`
-build macro. For example, the header generated for `test_graph_tfmatmul` would
-be `test_graph_tfmatmul.h`. Below is an abbreviated version of what is
-generated. The generated file, in `bazel-genfiles`, contains additional useful
-comments.
-
-```c++
-namespace foo {
-namespace bar {
-
-// MatMulComp represents a computation previously specified in a
-// TensorFlow graph, now compiled into executable code.
-class MatMulComp {
- public:
- // AllocMode controls the buffer allocation mode.
- enum class AllocMode {
- ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers
- RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers
- };
-
- MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
- ~MatMulComp();
-
- // Runs the computation, with inputs read from arg buffers, and outputs
- // written to result buffers. Returns true on success and false on failure.
- bool Run();
-
- // Arg methods for managing input buffers. Buffers are in row-major order.
- // There is a set of methods for each positional argument.
- void** args();
-
- void set_arg0_data(float* data);
- float* arg0_data();
- float& arg0(size_t dim0, size_t dim1);
-
- void set_arg1_data(float* data);
- float* arg1_data();
- float& arg1(size_t dim0, size_t dim1);
-
- // Result methods for managing output buffers. Buffers are in row-major order.
- // Must only be called after a successful Run call. There is a set of methods
- // for each positional result.
- void** results();
-
-
- float* result0_data();
- float& result0(size_t dim0, size_t dim1);
-};
-
-} // end namespace bar
-} // end namespace foo
-```
-
-The generated C++ class is called `MatMulComp` in the `foo::bar` namespace,
-because that was the `cpp_class` specified in the `tf_library` macro. All
-generated classes have a similar API, with the only difference being the methods
-to handle arg and result buffers. Those methods differ based on the number and
-types of the buffers, which were specified by the `feed` and `fetch` arguments
-to the `tf_library` macro.
-
-There are three types of buffers managed within the generated class: `args`
-representing the inputs, `results` representing the outputs, and `temps`
-representing temporary buffers used internally to perform the computation. By
-default, each instance of the generated class allocates and manages all of these
-buffers for you. The `AllocMode` constructor argument may be used to change this
-behavior. All buffers are aligned to 64-byte boundaries.
-
-The generated C++ class is just a wrapper around the low-level code generated by
-XLA.
-
-Example of invoking the generated function based on
-[`tfcompile_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/tfcompile_test.cc):
-
-```c++
-#define EIGEN_USE_THREADS
-#define EIGEN_USE_CUSTOM_THREAD_POOL
-
-#include <iostream>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated
-
-int main(int argc, char** argv) {
- Eigen::ThreadPool tp(2); // Size the thread pool as appropriate.
- Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
-
-
- foo::bar::MatMulComp matmul;
- matmul.set_thread_pool(&device);
-
- // Set up args and run the computation.
- const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- std::copy(args + 0, args + 6, matmul.arg0_data());
- std::copy(args + 6, args + 12, matmul.arg1_data());
- matmul.Run();
-
- // Check result
- if (matmul.result0(0, 0) == 58) {
- std::cout << "Success" << std::endl;
- } else {
- std::cout << "Failed. Expected value 58 at 0,0. Got:"
- << matmul.result0(0, 0) << std::endl;
- }
-
- return 0;
-}
-```
-
-### Step 4: Create the final binary
-
-This step combines the library generated by `tf_library` in step 2 and the code
-written in step 3 to create a final binary. Below is an example `bazel` BUILD
-file.
-
-```build
-# Example of linking your binary
-# Also see //tensorflow/compiler/aot/tests/BUILD
-load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
-
-# The same tf_library call from step 2 above.
-tf_library(
- name = "test_graph_tfmatmul",
- ...
-)
-
-# The executable code generated by tf_library can then be linked into your code.
-cc_binary(
- name = "my_binary",
- srcs = [
- "my_code.cc", # include test_graph_tfmatmul.h to access the generated header
- ],
- deps = [
- ":test_graph_tfmatmul", # link in the generated object file
- "//third_party/eigen3",
- ],
- linkopts = [
- "-lpthread",
- ]
-)
-```
diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml
deleted file mode 100644
index 9534114689..0000000000
--- a/tensorflow/docs_src/tutorials/_index.yaml
+++ /dev/null
@@ -1,202 +0,0 @@
-project_path: /_project.yaml
-book_path: /_book.yaml
-description: <!--no description-->
-landing_page:
- custom_css_path: /site-assets/css/style.css
- show_side_navs: True
- rows:
- - description: >
- <h1 class="hide-from-toc">Get Started with TensorFlow</h1>
- <p>
- TensorFlow is an open-source machine learning library for research and
- production. TensorFlow offers APIs for beginners and experts to develop
- for desktop, mobile, web, and cloud. See the sections below to get
- started.
- </p>
- items:
- - custom_html: >
- <div class="devsite-landing-row-item-description">
- <h3 class="hide-from-toc">Learn and use ML</h3>
- <div class="devsite-landing-row-item-description-content">
- <p>
- The high-level Keras API provides building blocks to create and
- train deep learning models. Start with these beginner-friendly
- notebook examples, then read the
- <a href="/guide/keras">TensorFlow Keras guide</a>.
- </p>
- <ol style="padding-left:20px;">
- <li><a href="./keras/basic_classification">Basic classification</a></li>
- <li><a href="./keras/basic_text_classification">Text classification</a></li>
- <li><a href="./keras/basic_regression">Regression</a></li>
- <li><a href="./keras/overfit_and_underfit">Overfitting and underfitting</a></li>
- <li><a href="./keras/save_and_restore_models">Save and load</a></li>
- </ol>
- </div>
- <div class="devsite-landing-row-item-buttons" style="margin-top:0;">
- <a class="button button-primary tfo-button-primary" href="/guide/keras">Read the Keras guide</a>
- </div>
- </div>
- - classname: tfo-landing-row-item-code-block
- code_block: |
- <pre class="prettyprint">
- import tensorflow as tf
- mnist = tf.keras.datasets.mnist
-
- (x_train, y_train),(x_test, y_test) = mnist.load_data()
- x_train, x_test = x_train / 255.0, x_test / 255.0
-
- model = tf.keras.models.Sequential([
- tf.keras.layers.Flatten(),
- tf.keras.layers.Dense(512, activation=tf.nn.relu),
- tf.keras.layers.Dropout(0.2),
- tf.keras.layers.Dense(10, activation=tf.nn.softmax)
- ])
- model.compile(optimizer='adam',
- loss='sparse_categorical_crossentropy',
- metrics=['accuracy'])
-
- model.fit(x_train, y_train, epochs=5)
- model.evaluate(x_test, y_test)
- </pre>
- {% dynamic if request.tld != 'cn' %}
- <a class="colab-button" target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/_index.ipynb">Run in a <span>Notebook</span></a>
- {% dynamic endif %}
-
- - items:
- - custom_html: >
- <div class="devsite-landing-row-item-description" style="border-right: 2px solid #eee;">
- <h3 class="hide-from-toc">Research and experimentation</h3>
- <div class="devsite-landing-row-item-description-content">
- <p>
- Eager execution provides an imperative, define-by-run interface for advanced operations. Write custom layers, forward passes, and training loops with auto‑differentiation. Start with
- these notebooks, then read the <a href="/guide/eager">eager execution guide</a>.
- </p>
- <ol style="padding-left:20px;">
- <li>
- {% dynamic if request.tld == 'cn' %}
- <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb" class="external">Eager execution basics</a>
- {% dynamic else %}
- <a href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb" class="external">Eager execution basics</a>
- {% dynamic endif %}
- </li>
- <li>
- {% dynamic if request.tld == 'cn' %}
- <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb" class="external">Automatic differentiation and gradient tape</a>
- {% dynamic else %}
- <a href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb" class="external">Automatic differentiation and gradient tape</a>
- {% dynamic endif %}
- </li>
- <li>
- {% dynamic if request.tld == 'cn' %}
- <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb" class="external">Custom training: basics</a>
- {% dynamic else %}
- <a href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb" class="external">Custom training: basics</a>
- {% dynamic endif %}
- </li>
- <li>
- {% dynamic if request.tld == 'cn' %}
- <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb" class="external">Custom layers</a>
- {% dynamic else %}
- <a href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb" class="external">Custom layers</a>
- {% dynamic endif %}
- </li>
- <li><a href="./eager/custom_training_walkthrough">Custom training: walkthrough</a></li>
- <li>
- {% dynamic if request.tld == 'cn' %}
- <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb" class="external">Example: Neural machine translation w/ attention</a>
- {% dynamic else %}
- <a href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb" class="external">Example: Neural machine translation w/ attention</a>
- {% dynamic endif %}
- </li>
- </ol>
- </div>
- <div class="devsite-landing-row-item-buttons">
- <a class="button button-primary tfo-button-primary" href="/guide/eager">Read the eager execution guide</a>
- </div>
- </div>
- - custom_html: >
- <div class="devsite-landing-row-item-description">
- <h3 class="hide-from-toc">ML at production scale</h3>
- <div class="devsite-landing-row-item-description-content">
- <p>
- Estimators can train large models on multiple machines in a
- production environment. TensorFlow provides a collection of
- pre-made Estimators to implement common ML algorithms. See the
- <a href="/guide/estimators">Estimators guide</a>.
- </p>
- <ol style="padding-left: 20px;">
- <li><a href="/tutorials/estimators/linear">Build a linear model with Estimators</a></li>
- <li><a href="https://github.com/tensorflow/models/tree/master/official/wide_deep" class="external">Wide and deep learning with Estimators</a></li>
- <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees" class="external">Boosted trees</a></li>
- <li><a href="/hub/tutorials/text_classification_with_tf_hub">How to build a simple text classifier with TF-Hub</a></li>
- <li><a href="/tutorials/estimators/cnn">Build a Convolutional Neural Network using Estimators</a></li>
- </ol>
- </div>
- <div class="devsite-landing-row-item-buttons">
- <a class="button button-primary tfo-button-primary" href="/guide/estimators">Read the Estimators guide</a>
- </div>
- </div>
-
- - description: >
- <h2 class="hide-from-toc">Google Colab&#58; An easy way to learn and use TensorFlow</h2>
- <p>
- <a href="https://colab.research.google.com/notebooks/welcome.ipynb" class="external">Colaboratory</a>
- is a Google research project created to help disseminate machine learning
- education and research. It's a Jupyter notebook environment that requires
- no setup to use and runs entirely in the cloud.
- <a href="https://medium.com/tensorflow/colab-an-easy-way-to-learn-and-use-tensorflow-d74d1686e309" class="external">Read the blog post</a>.
- </p>
-
- - description: >
- <h2 class="hide-from-toc">Build your first ML app</h2>
- <p>Create and deploy TensorFlow models on web and mobile.</p>
- background: grey
- items:
- - custom_html: >
- <div class="devsite-landing-row-item-description" style="background: #fff; padding:32px;">
- <a href="https://js.tensorflow.org">
- <h3 class="hide-from-toc">Web developers</h3>
- </a>
- <div class="devsite-landing-row-item-description-content">
- TensorFlow.js is a WebGL accelerated, JavaScript library to train and
- deploy ML models in the browser and for Node.js.
- </div>
- </div>
- - custom_html: >
- <div class="devsite-landing-row-item-description" style="background: #fff; padding:32px;">
- <a href="/mobile/tflite/">
- <h3 class="hide-from-toc">Mobile developers</h3>
- </a>
- <div class="devsite-landing-row-item-description-content">
- TensorFlow Lite is lightweight solution for mobile and embedded devices.
- </div>
- </div>
-
- - description: >
- <h2 class="hide-from-toc">Videos and updates</h2>
- <p>
- Subscribe to the TensorFlow
- <a href="https://www.youtube.com/tensorflow" class="external">YouTube channel</a>
- and <a href="https://blog.tensorflow.org" class="external">blog</a> for
- the latest videos and updates.
- </p>
- items:
- - description: >
- <h3 class="hide-from-toc">Get started with TensorFlow's High-Level APIs</h3>
- youtube_id: tjsHSIG8I08
- buttons:
- - label: Watch the video
- path: https://www.youtube.com/watch?v=tjsHSIG8I08
- - description: >
- <h3 class="hide-from-toc">Eager execution</h3>
- youtube_id: T8AW0fKP0Hs
- background: grey
- buttons:
- - label: Watch the video
- path: https://www.youtube.com/watch?v=T8AW0fKP0Hs
- - description: >
- <h3 class="hide-from-toc">tf.data: Fast, flexible, and easy-to-use input pipelines</h3>
- youtube_id: uIcqeP7MFH0
- buttons:
- - label: Watch the video
- path: https://www.youtube.com/watch?v=uIcqeP7MFH0
diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml
deleted file mode 100644
index 0e25208a00..0000000000
--- a/tensorflow/docs_src/tutorials/_toc.yaml
+++ /dev/null
@@ -1,124 +0,0 @@
-toc:
-- title: Get started with TensorFlow
- path: /tutorials/
-
-- title: Learn and use ML
- style: accordion
- section:
- - title: Overview
- path: /tutorials/keras/
- - title: Basic classification
- path: /tutorials/keras/basic_classification
- - title: Text classification
- path: /tutorials/keras/basic_text_classification
- - title: Regression
- path: /tutorials/keras/basic_regression
- - title: Overfitting and underfitting
- path: /tutorials/keras/overfit_and_underfit
- - title: Save and restore models
- path: /tutorials/keras/save_and_restore_models
-
-- title: Research and experimentation
- style: accordion
- section:
- - title: Overview
- path: /tutorials/eager/
- - title: Eager execution
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
- status: external
- - title: Automatic differentiation
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
- status: external
- - title: "Custom training: basics"
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
- status: external
- - title: Custom layers
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
- status: external
- - title: "Custom training: walkthrough"
- path: /tutorials/eager/custom_training_walkthrough
- - title: Text generation
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
- status: external
- - title: Translation with attention
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
- status: external
- - title: Image captioning
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
- status: external
- - title: Neural Style Transfer
- path: https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb
- status: external
- - title: DCGAN
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb
- status: external
- - title: VAE
- path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb
- status: external
- - title: Pix2Pix
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
- status: external
- - title: Image Segmentation
- path: https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb
- status: external
-
-- title: ML at production scale
- style: accordion
- section:
- - title: Linear model with Estimators
- path: /tutorials/estimators/linear
- - title: Wide and deep learning
- path: https://github.com/tensorflow/models/tree/master/official/wide_deep
- status: external
- - title: Boosted trees
- path: https://github.com/tensorflow/models/tree/master/official/boosted_trees
- status: external
- - title: Text classifier with TF-Hub
- path: /hub/tutorials/text_classification_with_tf_hub
- - title: Build a CNN using Estimators
- path: /tutorials/estimators/cnn
-
-- title: Images
- style: accordion
- section:
- - title: Image recognition
- path: /tutorials/images/image_recognition
- - title: Image retraining
- path: /hub/tutorials/image_retraining
- - title: Advanced CNN
- path: /tutorials/images/deep_cnn
-
-- title: Sequences
- style: accordion
- section:
- - title: Recurrent neural network
- path: /tutorials/sequences/recurrent
- - title: Drawing classification
- path: /tutorials/sequences/recurrent_quickdraw
- - title: Simple audio recognition
- path: /tutorials/sequences/audio_recognition
- - title: Neural machine translation
- path: https://github.com/tensorflow/nmt
- status: external
-
-- title: Data representation
- style: accordion
- section:
- - title: Vector representations of words
- path: /tutorials/representation/word2vec
- - title: Kernel methods
- path: /tutorials/representation/kernel_methods
- - title: Large-scale linear models
- path: /tutorials/representation/linear
-
-- title: Non-ML
- style: accordion
- section:
- - title: Mandelbrot set
- path: /tutorials/non-ml/mandelbrot
- - title: Partial differential equations
- path: /tutorials/non-ml/pdes
-
-- break: True
-- title: Next steps
- path: /tutorials/next_steps
diff --git a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md
deleted file mode 100644
index b564a27ecf..0000000000
--- a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Custom training: walkthrough
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/eager/custom_training_walkthrough.ipynb)
diff --git a/tensorflow/docs_src/tutorials/eager/index.md b/tensorflow/docs_src/tutorials/eager/index.md
deleted file mode 100644
index a13b396094..0000000000
--- a/tensorflow/docs_src/tutorials/eager/index.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# Research and experimentation
-
-Eager execution provides an imperative, define-by-run interface for advanced
-operations. Write custom layers, forward passes, and training loops with
-auto&nbsp;differentiation. Start with these notebooks, then read the
-[eager execution guide](../../guide/eager).
-
-1. <span>[Eager execution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb){:.external}</span>
-2. <span>[Automatic differentiation and gradient tape](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb){:.external}</span>
-3. <span>[Custom training: basics](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb){:.external}</span>
-4. <span>[Custom layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb){:.external}</span>
-5. [Custom training: walkthrough](/tutorials/eager/custom_training_walkthrough)
-6. <span>[Advanced example: Neural machine translation with attention](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb){:.external}</span>
diff --git a/tensorflow/docs_src/tutorials/estimators/cnn.md b/tensorflow/docs_src/tutorials/estimators/cnn.md
deleted file mode 100644
index 100f501cc2..0000000000
--- a/tensorflow/docs_src/tutorials/estimators/cnn.md
+++ /dev/null
@@ -1,694 +0,0 @@
-# Build a Convolutional Neural Network using Estimators
-
-The `tf.layers` module provides a high-level API that makes
-it easy to construct a neural network. It provides methods that facilitate the
-creation of dense (fully connected) layers and convolutional layers, adding
-activation functions, and applying dropout regularization. In this tutorial,
-you'll learn how to use `layers` to build a convolutional neural network model
-to recognize the handwritten digits in the MNIST data set.
-
-![handwritten digits 0–9 from the MNIST data set](https://www.tensorflow.org/images/mnist_0-9.png)
-
-**The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) comprises 60,000
-training examples and 10,000 test examples of the handwritten digits 0–9,
-formatted as 28x28-pixel monochrome images.**
-
-## Getting Started
-
-Let's set up the skeleton for our TensorFlow program. Create a file called
-`cnn_mnist.py`, and add the following code:
-
-```python
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# Imports
-import numpy as np
-import tensorflow as tf
-
-tf.logging.set_verbosity(tf.logging.INFO)
-
-# Our application logic will be added here
-
-if __name__ == "__main__":
- tf.app.run()
-```
-
-As you work through the tutorial, you'll add code to construct, train, and
-evaluate the convolutional neural network. The complete, final code can be
-[found here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/layers/cnn_mnist.py).
-
-## Intro to Convolutional Neural Networks
-
-Convolutional neural networks (CNNs) are the current state-of-the-art model
-architecture for image classification tasks. CNNs apply a series of filters to
-the raw pixel data of an image to extract and learn higher-level features, which
-the model can then use for classification. CNNs contains three components:
-
-* **Convolutional layers**, which apply a specified number of convolution
- filters to the image. For each subregion, the layer performs a set of
- mathematical operations to produce a single value in the output feature map.
- Convolutional layers then typically apply a
- [ReLU activation function](https://en.wikipedia.org/wiki/Rectifier_\(neural_networks\)) to
- the output to introduce nonlinearities into the model.
-
-* **Pooling layers**, which
- [downsample the image data](https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer)
- extracted by the convolutional layers to reduce the dimensionality of the
- feature map in order to decrease processing time. A commonly used pooling
- algorithm is max pooling, which extracts subregions of the feature map
- (e.g., 2x2-pixel tiles), keeps their maximum value, and discards all other
- values.
-
-* **Dense (fully connected) layers**, which perform classification on the
- features extracted by the convolutional layers and downsampled by the
- pooling layers. In a dense layer, every node in the layer is connected to
- every node in the preceding layer.
-
-Typically, a CNN is composed of a stack of convolutional modules that perform
-feature extraction. Each module consists of a convolutional layer followed by a
-pooling layer. The last convolutional module is followed by one or more dense
-layers that perform classification. The final dense layer in a CNN contains a
-single node for each target class in the model (all the possible classes the
-model may predict), with a
-[softmax](https://en.wikipedia.org/wiki/Softmax_function) activation function to
-generate a value between 0–1 for each node (the sum of all these softmax values
-is equal to 1). We can interpret the softmax values for a given image as
-relative measurements of how likely it is that the image falls into each target
-class.
-
-> Note: For a more comprehensive walkthrough of CNN architecture, see Stanford
-> University's <a href="https://cs231n.github.io/convolutional-networks/">
-> Convolutional Neural Networks for Visual Recognition course materials</a>.</p>
-
-## Building the CNN MNIST Classifier {#building_the_cnn_mnist_classifier}
-
-Let's build a model to classify the images in the MNIST dataset using the
-following CNN architecture:
-
-1. **Convolutional Layer #1**: Applies 32 5x5 filters (extracting 5x5-pixel
- subregions), with ReLU activation function
-2. **Pooling Layer #1**: Performs max pooling with a 2x2 filter and stride of 2
- (which specifies that pooled regions do not overlap)
-3. **Convolutional Layer #2**: Applies 64 5x5 filters, with ReLU activation
- function
-4. **Pooling Layer #2**: Again, performs max pooling with a 2x2 filter and
- stride of 2
-5. **Dense Layer #1**: 1,024 neurons, with dropout regularization rate of 0.4
- (probability of 0.4 that any given element will be dropped during training)
-6. **Dense Layer #2 (Logits Layer)**: 10 neurons, one for each digit target
- class (0–9).
-
-The `tf.layers` module contains methods to create each of the three layer types
-above:
-
-* `conv2d()`. Constructs a two-dimensional convolutional layer. Takes number
- of filters, filter kernel size, padding, and activation function as
- arguments.
-* `max_pooling2d()`. Constructs a two-dimensional pooling layer using the
- max-pooling algorithm. Takes pooling filter size and stride as arguments.
-* `dense()`. Constructs a dense layer. Takes number of neurons and activation
- function as arguments.
-
-Each of these methods accepts a tensor as input and returns a transformed tensor
-as output. This makes it easy to connect one layer to another: just take the
-output from one layer-creation method and supply it as input to another.
-
-Open `cnn_mnist.py` and add the following `cnn_model_fn` function, which
-conforms to the interface expected by TensorFlow's Estimator API (more on this
-later in [Create the Estimator](#create-the-estimator)). `cnn_mnist.py` takes
-MNIST feature data, labels, and mode (from
-`tf.estimator.ModeKeys`: `TRAIN`, `EVAL`, `PREDICT`) as arguments;
-configures the CNN; and returns predictions, loss, and a training operation:
-
-```python
-def cnn_model_fn(features, labels, mode):
- """Model function for CNN."""
- # Input Layer
- input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])
-
- # Convolutional Layer #1
- conv1 = tf.layers.conv2d(
- inputs=input_layer,
- filters=32,
- kernel_size=[5, 5],
- padding="same",
- activation=tf.nn.relu)
-
- # Pooling Layer #1
- pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
-
- # Convolutional Layer #2 and Pooling Layer #2
- conv2 = tf.layers.conv2d(
- inputs=pool1,
- filters=64,
- kernel_size=[5, 5],
- padding="same",
- activation=tf.nn.relu)
- pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
-
- # Dense Layer
- pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
- dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
- dropout = tf.layers.dropout(
- inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
-
- # Logits Layer
- logits = tf.layers.dense(inputs=dropout, units=10)
-
- predictions = {
- # Generate predictions (for PREDICT and EVAL mode)
- "classes": tf.argmax(input=logits, axis=1),
- # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
- # `logging_hook`.
- "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
- }
-
- if mode == tf.estimator.ModeKeys.PREDICT:
- return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
-
- # Calculate Loss (for both TRAIN and EVAL modes)
- loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
-
- # Configure the Training Op (for TRAIN mode)
- if mode == tf.estimator.ModeKeys.TRAIN:
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
- train_op = optimizer.minimize(
- loss=loss,
- global_step=tf.train.get_global_step())
- return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
-
- # Add evaluation metrics (for EVAL mode)
- eval_metric_ops = {
- "accuracy": tf.metrics.accuracy(
- labels=labels, predictions=predictions["classes"])}
- return tf.estimator.EstimatorSpec(
- mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
-```
-
-The following sections (with headings corresponding to each code block above)
-dive deeper into the `tf.layers` code used to create each layer, as well as how
-to calculate loss, configure the training op, and generate predictions. If
-you're already experienced with CNNs and @{$custom_estimators$TensorFlow `Estimator`s},
-and find the above code intuitive, you may want to skim these sections or just
-skip ahead to ["Training and Evaluating the CNN MNIST Classifier"](#train_eval_mnist).
-
-### Input Layer
-
-The methods in the `layers` module for creating convolutional and pooling layers
-for two-dimensional image data expect input tensors to have a shape of
-<code>[<em>batch_size</em>, <em>image_height</em>, <em>image_width</em>,
-<em>channels</em>]</code> by default. This behavior can be changed using the <code><em>data_format</em></code> parameter; defined as follows:
-
-
-* _`batch_size`_. Size of the subset of examples to use when performing
- gradient descent during training.
-* _`image_height`_. Height of the example images.
-* _`image_width`_. Width of the example images.
-* _`channels`_. Number of color channels in the example images. For color
- images, the number of channels is 3 (red, green, blue). For monochrome
- images, there is just 1 channel (black).
-* _`data_format`_. A string, one of `channels_last` (default) or `channels_first`.
- `channels_last` corresponds to inputs with shape
- `(batch, ..., channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, ...)`.
-
-Here, our MNIST dataset is composed of monochrome 28x28 pixel images, so the
-desired shape for our input layer is <code>[<em>batch_size</em>, 28, 28,
-1]</code>.
-
-To convert our input feature map (`features`) to this shape, we can perform the
-following `reshape` operation:
-
-```python
-input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])
-```
-
-Note that we've indicated `-1` for batch size, which specifies that this
-dimension should be dynamically computed based on the number of input values in
-`features["x"]`, holding the size of all other dimensions constant. This allows
-us to treat `batch_size` as a hyperparameter that we can tune. For example, if
-we feed examples into our model in batches of 5, `features["x"]` will contain
-3,920 values (one value for each pixel in each image), and `input_layer` will
-have a shape of `[5, 28, 28, 1]`. Similarly, if we feed examples in batches of
-100, `features["x"]` will contain 78,400 values, and `input_layer` will have a
-shape of `[100, 28, 28, 1]`.
-
-### Convolutional Layer #1
-
-In our first convolutional layer, we want to apply 32 5x5 filters to the input
-layer, with a ReLU activation function. We can use the `conv2d()` method in the
-`layers` module to create this layer as follows:
-
-```python
-conv1 = tf.layers.conv2d(
- inputs=input_layer,
- filters=32,
- kernel_size=[5, 5],
- padding="same",
- activation=tf.nn.relu)
-```
-
-The `inputs` argument specifies our input tensor, which must have the shape
-<code>[<em>batch_size</em>, <em>image_height</em>, <em>image_width</em>,
-<em>channels</em>]</code>. Here, we're connecting our first convolutional layer
-to `input_layer`, which has the shape <code>[<em>batch_size</em>, 28, 28,
-1]</code>.
-
-> Note: <code>conv2d()</code> will instead accept a shape of
-> <code>[<em>batch_size</em>, <em>channels</em>, <em>image_height</em>, <em>image_width</em>]</code> when passed the argument
-> <code>data_format=channels_first</code>.
-
-The `filters` argument specifies the number of filters to apply (here, 32), and
-`kernel_size` specifies the dimensions of the filters as <code>[<em>height</em>,
-<em>width</em>]</code> (here, <code>[5, 5]</code>).
-
-<p class="tip"><b>TIP:</b> If filter height and width have the same value, you can instead specify a
-single integer for <code>kernel_size</code>—e.g., <code>kernel_size=5</code>.</p>
-
-The `padding` argument specifies one of two enumerated values
-(case-insensitive): `valid` (default value) or `same`. To specify that the
-output tensor should have the same height and width values as the input tensor,
-we set `padding=same` here, which instructs TensorFlow to add 0 values to the
-edges of the input tensor to preserve height and width of 28. (Without padding,
-a 5x5 convolution over a 28x28 tensor will produce a 24x24 tensor, as there are
-24x24 locations to extract a 5x5 tile from a 28x28 grid.)
-
-The `activation` argument specifies the activation function to apply to the
-output of the convolution. Here, we specify ReLU activation with
-`tf.nn.relu`.
-
-Our output tensor produced by `conv2d()` has a shape of
-<code>[<em>batch_size</em>, 28, 28, 32]</code>: the same height and width
-dimensions as the input, but now with 32 channels holding the output from each
-of the filters.
-
-### Pooling Layer #1
-
-Next, we connect our first pooling layer to the convolutional layer we just
-created. We can use the `max_pooling2d()` method in `layers` to construct a
-layer that performs max pooling with a 2x2 filter and stride of 2:
-
-```python
-pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
-```
-
-Again, `inputs` specifies the input tensor, with a shape of
-<code>[<em>batch_size</em>, <em>image_height</em>, <em>image_width</em>,
-<em>channels</em>]</code>. Here, our input tensor is `conv1`, the output from
-the first convolutional layer, which has a shape of <code>[<em>batch_size</em>,
-28, 28, 32]</code>.
-
-> Note: As with <code>conv2d()</code>, <code>max_pooling2d()</code> will instead
-> accept a shape of <code>[<em>batch_size</em>, <em>channels</em>,
-> <em>image_height</em>, <em>image_width</em>]</code> when passed the argument
-> <code>data_format=channels_first</code>.
-
-The `pool_size` argument specifies the size of the max pooling filter as
-<code>[<em>height</em>, <em>width</em>]</code> (here, `[2, 2]`). If both
-dimensions have the same value, you can instead specify a single integer (e.g.,
-`pool_size=2`).
-
-The `strides` argument specifies the size of the stride. Here, we set a stride
-of 2, which indicates that the subregions extracted by the filter should be
-separated by 2 pixels in both the height and width dimensions (for a 2x2 filter,
-this means that none of the regions extracted will overlap). If you want to set
-different stride values for height and width, you can instead specify a tuple or
-list (e.g., `stride=[3, 6]`).
-
-Our output tensor produced by `max_pooling2d()` (`pool1`) has a shape of
-<code>[<em>batch_size</em>, 14, 14, 32]</code>: the 2x2 filter reduces height and width by 50% each.
-
-### Convolutional Layer #2 and Pooling Layer #2
-
-We can connect a second convolutional and pooling layer to our CNN using
-`conv2d()` and `max_pooling2d()` as before. For convolutional layer #2, we
-configure 64 5x5 filters with ReLU activation, and for pooling layer #2, we use
-the same specs as pooling layer #1 (a 2x2 max pooling filter with stride of 2):
-
-```python
-conv2 = tf.layers.conv2d(
- inputs=pool1,
- filters=64,
- kernel_size=[5, 5],
- padding="same",
- activation=tf.nn.relu)
-
-pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
-```
-
-Note that convolutional layer #2 takes the output tensor of our first pooling
-layer (`pool1`) as input, and produces the tensor `conv2` as output. `conv2`
-has a shape of <code>[<em>batch_size</em>, 14, 14, 64]</code>, the same height and width as `pool1` (due to `padding="same"`), and 64 channels for the 64
-filters applied.
-
-Pooling layer #2 takes `conv2` as input, producing `pool2` as output. `pool2`
-has shape <code>[<em>batch_size</em>, 7, 7, 64]</code> (50% reduction of height and width from `conv2`).
-
-### Dense Layer
-
-Next, we want to add a dense layer (with 1,024 neurons and ReLU activation) to
-our CNN to perform classification on the features extracted by the
-convolution/pooling layers. Before we connect the layer, however, we'll flatten
-our feature map (`pool2`) to shape <code>[<em>batch_size</em>,
-<em>features</em>]</code>, so that our tensor has only two dimensions:
-
-```python
-pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
-```
-
-In the `reshape()` operation above, the `-1` signifies that the *`batch_size`*
-dimension will be dynamically calculated based on the number of examples in our
-input data. Each example has 7 (`pool2` height) * 7 (`pool2` width) * 64
-(`pool2` channels) features, so we want the `features` dimension to have a value
-of 7 * 7 * 64 (3136 in total). The output tensor, `pool2_flat`, has shape
-<code>[<em>batch_size</em>, 3136]</code>.
-
-Now, we can use the `dense()` method in `layers` to connect our dense layer as
-follows:
-
-```python
-dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
-```
-
-The `inputs` argument specifies the input tensor: our flattened feature map,
-`pool2_flat`. The `units` argument specifies the number of neurons in the dense
-layer (1,024). The `activation` argument takes the activation function; again,
-we'll use `tf.nn.relu` to add ReLU activation.
-
-To help improve the results of our model, we also apply dropout regularization
-to our dense layer, using the `dropout` method in `layers`:
-
-```python
-dropout = tf.layers.dropout(
- inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
-```
-
-Again, `inputs` specifies the input tensor, which is the output tensor from our
-dense layer (`dense`).
-
-The `rate` argument specifies the dropout rate; here, we use `0.4`, which means
-40% of the elements will be randomly dropped out during training.
-
-The `training` argument takes a boolean specifying whether or not the model is
-currently being run in training mode; dropout will only be performed if
-`training` is `True`. Here, we check if the `mode` passed to our model function
-`cnn_model_fn` is `TRAIN` mode.
-
-Our output tensor `dropout` has shape <code>[<em>batch_size</em>, 1024]</code>.
-
-### Logits Layer
-
-The final layer in our neural network is the logits layer, which will return the
-raw values for our predictions. We create a dense layer with 10 neurons (one for
-each target class 0–9), with linear activation (the default):
-
-```python
-logits = tf.layers.dense(inputs=dropout, units=10)
-```
-
-Our final output tensor of the CNN, `logits`, has shape
-<code>[<em>batch_size</em>, 10]</code>.
-
-### Generate Predictions {#generate_predictions}
-
-The logits layer of our model returns our predictions as raw values in a
-<code>[<em>batch_size</em>, 10]</code>-dimensional tensor. Let's convert these
-raw values into two different formats that our model function can return:
-
-* The **predicted class** for each example: a digit from 0–9.
-* The **probabilities** for each possible target class for each example: the
- probability that the example is a 0, is a 1, is a 2, etc.
-
-For a given example, our predicted class is the element in the corresponding row
-of the logits tensor with the highest raw value. We can find the index of this
-element using the `tf.argmax`
-function:
-
-```python
-tf.argmax(input=logits, axis=1)
-```
-
-The `input` argument specifies the tensor from which to extract maximum
-values—here `logits`. The `axis` argument specifies the axis of the `input`
-tensor along which to find the greatest value. Here, we want to find the largest
-value along the dimension with index of 1, which corresponds to our predictions
-(recall that our logits tensor has shape <code>[<em>batch_size</em>,
-10]</code>).
-
-We can derive probabilities from our logits layer by applying softmax activation
-using `tf.nn.softmax`:
-
-```python
-tf.nn.softmax(logits, name="softmax_tensor")
-```
-
-> Note: We use the `name` argument to explicitly name this operation
-> `softmax_tensor`, so we can reference it later. (We'll set up logging for the
-> softmax values in ["Set Up a Logging Hook"](#set-up-a-logging-hook)).
-
-We compile our predictions in a dict, and return an `EstimatorSpec` object:
-
-```python
-predictions = {
- "classes": tf.argmax(input=logits, axis=1),
- "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
-}
-if mode == tf.estimator.ModeKeys.PREDICT:
- return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
-```
-
-### Calculate Loss {#calculating-loss}
-
-For both training and evaluation, we need to define a
-[loss function](https://en.wikipedia.org/wiki/Loss_function)
-that measures how closely the model's predictions match the target classes. For
-multiclass classification problems like MNIST,
-[cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) is typically used
-as the loss metric. The following code calculates cross entropy when the model
-runs in either `TRAIN` or `EVAL` mode:
-
-```python
-loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
-```
-
-Let's take a closer look at what's happening above.
-
-Our `labels` tensor contains a list of prediction indices for our examples, e.g. `[1,
-9, ...]`. `logits` contains the linear outputs of our last layer.
-
-`tf.losses.sparse_softmax_cross_entropy`, calculates the softmax crossentropy
-(aka: categorical crossentropy, negative log-likelihood) from these two inputs
-in an efficient, numerically stable way.
-
-
-### Configure the Training Op
-
-In the previous section, we defined loss for our CNN as the softmax
-cross-entropy of the logits layer and our labels. Let's configure our model to
-optimize this loss value during training. We'll use a learning rate of 0.001 and
-[stochastic gradient descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)
-as the optimization algorithm:
-
-```python
-if mode == tf.estimator.ModeKeys.TRAIN:
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
- train_op = optimizer.minimize(
- loss=loss,
- global_step=tf.train.get_global_step())
- return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
-```
-
-> Note: For a more in-depth look at configuring training ops for Estimator model
-> functions, see @{$custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"}
-> in the @{$custom_estimators$"Creating Estimations in tf.estimator"} tutorial.
-
-
-### Add evaluation metrics
-
-To add accuracy metric in our model, we define `eval_metric_ops` dict in EVAL
-mode as follows:
-
-```python
-eval_metric_ops = {
- "accuracy": tf.metrics.accuracy(
- labels=labels, predictions=predictions["classes"])}
-return tf.estimator.EstimatorSpec(
- mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
-```
-
-<a id="train_eval_mnist"></a>
-## Training and Evaluating the CNN MNIST Classifier
-
-We've coded our MNIST CNN model function; now we're ready to train and evaluate
-it.
-
-### Load Training and Test Data
-
-First, let's load our training and test data. Add a `main()` function to
-`cnn_mnist.py` with the following code:
-
-```python
-def main(unused_argv):
- # Load training and eval data
- mnist = tf.contrib.learn.datasets.load_dataset("mnist")
- train_data = mnist.train.images # Returns np.array
- train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
- eval_data = mnist.test.images # Returns np.array
- eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
-```
-
-We store the training feature data (the raw pixel values for 55,000 images of
-hand-drawn digits) and training labels (the corresponding value from 0–9 for
-each image) as [numpy
-arrays](https://docs.scipy.org/doc/numpy/reference/generated/numpy.array.html)
-in `train_data` and `train_labels`, respectively. Similarly, we store the
-evaluation feature data (10,000 images) and evaluation labels in `eval_data`
-and `eval_labels`, respectively.
-
-### Create the Estimator {#create-the-estimator}
-
-Next, let's create an `Estimator` (a TensorFlow class for performing high-level
-model training, evaluation, and inference) for our model. Add the following code
-to `main()`:
-
-```python
-# Create the Estimator
-mnist_classifier = tf.estimator.Estimator(
- model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
-```
-
-The `model_fn` argument specifies the model function to use for training,
-evaluation, and prediction; we pass it the `cnn_model_fn` we created in
-["Building the CNN MNIST Classifier."](#building-the-cnn-mnist-classifier) The
-`model_dir` argument specifies the directory where model data (checkpoints) will
-be saved (here, we specify the temp directory `/tmp/mnist_convnet_model`, but
-feel free to change to another directory of your choice).
-
-> Note: For an in-depth walkthrough of the TensorFlow `Estimator` API, see the
-> tutorial @{$custom_estimators$"Creating Estimators in tf.estimator."}
-
-### Set Up a Logging Hook {#set_up_a_logging_hook}
-
-Since CNNs can take a while to train, let's set up some logging so we can track
-progress during training. We can use TensorFlow's `tf.train.SessionRunHook` to create a
-`tf.train.LoggingTensorHook`
-that will log the probability values from the softmax layer of our CNN. Add the
-following to `main()`:
-
-```python
-# Set up logging for predictions
-tensors_to_log = {"probabilities": "softmax_tensor"}
-logging_hook = tf.train.LoggingTensorHook(
- tensors=tensors_to_log, every_n_iter=50)
-```
-
-We store a dict of the tensors we want to log in `tensors_to_log`. Each key is a
-label of our choice that will be printed in the log output, and the
-corresponding label is the name of a `Tensor` in the TensorFlow graph. Here, our
-`probabilities` can be found in `softmax_tensor`, the name we gave our softmax
-operation earlier when we generated the probabilities in `cnn_model_fn`.
-
-> Note: If you don't explicitly assign a name to an operation via the `name`
-> argument, TensorFlow will assign a default name. A couple easy ways to
-> discover the names applied to operations are to visualize your graph on
-> @{$graph_viz$TensorBoard}) or to enable the
-> @{$guide/debugger$TensorFlow Debugger (tfdbg)}.
-
-Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the
-`tensors` argument. We set `every_n_iter=50`, which specifies that probabilities
-should be logged after every 50 steps of training.
-
-### Train the Model
-
-Now we're ready to train our model, which we can do by creating `train_input_fn`
-and calling `train()` on `mnist_classifier`. Add the following to `main()`:
-
-```python
-# Train the model
-train_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={"x": train_data},
- y=train_labels,
- batch_size=100,
- num_epochs=None,
- shuffle=True)
-mnist_classifier.train(
- input_fn=train_input_fn,
- steps=20000,
- hooks=[logging_hook])
-```
-
-In the `numpy_input_fn` call, we pass the training feature data and labels to
-`x` (as a dict) and `y`, respectively. We set a `batch_size` of `100` (which
-means that the model will train on minibatches of 100 examples at each step).
-`num_epochs=None` means that the model will train until the specified number of
-steps is reached. We also set `shuffle=True` to shuffle the training data.
-In the `train` call, we set `steps=20000`
-(which means the model will train for 20,000 steps total). We pass our
-`logging_hook` to the `hooks` argument, so that it will be triggered during
-training.
-
-### Evaluate the Model
-
-Once training is complete, we want to evaluate our model to determine its
-accuracy on the MNIST test set. We call the `evaluate` method, which evaluates
-the metrics we specified in `eval_metric_ops` argument in the `model_fn`.
-Add the following to `main()`:
-
-```python
-# Evaluate the model and print results
-eval_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={"x": eval_data},
- y=eval_labels,
- num_epochs=1,
- shuffle=False)
-eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
-print(eval_results)
-```
-
-To create `eval_input_fn`, we set `num_epochs=1`, so that the model evaluates
-the metrics over one epoch of data and returns the result. We also set
-`shuffle=False` to iterate through the data sequentially.
-
-### Run the Model
-
-We've coded the CNN model function, `Estimator`, and the training/evaluation
-logic; now let's see the results. Run `cnn_mnist.py`.
-
-> Note: Training CNNs is quite computationally intensive. Estimated completion
-> time of `cnn_mnist.py` will vary depending on your processor, but will likely
-> be upwards of 1 hour on CPU. To train more quickly, you can decrease the
-> number of `steps` passed to `train()`, but note that this will affect accuracy.
-
-As the model trains, you'll see log output like the following:
-
-```python
-INFO:tensorflow:loss = 2.36026, step = 1
-INFO:tensorflow:probabilities = [[ 0.07722801 0.08618255 0.09256398, ...]]
-...
-INFO:tensorflow:loss = 2.13119, step = 101
-INFO:tensorflow:global_step/sec: 5.44132
-...
-INFO:tensorflow:Loss for final step: 0.553216.
-
-INFO:tensorflow:Restored model from /tmp/mnist_convnet_model
-INFO:tensorflow:Eval steps [0,inf) for training step 20000.
-INFO:tensorflow:Input iterator is exhausted.
-INFO:tensorflow:Saving evaluation summary for step 20000: accuracy = 0.9733, loss = 0.0902271
-{'loss': 0.090227105, 'global_step': 20000, 'accuracy': 0.97329998}
-```
-
-Here, we've achieved an accuracy of 97.3% on our test data set.
-
-## Additional Resources
-
-To learn more about TensorFlow Estimators and CNNs in TensorFlow, see the
-following resources:
-
-* @{$custom_estimators$Creating Estimators in tf.estimator}
- provides an introduction to the TensorFlow Estimator API. It walks through
- configuring an Estimator, writing a model function, calculating loss, and
- defining a training op.
-* @{$deep_cnn} walks through how to build a MNIST CNN classification model
- *without estimators* using lower-level TensorFlow operations.
diff --git a/tensorflow/docs_src/tutorials/estimators/linear.md b/tensorflow/docs_src/tutorials/estimators/linear.md
deleted file mode 100644
index 067a33ac03..0000000000
--- a/tensorflow/docs_src/tutorials/estimators/linear.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Build a linear model with Estimators
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/estimators/linear.ipynb)
diff --git a/tensorflow/docs_src/tutorials/images/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md
deleted file mode 100644
index 42ad484bbf..0000000000
--- a/tensorflow/docs_src/tutorials/images/deep_cnn.md
+++ /dev/null
@@ -1,446 +0,0 @@
-# Advanced Convolutional Neural Networks
-
-## Overview
-
-CIFAR-10 classification is a common benchmark problem in machine learning. The
-problem is to classify RGB 32x32 pixel images across 10 categories:
-```
-airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
-```
-
-For more details refer to the [CIFAR-10 page](https://www.cs.toronto.edu/~kriz/cifar.html)
-and a [Tech Report](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
-by Alex Krizhevsky.
-
-### Goals
-
-The goal of this tutorial is to build a relatively small [convolutional neural
-network](https://en.wikipedia.org/wiki/Convolutional_neural_network) (CNN) for
-recognizing images. In the process, this tutorial:
-
-1. Highlights a canonical organization for network architecture,
-training and evaluation.
-2. Provides a template for constructing larger and more sophisticated models.
-
-The reason CIFAR-10 was selected was that it is complex enough to exercise
-much of TensorFlow's ability to scale to large models. At the same time,
-the model is small enough to train fast, which is ideal for trying out
-new ideas and experimenting with new techniques.
-
-### Highlights of the Tutorial
-The CIFAR-10 tutorial demonstrates several important constructs for
-designing larger and more sophisticated models in TensorFlow:
-
-* Core mathematical components including `tf.nn.conv2d`
-([wiki](https://en.wikipedia.org/wiki/Convolution)),
-`tf.nn.relu`
-([wiki](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))),
-`tf.nn.max_pool`
-([wiki](https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer))
-and `tf.nn.local_response_normalization`
-(Chapter 3.3 in
-[AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)).
-* @{$summaries_and_tensorboard$Visualization}
-of network activities during training, including input images,
-losses and distributions of activations and gradients.
-* Routines for calculating the
-`tf.train.ExponentialMovingAverage`
-of learned parameters and using these averages
-during evaluation to boost predictive performance.
-* Implementation of a
-`tf.train.exponential_decay`
-that systematically decrements over time.
-* Prefetching `tf.train.shuffle_batch`
-for input
-data to isolate the model from disk latency and expensive image pre-processing.
-
-We also provide a [multi-GPU version](#training-a-model-using-multiple-gpu-cards)
-of the model which demonstrates:
-
-* Configuring a model to train across multiple GPU cards in parallel.
-* Sharing and updating variables among multiple GPUs.
-
-We hope that this tutorial provides a launch point for building larger CNNs for
-vision tasks on TensorFlow.
-
-### Model Architecture
-
-The model in this CIFAR-10 tutorial is a multi-layer architecture consisting of
-alternating convolutions and nonlinearities. These layers are followed by fully
-connected layers leading into a softmax classifier. The model follows the
-architecture described by
-[Alex Krizhevsky](https://code.google.com/p/cuda-convnet/), with a few
-differences in the top few layers.
-
-This model achieves a peak performance of about 86% accuracy within a few hours
-of training time on a GPU. Please see [below](#evaluating-a-model) and the code
-for details. It consists of 1,068,298 learnable parameters and requires about
-19.5M multiply-add operations to compute inference on a single image.
-
-## Code Organization
-
-The code for this tutorial resides in
-[`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
-
-File | Purpose
---- | ---
-[`cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format.
-[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model.
-[`cifar10_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU.
-[`cifar10_multi_gpu_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs.
-[`cifar10_eval.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
-
-
-## CIFAR-10 Model
-
-The CIFAR-10 network is largely contained in
-[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py).
-The complete training
-graph contains roughly 765 operations. We find that we can make the code most
-reusable by constructing the graph with the following modules:
-
-1. [**Model inputs:**](#model-inputs) `inputs()` and `distorted_inputs()` add
-operations that read and preprocess CIFAR images for evaluation and training,
-respectively.
-1. [**Model prediction:**](#model-prediction) `inference()`
-adds operations that perform inference, i.e. classification, on supplied images.
-1. [**Model training:**](#model-training) `loss()` and `train()`
-add operations that compute the loss,
-gradients, variable updates and visualization summaries.
-
-### Model Inputs
-
-The input part of the model is built by the functions `inputs()` and
-`distorted_inputs()` which read images from the CIFAR-10 binary data files.
-These files contain fixed byte length records, so we use
-`tf.FixedLengthRecordReader`.
-See @{$reading_data#reading-from-files$Reading Data} to
-learn more about how the `Reader` class works.
-
-The images are processed as follows:
-
-* They are cropped to 24 x 24 pixels, centrally for evaluation or
- `tf.random_crop` for training.
-* They are `tf.image.per_image_standardization`
- to make the model insensitive to dynamic range.
-
-For training, we additionally apply a series of random distortions to
-artificially increase the data set size:
-
-* `tf.image.random_flip_left_right` the image from left to right.
-* Randomly distort the `tf.image.random_brightness`.
-* Randomly distort the `tf.image.random_contrast`.
-
-Please see the @{$python/image$Images} page for the list of
-available distortions. We also attach an
-`tf.summary.image` to the images
-so that we may visualize them in @{$summaries_and_tensorboard$TensorBoard}.
-This is a good practice to verify that inputs are built correctly.
-
-<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:70%" src="https://www.tensorflow.org/images/cifar_image_summary.png">
-</div>
-
-Reading images from disk and distorting them can use a non-trivial amount of
-processing time. To prevent these operations from slowing down training, we run
-them inside 16 separate threads which continuously fill a TensorFlow
-`tf.train.shuffle_batch`.
-
-### Model Prediction
-
-The prediction part of the model is constructed by the `inference()` function
-which adds operations to compute the *logits* of the predictions. That part of
-the model is organized as follows:
-
-Layer Name | Description
---- | ---
-`conv1` | `tf.nn.conv2d` and `tf.nn.relu` activation.
-`pool1` | `tf.nn.max_pool`.
-`norm1` | `tf.nn.local_response_normalization`.
-`conv2` | `tf.nn.conv2d` and `tf.nn.relu` activation.
-`norm2` | `tf.nn.local_response_normalization`.
-`pool2` | `tf.nn.max_pool`.
-`local3` | @{$python/nn$fully connected layer with rectified linear activation}.
-`local4` | @{$python/nn$fully connected layer with rectified linear activation}.
-`softmax_linear` | linear transformation to produce logits.
-
-Here is a graph generated from TensorBoard describing the inference operation:
-
-<div style="width:15%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/cifar_graph.png">
-</div>
-
-> **EXERCISE**: The output of `inference` are un-normalized logits. Try editing
-the network architecture to return normalized predictions using
-`tf.nn.softmax`.
-
-The `inputs()` and `inference()` functions provide all the components
-necessary to perform an evaluation of a model. We now shift our focus towards
-building operations for training a model.
-
-> **EXERCISE:** The model architecture in `inference()` differs slightly from
-the CIFAR-10 model specified in
-[cuda-convnet](https://code.google.com/p/cuda-convnet/). In particular, the top
-layers of Alex's original model are locally connected and not fully connected.
-Try editing the architecture to exactly reproduce the locally connected
-architecture in the top layer.
-
-### Model Training
-
-The usual method for training a network to perform N-way classification is
-[multinomial logistic regression](https://en.wikipedia.org/wiki/Multinomial_logistic_regression),
-aka. *softmax regression*. Softmax regression applies a
-`tf.nn.softmax` nonlinearity to the
-output of the network and calculates the
-`tf.nn.sparse_softmax_cross_entropy_with_logits`
-between the normalized predictions and the label index.
-For regularization, we also apply the usual
-`tf.nn.l2_loss` losses to all learned
-variables. The objective function for the model is the sum of the cross entropy
-loss and all these weight decay terms, as returned by the `loss()` function.
-
-We visualize it in TensorBoard with a `tf.summary.scalar`:
-
-![CIFAR-10 Loss](https://www.tensorflow.org/images/cifar_loss.png "CIFAR-10 Total Loss")
-
-We train the model using standard
-[gradient descent](https://en.wikipedia.org/wiki/Gradient_descent)
-algorithm (see @{$python/train$Training} for other methods)
-with a learning rate that
-`tf.train.exponential_decay`
-over time.
-
-![CIFAR-10 Learning Rate Decay](https://www.tensorflow.org/images/cifar_lr_decay.png "CIFAR-10 Learning Rate Decay")
-
-The `train()` function adds the operations needed to minimize the objective by
-calculating the gradient and updating the learned variables (see
-`tf.train.GradientDescentOptimizer`
-for details). It returns an operation that executes all the calculations
-needed to train and update the model for one batch of images.
-
-## Launching and Training the Model
-
-We have built the model, let's now launch it and run the training operation with
-the script `cifar10_train.py`.
-
-```shell
-python cifar10_train.py
-```
-
-> **NOTE:** The first time you run any target in the CIFAR-10 tutorial,
-the CIFAR-10 dataset is automatically downloaded. The data set is ~160MB
-so you may want to grab a quick cup of coffee for your first run.
-
-You should see the output:
-
-```shell
-Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
-2015-11-04 11:45:45.927302: step 0, loss = 4.68 (2.0 examples/sec; 64.221 sec/batch)
-2015-11-04 11:45:49.133065: step 10, loss = 4.66 (533.8 examples/sec; 0.240 sec/batch)
-2015-11-04 11:45:51.397710: step 20, loss = 4.64 (597.4 examples/sec; 0.214 sec/batch)
-2015-11-04 11:45:54.446850: step 30, loss = 4.62 (391.0 examples/sec; 0.327 sec/batch)
-2015-11-04 11:45:57.152676: step 40, loss = 4.61 (430.2 examples/sec; 0.298 sec/batch)
-2015-11-04 11:46:00.437717: step 50, loss = 4.59 (406.4 examples/sec; 0.315 sec/batch)
-...
-```
-
-The script reports the total loss every 10 steps as well as the speed at which
-the last batch of data was processed. A few comments:
-
-* The first batch of data can be inordinately slow (e.g. several minutes) as the
-preprocessing threads fill up the shuffling queue with 20,000 processed CIFAR
-images.
-
-* The reported loss is the average loss of the most recent batch. Remember that
-this loss is the sum of the cross entropy and all weight decay terms.
-
-* Keep an eye on the processing speed of a batch. The numbers shown above were
-obtained on a Tesla K40c. If you are running on a CPU, expect slower performance.
-
-
-> **EXERCISE:** When experimenting, it is sometimes annoying that the first
-training step can take so long. Try decreasing the number of images that
-initially fill up the queue. Search for `min_fraction_of_examples_in_queue`
-in `cifar10_input.py`.
-
-`cifar10_train.py` periodically uses a `tf.train.Saver` to save
-all model parameters in
-@{$guide/saved_model$checkpoint files}
-but it does *not* evaluate the model. The checkpoint file
-will be used by `cifar10_eval.py` to measure the predictive
-performance (see [Evaluating a Model](#evaluating-a-model) below).
-
-
-If you followed the previous steps, then you have now started training
-a CIFAR-10 model. [Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0)
-
-The terminal text returned from `cifar10_train.py` provides minimal insight into
-how the model is training. We want more insight into the model during training:
-
-* Is the loss *really* decreasing or is that just noise?
-* Is the model being provided appropriate images?
-* Are the gradients, activations and weights reasonable?
-* What is the learning rate currently at?
-
-@{$summaries_and_tensorboard$TensorBoard} provides this
-functionality, displaying data exported periodically from `cifar10_train.py` via
-a
-`tf.summary.FileWriter`.
-
-For instance, we can watch how the distribution of activations and degree of
-sparsity in `local3` features evolve during training:
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px; display: flex; flex-direction: row">
- <img style="flex-grow:1; flex-shrink:1;" src="https://www.tensorflow.org/images/cifar_sparsity.png">
- <img style="flex-grow:1; flex-shrink:1;" src="https://www.tensorflow.org/images/cifar_activations.png">
-</div>
-
-Individual loss functions, as well as the total loss, are particularly
-interesting to track over time. However, the loss exhibits a considerable amount
-of noise due to the small batch size employed by training. In practice we find
-it extremely useful to visualize their moving averages in addition to their raw
-values. See how the scripts use
-`tf.train.ExponentialMovingAverage`
-for this purpose.
-
-## Evaluating a Model
-
-Let us now evaluate how well the trained model performs on a hold-out data set.
-The model is evaluated by the script `cifar10_eval.py`. It constructs the model
-with the `inference()` function and uses all 10,000 images in the evaluation set
-of CIFAR-10. It calculates the *precision at 1:* how often the top prediction
-matches the true label of the image.
-
-To monitor how the model improves during training, the evaluation script runs
-periodically on the latest checkpoint files created by the `cifar10_train.py`.
-
-```shell
-python cifar10_eval.py
-```
-
-> Be careful not to run the evaluation and training binary on the same GPU or
-else you might run out of memory. Consider running the evaluation on
-a separate GPU if available or suspending the training binary while running
-the evaluation on the same GPU.
-
-You should see the output:
-
-```shell
-2015-11-06 08:30:44.391206: precision @ 1 = 0.860
-...
-```
-
-The script merely returns the precision @ 1 periodically -- in this case
-it returned 86% accuracy. `cifar10_eval.py` also
-exports summaries that may be visualized in TensorBoard. These summaries
-provide additional insight into the model during evaluation.
-
-The training script calculates the
-`tf.train.ExponentialMovingAverage` of all learned variables.
-The evaluation script substitutes
-all learned model parameters with the moving average version. This
-substitution boosts model performance at evaluation time.
-
-> **EXERCISE:** Employing averaged parameters may boost predictive performance
-by about 3% as measured by precision @ 1. Edit `cifar10_eval.py` to not employ
-the averaged parameters for the model and verify that the predictive performance
-drops.
-
-
-## Training a Model Using Multiple GPU Cards
-
-Modern workstations may contain multiple GPUs for scientific computation.
-TensorFlow can leverage this environment to run the training operation
-concurrently across multiple cards.
-
-Training a model in a parallel, distributed fashion requires
-coordinating training processes. For what follows we term *model replica*
-to be one copy of a model training on a subset of data.
-
-Naively employing asynchronous updates of model parameters
-leads to sub-optimal training performance
-because an individual model replica might be trained on a stale
-copy of the model parameters. Conversely, employing fully synchronous
-updates will be as slow as the slowest model replica.
-
-In a workstation with multiple GPU cards, each GPU will have similar speed
-and contain enough memory to run an entire CIFAR-10 model. Thus, we opt to
-design our training system in the following manner:
-
-* Place an individual model replica on each GPU.
-* Update model parameters synchronously by waiting for all GPUs to finish
-processing a batch of data.
-
-Here is a diagram of this model:
-
-<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/Parallelism.png">
-</div>
-
-Note that each GPU computes inference as well as the gradients for a unique
-batch of data. This setup effectively permits dividing up a larger batch
-of data across the GPUs.
-
-This setup requires that all GPUs share the model parameters. A well-known
-fact is that transferring data to and from GPUs is quite slow. For this
-reason, we decide to store and update all model parameters on the CPU (see
-green box). A fresh set of model parameters is transferred to the GPU
-when a new batch of data is processed by all GPUs.
-
-The GPUs are synchronized in operation. All gradients are accumulated from
-the GPUs and averaged (see green box). The model parameters are updated with
-the gradients averaged across all model replicas.
-
-### Placing Variables and Operations on Devices
-
-Placing operations and variables on devices requires some special
-abstractions.
-
-The first abstraction we require is a function for computing inference and
-gradients for a single model replica. In the code we term this abstraction
-a "tower". We must set two attributes for each tower:
-
-* A unique name for all operations within a tower.
-`tf.name_scope` provides
-this unique name by prepending a scope. For instance, all operations in
-the first tower are prepended with `tower_0`, e.g. `tower_0/conv1/Conv2D`.
-
-* A preferred hardware device to run the operation within a tower.
-`tf.device` specifies this. For
-instance, all operations in the first tower reside within `device('/device:GPU:0')`
-scope indicating that they should be run on the first GPU.
-
-All variables are pinned to the CPU and accessed via
-`tf.get_variable`
-in order to share them in a multi-GPU version.
-See how-to on @{$variables$Sharing Variables}.
-
-### Launching and Training the Model on Multiple GPU cards
-
-If you have several GPU cards installed on your machine you can use them to
-train the model faster with the `cifar10_multi_gpu_train.py` script. This
-version of the training script parallelizes the model across multiple GPU cards.
-
-```shell
-python cifar10_multi_gpu_train.py --num_gpus=2
-```
-
-Note that the number of GPU cards used defaults to 1. Additionally, if only 1
-GPU is available on your machine, all computations will be placed on it, even if
-you ask for more.
-
-> **EXERCISE:** The default settings for `cifar10_train.py` is to
-run on a batch size of 128. Try running `cifar10_multi_gpu_train.py` on 2 GPUs
-with a batch size of 64 and compare the training speed.
-
-## Next Steps
-
-If you are now interested in developing and training your own image
-classification system, we recommend forking this tutorial and replacing
-components to address your image classification problem.
-
-
-> **EXERCISE:** Download the
-[Street View House Numbers (SVHN)](http://ufldl.stanford.edu/housenumbers/) data set.
-Fork the CIFAR-10 tutorial and swap in the SVHN as the input data. Try adapting
-the network architecture to improve predictive performance.
diff --git a/tensorflow/docs_src/tutorials/images/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md
deleted file mode 100644
index 83a8d97cf0..0000000000
--- a/tensorflow/docs_src/tutorials/images/image_recognition.md
+++ /dev/null
@@ -1,455 +0,0 @@
-# Image Recognition
-
-Our brains make vision seem easy. It doesn't take any effort for humans to
-tell apart a lion and a jaguar, read a sign, or recognize a human's face.
-But these are actually hard problems to solve with a computer: they only
-seem easy because our brains are incredibly good at understanding images.
-
-In the last few years, the field of machine learning has made tremendous
-progress on addressing these difficult problems. In particular, we've
-found that a kind of model called a deep
-[convolutional neural network](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/)
-can achieve reasonable performance on hard visual recognition tasks --
-matching or exceeding human performance in some domains.
-
-Researchers have demonstrated steady progress
-in computer vision by validating their work against
-[ImageNet](http://www.image-net.org) -- an academic benchmark for computer vision.
-Successive models continue to show improvements, each time achieving
-a new state-of-the-art result:
-[QuocNet], [AlexNet], [Inception (GoogLeNet)], [BN-Inception-v2].
-Researchers both internal and external to Google have published papers describing all
-these models but the results are still hard to reproduce.
-We're now taking the next step by releasing code for running image recognition
-on our latest model, [Inception-v3].
-
-[QuocNet]: https://static.googleusercontent.com/media/research.google.com/en//archive/unsupervised_icml2012.pdf
-[AlexNet]: https://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
-[Inception (GoogLeNet)]: https://arxiv.org/abs/1409.4842
-[BN-Inception-v2]: https://arxiv.org/abs/1502.03167
-[Inception-v3]: https://arxiv.org/abs/1512.00567
-
-Inception-v3 is trained for the [ImageNet] Large Visual Recognition Challenge
-using the data from 2012. This is a standard task in computer vision,
-where models try to classify entire
-images into [1000 classes], like "Zebra", "Dalmatian", and "Dishwasher".
-For example, here are the results from [AlexNet] classifying some images:
-
-<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/AlexClassification.png">
-</div>
-
-To compare models, we examine how often the model fails to predict the
-correct answer as one of their top 5 guesses -- termed "top-5 error rate".
-[AlexNet] achieved by setting a top-5 error rate of 15.3% on the 2012
-validation data set; [Inception (GoogLeNet)] achieved 6.67%;
-[BN-Inception-v2] achieved 4.9%; [Inception-v3] reaches 3.46%.
-
-> How well do humans do on ImageNet Challenge? There's a [blog post] by
-Andrej Karpathy who attempted to measure his own performance. He reached
-5.1% top-5 error rate.
-
-[ImageNet]: http://image-net.org/
-[1000 classes]: http://image-net.org/challenges/LSVRC/2014/browse-synsets
-[blog post]: https://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/
-
-This tutorial will teach you how to use [Inception-v3]. You'll learn how to
-classify images into [1000 classes] in Python or C++. We'll also discuss how to
-extract higher level features from this model which may be reused for other
-vision tasks.
-
-We're excited to see what the community will do with this model.
-
-
-##Usage with Python API
-
-`classify_image.py` downloads the trained model from `tensorflow.org`
-when the program is run for the first time. You'll need about 200M of free space
-available on your hard disk.
-
-Start by cloning the [TensorFlow models repo](https://github.com/tensorflow/models) from GitHub. Run the following commands:
-
- cd models/tutorials/image/imagenet
- python classify_image.py
-
-The above command will classify a supplied image of a panda bear.
-
-<div style="width:15%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/cropped_panda.jpg">
-</div>
-
-If the model runs correctly, the script will produce the following output:
-
- giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493)
- indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878)
- lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317)
- custard apple (score = 0.00149)
- earthstar (score = 0.00127)
-
-If you wish to supply other JPEG images, you may do so by editing
-the `--image_file` argument.
-
-> If you download the model data to a different directory, you
-will need to point `--model_dir` to the directory used.
-
-## Usage with the C++ API
-
-You can run the same [Inception-v3] model in C++ for use in production
-environments. You can download the archive containing the GraphDef that defines
-the model like this (running from the root directory of the TensorFlow
-repository):
-
-```bash
-curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" |
- tar -C tensorflow/examples/label_image/data -xz
-```
-
-Next, we need to compile the C++ binary that includes the code to load and run the graph.
-If you've followed
-@{$install_sources$the instructions to download the source installation of TensorFlow}
-for your platform, you should be able to build the example by
-running this command from your shell terminal:
-
-```bash
-bazel build tensorflow/examples/label_image/...
-```
-
-That should create a binary executable that you can then run like this:
-
-```bash
-bazel-bin/tensorflow/examples/label_image/label_image
-```
-
-This uses the default example image that ships with the framework, and should
-output something similar to this:
-
-```
-I tensorflow/examples/label_image/main.cc:206] military uniform (653): 0.834306
-I tensorflow/examples/label_image/main.cc:206] mortarboard (668): 0.0218692
-I tensorflow/examples/label_image/main.cc:206] academic gown (401): 0.0103579
-I tensorflow/examples/label_image/main.cc:206] pickelhaube (716): 0.00800814
-I tensorflow/examples/label_image/main.cc:206] bulletproof vest (466): 0.00535088
-```
-In this case, we're using the default image of
-[Admiral Grace Hopper](https://en.wikipedia.org/wiki/Grace_Hopper), and you can
-see the network correctly identifies she's wearing a military uniform, with a high
-score of 0.8.
-
-
-<div style="width:45%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/grace_hopper.jpg">
-</div>
-
-Next, try it out on your own images by supplying the --image= argument, e.g.
-
-```bash
-bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png
-```
-
-If you look inside the [`tensorflow/examples/label_image/main.cc`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc)
-file, you can find out
-how it works. We hope this code will help you integrate TensorFlow into
-your own applications, so we will walk step by step through the main functions:
-
-The command line flags control where the files are loaded from, and properties of the input images.
-The model expects to get square 299x299 RGB images, so those are the `input_width`
-and `input_height` flags. We also need to scale the pixel values from integers that
-are between 0 and 255 to the floating point values that the graph operates on.
-We control the scaling with the `input_mean` and `input_std` flags: we first subtract
-`input_mean` from each pixel value, then divide it by `input_std`.
-
-These values probably look somewhat magical, but they are just defined by the
-original model author based on what he/she wanted to use as input images for
-training. If you have a graph that you've trained yourself, you'll just need
-to adjust the values to match whatever you used during your training process.
-
-You can see how they're applied to an image in the
-[`ReadTensorFromImageFile()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc#L88)
-function.
-
-```C++
-// Given an image file name, read in the data, try to decode it as an image,
-// resize it to the requested size, and then scale the values as desired.
-Status ReadTensorFromImageFile(string file_name, const int input_height,
- const int input_width, const float input_mean,
- const float input_std,
- std::vector<Tensor>* out_tensors) {
- tensorflow::GraphDefBuilder b;
-```
-We start by creating a `GraphDefBuilder`, which is an object we can use to
-specify a model to run or load.
-
-```C++
- string input_name = "file_reader";
- string output_name = "normalized";
- tensorflow::Node* file_reader =
- tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
- b.opts().WithName(input_name));
-```
-We then start creating nodes for the small model we want to run
-to load, resize, and scale the pixel values to get the result the main model
-expects as its input. The first node we create is just a `Const` op that holds a
-tensor with the file name of the image we want to load. That's then passed as the
-first input to the `ReadFile` op. You might notice we're passing `b.opts()` as the last
-argument to all the op creation functions. The argument ensures that the node is added to
-the model definition held in the `GraphDefBuilder`. We also name the `ReadFile`
-operator by making the `WithName()` call to `b.opts()`. This gives a name to the node,
-which isn't strictly necessary since an automatic name will be assigned if you don't
-do this, but it does make debugging a bit easier.
-
-```C++
- // Now try to figure out what kind of file it is and decode it.
- const int wanted_channels = 3;
- tensorflow::Node* image_reader;
- if (tensorflow::StringPiece(file_name).ends_with(".png")) {
- image_reader = tensorflow::ops::DecodePng(
- file_reader,
- b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
- } else {
- // Assume if it's not a PNG then it must be a JPEG.
- image_reader = tensorflow::ops::DecodeJpeg(
- file_reader,
- b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
- }
- // Now cast the image data to float so we can do normal math on it.
- tensorflow::Node* float_caster = tensorflow::ops::Cast(
- image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
- // The convention for image ops in TensorFlow is that all images are expected
- // to be in batches, so that they're four-dimensional arrays with indices of
- // [batch, height, width, channel]. Because we only have a single image, we
- // have to add a batch dimension of 1 to the start with ExpandDims().
- tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
- float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
- // Bilinearly resize the image to fit the required dimensions.
- tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
- dims_expander, tensorflow::ops::Const({input_height, input_width},
- b.opts().WithName("size")),
- b.opts());
- // Subtract the mean and divide by the scale.
- tensorflow::ops::Div(
- tensorflow::ops::Sub(
- resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
- tensorflow::ops::Const({input_std}, b.opts()),
- b.opts().WithName(output_name));
-```
-We then keep adding more nodes, to decode the file data as an image, to cast the
-integers into floating point values, to resize it, and then finally to run the
-subtraction and division operations on the pixel values.
-
-```C++
- // This runs the GraphDef network definition that we've just constructed, and
- // returns the results in the output tensor.
- tensorflow::GraphDef graph;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
-```
-At the end of this we have
-a model definition stored in the b variable, which we turn into a full graph
-definition with the `ToGraphDef()` function.
-
-```C++
- std::unique_ptr<tensorflow::Session> session(
- tensorflow::NewSession(tensorflow::SessionOptions()));
- TF_RETURN_IF_ERROR(session->Create(graph));
- TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors));
- return Status::OK();
-```
-Then we create a `tf.Session`
-object, which is the interface to actually running the graph, and run it,
-specifying which node we want to get the output from, and where to put the
-output data.
-
-This gives us a vector of `Tensor` objects, which in this case we know will only be a
-single object long. You can think of a `Tensor` as a multi-dimensional array in this
-context, and it holds a 299 pixel high, 299 pixel wide, 3 channel image as float
-values. If you have your own image-processing framework in your product already, you
-should be able to use that instead, as long as you apply the same transformations
-before you feed images into the main graph.
-
-This is a simple example of creating a small TensorFlow graph dynamically in C++,
-but for the pre-trained Inception model we want to load a much larger definition from
-a file. You can see how we do that in the `LoadGraph()` function.
-
-```C++
-// Reads a model graph definition from disk, and creates a session object you
-// can use to run it.
-Status LoadGraph(string graph_file_name,
- std::unique_ptr<tensorflow::Session>* session) {
- tensorflow::GraphDef graph_def;
- Status load_graph_status =
- ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
- if (!load_graph_status.ok()) {
- return tensorflow::errors::NotFound("Failed to load compute graph at '",
- graph_file_name, "'");
- }
-```
-If you've looked through the image loading code, a lot of the terms should seem familiar. Rather than
-using a `GraphDefBuilder` to produce a `GraphDef` object, we load a protobuf file that
-directly contains the `GraphDef`.
-
-```C++
- session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
- Status session_create_status = (*session)->Create(graph_def);
- if (!session_create_status.ok()) {
- return session_create_status;
- }
- return Status::OK();
-}
-```
-Then we create a Session object from that `GraphDef` and
-pass it back to the caller so that they can run it at a later time.
-
-The `GetTopLabels()` function is a lot like the image loading, except that in this case
-we want to take the results of running the main graph, and turn it into a sorted list
-of the highest-scoring labels. Just like the image loader, it creates a
-`GraphDefBuilder`, adds a couple of nodes to it, and then runs the short graph to get a
-pair of output tensors. In this case they represent the sorted scores and index
-positions of the highest results.
-
-```C++
-// Analyzes the output of the Inception graph to retrieve the highest scores and
-// their positions in the tensor, which correspond to categories.
-Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
- Tensor* indices, Tensor* scores) {
- tensorflow::GraphDefBuilder b;
- string output_name = "top_k";
- tensorflow::ops::TopK(tensorflow::ops::Const(outputs[0], b.opts()),
- how_many_labels, b.opts().WithName(output_name));
- // This runs the GraphDef network definition that we've just constructed, and
- // returns the results in the output tensors.
- tensorflow::GraphDef graph;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
- std::unique_ptr<tensorflow::Session> session(
- tensorflow::NewSession(tensorflow::SessionOptions()));
- TF_RETURN_IF_ERROR(session->Create(graph));
- // The TopK node returns two outputs, the scores and their original indices,
- // so we have to append :0 and :1 to specify them both.
- std::vector<Tensor> out_tensors;
- TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
- {}, &out_tensors));
- *scores = out_tensors[0];
- *indices = out_tensors[1];
- return Status::OK();
-```
-The `PrintTopLabels()` function takes those sorted results, and prints them out in a
-friendly way. The `CheckTopLabel()` function is very similar, but just makes sure that
-the top label is the one we expect, for debugging purposes.
-
-At the end, [`main()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc#L252)
-ties together all of these calls.
-
-```C++
-int main(int argc, char* argv[]) {
- // We need to call this to set up global state for TensorFlow.
- tensorflow::port::InitMain(argv[0], &argc, &argv);
- Status s = tensorflow::ParseCommandLineFlags(&argc, argv);
- if (!s.ok()) {
- LOG(ERROR) << "Error parsing command line flags: " << s.ToString();
- return -1;
- }
-
- // First we load and initialize the model.
- std::unique_ptr<tensorflow::Session> session;
- string graph_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_graph);
- Status load_graph_status = LoadGraph(graph_path, &session);
- if (!load_graph_status.ok()) {
- LOG(ERROR) << load_graph_status;
- return -1;
- }
-```
-We load the main graph.
-
-```C++
- // Get the image from disk as a float array of numbers, resized and normalized
- // to the specifications the main graph expects.
- std::vector<Tensor> resized_tensors;
- string image_path = tensorflow::io::JoinPath(FLAGS_root_dir, FLAGS_image);
- Status read_tensor_status = ReadTensorFromImageFile(
- image_path, FLAGS_input_height, FLAGS_input_width, FLAGS_input_mean,
- FLAGS_input_std, &resized_tensors);
- if (!read_tensor_status.ok()) {
- LOG(ERROR) << read_tensor_status;
- return -1;
- }
- const Tensor& resized_tensor = resized_tensors[0];
-```
-Load, resize, and process the input image.
-
-```C++
- // Actually run the image through the model.
- std::vector<Tensor> outputs;
- Status run_status = session->Run({{FLAGS_input_layer, resized_tensor}},
- {FLAGS_output_layer}, {}, &outputs);
- if (!run_status.ok()) {
- LOG(ERROR) << "Running model failed: " << run_status;
- return -1;
- }
-```
-Here we run the loaded graph with the image as an input.
-
-```C++
- // This is for automated testing to make sure we get the expected result with
- // the default settings. We know that label 866 (military uniform) should be
- // the top label for the Admiral Hopper image.
- if (FLAGS_self_test) {
- bool expected_matches;
- Status check_status = CheckTopLabel(outputs, 866, &expected_matches);
- if (!check_status.ok()) {
- LOG(ERROR) << "Running check failed: " << check_status;
- return -1;
- }
- if (!expected_matches) {
- LOG(ERROR) << "Self-test failed!";
- return -1;
- }
- }
-```
-For testing purposes we can check to make sure we get the output we expect here.
-
-```C++
- // Do something interesting with the results we've generated.
- Status print_status = PrintTopLabels(outputs, FLAGS_labels);
-```
-Finally we print the labels we found.
-
-```C++
- if (!print_status.ok()) {
- LOG(ERROR) << "Running print failed: " << print_status;
- return -1;
- }
-```
-
-The error handling here is using TensorFlow's `Status`
-object, which is very convenient because it lets you know whether any error has
-occurred with the `ok()` checker, and then can be printed out to give a readable error
-message.
-
-In this case we are demonstrating object recognition, but you should be able to
-use very similar code on other models you've found or trained yourself, across
-all
-sorts of domains. We hope this small example gives you some ideas on how to use
-TensorFlow within your own products.
-
-> **EXERCISE**: Transfer learning is the idea that, if you know how to solve a task well, you
-should be able to transfer some of that understanding to solving related
-problems. One way to perform transfer learning is to remove the final
-classification layer of the network and extract
-the [next-to-last layer of the CNN](https://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector.
-
-
-## Resources for Learning More
-
-To learn about neural networks in general, Michael Nielsen's
-[free online book](http://neuralnetworksanddeeplearning.com/chap1.html)
-is an excellent resource. For convolutional neural networks in particular,
-Chris Olah has some
-[nice blog posts](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),
-and Michael Nielsen's book has a
-[great chapter](http://neuralnetworksanddeeplearning.com/chap6.html)
-covering them.
-
-To find out more about implementing convolutional neural networks, you can jump
-to the TensorFlow @{$deep_cnn$deep convolutional networks tutorial},
-or start a bit more gently with our [Estimator MNIST tutorial](../estimators/cnn.md).
-Finally, if you want to get up to speed on research in this area, you can
-read the recent work of all the papers referenced in this tutorial.
-
diff --git a/tensorflow/docs_src/tutorials/keras/basic_classification.md b/tensorflow/docs_src/tutorials/keras/basic_classification.md
deleted file mode 100644
index e028af99b9..0000000000
--- a/tensorflow/docs_src/tutorials/keras/basic_classification.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Basic Classification
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_classification.ipynb)
diff --git a/tensorflow/docs_src/tutorials/keras/basic_regression.md b/tensorflow/docs_src/tutorials/keras/basic_regression.md
deleted file mode 100644
index 8721b7aca1..0000000000
--- a/tensorflow/docs_src/tutorials/keras/basic_regression.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Basic Regression
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_regression.ipynb)
diff --git a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md
deleted file mode 100644
index c2a16bdd20..0000000000
--- a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Basic Text Classification
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_text_classification.ipynb)
diff --git a/tensorflow/docs_src/tutorials/keras/index.md b/tensorflow/docs_src/tutorials/keras/index.md
deleted file mode 100644
index 9d42281c8f..0000000000
--- a/tensorflow/docs_src/tutorials/keras/index.md
+++ /dev/null
@@ -1,22 +0,0 @@
-# Learn and use machine learning
-
-This notebook collection is inspired by the book
-*[Deep Learning with Python](https://books.google.com/books?id=Yo3CAQAACAAJ)*.
-These tutorials use `tf.keras`, TensorFlow's high-level Python API for building
-and training deep learning models. To learn more about using Keras with
-TensorFlow, see the [TensorFlow Keras Guide](../../guide/keras).
-
-Publisher's note: *Deep Learning with Python* introduces the field of deep
-learning using the Python language and the powerful Keras library. Written by
-Keras creator and Google AI researcher François Chollet, this book builds your
-understanding through intuitive explanations and practical examples.
-
-To learn about machine learning fundamentals and concepts, consider taking the
-[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/).
-Additional TensorFlow and machine learning resources are listed in [next steps](../next_steps).
-
-1. [Basic classification](./basic_classification)
-2. [Text classification](./basic_text_classification)
-3. [Regression](./basic_regression)
-4. [Overfitting and underfitting](./overfit_and_underfit)
-5. [Save and restore models](./save_and_restore_models)
diff --git a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md
deleted file mode 100644
index f07f3addd8..0000000000
--- a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Overfitting and Underfitting
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/overfit_and_underfit.ipynb)
diff --git a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md
deleted file mode 100644
index a799b379a0..0000000000
--- a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# Save and restore Models
-
-[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/save_and_restore_models.ipynb)
diff --git a/tensorflow/docs_src/tutorials/next_steps.md b/tensorflow/docs_src/tutorials/next_steps.md
deleted file mode 100644
index 01c9f7204a..0000000000
--- a/tensorflow/docs_src/tutorials/next_steps.md
+++ /dev/null
@@ -1,36 +0,0 @@
-# Next steps
-
-## Learn more about TensorFlow
-
-* The [TensorFlow Guide](/guide) includes usage guides for the
- high-level APIs, as well as advanced TensorFlow operations.
-* [Premade Estimators](/guide/premade_estimators) are designed to
- get results out of the box. Use TensorFlow without building your own models.
-* [TensorFlow.js](https://js.tensorflow.org/) allows web developers to train and
- deploy ML models in the browser and using Node.js.
-* [TFLite](/mobile/tflite) allows mobile developers to do inference efficiently
- on mobile devices.
-* [TensorFlow Serving](/serving) is an open-source project that can put
- TensorFlow models in production quickly.
-* The [ecosystem](/ecosystem) contains more projects, including
- [Magenta](https://magenta.tensorflow.org/), [TFX](/tfx),
- [Swift for TensorFlow](https://github.com/tensorflow/swift), and more.
-
-## Learn more about machine learning
-
-Recommended resources include:
-
-* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/),
- a course from Google that introduces machine learning concepts.
-* [CS 20: Tensorflow for Deep Learning Research](http://web.stanford.edu/class/cs20si/),
- notes from an intro course from Stanford.
-* [CS231n: Convolutional Neural Networks for Visual Recognition](http://cs231n.stanford.edu/),
- a course that teaches how convolutional networks work.
-* [Machine Learning Recipes](https://www.youtube.com/watch?v=cKxRvEZd3Mw&list=PLOU2XLYxmsIIuiBfYad6rFYQU_jL2ryal),
- a video series that introduces basic machine learning concepts with few prerequisites.
-* [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python),
- a book by Francois Chollet about the Keras API, as well as an excellent hands on intro to Deep Learning.
-* [Hands-on Machine Learning with Scikit-Learn and TensorFlow](https://github.com/ageron/handson-ml),
- a book by Aurélien Geron's that is a clear getting-started guide to data science and deep learning.
-* [Deep Learning](https://www.deeplearningbook.org/), a book by Ian Goodfellow et al.
- that provides a technical dive into learning machine learning.
diff --git a/tensorflow/docs_src/tutorials/non-ml/mandelbrot.md b/tensorflow/docs_src/tutorials/non-ml/mandelbrot.md
deleted file mode 100644
index 1c0a548129..0000000000
--- a/tensorflow/docs_src/tutorials/non-ml/mandelbrot.md
+++ /dev/null
@@ -1,116 +0,0 @@
-# Mandelbrot Set
-
-Visualizing the [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set)
-doesn't have anything to do with machine learning, but it makes for a fun
-example of how one can use TensorFlow for general mathematics. This is
-actually a pretty naive implementation of the visualization, but it makes the
-point. (We may end up providing a more elaborate implementation down the line
-to produce more truly beautiful images.)
-
-
-## Basic Setup
-
-We'll need a few imports to get started.
-
-```python
-# Import libraries for simulation
-import tensorflow as tf
-import numpy as np
-
-# Imports for visualization
-import PIL.Image
-from io import BytesIO
-from IPython.display import Image, display
-```
-
-Now we'll define a function to actually display the image once we have
-iteration counts.
-
-```python
-def DisplayFractal(a, fmt='jpeg'):
- """Display an array of iteration counts as a
- colorful picture of a fractal."""
- a_cyclic = (6.28*a/20.0).reshape(list(a.shape)+[1])
- img = np.concatenate([10+20*np.cos(a_cyclic),
- 30+50*np.sin(a_cyclic),
- 155-80*np.cos(a_cyclic)], 2)
- img[a==a.max()] = 0
- a = img
- a = np.uint8(np.clip(a, 0, 255))
- f = BytesIO()
- PIL.Image.fromarray(a).save(f, fmt)
- display(Image(data=f.getvalue()))
-```
-
-## Session and Variable Initialization
-
-For playing around like this, we often use an interactive session, but a regular
-session would work as well.
-
-```python
-sess = tf.InteractiveSession()
-```
-
-It's handy that we can freely mix NumPy and TensorFlow.
-
-```python
-# Use NumPy to create a 2D array of complex numbers
-
-Y, X = np.mgrid[-1.3:1.3:0.005, -2:1:0.005]
-Z = X+1j*Y
-```
-
-Now we define and initialize TensorFlow tensors.
-
-```python
-xs = tf.constant(Z.astype(np.complex64))
-zs = tf.Variable(xs)
-ns = tf.Variable(tf.zeros_like(xs, tf.float32))
-```
-
-TensorFlow requires that you explicitly initialize variables before using them.
-
-```python
-tf.global_variables_initializer().run()
-```
-
-## Defining and Running the Computation
-
-Now we specify more of the computation...
-
-```python
-# Compute the new values of z: z^2 + x
-zs_ = zs*zs + xs
-
-# Have we diverged with this new value?
-not_diverged = tf.abs(zs_) < 4
-
-# Operation to update the zs and the iteration count.
-#
-# Note: We keep computing zs after they diverge! This
-# is very wasteful! There are better, if a little
-# less simple, ways to do this.
-#
-step = tf.group(
- zs.assign(zs_),
- ns.assign_add(tf.cast(not_diverged, tf.float32))
- )
-```
-
-... and run it for a couple hundred steps
-
-```python
-for i in range(200): step.run()
-```
-
-Let's see what we've got.
-
-```python
-DisplayFractal(ns.eval())
-```
-
-![jpeg](https://www.tensorflow.org/images/mandelbrot_output.jpg)
-
-Not bad!
-
-
diff --git a/tensorflow/docs_src/tutorials/non-ml/pdes.md b/tensorflow/docs_src/tutorials/non-ml/pdes.md
deleted file mode 100644
index b5a0fa834a..0000000000
--- a/tensorflow/docs_src/tutorials/non-ml/pdes.md
+++ /dev/null
@@ -1,140 +0,0 @@
-# Partial Differential Equations
-
-TensorFlow isn't just for machine learning. Here we give a (somewhat
-pedestrian) example of using TensorFlow for simulating the behavior of a
-[partial differential equation](
-https://en.wikipedia.org/wiki/Partial_differential_equation).
-We'll simulate the surface of square pond as a few raindrops land on it.
-
-
-## Basic Setup
-
-A few imports we'll need.
-
-```python
-#Import libraries for simulation
-import tensorflow as tf
-import numpy as np
-
-#Imports for visualization
-import PIL.Image
-from io import BytesIO
-from IPython.display import clear_output, Image, display
-```
-
-A function for displaying the state of the pond's surface as an image.
-
-```python
-def DisplayArray(a, fmt='jpeg', rng=[0,1]):
- """Display an array as a picture."""
- a = (a - rng[0])/float(rng[1] - rng[0])*255
- a = np.uint8(np.clip(a, 0, 255))
- f = BytesIO()
- PIL.Image.fromarray(a).save(f, fmt)
- clear_output(wait = True)
- display(Image(data=f.getvalue()))
-```
-
-Here we start an interactive TensorFlow session for convenience in playing
-around. A regular session would work as well if we were doing this in an
-executable .py file.
-
-```python
-sess = tf.InteractiveSession()
-```
-
-## Computational Convenience Functions
-
-
-```python
-def make_kernel(a):
- """Transform a 2D array into a convolution kernel"""
- a = np.asarray(a)
- a = a.reshape(list(a.shape) + [1,1])
- return tf.constant(a, dtype=1)
-
-def simple_conv(x, k):
- """A simplified 2D convolution operation"""
- x = tf.expand_dims(tf.expand_dims(x, 0), -1)
- y = tf.nn.depthwise_conv2d(x, k, [1, 1, 1, 1], padding='SAME')
- return y[0, :, :, 0]
-
-def laplace(x):
- """Compute the 2D laplacian of an array"""
- laplace_k = make_kernel([[0.5, 1.0, 0.5],
- [1.0, -6., 1.0],
- [0.5, 1.0, 0.5]])
- return simple_conv(x, laplace_k)
-```
-
-## Define the PDE
-
-Our pond is a perfect 500 x 500 square, as is the case for most ponds found in
-nature.
-
-```python
-N = 500
-```
-
-Here we create our pond and hit it with some rain drops.
-
-```python
-# Initial Conditions -- some rain drops hit a pond
-
-# Set everything to zero
-u_init = np.zeros([N, N], dtype=np.float32)
-ut_init = np.zeros([N, N], dtype=np.float32)
-
-# Some rain drops hit a pond at random points
-for n in range(40):
- a,b = np.random.randint(0, N, 2)
- u_init[a,b] = np.random.uniform()
-
-DisplayArray(u_init, rng=[-0.1, 0.1])
-```
-
-![jpeg](https://www.tensorflow.org/images/pde_output_1.jpg)
-
-
-Now let's specify the details of the differential equation.
-
-
-```python
-# Parameters:
-# eps -- time resolution
-# damping -- wave damping
-eps = tf.placeholder(tf.float32, shape=())
-damping = tf.placeholder(tf.float32, shape=())
-
-# Create variables for simulation state
-U = tf.Variable(u_init)
-Ut = tf.Variable(ut_init)
-
-# Discretized PDE update rules
-U_ = U + eps * Ut
-Ut_ = Ut + eps * (laplace(U) - damping * Ut)
-
-# Operation to update the state
-step = tf.group(
- U.assign(U_),
- Ut.assign(Ut_))
-```
-
-## Run The Simulation
-
-This is where it gets fun -- running time forward with a simple for loop.
-
-```python
-# Initialize state to initial conditions
-tf.global_variables_initializer().run()
-
-# Run 1000 steps of PDE
-for i in range(1000):
- # Step simulation
- step.run({eps: 0.03, damping: 0.04})
- DisplayArray(U.eval(), rng=[-0.1, 0.1])
-```
-
-![jpeg](../../images/pde_output_2.jpg)
-
-Look! Ripples!
diff --git a/tensorflow/docs_src/tutorials/representation/kernel_methods.md b/tensorflow/docs_src/tutorials/representation/kernel_methods.md
deleted file mode 100644
index 71e87f4d3e..0000000000
--- a/tensorflow/docs_src/tutorials/representation/kernel_methods.md
+++ /dev/null
@@ -1,303 +0,0 @@
-# Improving Linear Models Using Explicit Kernel Methods
-
-Note: This document uses a deprecated version of `tf.estimator`,
-`tf.contrib.learn.Estimator`, which has a different interface. It also uses
-other `contrib` methods whose @{$version_compat#not_covered$API may not be stable}.
-
-In this tutorial, we demonstrate how combining (explicit) kernel methods with
-linear models can drastically increase the latters' quality of predictions
-without significantly increasing training and inference times. Unlike dual
-kernel methods, explicit (primal) kernel methods scale well with the size of the
-training dataset both in terms of training/inference times and in terms of
-memory requirements.
-
-**Intended audience:** Even though we provide a high-level overview of concepts
-related to explicit kernel methods, this tutorial primarily targets readers who
-already have at least basic knowledge of kernel methods and Support Vector
-Machines (SVMs). If you are new to kernel methods, refer to either of the
-following sources for an introduction:
-
-* If you have a strong mathematical background:
-[Kernel Methods in Machine Learning](https://arxiv.org/pdf/math/0701907.pdf)
-* [Kernel method wikipedia page](https://en.wikipedia.org/wiki/Kernel_method)
-
-Currently, TensorFlow supports explicit kernel mappings for dense features only;
-TensorFlow will provide support for sparse features at a later release.
-
-This tutorial uses [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn)
-(TensorFlow's high-level Machine Learning API) Estimators for our ML models.
-If you are not familiar with this API, The [Estimator guide](../../guide/estimators.md)
-is a good place to start. We will use the MNIST dataset. The tutorial consists
-of the following steps:
-
-* Load and prepare MNIST data for classification.
-* Construct a simple linear model, train it, and evaluate it on the eval data.
-* Replace the linear model with a kernelized linear model, re-train, and
-re-evaluate.
-
-## Load and prepare MNIST data for classification
-Run the following utility command to load the MNIST dataset:
-
-```python
-data = tf.contrib.learn.datasets.mnist.load_mnist()
-```
-The preceding method loads the entire MNIST dataset (containing 70K samples) and
-splits it into train, validation, and test data with 55K, 5K, and 10K samples
-respectively. Each split contains one numpy array for images (with shape
-[sample_size, 784]) and one for labels (with shape [sample_size, 1]). In this
-tutorial, we only use the train and validation splits to train and evaluate our
-models respectively.
-
-In order to feed data to a `tf.contrib.learn Estimator`, it is helpful to convert
-it to Tensors. For this, we will use an `input function` which adds Ops to the
-TensorFlow graph that, when executed, create mini-batches of Tensors to be used
-downstream. For more background on input functions, check
-@{$premade_estimators#create_input_functions$this section on input functions}.
-In this example, we will use the `tf.train.shuffle_batch` Op which, besides
-converting numpy arrays to Tensors, allows us to specify the batch_size and
-whether to randomize the input every time the input_fn Ops are executed
-(randomization typically expedites convergence during training). The full code
-for loading and preparing the data is shown in the snippet below. In this
-example, we use mini-batches of size 256 for training and the entire sample
-(5K entries) for evaluation. Feel free to experiment with different batch sizes.
-
-```python
-import numpy as np
-import tensorflow as tf
-
-def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
-
- def _input_fn():
- images_batch, labels_batch = tf.train.shuffle_batch(
- tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
- batch_size=batch_size,
- capacity=capacity,
- min_after_dequeue=min_after_dequeue,
- enqueue_many=True,
- num_threads=4)
- features_map = {'images': images_batch}
- return features_map, labels_batch
-
- return _input_fn
-
-data = tf.contrib.learn.datasets.mnist.load_mnist()
-
-train_input_fn = get_input_fn(data.train, batch_size=256)
-eval_input_fn = get_input_fn(data.validation, batch_size=5000)
-
-```
-
-## Training a simple linear model
-We can now train a linear model over the MNIST dataset. We will use the
-`tf.contrib.learn.LinearClassifier` estimator with 10 classes representing the
-10 digits. The input features form a 784-dimensional dense vector which can
-be specified as follows:
-
-```python
-image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
-```
-
-The full code for constructing, training and evaluating a LinearClassifier
-estimator is as follows:
-
-```python
-import time
-
-# Specify the feature(s) to be used by the estimator.
-image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
-estimator = tf.contrib.learn.LinearClassifier(feature_columns=[image_column], n_classes=10)
-
-# Train.
-start = time.time()
-estimator.fit(input_fn=train_input_fn, steps=2000)
-end = time.time()
-print('Elapsed time: {} seconds'.format(end - start))
-
-# Evaluate and report metrics.
-eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
-print(eval_metrics)
-```
-The following table summarizes the results on the eval data.
-
-metric | value
-:------------ | :------------
-loss | 0.25 to 0.30
-accuracy | 92.5%
-training time | ~25 seconds on my machine
-
-Note: Metrics will vary depending on various factors.
-
-In addition to experimenting with the (training) batch size and the number of
-training steps, there are a couple other parameters that can be tuned as well.
-For instance, you can change the optimization method used to minimize the loss
-by explicitly selecting another optimizer from the collection of
-[available optimizers](https://www.tensorflow.org/code/tensorflow/python/training).
-As an example, the following code constructs a LinearClassifier estimator that
-uses the Follow-The-Regularized-Leader (FTRL) optimization strategy with a
-specific learning rate and L2-regularization.
-
-
-```python
-optimizer = tf.train.FtrlOptimizer(learning_rate=5.0, l2_regularization_strength=1.0)
-estimator = tf.contrib.learn.LinearClassifier(
- feature_columns=[image_column], n_classes=10, optimizer=optimizer)
-```
-
-Regardless of the values of the parameters, the maximum accuracy a linear model
-can achieve on this dataset caps at around **93%**.
-
-## Using explicit kernel mappings with the linear model.
-The relatively high error (~7%) of the linear model over MNIST indicates that
-the input data is not linearly separable. We will use explicit kernel mappings
-to reduce the classification error.
-
-**Intuition:** The high-level idea is to use a non-linear map to transform the
-input space to another feature space (of possibly higher dimension) where the
-(transformed) features are (almost) linearly separable and then apply a linear
-model on the mapped features. This is shown in the following figure:
-
-<div style="text-align:center">
-<img src="https://www.tensorflow.org/versions/master/images/kernel_mapping.png" />
-</div>
-
-
-### Technical details
-In this example we will use **Random Fourier Features**, introduced in the
-["Random Features for Large-Scale Kernel Machines"](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf)
-paper by Rahimi and Recht, to map the input data. Random Fourier Features map a
-vector \\(\mathbf{x} \in \mathbb{R}^d\\) to \\(\mathbf{x'} \in \mathbb{R}^D\\)
-via the following mapping:
-
-$$
-RFFM(\cdot): \mathbb{R}^d \to \mathbb{R}^D, \quad
-RFFM(\mathbf{x}) = \cos(\mathbf{\Omega} \cdot \mathbf{x}+ \mathbf{b})
-$$
-
-where \\(\mathbf{\Omega} \in \mathbb{R}^{D \times d}\\),
-\\(\mathbf{x} \in \mathbb{R}^d,\\) \\(\mathbf{b} \in \mathbb{R}^D\\) and the
-cosine is applied element-wise.
-
-In this example, the entries of \\(\mathbf{\Omega}\\) and \\(\mathbf{b}\\) are
-sampled from distributions such that the mapping satisfies the following
-property:
-
-$$
-RFFM(\mathbf{x})^T \cdot RFFM(\mathbf{y}) \approx
-e^{-\frac{\|\mathbf{x} - \mathbf{y}\|^2}{2 \sigma^2}}
-$$
-
-The right-hand-side quantity of the expression above is known as the RBF (or
-Gaussian) kernel function. This function is one of the most-widely used kernel
-functions in Machine Learning and implicitly measures similarity in a different,
-much higher dimensional space than the original one. See
-[Radial basis function kernel](https://en.wikipedia.org/wiki/Radial_basis_function_kernel)
-for more details.
-
-### Kernel classifier
-`tf.contrib.kernel_methods.KernelLinearClassifier` is a pre-packaged
-`tf.contrib.learn` estimator that combines the power of explicit kernel mappings
-with linear models. Its constructor is almost identical to that of the
-LinearClassifier estimator with the additional option to specify a list of
-explicit kernel mappings to be applied to each feature the classifier uses. The
-following code snippet demonstrates how to replace LinearClassifier with
-KernelLinearClassifier.
-
-
-```python
-# Specify the feature(s) to be used by the estimator. This is identical to the
-# code used for the LinearClassifier.
-image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
-optimizer = tf.train.FtrlOptimizer(
- learning_rate=50.0, l2_regularization_strength=0.001)
-
-
-kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
- input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
-kernel_mappers = {image_column: [kernel_mapper]}
-estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
- n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
-
-# Train.
-start = time.time()
-estimator.fit(input_fn=train_input_fn, steps=2000)
-end = time.time()
-print('Elapsed time: {} seconds'.format(end - start))
-
-# Evaluate and report metrics.
-eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
-print(eval_metrics)
-```
-The only additional parameter passed to `KernelLinearClassifier` is a dictionary
-from feature_columns to a list of kernel mappings to be applied to the
-corresponding feature column. The following lines instruct the classifier to
-first map the initial 784-dimensional images to 2000-dimensional vectors using
-random Fourier features and then learn a linear model on the transformed
-vectors:
-
-```python
-kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
- input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
-kernel_mappers = {image_column: [kernel_mapper]}
-estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
- n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
-```
-Notice the `stddev` parameter. This is the standard deviation (\\(\sigma\\)) of
-the approximated RBF kernel and controls the similarity measure used in
-classification. `stddev` is typically determined via hyperparameter tuning.
-
-The results of running the preceding code are summarized in the following table.
-We can further increase the accuracy by increasing the output dimension of the
-mapping and tuning the standard deviation.
-
-metric | value
-:------------ | :------------
-loss | 0.10
-accuracy | 97%
-training time | ~35 seconds on my machine
-
-
-### stddev
-The classification quality is very sensitive to the value of stddev. The
-following table shows the accuracy of the classifier on the eval data for
-different values of stddev. The optimal value is stddev=5.0. Notice how too
-small or too high stddev values can dramatically decrease the accuracy of the
-classification.
-
-stddev | eval accuracy
-:----- | :------------
-1.0 | 0.1362
-2.0 | 0.4764
-4.0 | 0.9654
-5.0 | 0.9766
-8.0 | 0.9714
-16.0 | 0.8878
-
-### Output dimension
-Intuitively, the larger the output dimension of the mapping, the closer the
-inner product of two mapped vectors approximates the kernel, which typically
-translates to better classification accuracy. Another way to think about this is
-that the output dimension equals the number of weights of the linear model; the
-larger this dimension, the larger the "degrees of freedom" of the model.
-However, after a certain threshold, higher output dimensions increase the
-accuracy by very little, while making training take more time. This is shown in
-the following two Figures which depict the eval accuracy as a function of the
-output dimension and the training time, respectively.
-
-![image](https://www.tensorflow.org/versions/master/images/acc_vs_outdim.png)
-![image](https://www.tensorflow.org/versions/master/images/acc-vs-trn_time.png)
-
-
-## Summary
-Explicit kernel mappings combine the predictive power of nonlinear models with
-the scalability of linear models. Unlike traditional dual kernel methods,
-explicit kernel methods can scale to millions or hundreds of millions of
-samples. When using explicit kernel mappings, consider the following tips:
-
-* Random Fourier Features can be particularly effective for datasets with dense
-features.
-* The parameters of the kernel mapping are often data-dependent. Model quality
-can be very sensitive to these parameters. Use hyperparameter tuning to find the
-optimal values.
-* If you have multiple numerical features, concatenate them into a single
-multi-dimensional feature and apply the kernel mapping to the concatenated
-vector.
diff --git a/tensorflow/docs_src/tutorials/representation/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md
deleted file mode 100644
index 014409c617..0000000000
--- a/tensorflow/docs_src/tutorials/representation/linear.md
+++ /dev/null
@@ -1,239 +0,0 @@
-# Large-scale Linear Models with TensorFlow
-
-`tf.estimator` provides (among other things) a rich set of tools for
-working with linear models in TensorFlow. This document provides an overview of
-those tools. It explains:
-
- * What a linear model is.
- * Why you might want to use a linear model.
- * How Estimators make it easy to build linear models in TensorFlow.
- * How you can use Estimators to combine linear models with.
- deep learning to get the advantages of both.
-
-Read this overview to decide whether the Estimator's linear model tools might
-be useful to you. Then work through the
-[Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
-to give it a try. This overview uses code samples from the tutorial, but the
-tutorial walks through the code in greater detail.
-
-To understand this overview it will help to have some familiarity
-with basic machine learning concepts, and also with
-@{$premade_estimators$Estimators}.
-
-[TOC]
-
-## What is a linear model?
-
-A **linear model** uses a single weighted sum of features to make a prediction.
-For example, if you have [data](https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names)
-on age, years of education, and weekly hours of
-work for a population, a model can learn weights for each of those numbers so that
-their weighted sum estimates a person's salary. You can also use linear models
-for classification.
-
-Some linear models transform the weighted sum into a more convenient form. For
-example, [**logistic regression**](https://developers.google.com/machine-learning/glossary/#logistic_regression) plugs the weighted sum into the logistic
-function to turn the output into a value between 0 and 1. But you still just
-have one weight for each input feature.
-
-## Why would you want to use a linear model?
-
-Why would you want to use so simple a model when recent research has
-demonstrated the power of more complex neural networks with many layers?
-
-Linear models:
-
- * train quickly, compared to deep neural nets.
- * can work well on very large feature sets.
- * can be trained with algorithms that don't require a lot of fiddling
- with learning rates, etc.
- * can be interpreted and debugged more easily than neural nets.
- You can examine the weights assigned to each feature to figure out what's
- having the biggest impact on a prediction.
- * provide an excellent starting point for learning about machine learning.
- * are widely used in industry.
-
-## How do Estimators help you build linear models?
-
-You can build a linear model from scratch in TensorFlow without the help of a
-special API. But Estimators provides some tools that make it easier to build
-effective large-scale linear models.
-
-### Feature columns and transformations
-
-Much of the work of designing a linear model consists of transforming raw data
-into suitable input features. Tensorflow uses the `FeatureColumn` abstraction to
-enable these transformations.
-
-A `FeatureColumn` represents a single feature in your data. A `FeatureColumn`
-may represent a quantity like 'height', or it may represent a category like
-'eye_color' where the value is drawn from a set of discrete possibilities like
-{'blue', 'brown', 'green'}.
-
-In the case of both *continuous features* like 'height' and *categorical
-features* like 'eye_color', a single value in the data might get transformed
-into a sequence of numbers before it is input into the model. The
-`FeatureColumn` abstraction lets you manipulate the feature as a single
-semantic unit in spite of this fact. You can specify transformations and
-select features to include without dealing with specific indices in the
-tensors you feed into the model.
-
-#### Sparse columns
-
-Categorical features in linear models are typically translated into a sparse
-vector in which each possible value has a corresponding index or id. For
-example, if there are only three possible eye colors you can represent
-'eye_color' as a length 3 vector: 'brown' would become [1, 0, 0], 'blue' would
-become [0, 1, 0] and 'green' would become [0, 0, 1]. These vectors are called
-"sparse" because they may be very long, with many zeros, when the set of
-possible values is very large (such as all English words).
-
-While you don't need to use categorical columns to use the linear model tools
-provided by Estimators, one of the strengths of linear models is their ability
-to deal with large sparse vectors. Sparse features are a primary use case for
-the linear model tools provided by Estimators.
-
-##### Encoding sparse columns
-
-`FeatureColumn` handles the conversion of categorical values into vectors
-automatically, with code like this:
-
-```python
-eye_color = tf.feature_column.categorical_column_with_vocabulary_list(
- "eye_color", vocabulary_list=["blue", "brown", "green"])
-```
-
-where `eye_color` is the name of a column in your source data.
-
-You can also generate `FeatureColumn`s for categorical features for which you
-don't know all possible values. For this case you would use
-`categorical_column_with_hash_bucket()`, which uses a hash function to assign
-indices to feature values.
-
-```python
-education = tf.feature_column.categorical_column_with_hash_bucket(
- "education", hash_bucket_size=1000)
-```
-
-##### Feature Crosses
-
-Because linear models assign independent weights to separate features, they
-can't learn the relative importance of specific combinations of feature
-values. If you have a feature 'favorite_sport' and a feature 'home_city' and
-you're trying to predict whether a person likes to wear red, your linear model
-won't be able to learn that baseball fans from St. Louis especially like to
-wear red.
-
-You can get around this limitation by creating a new feature
-'favorite_sport_x_home_city'. The value of this feature for a given person is
-just the concatenation of the values of the two source features:
-'baseball_x_stlouis', for example. This sort of combination feature is called
-a *feature cross*.
-
-The `crossed_column()` method makes it easy to set up feature crosses:
-
-```python
-sport_x_city = tf.feature_column.crossed_column(
- ["sport", "city"], hash_bucket_size=int(1e4))
-```
-
-#### Continuous columns
-
-You can specify a continuous feature like so:
-
-```python
-age = tf.feature_column.numeric_column("age")
-```
-
-Although, as a single real number, a continuous feature can often be input
-directly into the model, Tensorflow offers useful transformations for this sort
-of column as well.
-
-##### Bucketization
-
-*Bucketization* turns a continuous column into a categorical column. This
-transformation lets you use continuous features in feature crosses, or learn
-cases where specific value ranges have particular importance.
-
-Bucketization divides the range of possible values into subranges called
-buckets:
-
-```python
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-```
-
-The bucket into which a value falls becomes the categorical label for
-that value.
-
-#### Input function
-
-`FeatureColumn`s provide a specification for the input data for your model,
-indicating how to represent and transform the data. But they do not provide
-the data itself. You provide the data through an input function.
-
-The input function must return a dictionary of tensors. Each key corresponds to
-the name of a `FeatureColumn`. Each key's value is a tensor containing the
-values of that feature for all data instances. See
-@{$premade_estimators#input_fn} for a
-more comprehensive look at input functions, and `input_fn` in the
-[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
-for an example implementation of an input function.
-
-The input function is passed to the `train()` and `evaluate()` calls that
-initiate training and testing, as described in the next section.
-
-### Linear estimators
-
-Tensorflow estimator classes provide a unified training and evaluation harness
-for regression and classification models. They take care of the details of the
-training and evaluation loops and allow the user to focus on model inputs and
-architecture.
-
-To build a linear estimator, you can use either the
-`tf.estimator.LinearClassifier` estimator or the
-`tf.estimator.LinearRegressor` estimator, for classification and
-regression respectively.
-
-As with all tensorflow estimators, to run the estimator you just:
-
- 1. Instantiate the estimator class. For the two linear estimator classes,
- you pass a list of `FeatureColumn`s to the constructor.
- 2. Call the estimator's `train()` method to train it.
- 3. Call the estimator's `evaluate()` method to see how it does.
-
-For example:
-
-```python
-e = tf.estimator.LinearClassifier(
- feature_columns=[
- native_country, education, occupation, workclass, marital_status,
- race, age_buckets, education_x_occupation,
- age_buckets_x_race_x_occupation],
- model_dir=YOUR_MODEL_DIRECTORY)
-e.train(input_fn=input_fn_train, steps=200)
-# Evaluate for one step (one pass through the test data).
-results = e.evaluate(input_fn=input_fn_test)
-
-# Print the stats for the evaluation.
-for key in sorted(results):
- print("%s: %s" % (key, results[key]))
-```
-
-### Wide and deep learning
-
-The `tf.estimator` module also provides an estimator class that lets you jointly
-train a linear model and a deep neural network. This novel approach combines the
-ability of linear models to "memorize" key features with the generalization
-ability of neural nets. Use `tf.estimator.DNNLinearCombinedClassifier` to
-create this sort of "wide and deep" model:
-
-```python
-e = tf.estimator.DNNLinearCombinedClassifier(
- model_dir=YOUR_MODEL_DIR,
- linear_feature_columns=wide_columns,
- dnn_feature_columns=deep_columns,
- dnn_hidden_units=[100, 50])
-```
-For more information, see the
-[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
diff --git a/tensorflow/docs_src/tutorials/representation/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md
deleted file mode 100644
index 7964650e19..0000000000
--- a/tensorflow/docs_src/tutorials/representation/word2vec.md
+++ /dev/null
@@ -1,405 +0,0 @@
-# Vector Representations of Words
-
-In this tutorial we look at the word2vec model by
-[Mikolov et al.](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)
-This model is used for learning vector representations of words, called "word
-embeddings".
-
-## Highlights
-
-This tutorial is meant to highlight the interesting, substantive parts of
-building a word2vec model in TensorFlow.
-
-* We start by giving the motivation for why we would want to
-represent words as vectors.
-* We look at the intuition behind the model and how it is trained
-(with a splash of math for good measure).
-* We also show a simple implementation of the model in TensorFlow.
-* Finally, we look at ways to make the naive version scale better.
-
-We walk through the code later during the tutorial, but if you'd prefer to dive
-straight in, feel free to look at the minimalistic implementation in
-[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/word2vec/word2vec_basic.py)
-This basic example contains the code needed to download some data, train on it a
-bit and visualize the result. Once you get comfortable with reading and running
-the basic version, you can graduate to
-[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py)
-which is a more serious implementation that showcases some more advanced
-TensorFlow principles about how to efficiently use threads to move data into a
-text model, how to checkpoint during training, etc.
-
-But first, let's look at why we would want to learn word embeddings in the first
-place. Feel free to skip this section if you're an Embedding Pro and you'd just
-like to get your hands dirty with the details.
-
-## Motivation: Why Learn Word Embeddings?
-
-Image and audio processing systems work with rich, high-dimensional datasets
-encoded as vectors of the individual raw pixel-intensities for image data, or
-e.g. power spectral density coefficients for audio data. For tasks like object
-or speech recognition we know that all the information required to successfully
-perform the task is encoded in the data (because humans can perform these tasks
-from the raw data). However, natural language processing systems traditionally
-treat words as discrete atomic symbols, and therefore 'cat' may be represented
-as `Id537` and 'dog' as `Id143`. These encodings are arbitrary, and provide
-no useful information to the system regarding the relationships that may exist
-between the individual symbols. This means that the model can leverage
-very little of what it has learned about 'cats' when it is processing data about
-'dogs' (such that they are both animals, four-legged, pets, etc.). Representing
-words as unique, discrete ids furthermore leads to data sparsity, and usually
-means that we may need more data in order to successfully train statistical
-models. Using vector representations can overcome some of these obstacles.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/audio-image-text.png" alt>
-</div>
-
-[Vector space models](https://en.wikipedia.org/wiki/Vector_space_model) (VSMs)
-represent (embed) words in a continuous vector space where semantically
-similar words are mapped to nearby points ('are embedded nearby each other').
-VSMs have a long, rich history in NLP, but all methods depend in some way or
-another on the
-[Distributional Hypothesis](https://en.wikipedia.org/wiki/Distributional_semantics#Distributional_Hypothesis),
-which states that words that appear in the same contexts share
-semantic meaning. The different approaches that leverage this principle can be
-divided into two categories: *count-based methods* (e.g.
-[Latent Semantic Analysis](https://en.wikipedia.org/wiki/Latent_semantic_analysis)),
-and *predictive methods* (e.g.
-[neural probabilistic language models](http://www.scholarpedia.org/article/Neural_net_language_models)).
-
-This distinction is elaborated in much more detail by
-[Baroni et al.](http://clic.cimec.unitn.it/marco/publications/acl2014/baroni-etal-countpredict-acl2014.pdf),
-but in a nutshell: Count-based methods compute the statistics of
-how often some word co-occurs with its neighbor words in a large text corpus,
-and then map these count-statistics down to a small, dense vector for each word.
-Predictive models directly try to predict a word from its neighbors in terms of
-learned small, dense *embedding vectors* (considered parameters of the
-model).
-
-Word2vec is a particularly computationally-efficient predictive model for
-learning word embeddings from raw text. It comes in two flavors, the Continuous
-Bag-of-Words model (CBOW) and the Skip-Gram model (Section 3.1 and 3.2 in [Mikolov et al.](https://arxiv.org/pdf/1301.3781.pdf)). Algorithmically, these
-models are similar, except that CBOW predicts target words (e.g. 'mat') from
-source context words ('the cat sits on the'), while the skip-gram does the
-inverse and predicts source context-words from the target words. This inversion
-might seem like an arbitrary choice, but statistically it has the effect that
-CBOW smoothes over a lot of the distributional information (by treating an
-entire context as one observation). For the most part, this turns out to be a
-useful thing for smaller datasets. However, skip-gram treats each context-target
-pair as a new observation, and this tends to do better when we have larger
-datasets. We will focus on the skip-gram model in the rest of this tutorial.
-
-
-## Scaling up with Noise-Contrastive Training
-
-Neural probabilistic language models are traditionally trained using the
-[maximum likelihood](https://en.wikipedia.org/wiki/Maximum_likelihood) (ML)
-principle to maximize the probability of the next word \\(w_t\\) (for "target")
-given the previous words \\(h\\) (for "history") in terms of a
-[*softmax* function](https://en.wikipedia.org/wiki/Softmax_function),
-
-$$
-\begin{align}
-P(w_t | h) &= \text{softmax}(\text{score}(w_t, h)) \\
- &= \frac{\exp \{ \text{score}(w_t, h) \} }
- {\sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} }
-\end{align}
-$$
-
-where \\(\text{score}(w_t, h)\\) computes the compatibility of word \\(w_t\\)
-with the context \\(h\\) (a dot product is commonly used). We train this model
-by maximizing its [log-likelihood](https://en.wikipedia.org/wiki/Likelihood_function)
-on the training set, i.e. by maximizing
-
-$$
-\begin{align}
- J_\text{ML} &= \log P(w_t | h) \\
- &= \text{score}(w_t, h) -
- \log \left( \sum_\text{Word w' in Vocab} \exp \{ \text{score}(w', h) \} \right).
-\end{align}
-$$
-
-This yields a properly normalized probabilistic model for language modeling.
-However this is very expensive, because we need to compute and normalize each
-probability using the score for all other \\(V\\) words \\(w'\\) in the current
-context \\(h\\), *at every training step*.
-
-<div style="width:60%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/softmax-nplm.png" alt>
-</div>
-
-On the other hand, for feature learning in word2vec we do not need a full
-probabilistic model. The CBOW and skip-gram models are instead trained using a
-binary classification objective ([logistic regression](https://en.wikipedia.org/wiki/Logistic_regression))
-to discriminate the real target words \\(w_t\\) from \\(k\\) imaginary (noise) words \\(\tilde w\\), in the
-same context. We illustrate this below for a CBOW model. For skip-gram the
-direction is simply inverted.
-
-<div style="width:60%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/nce-nplm.png" alt>
-</div>
-
-Mathematically, the objective (for each example) is to maximize
-
-$$J_\text{NEG} = \log Q_\theta(D=1 |w_t, h) +
- k \mathop{\mathbb{E}}_{\tilde w \sim P_\text{noise}}
- \left[ \log Q_\theta(D = 0 |\tilde w, h) \right]$$
-
-where \\(Q_\theta(D=1 | w, h)\\) is the binary logistic regression probability
-under the model of seeing the word \\(w\\) in the context \\(h\\) in the dataset
-\\(D\\), calculated in terms of the learned embedding vectors \\(\theta\\). In
-practice we approximate the expectation by drawing \\(k\\) contrastive words
-from the noise distribution (i.e. we compute a
-[Monte Carlo average](https://en.wikipedia.org/wiki/Monte_Carlo_integration)).
-
-This objective is maximized when the model assigns high probabilities
-to the real words, and low probabilities to noise words. Technically, this is
-called
-[Negative Sampling](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf),
-and there is good mathematical motivation for using this loss function:
-The updates it proposes approximate the updates of the softmax function in the
-limit. But computationally it is especially appealing because computing the
-loss function now scales only with the number of *noise words* that we
-select (\\(k\\)), and not *all words* in the vocabulary (\\(V\\)). This makes it
-much faster to train. We will actually make use of the very similar
-[noise-contrastive estimation (NCE)](https://papers.nips.cc/paper/5165-learning-word-embeddings-efficiently-with-noise-contrastive-estimation.pdf)
-loss, for which TensorFlow has a handy helper function `tf.nn.nce_loss()`.
-
-Let's get an intuitive feel for how this would work in practice!
-
-## The Skip-gram Model
-
-As an example, let's consider the dataset
-
-`the quick brown fox jumped over the lazy dog`
-
-We first form a dataset of words and the contexts in which they appear. We
-could define 'context' in any way that makes sense, and in fact people have
-looked at syntactic contexts (i.e. the syntactic dependents of the current
-target word, see e.g.
-[Levy et al.](https://levyomer.files.wordpress.com/2014/04/dependency-based-word-embeddings-acl-2014.pdf)),
-words-to-the-left of the target, words-to-the-right of the target, etc. For now,
-let's stick to the vanilla definition and define 'context' as the window
-of words to the left and to the right of a target word. Using a window
-size of 1, we then have the dataset
-
-`([the, brown], quick), ([quick, fox], brown), ([brown, jumped], fox), ...`
-
-of `(context, target)` pairs. Recall that skip-gram inverts contexts and
-targets, and tries to predict each context word from its target word, so the
-task becomes to predict 'the' and 'brown' from 'quick', 'quick' and 'fox' from
-'brown', etc. Therefore our dataset becomes
-
-`(quick, the), (quick, brown), (brown, quick), (brown, fox), ...`
-
-of `(input, output)` pairs. The objective function is defined over the entire
-dataset, but we typically optimize this with
-[stochastic gradient descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)
-(SGD) using one example at a time (or a 'minibatch' of `batch_size` examples,
-where typically `16 <= batch_size <= 512`). So let's look at one step of
-this process.
-
-Let's imagine at training step \\(t\\) we observe the first training case above,
-where the goal is to predict `the` from `quick`. We select `num_noise` number
-of noisy (contrastive) examples by drawing from some noise distribution,
-typically the unigram distribution, \\(P(w)\\). For simplicity let's say
-`num_noise=1` and we select `sheep` as a noisy example. Next we compute the
-loss for this pair of observed and noisy examples, i.e. the objective at time
-step \\(t\\) becomes
-
-$$J^{(t)}_\text{NEG} = \log Q_\theta(D=1 | \text{the, quick}) +
- \log(Q_\theta(D=0 | \text{sheep, quick}))$$
-
-The goal is to make an update to the embedding parameters \\(\theta\\) to improve
-(in this case, maximize) this objective function. We do this by deriving the
-gradient of the loss with respect to the embedding parameters \\(\theta\\), i.e.
-\\(\frac{\partial}{\partial \theta} J_\text{NEG}\\) (luckily TensorFlow provides
-easy helper functions for doing this!). We then perform an update to the
-embeddings by taking a small step in the direction of the gradient. When this
-process is repeated over the entire training set, this has the effect of
-'moving' the embedding vectors around for each word until the model is
-successful at discriminating real words from noise words.
-
-We can visualize the learned vectors by projecting them down to 2 dimensions
-using for instance something like the
-[t-SNE dimensionality reduction technique](https://lvdmaaten.github.io/tsne/).
-When we inspect these visualizations it becomes apparent that the vectors
-capture some general, and in fact quite useful, semantic information about
-words and their relationships to one another. It was very interesting when we
-first discovered that certain directions in the induced vector space specialize
-towards certain semantic relationships, e.g. *male-female*, *verb tense* and
-even *country-capital* relationships between words, as illustrated in the figure
-below (see also for example
-[Mikolov et al., 2013](https://www.aclweb.org/anthology/N13-1090)).
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/linear-relationships.png" alt>
-</div>
-
-This explains why these vectors are also useful as features for many canonical
-NLP prediction tasks, such as part-of-speech tagging or named entity recognition
-(see for example the original work by
-[Collobert et al., 2011](https://arxiv.org/abs/1103.0398)
-([pdf](https://arxiv.org/pdf/1103.0398.pdf)), or follow-up work by
-[Turian et al., 2010](https://www.aclweb.org/anthology/P10-1040)).
-
-But for now, let's just use them to draw pretty pictures!
-
-## Building the Graph
-
-This is all about embeddings, so let's define our embedding matrix.
-This is just a big random matrix to start. We'll initialize the values to be
-uniform in the unit cube.
-
-```python
-embeddings = tf.Variable(
- tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
-```
-
-The noise-contrastive estimation loss is defined in terms of a logistic regression
-model. For this, we need to define the weights and biases for each word in the
-vocabulary (also called the `output weights` as opposed to the `input
-embeddings`). So let's define that.
-
-```python
-nce_weights = tf.Variable(
- tf.truncated_normal([vocabulary_size, embedding_size],
- stddev=1.0 / math.sqrt(embedding_size)))
-nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
-```
-
-Now that we have the parameters in place, we can define our skip-gram model
-graph. For simplicity, let's suppose we've already integerized our text corpus
-with a vocabulary so that each word is represented as an integer (see
-[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/word2vec/word2vec_basic.py)
-for the details). The skip-gram model takes two inputs. One is a batch full of
-integers representing the source context words, the other is for the target
-words. Let's create placeholder nodes for these inputs, so that we can feed in
-data later.
-
-```python
-# Placeholders for inputs
-train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
-train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
-```
-
-Now what we need to do is look up the vector for each of the source words in
-the batch. TensorFlow has handy helpers that make this easy.
-
-```python
-embed = tf.nn.embedding_lookup(embeddings, train_inputs)
-```
-
-Ok, now that we have the embeddings for each word, we'd like to try to predict
-the target word using the noise-contrastive training objective.
-
-```python
-# Compute the NCE loss, using a sample of the negative labels each time.
-loss = tf.reduce_mean(
- tf.nn.nce_loss(weights=nce_weights,
- biases=nce_biases,
- labels=train_labels,
- inputs=embed,
- num_sampled=num_sampled,
- num_classes=vocabulary_size))
-```
-
-Now that we have a loss node, we need to add the nodes required to compute
-gradients and update the parameters, etc. For this we will use stochastic
-gradient descent, and TensorFlow has handy helpers to make this easy as well.
-
-```python
-# We use the SGD optimizer.
-optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0).minimize(loss)
-```
-
-## Training the Model
-
-Training the model is then as simple as using a `feed_dict` to push data into
-the placeholders and calling
-`tf.Session.run` with this new data
-in a loop.
-
-```python
-for inputs, labels in generate_batch(...):
- feed_dict = {train_inputs: inputs, train_labels: labels}
- _, cur_loss = session.run([optimizer, loss], feed_dict=feed_dict)
-```
-
-See the full example code in
-[tensorflow/examples/tutorials/word2vec/word2vec_basic.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/word2vec/word2vec_basic.py).
-
-## Visualizing the Learned Embeddings
-
-After training has finished we can visualize the learned embeddings using
-t-SNE.
-
-<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://www.tensorflow.org/images/tsne.png" alt>
-</div>
-
-Et voila! As expected, words that are similar end up clustering nearby each
-other. For a more heavyweight implementation of word2vec that showcases more of
-the advanced features of TensorFlow, see the implementation in
-[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
-
-## Evaluating Embeddings: Analogical Reasoning
-
-Embeddings are useful for a wide variety of prediction tasks in NLP. Short of
-training a full-blown part-of-speech model or named-entity model, one simple way
-to evaluate embeddings is to directly use them to predict syntactic and semantic
-relationships like `king is to queen as father is to ?`. This is called
-*analogical reasoning* and the task was introduced by
-[Mikolov and colleagues
-](https://www.aclweb.org/anthology/N13-1090).
-Download the dataset for this task from
-[download.tensorflow.org](http://download.tensorflow.org/data/questions-words.txt).
-
-To see how we do this evaluation, have a look at the `build_eval_graph()` and
-`eval()` functions in
-[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
-
-The choice of hyperparameters can strongly influence the accuracy on this task.
-To achieve state-of-the-art performance on this task requires training over a
-very large dataset, carefully tuning the hyperparameters and making use of
-tricks like subsampling the data, which is out of the scope of this tutorial.
-
-
-## Optimizing the Implementation
-
-Our vanilla implementation showcases the flexibility of TensorFlow. For
-example, changing the training objective is as simple as swapping out the call
-to `tf.nn.nce_loss()` for an off-the-shelf alternative such as
-`tf.nn.sampled_softmax_loss()`. If you have a new idea for a loss function, you
-can manually write an expression for the new objective in TensorFlow and let
-the optimizer compute its derivatives. This flexibility is invaluable in the
-exploratory phase of machine learning model development, where we are trying
-out several different ideas and iterating quickly.
-
-Once you have a model structure you're satisfied with, it may be worth
-optimizing your implementation to run more efficiently (and cover more data in
-less time). For example, the naive code we used in this tutorial would suffer
-compromised speed because we use Python for reading and feeding data items --
-each of which require very little work on the TensorFlow back-end. If you find
-your model is seriously bottlenecked on input data, you may want to implement a
-custom data reader for your problem, as described in
-@{$new_data_formats$New Data Formats}. For the case of Skip-Gram
-modeling, we've actually already done this for you as an example in
-[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
-
-If your model is no longer I/O bound but you want still more performance, you
-can take things further by writing your own TensorFlow Ops, as described in
-@{$adding_an_op$Adding a New Op}. Again we've provided an
-example of this for the Skip-Gram case
-[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py).
-Feel free to benchmark these against each other to measure performance
-improvements at each stage.
-
-## Conclusion
-
-In this tutorial we covered the word2vec model, a computationally efficient
-model for learning word embeddings. We motivated why embeddings are useful,
-discussed efficient training techniques and showed how to implement all of this
-in TensorFlow. Overall, we hope that this has show-cased how TensorFlow affords
-you the flexibility you need for early experimentation, and the control you
-later need for bespoke optimized implementation.
diff --git a/tensorflow/docs_src/tutorials/sequences/audio_recognition.md b/tensorflow/docs_src/tutorials/sequences/audio_recognition.md
deleted file mode 100644
index d7a8da6f96..0000000000
--- a/tensorflow/docs_src/tutorials/sequences/audio_recognition.md
+++ /dev/null
@@ -1,631 +0,0 @@
-# Simple Audio Recognition
-
-This tutorial will show you how to build a basic speech recognition network that
-recognizes ten different words. It's important to know that real speech and
-audio recognition systems are much more complex, but like MNIST for images, it
-should give you a basic understanding of the techniques involved. Once you've
-completed this tutorial, you'll have a model that tries to classify a one second
-audio clip as either silence, an unknown word, "yes", "no", "up", "down",
-"left", "right", "on", "off", "stop", or "go". You'll also be able to take this
-model and run it in an Android application.
-
-## Preparation
-
-You should make sure you have TensorFlow installed, and since the script
-downloads over 1GB of training data, you'll need a good internet connection and
-enough free space on your machine. The training process itself can take several
-hours, so make sure you have a machine available for that long.
-
-## Training
-
-To begin the training process, go to the TensorFlow source tree and run:
-
-```bash
-python tensorflow/examples/speech_commands/train.py
-```
-
-The script will start off by downloading the [Speech Commands
-dataset](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz),
-which consists of over 105,000 WAVE audio files of people saying thirty
-different words. This data was collected by Google and released under a CC BY
-license, and you can help improve it by [contributing five minutes of your own
-voice](https://aiyprojects.withgoogle.com/open_speech_recording). The archive is
-over 2GB, so this part may take a while, but you should see progress logs, and
-once it's been downloaded once you won't need to do this step again. You can
-find more information about this dataset in this
-[Speech Commands paper](https://arxiv.org/abs/1804.03209).
-
-Once the downloading has completed, you'll see logging information that looks
-like this:
-
-```
-I0730 16:53:44.766740 55030 train.py:176] Training from step: 1
-I0730 16:53:47.289078 55030 train.py:217] Step #1: rate 0.001000, accuracy 7.0%, cross entropy 2.611571
-```
-
-This shows that the initialization process is done and the training loop has
-begun. You'll see that it outputs information for every training step. Here's a
-break down of what it means:
-
-`Step #1` shows that we're on the first step of the training loop. In this case
-there are going to be 18,000 steps in total, so you can look at the step number
-to get an idea of how close it is to finishing.
-
-`rate 0.001000` is the learning rate that's controlling the speed of the
-network's weight updates. Early on this is a comparatively high number (0.001),
-but for later training cycles it will be reduced 10x, to 0.0001.
-
-`accuracy 7.0%` is the how many classes were correctly predicted on this
-training step. This value will often fluctuate a lot, but should increase on
-average as training progresses. The model outputs an array of numbers, one for
-each label, and each number is the predicted likelihood of the input being that
-class. The predicted label is picked by choosing the entry with the highest
-score. The scores are always between zero and one, with higher values
-representing more confidence in the result.
-
-`cross entropy 2.611571` is the result of the loss function that we're using to
-guide the training process. This is a score that's obtained by comparing the
-vector of scores from the current training run to the correct labels, and this
-should trend downwards during training.
-
-After a hundred steps, you should see a line like this:
-
-`I0730 16:54:41.813438 55030 train.py:252] Saving to
-"/tmp/speech_commands_train/conv.ckpt-100"`
-
-This is saving out the current trained weights to a checkpoint file. If your
-training script gets interrupted, you can look for the last saved checkpoint and
-then restart the script with
-`--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-100` as a command line
-argument to start from that point.
-
-## Confusion Matrix
-
-After four hundred steps, this information will be logged:
-
-```
-I0730 16:57:38.073667 55030 train.py:243] Confusion Matrix:
- [[258 0 0 0 0 0 0 0 0 0 0 0]
- [ 7 6 26 94 7 49 1 15 40 2 0 11]
- [ 10 1 107 80 13 22 0 13 10 1 0 4]
- [ 1 3 16 163 6 48 0 5 10 1 0 17]
- [ 15 1 17 114 55 13 0 9 22 5 0 9]
- [ 1 1 6 97 3 87 1 12 46 0 0 10]
- [ 8 6 86 84 13 24 1 9 9 1 0 6]
- [ 9 3 32 112 9 26 1 36 19 0 0 9]
- [ 8 2 12 94 9 52 0 6 72 0 0 2]
- [ 16 1 39 74 29 42 0 6 37 9 0 3]
- [ 15 6 17 71 50 37 0 6 32 2 1 9]
- [ 11 1 6 151 5 42 0 8 16 0 0 20]]
-```
-
-The first section is a [confusion
-matrix](https://www.tensorflow.org/api_docs/python/tf/confusion_matrix). To
-understand what it means, you first need to know the labels being used, which in
-this case are "_silence_", "_unknown_", "yes", "no", "up", "down", "left",
-"right", "on", "off", "stop", and "go". Each column represents a set of samples
-that were predicted to be each label, so the first column represents all the
-clips that were predicted to be silence, the second all those that were
-predicted to be unknown words, the third "yes", and so on.
-
-Each row represents clips by their correct, ground truth labels. The first row
-is all the clips that were silence, the second clips that were unknown words,
-the third "yes", etc.
-
-This matrix can be more useful than just a single accuracy score because it
-gives a good summary of what mistakes the network is making. In this example you
-can see that all of the entries in the first row are zero, apart from the
-initial one. Because the first row is all the clips that are actually silence,
-this means that none of them were mistakenly labeled as words, so we have no
-false negatives for silence. This shows the network is already getting pretty
-good at distinguishing silence from words.
-
-If we look down the first column though, we see a lot of non-zero values. The
-column represents all the clips that were predicted to be silence, so positive
-numbers outside of the first cell are errors. This means that some clips of real
-spoken words are actually being predicted to be silence, so we do have quite a
-few false positives.
-
-A perfect model would produce a confusion matrix where all of the entries were
-zero apart from a diagonal line through the center. Spotting deviations from
-that pattern can help you figure out how the model is most easily confused, and
-once you've identified the problems you can address them by adding more data or
-cleaning up categories.
-
-## Validation
-
-After the confusion matrix, you should see a line like this:
-
-`I0730 16:57:38.073777 55030 train.py:245] Step 400: Validation accuracy = 26.3%
-(N=3093)`
-
-It's good practice to separate your data set into three categories. The largest
-(in this case roughly 80% of the data) is used for training the network, a
-smaller set (10% here, known as "validation") is reserved for evaluation of the
-accuracy during training, and another set (the last 10%, "testing") is used to
-evaluate the accuracy once after the training is complete.
-
-The reason for this split is that there's always a danger that networks will
-start memorizing their inputs during training. By keeping the validation set
-separate, you can ensure that the model works with data it's never seen before.
-The testing set is an additional safeguard to make sure that you haven't just
-been tweaking your model in a way that happens to work for both the training and
-validation sets, but not a broader range of inputs.
-
-The training script automatically separates the data set into these three
-categories, and the logging line above shows the accuracy of model when run on
-the validation set. Ideally, this should stick fairly close to the training
-accuracy. If the training accuracy increases but the validation doesn't, that's
-a sign that overfitting is occurring, and your model is only learning things
-about the training clips, not broader patterns that generalize.
-
-## Tensorboard
-
-A good way to visualize how the training is progressing is using Tensorboard. By
-default, the script saves out events to /tmp/retrain_logs, and you can load
-these by running:
-
-`tensorboard --logdir /tmp/retrain_logs`
-
-Then navigate to [http://localhost:6006](http://localhost:6006) in your browser,
-and you'll see charts and graphs showing your models progress.
-
-<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/speech_commands_tensorflow.png"/>
-</div>
-
-## Training Finished
-
-After a few hours of training (depending on your machine's speed), the script
-should have completed all 18,000 steps. It will print out a final confusion
-matrix, along with an accuracy score, all run on the testing set. With the
-default settings, you should see an accuracy of between 85% and 90%.
-
-Because audio recognition is particularly useful on mobile devices, next we'll
-export it to a compact format that's easy to work with on those platforms. To do
-that, run this command line:
-
-```
-python tensorflow/examples/speech_commands/freeze.py \
---start_checkpoint=/tmp/speech_commands_train/conv.ckpt-18000 \
---output_file=/tmp/my_frozen_graph.pb
-```
-
-Once the frozen model has been created, you can test it with the `label_wav.py`
-script, like this:
-
-```
-python tensorflow/examples/speech_commands/label_wav.py \
---graph=/tmp/my_frozen_graph.pb \
---labels=/tmp/speech_commands_train/conv_labels.txt \
---wav=/tmp/speech_dataset/left/a5d485dc_nohash_0.wav
-```
-
-This should print out three labels:
-
-```
-left (score = 0.81477)
-right (score = 0.14139)
-_unknown_ (score = 0.03808)
-```
-
-Hopefully "left" is the top score since that's the correct label, but since the
-training is random it may not for the first file you try. Experiment with some
-of the other .wav files in that same folder to see how well it does.
-
-The scores are between zero and one, and higher values mean the model is more
-confident in its prediction.
-
-## Running the Model in an Android App
-
-The easiest way to see how this model works in a real application is to download
-[the prebuilt Android demo
-applications](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#prebuilt-components)
-and install them on your phone. You'll see 'TF Speech' appear in your app list,
-and opening it will show you the same list of action words we've just trained
-our model on, starting with "Yes" and "No". Once you've given the app permission
-to use the microphone, you should be able to try saying those words and see them
-highlighted in the UI when the model recognizes one of them.
-
-You can also build this application yourself, since it's open source and
-[available as part of the TensorFlow repository on
-github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#building-in-android-studio-using-the-tensorflow-aar-from-jcenter).
-By default it downloads [a pretrained model from
-tensorflow.org](http://download.tensorflow.org/models/speech_commands_v0.02.zip),
-but you can easily [replace it with a model you've trained
-yourself](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-model-files-optional).
-If you do this, you'll need to make sure that the constants in [the main
-SpeechActivity Java source
-file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java)
-like `SAMPLE_RATE` and `SAMPLE_DURATION` match any changes you've made to the
-defaults while training. You'll also see that there's a [Java version of the
-RecognizeCommands
-module](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/src/org/tensorflow/demo/RecognizeCommands.java)
-that's very similar to the C++ version in this tutorial. If you've tweaked
-parameters for that, you can also update them in SpeechActivity to get the same
-results as in your server testing.
-
-The demo app updates its UI list of results automatically based on the labels
-text file you copy into assets alongside your frozen graph, which means you can
-easily try out different models without needing to make any code changes. You
-will need to update `LABEL_FILENAME` and `MODEL_FILENAME` to point to the files
-you've added if you change the paths though.
-
-## How does this Model Work?
-
-The architecture used in this tutorial is based on some described in the paper
-[Convolutional Neural Networks for Small-footprint Keyword
-Spotting](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
-It was chosen because it's comparatively simple, quick to train, and easy to
-understand, rather than being state of the art. There are lots of different
-approaches to building neural network models to work with audio, including
-[recurrent networks](https://svds.com/tensorflow-rnn-tutorial/) or [dilated
-(atrous)
-convolutions](https://deepmind.com/blog/wavenet-generative-model-raw-audio/).
-This tutorial is based on the kind of convolutional network that will feel very
-familiar to anyone who's worked with image recognition. That may seem surprising
-at first though, since audio is inherently a one-dimensional continuous signal
-across time, not a 2D spatial problem.
-
-We solve that issue by defining a window of time we believe our spoken words
-should fit into, and converting the audio signal in that window into an image.
-This is done by grouping the incoming audio samples into short segments, just a
-few milliseconds long, and calculating the strength of the frequencies across a
-set of bands. Each set of frequency strengths from a segment is treated as a
-vector of numbers, and those vectors are arranged in time order to form a
-two-dimensional array. This array of values can then be treated like a
-single-channel image, and is known as a
-[spectrogram](https://en.wikipedia.org/wiki/Spectrogram). If you want to view
-what kind of image an audio sample produces, you can run the `wav_to_spectrogram
-tool:
-
-```
-bazel run tensorflow/examples/wav_to_spectrogram:wav_to_spectrogram -- \
---input_wav=/tmp/speech_dataset/happy/ab00c4b2_nohash_0.wav \
---output_image=/tmp/spectrogram.png
-```
-
-If you open up `/tmp/spectrogram.png` you should see something like this:
-
-<div style="width:50%; margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="https://storage.googleapis.com/download.tensorflow.org/example_images/spectrogram.png"/>
-</div>
-
-Because of TensorFlow's memory order, time in this image is increasing from top
-to bottom, with frequencies going from left to right, unlike the usual
-convention for spectrograms where time is left to right. You should be able to
-see a couple of distinct parts, with the first syllable "Ha" distinct from
-"ppy".
-
-Because the human ear is more sensitive to some frequencies than others, it's
-been traditional in speech recognition to do further processing to this
-representation to turn it into a set of [Mel-Frequency Cepstral
-Coefficients](https://en.wikipedia.org/wiki/Mel-frequency_cepstrum), or MFCCs
-for short. This is also a two-dimensional, one-channel representation so it can
-be treated like an image too. If you're targeting general sounds rather than
-speech you may find you can skip this step and operate directly on the
-spectrograms.
-
-The image that's produced by these processing steps is then fed into a
-multi-layer convolutional neural network, with a fully-connected layer followed
-by a softmax at the end. You can see the definition of this portion in
-[tensorflow/examples/speech_commands/models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py).
-
-## Streaming Accuracy
-
-Most audio recognition applications need to run on a continuous stream of audio,
-rather than on individual clips. A typical way to use a model in this
-environment is to apply it repeatedly at different offsets in time and average
-the results over a short window to produce a smoothed prediction. If you think
-of the input as an image, it's continuously scrolling along the time axis. The
-words we want to recognize can start at any time, so we need to take a series of
-snapshots to have a chance of having an alignment that captures most of the
-utterance in the time window we feed into the model. If we sample at a high
-enough rate, then we have a good chance of capturing the word in multiple
-windows, so averaging the results improves the overall confidence of the
-prediction.
-
-For an example of how you can use your model on streaming data, you can look at
-[test_streaming_accuracy.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/).
-This uses the
-[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h)
-class to run through a long-form input audio, try to spot words, and compare
-those predictions against a ground truth list of labels and times. This makes it
-a good example of applying a model to a stream of audio signals over time.
-
-You'll need a long audio file to test it against, along with labels showing
-where each word was spoken. If you don't want to record one yourself, you can
-generate some synthetic test data using the `generate_streaming_test_wav`
-utility. By default this will create a ten minute .wav file with words roughly
-every three seconds, and a text file containing the ground truth of when each
-word was spoken. These words are pulled from the test portion of your current
-dataset, mixed in with background noise. To run it, use:
-
-```
-bazel run tensorflow/examples/speech_commands:generate_streaming_test_wav
-```
-
-This will save a .wav file to `/tmp/speech_commands_train/streaming_test.wav`,
-and a text file listing the labels to
-`/tmp/speech_commands_train/streaming_test_labels.txt`. You can then run
-accuracy testing with:
-
-```
-bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \
---graph=/tmp/my_frozen_graph.pb \
---labels=/tmp/speech_commands_train/conv_labels.txt \
---wav=/tmp/speech_commands_train/streaming_test.wav \
---ground_truth=/tmp/speech_commands_train/streaming_test_labels.txt \
---verbose
-```
-
-This will output information about the number of words correctly matched, how
-many were given the wrong labels, and how many times the model triggered when
-there was no real word spoken. There are various parameters that control how the
-signal averaging works, including `--average_window_ms` which sets the length of
-time to average results over, `--clip_stride_ms` which is the time between
-applications of the model, `--suppression_ms` which stops subsequent word
-detections from triggering for a certain time after an initial one is found, and
-`--detection_threshold`, which controls how high the average score must be
-before it's considered a solid result.
-
-You'll see that the streaming accuracy outputs three numbers, rather than just
-the one metric used in training. This is because different applications have
-varying requirements, with some being able to tolerate frequent incorrect
-results as long as real words are found (high recall), while others very focused
-on ensuring the predicted labels are highly likely to be correct even if some
-aren't detected (high precision). The numbers from the tool give you an idea of
-how your model will perform in an application, and you can try tweaking the
-signal averaging parameters to tune it to give the kind of performance you want.
-To understand what the right parameters are for your application, you can look
-at generating an [ROC
-curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) to help
-you understand the tradeoffs.
-
-## RecognizeCommands
-
-The streaming accuracy tool uses a simple decoder contained in a small C++ class
-called
-[RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h).
-This class is fed the output of running the TensorFlow model over time, it
-averages the signals, and returns information about a label when it has enough
-evidence to think that a recognized word has been found. The implementation is
-fairly small, just keeping track of the last few predictions and averaging them,
-so it's easy to port to other platforms and languages as needed. For example,
-it's convenient to do something similar at the Java level on Android, or Python
-on the Raspberry Pi. As long as these implementations share the same logic, you
-can tune the parameters that control the averaging using the streaming test
-tool, and then transfer them over to your application to get similar results.
-
-## Advanced Training
-
-The defaults for the training script are designed to produce good end to end
-results in a comparatively small file, but there are a lot of options you can
-change to customize the results for your own requirements.
-
-### Custom Training Data
-
-By default the script will download the [Speech Commands
-dataset](https://download.tensorflow.org/data/speech_commands_v0.01.tgz), but
-you can also supply your own training data. To train on your own data, you
-should make sure that you have at least several hundred recordings of each sound
-you would like to recognize, and arrange them into folders by class. For
-example, if you were trying to recognize dog barks from cat miaows, you would
-create a root folder called `animal_sounds`, and then within that two
-sub-folders called `bark` and `miaow`. You would then organize your audio files
-into the appropriate folders.
-
-To point the script to your new audio files, you'll need to set `--data_url=` to
-disable downloading of the Speech Commands dataset, and
-`--data_dir=/your/data/folder/` to find the files you've just created.
-
-The files themselves should be 16-bit little-endian PCM-encoded WAVE format. The
-sample rate defaults to 16,000, but as long as all your audio is consistently
-the same rate (the script doesn't support resampling) you can change this with
-the `--sample_rate` argument. The clips should also all be roughly the same
-duration. The default expected duration is one second, but you can set this with
-the `--clip_duration_ms` flag. If you have clips with variable amounts of
-silence at the start, you can look at word alignment tools to standardize them
-([here's a quick and dirty approach you can use
-too](https://petewarden.com/2017/07/17/a-quick-hack-to-align-single-word-audio-recordings/)).
-
-One issue to watch out for is that you may have very similar repetitions of the
-same sounds in your dataset, and these can give misleading metrics if they're
-spread across your training, validation, and test sets. For example, the Speech
-Commands set has people repeating the same word multiple times. Each one of
-those repetitions is likely to be pretty close to the others, so if training was
-overfitting and memorizing one, it could perform unrealistically well when it
-saw a very similar copy in the test set. To avoid this danger, Speech Commands
-trys to ensure that all clips featuring the same word spoken by a single person
-are put into the same partition. Clips are assigned to training, test, or
-validation sets based on a hash of their filename, to ensure that the
-assignments remain steady even as new clips are added and avoid any training
-samples migrating into the other sets. To make sure that all a given speaker's
-words are in the same bucket, [the hashing
-function](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/input_data.py)
-ignores anything in a filename after '_nohash_' when calculating the
-assignments. This means that if you have file names like `pete_nohash_0.wav` and
-`pete_nohash_1.wav`, they're guaranteed to be in the same set.
-
-### Unknown Class
-
-It's likely that your application will hear sounds that aren't in your training
-set, and you'll want the model to indicate that it doesn't recognize the noise
-in those cases. To help the network learn what sounds to ignore, you need to
-provide some clips of audio that are neither of your classes. To do this, you'd
-create `quack`, `oink`, and `moo` subfolders and populate them with noises from
-other animals your users might encounter. The `--wanted_words` argument to the
-script defines which classes you care about, all the others mentioned in
-subfolder names will be used to populate an `_unknown_` class during training.
-The Speech Commands dataset has twenty words in its unknown classes, including
-the digits zero through nine and random names like "Sheila".
-
-By default 10% of the training examples are picked from the unknown classes, but
-you can control this with the `--unknown_percentage` flag. Increasing this will
-make the model less likely to mistake unknown words for wanted ones, but making
-it too large can backfire as the model might decide it's safest to categorize
-all words as unknown!
-
-### Background Noise
-
-Real applications have to recognize audio even when there are other irrelevant
-sounds happening in the environment. To build a model that's robust to this kind
-of interference, we need to train against recorded audio with similar
-properties. The files in the Speech Commands dataset were captured on a variety
-of devices by users in many different environments, not in a studio, so that
-helps add some realism to the training. To add even more, you can mix in random
-segments of environmental audio to the training inputs. In the Speech Commands
-set there's a special folder called `_background_noise_` which contains
-minute-long WAVE files with white noise and recordings of machinery and everyday
-household activity.
-
-Small snippets of these files are chosen at random and mixed at a low volume
-into clips during training. The loudness is also chosen randomly, and controlled
-by the `--background_volume` argument as a proportion where 0 is silence, and 1
-is full volume. Not all clips have background added, so the
-`--background_frequency` flag controls what proportion have them mixed in.
-
-Your own application might operate in its own environment with different
-background noise patterns than these defaults, so you can supply your own audio
-clips in the `_background_noise_` folder. These should be the same sample rate
-as your main dataset, but much longer in duration so that a good set of random
-segments can be selected from them.
-
-### Silence
-
-In most cases the sounds you care about will be intermittent and so it's
-important to know when there's no matching audio. To support this, there's a
-special `_silence_` label that indicates when the model detects nothing
-interesting. Because there's never complete silence in real environments, we
-actually have to supply examples with quiet and irrelevant audio. For this, we
-reuse the `_background_noise_` folder that's also mixed in to real clips,
-pulling short sections of the audio data and feeding those in with the ground
-truth class of `_silence_`. By default 10% of the training data is supplied like
-this, but the `--silence_percentage` can be used to control the proportion. As
-with unknown words, setting this higher can weight the model results in favor of
-true positives for silence, at the expense of false negatives for words, but too
-large a proportion can cause it to fall into the trap of always guessing
-silence.
-
-### Time Shifting
-
-Adding in background noise is one way of distorting the training data in a
-realistic way to effectively increase the size of the dataset, and so increase
-overall accuracy, and time shifting is another. This involves a random offset in
-time of the training sample data, so that a small part of the start or end is
-cut off and the opposite section is padded with zeroes. This mimics the natural
-variations in starting time in the training data, and is controlled with the
-`--time_shift_ms` flag, which defaults to 100ms. Increasing this value will
-provide more variation, but at the risk of cutting off important parts of the
-audio. A related way of augmenting the data with realistic distortions is by
-using [time stretching and pitch
-scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling),
-but that's outside the scope of this tutorial.
-
-## Customizing the Model
-
-The default model used for this script is pretty large, taking over 800 million
-FLOPs for each inference and using 940,000 weight parameters. This runs at
-usable speeds on desktop machines or modern phones, but it involves too many
-calculations to run at interactive speeds on devices with more limited
-resources. To support these use cases, there's a couple of alternatives
-available:
-
-
-**low_latency_conv**
-Based on the 'cnn-one-fstride4' topology described in the [Convolutional
-Neural Networks for Small-footprint Keyword Spotting
-paper](http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf).
-The accuracy is slightly lower than 'conv' but the number of weight parameters
-is about the same, and it only needs 11 million FLOPs to run one prediction,
-making it much faster.
-
-To use this model, you specify `--model_architecture=low_latency_conv` on
-the command line. You'll also need to update the training rates and the number
-of steps, so the full command will look like:
-
-```
-python tensorflow/examples/speech_commands/train \
---model_architecture=low_latency_conv \
---how_many_training_steps=20000,6000 \
---learning_rate=0.01,0.001
-```
-
-This asks the script to train with a learning rate of 0.01 for 20,000 steps, and
-then do a fine-tuning pass of 6,000 steps with a 10x smaller rate.
-
-**low_latency_svdf**
-Based on the topology presented in the [Compressing Deep Neural Networks using a
-Rank-Constrained Topology paper](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43813.pdf).
-The accuracy is also lower than 'conv' but it only uses about 750 thousand
-parameters, and most significantly, it allows for an optimized execution at
-test time (i.e. when you will actually use it in your application), resulting
-in 750 thousand FLOPs.
-
-To use this model, you specify `--model_architecture=low_latency_svdf` on
-the command line, and update the training rates and the number
-of steps, so the full command will look like:
-
-```
-python tensorflow/examples/speech_commands/train \
---model_architecture=low_latency_svdf \
---how_many_training_steps=100000,35000 \
---learning_rate=0.01,0.005
-```
-
-Note that despite requiring a larger number of steps than the previous two
-topologies, the reduced number of computations means that training should take
-about the same time, and at the end reach an accuracy of around 85%.
-You can also further tune the topology fairly easily for computation and
-accuracy by changing these parameters in the SVDF layer:
-
-* rank - The rank of the approximation (higher typically better, but results in
- more computation).
-* num_units - Similar to other layer types, specifies the number of nodes in
- the layer (more nodes better quality, and more computation).
-
-Regarding runtime, since the layer allows optimizations by caching some of the
-internal neural network activations, you need to make sure to use a consistent
-stride (e.g. 'clip_stride_ms' flag) both when you freeze the graph, and when
-executing the model in streaming mode (e.g. test_streaming_accuracy.cc).
-
-**Other parameters to customize**
-If you want to experiment with customizing models, a good place to start is by
-tweaking the spectrogram creation parameters. This has the effect of altering
-the size of the input image to the model, and the creation code in
-[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
-will adjust the number of computations and weights automatically to fit with
-different dimensions. If you make the input smaller, the model will need fewer
-computations to process it, so it can be a great way to trade off some accuracy
-for improved latency. The `--window_stride_ms` controls how far apart each
-frequency analysis sample is from the previous. If you increase this value, then
-fewer samples will be taken for a given duration, and the time axis of the input
-will shrink. The `--dct_coefficient_count` flag controls how many buckets are
-used for the frequency counting, so reducing this will shrink the input in the
-other dimension. The `--window_size_ms` argument doesn't affect the size, but
-does control how wide the area used to calculate the frequencies is for each
-sample. Reducing the duration of the training samples, controlled by
-`--clip_duration_ms`, can also help if the sounds you're looking for are short,
-since that also reduces the time dimension of the input. You'll need to make
-sure that all your training data contains the right audio in the initial portion
-of the clip though.
-
-If you have an entirely different model in mind for your problem, you may find
-that you can plug it into
-[models.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/models.py)
-and have the rest of the script handle all of the preprocessing and training
-mechanics. You would add a new clause to `create_model`, looking for the name of
-your architecture and then calling a model creation function. This function is
-given the size of the spectrogram input, along with other model information, and
-is expected to create TensorFlow ops to read that in and produce an output
-prediction vector, and a placeholder to control the dropout rate. The rest of
-the script will handle integrating this model into a larger graph doing the
-input calculations and applying softmax and a loss function to train it.
-
-One common problem when you're adjusting models and training hyper-parameters is
-that not-a-number values can creep in, thanks to numerical precision issues. In
-general you can solve these by reducing the magnitude of things like learning
-rates and weight initialization functions, but if they're persistent you can
-enable the `--check_nans` flag to track down the source of the errors. This will
-insert check ops between most regular operations in TensorFlow, and abort the
-training process with a useful error message when they're encountered.
diff --git a/tensorflow/docs_src/tutorials/sequences/recurrent.md b/tensorflow/docs_src/tutorials/sequences/recurrent.md
deleted file mode 100644
index 10d60f7966..0000000000
--- a/tensorflow/docs_src/tutorials/sequences/recurrent.md
+++ /dev/null
@@ -1,230 +0,0 @@
-# Recurrent Neural Networks
-
-## Introduction
-
-See [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/){:.external}
-for an introduction to recurrent neural networks and LSTMs.
-
-## Language Modeling
-
-In this tutorial we will show how to train a recurrent neural network on
-a challenging task of language modeling. The goal of the problem is to fit a
-probabilistic model which assigns probabilities to sentences. It does so by
-predicting next words in a text given a history of previous words. For this
-purpose we will use the [Penn Tree Bank](https://catalog.ldc.upenn.edu/ldc99t42)
-(PTB) dataset, which is a popular benchmark for measuring the quality of these
-models, whilst being small and relatively fast to train.
-
-Language modeling is key to many interesting problems such as speech
-recognition, machine translation, or image captioning. It is also fun --
-take a look [here](https://karpathy.github.io/2015/05/21/rnn-effectiveness/).
-
-For the purpose of this tutorial, we will reproduce the results from
-[Zaremba et al., 2014](https://arxiv.org/abs/1409.2329)
-([pdf](https://arxiv.org/pdf/1409.2329.pdf)), which achieves very good quality
-on the PTB dataset.
-
-## Tutorial Files
-
-This tutorial references the following files from `models/tutorials/rnn/ptb` in the [TensorFlow models repo](https://github.com/tensorflow/models):
-
-File | Purpose
---- | ---
-`ptb_word_lm.py` | The code to train a language model on the PTB dataset.
-`reader.py` | The code to read the dataset.
-
-## Download and Prepare the Data
-
-The data required for this tutorial is in the `data/` directory of the
-[PTB dataset from Tomas Mikolov's webpage](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz).
-
-The dataset is already preprocessed and contains overall 10000 different words,
-including the end-of-sentence marker and a special symbol (\<unk\>) for rare
-words. In `reader.py`, we convert each word to a unique integer identifier,
-in order to make it easy for the neural network to process the data.
-
-## The Model
-
-### LSTM
-
-The core of the model consists of an LSTM cell that processes one word at a
-time and computes probabilities of the possible values for the next word in the
-sentence. The memory state of the network is initialized with a vector of zeros
-and gets updated after reading each word. For computational reasons, we will
-process data in mini-batches of size `batch_size`. In this example, it is
-important to note that `current_batch_of_words` does not correspond to a
-"sentence" of words. Every word in a batch should correspond to a time t.
-TensorFlow will automatically sum the gradients of each batch for you.
-
-For example:
-
-```
- t=0 t=1 t=2 t=3 t=4
-[The, brown, fox, is, quick]
-[The, red, fox, jumped, high]
-
-words_in_dataset[0] = [The, The]
-words_in_dataset[1] = [brown, red]
-words_in_dataset[2] = [fox, fox]
-words_in_dataset[3] = [is, jumped]
-words_in_dataset[4] = [quick, high]
-batch_size = 2, time_steps = 5
-```
-
-The basic pseudocode is as follows:
-
-```python
-words_in_dataset = tf.placeholder(tf.float32, [time_steps, batch_size, num_features])
-lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
-# Initial state of the LSTM memory.
-state = lstm.zero_state(batch_size, dtype=tf.float32)
-probabilities = []
-loss = 0.0
-for current_batch_of_words in words_in_dataset:
- # The value of state is updated after processing each batch of words.
- output, state = lstm(current_batch_of_words, state)
-
- # The LSTM output can be used to make next word predictions
- logits = tf.matmul(output, softmax_w) + softmax_b
- probabilities.append(tf.nn.softmax(logits))
- loss += loss_function(probabilities, target_words)
-```
-
-### Truncated Backpropagation
-
-By design, the output of a recurrent neural network (RNN) depends on arbitrarily
-distant inputs. Unfortunately, this makes backpropagation computation difficult.
-In order to make the learning process tractable, it is common practice to create
-an "unrolled" version of the network, which contains a fixed number
-(`num_steps`) of LSTM inputs and outputs. The model is then trained on this
-finite approximation of the RNN. This can be implemented by feeding inputs of
-length `num_steps` at a time and performing a backward pass after each
-such input block.
-
-Here is a simplified block of code for creating a graph which performs
-truncated backpropagation:
-
-```python
-# Placeholder for the inputs in a given iteration.
-words = tf.placeholder(tf.int32, [batch_size, num_steps])
-
-lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
-# Initial state of the LSTM memory.
-initial_state = state = lstm.zero_state(batch_size, dtype=tf.float32)
-
-for i in range(num_steps):
- # The value of state is updated after processing each batch of words.
- output, state = lstm(words[:, i], state)
-
- # The rest of the code.
- # ...
-
-final_state = state
-```
-
-And this is how to implement an iteration over the whole dataset:
-
-```python
-# A numpy array holding the state of LSTM after each batch of words.
-numpy_state = initial_state.eval()
-total_loss = 0.0
-for current_batch_of_words in words_in_dataset:
- numpy_state, current_loss = session.run([final_state, loss],
- # Initialize the LSTM state from the previous iteration.
- feed_dict={initial_state: numpy_state, words: current_batch_of_words})
- total_loss += current_loss
-```
-
-### Inputs
-
-The word IDs will be embedded into a dense representation (see the
-@{$word2vec$Vector Representations Tutorial}) before feeding to
-the LSTM. This allows the model to efficiently represent the knowledge about
-particular words. It is also easy to write:
-
-```python
-# embedding_matrix is a tensor of shape [vocabulary_size, embedding size]
-word_embeddings = tf.nn.embedding_lookup(embedding_matrix, word_ids)
-```
-
-The embedding matrix will be initialized randomly and the model will learn to
-differentiate the meaning of words just by looking at the data.
-
-### Loss Function
-
-We want to minimize the average negative log probability of the target words:
-
-$$ \text{loss} = -\frac{1}{N}\sum_{i=1}^{N} \ln p_{\text{target}_i} $$
-
-It is not very difficult to implement but the function
-`sequence_loss_by_example` is already available, so we can just use it here.
-
-The typical measure reported in the papers is average per-word perplexity (often
-just called perplexity), which is equal to
-
-$$e^{-\frac{1}{N}\sum_{i=1}^{N} \ln p_{\text{target}_i}} = e^{\text{loss}} $$
-
-and we will monitor its value throughout the training process.
-
-### Stacking multiple LSTMs
-
-To give the model more expressive power, we can add multiple layers of LSTMs
-to process the data. The output of the first layer will become the input of
-the second and so on.
-
-We have a class called `MultiRNNCell` that makes the implementation seamless:
-
-```python
-def lstm_cell():
- return tf.contrib.rnn.BasicLSTMCell(lstm_size)
-stacked_lstm = tf.contrib.rnn.MultiRNNCell(
- [lstm_cell() for _ in range(number_of_layers)])
-
-initial_state = state = stacked_lstm.zero_state(batch_size, tf.float32)
-for i in range(num_steps):
- # The value of state is updated after processing each batch of words.
- output, state = stacked_lstm(words[:, i], state)
-
- # The rest of the code.
- # ...
-
-final_state = state
-```
-
-## Run the Code
-
-Before running the code, download the PTB dataset, as discussed at the beginning
-of this tutorial. Then, extract the PTB dataset underneath your home directory
-as follows:
-
-```bsh
-tar xvfz simple-examples.tgz -C $HOME
-```
-_(Note: On Windows, you may need to use
-[other tools](https://wiki.haskell.org/How_to_unpack_a_tar_file_in_Windows).)_
-
-Now, clone the [TensorFlow models repo](https://github.com/tensorflow/models)
-from GitHub. Run the following commands:
-
-```bsh
-cd models/tutorials/rnn/ptb
-python ptb_word_lm.py --data_path=$HOME/simple-examples/data/ --model=small
-```
-
-There are 3 supported model configurations in the tutorial code: "small",
-"medium" and "large". The difference between them is in size of the LSTMs and
-the set of hyperparameters used for training.
-
-The larger the model, the better results it should get. The `small` model should
-be able to reach perplexity below 120 on the test set and the `large` one below
-80, though it might take several hours to train.
-
-## What Next?
-
-There are several tricks that we haven't mentioned that make the model better,
-including:
-
-* decreasing learning rate schedule,
-* dropout between the LSTM layers.
-
-Study the code and modify it to improve the model even further.
diff --git a/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md
deleted file mode 100644
index 37bce5b76d..0000000000
--- a/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md
+++ /dev/null
@@ -1,411 +0,0 @@
-# Recurrent Neural Networks for Drawing Classification
-
-[Quick, Draw!]: http://quickdraw.withgoogle.com
-
-[Quick, Draw!] is a game where a player is challenged to draw a number of
-objects and see if a computer can recognize the drawing.
-
-The recognition in [Quick, Draw!] is performed by a classifier that takes the
-user input, given as a sequence of strokes of points in x and y, and recognizes
-the object category that the user tried to draw.
-
-In this tutorial we'll show how to build an RNN-based recognizer for this
-problem. The model will use a combination of convolutional layers, LSTM layers,
-and a softmax output layer to classify the drawings:
-
-<center> ![RNN model structure](../../images/quickdraw_model.png) </center>
-
-The figure above shows the structure of the model that we will build in this
-tutorial. The input is a drawing that is encoded as a sequence of strokes of
-points in x, y, and n, where n indicates whether a the point is the first point
-in a new stroke.
-
-Then, a series of 1-dimensional convolutions is applied. Then LSTM layers are
-applied and the sum of the outputs of all LSTM steps is fed into a softmax layer
-to make a classification decision among the classes of drawings that we know.
-
-This tutorial uses the data from actual [Quick, Draw!] games [that is publicly
-available](https://quickdraw.withgoogle.com/data). This dataset contains of 50M
-drawings in 345 categories.
-
-## Run the tutorial code
-
-To try the code for this tutorial:
-
-1. @{$install$Install TensorFlow} if you haven't already.
-1. Download the [tutorial code]
-(https://github.com/tensorflow/models/tree/master/tutorials/rnn/quickdraw/train_model.py).
-1. [Download the data](#download-the-data) in `TFRecord` format from
- [here](http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz) and unzip it. More details about [how to
- obtain the original Quick, Draw!
- data](#optional_download_the_full_quick_draw_data) and [how to convert that
- to `TFRecord` files](#optional_converting_the_data) is available below.
-
-1. Execute the tutorial code with the following command to train the RNN-based
- model described in this tutorial. Make sure to adjust the paths to point to
- the unzipped data from the download in step 3.
-
-```shell
- python train_model.py \
- --training_data=rnn_tutorial_data/training.tfrecord-?????-of-????? \
- --eval_data=rnn_tutorial_data/eval.tfrecord-?????-of-????? \
- --classes_file=rnn_tutorial_data/training.tfrecord.classes
-```
-
-## Tutorial details
-
-### Download the data
-
-We make the data that we use in this tutorial available as `TFRecord` files
-containing `TFExamples`. You can download the data from here:
-
-http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz
-
-Alternatively you can download the original data in `ndjson` format from the
-Google cloud and convert it to the `TFRecord` files containing `TFExamples`
-yourself as described in the next section.
-
-### Optional: Download the full Quick Draw Data
-
-The full [Quick, Draw!](https://quickdraw.withgoogle.com)
-[dataset](https://quickdraw.withgoogle.com/data) is available on Google Cloud
-Storage as [ndjson](http://ndjson.org/) files separated by category. You can
-[browse the list of files in Cloud
-Console](https://console.cloud.google.com/storage/quickdraw_dataset).
-
-To download the data we recommend using
-[gsutil](https://cloud.google.com/storage/docs/gsutil_install#install) to
-download the entire dataset. Note that the original .ndjson files require
-downloading ~22GB.
-
-Then use the following command to check that your gsutil installation works and
-that you can access the data bucket:
-
-```shell
-gsutil ls -r "gs://quickdraw_dataset/full/simplified/*"
-```
-
-which will output a long list of files like the following:
-
-```shell
-gs://quickdraw_dataset/full/simplified/The Eiffel Tower.ndjson
-gs://quickdraw_dataset/full/simplified/The Great Wall of China.ndjson
-gs://quickdraw_dataset/full/simplified/The Mona Lisa.ndjson
-gs://quickdraw_dataset/full/simplified/aircraft carrier.ndjson
-...
-```
-
-Then create a folder and download the dataset there.
-
-```shell
-mkdir rnn_tutorial_data
-cd rnn_tutorial_data
-gsutil -m cp "gs://quickdraw_dataset/full/simplified/*" .
-```
-
-This download will take a while and download a bit more than 23GB of data.
-
-### Optional: Converting the data
-
-To convert the `ndjson` files to
-@{$python/python_io#TFRecords_Format_Details$TFRecord} files containing
-[`tf.train.Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
-protos run the following command.
-
-```shell
- python create_dataset.py --ndjson_path rnn_tutorial_data \
- --output_path rnn_tutorial_data
-```
-
-This will store the data in 10 shards of
-@{$python/python_io#TFRecords_Format_Details$TFRecord} files with 10000 items
-per class for the training data and 1000 items per class as eval data.
-
-This conversion process is described in more detail in the following.
-
-The original QuickDraw data is formatted as `ndjson` files where each line
-contains a JSON object like the following:
-
-```json
-{"word":"cat",
- "countrycode":"VE",
- "timestamp":"2017-03-02 23:25:10.07453 UTC",
- "recognized":true,
- "key_id":"5201136883597312",
- "drawing":[
- [
- [130,113,99,109,76,64,55,48,48,51,59,86,133,154,170,203,214,217,215,208,186,176,162,157,132],
- [72,40,27,79,82,88,100,120,134,152,165,184,189,186,179,152,131,114,100,89,76,0,31,65,70]
- ],[
- [76,28,7],
- [136,128,128]
- ],[
- [76,23,0],
- [160,164,175]
- ],[
- [87,52,37],
- [175,191,204]
- ],[
- [174,220,246,251],
- [134,132,136,139]
- ],[
- [175,255],
- [147,168]
- ],[
- [171,208,215],
- [164,198,210]
- ],[
- [130,110,108,111,130,139,139,119],
- [129,134,137,144,148,144,136,130]
- ],[
- [107,106],
- [96,113]
- ]
- ]
-}
-```
-
-For our purpose of building a classifier we only care about the fields "`word`"
-and "`drawing`". While parsing the ndjson files, we process them line by line
-using a function that converts the strokes from the `drawing` field into a
-tensor of size `[number of points, 3]` containing the differences of consecutive
-points. This function also returns the class name as a string.
-
-```python
-def parse_line(ndjson_line):
- """Parse an ndjson line and return ink (as np array) and classname."""
- sample = json.loads(ndjson_line)
- class_name = sample["word"]
- inkarray = sample["drawing"]
- stroke_lengths = [len(stroke[0]) for stroke in inkarray]
- total_points = sum(stroke_lengths)
- np_ink = np.zeros((total_points, 3), dtype=np.float32)
- current_t = 0
- for stroke in inkarray:
- for i in [0, 1]:
- np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
- current_t += len(stroke[0])
- np_ink[current_t - 1, 2] = 1 # stroke_end
- # Preprocessing.
- # 1. Size normalization.
- lower = np.min(np_ink[:, 0:2], axis=0)
- upper = np.max(np_ink[:, 0:2], axis=0)
- scale = upper - lower
- scale[scale == 0] = 1
- np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
- # 2. Compute deltas.
- np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
- return np_ink, class_name
-```
-
-Since we want the data to be shuffled for writing we read from each of the
-category files in random order and write to a random shard.
-
-For the training data we read the first 10000 items for each class and for the
-eval data we read the next 1000 items for each class.
-
-This data is then reformatted into a tensor of shape `[num_training_samples,
-max_length, 3]`. Then we determine the bounding box of the original drawing in
-screen coordinates and normalize the size such that the drawing has unit height.
-
-<center> ![Size normalization](../../images/quickdraw_sizenormalization.png) </center>
-
-Finally, we compute the differences between consecutive points and store these
-as a `VarLenFeature` in a
-[tensorflow.Example](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
-under the key `ink`. In addition we store the `class_index` as a single entry
-`FixedLengthFeature` and the `shape` of the `ink` as a `FixedLengthFeature` of
-length 2.
-
-### Defining the model
-
-To define the model we create a new `Estimator`. If you want to read more about
-estimators, we recommend @{$custom_estimators$this tutorial}.
-
-To build the model, we:
-
-1. reshape the input back into the original shape - where the mini batch is
- padded to the maximal length of its contents. In addition to the ink data we
- also have the lengths for each example and the target class. This happens in
- the function [`_get_input_tensors`](#-get-input-tensors).
-
-1. pass the input through to a series of convolution layers in
- [`_add_conv_layers`](#-add-conv-layers).
-
-1. pass the output of the convolutions into a series of bidirectional LSTM
- layers in [`_add_rnn_layers`](#-add-rnn-layers). At the end of that, the
- outputs for each time step are summed up to have a compact, fixed length
- embedding of the input.
-
-1. classify this embedding using a softmax layer in
- [`_add_fc_layers`](#-add-fc-layers).
-
-In code this looks like:
-
-```python
-inks, lengths, targets = _get_input_tensors(features, targets)
-convolved = _add_conv_layers(inks)
-final_state = _add_rnn_layers(convolved, lengths)
-logits =_add_fc_layers(final_state)
-```
-
-### _get_input_tensors
-
-To obtain the input features we first obtain the shape from the features dict
-and then create a 1D tensor of size `[batch_size]` containing the lengths of the
-input sequences. The ink is stored as a SparseTensor in the features dict which
-we convert into a dense tensor and then reshape to be `[batch_size, ?, 3]`. And
-finally, if targets were passed in we make sure they are stored as a 1D tensor
-of size `[batch_size]`
-
-In code this looks like this:
-
-```python
-shapes = features["shape"]
-lengths = tf.squeeze(
- tf.slice(shapes, begin=[0, 0], size=[params["batch_size"], 1]))
-inks = tf.reshape(
- tf.sparse_tensor_to_dense(features["ink"]),
- [params["batch_size"], -1, 3])
-if targets is not None:
- targets = tf.squeeze(targets)
-```
-
-### _add_conv_layers
-
-The desired number of convolution layers and the lengths of the filters is
-configured through the parameters `num_conv` and `conv_len` in the `params`
-dict.
-
-The input is a sequence where each point has dimensionality 3. We are going to
-use 1D convolutions where we treat the 3 input features as channels. That means
-that the input is a `[batch_size, length, 3]` tensor and the output will be a
-`[batch_size, length, number_of_filters]` tensor.
-
-```python
-convolved = inks
-for i in range(len(params.num_conv)):
- convolved_input = convolved
- if params.batch_norm:
- convolved_input = tf.layers.batch_normalization(
- convolved_input,
- training=(mode == tf.estimator.ModeKeys.TRAIN))
- # Add dropout layer if enabled and not first convolution layer.
- if i > 0 and params.dropout:
- convolved_input = tf.layers.dropout(
- convolved_input,
- rate=params.dropout,
- training=(mode == tf.estimator.ModeKeys.TRAIN))
- convolved = tf.layers.conv1d(
- convolved_input,
- filters=params.num_conv[i],
- kernel_size=params.conv_len[i],
- activation=None,
- strides=1,
- padding="same",
- name="conv1d_%d" % i)
-return convolved, lengths
-```
-
-### _add_rnn_layers
-
-We pass the output from the convolutions into bidirectional LSTM layers for
-which we use a helper function from contrib.
-
-```python
-outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn(
- cells_fw=[cell(params.num_nodes) for _ in range(params.num_layers)],
- cells_bw=[cell(params.num_nodes) for _ in range(params.num_layers)],
- inputs=convolved,
- sequence_length=lengths,
- dtype=tf.float32,
- scope="rnn_classification")
-```
-
-see the code for more details and how to use `CUDA` accelerated implementations.
-
-To create a compact, fixed-length embedding, we sum up the output of the LSTMs.
-We first zero out the regions of the batch where the sequences have no data.
-
-```python
-mask = tf.tile(
- tf.expand_dims(tf.sequence_mask(lengths, tf.shape(outputs)[1]), 2),
- [1, 1, tf.shape(outputs)[2]])
-zero_outside = tf.where(mask, outputs, tf.zeros_like(outputs))
-outputs = tf.reduce_sum(zero_outside, axis=1)
-```
-
-### _add_fc_layers
-
-The embedding of the input is passed into a fully connected layer which we then
-use as a softmax layer.
-
-```python
-tf.layers.dense(final_state, params.num_classes)
-```
-
-### Loss, predictions, and optimizer
-
-Finally, we need to add a loss, a training op, and predictions to create the
-`ModelFn`:
-
-```python
-cross_entropy = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=targets, logits=logits))
-# Add the optimizer.
-train_op = tf.contrib.layers.optimize_loss(
- loss=cross_entropy,
- global_step=tf.train.get_global_step(),
- learning_rate=params.learning_rate,
- optimizer="Adam",
- # some gradient clipping stabilizes training in the beginning.
- clip_gradients=params.gradient_clipping_norm,
- summaries=["learning_rate", "loss", "gradients", "gradient_norm"])
-predictions = tf.argmax(logits, axis=1)
-return model_fn_lib.ModelFnOps(
- mode=mode,
- predictions={"logits": logits,
- "predictions": predictions},
- loss=cross_entropy,
- train_op=train_op,
- eval_metric_ops={"accuracy": tf.metrics.accuracy(targets, predictions)})
-```
-
-### Training and evaluating the model
-
-To train and evaluate the model we can rely on the functionalities of the
-`Estimator` APIs and easily run training and evaluation with the `Experiment`
-APIs:
-
-```python
- estimator = tf.estimator.Estimator(
- model_fn=model_fn,
- model_dir=output_dir,
- config=config,
- params=model_params)
- # Train the model.
- tf.contrib.learn.Experiment(
- estimator=estimator,
- train_input_fn=get_input_fn(
- mode=tf.contrib.learn.ModeKeys.TRAIN,
- tfrecord_pattern=FLAGS.training_data,
- batch_size=FLAGS.batch_size),
- train_steps=FLAGS.steps,
- eval_input_fn=get_input_fn(
- mode=tf.contrib.learn.ModeKeys.EVAL,
- tfrecord_pattern=FLAGS.eval_data,
- batch_size=FLAGS.batch_size),
- min_eval_frequency=1000)
-```
-
-Note that this tutorial is just a quick example on a relatively small dataset to
-get you familiar with the APIs of recurrent neural networks and estimators. Such
-models can be even more powerful if you try them on a large dataset.
-
-When training the model for 1M steps you can expect to get an accuracy of
-approximately of approximately 70% on the top-1 candidate. Note that this
-accuracy is sufficient to build the quickdraw game because of the game dynamics
-the user will be able to adjust their drawing until it is ready. Also, the game
-does not use the top-1 candidate only but accepts a drawing as correct if the
-target category shows up with a score better than a fixed threshold.
diff --git a/tensorflow/examples/adding_an_op/cuda_op_test.py b/tensorflow/examples/adding_an_op/cuda_op_test.py
index 07390bc3bf..a9aaa81e3f 100644
--- a/tensorflow/examples/adding_an_op/cuda_op_test.py
+++ b/tensorflow/examples/adding_an_op/cuda_op_test.py
@@ -26,7 +26,7 @@ class AddOneTest(tf.test.TestCase):
def test(self):
if tf.test.is_built_with_cuda():
- with self.test_session():
+ with self.cached_session():
result = cuda_op.add_one([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
diff --git a/tensorflow/examples/adding_an_op/fact_test.py b/tensorflow/examples/adding_an_op/fact_test.py
index f7f17e5180..11163e7ba5 100644
--- a/tensorflow/examples/adding_an_op/fact_test.py
+++ b/tensorflow/examples/adding_an_op/fact_test.py
@@ -24,7 +24,7 @@ import tensorflow as tf
class FactTest(tf.test.TestCase):
def test(self):
- with self.test_session():
+ with self.cached_session():
print(tf.user_ops.my_fact().eval())
diff --git a/tensorflow/examples/adding_an_op/zero_out_1_test.py b/tensorflow/examples/adding_an_op/zero_out_1_test.py
index fac486100d..342d3a020c 100644
--- a/tensorflow/examples/adding_an_op/zero_out_1_test.py
+++ b/tensorflow/examples/adding_an_op/zero_out_1_test.py
@@ -28,7 +28,7 @@ from tensorflow.examples.adding_an_op import zero_out_op_1
class ZeroOut1Test(tf.test.TestCase):
def test(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
diff --git a/tensorflow/examples/adding_an_op/zero_out_2_test.py b/tensorflow/examples/adding_an_op/zero_out_2_test.py
index 217bbbcffa..4504597817 100644
--- a/tensorflow/examples/adding_an_op/zero_out_2_test.py
+++ b/tensorflow/examples/adding_an_op/zero_out_2_test.py
@@ -29,17 +29,17 @@ from tensorflow.examples.adding_an_op import zero_out_op_2
class ZeroOut2Test(tf.test.TestCase):
def test(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def test_2d(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
def test_grad(self):
- with self.test_session():
+ with self.cached_session():
shape = (5,)
x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)
@@ -47,7 +47,7 @@ class ZeroOut2Test(tf.test.TestCase):
self.assertLess(err, 1e-4)
def test_grad_2d(self):
- with self.test_session():
+ with self.cached_session():
shape = (2, 3)
x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)
diff --git a/tensorflow/examples/adding_an_op/zero_out_3_test.py b/tensorflow/examples/adding_an_op/zero_out_3_test.py
index 01280caf49..15d62495aa 100644
--- a/tensorflow/examples/adding_an_op/zero_out_3_test.py
+++ b/tensorflow/examples/adding_an_op/zero_out_3_test.py
@@ -26,23 +26,23 @@ from tensorflow.examples.adding_an_op import zero_out_op_3
class ZeroOut3Test(tf.test.TestCase):
def test(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def testAttr(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
def testNegative(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
result.eval()
def testLarge(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
with self.assertRaisesOpError("preserve_index out of range"):
result.eval()
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index dac9b7ab82..82bc3ffda9 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -121,10 +121,6 @@ the Android NDK and SDK must be installed on your system.
2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
current recommended version is 14b, which may be found
[here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
-
- * NDK 16, the revision released in November 2017, is **incompatible** with
- Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918).
-
3. The Android SDK and build tools may be obtained
[here](https://developer.android.com/tools/revisions/build-tools.html), or
alternatively as part of [Android
@@ -132,10 +128,6 @@ the Android NDK and SDK must be installed on your system.
23 is required to build the TF Android demo (though it will run on API >= 21
devices).
- - The Android Studio SDK Manager's NDK installer will install the latest
- revision of the NDK, which is **incompatible** with Bazel. You'll need
- to download an older version manually, as (2) suggests.
-
##### Edit WORKSPACE
NOTE: As long as you have the SDK and NDK installed, the `./configure` script
diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
index b81d9e0c12..06048ecfd3 100644
--- a/tensorflow/examples/android/jni/object_tracking/jni_utils.h
+++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
@@ -60,4 +60,4 @@ class JniLongField {
jfieldID field_ID_;
};
-#endif
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/logging.h b/tensorflow/examples/android/jni/object_tracking/logging.h
index 852a749399..24d05e3398 100644
--- a/tensorflow/examples/android/jni/object_tracking/logging.h
+++ b/tensorflow/examples/android/jni/object_tracking/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
-#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
#include <android/log.h>
#include <string.h>
@@ -118,4 +118,4 @@ void LogPrintF(const int severity, const char* format, ...);
#endif
-#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOGGING_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h
index 5e81c49080..4bc4d5bc9e 100644
--- a/tensorflow/examples/android/jni/object_tracking/object_model.h
+++ b/tensorflow/examples/android/jni/object_tracking/object_model.h
@@ -19,8 +19,8 @@ limitations under the License.
// Contains ObjectModelBase declaration.
-#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
-#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
#ifdef __RENDER_OPENGL__
#include <GLES/gl.h>
@@ -99,4 +99,4 @@ class ObjectModel : public ObjectModelBase {
} // namespace tf_tracking
-#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_MODEL_H_
diff --git a/tensorflow/examples/android/jni/rgb2yuv.h b/tensorflow/examples/android/jni/rgb2yuv.h
index 13ac4148f3..ff720fda7d 100755
--- a/tensorflow/examples/android/jni/rgb2yuv.h
+++ b/tensorflow/examples/android/jni/rgb2yuv.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef ORG_TENSORFLOW_JNI_IMAGEUTILS_RGB2YUV_H_
-#define ORG_TENSORFLOW_JNI_IMAGEUTILS_RGB2YUV_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
#include <stdint.h>
@@ -32,4 +32,4 @@ void ConvertRGB565ToYUV420SP(const uint16_t* const input, uint8_t* const output,
}
#endif
-#endif // ORG_TENSORFLOW_JNI_IMAGEUTILS_RGB2YUV_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_RGB2YUV_H_
diff --git a/tensorflow/examples/android/jni/yuv2rgb.h b/tensorflow/examples/android/jni/yuv2rgb.h
index 7d2b8ab7f4..fab462f0e1 100644
--- a/tensorflow/examples/android/jni/yuv2rgb.h
+++ b/tensorflow/examples/android/jni/yuv2rgb.h
@@ -16,8 +16,8 @@ limitations under the License.
// This is a collection of routines which converts various YUV image formats
// to (A)RGB.
-#ifndef ORG_TENSORFLOW_JNI_IMAGEUTILS_YUV2RGB_H_
-#define ORG_TENSORFLOW_JNI_IMAGEUTILS_YUV2RGB_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
#include <stdint.h>
@@ -54,4 +54,4 @@ void ConvertYUV420SPToRGB565(const uint8_t* const input, uint16_t* const output,
}
#endif
-#endif // ORG_TENSORFLOW_JNI_IMAGEUTILS_YUV2RGB_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_YUV2RGB_H_
diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.h b/tensorflow/examples/ios/benchmark/ios_image_load.h
index 78eaded8d7..3f94984692 100644
--- a/tensorflow/examples/ios/benchmark/ios_image_load.h
+++ b/tensorflow/examples/ios/benchmark/ios_image_load.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
-#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
#include <vector>
@@ -24,4 +24,4 @@ std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);
-#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
+#endif // TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/examples/ios/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h
index 87a847e145..f10b0b983a 100644
--- a/tensorflow/examples/ios/camera/ios_image_load.h
+++ b/tensorflow/examples/ios/camera/ios_image_load.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
-#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
+#ifndef TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
+#define TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
#include <vector>
@@ -24,4 +24,4 @@ std::vector<tensorflow::uint8> LoadImageFromFile(const char* file_name,
int* out_height,
int* out_channels);
-#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_IMAGE_LOAD_H_
+#endif // TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index baa65d3243..ee2927d0a5 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -106,7 +106,7 @@ static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
"' expected ", file_size, " got ",
data.size());
}
- output->scalar<string>()() = data.ToString();
+ output->scalar<string>()() = string(data);
return Status::OK();
}
diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py
index 4d1454be0d..c63d4c3c7d 100644
--- a/tensorflow/examples/speech_commands/models.py
+++ b/tensorflow/examples/speech_commands/models.py
@@ -634,7 +634,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings,
label_count = model_settings['label_count']
final_fc_weights = tf.get_variable(
name='final_fc_weights',
- initializer=tf.truncated_normal(stddev=0.01),
+ initializer=tf.truncated_normal_initializer(stddev=0.01),
shape=[second_fc_output_channels, label_count])
final_fc_bias = tf.get_variable(
name='final_fc_bias',
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 3775af4c77..e755c37039 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3355,56 +3355,105 @@ func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Computes the mean along sparse segments of a tensor.
+// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// For each entry in `x`, calculates the number of `1` (on) bits in the binary
+// representation of that entry.
//
-// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into
+// `int32` or `int64` and perform the bitcount on the result, than to feed in
+// 8- or 16-bit inputs and then aggregate the resulting counts.
+func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "PopulationCount",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
//
// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+// mean_gradients: A tensor with shape=[logits_dimension] with mean of gradients for a first node.
+// mean_hessians: A tensor with shape=[logits_dimension] mean of hessians for a first node.
+// l1: l1 regularization factor on leaf weights, per instance based.
+// l2: l2 regularization factor on leaf weights, per instance based.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns Bool, whether to continue bias centering.
+func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_gradients tf.Output, mean_hessians tf.Output, l1 tf.Output, l2 tf.Output) (continue_centering tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "SparseSegmentMean",
+ Type: "BoostedTreesCenterBias",
Input: []tf.Input{
- data, indices, segment_ids,
+ tree_ensemble_handle, mean_gradients, mean_hessians, l1, l2,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Pop the element at the top of the stack.
+// Runs multiple additive regression ensemble predictors on input instances and
+//
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
//
// Arguments:
-// handle: The handle to a stack.
-// elem_type: The type of the elem that is popped.
//
-// Returns The tensor that is popped from the top of the stack.
-func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
+//
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"elem_type": elem_type}
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
opspec := tf.OpSpec{
- Type: "StackPopV2",
+ Type: "BoostedTreesTrainingPredict",
Input: []tf.Input{
- handle,
+ tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Serializes the tree ensemble to a proto.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
+func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesSerializeEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
}
// Computes the sum along sparse segments of a tensor.
@@ -4037,78 +4086,6 @@ func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
return op.Output(0)
}
-// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
-type FusedBatchNormAttr func(optionalAttr)
-
-// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
-//
-// value: A small float number added to the variance of x.
-// If not specified, defaults to 0.0001
-func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["epsilon"] = value
- }
-}
-
-// FusedBatchNormDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
-// If not specified, defaults to "NHWC"
-func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// FusedBatchNormIsTraining sets the optional is_training attribute to value.
-//
-// value: A bool value to indicate the operation is for training (default)
-// or inference.
-// If not specified, defaults to true
-func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["is_training"] = value
- }
-}
-
-// Batch normalization.
-//
-// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-// The size of 1D Tensors matches the dimension C of the 4D Tensors.
-//
-// Arguments:
-// x: A 4D Tensor for input data.
-// scale: A 1D Tensor for scaling factor, to scale the normalized x.
-// offset: A 1D Tensor for offset, to shift to the normalized x.
-// mean: A 1D Tensor for population mean. Used for inference only;
-// must be empty for training.
-// variance: A 1D Tensor for population variance. Used for inference only;
-// must be empty for training.
-//
-// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
-// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
-// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
-// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
-// in the cuDNN case), to be reused in the gradient computation.
-func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FusedBatchNorm",
- Input: []tf.Input{
- x, scale, offset, mean, variance,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -4700,51 +4677,6 @@ func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output)
return op.Output(0)
}
-// Computes the mean along sparse segments of a tensor.
-//
-// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Arguments:
-//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which has size
-// `num_segments`.
-func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentMeanWithNumSegments",
- Input: []tf.Input{
- data, indices, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes hyperbolic cosine of x element-wise.
-func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Cosh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that emits each dim-0 slice of `components` once.
func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -8230,47 +8162,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
-// RandomPoissonAttr is an optional argument to RandomPoisson.
-type RandomPoissonAttr func(optionalAttr)
-
-// RandomPoissonSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed2(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Use RandomPoissonV2 instead.
-//
-// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
-func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoisson",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the element-wise sum of a list of tensors.
//
// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
@@ -8419,6 +8310,377 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
return op.Output(0)
}
+// Returns the truth value of (x > y) element-wise.
+//
+// *NOTE*: `Greater` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Greater",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
+type ResourceSparseApplyRMSPropAttr func(optionalAttr)
+
+// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+// var_: Should be from a Variable().
+// ms: Should be from a Variable().
+// mom: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// rho: Decay rate. Must be a scalar.
+//
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var, ms and mom.
+//
+// Returns the created operation.
+func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyRMSProp",
+ Input: []tf.Input{
+ var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
+type SampleDistortedBoundingBoxAttr func(optionalAttr)
+
+// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to non-zero, the random number
+// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
+// seed.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
+//
+// value: The cropped area of the image must contain at least this
+// fraction of any bounding box supplied. The value of this parameter should be
+// non-negative. In the case of 0, the cropped area does not need to overlap
+// any of the bounding boxes supplied.
+// If not specified, defaults to 0.1
+func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["min_object_covered"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
+//
+// value: The cropped area of the image must have an aspect ratio =
+// width / height within this range.
+// If not specified, defaults to <f:0.75 f:1.33 >
+func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["aspect_ratio_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
+//
+// value: The cropped area of the image must contain a fraction of the
+// supplied image within this range.
+// If not specified, defaults to <f:0.05 f:1 >
+func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["area_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
+//
+// value: Number of attempts at generating a cropped region of the image
+// of the specified constraints. After `max_attempts` failures, return the entire
+// image.
+// If not specified, defaults to 100
+func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["max_attempts"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
+//
+// value: Controls behavior if no bounding boxes supplied.
+// If true, assume an implicit bounding box covering the whole input. If false,
+// raise an error.
+// If not specified, defaults to false
+func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["use_image_if_no_bounding_boxes"] = value
+ }
+}
+
+// Generate a single randomly distorted bounding box for an image.
+//
+// Bounding box annotations are often supplied in addition to ground-truth labels
+// in image recognition or object localization tasks. A common technique for
+// training such a system is to randomly distort an image while preserving
+// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
+// localization of an object, i.e. bounding box, given an `image_size`,
+// `bounding_boxes` and a series of constraints.
+//
+// The output of this Op is a single bounding box that may be used to crop the
+// original image. The output is returned as 3 tensors: `begin`, `size` and
+// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
+// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+// what the bounding box looks like.
+//
+// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example,
+//
+// ```python
+// # Generate a single distorted bounding box.
+// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+// tf.shape(image),
+// bounding_boxes=bounding_boxes)
+//
+// # Draw the bounding box in an image summary.
+// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+// bbox_for_draw)
+// tf.summary.image('images_with_box', image_with_box)
+//
+// # Employ the bounding box to distort the image.
+// distorted_image = tf.slice(image, begin, size)
+// ```
+//
+// Note that if no bounding box information is available, setting
+// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
+// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
+// false and no bounding boxes are supplied, an error is raised.
+//
+// Arguments:
+// image_size: 1-D, containing `[height, width, channels]`.
+// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
+// associated with the image.
+//
+// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
+// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+// Provide as input to `tf.image.draw_bounding_boxes`.
+func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SampleDistortedBoundingBox",
+ Input: []tf.Input{
+ image_size, bounding_boxes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes sigmoid of `x` element-wise.
+//
+// Specifically, `y = 1 / (1 + exp(-x))`.
+func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sigmoid",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
+type FusedBatchNormAttr func(optionalAttr)
+
+// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
+//
+// value: A small float number added to the variance of x.
+// If not specified, defaults to 0.0001
+func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["epsilon"] = value
+ }
+}
+
+// FusedBatchNormDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
+// If not specified, defaults to "NHWC"
+func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// FusedBatchNormIsTraining sets the optional is_training attribute to value.
+//
+// value: A bool value to indicate the operation is for training (default)
+// or inference.
+// If not specified, defaults to true
+func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["is_training"] = value
+ }
+}
+
+// Batch normalization.
+//
+// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+// The size of 1D Tensors matches the dimension C of the 4D Tensors.
+//
+// Arguments:
+// x: A 4D Tensor for input data.
+// scale: A 1D Tensor for scaling factor, to scale the normalized x.
+// offset: A 1D Tensor for offset, to shift to the normalized x.
+// mean: A 1D Tensor for population mean. Used for inference only;
+// must be empty for training.
+// variance: A 1D Tensor for population variance. Used for inference only;
+// must be empty for training.
+//
+// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
+// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
+// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
+// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
+// in the cuDNN case), to be reused in the gradient computation.
+func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FusedBatchNorm",
+ Input: []tf.Input{
+ x, scale, offset, mean, variance,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
+// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
+type RandomStandardNormalAttr func(optionalAttr)
+
+// RandomStandardNormalSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random values from a normal distribution.
+//
+// The generated values will have mean 0 and standard deviation 1.
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// dtype: The type of the output.
+//
+// Returns A tensor of the specified shape filled with random normal values.
+func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomStandardNormal",
+ Input: []tf.Input{
+ shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -8661,28 +8923,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount).
-//
-// For each entry in `x`, calculates the number of `1` (on) bits in the binary
-// representation of that entry.
-//
-// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into
-// `int32` or `int64` and perform the bitcount on the result, than to feed in
-// 8- or 16-bit inputs and then aggregate the resulting counts.
-func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "PopulationCount",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Broadcasts a tensor value to one or more other devices.
func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
if scope.Err() != nil {
@@ -8993,21 +9233,6 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value
return op.Output(0)
}
-// Computes tan of x element-wise.
-func Tan(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tan",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -9048,6 +9273,21 @@ func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, fe
return scope.AddOperation(opspec)
}
+// Computes tan of x element-wise.
+func Tan(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tan",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// EncodeJpegAttr is an optional argument to EncodeJpeg.
type EncodeJpegAttr func(optionalAttr)
@@ -11427,6 +11667,85 @@ func FakeQuantWithMinMaxVars(scope *Scope, inputs tf.Output, min tf.Output, max
return op.Output(0)
}
+// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
+type ResourceScatterNdUpdateAttr func(optionalAttr)
+
+// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value.
+//
+// value: An optional bool. Defaults to True. If True, the assignment will
+// be protected by a lock; otherwise the behavior is undefined,
+// but may exhibit less contention.
+// If not specified, defaults to true
+func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Applies sparse `updates` to individual values or slices within a given
+//
+// variable according to `indices`.
+//
+// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+//
+// `indices` must be integer tensor, containing indices into `ref`.
+// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+//
+// The innermost dimension of `indices` (with length `K`) corresponds to
+// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+// dimension of `ref`.
+//
+// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+//
+// ```
+// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+// ```
+//
+// For example, say we want to update 4 scattered elements to a rank-1 tensor to
+// 8 elements. In Python, that update would look like this:
+//
+// ```python
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+// indices = tf.constant([[4], [3], [1] ,[7]])
+// updates = tf.constant([9, 10, 11, 12])
+// update = tf.scatter_nd_update(ref, indices, updates)
+// with tf.Session() as sess:
+// print sess.run(update)
+// ```
+//
+// The resulting update to ref would look like this:
+//
+// [1, 11, 3, 10, 9, 6, 7, 12]
+//
+// See @{tf.scatter_nd} for more details about how to make updates to
+// slices.
+//
+// Arguments:
+// ref: A resource handle. Must be from a VarHandleOp.
+// indices: A Tensor. Must be one of the following types: int32, int64.
+// A tensor of indices into ref.
+// updates: A Tensor. Must have the same type as ref. A tensor of updated
+// values to add to ref.
+//
+// Returns the created operation.
+func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceScatterNdUpdate",
+ Input: []tf.Input{
+ ref, indices, updates,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -12371,263 +12690,6 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.
return values
}
-// Inverse fast Fourier transform.
-//
-// Computes the inverse 1-dimensional discrete Fourier transform over the
-// inner-most dimension of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most
-// dimension of `input` is replaced with its inverse 1D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.ifft
-// @end_compatibility
-func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IFFT",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
-type ResourceSparseApplyRMSPropAttr func(optionalAttr)
-
-// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the RMSProp algorithm.
-//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
-//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
-//
-// Arguments:
-// var_: Should be from a Variable().
-// ms: Should be from a Variable().
-// mom: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// rho: Decay rate. Must be a scalar.
-//
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var, ms and mom.
-//
-// Returns the created operation.
-func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceSparseApplyRMSProp",
- Input: []tf.Input{
- var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Returns the truth value of (x > y) element-wise.
-//
-// *NOTE*: `Greater` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Greater",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
-type SampleDistortedBoundingBoxAttr func(optionalAttr)
-
-// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to non-zero, the random number
-// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
-// seed.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
-//
-// value: The cropped area of the image must contain at least this
-// fraction of any bounding box supplied. The value of this parameter should be
-// non-negative. In the case of 0, the cropped area does not need to overlap
-// any of the bounding boxes supplied.
-// If not specified, defaults to 0.1
-func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["min_object_covered"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
-//
-// value: The cropped area of the image must have an aspect ratio =
-// width / height within this range.
-// If not specified, defaults to <f:0.75 f:1.33 >
-func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["aspect_ratio_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
-//
-// value: The cropped area of the image must contain a fraction of the
-// supplied image within this range.
-// If not specified, defaults to <f:0.05 f:1 >
-func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["area_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
-//
-// value: Number of attempts at generating a cropped region of the image
-// of the specified constraints. After `max_attempts` failures, return the entire
-// image.
-// If not specified, defaults to 100
-func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["max_attempts"] = value
- }
-}
-
-// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
-//
-// value: Controls behavior if no bounding boxes supplied.
-// If true, assume an implicit bounding box covering the whole input. If false,
-// raise an error.
-// If not specified, defaults to false
-func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["use_image_if_no_bounding_boxes"] = value
- }
-}
-
-// Generate a single randomly distorted bounding box for an image.
-//
-// Bounding box annotations are often supplied in addition to ground-truth labels
-// in image recognition or object localization tasks. A common technique for
-// training such a system is to randomly distort an image while preserving
-// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
-// localization of an object, i.e. bounding box, given an `image_size`,
-// `bounding_boxes` and a series of constraints.
-//
-// The output of this Op is a single bounding box that may be used to crop the
-// original image. The output is returned as 3 tensors: `begin`, `size` and
-// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
-// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
-// what the bounding box looks like.
-//
-// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example,
-//
-// ```python
-// # Generate a single distorted bounding box.
-// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
-// tf.shape(image),
-// bounding_boxes=bounding_boxes)
-//
-// # Draw the bounding box in an image summary.
-// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
-// bbox_for_draw)
-// tf.summary.image('images_with_box', image_with_box)
-//
-// # Employ the bounding box to distort the image.
-// distorted_image = tf.slice(image, begin, size)
-// ```
-//
-// Note that if no bounding box information is available, setting
-// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
-// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
-// false and no bounding boxes are supplied, an error is raised.
-//
-// Arguments:
-// image_size: 1-D, containing `[height, width, channels]`.
-// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
-// associated with the image.
-//
-// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
-// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
-// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
-// Provide as input to `tf.image.draw_bounding_boxes`.
-func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SampleDistortedBoundingBox",
- Input: []tf.Input{
- image_size, bounding_boxes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// LRNAttr is an optional argument to LRN.
type LRNAttr func(optionalAttr)
@@ -12977,85 +13039,6 @@ func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataT
return op.Output(0), op.Output(1), op.Output(2)
}
-// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
-type ResourceScatterNdUpdateAttr func(optionalAttr)
-
-// ResourceScatterNdUpdateUseLocking sets the optional use_locking attribute to value.
-//
-// value: An optional bool. Defaults to True. If True, the assignment will
-// be protected by a lock; otherwise the behavior is undefined,
-// but may exhibit less contention.
-// If not specified, defaults to true
-func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Applies sparse `updates` to individual values or slices within a given
-//
-// variable according to `indices`.
-//
-// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
-//
-// `indices` must be integer tensor, containing indices into `ref`.
-// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
-//
-// The innermost dimension of `indices` (with length `K`) corresponds to
-// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
-// dimension of `ref`.
-//
-// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
-//
-// ```
-// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-// ```
-//
-// For example, say we want to update 4 scattered elements to a rank-1 tensor to
-// 8 elements. In Python, that update would look like this:
-//
-// ```python
-// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
-// indices = tf.constant([[4], [3], [1] ,[7]])
-// updates = tf.constant([9, 10, 11, 12])
-// update = tf.scatter_nd_update(ref, indices, updates)
-// with tf.Session() as sess:
-// print sess.run(update)
-// ```
-//
-// The resulting update to ref would look like this:
-//
-// [1, 11, 3, 10, 9, 6, 7, 12]
-//
-// See @{tf.scatter_nd} for more details about how to make updates to
-// slices.
-//
-// Arguments:
-// ref: A resource handle. Must be from a VarHandleOp.
-// indices: A Tensor. Must be one of the following types: int32, int64.
-// A tensor of indices into ref.
-// updates: A Tensor. Must have the same type as ref. A tensor of updated
-// values to add to ref.
-//
-// Returns the created operation.
-func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdUpdateAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceScatterNdUpdate",
- Input: []tf.Input{
- ref, indices, updates,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// SqueezeAttr is an optional argument to Squeeze.
type SqueezeAttr func(optionalAttr)
@@ -14517,6 +14500,47 @@ func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// RandomPoissonAttr is an optional argument to RandomPoisson.
+type RandomPoissonAttr func(optionalAttr)
+
+// RandomPoissonSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed2(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Use RandomPoissonV2 instead.
+//
+// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
+func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoisson",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
type LogUniformCandidateSamplerAttr func(optionalAttr)
@@ -16257,76 +16281,6 @@ func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
-// Computes sigmoid of `x` element-wise.
-//
-// Specifically, `y = 1 / (1 + exp(-x))`.
-func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Sigmoid",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
-type RandomStandardNormalAttr func(optionalAttr)
-
-// RandomStandardNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from a normal distribution.
-//
-// The generated values will have mean 0 and standard deviation 1.
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// dtype: The type of the output.
-//
-// Returns A tensor of the specified shape filled with random normal values.
-func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomStandardNormal",
- Input: []tf.Input{
- shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Component-wise divides a SparseTensor by a dense Tensor.
//
// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
@@ -16678,30 +16632,6 @@ func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataTyp
return key, values
}
-// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// mean_gradients: A tensor with shape=[logits_dimension] with mean of gradients for a first node.
-// mean_hessians: A tensor with shape=[logits_dimension] mean of hessians for a first node.
-// l1: l1 regularization factor on leaf weights, per instance based.
-// l2: l2 regularization factor on leaf weights, per instance based.
-//
-// Returns Bool, whether to continue bias centering.
-func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_gradients tf.Output, mean_hessians tf.Output, l1 tf.Output, l2 tf.Output) (continue_centering tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesCenterBias",
- Input: []tf.Input{
- tree_ensemble_handle, mean_gradients, mean_hessians, l1, l2,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SerializeManySparseAttr is an optional argument to SerializeManySparse.
type SerializeManySparseAttr func(optionalAttr)
@@ -17181,6 +17111,34 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// Inverse fast Fourier transform.
+//
+// Computes the inverse 1-dimensional discrete Fourier transform over the
+// inner-most dimension of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most
+// dimension of `input` is replaced with its inverse 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.ifft
+// @end_compatibility
+func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IFFT",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// 2D fast Fourier transform.
//
// Computes the 2-dimensional discrete Fourier transform over the inner-most
@@ -17472,26 +17430,6 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i
return op.Output(0)
}
-// Serializes the tree ensemble to a proto.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
-func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesSerializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// StageSizeAttr is an optional argument to StageSize.
type StageSizeAttr func(optionalAttr)
@@ -17689,123 +17627,6 @@ func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp
return op.Output(0)
}
-// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize.
-type CudnnRNNParamsSizeAttr func(optionalAttr)
-
-// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value.
-// If not specified, defaults to "lstm"
-func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["rnn_mode"] = value
- }
-}
-
-// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value.
-// If not specified, defaults to "linear_input"
-func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["input_mode"] = value
- }
-}
-
-// CudnnRNNParamsSizeDirection sets the optional direction attribute to value.
-// If not specified, defaults to "unidirectional"
-func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["direction"] = value
- }
-}
-
-// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["dropout"] = value
- }
-}
-
-// CudnnRNNParamsSizeSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Computes size of weights that can be used by a Cudnn RNN model.
-//
-// Return the params size that can be used by the Cudnn RNN model. Subsequent
-// weight allocation and initialization should use this size.
-//
-// num_layers: Specifies the number of layers in the RNN model.
-// num_units: Specifies the size of the hidden state.
-// input_size: Specifies the size of the input state.
-// rnn_mode: Indicates the type of the RNN model.
-// input_mode: Indicate whether there is a linear projection between the input and
-// The actual computation before the first layer. 'skip_input' is only allowed
-// when input_size == num_units; 'auto_select' implies 'skip_input' when
-// input_size == num_units; otherwise, it implies 'linear_input'.
-// direction: Indicates whether a bidirectional model will be used.
-// dir = (direction == bidirectional) ? 2 : 1
-// dropout: dropout probability. When set to 0., dropout is disabled.
-// seed: the 1st part of a seed to initialize dropout.
-// seed2: the 2nd part of a seed to initialize dropout.
-// params_size: The size of the params buffer that should be allocated and
-// initialized for this RNN model. Note that this params buffer may not be
-// compatible across GPUs. Please use CudnnRNNParamsWeights and
-// CudnnRNNParamsBiases to save and restore them in a way that is compatible
-// across different runs.
-func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"T": T, "S": S}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "CudnnRNNParamsSize",
- Input: []tf.Input{
- num_layers, num_units, input_size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes gradients for SparseSegmentMean.
-//
-// Returns tensor "output" with same shape as grad, except for dimension 0 whose
-// value is output_dim0.
-//
-// Arguments:
-// grad: gradient propagated to the SparseSegmentMean op.
-// indices: indices passed to the corresponding SparseSegmentMean op.
-// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
-// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
-func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentMeanGrad",
- Input: []tf.Input{
- grad, indices, segment_ids, output_dim0,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the set of files matching one or more glob patterns.
//
// Note that this routine only supports wildcard characters in the
@@ -20538,6 +20359,220 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// Computes the mean along sparse segments of a tensor.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMean",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Pop the element at the top of the stack.
+//
+// Arguments:
+// handle: The handle to a stack.
+// elem_type: The type of the elem that is popped.
+//
+// Returns The tensor that is popped from the top of the stack.
+func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"elem_type": elem_type}
+ opspec := tf.OpSpec{
+ Type: "StackPopV2",
+ Input: []tf.Input{
+ handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes hyperbolic cosine of x element-wise.
+func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Cosh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the mean along sparse segments of a tensor.
+//
+// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which has size
+// `num_segments`.
+func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMeanWithNumSegments",
+ Input: []tf.Input{
+ data, indices, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// CudnnRNNParamsSizeAttr is an optional argument to CudnnRNNParamsSize.
+type CudnnRNNParamsSizeAttr func(optionalAttr)
+
+// CudnnRNNParamsSizeRnnMode sets the optional rnn_mode attribute to value.
+// If not specified, defaults to "lstm"
+func CudnnRNNParamsSizeRnnMode(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["rnn_mode"] = value
+ }
+}
+
+// CudnnRNNParamsSizeInputMode sets the optional input_mode attribute to value.
+// If not specified, defaults to "linear_input"
+func CudnnRNNParamsSizeInputMode(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["input_mode"] = value
+ }
+}
+
+// CudnnRNNParamsSizeDirection sets the optional direction attribute to value.
+// If not specified, defaults to "unidirectional"
+func CudnnRNNParamsSizeDirection(value string) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["direction"] = value
+ }
+}
+
+// CudnnRNNParamsSizeDropout sets the optional dropout attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeDropout(value float32) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["dropout"] = value
+ }
+}
+
+// CudnnRNNParamsSizeSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeSeed(value int64) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// CudnnRNNParamsSizeSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func CudnnRNNParamsSizeSeed2(value int64) CudnnRNNParamsSizeAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Computes size of weights that can be used by a Cudnn RNN model.
+//
+// Return the params size that can be used by the Cudnn RNN model. Subsequent
+// weight allocation and initialization should use this size.
+//
+// num_layers: Specifies the number of layers in the RNN model.
+// num_units: Specifies the size of the hidden state.
+// input_size: Specifies the size of the input state.
+// rnn_mode: Indicates the type of the RNN model.
+// input_mode: Indicate whether there is a linear projection between the input and
+// The actual computation before the first layer. 'skip_input' is only allowed
+// when input_size == num_units; 'auto_select' implies 'skip_input' when
+// input_size == num_units; otherwise, it implies 'linear_input'.
+// direction: Indicates whether a bidirectional model will be used.
+// dir = (direction == bidirectional) ? 2 : 1
+// dropout: dropout probability. When set to 0., dropout is disabled.
+// seed: the 1st part of a seed to initialize dropout.
+// seed2: the 2nd part of a seed to initialize dropout.
+// params_size: The size of the params buffer that should be allocated and
+// initialized for this RNN model. Note that this params buffer may not be
+// compatible across GPUs. Please use CudnnRNNParamsWeights and
+// CudnnRNNParamsBiases to save and restore them in a way that is compatible
+// across different runs.
+func CudnnRNNParamsSize(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, T tf.DataType, S tf.DataType, optional ...CudnnRNNParamsSizeAttr) (params_size tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"T": T, "S": S}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "CudnnRNNParamsSize",
+ Input: []tf.Input{
+ num_layers, num_units, input_size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes gradients for SparseSegmentMean.
+//
+// Returns tensor "output" with same shape as grad, except for dimension 0 whose
+// value is output_dim0.
+//
+// Arguments:
+// grad: gradient propagated to the SparseSegmentMean op.
+// indices: indices passed to the corresponding SparseSegmentMean op.
+// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
+func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMeanGrad",
+ Input: []tf.Input{
+ grad, indices, segment_ids, output_dim0,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along sparse segments of a tensor divided by the sqrt of N.
//
// N is the size of the segment being reduced.
@@ -23396,6 +23431,8 @@ func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, it
// Computes the matrix exponential of one or more square matrices:
//
+// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
+//
// \\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
//
// The exponential is computed using a combination of the scaling and squaring
@@ -26669,41 +26706,6 @@ func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, o
return op.Output(0)
}
-// Runs multiple additive regression ensemble predictors on input instances and
-//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
-//
-// Arguments:
-//
-// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesTrainingPredict",
- Input: []tf.Input{
- tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MapSizeAttr is an optional argument to MapSize.
type MapSizeAttr func(optionalAttr)
@@ -31776,54 +31778,6 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
return op.Output(0), op.Output(1), op.Output(2)
}
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a tf.Example proto (as a string) into typed tensors.
//
// Arguments:
@@ -31894,6 +31848,54 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.
return sparse_indices, sparse_values, sparse_shapes, dense_values
}
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
+//
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
+//
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Deserializes a serialized tree ensemble config and replaces current tree
//
// ensemble.
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 035077e1e0..e1bf2c7dba 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -32,8 +32,8 @@
<module>libtensorflow_jni_gpu</module>
<module>tensorflow</module>
<module>proto</module>
- <module>hadoop</module>
- <module>spark-connector</module>
+ <module>tensorflow-hadoop</module>
+ <module>spark-tensorflow-connector</module>
</modules>
<!-- Two profiles are used:
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 8c4c9d498c..75c6cff529 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -41,7 +41,7 @@ clean() {
mvn -q clean
rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target \
libtensorflow/src libtensorflow/target tensorflow-android/target proto/src proto/target \
- hadoop/src hadoop/target spark-connector/src spark-connector/target
+ tensorflow-hadoop/src tensorflow-hadoop/target spark-tensorflow-connector/src spark-tensorflow-connector/target
}
update_version_in_pom() {
@@ -170,8 +170,8 @@ generate_java_protos() {
# is updated for each module.
download_tf_ecosystem() {
ECOSYSTEM_DIR="/tmp/tensorflow-ecosystem"
- HADOOP_DIR="${DIR}/hadoop"
- SPARK_DIR="${DIR}/spark-connector"
+ HADOOP_DIR="${DIR}/tensorflow-hadoop"
+ SPARK_DIR="${DIR}/spark-tensorflow-connector"
# Clean any previous attempts
rm -rf "${ECOSYSTEM_DIR}"
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index 31e39c588a..1b7995be2c 100644
--- a/tensorflow/java/maven/spark-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
- <artifactId>spark-connector_2.11</artifactId>
+ <artifactId>spark-tensorflow-connector_2.11</artifactId>
<packaging>jar</packaging>
<version>1.10.0</version>
<name>spark-tensorflow-connector</name>
@@ -120,7 +120,7 @@
<artifactSet>
<includes>
<include>com.google.protobuf:protobuf-java</include>
- <include>org.tensorflow:hadoop</include>
+ <include>org.tensorflow:tensorflow-hadoop</include>
<include>org.tensorflow:proto</include>
</includes>
</artifactSet>
@@ -305,7 +305,7 @@
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
- <artifactId>hadoop</artifactId>
+ <artifactId>tensorflow-hadoop</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index e0409fa41b..0fe6f4dce4 100644
--- a/tensorflow/java/maven/hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -3,7 +3,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
- <artifactId>hadoop</artifactId>
+ <artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
<version>1.10.0</version>
<name>tensorflow-hadoop</name>
@@ -15,7 +15,7 @@
<maven.compiler.source>1.6</maven.compiler.source>
<maven.compiler.target>1.6</maven.compiler.target>
<hadoop.version>2.6.0</hadoop.version>
- <protobuf.version>3.3.1</protobuf.version>
+ <protobuf.version>3.5.1</protobuf.version>
<junit.version>4.11</junit.version>
</properties>
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index d9d6f8adc8..d39653ef41 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/core/framework/types.h"
+
namespace tensorflow {
namespace java {
@@ -95,6 +97,34 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
+ static Type ForDataType(DataType data_type) {
+ switch (data_type) {
+ case DataType::DT_BOOL:
+ return Class("Boolean");
+ case DataType::DT_STRING:
+ return Class("String");
+ case DataType::DT_FLOAT:
+ return Class("Float");
+ case DataType::DT_DOUBLE:
+ return Class("Double");
+ case DataType::DT_UINT8:
+ return Class("UInt8", "org.tensorflow.types");
+ case DataType::DT_INT32:
+ return Class("Integer");
+ case DataType::DT_INT64:
+ return Class("Long");
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ // Falling through...
+ default:
+ // Any other datatypes does not have a equivalent in Java and must
+ // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+ return Wildcard();
+ }
+ }
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index d5bd99bdd9..5d6387e88e 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
+#include <utility>
#include <vector>
#include "tensorflow/core/framework/op_gen_lib.h"
@@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
+ if (attribute.has_default_value() &&
+ attribute.type().kind() == Type::GENERIC) {
+ out->push_back(Type::ForDataType(attribute.default_value()->type()));
+ }
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
out->push_back(optional_attribute.var().type());
@@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
}
}
+void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
+ std::map<string, Type> default_types,
+ SourceWriter* writer) {
+ // Build the return type for the secondary factory, replacing generic
+ // parameters with their default value if any
+ Type return_type = Type::Class(op_class.name(), op_class.package());
+ for (const Type& parameter : op_class.parameters()) {
+ if (parameter.kind() == Type::GENERIC &&
+ default_types.find(parameter.name()) != default_types.end()) {
+ return_type.add_parameter(default_types.at(parameter.name()));
+ } else {
+ return_type.add_parameter(parameter);
+ }
+ }
+ Method factory = Method::Create("create", return_type);
+ Javadoc factory_doc = Javadoc::Create(
+ "Factory method to create a class to wrap a new " + op_class.name() +
+ " operation to the graph, using "
+ "default output types.");
+ Variable scope =
+ Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
+ AddArgument(scope, "current graph scope", &factory, &factory_doc);
+ std::stringstream factory_statement;
+ factory_statement << "return create(scope";
+ for (const ArgumentSpec& input : op.inputs()) {
+ AddArgument(input.var(), input.description(), &factory, &factory_doc);
+ factory_statement << ", " << input.var().name();
+ }
+ for (const AttributeSpec& attr : op.attributes()) {
+ // Only add attributes that are not types or have no default value to the
+ // signature of the secondary factory
+ factory_statement << ", ";
+ if (attr.type().kind() == Type::GENERIC &&
+ default_types.find(attr.type().name()) != default_types.end()) {
+ factory_statement << default_types.at(attr.type().name()).name()
+ << ".class";
+ } else {
+ AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+ factory_statement << attr.var().name();
+ }
+ }
+ if (!op.optional_attributes().empty()) {
+ Variable options_var = Variable::Varargs("options", Type::Class("Options"));
+ AddArgument(options_var, "carries optional attributes values", &factory,
+ &factory_doc);
+ factory_statement << ", " << options_var.name();
+ }
+ factory_doc.add_tag("return", "a new instance of " + op_class.name());
+
+ writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
+ writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
+ writer->EndMethod();
+}
+
void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
@@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
+ std::map<string, Type> default_types;
for (const AttributeSpec& attr : op.attributes()) {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+ // If this attribute is a type with a default value, save its value
+ // for passing it implicitly in a secondary factory method
+ if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
+ Type default_type = Type::ForDataType(attr.default_value()->type());
+ if (!default_type.wildcard()) {
+ default_types.insert(std::make_pair(attr.type().name(), default_type));
+ }
+ }
}
if (!op.optional_attributes().empty()) {
AddArgument(Variable::Varargs("options", Type::Class("Options")),
@@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
.Append("(opBuilder.build());")
.EndLine();
writer->EndMethod();
+
+ // If this operation has type attributes with a default value, create a
+ // second factory method that infers those values implicitly
+ if (!default_types.empty()) {
+ RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
+ }
}
void RenderConstructor(const OpSpec& op, const Type& op_class,
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
index 941ab2699c..4f5a491d25 100644
--- a/tensorflow/java/src/gen/cc/op_specs.cc
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = true;
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
}
-
Type type = Type::Wildcard();
if (arg_def.type() != DataType::DT_INVALID) {
- // resolve type from DataType
- switch (arg_def.type()) {
- case DataType::DT_BOOL:
- type = Type::Class("Boolean");
- break;
- case DataType::DT_STRING:
- type = Type::Class("String");
- break;
- case DataType::DT_FLOAT:
- type = Type::Class("Float");
- break;
- case DataType::DT_DOUBLE:
- type = Type::Class("Double");
- break;
- case DataType::DT_UINT8:
- type = Type::Class("UInt8", "org.tensorflow.types");
- break;
- case DataType::DT_INT32:
- type = Type::Class("Integer");
- break;
- case DataType::DT_INT64:
- type = Type::Class("Long");
- break;
- case DataType::DT_RESOURCE:
- // TODO(karllessard) create a Resource utility class that could be
- // used to store a resource and its type (passed in a second argument).
- // For now, we need to force a wildcard and we will unfortunately lose
- // track of the resource type.
- break;
- default:
- // Any other datatypes does not have a equivalent in Java and must
- // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
- break;
- }
+ type = Type::ForDataType(arg_def.type());
+
} else if (!arg_def.type_attr().empty()) {
// resolve type from attribute (if already visited, retrieve its type)
if (IsAttributeVisited(arg_def.type_attr())) {
@@ -337,7 +304,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
- ? Type::Class("Class").add_parameter(types.first)
+ ? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
@@ -346,7 +313,8 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
types.first, types.second, ParseDocumentation(attr_api_def.description()),
- iterable, attr_api_def.has_default_value());
+ iterable,
+ attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
}
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index 30ecb8ce53..4adcfca96a 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -94,18 +94,21 @@ class AttributeSpec {
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
// description: a description of this attribute, in javadoc
// iterable: true if this attribute is a list
- // has_default_value: true if this attribute has a default value if not set
+ // default_value: default value for this attribute or nullptr if none. Any
+ // value referenced by this pointer must outlive the lifetime
+ // of the AttributeSpec. This is guaranteed if the value is
+ // issued by an OpDef of the global OpRegistry.
AttributeSpec(const string& op_def_name, const Variable& var,
const Type& type, const Type& jni_type,
const string& description, bool iterable,
- bool has_default_value)
+ const AttrValue* default_value)
: op_def_name_(op_def_name),
var_(var),
type_(type),
description_(description),
iterable_(iterable),
jni_type_(jni_type),
- has_default_value_(has_default_value) {}
+ default_value_(default_value) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -113,7 +116,8 @@ class AttributeSpec {
const string& description() const { return description_; }
bool iterable() const { return iterable_; }
const Type& jni_type() const { return jni_type_; }
- bool has_default_value() const { return has_default_value_; }
+ bool has_default_value() const { return default_value_ != nullptr; }
+ const AttrValue* default_value() const { return default_value_; }
private:
const string op_def_name_;
@@ -122,7 +126,7 @@ class AttributeSpec {
const string description_;
const bool iterable_;
const Type jni_type_;
- const bool has_default_value_;
+ const AttrValue* default_value_;
};
class OpSpec {
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index 8e5fba7e32..a71b367691 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include <string>
#include <algorithm>
#include <list>
-#include <string>
#include "tensorflow/java/src/gen/cc/source_writer.h"
diff --git a/tensorflow/java/src/main/native/exception_jni.h b/tensorflow/java/src/main/native/exception_jni.h
index 28f26d7ebf..465281f804 100644
--- a/tensorflow/java/src/main/native/exception_jni.h
+++ b/tensorflow/java/src/main/native/exception_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_EXCEPTION_JNI_H_
-#define TENSORFLOW_JAVA_EXCEPTION_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
#include <jni.h>
@@ -39,4 +39,4 @@ bool throwExceptionIfNotOK(JNIEnv* env, const TF_Status* status);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_EXCEPTION_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_EXCEPTION_JNI_H_
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 215695cdfd..efed23f83b 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_GRAPH_JNI_H_
-#define TENSORFLOW_JAVA_GRAPH_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
#include <jni.h>
@@ -85,4 +85,4 @@ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_GRAPH_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_GRAPH_JNI_H_
diff --git a/tensorflow/java/src/main/native/operation_builder_jni.h b/tensorflow/java/src/main/native/operation_builder_jni.h
index cf0abe4829..1cda7acea8 100644
--- a/tensorflow/java/src/main/native/operation_builder_jni.h
+++ b/tensorflow/java/src/main/native/operation_builder_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_OPERATION_BUILDER_JNI_H_
-#define TENSORFLOW_JAVA_OPERATION_BUILDER_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_
#include <jni.h>
@@ -188,4 +188,4 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList(
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_OPERATION_BUILDER_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_BUILDER_JNI_H_
diff --git a/tensorflow/java/src/main/native/operation_jni.h b/tensorflow/java/src/main/native/operation_jni.h
index 6f379256d2..56da2ebaee 100644
--- a/tensorflow/java/src/main/native/operation_jni.h
+++ b/tensorflow/java/src/main/native/operation_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_OPERATION_JNI_H_
-#define TENSORFLOW_JAVA_OPERATION_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_
#include <jni.h>
@@ -87,4 +87,4 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv *,
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_OPERATION_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_OPERATION_JNI_H_
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
index a4b05d0409..e8f28dd670 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_SAVEDMODELBUNDLE_JNI_H_
-#define TENSORFLOW_JAVA_SAVEDMODELBUNDLE_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
#include <jni.h>
@@ -34,4 +34,4 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SAVEDMODELBUNDLE_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SAVED_MODEL_BUNDLE_JNI_H_
diff --git a/tensorflow/java/src/main/native/session_jni.h b/tensorflow/java/src/main/native/session_jni.h
index 54c9c0aa4d..1cc196bdc8 100644
--- a/tensorflow/java/src/main/native/session_jni.h
+++ b/tensorflow/java/src/main/native/session_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_SESSION_JNI_H_
-#define TENSORFLOW_JAVA_SESSION_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
#include <jni.h>
@@ -59,4 +59,4 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_SESSION_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
diff --git a/tensorflow/java/src/main/native/tensor_jni.h b/tensorflow/java/src/main/native/tensor_jni.h
index a300936884..4cf682548e 100644
--- a/tensorflow/java/src/main/native/tensor_jni.h
+++ b/tensorflow/java/src/main/native/tensor_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_TENSOR_JNI_H_
-#define TENSORFLOW_JAVA_TENSOR_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
@@ -153,4 +153,4 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv *, jclass,
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_TENSOR_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
diff --git a/tensorflow/java/src/main/native/tensorflow_jni.h b/tensorflow/java/src/main/native/tensorflow_jni.h
index c0c9322020..d7c44fb0e2 100644
--- a/tensorflow/java/src/main/native/tensorflow_jni.h
+++ b/tensorflow/java/src/main/native/tensorflow_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_TENSORFLOW_JNI_H_
-#define TENSORFLOW_JAVA_TENSORFLOW_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
#include <jni.h>
@@ -67,4 +67,4 @@ Java_org_tensorflow_TensorFlow_libraryOpList(JNIEnv *, jclass, jlong);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // TENSORFLOW_JAVA_TENSORFLOW_JNI_H_
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h
index 352298e7de..d1e1b93878 100644
--- a/tensorflow/java/src/main/native/utils_jni.h
+++ b/tensorflow/java/src/main/native/utils_jni.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_
-#define TENSORFLOW_JAVA_UTILS_JNI_H_
+#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_UTILS_JNI_H_
+#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_UTILS_JNI_H_
#include <jni.h>
@@ -30,4 +30,4 @@ void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */
+#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_UTILS_JNI_H_
diff --git a/tensorflow/js/BUILD b/tensorflow/js/BUILD
new file mode 100644
index 0000000000..ad0dc44f54
--- /dev/null
+++ b/tensorflow/js/BUILD
@@ -0,0 +1,52 @@
+# Description:
+# JavaScript/TypeScript code generation for TensorFlow.js
+
+visibility = [
+ "//tensorflow:internal",
+]
+
+package(default_visibility = visibility)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+cc_library(
+ name = "ts_op_gen",
+ srcs = [
+ "ops/ts_op_gen.cc",
+ ],
+ hdrs = [
+ "ops/ts_op_gen.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "ts_op_gen_test",
+ srcs = [
+ "ops/ts_op_gen.cc",
+ "ops/ts_op_gen.h",
+ "ops/ts_op_gen_test.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
new file mode 100644
index 0000000000..fb93bb6d8e
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -0,0 +1,290 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+#include <unordered_map>
+
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+
+static bool IsListAttr(const OpDef_ArgDef& arg) {
+ return !arg.type_list_attr().empty() || !arg.number_attr().empty();
+}
+
+// Struct to hold a combo OpDef and ArgDef for a given Op argument:
+struct ArgDefs {
+ ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
+ : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
+
+ const OpDef::ArgDef& op_def_arg;
+ const ApiDef::Arg& api_def_arg;
+};
+
+// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
+struct OpAttrs {
+ OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
+ : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
+
+ const OpDef::AttrDef& op_def_attr;
+ const ApiDef::Attr& api_def_attr;
+};
+
+// Helper class to generate TypeScript code for a given OpDef:
+class GenTypeScriptOp {
+ public:
+ GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
+ ~GenTypeScriptOp();
+
+ // Returns the generated code as a string:
+ string Code();
+
+ private:
+ void ProcessArgs();
+ void ProcessAttrs();
+ void AddAttrForArg(const string& attr, int arg_index);
+ string InputForAttr(const OpDef::AttrDef& op_def_attr);
+
+ void AddMethodSignature();
+ void AddOpAttrs();
+ void AddMethodReturnAndClose();
+
+ const OpDef& op_def_;
+ const ApiDef& api_def_;
+
+ // Placeholder string for all generated code:
+ string result_;
+
+ // Holds in-order vector of Op inputs:
+ std::vector<ArgDefs> input_op_args_;
+
+ // Holds in-order vector of Op attributes:
+ std::vector<OpAttrs> op_attrs_;
+
+ // Stores attributes-to-arguments by name:
+ typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
+ AttrArgIdxMap attr_arg_idx_map_;
+
+ // Holds number of outputs:
+ int num_outputs_;
+};
+
+GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
+ : op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
+
+GenTypeScriptOp::~GenTypeScriptOp() {}
+
+string GenTypeScriptOp::Code() {
+ ProcessArgs();
+ ProcessAttrs();
+
+ // Generate exported function for Op:
+ AddMethodSignature();
+ AddOpAttrs();
+ AddMethodReturnAndClose();
+
+ strings::StrAppend(&result_, "\n");
+ return result_;
+}
+
+void GenTypeScriptOp::ProcessArgs() {
+ for (int i = 0; i < api_def_.arg_order_size(); i++) {
+ auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
+ if (op_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find OpDef::ArgDef for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+ auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
+ if (api_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find ApiDef::Arg for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+
+ // Map attr names to arg indexes:
+ if (!op_def_arg->type_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_attr(), i);
+ } else if (!op_def_arg->type_list_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_list_attr(), i);
+ }
+ if (!op_def_arg->number_attr().empty()) {
+ AddAttrForArg(op_def_arg->number_attr(), i);
+ }
+
+ input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
+ }
+
+ num_outputs_ = api_def_.out_arg_size();
+}
+
+void GenTypeScriptOp::ProcessAttrs() {
+ for (int i = 0; i < op_def_.attr_size(); i++) {
+ op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
+ }
+}
+
+void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
+ // Keep track of attributes-to-arguments by name. These will be used for
+ // construction Op attributes that require information about the inputs.
+ auto iter = attr_arg_idx_map_.find(attr);
+ if (iter == attr_arg_idx_map_.end()) {
+ attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
+ } else {
+ iter->second.push_back(arg_index);
+ }
+}
+
+string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
+ string inputs;
+ auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
+ if (arg_list != attr_arg_idx_map_.end()) {
+ for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
+ ++iter) {
+ strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
+ }
+ }
+ return inputs;
+}
+
+void GenTypeScriptOp::AddMethodSignature() {
+ strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
+ "(");
+
+ bool is_first = true;
+ for (auto& in_arg : input_op_args_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ", ");
+ }
+
+ auto op_def_arg = in_arg.op_def_arg;
+
+ strings::StrAppend(&result_, op_def_arg.name(), ": ");
+ if (IsListAttr(op_def_arg)) {
+ strings::StrAppend(&result_, "tfc.Tensor[]");
+ } else {
+ strings::StrAppend(&result_, "tfc.Tensor");
+ }
+ }
+
+ if (num_outputs_ == 1) {
+ strings::StrAppend(&result_, "): tfc.Tensor {\n");
+ } else {
+ strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
+ }
+}
+
+void GenTypeScriptOp::AddOpAttrs() {
+ strings::StrAppend(&result_, " const opAttrs = [\n");
+
+ bool is_first = true;
+ for (auto& attr : op_attrs_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ",\n");
+ }
+
+ // Append 4 spaces to start:
+ strings::StrAppend(&result_, " ");
+
+ if (attr.op_def_attr.type() == "type") {
+ // Type OpAttributes can be generated from a helper function:
+ strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
+ attr.op_def_attr.name(), "', ",
+ InputForAttr(attr.op_def_attr), ")");
+ } else if (attr.op_def_attr.type() == "int") {
+ strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
+ strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
+ strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
+ ".length}");
+ }
+ }
+ strings::StrAppend(&result_, "\n ];\n");
+}
+
+void GenTypeScriptOp::AddMethodReturnAndClose() {
+ strings::StrAppend(&result_, " return null;\n}\n");
+}
+
+void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
+ GenTypeScriptOp ts_op(op_def, api_def);
+ TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
+}
+
+void StartFile(WritableFile* ts_file) {
+ const string header =
+ R"header(/**
+ * @license
+ * Copyright 2018 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+// This file is MACHINE GENERATED! Do not edit
+
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
+
+)header";
+
+ TF_CHECK_OK(ts_file->Append(header));
+}
+
+} // namespace
+
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename) {
+ Env* env = Env::Default();
+
+ std::unique_ptr<WritableFile> ts_file = nullptr;
+ TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
+
+ StartFile(ts_file.get());
+
+ for (const auto& op_def : ops.op()) {
+ // Skip deprecated ops
+ if (op_def.has_deprecation() &&
+ op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
+ continue;
+ }
+
+ const auto* api_def = api_def_map.GetApiDef(op_def.name());
+ if (api_def->visibility() == ApiDef::VISIBLE) {
+ WriteTSOp(op_def, *api_def, ts_file.get());
+ }
+ }
+
+ TF_CHECK_OK(ts_file->Close());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/warn_about_ints.h b/tensorflow/js/ops/ts_op_gen.h
index 20666b230e..fcd46a17a7 100644
--- a/tensorflow/core/kernels/warn_about_ints.h
+++ b/tensorflow/js/ops/ts_op_gen.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
-#define TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_
+#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_
-#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// Warn if a kernel is being created using ints
-// TODO(irving): Remove in TF 2.0 along with the bad op registrations.
-void WarnAboutInts(OpKernelConstruction* context);
+// Generated code is written to the file ts_filename:
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename);
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
new file mode 100644
index 0000000000..03241689b5
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -0,0 +1,246 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+void ExpectContainsStr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
+ EXPECT_FALSE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+constexpr char kBaseOpDef[] = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ description: "Images to process."
+ }
+ input_arg {
+ name: "dim"
+ description: "Description for dim."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for images"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ }
+ }
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Summary for op Foo."
+ description: "Description for op Foo."
+}
+)";
+
+// Generate TypeScript code
+void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
+ string* ts_file_text) {
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(
+ op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
+ ApiDefMap api_def_map(op_defs);
+
+ if (!api_def_str.empty()) {
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
+ }
+
+ const string& tmpdir = testing::TmpDir();
+ const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
+
+ WriteTSOps(op_defs, api_def_map, ts_file_path);
+ TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
+}
+
+TEST(TsOpGenTest, TestImports) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expected = R"(
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, InputSingleAndList) {
+ const string api_def = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ }
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText("", api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, TestVisibility) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ visibility: HIDDEN
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText("", api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+)";
+ ExpectDoesNotContainStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, SkipDeprecated) {
+ const string op_def = R"(
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ description: "Description for input."
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
+
+ ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
+}
+
+TEST(TsOpGenTest, MultiOutput) {
+ const string op_def = R"(
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
+
+ const string expected = R"(
+export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, OpAttrs) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expectedFooAttrs = R"(
+ const opAttrs = [
+ createTensorsTypeOpAttr('T', images),
+ {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
+ ];
+)";
+
+ ExpectContainsStr(ts_file_text, expectedFooAttrs);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7c2e8b003b..19729813a1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -44,6 +44,10 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_mpi_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_gdr_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
py_library(
name = "python",
@@ -74,6 +78,7 @@ py_library(
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
"//tensorflow/python/tools/api/generator:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
],
deps = [
":array_ops",
@@ -130,6 +135,7 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python/compat",
"//tensorflow/python/data",
+ "//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras",
"//tensorflow/python/ops/distributions",
@@ -138,6 +144,8 @@ py_library(
"//tensorflow/python/ops/parallel_for",
"//tensorflow/python/profiler",
"//tensorflow/python/saved_model",
+ "//tensorflow/python/tools:component_api_helper",
+ "//tensorflow/python/tools/api/generator:create_python_api",
"//third_party/py/numpy",
],
)
@@ -717,7 +725,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
- ":cond_v2_impl",
":dtypes",
":framework_ops",
":graph_to_function_def",
@@ -1342,6 +1349,19 @@ py_test(
)
py_test(
+ name = "framework_ops_enable_eager_test",
+ size = "small",
+ srcs = ["framework/ops_enable_eager_test.py"],
+ main = "framework/ops_enable_eager_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework",
+ ":platform_test",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_test(
name = "framework_tensor_shape_test",
size = "small",
srcs = ["framework/tensor_shape_test.py"],
@@ -2071,6 +2091,18 @@ py_library(
srcs = [
"ops/custom_gradient.py",
"ops/gradients.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gradients_impl",
+ "//tensorflow/python/eager:function",
+ "//tensorflow/python/eager:tape",
+ ],
+)
+
+py_library(
+ name = "gradients_impl",
+ srcs = [
"ops/gradients_impl.py",
],
srcs_version = "PY2AND3",
@@ -2608,6 +2640,19 @@ py_library(
],
)
+py_test(
+ name = "sparse_ops_test",
+ srcs = ["ops/sparse_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constant_op",
+ ":dtypes",
+ ":framework_test_lib",
+ ":sparse_ops",
+ ":sparse_tensor",
+ ],
+)
+
py_library(
name = "spectral_grad",
srcs = ["ops/spectral_grad.py"],
@@ -2779,11 +2824,13 @@ py_library(
srcs = ["ops/state_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":array_ops",
":framework_ops",
+ ":math_ops_gen",
":resource_variable_ops_gen",
":state_ops_gen",
":tensor_shape",
- "//tensorflow/python/eager:context",
+ ":util",
],
)
@@ -3226,7 +3273,6 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
- "saver",
":array_ops",
":array_ops_gen",
":checkpoint_management",
@@ -3250,6 +3296,7 @@ py_library(
":random_ops",
":resource_variable_ops",
":resources",
+ ":saver",
":sdca_ops",
":session",
":sparse_ops",
@@ -3265,6 +3312,7 @@ py_library(
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@@ -3742,6 +3790,7 @@ tf_py_wrap_cc(
"framework/python_op_gen.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
+ "grappler/graph_analyzer.i",
"grappler/item.i",
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
@@ -3800,6 +3849,7 @@ tf_py_wrap_cc(
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
+ "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:reader_base",
@@ -3812,7 +3862,9 @@ tf_py_wrap_cc(
tf_additional_plugin_deps() +
tf_additional_verbs_deps() +
tf_additional_mpi_deps() +
- tf_additional_gdr_deps()),
+ tf_additional_gdr_deps()) + if_ngraph([
+ "@ngraph_tf//:ngraph_tf",
+ ]),
)
# ** Targets for Windows build (start) **
@@ -4342,6 +4394,7 @@ cuda_py_tests(
"training/ftrl_test.py",
"training/gradient_descent_test.py",
"training/learning_rate_decay_test.py",
+ "training/learning_rate_decay_v2_test.py",
"training/momentum_test.py",
"training/optimizer_test.py",
"training/proximal_adagrad_test.py",
@@ -4658,7 +4711,10 @@ py_test(
size = "medium",
srcs = ["training/monitored_session_test.py"],
srcs_version = "PY2AND3",
- tags = ["notsan"], # b/67945581
+ tags = [
+ "no_pip",
+ "notsan", # b/67945581
+ ],
deps = [
":array_ops",
":checkpoint_management",
@@ -4676,6 +4732,7 @@ py_test(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
@@ -5495,6 +5552,18 @@ py_test(
],
)
+py_binary(
+ name = "graph_analyzer",
+ srcs = [
+ "grappler/graph_analyzer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ ":pywrap_tensorflow_internal",
+ ],
+)
+
pyx_library(
name = "framework_fast_tensor_util",
srcs = ["framework/fast_tensor_util.pyx"],
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index a2ab63bb48..4921ecc43c 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -48,6 +48,13 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.tools import component_api_helper
+component_api_helper.package_hook(
+ parent_package_str='tensorflow.python',
+ child_package_str=(
+ 'tensorflow_estimator.python.estimator'))
+del component_api_helper
+
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
from tensorflow.core.framework.node_def_pb2 import *
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
index c94767a03c..80a256bf7a 100644
--- a/tensorflow/python/client/client_lib.py
+++ b/tensorflow/python/client/client_lib.py
@@ -15,7 +15,7 @@
"""Support for launching graphs and executing operations.
-See the @{$python/client} guide.
+See the [Client](https://tensorflow.org/api_guides/python/client) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 28f26ad27e..ae0ad27f15 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1132,7 +1132,7 @@ class BaseSession(SessionInterface):
for details of the allowable fetch types.
feed_list: (Optional.) A list of `feed_dict` keys. See
`tf.Session.run` for details of the allowable feed key types.
- accept_options: (Optional.) Iff `True`, the returned `Callable` will be
+ accept_options: (Optional.) If `True`, the returned `Callable` will be
able to accept `tf.RunOptions` and `tf.RunMetadata` as optional
keyword arguments `options` and `run_metadata`, respectively, with
the same syntax and semantics as `tf.Session.run`, which is useful
@@ -1302,9 +1302,7 @@ class BaseSession(SessionInterface):
node_def = op.node_def
except KeyError:
pass
- if (self._config is not None and
- self._config.experimental.client_handles_error_formatting):
- message = error_interpolation.interpolate(message, self._graph)
+ message = error_interpolation.interpolate(message, self._graph)
raise type(e)(node_def, op, message)
def _extend_graph(self):
@@ -1500,7 +1498,7 @@ class Session(BaseSession):
Args:
target: (Optional.) The execution engine to connect to.
Defaults to using an in-process engine. See
- @{$distributed$Distributed TensorFlow}
+ [Distributed TensorFlow](https://tensorflow.org/deploy/distributed)
for more examples.
graph: (Optional.) The `Graph` to be launched (described above).
config: (Optional.) A
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 052be68385..4afc6399d5 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -49,6 +49,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_control_flow_ops
+# Import gradients to resolve circular imports
+from tensorflow.python.ops import gradients # pylint: disable=unused-import
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
@@ -1760,7 +1762,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
feed_fn1, feed_fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np1 = np.array([1.0, 1.5, 2.0, 2.5])
np2 = np.array([3.0, 3.5, 4.0, 4.5])
squared_tensor = SquaredTensor(np2)
@@ -1920,7 +1922,7 @@ class SessionTest(test_util.TensorFlowTestCase):
pass
def testAutoConvertAndCheckData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = array_ops.placeholder(dtype=dtypes.string)
with self.assertRaisesRegexp(
TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 8d7063ff6e..60ebae19ab 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -14,8 +14,8 @@
# ==============================================================================
"""Utilities for API compatibility between TensorFlow release versions.
-See
-@{$guide/version_compat#backward_and_partial_forward_compatibility}
+See [Version
+Compatibility](https://tensorflow.org/guide/version_compat#backward_forward)
"""
from __future__ import absolute_import
@@ -26,14 +26,15 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 16)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 11)
@tf_export("compat.forward_compatible")
def forward_compatible(year, month, day):
"""Return true if the forward compatibility window has expired.
- See @{$guide/version_compat#backward_and_partial_forward_compatibility}.
+ See [Version
+ compatibility](https://tensorflow.org/guide/version_compat#backward_forward).
Forward-compatibility refers to scenarios where the producer of a TensorFlow
model (a GraphDef or SavedModel) is compiled against a version of the
@@ -91,7 +92,8 @@ def forward_compatible(year, month, day):
def forward_compatibility_horizon(year, month, day):
"""Context manager for testing forward compatibility of generated graphs.
- See @{$guide/version_compat#backward_and_partial_forward_compatibility}.
+ See [Version
+ compatibility](https://tensorflow.org/guide/version_compat#backward_forward).
To ensure forward compatibility of generated graphs (see `forward_compatible`)
with older binaries, new features can be gated with:
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index 3b9bf2469e..f8b561205e 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""`tf.data.Dataset` API for input pipelines.
-See @{$guide/datasets$Importing Data} for an overview.
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 23c98247bf..631b87a718 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -137,6 +137,8 @@ tf_py_test(
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -154,6 +156,7 @@ tf_py_test(
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 89de55dd4f..c48708a2b9 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -82,7 +82,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -111,7 +111,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -131,7 +131,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -158,7 +158,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -188,7 +188,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
@@ -214,7 +214,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -262,7 +262,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.DataLossError):
sess.run(get_next)
@@ -318,7 +318,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -342,7 +342,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(
@@ -381,7 +381,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
padded_dataset = dataset.padded_batch(
2, padded_shapes=([None], [None]), padding_values=('', 0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_element = padded_dataset.make_one_shot_iterator().get_next()
sess.run(next_element)
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index 4f7fd3566e..d5f5b2fe05 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -68,7 +68,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First run without caching to collect the "ground truth".
sess.run(init_fifo_op)
elements = []
@@ -132,7 +132,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next1 = iterator1.get_next()
get_next2 = iterator2.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
sess.run(get_next1) # this should succeed
@@ -162,7 +162,7 @@ class FileCacheDatasetTest(test.TestCase):
get_next1 = iterator1.get_next()
get_next2 = iterator2.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
elements = []
@@ -217,7 +217,7 @@ class MemoryCacheDatasetTest(test.TestCase):
uncached_iterator = uncached_dataset.make_initializable_iterator()
uncached_next = uncached_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(repeat_count.initializer)
sess.run(cached_iterator.initializer)
@@ -261,7 +261,7 @@ class MemoryCacheDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an empty upstream and a missing cache file (should
# throw errors.OutOfRangeError immediately).
sess.run(init_cache_op, feed_dict={count_placeholder: 0})
@@ -278,7 +278,7 @@ class MemoryCacheDatasetTest(test.TestCase):
i1 = d1.make_initializable_iterator()
i2 = d2.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(i1.initializer)
self.assertEqual(1, sess.run(i1.get_next()))
@@ -304,7 +304,7 @@ class MemoryCacheDatasetTest(test.TestCase):
expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i, expected in enumerate(expected_values):
self.assertEqual(expected, sess.run(n),
"Unexpected value at index %s" % i)
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index e16aa82d4d..5dfb84f28e 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -49,7 +49,7 @@ class ConcatenateDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(9):
result = sess.run(get_next)
@@ -83,7 +83,7 @@ class ConcatenateDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(9):
result = sess.run(get_next)
@@ -110,8 +110,24 @@ class ConcatenateDatasetTest(test.TestCase):
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
to_concatenate_components)
- with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ with self.assertRaisesRegexp(TypeError, "have different types"):
+ input_dataset.concatenate(dataset_to_concatenate)
+
+ def testConcatenateDatasetDifferentKeys(self):
+ input_components = {
+ "foo": np.array([[1], [2], [3], [4]]),
+ "bar": np.array([[12], [13], [14], [15]])
+ }
+ to_concatenate_components = {
+ "foo": np.array([[1], [2], [3], [4]]),
+ "baz": np.array([[5], [6], [7], [8]])
+ }
+
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
+ dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
+ to_concatenate_components)
+
+ with self.assertRaisesRegexp(TypeError, "have different types"):
input_dataset.concatenate(dataset_to_concatenate)
def testConcatenateDatasetDifferentType(self):
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index ea5b41e5d8..e43564a2eb 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -50,7 +50,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -84,7 +84,7 @@ class DatasetConstructorTest(test.TestCase):
[tensor_shape.TensorShape(c.dense_shape) for c in components],
[shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -115,7 +115,7 @@ class DatasetConstructorTest(test.TestCase):
if sparse_tensor.is_sparse(c) else c.shape for c in components
], [shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
@@ -142,7 +142,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(4):
results = sess.run(get_next)
@@ -172,7 +172,7 @@ class DatasetConstructorTest(test.TestCase):
[tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
[shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
@@ -232,7 +232,7 @@ class DatasetConstructorTest(test.TestCase):
if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
], [shape for shape in iterator.output_shapes])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [
(sparse_tensor.SparseTensorValue(
@@ -283,7 +283,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual((), iterator.output_shapes["foo"])
self.assertEqual((1,), iterator.output_shapes["bar"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(3):
results = sess.run(get_next)
@@ -300,7 +300,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
# Test with sparse tensor in the appropriate order.
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index fb55ae1400..cd0c1ddf1e 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -44,7 +44,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
sess.run(init_op)
for _ in range(num_repeats):
@@ -61,7 +61,7 @@ class DatasetConstructorTest(test.TestCase):
.make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
@@ -131,7 +131,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
@@ -190,7 +190,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
@@ -213,7 +213,7 @@ class DatasetConstructorTest(test.TestCase):
self.assertEqual(dtype, get_next.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for expected in [[1], [2], [3]]:
next_val = sess.run(get_next)
@@ -234,7 +234,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for expected in [b"foo", b"bar", b"baz"]:
next_val = sess.run(get_next)
@@ -255,7 +255,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -278,7 +278,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -302,7 +302,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((1, 2), sess.run(get_next))
self.assertEqual((3, 4), sess.run(get_next))
@@ -327,7 +327,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(1, sess.run(get_next))
self.assertAllEqual([2, 3], sess.run(get_next))
@@ -347,7 +347,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(0, sess.run(get_next))
self.assertAllEqual(1, sess.run(get_next))
@@ -405,7 +405,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
for x in expected:
@@ -434,7 +434,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
expected = [(0, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"),
@@ -468,7 +468,7 @@ class DatasetConstructorTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(37, sess.run(get_next))
self.assertAllEqual(37, sess.run(get_next))
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 2c4c11e132..239aa85175 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -27,7 +27,7 @@ class DatasetOpsTest(test.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 4f2216f0a3..19944d389f 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -59,7 +59,7 @@ class FilterDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test that we can dynamically feed a different modulus value for each
# iterator.
def do_test(count_val, modulus_val):
@@ -84,7 +84,7 @@ class FilterDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
self.assertEqual(1, sess.run(get_next))
self.assertEqual(3, sess.run(get_next))
@@ -98,7 +98,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
if (i ** 2) % 2 == 0:
@@ -123,7 +123,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(input_data[0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -151,7 +151,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(5):
actual = sess.run(get_next)
@@ -169,7 +169,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, True), sess.run(get_next))
@@ -181,7 +181,7 @@ class FilterDatasetTest(test.TestCase):
lambda x: math_ops.equal(x % 2, 0))
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
next_elements = [iterator.get_next() for iterator in iterators]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 350234a839..1123cbff62 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -43,7 +43,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in repeats:
for _ in range(i):
@@ -62,7 +62,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for row in repeats:
for i in row:
@@ -113,7 +113,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for _ in range(i ** 2):
@@ -137,7 +137,7 @@ class FlatMapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index 7dbf7268d7..a35cee594a 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -19,8 +19,10 @@ from __future__ import print_function
import itertools
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase):
+class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
@@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase):
expected_elements, self._interleave(input_lists, 7, 2)):
self.assertEqual(expected, produced)
- def testInterleaveDataset(self):
- input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- block_length = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_count = 2
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_values)
- .repeat(repeat_count)
- .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
+ @parameterized.named_parameters(
+ ("1", np.int64([4, 5, 6]), 1, 3, None),
+ ("2", np.int64([4, 5, 6]), 1, 3, 1),
+ ("3", np.int64([4, 5, 6]), 2, 1, None),
+ ("4", np.int64([4, 5, 6]), 2, 1, 1),
+ ("5", np.int64([4, 5, 6]), 2, 1, 2),
+ ("6", np.int64([4, 5, 6]), 2, 3, None),
+ ("7", np.int64([4, 5, 6]), 2, 3, 1),
+ ("8", np.int64([4, 5, 6]), 2, 3, 2),
+ ("9", np.int64([4, 5, 6]), 7, 2, None),
+ ("10", np.int64([4, 5, 6]), 7, 2, 1),
+ ("11", np.int64([4, 5, 6]), 7, 2, 3),
+ ("12", np.int64([4, 5, 6]), 7, 2, 5),
+ ("13", np.int64([4, 5, 6]), 7, 2, 7),
+ ("14", np.int64([]), 2, 3, None),
+ ("15", np.int64([0, 0, 0]), 2, 3, None),
+ ("16", np.int64([4, 0, 6]), 2, 3, None),
+ ("17", np.int64([4, 0, 6]), 2, 3, 1),
+ ("18", np.int64([4, 0, 6]), 2, 3, 2),
+ )
+ def testInterleaveDataset(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ count = 2
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ def repeat(values, count):
+ result = []
+ for value in values:
+ result.append([value] * value)
+ return result * count
with self.test_session() as sess:
- # Cycle length 1 acts like `Dataset.flat_map()`.
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 1, block_length: 3})
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
- self.assertEqual(expected_element, sess.run(next_element))
-
- # Cycle length > 1.
- # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
- # 6, 5, 6, 5, 6, 5, 6, 5]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 1})
for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > 1 and block length > 1.
- # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
- # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > len(input_values) * repeat_count.
- # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
- # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 7, block_length: 2})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Empty input.
- sess.run(init_op, feed_dict={input_values: [],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ repeat(input_values, count), cycle_length, block_length):
+ self.assertEqual(expected_element, sess.run(get_next))
+
+ for _ in range(2):
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
+ ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
+ ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
+ ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
+ ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
+ ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
+ ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
+ ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
+ ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
+ ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
+ ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
+ ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
+ ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
+ )
+ def testInterleaveErrorDataset(self,
+ input_values,
+ cycle_length,
+ block_length,
+ num_parallel_calls):
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).interleave(
+ dataset_ops.Dataset.from_tensors, cycle_length, block_length,
+ num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
- # Non-empty input leading to empty output.
- sess.run(init_op, feed_dict={input_values: [0, 0, 0],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Mixture of non-empty and empty interleaved datasets.
- sess.run(init_op, feed_dict={input_values: [4, 0, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
+ with self.test_session() as sess:
+ for value in input_values:
+ if np.isnan(value):
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+ else:
+ self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ sess.run(get_next)
def testSparse(self):
@@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyInput(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([])
- .repeat(None)
- .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index 352424514e..671e5d4812 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -91,7 +91,7 @@ class IteratorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
@@ -117,7 +117,7 @@ class IteratorTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
@@ -208,7 +208,7 @@ class IteratorTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(next_element)
@@ -216,7 +216,7 @@ class IteratorTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(next_element)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def consumer_thread():
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
@@ -287,7 +287,7 @@ class IteratorTest(test.TestCase):
.make_initializable_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"iterator has not been initialized"):
sess.run(get_next)
@@ -308,7 +308,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, iterator.output_types)
self.assertEqual([None], iterator.output_shapes.as_list())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The iterator is initially uninitialized.
with self.assertRaises(errors.FailedPreconditionError):
sess.run(get_next)
@@ -380,7 +380,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
self.assertEqual([], feedable_iterator.output_shapes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle())
@@ -436,7 +436,7 @@ class IteratorTest(test.TestCase):
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
self.assertEqual([], feedable_iterator.output_shapes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle())
@@ -524,7 +524,7 @@ class IteratorTest(test.TestCase):
feedable_int_any = iterator_ops.Iterator.from_string_handle(
handle_placeholder, dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
handle_int_scalar = sess.run(
dataset_int_scalar.make_one_shot_iterator().string_handle())
handle_float_vector = sess.run(
@@ -687,7 +687,7 @@ class IteratorTest(test.TestCase):
f=_remote_fn,
target=target_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
elem = sess.run(
remote_op,
feed_dict={
@@ -756,7 +756,7 @@ class IteratorTest(test.TestCase):
# Saving iterator for RangeDataset graph.
with ops.Graph().as_default() as g:
init_op, _, save_op, _ = _build_range_dataset_graph()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(save_op)
@@ -767,7 +767,7 @@ class IteratorTest(test.TestCase):
# IteratorResource::set_iterator.
with ops.Graph().as_default() as g:
_, _, _, restore_op = _build_reader_dataset_graph()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(restore_op)
@@ -803,16 +803,15 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- with self.test_session() as sess:
- self.assertAllEqual([1, 4], get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
+ self.assertAllEqual([1, 4], get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreMultipleIterator(self):
@@ -833,19 +832,18 @@ class IteratorCheckpointingTest(test.TestCase):
) else functools.partial(self.evaluate, iterator_3.get_next())
checkpoint = checkpointable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
- with self.test_session() as sess:
- self.assertAllEqual([1, 4], get_next_1())
- self.assertAllEqual(0, get_next_3())
- self.assertAllEqual(1, get_next_3())
- self.assertAllEqual(2, get_next_3())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual([9, 16], get_next_2())
- self.assertAllEqual(3, get_next_3())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual([9, 16], get_next_1())
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual(3, get_next_3())
+ self.assertAllEqual([1, 4], get_next_1())
+ self.assertAllEqual(0, get_next_3())
+ self.assertAllEqual(1, get_next_3())
+ self.assertAllEqual(2, get_next_3())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual([9, 16], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next_1())
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual(3, get_next_3())
@test_util.run_in_graph_and_eager_modes
def testRestoreExhaustedIterator(self):
@@ -856,17 +854,16 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- with self.test_session() as sess:
- self.assertAllEqual(0, get_next())
- self.assertAllEqual(1, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual(2, get_next())
- checkpoint.restore(save_path).run_restore_ops(sess)
- self.assertAllEqual(2, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- checkpoint.restore(save_path).run_restore_ops(sess)
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
+ self.assertAllEqual(0, get_next())
+ self.assertAllEqual(1, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual(2, get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual(2, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ checkpoint.restore(save_path).run_restore_ops()
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
@@ -876,7 +873,7 @@ class IteratorCheckpointingTest(test.TestCase):
get_next = iterator.get_next()
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index 579096f880..c4b338a58f 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -44,7 +44,7 @@ class ListFilesDatasetOpTest(test.TestCase):
def testEmptyDirectory(self):
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@ class ListFilesDatasetOpTest(test.TestCase):
self._touchTempFiles(filenames)
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
@@ -75,7 +75,7 @@ class ListFilesDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
@@ -91,7 +91,7 @@ class ListFilesDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
@@ -121,7 +121,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
with self.assertRaisesRegexp(
errors.InvalidArgumentError, 'No files matched pattern: '):
@@ -136,7 +136,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -162,7 +162,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -187,7 +187,7 @@ class ListFilesDatasetOpTest(test.TestCase):
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
@@ -221,7 +221,7 @@ class ListFilesDatasetOpTest(test.TestCase):
# more meaningful.
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 637bde9ae4..7685d8dbdc 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -22,8 +22,10 @@ import threading
import time
import warnings
+from absl.testing import parameterized
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -31,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
@@ -44,7 +47,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test.TestCase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -69,7 +72,7 @@ class MapDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={count: 14})
for _ in range(14):
@@ -135,7 +138,8 @@ class MapDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
def do_test(num_parallel_calls_val, output_buffer_size_val):
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={
@@ -200,7 +204,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -215,7 +219,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -230,7 +234,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -251,7 +255,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -282,7 +286,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(table.init)
sess.run(init_op)
sess.run(get_next)
@@ -300,7 +304,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
@@ -325,7 +329,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
@@ -344,7 +348,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(counter_var.initializer)
sess.run(init_op)
for i in range(10):
@@ -364,7 +368,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.NotFoundError):
sess.run(get_next)
@@ -376,7 +380,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
random_values = []
with self.assertRaises(errors.OutOfRangeError):
@@ -401,7 +405,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
@@ -433,7 +437,7 @@ class MapDatasetTest(test.TestCase):
next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
# make sure both datasets contain the same data
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(count):
tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
self.assertEqual(tuple_, namedtuple_)
@@ -451,7 +455,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(row ** 2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -482,7 +486,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Simple test that prefetch yields the expected values in the
# expected order.
for buffer_size in [1, 10, 100, 1000]:
@@ -520,7 +524,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
@@ -541,7 +545,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, 37.0), sess.run(get_next))
@@ -567,7 +571,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
@@ -594,7 +598,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
@@ -618,7 +622,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(100):
self.assertEqual(i, sess.run(get_next))
@@ -632,7 +636,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, b"hello", 10), sess.run(get_next))
@@ -673,63 +677,139 @@ class MapDatasetTest(test.TestCase):
r"Dataset.map\(\): None."):
_ = dataset.map(lambda x: None)
+ def testBrokenFunctionErrorOnInitialization(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0, 3.0])
+
+ def broken_function(_):
+ """A function deliberately designed to fail on instantiation."""
+ value = []
+ tensor_value = attr_value_pb2.AttrValue()
+ tensor_value.tensor.CopyFrom(
+ tensor_util.make_tensor_proto(
+ value, dtype=dtypes.float32, shape=[0], verify_shape=False))
+ dtype_value = attr_value_pb2.AttrValue(type=dtypes.int32.as_datatype_enum)
+
+ # Create a "Const" op with a `tf.float32` value and a `tf.int32` type
+ # attr.
+ const_tensor = ops.get_default_graph().create_op(
+ "Const", [], [dtypes.int32],
+ attrs={
+ "value": tensor_value,
+ "dtype": dtype_value
+ },
+ name="BrokenConst").outputs[0]
+ return const_tensor
+
+ dataset = dataset.map(broken_function)
+ iterator = dataset.make_initializable_iterator()
+
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
+ sess.run(iterator.initializer)
+
+# pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Map", lambda dataset, func:
+ dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)),
+ ("ParallelMap", lambda dataset, func:
+ dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1,
+ use_inter_op_parallelism=False)),
+ )
+ def testNoInterOpParallelism(self, make_dataset_fn):
+ dataset = dataset_ops.Dataset.from_tensors(0)
+
+ def _get_tid():
+ return np.int64(threading.current_thread().ident)
+
+ def _map_fn(_):
+ tids = []
+ for _ in range(10):
+ tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
+ return tids
+
+ dataset = make_dataset_fn(dataset, _map_fn)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ tids = sess.run(get_next)
+ self.assertTrue(all(tids[0] == tid for tid in tids))
+# pylint: enable=g-long-lambda
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
+ for use_inter_op_parallelism in [False, True]:
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ lambda x: x,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset chain length: %d Median wall time: %f"
- % (chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000, wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_%d" % chain_length)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset chain length%s: %d Median wall time: %f" %
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_chain_latency_%d%s" %
+ (chain_length, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(
- tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element[0].op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
+ for use_inter_op_parallelism in [False, True]:
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(
+ tuple(0 for _ in range(fan_out))).repeat(None)
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ lambda *xs: xs,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
sess.run(next_element[0].op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset fan out: %d Median wall time: %f"
- % (fan_out, median_wall_time))
- self.report_benchmark(
- iters=1000, wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d" % fan_out)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element[0].op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset fan out%s: %d Median wall time: %f" %
+ (" (single threaded mode)" if not use_inter_op_parallelism
+ else "", fan_out, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_fan_out_%d%s" %
+ (fan_out, "_single_threaded"
+ if not use_inter_op_parallelism else ""))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index a32527af8d..c344513e71 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -158,7 +158,7 @@ class OptionalTest(test.TestCase):
self.assertEqual(ds.output_classes, next_elem.output_classes)
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index 63a0830272..cc97bac609 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -36,7 +36,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
for m in range(10):
self.assertEqual(m, sess.run(get_next))
@@ -51,7 +51,7 @@ class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 0c530522b8..51e90785e7 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -49,7 +49,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={stop: 5})
for i in range(5):
self.assertEqual(i, sess.run(get_next))
@@ -64,7 +64,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 5})
for i in range(2, 5):
self.assertEqual(i, sess.run(get_next))
@@ -80,7 +80,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
for i in range(2, 10, 2):
self.assertEqual(i, sess.run(get_next))
@@ -95,7 +95,7 @@ class RangeDatasetTest(test.TestCase):
step).make_initializable_iterator()
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
@@ -108,7 +108,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -125,7 +125,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -143,7 +143,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
@@ -161,7 +161,7 @@ class RangeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
for i in range(10, 2, -1):
self.assertEqual(i, sess.run(get_next))
@@ -203,7 +203,7 @@ class RangeDatasetTest(test.TestCase):
break_point = 5
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
@@ -212,7 +212,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
@@ -223,7 +223,7 @@ class RangeDatasetTest(test.TestCase):
# Saving and restoring in same session.
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
@@ -254,7 +254,7 @@ class RangeDatasetTest(test.TestCase):
break_epoch = 3
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for _ in range(break_epoch):
@@ -272,7 +272,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes)
restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, sess.run(get_next))
@@ -300,7 +300,7 @@ class RangeDatasetTest(test.TestCase):
break_point = 5
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
@@ -311,7 +311,7 @@ class RangeDatasetTest(test.TestCase):
# Intentionally build a graph with a different value for stop to make sure
# the original dataset graph is actually getting loaded.
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for i in range(break_point, stop):
self.assertEqual(i, sess.run(get_next))
@@ -338,7 +338,7 @@ class RangeDatasetTest(test.TestCase):
break_point = 5
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
@@ -347,7 +347,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(restore_op)
for i in range(break_point, stop):
@@ -373,7 +373,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point1):
@@ -382,7 +382,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for i in range(break_point1, break_point2):
self.assertEqual(i, sess.run(get_next))
@@ -391,7 +391,7 @@ class RangeDatasetTest(test.TestCase):
break_point2 = 7
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for i in range(break_point2, stop):
self.assertEqual(i, sess.run(get_next))
@@ -417,7 +417,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
@@ -433,7 +433,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for i in range(break_range, stop):
self.assertEqual(i, sess.run(get_next))
@@ -460,7 +460,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
@@ -476,7 +476,7 @@ class RangeDatasetTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index e99f0a203b..aa3636364d 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -100,7 +100,7 @@ class TextLineDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -163,7 +163,7 @@ class TextLineDatasetTest(test.TestCase):
repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
iterator = repeat_dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
@@ -240,7 +240,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
init_op, feed_dict={filenames: [test_filenames[0]],
@@ -302,7 +302,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
@@ -319,7 +319,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
buffer_size=10)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
@@ -374,7 +374,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -401,7 +401,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
@@ -427,7 +427,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -454,7 +454,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
sess.run(restore_op)
for epoch in range(num_epochs):
@@ -479,7 +479,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -506,7 +506,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs_1)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
@@ -529,7 +529,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -555,7 +555,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
restore_op, get_next_op = self._restore_iterator()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for epoch in range(num_epochs):
for f in range(self._num_files):
@@ -574,7 +574,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -585,7 +585,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
for _ in range(num_epochs * self._num_files * self._num_records):
sess.run(get_next_op)
@@ -598,7 +598,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
@@ -615,7 +615,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
with ops.Graph().as_default() as g:
init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(restore_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
@@ -661,7 +661,7 @@ class TFRecordDatasetTest(test.TestCase):
return filenames
def testReadOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from file 0.
sess.run(
self.init_op,
@@ -698,7 +698,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: self.test_filenames,
@@ -711,7 +711,7 @@ class TFRecordDatasetTest(test.TestCase):
sess.run(self.get_next)
def testReadTenEpochsOfBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_batch_op,
feed_dict={
@@ -738,7 +738,7 @@ class TFRecordDatasetTest(test.TestCase):
f.write(cdata)
zlib_files.append(zfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: zlib_files,
@@ -758,7 +758,7 @@ class TFRecordDatasetTest(test.TestCase):
gzf.write(f.read())
gzip_files.append(gzfn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={self.filenames: gzip_files,
@@ -774,7 +774,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -786,7 +786,7 @@ class TFRecordDatasetTest(test.TestCase):
d = readers.TFRecordDataset(files)
iterator = d.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -801,7 +801,7 @@ class TFRecordDatasetTest(test.TestCase):
next_element = iterator.get_next()
expected = []
actual = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 1d27b036eb..37e2333560 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -44,7 +44,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test a finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 3})
for _ in range(3):
@@ -90,7 +90,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Take fewer than input size
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4):
@@ -136,7 +136,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Skip fewer than input size, we should skip
# the first 4 elements and then read the rest.
sess.run(init_op, feed_dict={count_placeholder: 4})
@@ -183,7 +183,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
for _ in range(7 * 14):
results = sess.run(get_next)
@@ -199,7 +199,7 @@ class SequenceDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index cefe872d0f..137f6341ce 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -28,7 +28,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(2, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -40,7 +40,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual((2, 8), sess.run(iterator.get_next()))
self.assertEqual((7, 3), sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -50,7 +50,7 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(10).shard(5, 0)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(iterator.get_next()))
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -76,14 +76,14 @@ class ShardDatasetOpTest(test.TestCase):
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
def testLargerWorkerPool(self):
dataset = dataset_ops.Dataset.range(10).shard(7, 5)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(5, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
@@ -91,7 +91,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 4)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(4, sess.run(iterator.get_next()))
self.assertEqual(9, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
@@ -100,7 +100,7 @@ class ShardDatasetOpTest(test.TestCase):
def testIndexEqualsNumShards2(self):
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(iterator.get_next()))
self.assertEqual(7, sess.run(iterator.get_next()))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 5fcc48831f..f294840706 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -60,7 +60,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First run without shuffling to collect the "ground truth".
sess.run(init_fifo_op)
unshuffled_elements = []
@@ -140,7 +140,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
elems = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
elems.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -152,7 +152,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_initializable_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
for elem in elems:
self.assertEqual(elem, sess.run(get_next))
@@ -166,7 +166,7 @@ class ShuffleDatasetTest(test.TestCase):
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
counts = collections.defaultdict(lambda: 0)
for _ in range(10):
for _ in range(5):
@@ -183,7 +183,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = sess.run(next_element)
self.assertAllEqual(initial_permutation, sess.run(next_element))
self.assertAllEqual(initial_permutation, sess.run(next_element))
@@ -198,7 +198,7 @@ class ShuffleDatasetTest(test.TestCase):
.make_one_shot_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initial_permutation = list(sess.run(next_element))
for _ in range(2):
next_permutation = list(sess.run(next_element))
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 55933118b9..3106effbd3 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -45,7 +45,7 @@ class ZipDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
@@ -93,7 +93,7 @@ class ZipDatasetTest(test.TestCase):
self.assertEqual([22], get_next[1][0].shape)
self.assertEqual([], get_next[1][1].shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 50ba5f403e..57517afae8 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -27,6 +27,7 @@ py_library(
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8ba98cb88d..c985e00dd1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -225,7 +225,7 @@ class Dataset(object):
`tf.constant` operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors contains
one or more large NumPy arrays, consider the alternative described in
- @{$guide/datasets#consuming_numpy_arrays$this guide}.
+ [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors.
@@ -244,7 +244,7 @@ class Dataset(object):
`tf.constant` operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors contains
one or more large NumPy arrays, consider the alternative described in
- @{$guide/datasets#consuming_numpy_arrays$this guide}.
+ [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors, each having the same size in the
@@ -1019,7 +1019,11 @@ class Dataset(object):
"""
return FlatMapDataset(self, map_func)
- def interleave(self, map_func, cycle_length, block_length=1):
+ def interleave(self,
+ map_func,
+ cycle_length,
+ block_length=1,
+ num_parallel_calls=None):
"""Maps `map_func` across this dataset, and interleaves the results.
For example, you can use `Dataset.interleave()` to process many input files
@@ -1082,11 +1086,19 @@ class Dataset(object):
processed concurrently.
block_length: The number of consecutive elements to produce from each
input element before cycling to another input element.
+ num_parallel_calls: (Optional.) If specified, the implementation creates
+ a threadpool, which is used to fetch inputs from cycle elements
+ asynchronously and in parallel. The default behavior is to fetch inputs
+ from cycle elements synchronously with no parallelism.
Returns:
Dataset: A `Dataset`.
"""
- return InterleaveDataset(self, map_func, cycle_length, block_length)
+ if num_parallel_calls is None:
+ return InterleaveDataset(self, map_func, cycle_length, block_length)
+ else:
+ return ParallelInterleaveDataset(self, map_func, cycle_length,
+ block_length, num_parallel_calls)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
@@ -1684,15 +1696,14 @@ class ConcatenateDataset(Dataset):
super(ConcatenateDataset, self).__init__()
self._input_dataset = input_dataset
self._dataset_to_concatenate = dataset_to_concatenate
- nest.assert_same_structure(input_dataset.output_types,
- dataset_to_concatenate.output_types)
- for a, b in zip(
- nest.flatten(input_dataset.output_types),
- nest.flatten(dataset_to_concatenate.output_types)):
- if a != b:
- raise TypeError(
- "Two datasets to concatenate have different types %s and %s" %
- (input_dataset.output_types, dataset_to_concatenate.output_types))
+ if input_dataset.output_types != dataset_to_concatenate.output_types:
+ raise TypeError(
+ "Two datasets to concatenate have different types %s and %s" %
+ (input_dataset.output_types, dataset_to_concatenate.output_types))
+ if input_dataset.output_classes != dataset_to_concatenate.output_classes:
+ raise TypeError(
+ "Two datasets to concatenate have different classes %s and %s" %
+ (input_dataset.output_classes, dataset_to_concatenate.output_classes))
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -2208,10 +2219,11 @@ def _warn_if_collections(transformation_name):
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
- def __init__(self, input_dataset, map_func):
+ def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
super(MapDataset, self).__init__()
self._input_dataset = input_dataset
+ self._use_inter_op_parallelism = use_inter_op_parallelism
wrapped_func = StructuredFunctionWrapper(
map_func, "Dataset.map()", input_dataset)
@@ -2226,6 +2238,7 @@ class MapDataset(Dataset):
input_t,
self._map_func.captured_inputs,
f=self._map_func,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
@property
@@ -2244,9 +2257,14 @@ class MapDataset(Dataset):
class ParallelMapDataset(MapDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
- def __init__(self, input_dataset, map_func, num_parallel_calls):
+ def __init__(self,
+ input_dataset,
+ map_func,
+ num_parallel_calls,
+ use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(ParallelMapDataset, self).__init__(input_dataset, map_func)
+ super(ParallelMapDataset, self).__init__(input_dataset, map_func,
+ use_inter_op_parallelism)
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
@@ -2259,6 +2277,7 @@ class ParallelMapDataset(MapDataset):
self._map_func.captured_inputs,
f=self._map_func,
num_parallel_calls=self._num_parallel_calls,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
# pylint: enable=protected-access
@@ -2329,6 +2348,36 @@ class InterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
+class ParallelInterleaveDataset(FlatMapDataset):
+ """A `Dataset` that maps a function over its input and interleaves the result.
+
+ """
+
+ def __init__(self, input_dataset, map_func, cycle_length, block_length,
+ num_parallel_calls):
+ """See `Dataset.interleave()` for details."""
+ super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
+ self._cycle_length = ops.convert_to_tensor(
+ cycle_length, dtype=dtypes.int64, name="cycle_length")
+ self._block_length = ops.convert_to_tensor(
+ block_length, dtype=dtypes.int64, name="block_length")
+ self._num_parallel_calls = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parallel_interleave_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._map_func.captured_inputs, # pylint: disable=protected-access
+ self._cycle_length,
+ self._block_length,
+ self._num_parallel_calls,
+ f=self._map_func, # pylint: disable=protected-access
+ **flat_structure(self))
+
+ def _transformation_name(self):
+ return "Dataset.interleave()"
+
+
class FilterDataset(Dataset):
"""A `Dataset` that filters its input according to a predicate function."""
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 5fcc62b60b..39082ce370 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -63,6 +63,41 @@ py_test(
)
py_library(
+ name = "structure",
+ srcs = ["structure.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nest",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "structure_test",
+ size = "small",
+ srcs = ["structure_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nest",
+ ":structure",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:variables",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_library(
name = "convert",
srcs = ["convert.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 6a67093e48..89c3afb296 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -30,28 +30,28 @@ class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
def testPartialShapeToTensorKnownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([1]))))
self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
@@ -60,7 +60,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([1], dtype=dtypes.int64))))
def testPartialShapeToTensorUnknownDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([None]))))
self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
@@ -84,7 +84,7 @@ class ConvertTest(test.TestCase):
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
def testPartialShapeToTensorMultipleDimensions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([3, 6]))))
self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
@@ -113,7 +113,7 @@ class ConvertTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64))))
def testPartialShapeToTensorScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
tensor_shape.TensorShape([]))))
self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 1b596bdfc0..9d621fcd30 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -129,35 +129,18 @@ def flatten(nest):
return _pywrap_tensorflow.FlattenForData(nest)
-def _recursive_assert_same_structure(nest1, nest2, check_types):
- is_sequence_nest1 = is_sequence(nest1)
- if is_sequence_nest1 != is_sequence(nest2):
- raise ValueError(
- "The two structures don't have the same nested structure. "
- "First structure: %s, second structure: %s." % (nest1, nest2))
-
- if is_sequence_nest1:
- type_nest1 = type(nest1)
- type_nest2 = type(nest2)
- if check_types and type_nest1 != type_nest2:
- raise TypeError(
- "The two structures don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
-
- for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
- _recursive_assert_same_structure(n1, n2, check_types)
-
-
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well. If set to `False`, for example a list and a tuple of objects will
- look same if they have the same size.
+ check_types: if `True` (default) types of sequences should be same as
+ well. For dictionary, "type" of dictionary is considered to include its
+ keys. In other words, two dictionaries with different keys are considered
+ to have a different "type". If set to `False`, two iterables are
+ considered same as long as they yield the elements that have same
+ structures.
Raises:
ValueError: If the two structures do not have the same number of elements or
@@ -165,13 +148,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
- len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
- if len_nest1 != len_nest2:
- raise ValueError("The two structures don't have the same number of "
- "elements. First structure: %s, second structure: %s."
- % (nest1, nest2))
- _recursive_assert_same_structure(nest1, nest2, check_types)
+ _pywrap_tensorflow.AssertSameStructureForData(nest1, nest2, check_types)
def _packed_nest_with_indices(structure, flat, index):
diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py
index ff380815a4..616aa9f551 100644
--- a/tensorflow/python/data/util/nest_test.py
+++ b/tensorflow/python/data/util/nest_test.py
@@ -163,21 +163,30 @@ class NestTest(test.TestCase):
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
+ structure_dictionary = {"foo": 2, "bar": 4, "baz": {"foo": 5, "bar": 6}}
+ structure_dictionary_diff_nested = {
+ "foo": 2,
+ "bar": 4,
+ "baz": {
+ "foo": 5,
+ "baz": 6
+ }
+ }
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure((0, 1), np.array([0, 1]))
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure(0, (0, 1))
with self.assertRaisesRegexp(ValueError,
@@ -203,11 +212,23 @@ class NestTest(test.TestCase):
nest.assert_same_structure(((3,), 4), (3, (4,)))
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
+ structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)}
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)
+ with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
+ nest.assert_same_structure(structure1_list, structure2_list)
+ with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
+ nest.assert_same_structure(structure_dictionary,
+ structure_dictionary_diff_nested)
+ nest.assert_same_structure(
+ structure_dictionary,
+ structure_dictionary_diff_nested,
+ check_types=False)
+ nest.assert_same_structure(
+ structure1_list, structure2_list, check_types=False)
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py
index d49b3ff34b..056b32480f 100644
--- a/tensorflow/python/data/util/sparse_test.py
+++ b/tensorflow/python/data/util/sparse_test.py
@@ -291,7 +291,7 @@ class SparseTest(test.TestCase):
self.assertEqual(a, b)
return
self.assertTrue(isinstance(b, sparse_tensor.SparseTensor))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(a.eval().indices, b.eval().indices)
self.assertAllEqual(a.eval().values, b.eval().values)
self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape)
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
new file mode 100644
index 0000000000..c5764b8dfe
--- /dev/null
+++ b/tensorflow/python/data/util/structure.py
@@ -0,0 +1,315 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for describing the structure of a `tf.data` type."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import sparse_ops
+
+
+class Structure(object):
+ """Represents structural information, such as type and shape, about a value.
+
+ A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape`
+ properties, so that we can define generic containers of objects including:
+
+ * `tf.Tensor`
+ * `tf.SparseTensor`
+ * Nested structures of the above.
+
+ TODO(b/110122868): In the future, a single `Structure` will replace the
+ `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`,
+ and `tf.data.Dataset.output_classes`, and similar properties and arguments in
+ the `tf.data.Iterator` and `Optional` classes.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def _flat_shapes(self):
+ """A list of shapes matching the shapes of `self._to_tensor_list()`.
+
+ Returns:
+ A list of `tf.TensorShape` objects.
+ """
+ raise NotImplementedError("Structure._flat_shapes")
+
+ @abc.abstractproperty
+ def _flat_types(self):
+ """A list of types matching the types of `self._to_tensor_list()`.
+
+ Returns:
+ A list of `tf.DType` objects.
+ """
+ raise NotImplementedError("Structure._flat_shapes")
+
+ @abc.abstractmethod
+ def is_compatible_with(self, value):
+ """Returns `True` if `value` is compatible with this structure.
+
+ A value `value` is compatible with a structure `s` if
+ `Structure.from_value(value)` would return a structure `t` that is a
+ "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+
+ * `s` and `t` are instances of the same `Structure` subclass.
+ * The nested structures (if any) of `s` and `t` are the same, according to
+ `tf.contrib.framework.nest.assert_same_structure`, and each nested
+ structure of `t` is a "subtype" of the corresponding nested structure of
+ `s`.
+ * Any `tf.DType` components of `t` are the same as the corresponding
+ components in `s`.
+ * Any `tf.TensorShape` components of `t` are compatible with the
+ corresponding components in `s`, according to
+ `tf.TensorShape.is_compatible_with`.
+
+ Args:
+ value: A potentially structured value.
+
+ Returns:
+ `True` if `value` matches this structure, otherwise `False`.
+ """
+ raise NotImplementedError("Structure.is_compatible_with()")
+
+ @abc.abstractmethod
+ def _to_tensor_list(self, value):
+ """Returns a flat list of `tf.Tensor` representing `value`.
+
+ This method can be used, along with `self._flat_shapes` and
+ `self._flat_types` to represent structured values in lower level APIs
+ (such as plain TensorFlow operations) that do not understand structure.
+
+ Requires: `self.is_compatible_with(value)`.
+
+ Args:
+ value: A value with compatible structure.
+
+ Returns:
+ A flat list of `tf.Tensor` representing `value`.
+ """
+ raise NotImplementedError("Structure._to_tensor_list()")
+
+ @abc.abstractmethod
+ def _from_tensor_list(self, flat_value):
+ """Builds a flat list of `tf.Tensor` into a value matching this structure.
+
+ Requires: The shapes and types of the tensors in `flat_value` must be
+ compatible with `self._flat_shapes` and `self._flat_types` respectively.
+
+ Args:
+ flat_value: A list of `tf.Tensor` with compatible flat structure.
+
+ Returns:
+ A structured object matching this structure.
+ """
+ raise NotImplementedError("Structure._from_tensor_list()")
+
+ @staticmethod
+ def from_value(value):
+ """Returns a `Structure` that represents the given `value`.
+
+ Args:
+ value: A potentially structured value.
+
+ Returns:
+ A `Structure` that is compatible with `value`.
+
+ Raises:
+ TypeError: If a structure cannot be built for `value`, because its type
+ or one of its component types is not supported.
+ """
+
+ # TODO(b/110122868): Add support for custom types, Dataset, and Optional
+ # to this method.
+ if isinstance(
+ value,
+ (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
+ return SparseTensorStructure.from_value(value)
+ elif isinstance(value, (tuple, dict)):
+ return NestedStructure.from_value(value)
+ else:
+ try:
+ tensor = ops.convert_to_tensor(value)
+ except (ValueError, TypeError):
+ raise TypeError("Could not build a structure for %r" % value)
+ return TensorStructure.from_value(tensor)
+
+
+# NOTE(mrry): The following classes make extensive use of non-public methods of
+# their base class, so we disable the protected-access lint warning once here.
+# pylint: disable=protected-access
+class NestedStructure(Structure):
+ """Represents a nested structure in which each leaf is a `Structure`."""
+
+ def __init__(self, nested_structure):
+ self._nested_structure = nested_structure
+ self._flat_shapes_list = []
+ self._flat_types_list = []
+ for s in nest.flatten(nested_structure):
+ if not isinstance(s, Structure):
+ raise TypeError("nested_structure must be a (potentially nested) tuple "
+ "or dictionary of Structure objects.")
+ self._flat_shapes_list.extend(s._flat_shapes)
+ self._flat_types_list.extend(s._flat_types)
+
+ @property
+ def _flat_shapes(self):
+ return self._flat_shapes_list
+
+ @property
+ def _flat_types(self):
+ return self._flat_types_list
+
+ def is_compatible_with(self, value):
+ try:
+ nest.assert_shallow_structure(self._nested_structure, value)
+ except (ValueError, TypeError):
+ return False
+
+ return all(
+ s.is_compatible_with(v) for s, v in zip(
+ nest.flatten(self._nested_structure),
+ nest.flatten_up_to(self._nested_structure, value)))
+
+ def _to_tensor_list(self, value):
+ ret = []
+
+ try:
+ flat_value = nest.flatten_up_to(self._nested_structure, value)
+ except (ValueError, TypeError):
+ raise ValueError("The value %r is not compatible with the nested "
+ "structure %r." % (value, self._nested_structure))
+
+ for sub_value, structure in zip(flat_value,
+ nest.flatten(self._nested_structure)):
+ if not structure.is_compatible_with(sub_value):
+ raise ValueError("Component value %r is not compatible with the nested "
+ "structure %r." % (sub_value, structure))
+ ret.extend(structure._to_tensor_list(sub_value))
+ return ret
+
+ def _from_tensor_list(self, flat_value):
+ if len(flat_value) != len(self._flat_types):
+ raise ValueError("Expected %d flat values in NestedStructure but got %d."
+ % (len(self._flat_types), len(flat_value)))
+
+ flat_ret = []
+ for sub_value, structure in zip(flat_value,
+ nest.flatten(self._nested_structure)):
+ flat_ret.append(structure._from_tensor_list([sub_value]))
+
+ return nest.pack_sequence_as(self._nested_structure, flat_ret)
+
+ @staticmethod
+ def from_value(value):
+ flat_nested_structure = [
+ Structure.from_value(sub_value) for sub_value in nest.flatten(value)
+ ]
+ return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
+
+
+class TensorStructure(Structure):
+ """Represents structural information about a `tf.Tensor`."""
+
+ def __init__(self, dtype, shape):
+ self._dtype = dtypes.as_dtype(dtype)
+ self._shape = tensor_shape.as_shape(shape)
+
+ @property
+ def _flat_shapes(self):
+ return [self._shape]
+
+ @property
+ def _flat_types(self):
+ return [self._dtype]
+
+ def is_compatible_with(self, value):
+ try:
+ value = ops.convert_to_tensor(value, dtype=self._dtype)
+ except (ValueError, TypeError):
+ return False
+
+ return (self._dtype.is_compatible_with(value.dtype) and
+ self._shape.is_compatible_with(value.shape))
+
+ def _to_tensor_list(self, value):
+ if not self.is_compatible_with(value):
+ raise ValueError("Value %r is not convertible to a tensor with dtype %s "
+ "and shape %s." % (value, self._dtype, self._shape))
+ return [value]
+
+ def _from_tensor_list(self, flat_value):
+ if len(flat_value) != 1:
+ raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
+ if not self.is_compatible_with(flat_value[0]):
+ raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
+ "%s." % (flat_value[0], self._dtype, self._shape))
+ return flat_value[0]
+
+ @staticmethod
+ def from_value(value):
+ return TensorStructure(value.dtype, value.shape)
+
+
+class SparseTensorStructure(Structure):
+ """Represents structural information about a `tf.SparseTensor`."""
+
+ def __init__(self, dtype, dense_shape):
+ self._dtype = dtypes.as_dtype(dtype)
+ self._dense_shape = tensor_shape.as_shape(dense_shape)
+
+ @property
+ def _flat_shapes(self):
+ return [tensor_shape.vector(3)]
+
+ @property
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, value):
+ try:
+ value = sparse_tensor_lib.SparseTensor.from_value(value)
+ except TypeError:
+ return False
+ return (isinstance(value, (sparse_tensor_lib.SparseTensor,
+ sparse_tensor_lib.SparseTensorValue)) and
+ self._dtype.is_compatible_with(value.dtype) and
+ self._dense_shape.is_compatible_with(
+ tensor_util.constant_value_as_shape(value.dense_shape)))
+
+ def _to_tensor_list(self, value):
+ return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))):
+ raise ValueError("SparseTensorStructure corresponds to a single "
+ "tf.variant vector of length 3.")
+ return sparse_ops.deserialize_sparse(
+ flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
+
+ @staticmethod
+ def from_value(value):
+ sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value)
+ return SparseTensorStructure(
+ sparse_tensor.dtype,
+ tensor_util.constant_value_as_shape(sparse_tensor.dense_shape))
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
new file mode 100644
index 0000000000..d0c7df67ae
--- /dev/null
+++ b/tensorflow/python/data/util/structure_test.py
@@ -0,0 +1,327 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utilities working with arbitrarily nested structures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import structure
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class StructureTest(test.TestCase, parameterized.TestCase):
+ # pylint disable=protected-access
+
+ @parameterized.parameters(
+ (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32],
+ [[]]), (sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ structure.SparseTensorStructure, [dtypes.variant], [[3]]),
+ ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
+ structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ ({
+ "a":
+ constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ }, structure.NestedStructure,
+ [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
+ def testFlatStructure(self, value, expected_structure, expected_types,
+ expected_shapes):
+ s = structure.Structure.from_value(value)
+ self.assertIsInstance(s, expected_structure)
+ self.assertEqual(expected_types, s._flat_types)
+ self.assertEqual(expected_shapes, s._flat_shapes)
+
+ @parameterized.parameters(
+ (constant_op.constant(37.0), [
+ constant_op.constant(38.0),
+ array_ops.placeholder(dtypes.float32),
+ variables.Variable(100.0), 42.0,
+ np.array(42.0, dtype=np.float32)
+ ], [constant_op.constant([1.0, 2.0]),
+ constant_op.constant(37)]),
+ (sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ [
+ sparse_tensor.SparseTensor(
+ indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
+ sparse_tensor.SparseTensorValue(
+ indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
+ array_ops.sparse_placeholder(dtype=dtypes.int32),
+ array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
+ ], [
+ constant_op.constant(37, shape=[4, 5]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
+ array_ops.sparse_placeholder(
+ dtype=dtypes.int32, shape=[None, None, None]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
+ ]),
+ ({
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, [{
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6])
+ }], [{
+ "a": constant_op.constant(15.0),
+ "b": constant_op.constant([4, 5, 6, 7])
+ }, {
+ "a": constant_op.constant(15),
+ "b": constant_op.constant([4, 5, 6])
+ }, {
+ "a":
+ constant_op.constant(15),
+ "b":
+ sparse_tensor.SparseTensor(
+ indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
+ }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
+ )
+ def testIsCompatibleWith(self, original_value, compatible_values,
+ incompatible_values):
+ s = structure.Structure.from_value(original_value)
+ for compatible_value in compatible_values:
+ self.assertTrue(s.is_compatible_with(compatible_value))
+ for incompatible_value in incompatible_values:
+ self.assertFalse(s.is_compatible_with(incompatible_value))
+
+ # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
+ # will be executed before the (eager- or graph-mode) test environment has been
+ # set up.
+ # pylint: disable=g-long-lambda
+ @parameterized.parameters(
+ (lambda: constant_op.constant(37.0),),
+ (lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),),
+ (lambda: {"a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])},),
+ (lambda: {"a": constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ },),
+ )
+ def testRoundTripConversion(self, value_fn):
+ value = value_fn()
+ s = structure.Structure.from_value(value)
+ before = self.evaluate(value)
+ after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))
+
+ flat_before = nest.flatten(before)
+ flat_after = nest.flatten(after)
+ for b, a in zip(flat_before, flat_after):
+ if isinstance(b, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(b.indices, a.indices)
+ self.assertAllEqual(b.values, a.values)
+ self.assertAllEqual(b.dense_shape, a.dense_shape)
+ else:
+ self.assertAllEqual(b, a)
+ # pylint: enable=g-long-lambda
+
+ def testIncompatibleStructure(self):
+ # Define three mutually incompatible values/structures, and assert that:
+ # 1. Using one structure to flatten a value with an incompatible structure
+ # fails.
+ # 2. Using one structure to restructre a flattened value with an
+ # incompatible structure fails.
+ value_tensor = constant_op.constant(42.0)
+ s_tensor = structure.Structure.from_value(value_tensor)
+ flat_tensor = s_tensor._to_tensor_list(value_tensor)
+
+ value_sparse_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])
+ s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
+ flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor)
+
+ value_nest = {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }
+ s_nest = structure.Structure.from_value(value_nest)
+ flat_nest = s_nest._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"SparseTensor.* is not convertible to a tensor with "
+ r"dtype.*float32.* and shape \(\)"):
+ s_tensor._to_tensor_list(value_sparse_tensor)
+ with self.assertRaisesRegexp(
+ ValueError, r"Value \{.*\} is not convertible to a tensor with "
+ r"dtype.*float32.* and shape \(\)"):
+ s_tensor._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
+ s_sparse_tensor._to_tensor_list(value_tensor)
+
+ with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
+ s_sparse_tensor._to_tensor_list(value_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.* not compatible with the nested structure "
+ ".*TensorStructure.*TensorStructure"):
+ s_nest._to_tensor_list(value_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.* not compatible with the nested structure "
+ ".*TensorStructure.*TensorStructure"):
+ s_nest._to_tensor_list(value_sparse_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
+ s_tensor._from_tensor_list(flat_sparse_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "TensorStructure corresponds to a single tf.Tensor."):
+ s_tensor._from_tensor_list(flat_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_sparse_tensor._from_tensor_list(flat_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_sparse_tensor._from_tensor_list(flat_nest)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 1."):
+ s_nest._from_tensor_list(flat_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 1."):
+ s_nest._from_tensor_list(flat_sparse_tensor)
+
+ def testIncompatibleNestedStructure(self):
+ # Define three mutually incompatible nested values/structures, and assert
+ # that:
+ # 1. Using one structure to flatten a value with an incompatible structure
+ # fails.
+ # 2. Using one structure to restructre a flattened value with an
+ # incompatible structure fails.
+
+ value_0 = {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }
+ s_0 = structure.Structure.from_value(value_0)
+ flat_s_0 = s_0._to_tensor_list(value_0)
+
+ # `value_1` has compatible nested structure with `value_0`, but different
+ # classes.
+ value_1 = {
+ "a":
+ constant_op.constant(37.0),
+ "b":
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1])
+ }
+ s_1 = structure.Structure.from_value(value_1)
+ flat_s_1 = s_1._to_tensor_list(value_1)
+
+ # `value_2` has incompatible nested structure with `value_0` and `value_1`.
+ value_2 = {
+ "a":
+ constant_op.constant(37.0),
+ "b": (sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
+ }
+ s_2 = structure.Structure.from_value(value_2)
+ flat_s_2 = s_2._to_tensor_list(value_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.* not compatible with the nested structure "
+ ".*TensorStructure"):
+ s_0._to_tensor_list(value_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
+ "nested structure .*TensorStructure"):
+ s_0._to_tensor_list(value_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.* not compatible with the nested structure "
+ ".*SparseTensorStructure"):
+ s_1._to_tensor_list(value_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
+ "nested structure .*TensorStructure"):
+ s_0._to_tensor_list(value_2)
+
+ # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
+ # needs to account for "a" coming before or after "b". It might be worth
+ # adding a deterministic repr for these error messages (among other
+ # improvements).
+ with self.assertRaisesRegexp(
+ ValueError, "Tensor.*Tensor.* not compatible with the nested structure "
+ ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
+ "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
+ s_2._to_tensor_list(value_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
+ "not compatible with the nested structure .*"
+ "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
+ "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
+ s_2._to_tensor_list(value_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
+ s_0._from_tensor_list(flat_s_1)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 3."):
+ s_0._from_tensor_list(flat_s_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "SparseTensorStructure corresponds to a single tf.variant "
+ "vector of length 3."):
+ s_1._from_tensor_list(flat_s_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 2 flat values in NestedStructure but got 3."):
+ s_1._from_tensor_list(flat_s_2)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 3 flat values in NestedStructure but got 2."):
+ s_2._from_tensor_list(flat_s_0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "Expected 3 flat values in NestedStructure but got 2."):
+ s_2._from_tensor_list(flat_s_1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 8a4ac6aaef..849d165bfa 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -576,7 +576,6 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
- "nomac",
"oss_serial",
],
deps = [
@@ -1047,7 +1046,6 @@ cuda_py_test(
tags = [
"no_oss", # Incompatible with bazel_pip.
"no_windows",
- "nomac", # TODO(cais): Install of futures and grpcio on all macs.
"notsan",
],
)
@@ -1102,6 +1100,23 @@ py_test(
],
)
+py_test(
+ name = "disk_usage_test",
+ size = "small",
+ srcs = ["wrappers/disk_usage_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dumping_wrapper",
+ ":hooks",
+ "//tensorflow/python:client",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
sh_test(
name = "examples_test",
size = "medium",
diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index 34da44b60d..242215dccb 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Public Python API of TensorFlow Debugger (tfdbg).
-See the @{$python/tfdbg} guide.
+See the [TFDBG](https://tensorflow.org/api_guides/python/tfdbg) guide.
@@add_debug_tensor_watch
@@watch_graph
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 7cbaae46b4..019f13c450 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -113,17 +113,16 @@ def main(_):
n_classes=3,
model_dir=model_dir)
- hooks = None
if FLAGS.debug and FLAGS.tensorboard_debug_address:
raise ValueError(
"The --debug and --tensorboard_debug_address flags are mutually "
"exclusive.")
+ hooks = []
if FLAGS.debug:
- debug_hook = tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type,
- dump_root=FLAGS.dump_root)
+ hooks.append(tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type,
+ dump_root=FLAGS.dump_root))
elif FLAGS.tensorboard_debug_address:
- debug_hook = tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address)
- hooks = [debug_hook]
+ hooks.append(tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address))
# Train model, using tfdbg hook.
classifier.train(training_input_fn,
diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py
index f1e972940b..f2a43a6152 100644
--- a/tensorflow/python/debug/lib/debug_utils.py
+++ b/tensorflow/python/debug/lib/debug_utils.py
@@ -87,7 +87,8 @@ def watch_graph(run_options,
op_type_regex_whitelist=None,
tensor_dtype_regex_whitelist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug watches to `RunOptions` for a TensorFlow graph.
To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
@@ -130,6 +131,8 @@ def watch_graph(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -170,6 +173,7 @@ def watch_graph(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
def watch_graph_with_blacklists(run_options,
@@ -180,7 +184,8 @@ def watch_graph_with_blacklists(run_options,
op_type_regex_blacklist=None,
tensor_dtype_regex_blacklist=None,
tolerate_debug_op_creation_failures=False,
- global_step=-1):
+ global_step=-1,
+ reset_disk_byte_usage=False):
"""Add debug tensor watches, blacklisting nodes and op types.
This is similar to `watch_graph()`, but the node names and op types are
@@ -219,6 +224,8 @@ def watch_graph_with_blacklists(run_options,
throwing exceptions.
global_step: (`int`) Optional global_step count for this debug tensor
watch.
+ reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
+ usage to zero (default: `False`).
"""
if isinstance(debug_ops, str):
@@ -259,3 +266,4 @@ def watch_graph_with_blacklists(run_options,
tolerate_debug_op_creation_failures=(
tolerate_debug_op_creation_failures),
global_step=global_step)
+ run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
diff --git a/tensorflow/python/debug/wrappers/disk_usage_test.py b/tensorflow/python/debug/wrappers/disk_usage_test.py
new file mode 100644
index 0000000000..0874525966
--- /dev/null
+++ b/tensorflow/python/debug/wrappers/disk_usage_test.py
@@ -0,0 +1,109 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debugger Wrapper Session Consisting of a Local Curses-based CLI."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.python.client import session
+from tensorflow.python.debug.wrappers import dumping_wrapper
+from tensorflow.python.debug.wrappers import hooks
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import monitored_session
+
+
+class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # For efficient testing, set the disk usage bytes limit to a small
+ # number (10).
+ os.environ["TFDBG_DISK_BYTES_LIMIT"] = "10"
+
+ def setUp(self):
+ self.session_root = tempfile.mkdtemp()
+
+ self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
+ self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
+ self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
+ self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")
+
+ self.sess = session.Session()
+ self.sess.run(self.v.initializer)
+
+ def testWrapperSessionNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r"(.*delta.*|.*inc_v.*)", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ sess.run(self.inc_v)
+
+ def testWrapperSessionExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ sess = dumping_wrapper.DumpingDebugWrapperSession(
+ self.sess, session_root=self.session_root,
+ watch_fn=_watch_fn, log_usage=False)
+ # Due to the watch function, each run should dump only 1 tensor,
+ # which has a size of 4 bytes, which corresponds to the dumped 'delta:0'
+ # tensor of scalar shape and float32 dtype.
+ # 1st run should pass, after which the disk usage is at 4 bytes.
+ sess.run(self.inc_v)
+ # 2nd run should also pass, after which 8 bytes are used.
+ sess.run(self.inc_v)
+ # 3rd run should fail, because the total byte count (12) exceeds the
+ # limit (10)
+ with self.assertRaises(ValueError):
+ sess.run(self.inc_v)
+
+ def testHookNotExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ mon_sess.run(self.inc_v)
+
+ def testHookExceedingLimit(self):
+ def _watch_fn(fetches, feeds):
+ del fetches, feeds
+ return "DebugIdentity", r".*delta.*", r".*"
+ dumping_hook = hooks.DumpingDebugHook(
+ self.session_root, watch_fn=_watch_fn, log_usage=False)
+ mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
+ # Like in `testWrapperSessionExceedingLimit`, the first two calls
+ # should be within the byte limit, but the third one should error
+ # out due to exceeding the limit.
+ mon_sess.run(self.inc_v)
+ mon_sess.run(self.inc_v)
+ with self.assertRaises(ValueError):
+ mon_sess.run(self.inc_v)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index b9524ce649..afda1fdc0d 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -447,13 +447,16 @@ class BaseDebugWrapperSession(session.SessionInterface):
"callable_runner and callable_options are mutually exclusive, but "
"are both specified in this call to BaseDebugWrapperSession.run().")
- if not (callable_runner or callable_options):
- self.increment_run_call_count()
- elif callable_runner and (fetches or feed_dict):
+ if callable_runner and (fetches or feed_dict):
raise ValueError(
"callable_runner and fetches/feed_dict are mutually exclusive, "
"but are used simultaneously.")
+ elif callable_options and (fetches or feed_dict):
+ raise ValueError(
+ "callable_options and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
+ self.increment_run_call_count()
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
tf_logging.info(
@@ -649,6 +652,18 @@ class BaseDebugWrapperSession(session.SessionInterface):
def increment_run_call_count(self):
self._run_call_count += 1
+ def _is_disk_usage_reset_each_run(self):
+ """Indicates whether disk usage is reset after each Session.run.
+
+ Subclasses that clean up the disk usage after every run should
+ override this protected method.
+
+ Returns:
+ (`bool`) Whether the disk usage amount is reset to zero after
+ each Session.run.
+ """
+ return False
+
def _decorate_run_options_for_debug(
self,
run_options,
@@ -686,7 +701,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
node_name_regex_whitelist=node_name_regex_whitelist,
op_type_regex_whitelist=op_type_regex_whitelist,
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
- tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
+ tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
+ reset_disk_byte_usage=(self._run_call_count == 1 or
+ self._is_disk_usage_reset_each_run()))
def _decorate_run_options_for_profile(self, run_options):
"""Modify a RunOptions object for profiling TensorFlow graph execution.
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 5e4604fda4..872b675506 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -188,6 +188,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
pass
def before_run(self, run_context):
+ reset_disk_byte_usage = False
if not self._session_wrapper:
self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
run_context.session,
@@ -195,6 +196,7 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
watch_fn=self._watch_fn,
thread_name_filter=self._thread_name_filter,
log_usage=self._log_usage)
+ reset_disk_byte_usage = True
self._session_wrapper.increment_run_call_count()
@@ -212,7 +214,8 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
- watch_options.tolerate_debug_op_creation_failures))
+ watch_options.tolerate_debug_op_creation_failures),
+ reset_disk_byte_usage=reset_disk_byte_usage)
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 668ffb57f1..a3ce4d388b 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -124,6 +124,11 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
self._ui_type = ui_type
+ def _is_disk_usage_reset_each_run(self):
+ # The dumped tensors are all cleaned up after every Session.run
+ # in a command-line wrapper.
+ return True
+
def _initialize_argparsers(self):
self._argparsers = {}
ap = argparse.ArgumentParser(
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 68d8b8d13b..bdc869c643 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -9,13 +9,36 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
+ name = "distribute",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":distribute_config",
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ ],
+)
+
+py_library(
+ name = "distribute_config",
+ srcs = [
+ "distribute_config.py",
+ ],
+ deps = [],
+)
+
+py_library(
name = "distribute_coordinator",
srcs = [
"distribute_coordinator.py",
],
srcs_version = "PY2AND3",
deps = [
+ ":distribute_coordinator_context",
+ ":multi_worker_util",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
],
)
@@ -25,7 +48,11 @@ py_test(
size = "large",
srcs = ["distribute_coordinator_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "manual",
+ "no_pip",
+ "notap",
+ ],
deps = [
":distribute_coordinator",
"//tensorflow/core:protos_all_py",
@@ -41,3 +68,57 @@ py_test(
"//tensorflow/python:variables",
],
)
+
+py_library(
+ name = "distribute_coordinator_context",
+ srcs = [
+ "distribute_coordinator_context.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [],
+)
+
+py_library(
+ name = "multi_worker_util",
+ srcs = [
+ "multi_worker_util.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "multi_worker_util_test",
+ srcs = ["multi_worker_util_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":multi_worker_util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+# Used only by estimator.
+py_library(
+ name = "estimator_training",
+ srcs = [
+ "estimator_training.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ "//tensorflow/python:training",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_config.py b/tensorflow/python/distribute/distribute_config.py
new file mode 100644
index 0000000000..fac35742fe
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_config.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A configure tuple for high-level APIs for running distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+class DistributeConfig(
+ collections.namedtuple(
+ 'DistributeConfig',
+ ['train_distribute', 'eval_distribute', 'remote_cluster'])):
+ """A config tuple for distribution strategies.
+
+ Attributes:
+ train_distribute: a `DistributionStrategy` object for training.
+ eval_distribute: an optional `DistributionStrategy` object for
+ evaluation.
+ remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
+ the cluster configurations. If this is given, the `train_and_evaluate`
+ method will be running as a standalone client which connects to the
+ cluster for training.
+ """
+
+ def __new__(cls,
+ train_distribute=None,
+ eval_distribute=None,
+ remote_cluster=None):
+ return super(DistributeConfig, cls).__new__(cls, train_distribute,
+ eval_distribute, remote_cluster)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index fc9ca4ac4a..bd3562f1ff 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""A unified and split coordinator for distributed TensorFlow."""
+"""A component for running distributed TensorFlow."""
from __future__ import absolute_import
from __future__ import division
@@ -22,8 +22,14 @@ import copy
import json
import os
import threading
-
-from tensorflow.core.protobuf import cluster_pb2
+import time
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -43,23 +49,12 @@ class CoordinatorMode(object):
# client and connects to remote servers for training. Each remote server can
# use the distribute coordinator binary with task_type set correctly which
# will then turn into standard servers.
- SPLIT_CLIENT = 0
+ STANDALONE_CLIENT = "standalone_client"
# The distribute coordinator runs on each worker. It will run a standard
# server on each worker and optionally run the `worker_fn` that is configured
# to talk to its standard server.
- INDEPENDENT_WORKER = 1
-
-
-_worker_context = threading.local()
-
-
-def get_current_worker_context():
- """Returns the current task context."""
- try:
- return _worker_context.current
- except AttributeError:
- return None
+ INDEPENDENT_WORKER = "independent_worker"
class _Barrier(object):
@@ -113,14 +108,17 @@ class _WorkerContext(object):
"""
def __init__(self,
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=None,
rpc_layer="grpc",
worker_barrier=None):
"""Initialize the worker context object.
Args:
+ strategy: a `DistributionStrategy` object.
cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case.
task_type: a string indicating the role of the corresponding task, such as
@@ -128,14 +126,17 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
worker_barrier: optional, the barrier object for worker synchronization.
"""
+ self._strategy = strategy
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
+ self._session_config = session_config
self._worker_barrier = worker_barrier
self._rpc_layer = rpc_layer
self._master_target = self._get_master_target()
@@ -143,26 +144,31 @@ class _WorkerContext(object):
self._is_chief_node = self._is_chief()
def _debug_message(self):
- return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
- self._cluster_spec, self.task_type, self.task_id)
+ if self._cluster_spec:
+ return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
+ self._cluster_spec, self.task_type, self.task_id)
+ else:
+ return "[local]"
def __enter__(self):
- old_context = get_current_worker_context()
+ old_context = distribute_coordinator_context.get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message())
- _worker_context.current = self
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
- _worker_context.current = None
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
# If cluster_spec is None or empty, we use local master.
if not self._cluster_spec:
- return "local"
+ return ""
# If task_type is None, then it is in-graph replicated training. In this
# case we use the chief or first worker's master target.
@@ -207,6 +213,54 @@ class _WorkerContext(object):
self._debug_message())
self._worker_barrier.wait()
+ def session_creator(self,
+ scaffold=None,
+ config=None,
+ checkpoint_dir=None,
+ checkpoint_filename_with_path=None,
+ max_wait_secs=7200):
+ """Returns a session creator.
+
+ The returned session creator will be configured with the correct master
+ target and session configs. It will also run either init ops or ready ops
+ by querying the `strategy` object when `create_session` is called on it.
+
+ Args:
+ scaffold: A `Scaffold` used for gathering or building supportive ops. If
+ not specified a default one is created. It's used to finalize the graph.
+ config: `ConfigProto` proto used to configure the session.
+ checkpoint_dir: A string. Optional path to a directory where to restore
+ variables.
+ checkpoint_filename_with_path: Full file name path to the checkpoint file.
+ Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
+ specified.
+ max_wait_secs: Maximum time to wait for the session to become available.
+
+ Returns:
+ a descendant of SessionCreator.
+ """
+ if config:
+ session_config = copy.deepcopy(config)
+ session_config.MergeFrom(self._session_config)
+ else:
+ session_config = self._session_config
+
+ if not self._strategy or self._strategy.should_init:
+ logging.info("Creating chief session creator with config: %r", config)
+ return monitored_session.ChiefSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=session_config,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_filename_with_path=checkpoint_filename_with_path)
+ else:
+ logging.info("Creating worker session creator with config: %r", config)
+ return monitored_session.WorkerSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=session_config,
+ max_wait_secs=max_wait_secs)
+
@property
def has_barrier(self):
"""Whether the barrier is set or not."""
@@ -247,46 +301,125 @@ class _WorkerContext(object):
"""Returns number of workers in the cluster, including chief."""
return self._num_workers
+ @property
+ def should_checkpoint(self):
+ """Whether to save checkpoint."""
+ return self._strategy.should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ """Whether to save summaries."""
+ return self._strategy.should_save_summary
+
def _run_single_worker(worker_fn,
+ strategy,
cluster_spec,
task_type,
task_id,
- rpc_layer,
+ session_config,
+ rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
- with _WorkerContext(
+ session_config = copy.deepcopy(session_config)
+ strategy = copy.deepcopy(strategy)
+ # If there is an EVALUATOR task, we run single-machine eval on that task.
+ if task_type == _TaskType.EVALUATOR:
+ # It is possible to not have a strategy object for EVALUATOR task.
+ if strategy:
+ strategy.configure(session_config)
+ else:
+ assert strategy
+ strategy.configure(session_config, cluster_spec, task_type, task_id)
+
+ context = _WorkerContext(
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
- worker_barrier=worker_barrier):
- worker_fn()
+ worker_barrier=worker_barrier)
+ with context:
+ worker_fn(strategy)
+
+
+def _split_cluster_for_evaluator(cluster_spec, task_type):
+ """Split the cluster for evaluator since it needn't talk to other tasks."""
+ # Splitting the cluster is important to prevent the evaluator from talking to
+ # other tasks in the cluster. Since we allow evaluator not to use
+ # distribution strategies and as a result ops in the evalauator task may have
+ # unspecified devices. Those ops may end up on other tasks if we don't split
+ # the cluster.
+ new_cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec).as_dict()
+ if task_type == _TaskType.EVALUATOR:
+ assert _TaskType.EVALUATOR in new_cluster_spec
+ new_cluster_spec = {
+ _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
+ }
+ else:
+ new_cluster_spec.pop(_TaskType.EVALUATOR, None)
+ return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
"""Runs a standard server."""
- server = server_lib.Server(
- cluster_spec,
- job_name=task_type,
- task_index=task_id,
- config=session_config,
- protocol=rpc_layer)
- server.start()
- return server
-
-
-def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
+
+ class _FakeServer(object):
+ """A fake server that runs a master session."""
+
+ def start(self):
+ # A tensorflow server starts when a remote session is created.
+ logging.info(
+ "Creating a remote session to start a TensorFlow server, "
+ "target = %r, session_config=%r", target, session_config)
+ session.Session(target=target, config=session_config)
+
+ def join(self):
+ while True:
+ time.sleep(5)
+
+ if environment == "google":
+ server = _FakeServer()
+ server.start()
+ return server
+ else:
+ if session_config:
+ logging.info(
+ "Starting standard TensorFlow server, target = %r, session_config= "
+ "%r", target, session_config)
+ else:
+ logging.info("Starting standard TensorFlow server, target = %r", target)
+ cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
+ server = server_lib.Server(
+ cluster_spec,
+ job_name=task_type,
+ task_index=task_id,
+ config=session_config,
+ protocol=rpc_layer)
+ server.start()
+ return server
+
+
+def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
@@ -298,7 +431,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
t = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, task_type, task_id),
+ args=(worker_fn, strategy, cluster_spec, task_type, task_id,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
"worker_barrier": worker_barrier
@@ -315,43 +449,155 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
eval_thread.join()
-def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
eval_thread.start()
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(
+ worker_fn,
+ strategy,
+ cluster_spec,
+ None,
+ None,
+ session_config,
+ rpc_layer=rpc_layer)
if eval_thread:
eval_thread.join()
-# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
+def _configure_session_config_for_std_servers(
+ strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
+ # pylint: disable=g-doc-args
+ """Call strategy's `configure` to mutate the session_config.
+
+ The session_config is currently needed as default config for a TensorFlow
+ server. In the future, we should be able to remove this method and only pass
+ the session config to a client session.
+ """
+ if task_type == _TaskType.EVALUATOR:
+ if eval_strategy:
+ eval_strategy.configure(session_config=session_config)
+ else:
+ # The strategy may be shared in standalone client mode.
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(
+ session_config=session_config,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ # Remove the device filters specific to the strategy, so that the
+ # TensorFlow server brought up with one strategy can be used by other
+ # strategies. The device filters can be set in the client side as well.
+ del session_config.device_filters[:]
+
+
+def run_standard_tensorflow_server(session_config=None):
+ """Starts a standard TensorFlow server.
+
+ This method parses configurations from "TF_CONFIG" environment variable and
+ starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
+ must have information of the cluster and the role of the server in the
+ cluster. One example is:
+
+ TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:2222", "host2:2222", "host3:2222"],
+ "ps": ["host4:2222", "host5:2222"]
+ },
+ "task": {"type": "worker", "index": 1}
+ }'
+
+ This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
+ and the current role is worker 1.
+
+ Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
+ at most one "chief" and at most one "evaluator".
+
+ An optional key-value can be specified is "rpc_layer". The default value is
+ "grpc".
+
+ Args:
+ session_config: an optional `tf.ConfigProto` object. Users can pass in
+ the session config object to configure server-local devices.
+
+ Returns:
+ a `tf.train.Server` object which has already been started.
+
+ Raises:
+ ValueError: if the "TF_CONFIG" environment is not complete.
+ """
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if "cluster" not in tf_config:
+ raise ValueError("\"cluster\" is not found in TF_CONFIG.")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
+ if "task" not in tf_config:
+ raise ValueError("\"task\" is not found in TF_CONFIG.")
+ task_env = tf_config["task"]
+ if "type" not in task_env:
+ raise ValueError(
+ "\"task_type\" is not found in the `task` part of TF_CONFIG.")
+ task_type = task_env["type"]
+ task_id = int(task_env.get("index", 0))
+
+ rpc_layer = tf_config.get("rpc_layer", "grpc")
+
+ session_config = session_config or config_pb2.ConfigProto()
+ # Set the collective group leader for collective ops to initialize collective
+ # ops when server starts.
+ if "chief" in cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ server = _run_std_server(
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config,
+ rpc_layer=rpc_layer)
+ server.start()
+ return server
+
+
+# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
- mode=CoordinatorMode.SPLIT_CLIENT,
+ strategy,
+ eval_fn=None,
+ eval_strategy=None,
+ mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
task_id=None,
- between_graph=False,
+ session_config=None,
rpc_layer="grpc"):
"""Runs the coordinator for distributed TensorFlow.
This function runs a split coordinator for distributed TensorFlow in its
- default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
- server addresses and their roles in a cluster, this coordinator will figure
- out how to set them up, give the underlying function the right targets for
- master sessions via a scope object and coordinate their training. The cluster
- consisting of standard servers needs to be brought up either with the standard
- server binary or with a binary running distribute coordinator with `task_type`
- set to non-client type which will then turn into standard servers.
+ default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
+ specifying server addresses and their roles in a cluster, this coordinator
+ will figure out how to set them up, give the underlying function the right
+ targets for master sessions via a scope object and coordinate their training.
+ The cluster consisting of standard servers needs to be brought up either with
+ the standard server binary or with a binary running distribute coordinator
+ with `task_type` set to non-client type which will then turn into standard
+ servers.
In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple
@@ -370,6 +616,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` depending whether it is between-graph training or in-graph
replicated training.
+ The `strategy` object is expected to be a DistributionStrategy object which
+ has implemented methods needed by distributed coordinator such as
+ `configure(session_config, cluster_spec, task_type, task_id)` which configures
+ the strategy object for a specific task and `should_init` property which
+ instructs the distribute coordinator whether to run init ops for a task. The
+ distribute coordinator will make a copy of the `strategy` object, call its
+ `configure` method and pass it to `worker_fn` as an argument.
+
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the
@@ -407,22 +661,32 @@ def run_distribute_coordinator(worker_fn,
If `cluster_spec` is not given in any format, it becomes local training and
this coordinator will connect to a local session.
- For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
- will be created with its `task_type` set to "evaluator". If "evaluator" is not
- set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
- evaluation.
+ For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
+ will be created to call `eval_fn` with its `task_type` set to "evaluator". If
+ `eval_fn` is not defined, fall back to `worker_fn`. This implies that
+ evaluation will be done on a single machine if there is an "evaluator" task.
+ If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
+ `worker_fn` for how to do evaluation.
Args:
- worker_fn: the function to be called and given the access to a coordinator
- context object.
+ worker_fn: the function to be called. The function should accept a
+ `strategy` object and will be given access to a context object via a
+ context manager scope.
+ strategy: a DistributionStrategy object which specifying whether it should
+ run between-graph replicated training or not, whether to run init ops,
+ etc. This object will also be configured given `session_config`,
+ `cluster_spec`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
+ in but a "evaluator" task found in the `cluster_spec`, the `worker_fn`
+ will be used for this task.
+ eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- between_graph: a boolean. It is only useful when `cluster_spec` is set and
- not empty. If true, it will use between-graph replicated training;
- otherwise it will use in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object which will be passed
+ to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
Raises:
@@ -438,52 +702,105 @@ def run_distribute_coordinator(worker_fn,
task_id = int(task_env.get("index", task_id))
if cluster_spec:
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
# TODO(yuefengz): validate cluster_spec.
+ rpc_layer = tf_config.get("rpc_layer", rpc_layer)
+ environment = tf_config.get("environment", None)
+
+ # Setting the session config is necessary for some strategies such
+ # CollectiveAllReduceStrategy.
+ session_config = session_config or config_pb2.ConfigProto(
+ allow_soft_placement=True)
+
+ if cluster_spec:
+ logging.info(
+ "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
+ "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
+ cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
+
if not cluster_spec:
# `mode` is ignored in the local case.
- _run_single_worker(worker_fn, None, None, None, rpc_layer)
- elif mode == CoordinatorMode.SPLIT_CLIENT:
+ logging.info("Running local Distribute Coordinator.")
+ _run_single_worker(worker_fn, strategy, None, None, None, session_config,
+ rpc_layer)
+ if eval_fn:
+ _run_single_worker(eval_fn, eval_strategy, None, None, None,
+ session_config, rpc_layer)
+ else:
+ logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
+ elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
+ eval_fn = eval_fn or worker_fn
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
+
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
- if between_graph:
- _run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
+ if strategy.between_graph:
+ _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
- _run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
+ _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config,
+ rpc_layer=rpc_layer,
+ environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
- # Every one starts a standard server.
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
+ eval_fn = eval_fn or worker_fn
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
+
+ # Every one starts a standard server, get session config from `configure`
+ # method.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config,
+ rpc_layer=rpc_layer,
+ environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
- if between_graph:
+ if strategy.between_graph:
# All jobs run `worker_fn` if between-graph.
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id,
- rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
# Only one node runs `worker_fn` if in-graph.
- context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
+ context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief:
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
+ session_config, rpc_layer)
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
+ _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/python/distribute/distribute_coordinator_context.py
index 09c9a4ab33..dee65ce883 100644
--- a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
+++ b/tensorflow/python/distribute/distribute_coordinator_context.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
+"""The context retrieval method for distribute coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.op_queue import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
+import threading
-_allowed_symbols = [
- 'OpQueue',
-]
+_worker_context = threading.local()
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
+
+def get_current_worker_context():
+ """Returns the current task context."""
+ try:
+ return _worker_context.current
+ except AttributeError:
+ return None
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 319c29ba2f..b07308a1b5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for distribute coordinator."""
+"""Tests for Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
@@ -20,23 +20,25 @@ from __future__ import print_function
import contextlib
import copy
+import json
import os
import sys
import threading
+import time
import six
-# pylint: disable=invalid-name
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error:
+except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
-# pylint: enable=invalid-name
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
@@ -44,20 +46,22 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
-SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
+STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
-RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
-
NUM_WORKERS = 3
NUM_PS = 2
+original_sys_exit = sys.exit
+
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
@@ -74,10 +78,75 @@ def _strip_protocol(target):
return target
+class MockStrategy(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=None,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self._between_graph = between_graph
+ self._should_init = should_init
+ self._should_checkpoint = should_checkpoint
+ self._should_save_summary = should_save_summary
+
+ @property
+ def between_graph(self):
+ return self._between_graph
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ if self._should_init is None:
+ if task_id == 0:
+ self._should_init = True
+ else:
+ self._should_init = False
+ if self._should_checkpoint is None:
+ if task_id == 0:
+ self._should_checkpoint = True
+ else:
+ self._should_checkpoint = False
+ if self._should_save_summary is None:
+ if task_id == 0:
+ self._should_save_summary = True
+ else:
+ self._should_save_summary = False
+
+ if session_config:
+ if (cluster_spec and task_type and task_id is not None and
+ self._between_graph):
+ session_config.intra_op_parallelism_threads += 1
+ if task_type in ["chief", "worker"]:
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (task_type, task_id), "/job:ps"])
+ else:
+ session_config.inter_op_parallelism_threads += 1
+ session_config.device_filters.append("/job:somejob")
+
+ @property
+ def should_init(self):
+ return self._should_init
+
+ @property
+ def should_checkpoint(self):
+ return self._should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ return self._should_save_summary
+
+
class MockServer(object):
def __init__(self):
self._joined = False
+ self._started = False
+
+ def start(self):
+ self._started = True
def join(self):
assert not self._joined
@@ -87,6 +156,10 @@ class MockServer(object):
def joined(self):
return self._joined
+ @property
+ def started(self):
+ return self._started
+
class DistributeCoordinatorTestBase(test.TestCase):
@@ -95,6 +168,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
# We have to create a global in-process cluster because once an in-process
# tensorflow server is created, there is no way to terminate it. Please see
# multi_worker_test_base.py for more details.
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
cls._workers, cls._ps = test_util.create_local_cluster(
NUM_WORKERS, num_ps=NUM_PS)
cls._cluster_spec = {
@@ -108,6 +182,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
+ self._strategy_property = {}
self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@@ -118,6 +193,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
with session.Session(graph=None, config=config, target=target) as sess:
yield sess
+ # TODO(yuefengz): use the utitliy from multi_worker_test_base.
def _create_cluster_spec(self,
has_chief=False,
num_workers=1,
@@ -142,8 +218,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
- def _in_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _in_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@@ -164,22 +240,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
if result_value == expected:
self._result_correct += 1
- def _run_coordinator_in_thread(self, worker_fn, **kwargs):
+ def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator,
- args=(worker_fn,),
+ args=(worker_fn, strategy),
kwargs=kwargs)
t.start()
return t
- def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
- **kwargs):
+ def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
+ cluster_spec, **kwargs):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread(
worker_fn,
+ strategy,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
@@ -187,8 +264,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
threads[task_type].append(t)
return threads
- def _between_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _between_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@@ -234,14 +311,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
with self._lock:
self._result_correct += 1
- def _dump_worker_context(self):
+ def _between_graph_with_monitored_session(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+ with ops.device("/job:ps/task:0"):
+ # TODO(yuefengz): investigate why not using resource variable will make
+ # the test flaky.
+ x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
+ with ops.device("/job:ps/task:1"):
+ y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
+
+ x_add = x.assign_add(2.0)
+ y_sub = y.assign_sub(2.0)
+ train_op = control_flow_ops.group([x_add, y_sub])
+
+ # The monitored session will run init or ready ops.
+ with monitored_session.MonitoredSession() as sess:
+ sess.run(train_op)
+
+ # Synchronize workers after one step to make sure they all have finished
+ # training.
+ if context.has_barrier:
+ context.wait_for_other_workers()
+ else:
+ self._barrier.wait()
+
+ x_val, y_val = sess.run([x, y])
+
+ self.assertEqual(x_val, 16.0)
+ self.assertEqual(y_val, 14.0)
+ if x_val == 16.0 and y_val == 14.0:
+ with self._lock:
+ self._result_correct += 1
+
+ def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
+
+ Args:
+ strategy: a `DistributionStrategy` object.
"""
- context = distribute_coordinator.get_current_worker_context()
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
@@ -255,12 +368,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
context.is_chief,
context.distributed_mode)
+ def _dump_strategy_property(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+
+ self.assertEqual(context._strategy.should_init, strategy.should_init)
+ self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
+ self.assertEqual(context.should_save_summary, strategy.should_save_summary)
+
+ task_type = str(context.task_type)
+ task_id = context.task_id or 0
+ with self._lock:
+ if task_type not in self._strategy_property:
+ self._strategy_property[task_type] = []
+ while len(self._strategy_property[task_type]) <= task_id:
+ self._strategy_property[task_type].append(None)
+ self._strategy_property[task_type][task_id] = (
+ context._strategy.should_init, context.should_checkpoint,
+ context.should_save_summary)
+
def _run_mock_std_server(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
@@ -274,22 +407,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
return server
-class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
+class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
- def testInGraphSplitMode(self):
- """Test it runs in-graph replication in split client mode."""
+ def testInGraphStandaloneMode(self):
+ """Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
- """Test it runs between-graph replication in split client mode."""
+ """Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ """Test monitored session in standalone client mode."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@@ -298,8 +441,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1)
@@ -318,12 +461,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
+ def testBetweenGraphStrategyProperties(self):
+ # Dumps properties of the strategy objects.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec=self._cluster_spec)
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
@@ -339,7 +500,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_worker_context, cluster_spec=None, between_graph=True)
+ self._dump_worker_context,
+ MockStrategy(between_graph=False),
+ cluster_spec=None)
# There is only a "None" task.
self.assertEqual(len(self._worker_context), 1)
@@ -348,7 +511,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
+ self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@@ -358,8 +521,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec=cluster_spec,
- between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
@@ -391,8 +554,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec=cluster_spec,
- between_graph=False,
rpc_layer=None)
# There are one "None" task and one EVALUATOR task.
@@ -417,8 +580,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn,
+ MockStrategy(between_graph=False),
cluster_spec,
- between_graph=False,
mode=INDEPENDENT_WORKER)
threads[WORKER][0].join()
self.assertEqual(self._result_correct, 1)
@@ -428,8 +591,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn,
+ MockStrategy(between_graph=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ cluster_spec = self._create_cluster_spec(
+ num_workers=NUM_WORKERS, num_ps=NUM_PS)
+ threads = self._run_multiple_coordinator_in_threads(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
cluster_spec,
- between_graph=True,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -444,9 +621,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=True,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -476,6 +653,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined)
+ def testBetweenGraphStrategyProperties(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ # Dumps properties of the strategy objects.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
@@ -483,9 +685,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -519,9 +721,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -552,6 +754,178 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
+ def testRunStdServerInGoogleEnvironment(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec, "environment": "google"}
+
+ joined = [False]
+
+ def _fake_sleep(_):
+ joined[0] = True
+ original_sys_exit(0)
+
+ def _thread_fn(cluster_spec):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ time, "sleep", _fake_sleep):
+ t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
+ t.start()
+ t.join()
+ self.assertTrue(joined[0])
+
+ def testRpcLayerEnvironmentVariable(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
+
+ rpc_layer_from_coordinator = [None]
+
+ def _run_mock_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None,
+ environment=None):
+ del cluster_spec, task_type, task_id, session_config, environment
+ rpc_layer_from_coordinator[0] = rpc_layer
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _run_mock_server):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+ self.assertEqual(rpc_layer_from_coordinator[0], "cake")
+
+
+class StrategyConfigureTest(test.TestCase):
+
+ def setUp(self):
+ self._device_filters = []
+ self._intra_op_parallelism_threads = None
+ self._inter_op_parallelism_threads = None
+ super(StrategyConfigureTest, self).setUp()
+
+ def _dump_device_filters(self, *args, **kwargs):
+ session_config = kwargs.get("session_config", None)
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def _worker_fn(self, strategy):
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+ session_config = worker_context._session_config
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def test_session_config_in_std_server(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server",
+ self._dump_device_filters):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._intra_op_parallelism_threads, 1)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_session_config_in_session_creator(self):
+ cluster_spec = {"worker": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ self._worker_fn,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
+ self.assertEqual(self._intra_op_parallelism_threads, 2)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_eval_strategy_configure(self):
+ cluster_spec = {"evaluator": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=False),
+ eval_fn=self._worker_fn,
+ eval_strategy=MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="evaluator",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:somejob"])
+ self.assertEqual(self._intra_op_parallelism_threads, 0)
+ self.assertEqual(self._inter_op_parallelism_threads, 2)
+
+
+class RunStandardTensorflowServerTest(test.TestCase):
+
+ def test_std_server_arguments(self):
+ cs = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cs, "task": {"type": "ps", "id": 0}}
+
+ def _mock_run_std_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None):
+ self.assertEqual(cluster_spec.as_dict(), cs)
+ self.assertEqual(task_type, "ps")
+ self.assertEqual(task_id, 0)
+ self.assertEqual(session_config.experimental.collective_group_leader,
+ "/job:worker/replica:0/task:0")
+ self.assertEqual(session_config.intra_op_parallelism_threads, 1)
+ self.assertEqual(rpc_layer, "grpc")
+
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _mock_run_std_server):
+ session_config = config_pb2.ConfigProto()
+ session_config.intra_op_parallelism_threads = 1
+ mock_server = distribute_coordinator.run_standard_tensorflow_server(
+ session_config)
+ self.assertTrue(mock_server.started)
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
new file mode 100644
index 0000000000..e17a598123
--- /dev/null
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -0,0 +1,264 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training utilities for Estimator to use Distribute Coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import six
+
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
+
+# pylint: disable=protected-access
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+PS = dc._TaskType.PS
+WORKER = dc._TaskType.WORKER
+
+# pylint: enable=protected-access
+
+
+def _count_ps(cluster_spec):
+ """Counts the number of parameter servers in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_ps` does not expect empty cluster_spec.')
+
+ return len(cluster_spec.as_dict().get(PS, []))
+
+
+def _count_worker(cluster_spec, chief_task_type):
+ """Counts the number of workers (including chief) in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_worker` does not expect empty cluster_spec.')
+
+ return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
+ cluster_spec.as_dict().get(chief_task_type, [])))
+
+
+def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
+ """Returns the global id of the given task type in a cluster."""
+ if not task_type:
+ return 0
+
+ # Sort task names in cluster by "chief"/"master", "evaluator", "worker"
+ # and "ps". More details can be found at the documentation of
+ # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ task_type_ordered_list = []
+ if chief_task_type in cluster_spec.jobs:
+ task_type_ordered_list = [chief_task_type]
+ task_type_ordered_list.extend([
+ t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
+ ])
+ if PS in cluster_spec.jobs:
+ task_type_ordered_list.append(PS)
+
+ # Find the right gloabl_id for current task.
+ next_global_id = 0
+ for t in task_type_ordered_list:
+ if t == task_type:
+ return next_global_id + task_id
+ # `cluster_spec.job_tasks` returns all task addresses of type `t`.
+ next_global_id += len(cluster_spec.job_tasks(t))
+
+ # It is unexpected that it passes through all task_types in
+ # `task_type_ordered_list`.
+ raise RuntimeError('Internal Error: `task_type` ({}) is not in '
+ 'cluster_spec ({}).'.format(task_type, cluster_spec))
+
+
+def _init_run_config_from_worker_context(config, worker_context):
+ """Initializes run config from distribute coordinator's worker context."""
+
+ # pylint: disable=protected-access
+ config._service = None
+ config._cluster_spec = worker_context.cluster_spec
+ config._task_type = worker_context.task_type
+ config._task_id = worker_context.task_id
+ config._evaluation_master = worker_context.master_target
+ config._master = worker_context.master_target
+ config._is_chief = worker_context.is_chief
+
+ if config._cluster_spec:
+ # Distributed mode.
+ if config._task_type != EVALUATOR:
+
+ config._num_ps_replicas = _count_ps(config._cluster_spec)
+ config._num_worker_replicas = _count_worker(
+ config._cluster_spec, chief_task_type=CHIEF)
+ config._global_id_in_cluster = _get_global_id(
+ config._cluster_spec,
+ config._task_type,
+ config._task_id,
+ chief_task_type=CHIEF)
+ else:
+ # Evaluator task should not be aware of the other tasks.
+ config._cluster_spec = server_lib.ClusterSpec({})
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 0
+ config._global_id_in_cluster = None # undefined
+ else:
+ # Local mode.
+ config._global_id_in_cluster = 0
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 1
+
+
+def init_run_config(config, tf_config):
+ """Initializes RunConfig for distribution strategies."""
+ # pylint: disable=protected-access
+ if (config._experimental_distribute and
+ config._experimental_distribute.train_distribute):
+ if config._train_distribute:
+ raise ValueError('Either `train_distribute` or'
+ '`experimental_distribute.train_distribute` can be set.')
+ config._train_distribute = config._experimental_distribute.train_distribute
+
+ if (config._experimental_distribute and
+ config._experimental_distribute.eval_distribute):
+ if config._eval_distribute:
+ raise ValueError('Either `eval_distribute` or'
+ '`experimental_distribute.eval_distribute` can be set.')
+ config._eval_distribute = config._experimental_distribute.eval_distribute
+
+ cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
+ config._init_distributed_setting_from_environment_var({})
+
+ # Use distribute coordinator with STANDALONE_CLIENT mode if
+ # `experimental_distribute.remote_cluster` is set.
+ if (config._train_distribute and config._experimental_distribute and
+ config._experimental_distribute.remote_cluster):
+ if cluster_spec:
+ raise ValueError('Cannot set both "cluster_spec" of TF_CONFIG and '
+ '`experimental_distribute.remote_cluster`')
+ config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
+ config._cluster_spec = config._experimental_distribute.remote_cluster
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'STANDALONE_CLIENT mode')
+ return
+
+ # Don't use distribute coordinator if it is local training or cluster has a
+ # MASTER job or `train_distribute` is not specifed.
+ if (not tf_config or 'master' in cluster_spec.jobs or
+ not config._train_distribute):
+ config._distribute_coordinator_mode = None
+ config._init_distributed_setting_from_environment_var(tf_config)
+ config._maybe_overwrite_session_config_for_distributed_training()
+ logging.info('Not using Distribute Coordinator.')
+ return
+
+ # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
+ assert tf_config
+
+ # Set the cluster_spec only since the distributed setting will come from
+ # distribute coordinator.
+ config._cluster_spec = cluster_spec
+ config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'INDEPENDENT_WORKER mode')
+
+
+def should_run_distribute_coordinator(config):
+ """Checks the config to see whether to run distribute coordinator."""
+ # pylint: disable=protected-access
+ if (not hasattr(config, '_distribute_coordinator_mode') or
+ config._distribute_coordinator_mode is None):
+ return False
+ if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
+ config._distribute_coordinator_mode not in [
+ dc.CoordinatorMode.STANDALONE_CLIENT,
+ dc.CoordinatorMode.INDEPENDENT_WORKER
+ ]):
+ logging.warning('Unexpected distribute_coordinator_mode: %r',
+ config._distribute_coordinator_mode)
+ return False
+ if not config.cluster_spec:
+ logging.warning('Running `train_and_evaluate` locally, ignoring '
+ '`experimental_distribute_coordinator_mode`.')
+ return False
+ return True
+
+
+def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
+ """Run distribute coordinator for Estimator's `train_and_evaluate`.
+
+ Args:
+ estimator: An `Estimator` instance to train and evaluate.
+ train_spec: A `TrainSpec` instance to specify the training specification.
+ eval_spec: A `EvalSpec` instance to specify the evaluation and export
+ specification.
+ executor_cls: the evaluation executor class of Estimator.
+
+ Raises:
+ ValueError: if `distribute_coordinator_mode` is None in RunConfig.
+ """
+ run_config = estimator.config
+ if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
+ raise ValueError(
+ 'Distribute coordinator mode is not specified in `RunConfig`.')
+
+ def _worker_fn(strategy):
+ """Function for worker task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._train_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._train_distribution = strategy
+ # pylint: enable=protected-access
+
+ local_estimator.train(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks))
+
+ def _eval_fn(strategy):
+ """Function for evaluator task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._eval_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._eval_distribution = strategy
+
+ executor = executor_cls(local_estimator, train_spec, eval_spec)
+ executor._start_continuous_evaluation()
+ # pylint: enable=protected-access
+
+ # pylint: disable=protected-access
+ if (run_config._distribute_coordinator_mode ==
+ dc.CoordinatorMode.STANDALONE_CLIENT):
+ cluster_spec = run_config.cluster_spec
+ assert cluster_spec
+ else:
+ # The cluster_spec comes from TF_CONFIG environment variable if it is
+ # INDEPENDENT_WORKER mode.
+ cluster_spec = None
+
+ dc.run_distribute_coordinator(
+ _worker_fn,
+ run_config.train_distribute,
+ _eval_fn,
+ run_config.eval_distribute,
+ mode=run_config._distribute_coordinator_mode,
+ cluster_spec=cluster_spec,
+ session_config=run_config.session_config)
diff --git a/tensorflow/python/distribute/multi_worker_util.py b/tensorflow/python/distribute/multi_worker_util.py
new file mode 100644
index 0000000000..360733eff6
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for multi-worker distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import server_lib
+
+
+def normalize_cluster_spec(cluster_spec):
+ """Makes `cluster_spec` into a `ClusterSpec` object.
+
+ Args:
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+
+ Returns:
+ a `ClusterSpec` object.
+
+ Raises:
+ ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
+ `ClusterDef`.
+ """
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+def is_chief(cluster_spec, task_type, task_id):
+ """Returns whether the given task is chief in the cluster.
+
+ Args:
+ cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
+ cluster configurations.
+ task_type: the task type in the cluster.
+ task_id: the task id in the cluster.
+
+ Returns:
+ a boolean indicating whether the given task is chief.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
+ the maximum id of the `task_type`.
+ """
+ cluster_spec = normalize_cluster_spec(cluster_spec)
+ if task_type not in cluster_spec.jobs:
+ raise ValueError(
+ "The task_type \"%s\" is not in the `cluster_spec`." % task_type)
+ if task_id >= cluster_spec.num_tasks(task_type):
+ raise ValueError("The `task_id` %d exceeds the maximum id of %s." % (
+ task_id, task_type))
+
+ if task_type == "chief":
+ return True
+
+ # If chief not in the cluster_spec, use the first worker as chief. This is
+ # common in CollectiveAllReduceStrategy.
+ if ("chief" not in cluster_spec.jobs and task_type == "worker" and
+ task_id == 0):
+ return True
+ return False
diff --git a/tensorflow/python/distribute/multi_worker_util_test.py b/tensorflow/python/distribute/multi_worker_util_test.py
new file mode 100644
index 0000000000..bdc49725c7
--- /dev/null
+++ b/tensorflow/python/distribute/multi_worker_util_test.py
@@ -0,0 +1,107 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multi_worker_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import test
+from tensorflow.python.training import server_lib
+
+
+class NormalizeClusterSpecTest(test.TestCase):
+
+ def assert_same_cluster(self, lhs, rhs):
+ self.assertEqual(
+ server_lib.ClusterSpec(lhs).as_dict(),
+ server_lib.ClusterSpec(rhs).as_dict())
+
+ def testDictAsInput(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testClusterDefAsInput(self):
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = "chief"
+ job.tasks[0] = "127.0.0.1:1234"
+
+ job = cluster_def.job.add()
+ job.name = "worker"
+ job.tasks[0] = "127.0.0.1:8964"
+ job.tasks[1] = "127.0.0.1:2333"
+
+ job = cluster_def.job.add()
+ job.name = "ps"
+ job.tasks[0] = "127.0.0.1:1926"
+ job.tasks[1] = "127.0.0.1:3141"
+
+ self.assert_same_cluster(
+ cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def))
+
+ def testClusterSpecAsInput(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ })
+ self.assert_same_cluster(
+ cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
+
+ def testUnexpectedInput(self):
+ cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object"):
+ multi_worker_util.normalize_cluster_spec(cluster_spec)
+
+
+class IsChiefTest(test.TestCase):
+
+ def testClusterWithChief(self):
+ cluster_spec = {
+ "chief": ["127.0.0.1:1234"],
+ "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
+ "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
+ }
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+
+ def testClusterWithoutChief(self):
+ cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]}
+ self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0))
+ self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
+
+ with self.assertRaisesRegexp(
+ ValueError, "The task_type \"chief\" is not in the `cluster_spec`."):
+ multi_worker_util.is_chief(cluster_spec, "chief", 0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
+ multi_worker_util.is_chief(cluster_spec, "worker", 2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index de93b1e2e1..85da1baaf0 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -47,7 +47,6 @@ py_library(
":core",
":execute",
":function",
- ":graph_callable",
":graph_only_ops",
":tape",
":test",
@@ -238,10 +237,11 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
+ "//tensorflow/python:cond_v2_impl",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
+ "//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
@@ -254,41 +254,6 @@ py_library(
)
py_library(
- name = "graph_callable",
- srcs = ["graph_callable.py"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/eager:function",
- "//tensorflow/python/eager:tape",
- ],
-)
-
-py_test(
- name = "graph_callable_test",
- srcs = ["graph_callable_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":backprop",
- ":graph_callable",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:function",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:test",
- ],
-)
-
-py_library(
name = "backprop",
srcs = ["backprop.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 553f761a14..be392c7a0f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
@@ -180,10 +181,10 @@ def implicit_val_and_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a tuple pair.
@@ -215,9 +216,7 @@ def implicit_val_and_grad(f):
"function was being computed.")
sources = [v.handle for v in variables]
- grad = imperative_grad.imperative_grad(_default_vspace,
- this_tape,
- nest.flatten(end_node),
+ grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -255,10 +254,10 @@ def implicit_grad(f):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar.
Returns:
A function which, when called, returns a list of (gradient, variable) pairs.
@@ -343,24 +342,24 @@ def gradients_function(f, params=None):
Note that only tensors with real or complex dtypes are differentiable.
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing None
- differentiates with respect to all parameters.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing None
+ differentiates with respect to all parameters.
Returns:
function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
+ of `f` with respect to all of `params`. The function takes an extra optional
+ keyword argument `dy`. Setting it allows computation of vector jacobian
products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -440,23 +439,24 @@ def val_and_grad_function(f, params=None):
```
Args:
- f: function to be differentiated. If `f` returns a scalar, this scalar will
- be differentiated. If `f` returns a tensor or list of tensors, by default
- a scalar will be computed by adding all their values to produce a single
- scalar. If desired, the tensors can be elementwise multiplied by the
- tensors passed as the `dy` keyword argument to the returned gradient
- function.
- params: list of parameter names of f or list of integers indexing the
- parameters with respect to which we'll differentiate. Passing `None`
- differentiates with respect to all parameters.
-
- Returns: function which, when called, returns the value of f and the gradient
- of f with respect to all of `params`. The function takes an extra optional
- keyword argument "dy". Setting it allows computation of vector jacobian
- products for vectors other than the vector of ones.
+ f: function to be differentiated. If `f` returns a scalar, this scalar will
+ be differentiated. If `f` returns a tensor or list of tensors, by default
+ a scalar will be computed by adding all their values to produce a single
+ scalar. If desired, the tensors can be elementwise multiplied by the
+ tensors passed as the `dy` keyword argument to the returned gradient
+ function.
+ params: list of parameter names of f or list of integers indexing the
+ parameters with respect to which we'll differentiate. Passing `None`
+ differentiates with respect to all parameters.
+
+ Returns:
+ function which, when called, returns the value of f and the gradient
+ of f with respect to all of `params`. The function takes an extra optional
+ keyword argument "dy". Setting it allows computation of vector jacobian
+ products for vectors other than the vector of ones.
Raises:
- ValueError: if the params are not all strings or all integers.
+ ValueError: if the params are not all strings or all integers.
"""
def decorated(*args, **kwds):
@@ -520,7 +520,7 @@ def make_vjp(f, params=None, persistent=True):
args = _ensure_unique_tensor_objects(parameter_positions, args)
for i in parameter_positions:
sources.append(args[i])
- tape.watch(args[i])
+ tape.watch(this_tape, args[i])
result = f(*args)
if result is None:
raise ValueError("Cannot differentiate a function that returns None; "
@@ -535,8 +535,8 @@ def make_vjp(f, params=None, persistent=True):
if dy is not None:
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
return imperative_grad.imperative_grad(
- _default_vspace, this_tape, nest.flatten(result), sources,
- output_gradients=dy)
+ this_tape, nest.flatten(result), sources, output_gradients=dy)
+
return result, vjp
return decorated
@@ -557,7 +557,7 @@ def _aggregate_grads(gradients):
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
- return math_ops.add_n(gradients)
+ return gen_math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
@@ -592,7 +592,9 @@ def _num_elements(grad):
def _fast_fill(value, shape, dtype):
- return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
+ return array_ops.fill(
+ constant_op.constant(shape, dtype=dtypes.int32),
+ constant_op.constant(value, dtype=dtype))
def _zeros(shape, dtype):
@@ -627,9 +629,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- tensor_id=ops.tensor_id,
zeros=_zeros,
ones=_ones)
+pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
@@ -691,19 +693,57 @@ class GradientTape(object):
del g # Drop the reference to the tape
```
+ By default GradientTape will automatically watch any trainable variables that
+ are accessed inside the context. If you want fine grained control over which
+ variables are watched you can disable automatic tracking by passing
+ `watch_accessed_variables=False` to the tape constructor:
+
+ ```python
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(variable_a)
+ y = variable_a ** 2 # Gradients will be available for `variable_a`.
+ z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is
+ # not being watched.
+ ```
+
+ Note that when using models you should ensure that your variables exist when
+ using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
+ first iteration not have any gradients:
+
+ ```python
+ a = tf.keras.layers.Dense(32)
+ b = tf.keras.layers.Dense(32)
+
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(a.variables) # Since `a.build` has not been called at this point
+ # `a.variables` will return an empty list and the
+ # tape will not be watching anything.
+ result = b(a(inputs))
+ tape.gradient(result, a.variables) # The result of this computation will be
+ # a list of `None`s since a's variables
+ # are not being watched.
+ ```
+
Note that only tensors with real or complex dtypes are differentiable.
"""
- def __init__(self, persistent=False):
+ def __init__(self, persistent=False, watch_accessed_variables=True):
"""Creates a new GradientTape.
Args:
persistent: Boolean controlling whether a persistent gradient tape
is created. False by default, which means at most one call can
be made to the gradient() method on this object.
+ watch_accessed_variables: Boolean controlling whether the tape will
+ automatically `watch` any (trainable) variables accessed while the tape
+ is active. Defaults to True meaning gradients can be requested from any
+ result computed in the tape derived from reading a trainable `Variable`.
+ If False users must explicitly `watch` any `Variable`s they want to
+ request gradients from.
"""
self._tape = None
self._persistent = persistent
+ self._watch_accessed_variables = watch_accessed_variables
self._recording = False
context.context().start_step()
@@ -717,15 +757,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self, existing_tape=False):
+ def _push_tape(self):
if self._recording:
raise ValueError("Tape is already recording.")
- if existing_tape:
- if self._tape is None:
- raise ValueError("There is no existing tape.")
- tape.push_tape(self._tape)
+ if self._tape is None:
+ self._tape = tape.push_new_tape(
+ persistent=self._persistent,
+ watch_accessed_variables=self._watch_accessed_variables)
else:
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ tape.push_tape(self._tape)
self._recording = True
def _pop_tape(self):
@@ -744,7 +784,13 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- tape.watch(_handle_or_self(t))
+ if hasattr(t, "handle"):
+ # There are many variable-like objects, all of them currently have
+ # `handle` attribute that points to a tensor. If this changes, internals
+ # of watch_variable need to change as well.
+ tape.watch_variable(self._tape, t)
+ else:
+ tape.watch(self._tape, t)
@tf_contextlib.contextmanager
def stop_recording(self):
@@ -776,7 +822,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape(existing_tape=True)
+ self._push_tape()
def reset(self):
"""Clears all information stored in this tape.
@@ -810,6 +856,7 @@ class GradientTape(object):
```
"""
self._pop_tape()
+ self._tape = None
self._push_tape()
def watched_variables(self):
@@ -861,7 +908,9 @@ class GradientTape(object):
for x in nest.flatten(output_gradients)]
flat_grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, nest.flatten(target), flat_sources,
+ self._tape,
+ nest.flatten(target),
+ flat_sources,
output_gradients=output_gradients)
if not self._persistent:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 3d3f54b9c4..f938ed5df8 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -65,7 +64,7 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(fn, [0])(var)[0]
grad = self.evaluate(ops.convert_to_tensor(grad))
- with context.graph_mode(), self.test_session():
+ with context.graph_mode():
tf_var = array_ops.constant(var_np, dtypes.float32)
tf_ind1 = array_ops.constant([0, 1])
tf_ind2 = array_ops.constant([2, 3])
@@ -80,14 +79,13 @@ class BackpropTest(test.TestCase):
tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
- self.assertAllClose(grad, tf_dense_grad.eval())
+ self.assertAllClose(grad, self.evaluate(tf_dense_grad))
def testImplicitGradWithResourceVariable(self):
x = resource_variable_ops.ResourceVariable(
initial_value=constant_op.constant(1.0), name='x')
def fn():
- tape.watch_variable(x)
b = constant_op.constant(2.0)
c = math_ops.add(x.value(), b)
return math_ops.add(c, constant_op.constant(3.0))
@@ -194,14 +192,13 @@ class BackpropTest(test.TestCase):
initial_value=random_init, dtype=dtypes.float32, name='embedding')
def f():
- tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return constant_op.constant(1.0, dtypes.float32) - embedded_x
grad = backprop.implicit_grad(f)()[0][0]
opt = training.GradientDescentOptimizer(lrn_rate)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_x = array_ops.ones((batch_size), dtypes.int64)
# TODO(ashankar,apassos): Change to ResourceVariable.
tf_embedding = variables.Variable(
@@ -316,6 +313,24 @@ class BackpropTest(test.TestCase):
grad = backprop.gradients_function(second, [0])(f)[0]
self.assertAllEqual([[0.0]], grad)
+ @test_util.run_in_graph_and_eager_modes
+ def testWatchingIsTapeLocal(self):
+ x1 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+ x2 = resource_variable_ops.ResourceVariable(2.0, trainable=False)
+
+ with backprop.GradientTape() as tape1:
+ with backprop.GradientTape() as tape2:
+ tape1.watch(x1)
+ tape2.watch([x1, x2])
+ y = x1 ** 3
+ z = x2 ** 2
+ dy, dz = tape2.gradient([y, z], [x1, x2])
+ d2y, d2z = tape1.gradient([dy, dz], [x1, x2])
+
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertEqual(self.evaluate(d2y), 12.0)
+ self.assertIsNone(d2z)
+
@test_util.assert_no_new_tensors
def testMakeVJP(self):
@@ -404,7 +419,6 @@ class BackpropTest(test.TestCase):
def f():
with context.device('gpu:0'):
- tape.watch_variable(v)
return v.read_value()
self.assertEqual(
@@ -460,6 +474,18 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
@test_util.assert_no_new_tensors
+ def testGradientTapeReEnterContext(self):
+ g = backprop.GradientTape()
+ with g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2*x
+ with g:
+ z = 2*y
+ grad = g.gradient(target=z, sources=[x])
+ self.assertEqual(self.evaluate(grad), [4.0])
+
+ @test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes
def testGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=False) as g:
@@ -784,7 +810,6 @@ class BackpropTest(test.TestCase):
initial_value=array_ops.constant([1.0]), name='x')
def fn():
- tape.watch_variable(x)
a = math_ops.add(x.value(), 1.0)
# Make sure convert_to_tensor works correctly with list of TensorNodes.
b = array_ops.stack([a, a], axis=0)
@@ -928,21 +953,75 @@ class BackpropTest(test.TestCase):
def testZerosCacheDoesntLeakAcrossGraphs(self):
with context.graph_mode():
def get_grad():
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
- with backprop.GradientTape() as gt:
+ with backprop.GradientTape() as tape:
tape.watch(x)
x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
y1 = x1**2
y = array_ops.concat([y1, t], axis=1)
- return self.evaluate(gt.gradient(y, x))
+ return self.evaluate(tape.gradient(y, x))
grad1 = get_grad()
grad2 = get_grad()
self.assertAllEqual(grad1, grad2)
+ @test_util.run_in_graph_and_eager_modes
+ def testSelectivelyWatchVariables(self):
+ x1 = resource_variable_ops.ResourceVariable(1.0)
+ x2 = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(x2)
+ y = x1**2
+ z = x2**3
+ self.assertTupleEqual(tape.watched_variables(), (x2,))
+ dy, dz = tape.gradient([y, z], [x1, x2])
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertIsNone(dy)
+ self.assertEqual(self.evaluate(dz), 3.0)
+
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDifferentiatingScalarCache(self):
+ # In the following test, if x2 = x1 (i.e the objects are the exact same),
+ # then y is essentially, 2*x1, and dy/dx1 = 2.
+ # When we had a pure scalar cache in eager, this would be the case. This
+ # test prevents us from going back to that case.
+ with backprop.GradientTape(persistent=False) as g:
+ x1 = constant_op.constant(3.0)
+ x2 = constant_op.constant(3.0)
+ g.watch(x1)
+ g.watch(x2)
+ y = x1 + x2
+ grad = g.gradient(target=y, sources=[x1])
+ self.assertEqual(self.evaluate(grad), [1.0])
+
+ def testVariablesAndConstantsProduceTheSameGradients(self):
+
+ # In the following test, differentiating [y, z] against [a, b] gives:
+ # (dy/da + dz/da, dy/db + dz/db).
+ # If a and b are the same constant, dz/da will not be 0 (which it should
+ # be).
+ # This is solved by using variable since doing a read_value on a tensor will
+ # produce a new tensor and corresponding TensorHandle, and not reuse the
+ # same tensor (which would happen if we are using a cache and reusing
+ # EagerTensor objects).
+ def get_grads(a, b):
+ with backprop.GradientTape() as tape:
+ tape.watch([a, b])
+ y = a**3
+ z = b**2
+ return tape.gradient([y, z], [a, b])
+
+ gradients_constants = get_grads(
+ constant_op.constant(2.0), constant_op.constant(2.0))
+ gradients_variables = get_grads(
+ resource_variable_ops.ResourceVariable(2.0),
+ resource_variable_ops.ResourceVariable(2.0))
+ self.assertAllEqual(gradients_constants, gradients_variables)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index e2b1890c2f..3fe79ef244 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -175,6 +176,11 @@ class MicroBenchmarks(test.Benchmark):
self._run(func, 30000)
+ def benchmark_create_constant(self):
+ func = lambda: constant_op.constant(3.0)
+
+ self._run(func, 30000)
+
def benchmark_create_float_tensor_from_list_CPU(self):
self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
@@ -350,6 +356,21 @@ class MicroBenchmarks(test.Benchmark):
func = lambda: f(m, m, transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
+ def _benchmark_defun_matmul_forward_backward(self,
+ m,
+ transpose_b,
+ num_iters,
+ execution_mode=None):
+ f = function.defun(math_ops.matmul)
+
+ def func():
+ with backprop.GradientTape() as gt:
+ gt.watch(m)
+ y = f(m, m, transpose_b)
+ _ = gt.gradient(y, m)
+
+ self._run(func, num_iters, execution_mode=execution_mode)
+
def _benchmark_read_variable(self, m, num_iters):
self._run(m.value, num_iters)
@@ -421,6 +442,21 @@ class MicroBenchmarks(test.Benchmark):
num_iters=self._num_iters_2_by_2,
execution_mode=context.ASYNC)
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m,
+ transpose_b=False,
+ num_iters=self._num_iters_2_by_2,
+ execution_mode=context.ASYNC)
+
def benchmark_tf_matmul_2_by_2_GPU(self):
if not context.num_gpus():
return
@@ -682,6 +718,25 @@ class MicroBenchmarks(test.Benchmark):
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
+ def benchmarkScan(self):
+ elems = math_ops.range(1600)
+
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
+ def benchmarkScanDefun(self):
+ elems = math_ops.range(1600)
+
+ @function.defun
+ def scan():
+ return functional_ops.scan(
+ lambda a, x: a + x, elems, parallel_iterations=1)
+
+ self._run(scan, 100)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 6a327bd010..778ff85342 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -37,7 +37,7 @@ GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
-_default_mode = GRAPH_MODE
+default_execution_mode = GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
@@ -56,14 +56,18 @@ SYNC = 0
ASYNC = 1
-class _TensorCache(object):
+class _EagerTensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
- def __init__(self, max_items=256):
+ def __init__(self, max_items=256, max_tensor_size=10000):
self._data = collections.OrderedDict()
- self._max_items = max_items if max_items else 256
+ self._max_items = max_items
+ self._max_tensor_size = max_tensor_size
def put(self, key, value):
+ if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
+ return
+
self._data[key] = value
if len(self._data) > self._max_items:
@@ -84,14 +88,14 @@ class _EagerContext(threading.local):
super(_EagerContext, self).__init__()
self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
- self.mode = _default_mode
- self.is_eager = _default_mode == EAGER_MODE
+ self.mode = default_execution_mode
+ self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
- self.ones_rank_cache = _TensorCache()
- self.zeros_cache = _TensorCache()
+ self.ones_rank_cache = _EagerTensorCache()
+ self.zeros_cache = _EagerTensorCache()
self.execution_mode = None
@@ -111,8 +115,8 @@ class _ContextSwitchStack(threading.local):
# Initialize the stack with a pointer to enter the eager context; this
# ensures that the fact that eager execution was enabled is propagated
# across threads, since (1) `enable_eager_execution` modifies a
- # process-level flag (`_default_mode`) and (2) `__init__` is called each
- # time a threading.local object is used in a separate thread.
+ # process-level flag (`default_execution_mode`) and (2) `__init__` is
+ # called each time a threading.local object is used in a separate thread.
self.push(is_building_function=False, enter_context_fn=eager_mode)
def push(self, is_building_function, enter_context_fn):
@@ -504,9 +508,7 @@ class Context(object):
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
- pywrap_tensorflow.TFE_ContextAddFunction(
- self._handle, # pylint: disable=protected-access
- fn)
+ pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@@ -519,9 +521,7 @@ class Context(object):
"""
fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef(
- self._handle, # pylint: disable=protected-access
- fdef_string,
- len(fdef_string))
+ self._handle, fdef_string, len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
@@ -633,14 +633,7 @@ def context():
def context_safe():
- return _context
-
-
-# TODO(agarwal): remove this.
-def get_default_context():
- """Same as context."""
- if _context is None:
- _initialize_context()
+ """Returns current context (or None if one hasn't been initialized)."""
return _context
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index cc765725a4..fb5442b646 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import pickle
import threading
import numpy as np
@@ -185,6 +187,17 @@ class TFETest(test_util.TensorFlowTestCase):
device_count={'GPU': 0}))
self.assertEquals(0, ctx.num_gpus())
+ def testPickle(self):
+ tmp_dir = self.get_temp_dir()
+ fname = os.path.join(tmp_dir, 't.pickle')
+ with open(fname, 'wb') as f:
+ t = constant_op.constant(10.0)
+ pickle.dump(t, f)
+
+ with open(fname, 'rb') as f:
+ t = pickle.load(f)
+ self.assertAllEqual(t.numpy(), 10.0)
+
def testTensorPlacement(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -676,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase):
2.0)
+class EagerTensorCacheTest(test_util.TensorFlowTestCase):
+
+ def testCacheSkipsTensorsTooLarge(self):
+ cache = context._EagerTensorCache(max_items=100, max_tensor_size=3)
+ cache.put('1', array_ops.zeros((2, 2)))
+ self.assertEqual(cache.get('1'), None)
+
+ cache.put('2', array_ops.zeros((2)))
+ self.assertNotEqual(cache.get('2'), None)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py
index 9a08259653..80ff4459d6 100644
--- a/tensorflow/python/eager/execution_callbacks.py
+++ b/tensorflow/python/eager/execution_callbacks.py
@@ -146,7 +146,7 @@ def inf_nan_callback(op_type,
"""
del attrs, inputs # Not used.
- ctx = context.get_default_context()
+ ctx = context.context()
for index, output in enumerate(outputs):
if not output.dtype.is_numpy_compatible:
@@ -263,12 +263,12 @@ def add_execution_callback(callback):
Return value(s) from the callback are ignored.
"""
execute.execute = execute.execute_with_callbacks
- context.get_default_context().add_post_execution_callback(callback)
+ context.context().add_post_execution_callback(callback)
def clear_execution_callbacks():
"""Clear all execution callbacks from the default eager context."""
- context.get_default_context().clear_post_execution_callbacks()
+ context.context().clear_post_execution_callbacks()
def seterr(inf_or_nan=None):
@@ -309,7 +309,7 @@ def seterr(inf_or_nan=None):
"Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
old_settings = {"inf_or_nan": "ignore"}
- default_context = context.get_default_context()
+ default_context = context.context()
carryover_callbacks = []
for callback in default_context.post_execution_callbacks:
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 5afba466bc..03f12139f6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,12 +21,12 @@ from __future__ import print_function
import collections
import functools
+import sys
import threading
import numpy as np
import six
-from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -34,10 +34,12 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
@@ -49,8 +51,15 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+# This is to avoid a circular dependency with cond_v2_impl
+# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
+cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-def create_substitute_placeholder(value, name, dtype=None):
+# This is to avoid a circular dependency with gradients_impl
+gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+
+
+def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
@@ -82,82 +91,131 @@ def create_substitute_placeholder(value, name, dtype=None):
return placeholder
-def capture_value(tensor_map, value, dtype, name):
- """Capture a value from outside the function, to pass in as an extra arg."""
- captured_value = tensor_map.get(value, None)
- if captured_value is None:
- captured_value = create_substitute_placeholder(value, name=name,
- dtype=dtype)
- tensor_map[value] = captured_value
- tape.record_operation("captured_value", [captured_value], [value],
- lambda x: [x])
- return captured_value
+def _get_device_functions(ctx, graph):
+ """Returns a tuple of device functions representing the device stack."""
+ if ctx.executing_eagerly():
+ return (pydev.merge_device(ctx.device_name),)
+ else:
+ return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
-class CapturingGraph(ops.Graph):
- """Graph that can capture tensors from other graphs.
+class FuncGraph(ops.Graph):
+ """Graph representing a function body.
Attributes:
- captures: Maps external tensor -> internal tensor (e.g. input placeholder).
+ name: The name of the function.
+ inputs: Placeholder tensors representing the inputs to this function. The
+ tensors are in this FuncGraph. This represents "regular" inputs as well as
+ captured inputs (i.e. the values of self.captures), with the regular
+ inputs coming first.
+ outputs: Tensors that will be returned by this function. The tensors are in
+ this FuncGraph.
+ structured_outputs: A possibly-nested python object which will be returned
+ by this function. The Tensors in this structure are the same as those of
+ self.outputs. Note that this structure might contain Python `None`s.
+ variables: Variables that should be watched during function execution.
+ outer_graph: The graph this function is defined in. May be another FuncGraph
+ or the global default Graph.
+ captures: Maps external tensor -> internal tensor (i.e. input placeholder).
The entries are in the order they were captured.
+ seed: The graph-level random seed.
"""
- def __init__(self):
- super(CapturingGraph, self).__init__()
+ def __init__(self, name):
+ """Construct a new FuncGraph.
+
+ The graph will inherit its graph key, collections, seed, device stack, and
+ distribution strategy stack from the current context or graph.
+ Args:
+ name: the name of the function.
+ """
+ super(FuncGraph, self).__init__()
+
+ self.name = name
+ self.inputs = []
+ self.outputs = []
+ self.structured_outputs = None
+ self.variables = []
+ self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
- self._building_function = True
+ self._building_function = True
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
self._last_op_using_resource_tensor = {}
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- return True
-
- def clear_resource_control_flow_state(self):
- self._last_op_using_resource_tensor = {}
+ graph = self.outer_graph
- # TODO(skyewm): get rid of name and use the name of `tensor`.
- def capture(self, tensor, name=None):
- """Capture `tensor` if it's external to this graph.
-
- If `tensor` is from a different graph, returns a placeholder for it.
- `tensor` and the placeholder will also appears in self.captures. Multiple
- calls to this method with the same `tensor` argument will return the same
- placeholder. If `tensor` is from this graph, returns `tensor`.
-
- Args:
- tensor: Tensor. May be from this FuncGraph or a different graph.
- name: Optional name if a placeholder is created.
-
- Returns:
- Tensor from this FuncGraph.
- """
- if isinstance(tensor, ops.EagerTensor):
- if name is None:
- name = str(ops.uid())
- return capture_value(self.captures, tensor, tensor.dtype, name)
- if tensor.graph is not self:
- if name is None:
- name = tensor.op.name
- return capture_value(self.captures, tensor, tensor.dtype, name)
- return tensor
+ if context.executing_eagerly():
+ self.seed = context.global_seed()
+ self._xla_compile = (context.context().device_spec.device_type == "TPU")
+ self._add_device_to_stack(context.context().device_name)
+ else:
+ self.seed = graph.seed
+ self._xla_compile = getattr(graph, "_xla_compile", False)
+ self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access
+ self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
+
+ # TODO(b/112165328, b/112906995): summaries depend on inheriting collections
+ # from the default graph even in eager mode. It'd be nice to not have a
+ # default graph with eager execution, so hopefully this will go away when we
+ # remove collections.
+ # pylint: disable=protected-access
+ self._collections = graph._collections
+ # TODO(b/112906995): distribution strategy depends on inheriting this stack
+ # from the default graph even in eager mode. Maybe it should be part of the
+ # eager context?
+ self._distribution_strategy_stack = graph._distribution_strategy_stack
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ self._graph_key = graph._graph_key
+ # pylint: enable=protected-access
def create_op(
self,
op_type,
inputs,
- dtypes, # pylint: disable=redefined-outer-name
+ dtypes,
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
- """Captures an external inputs before calling Graph.capture_op."""
+ """Like Graph.create_op, except handles external input tensors.
+
+ This overload adds functionality to create_op to "capture" any external
+ input tensors, i.e. tensors from the eager context or outer function graphs
+ if this is a nested function. See `capture` for more information.
+
+ Args:
+ op_type: The `Operation` type to create. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+ dtypes: A list of `DType` objects that will be the types of the tensors
+ that the operation produces.
+ input_types: (Optional.) A list of `DType`s that will be the types of
+ the tensors that the operation consumes. By default, uses the base
+ `DType` of each input in `inputs`. Operations that expect
+ reference-typed inputs must specify `input_types` explicitly.
+ name: (Optional.) A string name for the operation. If not specified, a
+ name is generated based on `op_type`.
+ attrs: (Optional.) A dictionary where the key is the attribute name (a
+ string) and the value is the respective `attr` attribute of the
+ `NodeDef` proto that will represent the operation (an `AttrValue`
+ proto).
+ op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+ the operation will have.
+ compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+ computed).
+ compute_device: (Optional.) If True, device functions will be executed
+ to compute the device property of the Operation.
+
+ Returns:
+ An `Operation` object.
+ """
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
@@ -171,80 +229,61 @@ class CapturingGraph(ops.Graph):
# to capture the inputs.
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
+ # TPU Estimator defines a control flow context with no AddValue method.
if ctxt is not None and hasattr(ctxt, "AddValue"):
inp = ctxt.AddValue(inp)
inp = self.capture(inp)
inputs[i] = inp
- return super(CapturingGraph, self).create_op(
+ return super(FuncGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
+ def capture(self, tensor, name=None):
+ """Captures `tensor` if it's external to this graph.
-class FuncGraph(CapturingGraph):
- """Graph representing a function body.
-
- Attributes:
- name: The name of the function.
-
- inputs: Placeholder tensors representing the inputs to this function. The
- tensors are in this FuncGraph. This represents "regular" inputs as well as
- captured inputs (i.e. the values of self.captures), with the regular
- inputs coming first.
- outputs: Tensors that will be returned by this function. The tensors are in
- this FuncGraph.
- structured_outputs: A possibly-nested python object which will be returned
- by this function. The Tensors in this structure are the same as those of
- self.outputs. Note that this structure might contain Python `None`s.
- variables: Variables that should be watched during function execution.
- seed: The graph-level random seed.
- """
-
- def __init__(self, name, graph=None):
- """Construct a new FuncGraph.
+ If `tensor` is from a different graph, returns a placeholder for it.
+ `tensor` and the placeholder will appear in self.captures, and the
+ placeholder will appear in self.inputs. Multiple calls to this method with
+ the same `tensor` argument will return the same placeholder. If `tensor` is
+ from this graph, returns `tensor`.
Args:
- name: the name of the function.
- graph: if specified, this FuncGraph will inherit its graph key,
- collections, and seed from `graph`.
- """
- super(FuncGraph, self).__init__()
-
- self.name = name
- self.inputs = []
- self.outputs = []
- self.structured_outputs = None
- self.variables = []
-
- if graph is not None:
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- self._graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Copy the graph collections to ensure summaries and other things work.
- # This lets the function access (but not mutate) collections of the
- # containing graph, such as the global step and the summary writer
- # collections.
- for collection in graph.collections:
- self.get_collection_ref(collection)[:] = graph.get_collection(
- collection)
-
- # Copy distribution strategy scope from the containing graph as well.
- self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
- if context.executing_eagerly():
- self.seed = context.global_seed()
- else:
- self.seed = graph.seed
+ Returns:
+ Tensor from this FuncGraph.
+ """
+ if isinstance(tensor, ops.EagerTensor):
+ if name is None:
+ name = str(ops.uid())
+ return self._capture_helper(tensor, name)
+ if tensor.graph is not self:
+ if name is None:
+ name = tensor.op.name
+ return self._capture_helper(tensor, name)
+ return tensor
- def capture(self, tensor, name=None):
- """Calls CapturingGraph.capture and updates self.inputs if necessary."""
- new_capture = tensor not in self.captures
- internal_tensor = super(FuncGraph, self).capture(tensor, name)
+ def _capture_helper(self, tensor, name):
+ captured_tensor = self.captures.get(tensor, None)
+ if captured_tensor is None:
+ captured_tensor = _create_substitute_placeholder(tensor, name=name,
+ dtype=tensor.dtype)
+ self.captures[tensor] = captured_tensor
+ self.inputs.append(captured_tensor)
+ tape.record_operation("captured_value", [captured_tensor], [tensor],
+ lambda x: [x])
+ return captured_tensor
- if new_capture and tensor is not internal_tensor:
- self.inputs.append(internal_tensor)
+ @property
+ def external_captures(self):
+ """External tensors captured by this function."""
+ return list(self.captures.keys())
- return internal_tensor
+ @property
+ def internal_captures(self):
+ """Placeholders in this function corresponding captured tensors."""
+ return list(self.captures.values())
def _forward_name(n):
@@ -267,9 +306,6 @@ def _register(fn):
context.context().add_function(fn)
-_xla_compile_attr = "_XlaCompile"
-
-
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
@@ -282,18 +318,20 @@ class _EagerDefinedFunction(object):
class may be provided as the value of these `func` attributes.
"""
- def __init__(self, name, graph, operations, inputs, outputs, attrs):
+ def __init__(self, name, graph, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
- operations: list of Operation; the subset of operations in the graph
- which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
attrs: dict mapping names of attributes to their AttrValue values
"""
+ operations = [
+ op for op in graph.get_operations()
+ if op not in set(arg.op for arg in inputs)
+ ]
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
@@ -311,7 +349,6 @@ class _EagerDefinedFunction(object):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
- self._xla_compile = _xla_compile_attr in attrs
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
@@ -327,6 +364,7 @@ class _EagerDefinedFunction(object):
self.signature = function_def.signature
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
+ self._output_shapes = [o.shape for o in outputs]
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -347,7 +385,7 @@ class _EagerDefinedFunction(object):
def stateful_ops(self):
return self._stateful_ops
- def call(self, ctx, args, output_shapes):
+ def call(self, ctx, args):
"""Calls this function with `args` as inputs.
Function execution respects device annotations only if the function won't
@@ -356,8 +394,6 @@ class _EagerDefinedFunction(object):
Args:
ctx: a Context object
args: a list of arguments to supply this function with.
- output_shapes: shapes to which outputs should be set; ignored when
- executing eagerly.
Returns:
The outputs of the function call.
@@ -365,10 +401,7 @@ class _EagerDefinedFunction(object):
executing_eagerly = ctx.executing_eagerly()
- xla_compile = self._xla_compile or (executing_eagerly and
- ctx.device_spec.device_type == "TPU")
-
- if xla_compile:
+ if self._graph._xla_compile: # pylint: disable=protected-access
# XLA compilation relies upon a custom kernel creator to run functions.
signature = self.signature
if executing_eagerly:
@@ -406,7 +439,7 @@ class _EagerDefinedFunction(object):
if executing_eagerly:
return outputs
else:
- for i, shape in enumerate(output_shapes):
+ for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
return outputs
@@ -427,179 +460,117 @@ def _flatten(sequence):
return outputs
-# TODO(akshayka): Perhaps rename to something more appropriate.
-class GraphModeFunction(object):
+class Function(object):
"""Callable object encapsulating a function definition and its gradient.
- `GraphModeFunction` is a callable that encapsulates a function definition and
+ `Function` is a callable that encapsulates a function definition and
is differentiable under `tf.GradientTape` objects.
"""
- def __init__(self,
- name,
- input_placeholders,
- extra_inputs,
- graph,
- operations,
- outputs,
- python_func_outputs,
- output_shapes,
- variables=None,
- attrs=None):
- """Initialize a GraphModeFunction.
+ def __init__(self, func_graph, attrs=None):
+ """Initialize a Function.
Args:
- name: str the name of the created function
- input_placeholders: list of placeholder values (tensors) to feed when
- calling the wrapped function.
- extra_inputs: Tensor inputs this function definition closed over which
- are passed as arguments. Need to track so gradients are supported
- correctly.
- graph: the Graph from which the operations will be pulled. Used as
- a context when computing gradients.
- operations: the subset of Operations in the graph used in the function
- definition.
- outputs: a flat list of the Tensors in the graph used as outputs to the
- function
- python_func_outputs: a possibly nested python object which will be
- returned by this function. The Tensors in this structure will be
- replaced by their corresponding values in outputs. Note that this
- structure might contain Python `None`s.
- output_shapes: List of shapes of all tensors in outputs
- variables: (optional) List of variables to watch during function
- execution.
+ func_graph: An instance of FuncGraph: the function body to wrap.
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
+
+ Raises:
+ ValueError: If number of input_placeholders is not equal to the number
+ of function inputs.
"""
+ self._func_graph = func_graph
+ self._captured_inputs = list(self._func_graph.captures.keys())
+ self._num_outputs = len(self._func_graph.outputs)
+ self._output_shapes = tuple(
+ output.shape for output in self._func_graph.outputs)
self._attrs = attrs or {}
- defined_function = _EagerDefinedFunction(
- name, graph, operations, input_placeholders, outputs, self._attrs)
- if len(input_placeholders) != len(defined_function.signature.input_arg):
- raise ValueError("Internal error: invalid lengths. %s %s" % (
- len(input_placeholders), len(defined_function.signature.input_arg)))
- self._input_placeholders = input_placeholders
- self._extra_inputs = list(extra_inputs)
- self._graph = graph
- self._backward_function = None
- self._func_name = name
- self._function_def = defined_function
- self._num_outputs = len(defined_function.signature.output_arg)
- self._python_func_outputs = python_func_outputs
- self._python_returns = [python_func_outputs] if isinstance(
- python_func_outputs,
- (ops.Tensor, type(None))) else _flatten(python_func_outputs)
- self._output_shapes = output_shapes
- self._variables = variables if variables is not None else []
-
- # Find the variables that are components of something distributed and
- # put them into a {handle_tensor -> distributed variable object} map.
+ self._device_functions = tuple(
+ self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+
+ self._inference_function = _EagerDefinedFunction(
+ _inference_name(self._func_graph.name), self._func_graph,
+ self._func_graph.inputs, self._func_graph.outputs, self._attrs)
+ self._backward_graph_function = None
+
+ # Map holding distributed variables, keyed by resource handle tensors.
self._distributed_variables = {}
strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._variables:
+ for variable in self._func_graph.variables:
# If variable is not distributed, unwrap returns [variable].
component_variables = strategy.unwrap(variable)
- # Only add to the dictionary when the variable is actually distributed,
- # i.e. more than one component or the component is different from the
- # variable itself. component_variables cannot be empty.
+ # Only update the dictionary when the variable is actually distributed.
if (len(component_variables) > 1 or component_variables[0] != variable):
for component_variable in component_variables:
self._distributed_variables[component_variable.handle] = variable
- @property
- def variables(self):
- return self._variables
+ def __call__(self, *args):
+ """Executes the wrapped function."""
+ ctx = context.context()
+ device_functions = _get_device_functions(ctx, ops.get_default_graph())
+ if device_functions != self._device_functions:
+ raise ValueError(
+ "The current device stack does not match the device stack under "
+ "which the TensorFlow function '%s' was created.\n"
+ "Current device stack: %s\n%s device stack: %s" %
+ (self._inference_function.name, device_functions,
+ self._inference_function.name, self._device_functions))
+
+ for v in self._func_graph.variables:
+ if v.trainable:
+ tape.variable_accessed(v)
- def _construct_backprop_function(self):
- """Constructs the backprop function object for this function."""
- filtered_outputs = [x for x in self._python_returns if x is not None]
- # TODO(skyewm): use FuncGraph
- backwards_graph = CapturingGraph()
- backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access
- for collection in self._graph.collections:
- backwards_graph.get_collection_ref(
- collection)[:] = self._graph.get_collection(collection)
- backwards_graph.seed = self._graph.seed
- with backwards_graph.as_default():
- self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access
- filtered_outputs,
- self._input_placeholders,
- grad_ys=self._out_grad_placeholders,
- src_graph=self._graph)
-
- backward_outputs = tuple(
- grad for grad in _flatten(in_gradients) if grad is not None)
- output_shapes = tuple(grad.shape for grad in backward_outputs)
-
- extra_inputs = backwards_graph.captures.keys()
- extra_placeholders = backwards_graph.captures.values()
-
- forward_name = _forward_name(self._func_name)
- # Note: we cannot have placeholder ops in the graph or the TPU compilation
- # pass fails.
- placeholder_ops = set([y.op for y in self._input_placeholders])
- function_ops = [x for x in self._graph.get_operations()
- if x not in placeholder_ops]
- self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, function_ops,
- self._input_placeholders, filtered_outputs + list(extra_inputs),
- self._attrs)
- all_inputs = self._out_grad_placeholders + list(extra_placeholders)
- # Excluding input ops from the body as we do not intend to execute these
- # operations when the function is executed.
- all_ignored_ops = frozenset(x.op for x in all_inputs)
- # Enforce a deterministic order of operations in the generated graph. This
- # means rerunning the function-defining code will always define the same
- # function, which is useful if we serialize this etc.
- function_def_ops = tuple(x
- for x in sorted(backwards_graph.get_operations(),
- key=lambda x: x.name)
- if x not in all_ignored_ops)
- bname = _backward_name(self._func_name)
- self._backward_function = GraphModeFunction(
- bname, all_inputs, [], backwards_graph, function_def_ops,
- backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
+ captures = self._resolve_captured_inputs()
+ tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ args = tensor_inputs + captures
- def _backprop_call(self, args):
- """Calls the wrapped function and records the result on a tape.
+ if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ return self._backprop_call(args)
- (Only records results on a tape if the function has outputs)
+ outputs = self._inference_function.call(ctx, args)
+ return self._build_call_outputs(outputs)
- Args:
- args: All inputs to the function, including resolved extra inputs
- Returns:
- The call output.
- """
- ctx = context.context()
- outputs = self._forward_fdef.call(ctx, args, self._output_shapes)
- if isinstance(outputs, ops.Operation) or outputs is None:
- return outputs
+ @property
+ def graph(self):
+ """Returns the graph from which this function was constructed."""
+ return self._func_graph
- # `real_outputs` are the actual outputs of the inference graph function;
- # `side_outputs` are the intermediate Tensors that were added as outputs to
- # the forward graph function so that we can compute its gradient.
- real_outputs = outputs[:self._num_outputs]
- side_outputs = outputs[self._num_outputs:]
+ @property
+ def variables(self):
+ """Returns all variables touched by this function."""
+ return self._func_graph.variables
- def backward_function(*args):
- return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
+ @property
+ def inputs(self):
+ """Returns tensors in `self.graph` corresponding to arguments."""
+ return self._func_graph.inputs
- tape.record_operation(
- self._forward_fdef.signature.name,
- real_outputs,
- args,
- backward_function)
+ @property
+ def outputs(self):
+ """Returns tensors in `self.graph` corresponding to return values."""
+ return self._func_graph.outputs
- return self._build_call_outputs(real_outputs)
+ @property
+ def captured_inputs(self):
+ """Returns external Tensors captured by this function.
+
+ self.__call__(*args) passes `args + self.captured_inputs` to the function.
+ """
+ return self._captured_inputs
+
+ @property
+ def function_def(self):
+ """Returns a `FunctionDef` object representing this function."""
+ return self._inference_function.definition
@property
def output_shapes(self):
"""The function's output shapes."""
# TODO(ebrevdo): Should we only keep the output shapes associated
# with len(self._python_returns) outputs?
- outputs_list = nest.flatten(self._python_func_outputs)
+ # TODO(akshayka): Consider removing this.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -613,23 +584,80 @@ class GraphModeFunction(object):
else:
outputs_list[i] = self._output_shapes[j]
j += 1
- return nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ return nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
@property
def output_dtypes(self):
- return nest.map_structure(
- lambda x: x.dtype if x is not None else None, self._python_func_outputs)
+ # TODO(akshayka): Consider removing this.
+ return nest.map_structure(lambda x: x.dtype if x is not None else None,
+ self._func_graph.structured_outputs)
- @property
- def captured_inputs(self):
- return self._extra_inputs
+ def _construct_backprop_function(self):
+ """Constructs the backprop function object for this function."""
+ backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ with backwards_graph.as_default():
+ gradients_wrt_outputs = [
+ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
+ ]
+ gradients_wrt_inputs = gradients_impl._GradientsHelper( # pylint: disable=protected-access
+ self._func_graph.outputs,
+ self._func_graph.inputs,
+ grad_ys=gradients_wrt_outputs,
+ src_graph=self._func_graph)
+
+ self._forward_function = _EagerDefinedFunction(
+ _forward_name(
+ self._func_graph.name), self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + list(backwards_graph.captures.keys()),
+ self._attrs)
- @property
- def name(self):
- """Returns the name of the function in Eager-compatible format."""
- return self._function_def.name.encode("utf-8")
+ # The ordering of `backwards_graph.inputs` is important: inputs of
+ # `self._backward_graph_function` correspond to outputs of
+ # `self._forward_function`.
+ backwards_graph.inputs = gradients_wrt_outputs + list(
+ backwards_graph.captures.values())
+ # Clear captures, since we pass them in as inputs.
+ backwards_graph.captures = {}
+ backwards_graph.outputs.extend(
+ grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
+ backwards_graph.structured_outputs = gradients_wrt_inputs
+ self._backward_graph_function = Function(
+ backwards_graph, attrs=self._attrs)
+
+ def _backprop_call(self, args):
+ """Calls the forward function and records the result on a tape.
+
+ (Only records results on a tape if the function has outputs)
- def _resolve_extra_inputs(self):
+ Args:
+ args: All inputs to the function, including resolved captured inputs
+
+ Returns:
+ The call output.
+ """
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+
+ ctx = context.context()
+ outputs = self._forward_function.call(ctx, args)
+ if isinstance(outputs, ops.Operation) or outputs is None:
+ return outputs
+
+ # `real_outputs` are the actual outputs of the inference graph function;
+ # `side_outputs` are the intermediate Tensors that were added as outputs to
+ # the forward graph function so that we can compute its gradient.
+ real_outputs = outputs[:self._num_outputs]
+ side_outputs = outputs[self._num_outputs:]
+
+ def backward_function(*args):
+ return self._backward_graph_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
+
+ tape.record_operation(self._forward_function.signature.name, real_outputs,
+ args, backward_function)
+ return self._build_call_outputs(real_outputs)
+
+ def _resolve_captured_inputs(self):
"""Resolve captured distributed variables to their current values.
Some inputs can be distributed variables. Such variables yield a different
@@ -637,44 +665,23 @@ class GraphModeFunction(object):
execution.
Returns:
- a list of resolved extra input tensors.
+ a list of resolved captured input tensors.
"""
if self._distributed_variables:
- # Loop over each extra_inputs and check if it corresponds to something
+ # Loop over each captured input and check if it corresponds to something
# distributed. If so, get its _distributed_container and fetch the
# component appropriate for the current execution context.
- resolved_extra_inputs = self._extra_inputs[:]
- for i, extra_input in enumerate(self._extra_inputs):
- distributed_var = self._distributed_variables.get(extra_input, None)
+ resolved_captured_inputs = self._captured_inputs[:]
+ for i, captured_input in enumerate(self._captured_inputs):
+ distributed_var = self._distributed_variables.get(captured_input, None)
if distributed_var is not None:
# distributed variables override __getattr__ and substitute the
# right component variable. In here, `distributed_var.handle`
# actually does the equivalent of
# distributed_var.get_current_component_var().handle.
- resolved_extra_inputs[i] = distributed_var.handle
- return resolved_extra_inputs
-
- return self._extra_inputs
-
- def __call__(self, *args):
- """Executes the passed function in eager mode."""
- for v in self._variables:
- if v.trainable:
- tape.watch_variable(v)
-
- resolved_extra_inputs = self._resolve_extra_inputs()
-
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
- args = tensor_inputs + resolved_extra_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(
- resolved_extra_inputs):
- if self._backward_function is None:
- self._construct_backprop_function()
- return self._backprop_call(args)
-
- ctx = context.context()
- outputs = self._function_def.call(ctx, args, self._output_shapes)
- return self._build_call_outputs(outputs)
+ resolved_captured_inputs[i] = distributed_var.handle
+ return resolved_captured_inputs
+ return self._captured_inputs
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -684,12 +691,12 @@ class GraphModeFunction(object):
Returns:
The actual call output.
"""
- if self._python_func_outputs is None:
+ if self._func_graph.structured_outputs is None:
return result
# Use `nest.flatten` instead of `_flatten` in order to preserve any
- # IndexedSlices in `self._python_func_outputs`.
- outputs_list = nest.flatten(self._python_func_outputs)
+ # IndexedSlices in `self._func_graph.structured_outputs`.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -703,13 +710,13 @@ class GraphModeFunction(object):
j += 3
else:
outputs_list[i] = ops.IndexedSlices(
- values=result[j],
- indices=result[j + 1])
+ values=result[j], indices=result[j + 1])
j += 2
else:
outputs_list[i] = result[j]
j += 1
- ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list)
+ ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
return ret
@@ -725,20 +732,18 @@ def _get_defun_inputs_from_signature(signature):
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
- graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor)
- else arg for arg in nest.flatten(args)
+ graph_placeholder(arg.dtype, arg.shape)
+ if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
-def _trace_and_define_function(name, python_func, compiled, args, kwds,
- signature=None):
- """Defines and returns graph-mode version of `python_func`.
+def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+ """Returns a `FuncGraph` generated from `python_func`.
Args:
name: an identifier for the function.
python_func: the Python function to trace.
- compiled: whether the graph function should be compiled through XLA.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
kwds: the keyword args with which the Python function should be called;
@@ -750,14 +755,13 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
inputs.
Returns:
- A GraphModeFunction.
+ A FuncGraph.
Raises:
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph())
-
+ func_graph = FuncGraph(name)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
@@ -771,8 +775,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
func_graph.inputs.extend(
x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor)
- )
+ if isinstance(x, ops.Tensor))
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
@@ -797,6 +800,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
this_tape = tape.push_new_tape()
try:
func_outputs = python_func(*func_args, **func_kwds)
+ # invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
def check_mutation(n1, n2):
@@ -816,53 +820,34 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds,
check_mutation(func_args_before, func_args)
check_mutation(func_kwds_before, func_kwds)
-
finally:
tape.pop_tape(this_tape)
+
func_graph.structured_outputs = func_outputs
+ # Returning a closed-over tensor does not trigger convert_to_tensor.
+ func_graph.outputs.extend(
+ func_graph.capture(x)
+ for x in _flatten(func_graph.structured_outputs)
+ if x is not None)
+
+ # Some captured variables might be components of DistributedValues.
+ # Instead of storing non-distributed component variables, we
+ # store their distributed containers so we can retrieve the correct
+ # component variables at call-time.
variables = list(this_tape.watched_variables())
-
- # Some variables captured by the tape can come from a DistributedValue.
- # At call time, DistributedValue can return another variable (e.g. if
- # the function is run on a different device). Thus, instead of storing
- # the specific captured variable, we replace it with its distributed
- # container.
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
variables[i] = strategy.value_container(variable)
-
func_graph.variables = variables
- # Returning a closed-over tensor as an output does not trigger a
- # call to convert_to_tensor, so we manually capture all such tensors.
- func_graph.outputs.extend(
- func_graph.capture(x) for x in _flatten(func_graph.structured_outputs)
- if x is not None
- )
-
- output_shapes = tuple(
- x.shape if isinstance(x, ops.Tensor) else None
- for x in func_graph.outputs)
-
- all_ignored_ops = frozenset(x.op for x in func_graph.inputs)
- operations = tuple(x for x in func_graph.get_operations()
- if x not in all_ignored_ops)
- # Register any other functions defined in the graph
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
+ # Register any other functions defined in the graph.
if context.executing_eagerly():
for f in func_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
_register(f._c_func.func) # pylint: disable=protected-access
- attrs = {}
- if compiled:
- attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
-
- return GraphModeFunction(
- func_graph.name, func_graph.inputs, func_graph.captures.keys(),
- func_graph, operations, func_graph.outputs, func_graph.structured_outputs,
- output_shapes, func_graph.variables, attrs)
+ return func_graph
_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
@@ -911,13 +896,13 @@ def _deterministic_dict_values(dictionary):
return tuple(dictionary[key] for key in sorted(dictionary))
-class _PolymorphicFunction(object):
+class PolymorphicFunction(object):
"""Wrapper class for the graph functions defined for a Python function.
See the documentation for `defun` for more information on the semantics of
defined functions.
- _PolymorphicFunction class is thread-compatible meaning that minimal
+ PolymorphicFunction class is thread-compatible meaning that minimal
usage of defuns (defining and calling) is thread-safe, but if users call other
methods or invoke the base `python_function` themselves, external
synchronization is necessary.
@@ -926,8 +911,7 @@ class _PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None,
- compiled=False):
+ input_signature=None):
"""Initializes a polymorphic function.
Args:
@@ -936,14 +920,10 @@ class _PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
- compiled: if True, the framework will attempt to compile func with XLA.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
- TypeError: if `input_signature` contains anything other than
- `TensorSpec` objects, or (if not None) is anything other than a tuple or
- list.
"""
if isinstance(python_function, functools.partial):
@@ -955,8 +935,7 @@ class _PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwds_to_include = {}
self._name = name
- self._compiled = compiled
- self._arguments_to_functions = {}
+ self._function_cache = collections.OrderedDict()
self._variables = []
self._lock = threading.Lock()
@@ -991,15 +970,40 @@ class _PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- if any(not isinstance(arg, tensor_spec.TensorSpec)
- for arg in self._flat_input_signature):
- raise TypeError("Invalid input_signature %s; input_signature must be "
- "a possibly nested sequence of TensorSpec objects.")
+
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized to the inputs."""
+ graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ return graph_function(*inputs)
+
+ @property
+ def python_function(self):
+ """Returns the wrapped Python function."""
+ return self._python_function
+
+ # TODO(akshayka): Remove this property.
+ @property
+ def variables(self):
+ """Returns the union of all variables referenced by cached `Function`s`."""
+ return self._variables
+
+ def get_concrete_function(self, *args, **kwargs):
+ """Returns a `Function` object specialized to inputs and execution context.
+
+ `args` and `kwargs` are ignored if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
+ Args:
+ *args: inputs to specialize on.
+ **kwargs: inputs to specialize on.
+ """
+ graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ return graph_function
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
- # `instance` here is the instance that this `_PolymorphicFunction` was
+ # `instance` here is the instance that this `PolymorphicFunction` was
# accessed through; e.g., for
#
# class Foo(object):
@@ -1009,29 +1013,42 @@ class _PolymorphicFunction(object):
# ...
#
# foo = Foo()
- # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance
+ # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance
#
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds):
- """Computes the cache key given inputs."""
+ def _cache_key(self, args, kwds, ctx, graph):
+ """Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwds) if kwds else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
del args, kwds
cache_key = self._flat_input_signature
+
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- return cache_key + (context.executing_eagerly() or ops.get_default_graph(),)
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or graph
+
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = _get_device_functions(ctx, graph)
+
+ # `ops.colocate_with` directives translate into `ops.device` directives when
+ # eager execution is enabled.
+ colocation_stack = (None if executing_eagerly else
+ tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+
+ return cache_key + (execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwds):
"""Canonicalizes `args` and `kwds`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
- `_PolymorphicFunction` was called with into a tuple corresponding to the
+ `PolymorphicFunction` was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
@@ -1085,8 +1102,9 @@ class _PolymorphicFunction(object):
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
- tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor)
- for tensor in flat_inputs]
+ tensor_specs = [
+ tensor_spec.TensorSpec.from_tensor(tensor) for tensor in flat_inputs
+ ]
if any(not spec.is_compatible_with(other)
for spec, other in zip(self._flat_input_signature, tensor_specs)):
raise ValueError("Python inputs incompatible with input_signature: "
@@ -1111,42 +1129,33 @@ class _PolymorphicFunction(object):
"""
args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds)
+ cache_key = self._cache_key(args, kwds, context.context(),
+ ops.get_default_graph())
with self._lock:
try:
- graph_function = self._arguments_to_functions.get(cache_key, None)
+ graph_function = self._function_cache.get(cache_key, None)
except TypeError:
raise TypeError("Arguments supplied to `defun`-generated functions "
"must be hashable.")
if graph_function is None:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds,
- self._input_signature)
+ graph_function = Function(
+ func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
- self._arguments_to_functions[cache_key] = graph_function
+ self._function_cache[cache_key] = graph_function
return graph_function, (args, kwds)
- def __call__(self, *args, **kwds):
- """Calls a graph function specialized for this input signature."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
- return graph_function(*inputs)
-
- def call_python_function(self, *args, **kwargs):
- """Directly calls the wrapped python function."""
- return self._python_function(*args, **kwargs)
- @property
- def variables(self):
- """Returns a list of variables used in any of the defined functions."""
- return self._variables
+def _validate_signature(signature):
+ if any(not isinstance(arg, tensor_spec.TensorSpec)
+ for arg in nest.flatten(signature)):
+ raise TypeError("Invalid input_signature %s; input_signature must be "
+ "a possibly nested sequence of TensorSpec objects.")
-# TODO(akshayka): Remove the `compiled` flag and create a separate
-# API for xla compilation (`defun` is already complicated enough
-# as it is, and the keyword argument makes 'compiled' an overloaded concept)
-def defun(func=None, input_signature=None, compiled=False):
+def defun(func=None, input_signature=None):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1221,6 +1230,7 @@ def defun(func=None, input_signature=None, compiled=False):
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.keep_probability = keep_probability
+ @tf.contrib.eager.defun
def call(self, inputs, training=True):
x = self.dense2(self.dense1(inputs))
if training:
@@ -1229,7 +1239,6 @@ def defun(func=None, input_signature=None, compiled=False):
return x
model = MyModel()
- model.call = tf.contrib.eager.defun(model.call)
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
@@ -1437,9 +1446,10 @@ def defun(func=None, input_signature=None, compiled=False):
func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
- In other words, defun(compiled=True)(func) is equivalent to
- defun(func, compiled=True). The former allows the following use case:
- @tf.contrib.eager.defun(compiled=True)
+ In other words, defun(input_signature=...)(func) is equivalent to
+ defun(func, input_signature=...). The former allows
+ the following use case:
+ @tf.contrib.eager.defun(input_signature=...)
def foo(...):
...
@@ -1450,17 +1460,20 @@ def defun(func=None, input_signature=None, compiled=False):
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
- compiled: If True, an attempt to compile `func` with XLA will be made.
- If it fails, function will be run normally. Experimental. Currently
- supported only for execution on TPUs. For the vast majority of users,
- this argument should be False.
-
Returns:
If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
+
+ Raises:
+ TypeError: If `input_signature` is neither `None` nor a sequence of
+ `tf.contrib.eager.TensorSpec` objects.
"""
+
+ if input_signature is not None:
+ _validate_signature(input_signature)
+
# TODO(apassos): deal with captured global state. Deal with control flow.
def decorated(function):
try:
@@ -1469,8 +1482,7 @@ def defun(func=None, input_signature=None, compiled=False):
name = "function"
return tf_decorator.make_decorator(
function,
- _PolymorphicFunction(
- function, name, input_signature=input_signature, compiled=compiled))
+ PolymorphicFunction(function, name, input_signature=input_signature))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1486,51 +1498,6 @@ def defun(func=None, input_signature=None, compiled=False):
return decorated
-def make_defun_op(func, *args, **kwds):
- """Compile func into graph_mode, assuming func arguments are *args, **kwargs.
-
- `make_defun_op` converts a function that constructs a TensorFlow graph into
- a function object and attaches it to the graph. The resulting function
- object can be queried for its properties, and called directly with different
- inputs to execute.
-
- More details on use cases and limitations are available in the
- documentation for `defun`.
-
- Example:
- ```python
- def f(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- def g(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- z = tf.constant([[0.0, 0.0]])
- g_op = make_defun_op(g, z, z)
-
- assert g_op.output_shapes == tf.TensorShape([])
- assert g_op.output_types == tf.float32
-
- x = tf.constant([[2.0, 3.0]])
- y = tf.constant([[3.0, -2.0]])
-
- # The plain function and defun-compiled function should return the same value.
- assert f(x, y).numpy() == g_op(x, y).numpy()
- ```
-
- Args:
- func: function to be compiled.
- *args: List arguments to pass to `func` when attaching to the graph.
- **kwds: Keyword arguments to pass to `func` when attaching to the graph.
-
- Returns:
- A wrapper object which can be queried for its output properties,
- and which can be called directly the way a `@defun` wrapped function
- can.
- """
- return _trace_and_define_function(func.__name__, func, False, args, kwds)
-
-
class AutomaticControlDependencies(object):
"""Context manager to automatically add control dependencies.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 380bcf763f..92254a2c00 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -105,7 +104,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(step(), 2.0)
def testGraphGradientVariable(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
@@ -130,16 +129,16 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
- def testBasicDefunOpGraphMode(self):
+ def testBasicGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
@@ -211,33 +210,44 @@ class FunctionTest(test.TestCase):
random_seed.set_random_seed(1)
self.assertAllEqual(f(), x)
- def testNestedInputsDefunOpGraphMode(self):
+ def testSymGradGatherNd(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertAllEqual(sess.run(g), [[1.0]])
+
+ def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
+ @function.defun
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
-
inputs = pair({'a': t}, {'b': t})
- sq_op = function.make_defun_op(a_times_b, inputs)
-
+ sq_op = a_times_b.get_concrete_function(inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
- def testNestedOutputDefunOpGraphMode(self):
+ def testNestedOutputGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes,
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
@@ -247,28 +257,28 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
- def testDefunOpGraphModeWithGradients(self):
+ def testGraphFunctionWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
+ @function.defun
def step():
def inner():
return v * v
return backprop.implicit_grad(inner)()[0][0]
- step_op = function.make_defun_op(step)
-
+ step_op = step.get_concrete_function()
self.assertEqual(step_op.output_dtypes, dtypes.float32)
self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
self.assertAllEqual(step_op(), 2.0)
- def testDefunOpGraphModeNoneOutput(self):
+ def testGraphFunctionNoneOutput(self):
+ @function.defun
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
- fn_op = function.make_defun_op(fn, x, x)
-
+ fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
@@ -309,13 +319,13 @@ class FunctionTest(test.TestCase):
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -346,6 +356,47 @@ class FunctionTest(test.TestCase):
self.assertEqual(3.0, float(test_assign_add()))
+ @test_util.run_in_graph_and_eager_modes
+ def testTensorInitializationInFunctionRaisesError(self):
+ error_msg = ('Tensor-typed variable initializers must either be '
+ 'wrapped in an init_scope or callable.*')
+
+ @function.defun
+ def tensor_init():
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
+
+ tensor_init()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCallableTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ v = resource_variable_ops.ResourceVariable(
+ lambda: constant_op.constant(2.0))
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testInitScopeTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ with ops.init_scope():
+ const = constant_op.constant(2.0)
+ v = resource_variable_ops.ResourceVariable(const)
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -430,7 +481,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
def testGraphModeCaptureVariable(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
class HasAVar(object):
@@ -458,12 +509,12 @@ class FunctionTest(test.TestCase):
x = constant_op.constant(1.0)
l = f(x, v)
_, dv = gradients_impl.gradients(l, [x, v])
- with self.test_session():
+ with self.cached_session():
v.initializer.run()
self.assertAllEqual(dv.eval(), 0.0)
def testGraphModeManyFunctions(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
@function.defun
def f(x):
@@ -564,7 +615,6 @@ class FunctionTest(test.TestCase):
@function.defun
def g(x):
- tape.watch_variable(x)
y = math_ops.add(x, three)
f(y)
@@ -578,7 +628,6 @@ class FunctionTest(test.TestCase):
return math_ops.add(x, three)
def g(x):
- tape.watch_variable(three)
return f(x)
g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
@@ -633,17 +682,19 @@ class FunctionTest(test.TestCase):
def testReturningIndexedSlicesWithDefun(self):
def validate(indexed_slice):
+ @function.defun
def f():
return indexed_slice
- output = function.defun(f)()
+ output = f()
self.assertTrue(isinstance(output, ops.IndexedSlices))
self.assertAllEqual(indexed_slice.values, output.values)
self.assertAllEqual(indexed_slice.indices, output.indices)
self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
self.assertEqual(
- function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
+ f.get_concrete_function().output_shapes,
+ indexed_slice.values.shape)
arg = ops.IndexedSlices(
values=constant_op.constant([1, 2]),
@@ -883,7 +934,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@function.defun
@@ -966,39 +1017,109 @@ class FunctionTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
def testDeviceAnnotationsRespected(self):
- @function.defun
def multi_device_fn():
with ops.device('/cpu:0'):
- s1 = iterator_ops.Iterator.from_structure(
+ s0 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:1'):
- s2 = iterator_ops.Iterator.from_structure(
+ s1 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:2'):
- s3 = iterator_ops.Iterator.from_structure(
- (dtypes.float32,)).string_handle()
- with ops.device(''):
- # TODO(akshayka): This is unfortunate and brittle. It prevents
- # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'.
- # Remove this hack once we have a way of obtaining metadata about
- # function execution.
- s4 = iterator_ops.Iterator.from_structure(
+ s2 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
- return s1, s2, s3, s4
+ s3 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ return s0, s1, s2, s3
- with ops.device('/cpu:3'):
- outputs = self.evaluate(multi_device_fn())
+ defined = function.defun(multi_device_fn)
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
- with ops.device('/cpu:0'):
- outputs = self.evaluate(multi_device_fn())
+ with ops.device('/cpu:3'):
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 2)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
+ self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
+
+ # This should retrieve the call-site-device agnostic function
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ # And this should retrieve the function created for '/cpu:3'
+ with ops.device('/cpu:3'):
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ @test_util.run_in_graph_and_eager_modes(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2}))
+ def testCallingGraphFunctionOnIncompatibleDeviceRaisesError(self):
+
+ def func():
+ return constant_op.constant(0)
+
+ defined = function.defun(func)
+ with ops.device('cpu:0'):
+ cpu_graph_function = defined.get_concrete_function()
+
+ with ops.device('cpu:0'):
+ self.assertEqual(
+ self.evaluate(cpu_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ cpu_graph_function()
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device(None):
+ cpu_graph_function()
+
+ default_graph_function = defined.get_concrete_function()
+ self.assertEqual(
+ self.evaluate(default_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ default_graph_function()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testColocateWithRespected(self):
+ # TODO(b/113291792): Use multiple CPUs instead of a GPU.
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('cpu:0'):
+ x = constant_op.constant(1.0)
+
+ with ops.device('gpu:0'):
+ y = constant_op.constant(1.0)
+
+ @function.defun
+ def foo():
+ return iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+
+ with ops.colocate_with(x):
+ self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
+
+ with ops.colocate_with(y):
+ self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1027,26 +1148,31 @@ class FunctionTest(test.TestCase):
defined = function.defun(func)
defined(0, baz=20)
+
+ def cache_keys():
+ """Sanitizes cache keys of non-input metadata."""
+ return tuple(key[:3] for key in defined._function_cache)
+
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20, True), defined._arguments_to_functions)
+ self.assertIn((0, 1, 20), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2, True), defined._arguments_to_functions)
+ self.assertIn((1, 1, 2), cache_keys())
# This matches the previous call.
defined(foo=1)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3, True), defined._arguments_to_functions)
+ self.assertIn((1, 2, 3), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
@@ -1072,7 +1198,7 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
def bar(a):
@@ -1083,13 +1209,13 @@ class FunctionTest(test.TestCase):
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, b)
def testNestedInputSignatures(self):
@@ -1106,7 +1232,7 @@ class FunctionTest(test.TestCase):
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
out = defined([a, a], b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, a], b])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
@@ -1117,7 +1243,7 @@ class FunctionTest(test.TestCase):
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
out = defined([a, b], c)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, b], c])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
@@ -1153,13 +1279,13 @@ class FunctionTest(test.TestCase):
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegexp(TypeError, 'input_signature must be either a '
'tuple or a list.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
def testInputsIncompatibleWithSignatureRaisesError(self):
@@ -1213,22 +1339,22 @@ class FunctionTest(test.TestCase):
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
@@ -1258,27 +1384,27 @@ class FunctionTest(test.TestCase):
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
two = defined(a=a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
three = defined(b=b, a=a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
four = defined(a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
six = defined(a=b, b=a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
seven = defined(b=a, a=b)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
@@ -1298,14 +1424,14 @@ class FunctionTest(test.TestCase):
grad_t, = backprop.gradients_function(sq, [0])(t)
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
- with backprop.GradientTape(persistent=True) as gtape:
- gtape.watch(t)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(t)
one = matmul(t, b=t, transpose_a=True)
two = matmul(b=t, a=t, transpose_a=True)
three = matmul(a=t, b=t, transpose_a=True)
for output in [one, two, three]:
- self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+ self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
def testGradientInFunctionWithKeywordArguments(self):
@@ -1363,7 +1489,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
- side_effecting_function.call_python_function()
+ side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
@@ -1371,7 +1497,7 @@ class FunctionTest(test.TestCase):
class AutomaticControlDependenciesTest(test.TestCase):
def testBasic(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
with function.AutomaticControlDependencies() as c:
@@ -1382,7 +1508,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(), 4.0)
def testCondMustRun(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1403,7 +1529,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
def testCondMustRunSeparateRead(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1426,7 +1552,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(v.read_value().eval(), 6.0)
def testCondNested(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1460,7 +1586,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
def testCondOneBranch(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1480,7 +1606,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
def testCondOneBranchUpdateBefore(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1501,7 +1627,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
def testCondOneBranchUpdateAfter(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1537,7 +1663,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertAllEqual(out, [3, 4, 5])
def testDecorator(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
deleted file mode 100644
index 7105d2e399..0000000000
--- a/tensorflow/python/eager/graph_callable.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Decorator that produces a callable object that executes a TensorFlow graph.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import contextlib
-
-from tensorflow.python.eager import context
-from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _default_initializer(name, shape, dtype):
- """The default initializer for variables."""
- # pylint: disable=protected-access
- store = variable_scope._get_default_variable_store()
- initializer = store._get_default_initializer(name, shape=shape, dtype=dtype)
- # pylint: enable=protected-access
- return initializer[0]
-
-
-class _CapturedVariable(object):
- """Variable captured by graph_callable.
-
- Internal to the implementation of graph_callable. Created only by
- _VariableCapturingScope and used only to read the variable values when calling
- the function after the variables are initialized.
- """
-
- def __init__(self, name, initializer, shape, dtype, trainable):
- self.name = name
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- initial_value = lambda: initializer(shape, dtype=dtype)
-
- with context.eager_mode():
- self.variable = resource_variable_ops.ResourceVariable(
- initial_value=initial_value, name=name, dtype=dtype,
- trainable=trainable)
- self.shape = shape
- self.dtype = dtype
- self.placeholder = None
- self.trainable = trainable
-
- def read(self, want_gradients=True):
- if want_gradients and self.trainable:
- v = tape.watch_variable(self.variable)
- else:
- v = self.variable
- return v.read_value()
-
-
-class _VariableCapturingScope(object):
- """Variable-scope-like object which captures tf.get_variable calls.
-
- This is responsible for the main difference between the initialization version
- of a function object and the calling version of a function object.
-
- capturing_scope replaces calls to tf.get_variable with placeholder tensors to
- be fed the variable's current value. TODO(apassos): these placeholders should
- instead be objects implementing a similar API to tf.Variable, for full
- compatibility.
-
- initializing_scope replaces calls to tf.get_variable with creation of
- variables and initialization of their values. This allows eventual support of
- initialized_value and friends.
-
- TODO(apassos): once the eager mode layers API is implemented support eager
- func-to-object as well.
- """
-
- def __init__(self):
- self.variables = {}
- self.tf_variables = {}
-
- @contextlib.contextmanager
- def capturing_scope(self):
- """Context manager to capture variable creations.
-
- Replaces variable accesses with placeholders.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape
- del aggregation, synchronization
- assert name in self.variables
- v = self.variables[name]
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
- @contextlib.contextmanager
- def initializing_scope(self):
- """Context manager to capture variable creations.
-
- Forcibly initializes all created variables.
-
- Yields:
- nothing
- """
- # TODO(apassos) ignoring the regularizer and partitioner here; figure out
- # how to deal with these.
- def _custom_getter( # pylint: disable=missing-docstring
- getter=None,
- name=None,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- regularizer=None,
- reuse=None,
- trainable=None,
- collections=None,
- caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- aggregation=variable_scope.VariableAggregation.NONE,
- synchronization=variable_scope.VariableSynchronization.AUTO):
- del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape, aggregation, synchronization
- if name in self.tf_variables:
- if reuse:
- return self.tf_variables[name].initialized_value()
- else:
- raise ValueError("Specified reuse=%s but tried to reuse variables."
- % reuse)
- # TODO(apassos): ensure this is on the same device as above
- v = _CapturedVariable(name, initializer, shape, dtype, trainable)
- self.variables[name] = v
-
- graph_mode_resource = v.variable.handle
- if initializer is None:
- initializer = _default_initializer(name, shape, dtype)
- resource_variable_ops.shape_safe_assign_variable_handle(
- graph_mode_resource, v.variable.shape, initializer(shape, dtype))
- return v.variable
-
- scope = variable_scope.get_variable_scope()
- with variable_scope.variable_scope(scope, custom_getter=_custom_getter):
- yield
-
-
-class _InitializingFunctionObject(object):
- """Responsible for deciding which version of func-to-object to call.
-
- call_fn is the version which calls the function with the current values of the
- variables and init_fn is the version which calls the function to initialize
- all variables.
-
- TODO(apassos): figure out a way to support initializing only _some_
- variables. This requires a way to pull out a variable's initialization code
- from the graph, which might not be possible in general.
- """
-
- def __init__(self, call_fn, init_fn, shape_and_dtypes):
- self._init_fn = init_fn
- self._call_fn = call_fn
- self.shape_and_dtypes = shape_and_dtypes
- self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in
- nest.flatten(self.shape_and_dtypes)]
-
- @property
- def variables(self):
- return self._call_fn.variables
-
- def __call__(self, *args):
- nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
- if not all([
- shape.is_compatible_with(arg.shape)
- for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
- ]):
- raise ValueError(
- "Declared shapes do not match argument shapes: Expected %s, found %s."
- % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))
-
- initialized = [resource_variable_ops.var_is_initialized_op(
- v.handle).numpy() for v in self._call_fn.variables]
- if all(x for x in initialized):
- for v in self._call_fn.variables:
- if v.trainable:
- tape.watch_variable(v)
- return self._call_fn(*args)
- elif all(not x for x in initialized):
- return self._init_fn(*args)
- else:
- raise ValueError("Some, but not all, variables are initialized.")
-
-
-def _get_graph_callable_inputs(shape_and_dtypes):
- """Maps specified shape_and_dtypes to graph inputs."""
- ret = []
- for x in shape_and_dtypes:
- if isinstance(x, ShapeAndDtype):
- ret.append(array_ops.placeholder(x.dtype, x.shape))
- elif isinstance(x, (tuple, list)):
- ret.append(_get_graph_callable_inputs(x))
- else:
- raise errors.InvalidArgumentError(
- None, None, "Expected the argument to @graph_callable to be a "
- "(possibly nested) list or tuple of ShapeAndDtype objects, "
- "but got an object of type: %s" % type(x))
-
- return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
-
-
-def _graph_callable_internal(func, shape_and_dtypes):
- """Defines and returns a template version of func.
-
- Under the hood we make two function objects, each wrapping a different version
- of the graph-mode code. One version immediately runs variable initialization
- before making the variable's Tensors available for use, while the other
- version replaces the Variables with placeholders which become function
- arguments and get the current variable's value.
-
- Limitations in (2) and (4) are because this does not implement a graph-mode
- Variable class which has a convert_to_tensor(as_ref=True) method and a
- initialized_value method. This is fixable.
-
- Args:
- func: The tfe Python function to compile.
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
-
- Raises:
- ValueError: If any one of func's outputs is not a Tensor.
-
- Returns:
- Callable graph object.
- """
- container = tf_ops.get_default_graph()._container # pylint: disable=protected-access
- graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access
- with context.graph_mode():
- # This graph will store both the initialization and the call version of the
- # wrapped function. It will later be used by the backprop code to build the
- # backprop graph, if necessary.
- tmp_graph = function.CapturingGraph()
- # Inherit the graph key from the original graph to ensure optimizers don't
- # misbehave.
- tmp_graph._container = container # pylint: disable=protected-access
- tmp_graph._graph_key = graph_key # pylint: disable=protected-access
- with tmp_graph.as_default():
- # Placeholders for the non-variable inputs.
- func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
- func_num_args = len(tf_inspect.getfullargspec(func).args)
- if len(func_inputs) != func_num_args:
- raise TypeError("The number of arguments accepted by the decorated "
- "function `%s` (%d) must match the number of "
- "ShapeAndDtype objects passed to the graph_callable() "
- "decorator (%d)." %
- (func.__name__, func_num_args, len(func_inputs)))
-
- # First call the function to generate a graph which can initialize all
- # variables. As a side-effect this will populate the variable capturing
- # scope's view of which variables exist.
- variable_captures = _VariableCapturingScope()
- with variable_captures.initializing_scope(
- ), function.AutomaticControlDependencies() as a:
- func_outputs = func(*func_inputs)
- outputs_list = nest.flatten(func_outputs)
- for i, x in enumerate(outputs_list):
- if x is not None:
- outputs_list[i] = a.mark_as_return(x)
- if len(outputs_list) == 1 and outputs_list[0] is None:
- outputs_list = []
- output_shapes = [x.shape for x in outputs_list]
- if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
- raise ValueError("Found non-tensor output in %s" % str(outputs_list))
- initializing_operations = tmp_graph.get_operations()
-
- # Call the function again, now replacing usages of variables with
- # placeholders. This assumes the variable capturing scope created above
- # knows about all variables.
- tmp_graph.clear_resource_control_flow_state()
- with variable_captures.capturing_scope(
- ), function.AutomaticControlDependencies() as a:
- captured_outputs = func(*func_inputs)
- captured_outlist = nest.flatten(captured_outputs)
- for i, x in enumerate(captured_outlist):
- if x is not None:
- captured_outlist[i] = a.mark_as_return(x)
- capturing_operations = tmp_graph.get_operations()[
- len(initializing_operations):]
-
- sorted_variables = sorted(variable_captures.variables.values(),
- key=lambda x: x.name)
-
- extra_inputs = tmp_graph.captures.keys()
- extra_placeholders = tmp_graph.captures.values()
-
- flat_inputs = [x for x in nest.flatten(func_inputs)
- if isinstance(x, tf_ops.Tensor)]
- placeholder_inputs = flat_inputs+ list(extra_placeholders)
-
- func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
- initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- # TODO(ashankar): Oh lord, forgive me for this lint travesty.
- # Also, what about the gradient registry of these functions? Those need to be
- # addressed as well.
- for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register(f._c_func.func) # pylint: disable=protected-access
- initializer_function = function.GraphModeFunction(
- initialization_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- initializing_operations,
- func_def_outputs,
- func_outputs,
- output_shapes)
-
- capture_func_def_outputs = [
- x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
- captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
- captured_function = function.GraphModeFunction(
- captured_function_name,
- placeholder_inputs,
- extra_inputs,
- tmp_graph,
- capturing_operations,
- capture_func_def_outputs,
- captured_outputs,
- output_shapes,
- variables=[x.variable for x in sorted_variables])
-
- return _InitializingFunctionObject(captured_function, initializer_function,
- shape_and_dtypes)
-
-
-class ShapeAndDtype(object):
- """Data type that packages together shape and type information.
-
- Used for arguments to graph callables. See graph_callable() for an example.
- """
-
- def __init__(self, shape, dtype):
- self.shape = shape
- self.dtype = dtype
-
-
-def graph_callable(shape_and_dtypes):
- """Decorator that produces a callable that executes a TensorFlow graph.
-
- When applied on a function that constructs a TensorFlow graph, this decorator
- produces a callable object that:
-
- 1. Executes the graph when invoked. The first call will initialize any
- variables defined in the graph.
-
- 2. Provides a .variables() method to return the list of TensorFlow variables
- defined in the graph.
-
- Note that the wrapped function is not allowed to change the values of the
- variables, just use them.
-
- The return value of the wrapped function must be one of the following:
- (1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors.
-
- Example:
-
- ```python
- @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)])
- def foo(x):
- v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
- return v + x
-
- ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0.
-
- foo.variables[0].assign(7.0) # Modify the value of variable `v`.
- ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0.
- ```
- Args:
- shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
- that specifies shape and type information for each of the callable's
- arguments. The length of this list must be equal to the number of
- arguments accepted by the wrapped function.
-
- Returns:
- A callable graph object.
- """
- # TODO(alive,apassos): support initialized_value and friends from tf.Variable.
- assert context.executing_eagerly(), (
- "graph_callable can only be used when Eager execution is enabled.")
- def decorator(func):
- return tf_decorator.make_decorator(func,
- _graph_callable_internal(
- func, shape_and_dtypes))
-
- return decorator
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
deleted file mode 100644
index b9e6ca2a93..0000000000
--- a/tensorflow/python/eager/graph_callable_test.py
+++ /dev/null
@@ -1,249 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import graph_callable
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-
-
-class GraphCallableTest(test.TestCase):
-
- def testBasic(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testFunctionWithoutReturnValue(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x)
-
- my_function(constant_op.constant(4, dtype=dtypes.float32))
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testFunctionWithoutReturnValueAndArgs(self):
-
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(4)
-
- my_function()
- self.assertAllEqual(4, my_function.variables[0].read_value())
-
- def testVariableAPI(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v.read_value() + x
-
- self.assertEqual(
- 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- my_function.variables[0].assign(1.)
- self.assertEqual(
- 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
-
- def testTensorShape(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- _ = x.get_shape()
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]])
- self.assertEqual(v.shape[0], x.shape[0])
- return v + x
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testUpdatesAreOrdered(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- v.assign(x + 1)
- v.assign(v * x)
- return v.read_value()
-
- self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
-
- def testEmptyInitializer(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable("v", shape=[1])
- return x + 0 * v
-
- self.assertEqual([2.],
- my_function(
- constant_op.constant([2.],
- dtype=dtypes.float32)).numpy())
-
- def testMismatchingNumArgs(self):
- # pylint: disable=anomalous-backslash-in-string
- with self.assertRaisesRegexp(TypeError,
- "The number of arguments accepted by the "
- "decorated function `my_function` \(2\) must "
- "match the number of ShapeAndDtype objects "
- "passed to the graph_callable\(\) decorator "
- "\(1\)."):
- @graph_callable.graph_callable([
- graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def my_function(x, y): # pylint: disable=unused-variable
- return x + y
- # pylint: enable=anomalous-backslash-in-string
-
- def testPureFunction(self):
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def f(x):
- return math_ops.add(x, constant_op.constant(3))
-
- self.assertAllEqual(5, f(constant_op.constant(2)))
-
- def testNestedFunction(self):
- # TensorFlow function (which is what would be used in TensorFlow graph
- # construction).
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- # A graph_callable that will invoke the TensorFlow function.
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- self.assertAllEqual(3, add_one(constant_op.constant(2)))
-
- # TODO(ashankar): Make this work.
- # The problem is that the two graph_callables (for add_one and add_two)
- # are both trying to register the FunctionDef corresponding to "add".
- def DISABLED_testRepeatedUseOfSubFunction(self):
-
- @function.Defun(dtypes.int32, dtypes.int32)
- def add(a, b):
- return math_ops.add(a, b)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_one(x):
- return add(x, 1)
-
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
- def add_two(x):
- return add(x, 2)
-
- two = constant_op.constant(2)
- self.assertAllEqual(3, add_one(two))
- self.assertAllEqual(4, add_two(two))
-
- def testNestedSequenceInputs(self):
- sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)
- @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]])
- def my_op(inputs):
- a, b, c = inputs
- e, f = b
- v = variable_scope.get_variable(
- "my_v", initializer=init_ops.zeros_initializer(), shape=())
- return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v
-
- inputs = [constant_op.constant(1.),
- [constant_op.constant(2.), constant_op.constant(3.)],
- constant_op.constant(4.)]
- ret = my_op(inputs)
- self.assertEqual(len(ret), 2.)
- self.assertAllEqual(ret[1], 10.)
-
- my_op.variables[0].assign(1.)
- ret = my_op(inputs)
- self.assertAllEqual(ret[1], 11.)
-
- def testVariableShapeIsTensorShape(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape)
-
- my_function()
-
- def testIncorrectlyShapedInputs(self):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)])
- def my_function(x):
- v = variable_scope.get_variable(
- "v", initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- with self.assertRaises(ValueError):
- my_function([1, 2])
-
- self.assertTrue(([1, 2, 3] == my_function(
- constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
-
- def testGradients(self):
- @graph_callable.graph_callable([])
- def my_function():
- v = variable_scope.get_variable(
- "v", initializer=init_ops.constant_initializer(3.), shape=())
- return v * v
-
- grad_fn = backprop.implicit_grad(my_function)
- grads_and_vars = list(zip(*grad_fn()))
- self.assertAllEqual(6., grads_and_vars[0][0])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
index d2a2b4e223..3cf3a61a62 100644
--- a/tensorflow/python/eager/graph_only_ops_test.py
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase):
def testGraphZerosLike(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
z_tf = graph_only_ops.graph_zeros_like(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
def testGraphPlaceholder(self):
x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
y_tf = math_ops.square(x_tf)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = np.array([42])
y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
self.assertAllClose(np.square(x), y)
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 000152855d..5f027d107c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -24,12 +24,10 @@ from tensorflow.python import pywrap_tensorflow
VSpace = collections.namedtuple(
- "VSpace",
- ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
+ "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
def imperative_grad(
- vspace,
tape,
target,
sources,
@@ -41,7 +39,6 @@ def imperative_grad(
gradients for all sources.
Args:
- vspace: the vector space in which to differentiate.
tape: the gradient tape which stores the trace.
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
@@ -60,4 +57,7 @@ def imperative_grad(
computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access
+ tape._tape, # pylint: disable=protected-access
+ target,
+ sources,
+ output_gradients)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 15d2ccf9d2..f34ce6af79 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
+#include "structmember.h" // NOLINT // For PyMemberDef
+
// forward declare
struct EagerTensor;
@@ -263,6 +265,14 @@ typedef struct EagerTensor {
TF_Status* status;
PyObject* weakreflist; /* List of weak references */
+
+ // Per-instance attribute dictionary, to support monkey patching
+ // (e.g. EagerTensor.assign when slicing variables). This dictionary is
+ // created by CPython the first time an attribute is assigned, pointed to by
+ // tp_dictoffset. Note that garbage collection is not enabled for
+ // EagerTensors, so assigning objects to EagerTensor attributes which require
+ // garbage collection is likely to cause issues.
+ PyObject* dict;
} EagerTensor;
namespace {
@@ -311,17 +321,42 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
+ self->dict = nullptr;
self->weakreflist = nullptr;
PyObject* value;
PyObject* context = nullptr;
PyObject* device = nullptr;
PyObject* dtype = Py_None;
- const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
+ PyObject* other_value = nullptr;
+ const char* kwlist[] = {"value", "context", "device",
+ "dtype", "other_value", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
const_cast<char**>(kwlist), &value, &context,
- &device, &dtype)) {
+ &device, &dtype, &other_value)) {
return -1;
}
+
+ if (other_value != nullptr) {
+ if (!EagerTensor_CheckExact(other_value)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting an EagerTensor for other_value, got ",
+ Py_TYPE(other_value)->tp_name)
+ .c_str());
+
+ return -1;
+ }
+ EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
+ self->handle =
+ TFE_TensorHandleCopySharingTensor(other->handle, self->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ return -1;
+ }
+
+ return 0;
+ }
+
// Extract dtype
int desired_dtype = -1;
if (dtype != Py_None) {
@@ -410,6 +445,10 @@ void EagerTensor_dealloc(EagerTensor* self) {
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
Py_DECREF(self->tensor_shape);
+ // If an attribute dictionary has been created, release it. Note that this
+ // is only ever created by CPython's attribute setting methods; we don't
+ // create it ourselves.
+ Py_CLEAR(self->dict);
if (self->handle != nullptr) {
TFE_DeleteTensorHandle(self->handle);
self->handle = nullptr;
@@ -474,6 +513,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
#endif
}
+// Getter for `_num_elements`.
+static PyObject* EagerTensor_num_elements(EagerTensor* self) {
+ auto handle = self->handle;
+ int n = TFE_TensorHandleNumDims(handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
+ tensorflow::int64 value = 1;
+ if (PyErr_Occurred()) return nullptr;
+ for (int i = 0; i < n; ++i) {
+ int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
+ return nullptr;
+ }
+ value *= dim;
+ }
+ return PyLong_FromLongLong(value);
+}
+
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
Py_INCREF(self->handle_data);
return self->handle_data;
@@ -582,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
{nullptr} /* Sentinel */
};
+#if PY_MAJOR_VERSION < 3
+// Only used for Python2 since Python3 seems to set the __dict__ correctly.
+static PyMemberDef EagerTensor_members[] = {
+ {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
+ READONLY},
+ {nullptr},
+};
+#endif
+
static PyMethodDef EagerTensor_methods[] = {
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
PyDoc_STR("_numpy")},
@@ -592,6 +664,8 @@ static PyMethodDef EagerTensor_methods[] = {
{"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
{"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
+ {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
+ PyDoc_STR("_num_elements")},
{nullptr, nullptr},
};
@@ -654,13 +728,13 @@ static PyTypeObject _EagerTensorType = {
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
EagerTensor_methods, /* tp_methods */
- nullptr, /* tp_members */
+ EagerTensor_members, /* tp_members */
EagerTensor_getseters, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
- 0, /* tp_dictoffset */
+ offsetof(EagerTensor, dict), /* tp_dictoffset */
(initproc)EagerTensor_init, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
@@ -788,8 +862,9 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
return nullptr;
}
+ EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
#else
- _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
+ _EagerTensorType.tp_base = base_class_type;
if (PyType_Ready(&_EagerTensorType) < 0) {
if (PyErr_Occurred()) return nullptr;
@@ -800,9 +875,6 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
EagerTensorType = &_EagerTensorType;
Py_INCREF(EagerTensorType);
#endif
- // We disable instance based attribute lookup. Its not clear if these
- // dictionaries are correctly initialized in the first place.
- EagerTensorType->tp_dictoffset = 0;
return reinterpret_cast<PyObject*>(EagerTensorType);
}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a916a75f00..f1b4042ec9 100644..100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
+// Registers e as the VSpace to use.
+// `vspace` must be a imperative_grad.py:VSpace named tuple.
+PyObject* TFE_Py_RegisterVSpace(PyObject* e);
+
// Registers e as the Exception to be raised when the conditions of
// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
// is a signal to the calling code that it should fall back to the safer (and
@@ -89,7 +93,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
PyObject* exception);
// Returns the string associated with the passed-in python object.
-char* TFE_GetPythonString(PyObject* o);
+const char* TFE_GetPythonString(PyObject* o);
// Returns a unique id on each call.
int64_t get_uid();
@@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// To unset the profiler, pass Py_None as the value of `profiler`.
PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -138,7 +143,7 @@ void TFE_Py_TapeSetAdd(PyObject* tape);
PyObject* TFE_Py_TapeSetIsEmpty();
PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors);
-void TFE_Py_TapeSetWatch(PyObject* tensor);
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor);
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id);
// Stops any gradient recording on the current thread.
@@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
// Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
-// have been produced by TFE_Py_NewTape. `vspace` must be a
-// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
// lists of Tensor objects. `output_gradients` is either None or a python list
// of either Tensor or None, and if not None should have the same length as
// target.
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status);
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status);
// Execute a tensorflow operation assuming that all provided inputs are
// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 18fafd0de1..46dcf7c8a8 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -216,7 +216,7 @@ bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(py_value)) {
Py_ssize_t size = 0;
- char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
+ const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
if (buf == nullptr) return false;
*value = tensorflow::StringPiece(buf, size);
return true;
@@ -825,7 +825,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
return -1;
}
-char* TFE_GetPythonString(PyObject* o) {
+const char* TFE_GetPythonString(PyObject* o) {
if (PyBytes_Check(o)) {
return PyBytes_AsString(o);
}
@@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
- explicit GradientTape(bool persistent)
+ explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent) {}
+ persistent),
+ watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
for (const IdAndVariable& v : watched_variables_) {
@@ -902,6 +903,12 @@ class GradientTape
}
}
+ void VariableAccessed(PyObject* v) {
+ if (watch_accessed_variables_) {
+ WatchVariable(v);
+ }
+ }
+
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
@@ -951,6 +958,7 @@ class GradientTape
}
};
+ bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
GUARDED_BY(watched_variables_mu_);
@@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
- tape->tape = new GradientTape(persistent == Py_True);
+ tape->tape = new GradientTape(persistent == Py_True,
+ watch_accessed_variables == Py_True);
Py_INCREF(tape);
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
return reinterpret_cast<PyObject*>(tape);
@@ -1154,7 +1164,7 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
Py_RETURN_FALSE;
}
-void TFE_Py_TapeSetWatch(PyObject* tensor) {
+void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
if (*ThreadTapeIsStopped()) {
return;
}
@@ -1162,9 +1172,7 @@ void TFE_Py_TapeSetWatch(PyObject* tensor) {
if (PyErr_Occurred()) {
return;
}
- for (TFE_Py_Tape* tape : *GetTapeSet()) {
- tape->tape->Watch(tensor_id);
- }
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
@@ -1235,15 +1243,22 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+void TFE_Py_TapeVariableAccessed(PyObject* variable) {
if (*ThreadTapeIsStopped()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- tape->tape->WatchVariable(variable);
+ tape->tape->VariableAccessed(variable);
}
}
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
+ if (*ThreadTapeIsStopped()) {
+ return;
+ }
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
+}
+
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
@@ -1350,7 +1365,9 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
class PyVSpace
: public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
tensorflow::Status Initialize() {
num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
@@ -1378,6 +1395,8 @@ class PyVSpace
Py_XDECREF(aggregate_fn_);
Py_XDECREF(zeros_);
Py_XDECREF(ones_);
+
+ Py_DECREF(py_vspace_);
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
@@ -1493,6 +1512,22 @@ class PyVSpace
PyObject* zeros_;
PyObject* ones_;
};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -1509,9 +1544,9 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
return list;
}
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status) {
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status) {
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
if (!tape_obj->tape->IsPersistent()) {
auto* tape_set = GetTapeSet();
@@ -1526,10 +1561,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
return nullptr;
}
}
- PyVSpace c_vspace(vspace);
- if (!c_vspace.Initialize().ok()) {
- return nullptr;
- }
std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
if (PyErr_Occurred()) {
@@ -1553,7 +1584,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
}
std::vector<PyObject*> result;
status->status = tape_obj->tape->ComputeGradient(
- c_vspace, target_vec, sources_vec, outgrad_vec, &result);
+ *py_vspace, target_vec, sources_vec, outgrad_vec, &result);
if (!status->status.ok()) {
if (PyErr_Occurred()) {
// Do not propagate the erroneous status as that would swallow the
@@ -1709,118 +1740,169 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
-bool OpDoesntRequireOutput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "MatMul",
- "Conv2DBackpropInput",
- "Conv2DBackpropFilter",
- "Conv3D",
- "Conv3DBackpropInputV2",
- "AvgPool3D",
- "AvgPool3DGrad",
- "MaxPool3D",
- "MaxPool3DGrad",
- "MaxPool3DGradGrad",
- "BiasAdd",
- "BiasAddV1",
- "BiasAddGrad",
- "Softplus",
- "SoftplusGrad",
- "Softsign",
- "ReluGrad",
- "LeakyRelu",
- "LeakyReluGrad",
- "Conv2D",
- "DepthwiseConv2dNative",
- "Dilation2D",
- "AvgPool",
- "AvgPoolGrad",
- "BatchNormWithGlobalNormalization",
- "L2Loss",
- "Sum",
- "Prod",
- "SegmentSum",
- "SegmentMean",
- "SparseSegmentSum",
- "SparseSegmentMean",
- "SparseSegmentSqrtN",
- "SegmentMin",
- "SegmentMax",
- "UnsortedSegmentSum",
- "UnsortedSegmentMax",
- "Abs",
- "Neg",
- "ReciprocalGrad",
- "Square",
- "Expm1",
- "Log",
- "Log1p",
- "TanhGrad",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Cos",
- "Tan",
- "Add",
- "Sub",
- "Mul",
- "Div",
- "RealDiv",
- "Maximum",
- "Minimum",
- "SquaredDifference",
- "Select",
- "SparseMatMul",
- "BatchMatMul",
- "Complex",
- "Real",
- "Imag",
- "Angle",
- "Conj",
- "Cast",
- "Cross",
- "Cumsum",
- "Cumprod",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+// Returns a pair where the first value of the pair indicates whether or not all
+// outputs are unused. If the first value is false, the second value is a
+// set that identifies which of the output indices are unused.
+bool OpGradientDoesntRequireOutputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any outputs.
+ {"Identity", {true, {}}},
+ {"MatMul", {true, {}}},
+ {"Conv2DBackpropInput", {true, {}}},
+ {"Conv2DBackpropFilter", {true, {}}},
+ {"Conv3D", {true, {}}},
+ {"Conv3DBackpropInputV2", {true, {}}},
+ {"AvgPool3D", {true, {}}},
+ {"AvgPool3DGrad", {true, {}}},
+ {"MaxPool3D", {true, {}}},
+ {"MaxPool3DGrad", {true, {}}},
+ {"MaxPool3DGradGrad", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"BiasAddV1", {true, {}}},
+ {"BiasAddGrad", {true, {}}},
+ {"Softplus", {true, {}}},
+ {"SoftplusGrad", {true, {}}},
+ {"Softsign", {true, {}}},
+ {"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
+ {"Conv2D", {true, {}}},
+ {"DepthwiseConv2dNative", {true, {}}},
+ {"Dilation2D", {true, {}}},
+ {"AvgPool", {true, {}}},
+ {"AvgPoolGrad", {true, {}}},
+ {"BatchNormWithGlobalNormalization", {true, {}}},
+ {"L2Loss", {true, {}}},
+ {"Sum", {true, {}}},
+ {"Prod", {true, {}}},
+ {"SegmentSum", {true, {}}},
+ {"SegmentMean", {true, {}}},
+ {"SparseSegmentSum", {true, {}}},
+ {"SparseSegmentMean", {true, {}}},
+ {"SparseSegmentSqrtN", {true, {}}},
+ {"SegmentMin", {true, {}}},
+ {"SegmentMax", {true, {}}},
+ {"UnsortedSegmentSum", {true, {}}},
+ {"UnsortedSegmentMax", {true, {}}},
+ {"Abs", {true, {}}},
+ {"Neg", {true, {}}},
+ {"ReciprocalGrad", {true, {}}},
+ {"Square", {true, {}}},
+ {"Expm1", {true, {}}},
+ {"Log", {true, {}}},
+ {"Log1p", {true, {}}},
+ {"TanhGrad", {true, {}}},
+ {"SigmoidGrad", {true, {}}},
+ {"Sign", {true, {}}},
+ {"Sin", {true, {}}},
+ {"Cos", {true, {}}},
+ {"Tan", {true, {}}},
+ {"Add", {true, {}}},
+ {"Sub", {true, {}}},
+ {"Mul", {true, {}}},
+ {"Div", {true, {}}},
+ {"RealDiv", {true, {}}},
+ {"Maximum", {true, {}}},
+ {"Minimum", {true, {}}},
+ {"SquaredDifference", {true, {}}},
+ {"Select", {true, {}}},
+ {"SparseMatMul", {true, {}}},
+ {"BatchMatMul", {true, {}}},
+ {"Complex", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Angle", {true, {}}},
+ {"Conj", {true, {}}},
+ {"Cast", {true, {}}},
+ {"Cross", {true, {}}},
+ {"Cumsum", {true, {}}},
+ {"Cumprod", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"StridedSlice", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of outputs.
+ {"FusedBatchNorm", {false, {0, 1, 2}}},
});
- return ops_that_dont_require_outputs->find(op_name) !=
- ops_that_dont_require_outputs->end();
-}
-
-bool OpDoesntRequireInput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "Softmax",
- "LogSoftmax",
- "BiasAdd",
- "Relu",
- "Relu6",
- "Elu",
- "Selu",
- "SparseSoftmaxCrossEntropyWithLogits",
- "Neg",
- "Inv",
- "Reciprocal",
- "Sqrt",
- "Exp",
- "Tanh",
- "Sigmoid",
- "Real",
- "Imag",
- "Conj",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+// Returns a pair where the first value of the pair indicates whether or not all
+// inputs are unused. If the first value is false, the second value is a
+// set that identifies which of the input indices are unused.
+bool OpGradientDoesntRequireInputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any inputs.
+ {"Identity", {true, {}}},
+ {"Softmax", {true, {}}},
+ {"LogSoftmax", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"Relu", {true, {}}},
+ {"Relu6", {true, {}}},
+ {"Elu", {true, {}}},
+ {"Selu", {true, {}}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
+ {"Neg", {true, {}}},
+ {"Inv", {true, {}}},
+ {"Reciprocal", {true, {}}},
+ {"Sqrt", {true, {}}},
+ {"Exp", {true, {}}},
+ {"Tanh", {true, {}}},
+ {"Sigmoid", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Conj", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of inputs.
+ {"FusedBatchNorm", {false, {2}}},
});
- return ops_that_dont_require_inputs->find(op_name) !=
- ops_that_dont_require_inputs->end();
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+PyObject* CopySequenceSettingIndicesToNull(
+ PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
+ tensorflow::Safe_PyObjectPtr fast_seq(
+ PySequence_Fast(seq, "unable to allocate"));
+ PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
+ for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
+ PyObject* item;
+ if (indices.find(i) != indices.end()) {
+ item = Py_None;
+ } else {
+ item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
+ }
+ Py_INCREF(item);
+ PyTuple_SET_ITEM(result, i, item);
+ }
+ return result;
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
@@ -1840,16 +1922,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
if (!should_record) Py_RETURN_NONE;
string c_op_name = TFE_GetPythonString(op_name);
+
PyObject* op_outputs;
- if (OpDoesntRequireOutput(c_op_name)) {
- op_outputs = Py_None;
+ bool op_outputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
+
+ if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
+ if (outputs_not_required->first) {
+ op_outputs = Py_None;
+ } else {
+ op_outputs_tuple_created = true;
+ op_outputs = CopySequenceSettingIndicesToNull(
+ results, outputs_not_required->second);
+ }
} else {
op_outputs = results;
}
PyObject* op_inputs;
- if (OpDoesntRequireInput(c_op_name)) {
- op_inputs = Py_None;
+ bool op_inputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
+
+ if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
+ if (inputs_not_required->first) {
+ op_inputs = Py_None;
+ } else {
+ op_inputs_tuple_created = true;
+ op_inputs =
+ CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
+ }
} else {
op_inputs = inputs;
}
@@ -1892,18 +1993,20 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
});
Py_DECREF(num_inputs);
+ if (op_outputs_tuple_created) Py_DECREF(op_outputs);
+ if (op_inputs_tuple_created) Py_DECREF(op_inputs);
Py_RETURN_NONE;
}
-void MaybeWatchVariable(PyObject* input) {
+void MaybeNotifyVariableAccessed(PyObject* input) {
DCHECK(CheckResourceVariable(input));
DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
- TFE_Py_TapeSetWatchVariable(input);
+ TFE_Py_TapeVariableAccessed(input);
}
bool CastTensor(const FastPathOpExecInfo& op_exec_info,
@@ -1934,7 +2037,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
- MaybeWatchVariable(input);
+ MaybeNotifyVariableAccessed(input);
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index caa217b70c..399d90223c 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -33,9 +33,10 @@ class Tape(object):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
-def push_new_tape(persistent=False):
+def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
- tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+ tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
+ watch_accessed_variables)
return Tape(tape)
@@ -44,22 +45,19 @@ def push_tape(tape):
pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
-def watch(tensor):
- """Marks this tensor to be watched by all tapes in the stack.
+def watch(tape, tensor):
+ """Marks this tensor to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
- Args:
- tensor: tensor to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor)
+def watch_variable(tape, variable):
+ """Marks this variable to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
-def watch_variable(variable):
- """Marks this variable to be watched by all tapes in the stack.
- Args:
- variable: variable to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
+def variable_accessed(variable):
+ """Notifies all tapes in the stack that a variable has been accessed."""
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
def pop_tape(tape):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 4326d5efa3..acd0e569f1 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -72,7 +72,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_c = tf_a + tf_b
@@ -135,7 +135,7 @@ class TapeTest(test.TestCase):
a = constant_op.constant([[1., 0.], [0., 1.]])
b = constant_op.constant([[1., 2.], [3., 4.]])
da, db = backprop.gradients_function(fn, [0, 1])(a, b)
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
tf_mm = math_ops.matmul(tf_a, tf_b)
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 871136e2c8..344a9b25bd 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
def _create_tensor(value, device=None, dtype=None):
@@ -295,6 +296,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
def testFloatTensor(self):
self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
+ self.assertEqual(dtypes.float16, _create_tensor(np.float16()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
def testSliceDimOutOfRange(self):
@@ -332,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
"but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testTensorDir(self):
+ t = array_ops.zeros(1)
+ t.test_attr = "Test"
+
+ instance_dir = dir(t)
+ type_dir = dir(ops.EagerTensor)
+
+ # Monkey patched attributes should show up in dir(t)
+ self.assertIn("test_attr", instance_dir)
+ instance_dir.remove("test_attr")
+ self.assertEqual(instance_dir, type_dir)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 817c8e6848..4001ffdd6b 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -211,6 +211,9 @@ py_test(
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
+ "manual",
+ "no_oss",
+ "notap",
"optonly",
],
deps = [
@@ -682,7 +685,6 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
- "notsan",
],
deps = [
":keras",
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
index e46a3a156d..1df7216ba6 100644
--- a/tensorflow/python/estimator/canned/baseline_test.py
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -42,13 +42,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
@@ -490,7 +490,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -498,7 +498,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -693,13 +693,13 @@ class BaselineClassifierTrainingTest(test.TestCase):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer.Optimizer,
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 16928ca4b7..19f18015e4 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -38,7 +38,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -404,18 +403,21 @@ class _EnsembleGrower(object):
training_ops.append(grow_op)
"""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
"""Initializes a grower object.
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+
Raises:
ValueError: when pruning mode is invalid or pruning is used and no tree
complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
+ self._feature_ids_list = feature_ids_list
# pylint: disable=protected-access
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
@@ -440,14 +442,12 @@ class _EnsembleGrower(object):
"""
@abc.abstractmethod
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
"""Grows a tree, if ready, based on provided statistics.
Args:
stats_summaries_list: List of stats summary tensors, representing sums of
gradients and hessians for each feature bucket.
- feature_ids_list: a list of lists of feature ids for each bucket size.
last_layer_nodes_range: A tensor representing ids of the nodes in the
current layer, to be split.
@@ -455,6 +455,10 @@ class _EnsembleGrower(object):
An op for growing a tree.
"""
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.no_op()
+
# ============= Helper methods ===========
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
@@ -468,7 +472,7 @@ class _EnsembleGrower(object):
return center_bias_var.assign(continue_centering)
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range):
+ last_layer_nodes_range):
"""Updates ensemble based on the best gains from stats summaries."""
node_ids_per_feature = []
gains_list = []
@@ -476,11 +480,11 @@ class _EnsembleGrower(object):
left_node_contribs_list = []
right_node_contribs_list = []
all_feature_ids = []
- assert len(stats_summaries_list) == len(feature_ids_list)
+ assert len(stats_summaries_list) == len(self._feature_ids_list)
max_splits = _get_max_splits(self._tree_hparams)
- for i, feature_ids in enumerate(feature_ids_list):
+ for i, feature_ids in enumerate(self._feature_ids_list):
(numeric_node_ids_per_feature, numeric_gains_list,
numeric_thresholds_list, numeric_left_node_contribs_list,
numeric_right_node_contribs_list) = (
@@ -516,12 +520,13 @@ class _EnsembleGrower(object):
class _InMemoryEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An in-memory ensemble grower."""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
super(_InMemoryEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
def center_bias(self, center_bias_var, gradients, hessians):
# For in memory, we already have a full batch of gradients and hessians,
@@ -531,83 +536,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
# For in memory, we already have full data in one batch, so we can grow the
# tree immediately.
return self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range)
class _AccumulatorEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An accumulator based ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
- n_batches_per_layer, bucket_size_list, is_chief):
+ n_batches_per_layer, bucket_size_list, is_chief, center_bias,
+ feature_ids_list):
super(_AccumulatorEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
self._stamp_token = stamp_token
self._n_batches_per_layer = n_batches_per_layer
self._bucket_size_list = bucket_size_list
self._is_chief = is_chief
+ self._growing_accumulators = []
+ self._chief_init_ops = []
+ max_splits = _get_max_splits(self._tree_hparams)
+ for i, feature_ids in enumerate(self._feature_ids_list):
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ self._chief_init_ops.append(
+ accumulator.set_global_step(self._stamp_token))
+ self._growing_accumulators.append(accumulator)
+ self._center_bias = center_bias
+ if center_bias:
+ self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+ self._chief_init_ops.append(
+ self._bias_accumulator.set_global_step(self._stamp_token))
def center_bias(self, center_bias_var, gradients, hessians):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with centering bias.
# Create an accumulator.
+ if not self._center_bias:
+ raise RuntimeError('center_bias called but bias centering is disabled.')
bias_dependencies = []
- bias_accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians means only.
- # TODO(nponomareva): this will change for a multiclass
- shape=[2, 1],
- shared_name='bias_accumulator')
-
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
- apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
+ apply_grad = self._bias_accumulator.apply_grad(
+ grads_and_hess, self._stamp_token)
bias_dependencies.append(apply_grad)
# Center bias if enough batches were processed.
with ops.control_dependencies(bias_dependencies):
if not self._is_chief:
return control_flow_ops.no_op()
+ def _set_accumulators_stamp():
+ return control_flow_ops.group(
+ [acc.set_global_step(self._stamp_token + 1) for acc in
+ self._growing_accumulators])
def center_bias_from_accumulator():
- accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
- return self._center_bias_fn(center_bias_var,
- array_ops.expand_dims(accumulated[0], 0),
- array_ops.expand_dims(accumulated[1], 0))
+ accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
+ axis=0)
+ center_bias_op = self._center_bias_fn(
+ center_bias_var,
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+ with ops.control_dependencies([center_bias_op]):
+ return control_flow_ops.cond(center_bias_var,
+ control_flow_ops.no_op,
+ _set_accumulators_stamp)
center_bias_op = control_flow_ops.cond(
- math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
self._n_batches_per_layer),
center_bias_from_accumulator,
control_flow_ops.no_op,
name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
- # For not in memory situation, we need to accumulate enough of batches first
- # before proceeding with building a tree layer.
- max_splits = _get_max_splits(self._tree_hparams)
-
- # Prepare accumulators.
- accumulators = []
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
dependencies = []
- for i, feature_ids in enumerate(feature_ids_list):
+ for i in range(len(self._feature_ids_list)):
stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
+ apply_grad = self._growing_accumulators[i].apply_grad(
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
dependencies.append(apply_grad)
@@ -617,7 +637,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
return control_flow_ops.no_op()
min_accumulated = math_ops.reduce_min(
- array_ops.stack([acc.num_accumulated() for acc in accumulators]))
+ array_ops.stack([acc.num_accumulated() for acc in
+ self._growing_accumulators]))
def grow_tree_from_accumulated_summaries_fn():
"""Updates tree with the best layer from accumulated summaries."""
@@ -625,10 +646,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
stats_summaries_list = []
stats_summaries_list = [
array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
+ for accumulator in self._growing_accumulators
]
grow_op = self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range
+ )
return grow_op
grow_model = control_flow_ops.cond(
@@ -638,6 +660,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
name='wait_until_n_batches_accumulated')
return grow_model
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.group(self._chief_init_ops)
+
def _bt_model_fn(
features,
@@ -683,21 +709,7 @@ def _bt_model_fn(
Raises:
ValueError: mode or params are invalid, or features has the wrong type.
"""
- is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
- center_bias = tree_hparams.center_bias
-
- if train_in_memory:
- assert n_batches_per_layer == 1, (
- 'When train_in_memory is enabled, input_fn should return the entire '
- 'dataset as a single batch, and n_batches_per_layer should be set as '
- '1.')
- if (not config.is_chief or config.num_worker_replicas > 1 or
- config.num_ps_replicas > 0):
- raise ValueError('train_in_memory is supported only for '
- 'non-distributed training.')
- worker_device = control_flow_ops.no_op().device
- train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
@@ -724,6 +736,20 @@ def _bt_model_fn(
logits=logits)
# ============== Training graph ==============
+ center_bias = tree_hparams.center_bias
+ is_single_machine = (config.num_worker_replicas <= 1)
+
+ if train_in_memory:
+ assert n_batches_per_layer == 1, (
+ 'When train_in_memory is enabled, input_fn should return the entire '
+ 'dataset as a single batch, and n_batches_per_layer should be set as '
+ '1.')
+ if (not config.is_chief or config.num_worker_replicas > 1 or
+ config.num_ps_replicas > 0):
+ raise ValueError('train_in_memory is supported only for '
+ 'non-distributed training.')
+ worker_device = control_flow_ops.no_op().device
+ train_op = []
# Extract input features and set up cache for training.
training_state_cache = None
if train_in_memory:
@@ -742,22 +768,6 @@ def _bt_model_fn(
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
-
- # Variable that determines whether bias centering is needed.
- center_bias_var = variable_scope.variable(
- initial_value=center_bias, name='center_bias_needed', trainable=False)
- if is_single_machine:
- local_tree_ensemble = tree_ensemble
- ensemble_reload = control_flow_ops.no_op()
- else:
- # Have a local copy of ensemble for the distributed setting.
- with ops.device(worker_device):
- local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
- name=name + '_local', is_local=True)
- # TODO(soroush): Do partial updates if this becomes a bottleneck.
- ensemble_reload = local_tree_ensemble.deserialize(
- *tree_ensemble.serialize())
-
if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup())
@@ -770,21 +780,46 @@ def _bt_model_fn(
array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32))
+ if is_single_machine:
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ # Have a local copy of ensemble for the distributed setting.
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states()
- summary.scalar('ensemble/num_trees', num_trees)
- summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
- summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
-
partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
tree_ensemble_handle=local_tree_ensemble.resource_handle,
cached_tree_ids=cached_tree_ids,
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
- logits = cached_logits + partial_logits
+ logits = cached_logits + partial_logits
+
+ if train_in_memory:
+ grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
+ feature_ids_list=feature_ids_list)
+ else:
+ grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
+ stamp_token, n_batches_per_layer,
+ bucket_size_list, config.is_chief,
+ center_bias=center_bias,
+ feature_ids_list=feature_ids_list)
+
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+ # Variable that determines whether bias centering is needed.
+ center_bias_var = variable_scope.variable(
+ initial_value=center_bias, name='center_bias_needed', trainable=False,
+ use_resource=True)
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
@@ -823,28 +858,24 @@ def _bt_model_fn(
axis=0) for f in feature_ids
]
stats_summaries_list.append(summaries)
-
- if train_in_memory and is_single_machine:
- grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
+ if center_bias:
+ update_model = control_flow_ops.cond(
+ center_bias_var,
+ functools.partial(
+ grower.center_bias,
+ center_bias_var,
+ gradients,
+ hessians,
+ ),
+ functools.partial(grower.grow_tree, stats_summaries_list,
+ last_layer_nodes_range))
else:
- grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
- stamp_token, n_batches_per_layer,
- bucket_size_list, config.is_chief)
-
- update_model = control_flow_ops.cond(
- center_bias_var,
- functools.partial(
- grower.center_bias,
- center_bias_var,
- gradients,
- hessians,
- ),
- functools.partial(grower.grow_tree, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range))
+ update_model = grower.grow_tree(stats_summaries_list,
+ last_layer_nodes_range)
train_op.append(update_model)
with ops.control_dependencies([update_model]):
- increment_global = distribute_lib.increment_var(global_step)
+ increment_global = state_ops.assign_add(global_step, 1).op
train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
@@ -859,10 +890,22 @@ def _bt_model_fn(
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
(_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
- tree_hparams.n_trees, tree_hparams.max_depth),))
+ tree_hparams.n_trees, tree_hparams.max_depth),),
+ training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
+ list(estimator_spec.training_chief_hooks))
return estimator_spec
+class GrowerInitializationHook(session_run_hook.SessionRunHook):
+ """A SessionRunHook handles initialization of `_EnsembleGrower`."""
+
+ def __init__(self, init_op):
+ self._init_op = init_op
+
+ def after_create_session(self, session, coord):
+ session.run(self._init_op)
+
+
def _create_classification_head(n_classes,
weight_column=None,
label_vocabulary=None):
@@ -957,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
classifier = estimator.BoostedTreesClassifier(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -981,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
@@ -1095,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
regressor = estimator.BoostedTreesRegressor(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -1119,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index ec597e4686..6e28c72151 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -173,6 +173,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
+ def testTrainTwiceAndEvaluateBinaryClassifier(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=5,
+ max_depth=10)
+
+ num_steps = 2
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ est.train(input_fn, steps=num_steps)
+
+ self._assert_checkpoint(
+ est.model_dir, global_step=num_steps * 2,
+ finalized_trees=0, attempted_layers=4)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+
def testInferBinaryClassifier(self):
train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -1540,7 +1560,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_classification())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1573,7 +1593,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_classification_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
boosted_trees._create_classification_head(n_classes=2),
@@ -1613,7 +1633,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_classification())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train without train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1646,7 +1666,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_classification_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
boosted_trees._create_classification_head(n_classes=2),
@@ -1684,7 +1704,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_regression())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1714,7 +1734,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_regression_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1754,7 +1774,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
self._get_expected_ensembles_for_regression())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train without train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1784,7 +1804,7 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ops.reset_default_graph()
expected_first, expected_second, expected_third, expected_forth = (
self._get_expected_ensembles_for_regression_with_bias())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Train with train_in_memory mode.
with sess.graph.as_default():
train_op, ensemble_serialized = self._get_train_op_and_ensemble(
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index c08cf61220..1c0c4581c0 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -142,7 +142,7 @@ def _dnn_model_fn(features,
dropout=None,
input_layer_partitioner=None,
config=None,
- tpu_estimator_spec=False,
+ use_tpu=False,
batch_norm=False):
"""Deep Neural Net model_fn.
@@ -164,8 +164,8 @@ def _dnn_model_fn(features,
input_layer_partitioner: Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
- tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
- or `model_fn.EstimatorSpec` instance.
+ use_tpu: Whether to make a DNN model able to run on TPU. Will make function
+ return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
Returns:
@@ -182,13 +182,15 @@ def _dnn_model_fn(features,
optimizer, learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
- partitioner = partitioned_variables.min_max_variable_partitioner(
- max_partitions=num_ps_replicas)
+ partitioner = (None if use_tpu else
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas))
with variable_scope.variable_scope(
'dnn',
values=tuple(six.itervalues(features)),
partitioner=partitioner):
input_layer_partitioner = input_layer_partitioner or (
+ None if use_tpu else
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
@@ -203,7 +205,7 @@ def _dnn_model_fn(features,
batch_norm=batch_norm)
logits = logit_fn(features=features, mode=mode)
- if tpu_estimator_spec:
+ if use_tpu:
return head._create_tpu_estimator_spec( # pylint: disable=protected-access
features=features,
mode=mode,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 4945c3ba11..9799cf9e98 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -31,10 +31,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -161,8 +161,8 @@ def _dnn_linear_combined_model_fn(features,
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
- partitioner=dnn_partitioner):
-
+ partitioner=dnn_partitioner) as scope:
+ dnn_absolute_scope = scope.name
dnn_logit_fn = dnn._dnn_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
hidden_units=dnn_hidden_units,
@@ -186,6 +186,7 @@ def _dnn_linear_combined_model_fn(features,
linear_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as scope:
+ linear_absolute_scope = scope.name
logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
feature_columns=linear_feature_columns,
@@ -211,18 +212,18 @@ def _dnn_linear_combined_model_fn(features,
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=dnn_parent_scope)))
+ scope=dnn_absolute_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=linear_parent_scope)))
+ scope=linear_absolute_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return head.create_estimator_spec(
features=features,
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index de226ed0ef..11f1e93630 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -44,13 +44,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import optimizer as optimizer_lib
@@ -222,7 +222,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
testcase.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -230,7 +230,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
optimizer_mock = test.mock.NonCallableMagicMock(
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index da9a64c2bc..06593f9520 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -335,8 +335,8 @@ def _check_dense_labels_match_logits_and_reshape(
'Expected labels dimension=%s. Received %s. '
'Suggested Fix:'
'If your classifier expects one-hot encoding label,'
- 'check your n_classes argument to the estimator'
- 'and/or the shape of your label.'
+ 'check your n_classes argument to the estimator '
+ 'and/or the shape of your label. '
'Otherwise, check the shape of your label.' %
(expected_labels_dimension, dim1))
expected_labels_shape = array_ops.concat(
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index bd2e0ae943..de9c84d2ef 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -260,7 +260,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
features={'x': np.array(((30.,), (42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
logits_placeholder: logits_2x2
@@ -293,7 +293,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -347,14 +347,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError('Labels must <= n_classes - 1'):
training_loss.eval({
labels_placeholder: labels_2x1_with_large_id,
logits_placeholder: logits_2x3
})
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError('Labels must >= 0'):
training_loss.eval({
labels_placeholder: labels_2x1_with_negative_id,
@@ -413,7 +413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -449,7 +449,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -484,7 +484,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
@@ -510,7 +510,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
@@ -534,7 +534,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -561,7 +561,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -581,7 +581,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -632,7 +632,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -698,7 +698,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -727,7 +727,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -755,7 +755,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
}
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -804,7 +804,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -837,7 +837,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -866,7 +866,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -921,7 +921,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -962,7 +962,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
optimizer=_Optimizer())
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -992,7 +992,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
labels=np.array(((1,), (1,)), dtype=np.int64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -1023,7 +1023,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -1064,7 +1064,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1104,7 +1104,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels_rank_1)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1153,7 +1153,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1183,7 +1183,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1211,7 +1211,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
train_op_fn=_train_op_fn)
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss = sess.run(spec.loss)
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1253,7 +1253,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1292,7 +1292,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1327,7 +1327,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1353,7 +1353,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1380,7 +1380,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1413,7 +1413,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -1506,7 +1506,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
features={'x': np.array(((42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
logits_placeholder: logits_2x2
@@ -1536,7 +1536,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -1577,7 +1577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[3 1\] \[labels_shape: \] \[2 1\]'):
@@ -1585,7 +1585,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels_placeholder: values_2x1,
logits_placeholder: values_3x1
})
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -1624,7 +1624,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -1660,7 +1660,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
@@ -1680,7 +1680,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1733,7 +1733,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1808,7 +1808,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1832,7 +1832,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(41., training_loss.eval())
@@ -1849,7 +1849,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1877,7 +1877,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1924,7 +1924,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1957,7 +1957,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -1983,7 +1983,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -2011,7 +2011,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -2031,7 +2031,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2086,7 +2086,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2126,7 +2126,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=labels,
optimizer=_Optimizer())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss)
@@ -2153,7 +2153,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=np.array(((1,), (1,),), dtype=np.float64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -2182,7 +2182,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
# Assert summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -2227,7 +2227,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
regularization_losses=regularization_losses)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2254,7 +2254,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'Labels must <= n_classes - 1'):
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
training_loss.eval()
@@ -2277,7 +2277,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2309,7 +2309,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
train_op_fn=_train_op_fn)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)
@@ -2334,7 +2334,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2360,7 +2360,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
expected_loss = 1.2484322
# Assert loss.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2385,7 +2385,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
predictions = sess.run(spec.predictions)
self.assertAllClose(
@@ -2447,7 +2447,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2483,7 +2483,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels_rank_1)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(),
@@ -2531,7 +2531,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNotNone(spec.train_op)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((
@@ -2577,7 +2577,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNotNone(spec.train_op)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((
@@ -2612,7 +2612,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-2
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(),
@@ -2649,7 +2649,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -2675,7 +2675,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2700,7 +2700,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -2744,7 +2744,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
tol = 1e-2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2825,7 +2825,7 @@ class RegressionHead(test.TestCase):
features={'x': np.array(((42.,),))},
mode=model_fn.ModeKeys.PREDICT,
logits=logits_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval({
logits_placeholder: logits_1d
@@ -2857,7 +2857,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.loss.eval({
labels_placeholder: values_3d,
@@ -2868,7 +2868,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2908,7 +2908,7 @@ class RegressionHead(test.TestCase):
logits=logits_placeholder,
labels=labels_placeholder,
train_op_fn=lambda x: x)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
spec.loss.eval({
labels_placeholder: values_3d,
@@ -2919,7 +2919,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits_placeholder,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2957,7 +2957,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(logits, spec.predictions[prediction_key].eval())
self.assertAllClose(
@@ -2992,7 +2992,7 @@ class RegressionHead(test.TestCase):
spec.export_outputs.keys())
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
@@ -3019,7 +3019,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(43-45)^2, (44-41)] = [4, 9]
self.assertAllClose(13., training_loss.eval())
@@ -3045,7 +3045,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss), actual_training_loss.eval())
@@ -3064,7 +3064,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -3112,7 +3112,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3180,7 +3180,7 @@ class RegressionHead(test.TestCase):
}
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -3212,7 +3212,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3237,7 +3237,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3294,7 +3294,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3337,7 +3337,7 @@ class RegressionHead(test.TestCase):
labels=labels,
optimizer=_Optimizer())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss)
@@ -3364,7 +3364,7 @@ class RegressionHead(test.TestCase):
labels=np.array(((43.,), (44.,),), dtype=np.float64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -3394,7 +3394,7 @@ class RegressionHead(test.TestCase):
train_op_fn=_train_op_fn)
# Assert summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
summary_str = sess.run(spec.scaffold.summary_op)
@@ -3441,7 +3441,7 @@ class RegressionHead(test.TestCase):
regularization_losses=regularization_losses)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
prediction_key = prediction_keys.PredictionKeys.PREDICTIONS
@@ -3487,7 +3487,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3523,7 +3523,7 @@ class RegressionHead(test.TestCase):
labels=np.array(((35,), (42,), (45,)), dtype=np.int32))
# Assert loss.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss = sess.run(spec.loss)
# loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6
@@ -3565,7 +3565,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3600,7 +3600,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels_rank_1)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3648,7 +3648,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3679,7 +3679,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
# weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3718,7 +3718,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3750,7 +3750,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
# loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
# weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3796,7 +3796,7 @@ class RegressionHead(test.TestCase):
_assert_no_hooks(self, spec)
# Evaluate predictions, loss, train_op, and summaries.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
@@ -3857,7 +3857,7 @@ class RegressionHead(test.TestCase):
self.assertIsNone(spec.train_op)
_assert_no_hooks(self, spec)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Finalize graph and initialize variables.
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3915,7 +3915,7 @@ class RegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.loss.dtype)
self.assertIsNotNone(spec.train_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Finalize graph and initialize variables.
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3955,7 +3955,7 @@ class RegressionHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss, training_loss.eval())
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3988,7 +3988,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_loss, spec.loss.eval())
@@ -4013,7 +4013,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -4042,7 +4042,7 @@ class RegressionHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_no_op_train_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index c3934c7a80..65cdd50061 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -48,13 +48,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer as optimizer_lib
@@ -756,7 +756,7 @@ class BaseLinearRegressorTrainingTest(object):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -764,7 +764,7 @@ class BaseLinearRegressorTrainingTest(object):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -979,13 +979,13 @@ class BaseLinearClassifierTrainingTest(object):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer_lib.Optimizer,
diff --git a/tensorflow/python/estimator/canned/prediction_keys.py b/tensorflow/python/estimator/canned/prediction_keys.py
index 16890ec09a..daa275b46b 100644
--- a/tensorflow/python/estimator/canned/prediction_keys.py
+++ b/tensorflow/python/estimator/canned/prediction_keys.py
@@ -32,3 +32,4 @@ class PredictionKeys(object):
LOGITS = 'logits'
PREDICTIONS = 'predictions'
PROBABILITIES = 'probabilities'
+ TOP_K = 'top_k'
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 52002fd79b..0f20acefdf 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -35,17 +35,16 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
-from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
-from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -120,7 +119,10 @@ class Estimator(object):
warm_start_from=None):
"""Constructs an `Estimator` instance.
- See @{$estimators} for more information. To warm-start an `Estimator`:
+ See [estimators](https://tensorflow.org/guide/estimators) for more
+ information.
+
+ To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
@@ -152,9 +154,9 @@ class Estimator(object):
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
- * `config`: Optional configuration object. Will receive what is passed
- to Estimator in `config` parameter, or the default `config`.
- Allows updating things in your `model_fn` based on
+ * `config`: Optional `estimator.RunConfig` object. Will receive what
+ is passed to Estimator as its `config` parameter, or a default
+ value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
@@ -166,7 +168,7 @@ class Estimator(object):
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
temporary directory will be used.
- config: Configuration object.
+ config: `estimator.RunConfig` configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
@@ -184,8 +186,8 @@ class Estimator(object):
"""
Estimator._assert_members_are_not_overridden(self)
- config = maybe_overwrite_model_dir_and_session_config(config, model_dir)
- self._config = config
+ self._config = maybe_overwrite_model_dir_and_session_config(config,
+ model_dir)
# The distribute field contains an instance of DistributionStrategy.
self._train_distribution = self._config.train_distribute
@@ -285,8 +287,10 @@ class Estimator(object):
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more information.
- The function should construct and return one of the following: * A
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
`(features, labels)` with same constraints as below. * A tuple
`(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
@@ -321,6 +325,14 @@ class Estimator(object):
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps <= 0`.
"""
+ if self.config.task_type in (run_config.TaskType.EVALUATOR,
+ run_config.TaskType.PS):
+ raise ValueError(
+ 'Train has been called wrong configuration. Please use '
+ 'tf.estimator.train_and_evaluate which calls propper API according '
+ 'to given configuration. Current configuration: {}.'.format(
+ self.config))
+
with context.graph_mode():
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
@@ -394,7 +406,9 @@ class Estimator(object):
Args:
input_fn: A function that constructs the input data for evaluation. See
- @{$premade_estimators#create_input_functions} for more information. The
+ [Premade Estimators](
+ https://tensorflow.org/guide/premade#create_input_functions)
+ for more information. The
function should construct and return one of the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
`(features, labels)` with same constraints as below. * A tuple
@@ -419,7 +433,11 @@ class Estimator(object):
Returns:
A dict containing the evaluation metrics specified in `model_fn` keyed by
name, as well as an entry `global_step` which contains the value of the
- global step for which this evaluation was performed.
+ global step for which this evaluation was performed. For canned
+ estimators, the dict contains the `loss` (mean loss per mini-batch) and
+ the `average_loss` (mean loss per sample). Canned classifiers also return
+ the `accuracy`. Canned regressors also return the `label/mean` and the
+ `prediction/mean`.
Raises:
ValueError: If `steps <= 0`.
@@ -450,9 +468,7 @@ class Estimator(object):
output_dir=self.eval_dir(name))
with ops.Graph().as_default():
- # TODO(priyag): Support distributed eval on TPUs.
- if (self._eval_distribution
- and self._eval_distribution.__class__.__name__ != 'TPUStrategy'):
+ if self._eval_distribution:
with self._eval_distribution.scope():
return _evaluate()
else:
@@ -478,8 +494,9 @@ class Estimator(object):
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception
(`tf.errors.OutOfRangeError` or `StopIteration`).
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A `tf.data.Dataset` object: Outputs of `Dataset` object must have
@@ -591,12 +608,43 @@ class Estimator(object):
as_text=False,
checkpoint_path=None,
strip_default_attrs=False):
+ # pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
+ """Exports inference graph as a `SavedModel` into the given dir.
+
+ Note that `export_to_savedmodel` will be renamed to `export_to_saved_model`
+ in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
+ additional underscore will be available only through tf.compat.v1.
+
+ Please see `tf.estimator.Estimator.export_saved_model` for more information.
+
+ There is one additional arg versus the new method:
+ strip_default_attrs: This parameter is going away in TF 2.0, and
+ the new behavior will automatically strip all default attributes.
+ Boolean. If `True`, default-valued attributes will be
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued Attributes](
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ """
+ # pylint: enable=line-too-long,g-doc-args,g-doc-return-or-yield
+ return self._export_saved_model_for_mode(
+ export_dir_base,
+ serving_input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=model_fn_lib.ModeKeys.PREDICT)
+
+ def export_saved_model(
+ self, export_dir_base, serving_input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None):
# pylint: disable=line-too-long
"""Exports inference graph as a `SavedModel` into the given dir.
For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with
- Estimators}.
+ [Using SavedModel with Estimators](https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
This method builds a new graph by first calling the
`serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling
@@ -638,28 +686,25 @@ class Estimator(object):
as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
- strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the `NodeDef`s. For a detailed guide, see [Stripping
- Default-Valued
- Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
Raises:
ValueError: if no `serving_input_receiver_fn` is provided, no
- `export_outputs`
- are provided, or no checkpoint can be found.
+ `export_outputs` are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
- return self._export_saved_model_for_mode(
+ # TODO(b/111442174): `export_to_savedmodel` will be renamed to
+ # `export_to_saved_model` in TensorFlow 2.0. This function is a wrapper
+ # while staging the new version; do not add any logic here.
+ return self.export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path,
- strip_default_attrs=strip_default_attrs,
- mode=model_fn_lib.ModeKeys.PREDICT)
+ strip_default_attrs=True)
def _export_saved_model_for_mode(
self, export_dir_base, input_receiver_fn,
@@ -911,7 +956,12 @@ class Estimator(object):
mode=mode,
config=self.config)
- export_outputs = self._get_export_outputs_for_spec(estimator_spec)
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode=estimator_spec.mode,
+ serving_export_outputs=estimator_spec.export_outputs,
+ predictions=estimator_spec.predictions,
+ loss=estimator_spec.loss,
+ metrics=estimator_spec.eval_metric_ops)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export_helpers.build_all_signature_defs(
@@ -968,45 +1018,6 @@ class Estimator(object):
else:
builder.add_meta_graph(**meta_graph_kwargs)
- def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an `EstimatorSpec`, determine what our export outputs should be.
-
- `EstimatorSpecs` contains `export_outputs` that are used for serving, but
- for
- training and eval graphs, we must wrap the tensors of interest in
- appropriate `tf.estimator.export.ExportOutput` objects.
-
- Args:
- estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
-
- Returns:
- a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
- object.
-
- Raises:
- ValueError: if an appropriate `ExportOutput` cannot be found for the
- passed `EstimatorSpec.mode`
- """
- mode = estimator_spec.mode
- if mode == model_fn_lib.ModeKeys.PREDICT:
- outputs = estimator_spec.export_outputs
- else:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- output_class = export_output.TrainOutput
- elif mode == model_fn_lib.ModeKeys.EVAL:
- output_class = export_output.EvalOutput
- else:
- raise ValueError(
- 'Export output type not found for mode: {}'.format(mode))
-
- export_out = output_class(
- loss=estimator_spec.loss,
- predictions=estimator_spec.predictions,
- metrics=estimator_spec.eval_metric_ops)
- outputs = {mode: export_out}
-
- return outputs
-
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
@@ -1020,16 +1031,21 @@ class Estimator(object):
'QueueRunner. That means predict yields forever. '
'This is probably a mistake.')
- def _get_features_and_labels_from_input_fn(self, input_fn, mode,
- distribution=None):
- """Extracts the `features` and labels from return values of `input_fn`."""
+ def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode)
- return estimator_util.parse_input_fn_result(result)
+ iterator = result.make_initializable_iterator()
+ input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
+ return iterator, input_hooks
+
+ def _get_features_and_labels_from_input_fn(self, input_fn, mode):
+ """Extracts the `features` and labels from return values of `input_fn`."""
+ return estimator_util.parse_input_fn_result(
+ self._call_input_fn(input_fn, mode))
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -1222,30 +1238,23 @@ class Estimator(object):
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
- steps_per_run_variable = training.get_or_create_steps_per_run_variable()
+ if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(
+ training_util.GLOBAL_STEP_READ_KEY,
+ self._train_distribution.read_var(global_step_tensor))
if is_tpu_strategy:
- # Create the iterator for run_on_dataset function
- # TODO(sourabhbajaj): refactor this out to call a function on the
- # strategy
- dataset = self._train_distribution.distribute_dataset(
- lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda
- model_fn_lib.ModeKeys.TRAIN))
- iterator = dataset.make_initializable_iterator()
- worker_hooks.append(
- estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access
-
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
-
# Create a step_fn from the train_op of grouped_estimator_spec
- def step_fn(ctx, features, labels):
+ def step_fn(ctx, features, labels=None):
"""A single step that is passed to run_on_dataset."""
estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
@@ -1267,54 +1276,42 @@ class Estimator(object):
step_fn, iterator, iterations=steps_per_run_variable,
initial_loop_values={'loss': initial_training_loss})
distributed_train_op = ctx.run_op
- tpu_result = ctx.last_step_outputs
+ loss = ctx.last_step_outputs['loss']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.TRAIN,
- self._train_distribution))
- worker_hooks.extend(input_hooks)
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
+ loss = self._train_distribution.unwrap(
+ self._train_distribution.reduce(
+ distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0]
+ distributed_train_op = grouped_estimator_spec.train_op
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._train_distribution)
+ # TODO(yuefengz): add a test for unwrapping per_device_hooks.
def get_hooks_from_the_first_device(per_device_hooks):
- hooks_list = self._train_distribution.unwrap(per_device_hooks)
- assert hooks_list
- return hooks_list[0]
+ return [
+ self._distribution.unwrap(per_device_hook)[0]
+ for per_device_hook in per_device_hooks
+ ]
training_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_hooks)
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
-
- # TODO(sourabhbajaj): Merge the two code paths and clean up the code
- if is_tpu_strategy:
- loss = tpu_result['loss']
- worker_hooks.append(
- estimator_util.StrategyInitFinalizeHook(
- self._train_distribution.initialize,
- self._train_distribution.finalize))
- else:
- loss = self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(),
- grouped_estimator_spec.loss,
- destinations='/device:CPU:0'))[0]
- distributed_train_op = grouped_estimator_spec.train_op
+ worker_hooks.append(
+ estimator_util.StrategyInitFinalizeHook(
+ self._train_distribution.initialize,
+ self._train_distribution.finalize))
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
@@ -1415,31 +1412,18 @@ class Estimator(object):
"""Builds the graph and related hooks to run evaluation."""
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(ops.get_default_graph())
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution))
if self._eval_distribution:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval_distributed(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval_distributed(input_fn, self.config))
else:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval(input_fn, self.config))
global_step_tensor = training_util.get_global_step(ops.get_default_graph())
# Call to warm_start has to be after model_fn is called.
self._maybe_warm_start(checkpoint_path)
- if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
- raise ValueError(
- 'Metric with name "%s" is not allowed, because Estimator ' %
- (model_fn_lib.LOSS_METRIC_KEY) +
- 'already defines a default metric with the same name.')
- eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
-
- update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops,
- self._eval_distribution)
-
if ops.GraphKeys.GLOBAL_STEP in eval_dict:
raise ValueError(
'Metric with name `global_step` is not allowed, because Estimator '
@@ -1464,26 +1448,71 @@ class Estimator(object):
return scaffold, update_op, eval_dict, all_hooks
- def _call_model_fn_eval(self, features, labels, config):
+ def _call_model_fn_eval(self, input_fn, config):
+ """Call model_fn for evaluation and handle return values."""
+ features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL)
+
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, config)
- loss_metric = metrics_lib.mean(estimator_spec.loss)
- return (loss_metric, estimator_spec.scaffold,
- estimator_spec.evaluation_hooks, estimator_spec.eval_metric_ops)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss)
+ update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
+ return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
+ input_hooks, update_op, eval_dict)
- def _call_model_fn_eval_distributed(self, features, labels, config):
+ def _call_model_fn_eval_distributed(self, input_fn, config):
"""Call model_fn in distribution mode and handle return values."""
- grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
- self._call_model_fn, features, labels,
- model_fn_lib.ModeKeys.EVAL, config)
+
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution)
+
+ is_tpu_strategy = (
+ self._eval_distribution.__class__.__name__ == 'TPUStrategy')
+
+ if is_tpu_strategy:
+ def step_fn(ctx, features, labels=None):
+ """Runs one step of the eval computation and captures outputs."""
+ estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels, model_fn_lib.ModeKeys.EVAL,
+ config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+ ctx.set_non_tensor_output(name='estimator_spec', output=estimator_spec)
+ ctx.set_non_tensor_output(name='eval_dict', output=eval_dict)
+ return update_op
+
+ # TODO(priyag): Fix eval step hook to account for steps_per_run.
+ ctx = self._eval_distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ update_op = ctx.run_op
+ eval_dict = ctx.non_tensor_outputs['eval_dict']
+ grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
+ else:
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
+ grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels,
+ model_fn_lib.ModeKeys.EVAL, config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ grouped_estimator_spec.eval_metric_ops, grouped_estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._eval_distribution)
evaluation_hooks = self._eval_distribution.unwrap(
grouped_estimator_spec.evaluation_hooks)[0]
- loss_metric = self._eval_distribution.call_for_each_tower(
- metrics_lib.mean, grouped_estimator_spec.loss)
- return (loss_metric, scaffold,
- evaluation_hooks, grouped_estimator_spec.eval_metric_ops)
+ evaluation_hooks = evaluation_hooks + (
+ estimator_util.StrategyInitFinalizeHook(
+ self._eval_distribution.initialize,
+ self._eval_distribution.finalize),)
+
+ return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
@@ -1519,6 +1548,23 @@ class Estimator(object):
warm_starting_util.warm_start(*self._warm_start_settings)
+def _verify_and_create_loss_metric(eval_metric_ops, loss, distribution=None):
+ """Creates a metric for loss and throws an error if one already exists."""
+ if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
+ raise ValueError(
+ 'Metric with name "%s" is not allowed, because Estimator ' %
+ (model_fn_lib.LOSS_METRIC_KEY) +
+ 'already defines a default metric with the same name.')
+
+ if distribution is None:
+ loss_metric = metrics_lib.mean(loss)
+ else:
+ loss_metric = distribution.call_for_each_tower(
+ metrics_lib.mean, loss)
+ eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
+ return eval_metric_ops
+
+
def maybe_overwrite_model_dir_and_session_config(config, model_dir):
"""Overwrite estimator config by `model_dir` and `session_config` if needed.
@@ -1562,21 +1608,6 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
return config
-def create_per_tower_ready_op(scaffold):
- """Create a `tf.train.Scaffold.ready_op` inside a tower."""
- if scaffold.ready_op:
- return scaffold.ready_op
-
- def default_ready_op():
- return array_ops.concat([
- variables.report_uninitialized_variables(),
- resources.report_uninitialized_resources()
- ], 0)
-
- return monitored_session.Scaffold.get_or_default(
- 'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
-
-
def create_per_tower_ready_for_local_init_op(scaffold):
"""Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower."""
if scaffold.ready_for_local_init_op:
@@ -1626,11 +1657,9 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
return value[0]
ready_op = distribution.call_for_each_tower(
- create_per_tower_ready_op, grouped_scaffold)
+ lambda scaffold: scaffold.ready_op, grouped_scaffold)
if ready_op is not None:
ready_op = _unwrap_and_concat(ready_op)
- else:
- ready_op = None
ready_for_local_init_op = distribution.call_for_each_tower(
create_per_tower_ready_for_local_init_op, grouped_scaffold)
@@ -1758,19 +1787,21 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
update_ops = []
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
- for name, metric_ops in sorted(six.iteritems(eval_dict)):
- value_ops[name] = metric_ops[0]
- if distribution:
- update_op = distribution.group(metric_ops[1])
+ for name, value in sorted(six.iteritems(eval_dict)):
+ if isinstance(value, metrics.Metric):
+ metric_result = value.result()
+ # We expect only one update op for every metric when there is no
+ # distribution strategy.
+ metric_update = value.updates if distribution else value.updates[0]
else:
- update_op = metric_ops[1]
- update_ops.append(update_op)
+ metric_result = value[0]
+ metric_update = value[1]
- if update_ops:
- update_op = control_flow_ops.group(*update_ops)
- else:
- update_op = None
+ value_ops[name] = metric_result
+ update_ops.append(
+ distribution.group(metric_update) if distribution else metric_update)
+ update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
@@ -2025,7 +2056,7 @@ class WarmStartSettings(
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
`tf.estimator.VocabInfo`. The variable names should be "full" variables,
not the names of the partitions. If not explicitly provided, the variable
- is assumed to have no vocabulary.
+ is assumed to have no (changes to) vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 05d1a04d2f..1ed5e30b0e 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.layers import layers
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@@ -955,22 +956,44 @@ class EstimatorTrainTest(test.TestCase):
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
+ def test_config_should_not_be_evaluator_or_ps(self):
+
+ class FakeEvaluatorConfig(run_config.RunConfig):
+
+ @property
+ def task_type(self):
+ return run_config.TaskType.EVALUATOR
+
+ est = estimator.Estimator(
+ model_fn=dummy_model_fn, config=FakeEvaluatorConfig())
+ with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
+ est.train(dummy_input_fn, steps=1)
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
- metric_name = params.get('metric_name') or 'metric'
- metric_value = params.get('metric_value') or 2.
global_step = training.get_global_step()
loss = constant_op.constant(1.)
+ metric_name_1 = params.get('metric_name') or 'metric'
+ metric_value_1 = params.get('metric_value') or 2.
+ metric_name_2 = params.get('metric_name_2') or 'metric2'
+ metric_value_2 = params.get('metric_value_2') or 2.
+
metric_update_op = loss.op
metric_tensor = control_flow_ops.with_dependencies(
- [metric_update_op], constant_op.constant(metric_value))
+ [metric_update_op], constant_op.constant(metric_value_1))
+
+ mean = metrics_module.Mean()
+ mean.update_state(metric_value_2)
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
predictions={'predictions': constant_op.constant(1.)},
train_op=state_ops.assign_add(global_step, 1),
- eval_metric_ops={metric_name: (metric_tensor, metric_update_op)})
+ eval_metric_ops={
+ metric_name_1: (metric_tensor, metric_update_op),
+ metric_name_2: mean,
+ })
class _StepCounterHook(session_run_hook.SessionRunHook):
@@ -1154,16 +1177,22 @@ class EstimatorEvaluateTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params):
del features, labels, params
+ mean = metrics_module.Mean()
+ mean.update_state(variables.Variable(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(
- variables.Variable(2.) + 1)})
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ })
+
est = estimator.Estimator(model_fn=_model_fn)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ scores = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is newly
# initialized (since there is no checkpoint).
- self.assertEqual(3., metrics['metric'])
+ self.assertEqual(3., scores['mean1'])
+ self.assertEqual(3., scores['mean2'])
def test_no_checkpoint_uses_init_with_warm_starting(self):
def _make_model_fn(x):
@@ -1171,14 +1200,24 @@ class EstimatorEvaluateTest(test.TestCase):
_, _ = features, labels
x_var = variable_scope.get_variable('x', initializer=x)
global_step = training.get_global_step()
+ mean = metrics_module.Mean()
+ mean.update_state(x_var + 1)
return model_fn_lib.EstimatorSpec(
mode,
predictions={'y': constant_op.constant(1.0)},
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(x_var + 1)},
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(x_var + 1)
+ },
train_op=state_ops.assign_add(global_step, 1),
- export_outputs={'test': export_output.ClassificationOutput(
- constant_op.constant([4.2]), constant_op.constant(['label']))})
+ export_outputs={
+ 'test':
+ export_output.ClassificationOutput(
+ constant_op.constant([4.2]),
+ constant_op.constant(['label']))
+ })
+
return _variable_creating_and_export_model_fn
first_est = estimator.Estimator(model_fn=_make_model_fn(42.))
@@ -1197,30 +1236,37 @@ class EstimatorEvaluateTest(test.TestCase):
# or an exported SavedModel.
est = estimator.Estimator(model_fn=_make_model_fn(52.),
warm_start_from=exported_path)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from the SavedModel of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
est = estimator.Estimator(model_fn=_make_model_fn(62.),
warm_start_from=first_est.model_dir)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from a checkpoint of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
def test_scores(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params={
'metric_name': 'metric',
- 'metric_value': 2.})
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('metric', scores)
self.assertAlmostEqual(2., scores['metric'])
+ self.assertIn('metric2', scores)
+ self.assertAlmostEqual(3., scores['metric2'])
def test_tuple_metrics(self):
def _model_fn(features, labels, mode):
@@ -1271,8 +1317,12 @@ class EstimatorEvaluateTest(test.TestCase):
def test_global_step_is_reported(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
- params={'metric_name': 'metric',
- 'metric_value': 2.})
+ params={
+ 'metric_name': 'metric',
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('global_step', scores)
@@ -1315,7 +1365,10 @@ class EstimatorEvaluateTest(test.TestCase):
def test_evaluate_from_checkpoint(self):
params = {
'metric_name': 'metric',
- 'metric_value': 2.}
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ }
est1 = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params=params)
@@ -2014,8 +2067,15 @@ def _model_fn_with_x_y(features, labels, mode):
multiplied = math_ops.multiply(
features['x'], features['y'], name='{}multiplied'.format(prefix))
- metrics = {'mean': metrics_lib.mean(features['x'] - features['y'],
- name='{}mean'.format(prefix))}
+ mean = metrics_module.Mean(name='{}mean'.format(prefix))
+ mean.update_state(features['x'] - features['y'])
+ eval_metrics = {
+ 'mean1':
+ mean,
+ 'mean2':
+ metrics_lib.mean(
+ features['x'] - features['y'], name='{}mean'.format(prefix))
+ }
variables.Variable(1., name='later_var')
variables.Variable(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
@@ -2023,7 +2083,7 @@ def _model_fn_with_x_y(features, labels, mode):
predictions=multiplied,
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
@@ -2382,14 +2442,18 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
del features, labels # Unused
- metrics = {'metrics': (constant_op.constant([0]),
- control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ eval_metrics = {
+ 'metrics1': (constant_op.constant([0]), control_flow_ops.no_op()),
+ 'metrics2': metric_obj,
+ }
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(model_fn=_model_fn)
@@ -2411,8 +2475,10 @@ class EstimatorExportTest(test.TestCase):
meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
sig_outputs = meta_graph.signature_def[
model_fn_lib.ModeKeys.EVAL].outputs
- self.assertEqual(
- sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+ self.assertTrue(sig_outputs['metrics1/update_op'].name.startswith(
+ 'metric_op_wrapper'))
+ self.assertTrue(sig_outputs['metrics2/update_op'].name.startswith(
+ 'metric_op_wrapper'))
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
@@ -3067,9 +3133,13 @@ class EstimatorIntegrationTest(test.TestCase):
loss = losses.mean_squared_error(labels, predictions)
train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
loss, training.get_global_step())
+ mean = metrics_module.Mean()
+ mean.update_state(loss)
eval_metric_ops = {
- 'absolute_error': metrics_lib.mean_absolute_error(
- labels, predictions)
+ 'absolute_error':
+ metrics_lib.mean_absolute_error(labels, predictions),
+ 'mean':
+ mean,
}
return model_fn_lib.EstimatorSpec(
@@ -3089,12 +3159,13 @@ class EstimatorIntegrationTest(test.TestCase):
x={'x': data}, y=data, batch_size=50, num_epochs=None, shuffle=True)
est.train(train_input_fn, steps=200)
- # EVALUTE
+ # EVALUATE
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data}, y=data, batch_size=50, num_epochs=1, shuffle=True)
scores = est.evaluate(eval_input_fn)
self.assertEqual(200, scores['global_step'])
self.assertGreater(0.1, scores['absolute_error'])
+ self.assertAlmostEqual(4.4e-14, scores['mean'], places=2)
# PREDICT
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 3d171f7811..55aace5fa9 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -217,6 +217,29 @@ class TensorServingInputReceiver(
receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
+class UnsupervisedInputReceiver(ServingInputReceiver):
+ """A return type for a training_input_receiver_fn or eval_input_receiver_fn.
+
+ This differs from SupervisedInputReceiver in that it does not require a set
+ of labels.
+
+ The expected return values are:
+ features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the features to be passed to the model.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
+ """
+
+ def __new__(cls, features, receiver_tensors):
+ return super(UnsupervisedInputReceiver, cls).__new__(
+ cls,
+ features=features,
+ receiver_tensors=receiver_tensors,
+ receiver_tensors_alternatives=None)
+
+
class SupervisedInputReceiver(
collections.namedtuple('SupervisedInputReceiver',
['features', 'labels', 'receiver_tensors'])):
@@ -288,13 +311,33 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
def _placeholder_from_tensor(t, default_batch_size=None):
+ """Creates a placeholder that matches the dtype and shape of passed tensor.
+
+ Args:
+ t: Tensor or EagerTensor
+ default_batch_size: the number of query examples expected per batch.
+ Leave unset for variable batch size (recommended).
+
+ Returns:
+ Placeholder that matches the passed tensor.
+ """
batch_shape = tensor_shape.TensorShape([default_batch_size])
shape = batch_shape.concatenate(t.get_shape()[1:])
# Reuse the feature tensor's op name (t.op.name) for the placeholder,
# excluding the index from the tensor's name (t.name):
# t.name = "%s:%d" % (t.op.name, t._value_index)
- return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
+ try:
+ name = t.op.name
+ except AttributeError:
+ # In Eager mode, tensors don't have ops or names, and while they do have
+ # IDs, those are not maintained across runs. The name here is used
+ # primarily for debugging, and is not critical to the placeholder.
+ # So, in order to make this Eager-compatible, continue with an empty
+ # name if none is available.
+ name = None
+
+ return array_ops.placeholder(dtype=t.dtype, shape=shape, name=name)
def _placeholders_from_receiver_tensors_dict(input_vals,
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 20382a58d8..c17fc08f21 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util.tf_export import estimator_export
@@ -259,7 +260,10 @@ class _SupervisedOutput(ExportOutput):
loss: dict of Tensors or single Tensor representing calculated loss.
predictions: dict of Tensors or single Tensor representing model
predictions.
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
metric_value must be a Tensor, and update_op must be a Tensor or Op.
Raises:
@@ -311,7 +315,11 @@ class _SupervisedOutput(ExportOutput):
Here, we separate out the tuples and create a dict with names to tensors.
Args:
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op.
Returns:
dict of output_names to tensors
@@ -324,7 +332,13 @@ class _SupervisedOutput(ExportOutput):
metrics = {self.METRICS_NAME: metrics}
outputs = {}
- for key, (metric_val, metric_op) in metrics.items():
+ for key, value in metrics.items():
+ if isinstance(value, metrics_module.Metric):
+ metric_val = value.result()
+ assert len(value.updates) == 1 # We expect only one update op.
+ metric_op = value.updates[0]
+ else:
+ metric_val, metric_op = value
key = self._check_output_key(key, self.METRICS_NAME)
key = self._prefix_key(key, self.METRICS_NAME)
@@ -397,7 +411,3 @@ class EvalOutput(_SupervisedOutput):
def _get_signature_def_fn(self):
return signature_def_utils.supervised_eval_signature_def
-
-
-
-
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index d94c764fd7..96ce0e580d 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -240,16 +241,19 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10])),
- "metrics2": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics": metric_obj,
+ "metrics2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"])
self.assertEqual(
outputter.predictions["predictions/output1"], predictions["output1"])
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper:0")
self.assertEqual(
outputter.metrics["metrics2/update_op"], metrics["metrics2"][1])
@@ -259,7 +263,8 @@ class SupervisedOutputTest(test.TestCase):
self.assertEqual(outputter.loss, {"loss": loss["my_loss"]})
self.assertEqual(
outputter.predictions, {"predictions": predictions["output1"]})
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper_1:0")
def test_supervised_outputs_none(self):
outputter = MockSupervisedOutput(
@@ -282,34 +287,56 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {("my", "loss"): constant_op.constant([0])}
predictions = {(u"output1", "2"): constant_op.constant(["foo"])}
- metrics = {("metrics", "twice"): (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ ("metrics", "1"):
+ metric_obj,
+ ("metrics", "2"): (constant_op.constant([0]),
+ constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"]))
self.assertEqual(set(outputter.predictions.keys()),
set(["predictions/output1/2"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/twice/value", "metrics/twice/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics/1/value", "metrics/1/update_op", "metrics/2/value",
+ "metrics/2/update_op"
+ ]))
def test_supervised_outputs_no_prepend(self):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"loss": constant_op.constant([0])}
predictions = {u"predictions": constant_op.constant(["foo"])}
- metrics = {u"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss"]))
self.assertEqual(set(outputter.predictions.keys()), set(["predictions"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/value", "metrics/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics_1/value", "metrics_1/update_op", "metrics_2/update_op",
+ "metrics_2/value"
+ ]))
def test_train_signature_def(self):
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
@@ -318,7 +345,8 @@ class SupervisedOutputTest(test.TestCase):
sig_def = outputter.as_signature_def(receiver)
self.assertTrue("loss/my_loss" in sig_def.outputs)
- self.assertTrue("metrics/value" in sig_def.outputs)
+ self.assertTrue("metrics_1/value" in sig_def.outputs)
+ self.assertTrue("metrics_2/value" in sig_def.outputs)
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
@@ -337,18 +365,33 @@ class SupervisedOutputTest(test.TestCase):
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
- def test_metric_op_is_operation(self):
+ def test_metric_op_is_tensor(self):
"""Tests that ops.Operation is wrapped by a tensor for metric_ops."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), control_flow_ops.no_op())
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
- self.assertEqual(
- outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0")
+
+ self.assertTrue(outputter.metrics["metrics_1/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_1/update_op"], ops.Tensor))
self.assertTrue(
- isinstance(outputter.metrics["metrics/update_op"], ops.Tensor))
+ isinstance(outputter.metrics["metrics_1/value"], ops.Tensor))
+
+ self.assertEqual(outputter.metrics["metrics_2/value"],
+ metrics["metrics_2"][0])
+ self.assertTrue(outputter.metrics["metrics_2/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_2/update_op"], ops.Tensor))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 1d475adb43..3eed1ab163 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -163,6 +163,29 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase):
_ = export.ServingInputReceiver(feature, receiver_tensor)
+class UnsupervisedInputReceiverTest(test_util.TensorFlowTestCase):
+
+ # Since this is basically a wrapper around ServingInputReceiver, we only
+ # have a simple sanity check to ensure that it works.
+
+ def test_unsupervised_input_receiver_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0":
+ constant_op.constant([0]),
+ u"feature1":
+ constant_op.constant([1]),
+ "feature2":
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ export.UnsupervisedInputReceiver(features, receiver_tensors)
+
+
class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
def test_input_receiver_constructor(self):
@@ -393,6 +416,7 @@ class ExportTest(test_util.TensorFlowTestCase):
tensor_shape.unknown_shape(),
v.receiver_tensors["feature_2"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_serving_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -411,6 +435,7 @@ class ExportTest(test_util.TensorFlowTestCase):
dtypes.int32,
serving_input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -431,6 +456,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(
dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -454,6 +480,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["input", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_batch_size(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -466,6 +493,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
self.assertEqual([10], input_receiver.features["feature_1"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -474,6 +502,7 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn(self):
def dummy_input_fn():
return ({"x": constant_op.constant([[1], [1]]),
@@ -491,6 +520,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["x", "y", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
def dummy_input_fn(feature_key="x"):
return ({feature_key: constant_op.constant([[1], [1]]),
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index c4b006955c..fcccfbde7a 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -323,6 +323,43 @@ class LatestExporterTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
+ def test_garbage_collect_exports_with_trailing_delimiter(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ gfile.MkDir(export_dir_base)
+ export_dir_1 = _create_test_export_dir(export_dir_base)
+ export_dir_2 = _create_test_export_dir(export_dir_base)
+ export_dir_3 = _create_test_export_dir(export_dir_base)
+ export_dir_4 = _create_test_export_dir(export_dir_base)
+
+ self.assertTrue(gfile.Exists(export_dir_1))
+ self.assertTrue(gfile.Exists(export_dir_2))
+ self.assertTrue(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
+ def _serving_input_receiver_fn():
+ return array_ops.constant([1]), None
+
+ exporter = exporter_lib.LatestExporter(
+ name="latest_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ exports_to_keep=1)
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ # Garbage collect all but the most recent 2 exports,
+ # where recency is determined based on the timestamp directory names.
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ mock_list_directory.return_value = [
+ os.path.basename(export_dir_1) + b"/",
+ os.path.basename(export_dir_2) + b"/",
+ os.path.basename(export_dir_3) + b"/",
+ os.path.basename(export_dir_4) + b"/",
+ ]
+ exporter.export(estimator, export_dir_base, None, None, False)
+
+ self.assertFalse(gfile.Exists(export_dir_1))
+ self.assertFalse(gfile.Exists(export_dir_2))
+ self.assertFalse(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
def _create_test_export_dir(export_dir_base):
export_dir = _get_timestamped_export_dir(export_dir_base)
diff --git a/tensorflow/python/estimator/gc.py b/tensorflow/python/estimator/gc.py
index 9f8a463ec1..03ad33dd6b 100644
--- a/tensorflow/python/estimator/gc.py
+++ b/tensorflow/python/estimator/gc.py
@@ -201,9 +201,11 @@ def _get_paths(base_dir, parser):
raw_paths = gfile.ListDirectory(base_dir)
paths = []
for r in raw_paths:
- p = parser(Path(os.path.join(compat.as_str_any(base_dir),
- compat.as_str_any(r)),
- None))
+ # ListDirectory() return paths with "/" at the last if base_dir was GCS URL
+ r = compat.as_str_any(r)
+ if r[-1] == '/':
+ r = r[0:len(r)-1]
+ p = parser(Path(os.path.join(compat.as_str_any(base_dir), r), None))
if p:
paths.append(p)
return sorted(paths)
diff --git a/tensorflow/python/estimator/gc_test.py b/tensorflow/python/estimator/gc_test.py
index 2cbdd511d1..53c3d4ca2a 100644
--- a/tensorflow/python/estimator/gc_test.py
+++ b/tensorflow/python/estimator/gc_test.py
@@ -140,6 +140,17 @@ class GcTest(test_util.TensorFlowTestCase):
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc._get_paths(base_dir, _create_parser(base_dir))
+ def testGcsDirWithSeparator(self):
+ base_dir = "gs://bucket/foo"
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ # gfile.ListDirectory returns directory names with separator '/'
+ mock_list_directory.return_value = ["0/", "1/"]
+ self.assertEqual(
+ gc._get_paths(base_dir, _create_parser(base_dir)),
+ [
+ gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1)
+ ])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 4e7b00b307..632908415f 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -42,7 +42,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -68,7 +68,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -30)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=128, shuffle=False, num_epochs=2)
features, target = input_fn()
@@ -93,7 +93,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=0)
features, target = input_fn()
@@ -114,7 +114,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -27)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -150,7 +150,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -29)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=3)
features, target = input_fn()
@@ -196,7 +196,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -28)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -221,7 +221,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = np.arange(-32, -30)
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()
@@ -240,7 +240,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXAsNonDict(self):
x = list(range(32, 36))
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):
failing_input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -249,7 +249,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXIsEmptyDict(self):
x = {}
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -257,7 +257,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithXIsEmptyArray(self):
x = np.array([[], []])
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -268,7 +268,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = None
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features_tensor = input_fn()
@@ -291,7 +291,7 @@ class NumpyIoTest(test.TestCase):
def testNumpyInputFnWithNonBoolShuffle(self):
x = np.arange(32, 36)
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'shuffle must be provided and explicitly '
'set as boolean'):
@@ -303,7 +303,7 @@ class NumpyIoTest(test.TestCase):
x = {'__target_key__': array}
y = np.arange(4)
- with self.test_session():
+ with self.cached_session():
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
input_fn()
@@ -318,7 +318,7 @@ class NumpyIoTest(test.TestCase):
x_mismatch_length = {'a': np.arange(1), 'b': b}
y_longer_length = np.arange(10)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'Length of tensors in x and y is mismatched.'):
failing_input_fn = numpy_io.numpy_input_fn(
@@ -341,7 +341,7 @@ class NumpyIoTest(test.TestCase):
x = {'a': a, 'b': b}
y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
features_tensor, targets_tensor = input_fn()
@@ -369,7 +369,7 @@ class NumpyIoTest(test.TestCase):
b = np.arange(32, 36)
x = {'a': a, 'b': b}
y = {}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
failing_input_fn()
@@ -379,7 +379,7 @@ class NumpyIoTest(test.TestCase):
b = np.arange(32, 36)
x = {'a': a, 'b': b}
y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, '2 duplicate keys are found in both x and y'):
failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index 6f13bc95d2..9e69fc72dc 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -102,7 +102,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -116,7 +116,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -131,7 +131,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
input_fn = pandas_io.pandas_input_fn(
@@ -147,7 +147,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrameWithYAsDataFrame()
y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
input_fn = pandas_io.pandas_input_fn(
@@ -163,7 +163,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -191,7 +191,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -230,7 +230,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -243,7 +243,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -266,7 +266,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -276,7 +276,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -286,7 +286,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -297,7 +297,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index e4ce5339d0..6b2765be82 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -33,9 +33,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
-from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.keras.engine.network import Network
-from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
@@ -47,8 +44,6 @@ from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -92,184 +87,78 @@ def _any_weight_initialized(keras_model):
return False
-def _create_ordered_io(keras_model, estimator_io, is_input=True):
- """Create a list of tensors from IO dictionary based on Keras IO order.
+def _convert_estimator_io_to_keras(keras_model, features, labels):
+ """Converts estimator features and labels to keras input and target tensors.
Args:
- keras_model: An instance of compiled keras model.
- estimator_io: The features or labels (dict or plain array) from model_fn.
- is_input: True if dictionary is for inputs.
+ keras_model: a compiled `tf.keras.Model` instance, used to determine the
+ order of the returned lists.
+ features: Dict of tensors or `None`.
+ labels: Dict of tensors, a single tensor, or `None`.
Returns:
- A list of tensors based on Keras IO order.
-
- Raises:
- ValueError: if dictionary keys cannot be found in Keras model input_names
- or output_names.
- """
- if isinstance(estimator_io, (list, tuple)):
- # Case currently not supported by most built-in input_fn,
- # but it's good to have for sanity
- return [_convert_tensor(x) for x in estimator_io]
- elif isinstance(estimator_io, dict):
- if is_input:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.input_names
- else:
- keras_io_names = [
- 'input_%d' % i for i in range(1, len(estimator_io) + 1)]
- else:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.output_names
- else:
- keras_io_names = [
- 'output_%d' % i for i in range(1, len(estimator_io) + 1)]
-
- for key in estimator_io:
- if key not in keras_io_names:
- raise ValueError(
- 'Cannot find %s with name "%s" in Keras Model. '
- 'It needs to match one '
- 'of the following: %s' % ('input' if is_input else 'output', key,
- ', '.join(keras_io_names)))
- tensors = [_convert_tensor(estimator_io[io_name])
- for io_name in keras_io_names]
- return tensors
- else:
- # Plain array.
- return _convert_tensor(estimator_io)
-
-
-def _in_place_subclassed_model_reset(model):
- """Substitute for model cloning that works for subclassed models.
-
- Subclassed models cannot be cloned because their topology is not serializable.
- To "instantiate" an identical model in a new TF graph, we reuse the original
- model object, but we clear its state.
-
- After calling this function on a model instance, you can use the model
- instance as if it were a model clone (in particular you can use it in a new
- graph).
-
- This method clears the state of the input model. It is thus destructive.
- However the original state can be restored fully by calling
- `_in_place_subclassed_model_state_restoration`.
-
- Args:
- model: Instance of a Keras model created via subclassing.
-
- Raises:
- ValueError: In case the model uses a subclassed model as inner layer.
+ Tuple of (
+ list of input tensors or `None`,
+ list of target tensors or `None`)
+ The order of tensors is determined by the order set in the keras model.
"""
- assert not model._is_graph_network # Only makes sense for subclassed networks
- # Retrieve all layers tracked by the model as well as their attribute names
- attributes_cache = {}
- for name in dir(model):
- try:
- value = getattr(model, name)
- except (AttributeError, ValueError, TypeError):
- continue
- if isinstance(value, Layer):
- attributes_cache[name] = value
- assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
- # Handle case: list/tuple of layers (also tracked by the Network API).
- if value and all(isinstance(val, Layer) for val in value):
- raise ValueError('We do not support the use of list-of-layers '
- 'attributes in subclassed models used with '
- '`model_to_estimator` at this time. Found list '
- 'model: %s' % name)
-
- # Replace layers on the model with fresh layers
- layers_to_names = {value: key for key, value in attributes_cache.items()}
- original_layers = model._layers[:]
- model._layers = data_structures.NoDependency([])
- for layer in original_layers: # We preserve layer order.
- config = layer.get_config()
- # This will not work for nested subclassed models used as layers.
- # This would be theoretically possible to support, but would add complexity.
- # Only do it if users complain.
- if isinstance(layer, Network) and not layer._is_graph_network:
- raise ValueError('We do not support the use of nested subclassed models '
- 'in `model_to_estimator` at this time. Found nested '
- 'model: %s' % layer)
- fresh_layer = layer.__class__.from_config(config)
- name = layers_to_names[layer]
- setattr(model, name, fresh_layer)
-
- # Cache original model build attributes (in addition to layers)
- if (not hasattr(model, '_original_attributes_cache') or
- model._original_attributes_cache is None):
- if model.built:
- attributes_to_cache = [
- 'inputs',
- 'outputs',
- '_feed_outputs',
- '_feed_output_names',
- '_feed_output_shapes',
- '_feed_loss_fns',
- 'loss_weights_list',
- 'targets',
- '_feed_targets',
- 'sample_weight_modes',
- 'weighted_metrics',
- 'metrics_names',
- 'metrics_tensors',
- 'metrics_updates',
- 'stateful_metric_names',
- 'total_loss',
- 'sample_weights',
- '_feed_sample_weights',
- 'train_function',
- 'test_function',
- 'predict_function',
- '_collected_trainable_weights',
- '_feed_inputs',
- '_feed_input_names',
- '_feed_input_shapes',
- 'optimizer',
- ]
- for name in attributes_to_cache:
- attributes_cache[name] = getattr(model, name)
- model._original_attributes_cache = data_structures.NoDependency(
- attributes_cache)
- # Reset built state
- model.built = False
- model.inputs = None
- model.outputs = None
-
-
-def _in_place_subclassed_model_state_restoration(model):
- """Restores the original state of a model after it was "reset".
-
- This undoes this action of `_in_place_subclassed_model_reset`.
- Args:
- model: Instance of a Keras model created via subclassing, on which
- `_in_place_subclassed_model_reset` was previously called.
- """
- assert not model._is_graph_network
- # Restore layers and build attributes
- if (hasattr(model, '_original_attributes_cache') and
- model._original_attributes_cache is not None):
- # Models have sticky attribute assignment, so we want to be careful to add
- # back the previous attributes and track Layers by their original names
- # without adding dependencies on "utility" attributes which Models exempt
- # when they're constructed.
- model._layers = data_structures.NoDependency([])
- for name, value in model._original_attributes_cache.items():
- if not isinstance(value, checkpointable.CheckpointableBase):
- # If this value is not already checkpointable, it's probably that way
- # for a reason; we don't want to start tracking data structures that the
- # original Model didn't.
- value = data_structures.NoDependency(value)
- setattr(model, name, value)
- model._original_attributes_cache = None
- else:
- # Restore to the state of a never-called model.
- model.built = False
- model.inputs = None
- model.outputs = None
+ def _to_ordered_tensor_list(obj, key_order, obj_name, order_name):
+ """Convert obj to an ordered list of tensors.
+
+ Args:
+ obj: List, dict, or single tensor. May be `None`.
+ key_order: List of strings with the order to return (used if obj is a
+ dict).
+ obj_name: String name of object (e.g. "features" or "labels")
+ order_name: String name of the key order (e.g. "inputs" or "outputs")
+
+ Returns:
+ List of tensors, or `None`
+
+ Raises:
+ KeyError: If obj has invalid keys.
+ """
+ if obj is None:
+ return None
+ elif isinstance(obj, (list, tuple)):
+ return [_convert_tensor(x) for x in obj]
+ elif isinstance(obj, dict):
+ # Ensure that the obj keys and keys in key_order are exactly the same.
+ different_keys = set(obj.keys()) ^ set(key_order)
+
+ if different_keys:
+ raise KeyError(
+ 'The dictionary passed into {obj_name} does not have the expected '
+ '{order_name} keys defined in the keras model.'
+ '\n\tExpected keys: {order_keys}'
+ '\n\t{obj_name} keys: {obj_keys}'
+ '\n\tDifference: {different_keys}'.format(
+ order_name=order_name, order_keys=set(key_order),
+ obj_name=obj_name, obj_keys=set(obj.keys()),
+ different_keys=different_keys))
+
+ return [_convert_tensor(obj[key]) for key in key_order]
+ else: # Assume obj is a tensor.
+ return [_convert_tensor(obj)]
+
+ input_names = None
+ output_names = None
+ if isinstance(features, dict):
+ input_names = (
+ keras_model.input_names if keras_model._is_graph_network else
+ ['input_%d' % i for i in range(1, len(features) + 1)])
+ if isinstance(labels, dict):
+ output_names = (
+ keras_model.output_names if keras_model._is_graph_network else
+ ['output_%d' % i for i in range(1, len(labels) + 1)])
+
+ input_tensors = _to_ordered_tensor_list(
+ features, input_names, 'features', 'inputs')
+ target_tensors = _to_ordered_tensor_list(
+ labels, output_names, 'labels', 'outputs')
+
+ return input_tensors, target_tensors
def _clone_and_build_model(mode,
@@ -289,61 +178,62 @@ def _clone_and_build_model(mode,
Returns:
The newly built model.
"""
- # Set to True during training, False for inference.
+ # Set to True during training, False for inference or testing.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
+ input_tensors, target_tensors = _convert_estimator_io_to_keras(
+ keras_model, features, labels)
- # Get list of inputs.
- if features is None:
- input_tensors = None
- else:
- input_tensors = _create_ordered_io(keras_model,
- estimator_io=features,
- is_input=True)
- # Get list of outputs.
- if labels is None:
- target_tensors = None
- elif isinstance(labels, dict):
- target_tensors = _create_ordered_io(keras_model,
- estimator_io=labels,
- is_input=False)
- else:
- target_tensors = [
- _convert_tensor(labels)
- ]
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
- if keras_model._is_graph_network:
- if custom_objects:
- with CustomObjectScope(custom_objects):
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = keras_model
- _in_place_subclassed_model_reset(model)
- if input_tensors is not None:
- model._set_inputs(input_tensors)
-
- # Compile/Build model
- if mode is model_fn_lib.ModeKeys.PREDICT:
- if isinstance(model, models.Sequential):
- model.build()
- else:
- if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
- optimizer = keras_model.optimizer
- else:
- optimizer_config = keras_model.optimizer.get_config()
- optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
- optimizer.iterations = training_util.get_or_create_global_step()
+ global_step = None
+ if compile_clone:
+ # Set iterations to the global step created by tf.train.create_global_step()
+ # which is automatically run in the estimator framework.
+ global_step = training_util.get_or_create_global_step()
+ K.track_variable(global_step)
- model.compile(
- optimizer,
- keras_model.loss,
- metrics=keras_model.metrics,
- loss_weights=keras_model.loss_weights,
- sample_weight_mode=keras_model.sample_weight_mode,
- weighted_metrics=keras_model.weighted_metrics,
- target_tensors=target_tensors)
- return model
+ clone = models.clone_and_build_model(
+ keras_model, input_tensors, target_tensors, custom_objects,
+ compile_clone=compile_clone,
+ in_place_reset=(not keras_model._is_graph_network),
+ optimizer_iterations=global_step)
+
+ return clone
+
+
+def _convert_keras_metrics_to_estimator(model):
+ """Convert metrics from a Keras model to ops used by the Estimator framework.
+
+ Args:
+ model: A `tf.keras.Model` object.
+
+ Returns:
+ Dictionary mapping metric names to tuples of (value, update) ops. May return
+ `None` if the model does not contain any metrics.
+ """
+ if not getattr(model, 'metrics', None):
+ return None
+
+ # TODO(psv/fchollet): support stateful metrics
+ eval_metric_ops = {}
+ # When each metric maps to an output
+ if isinstance(model.metrics, dict):
+ for i, output_name in enumerate(model.metrics.keys()):
+ metric_name = model.metrics[output_name]
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ # When some outputs use the same metric
+ if list(model.metrics.values()).count(metric_name) > 1:
+ metric_name += '_' + output_name
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
+ else:
+ for i, metric_name in enumerate(model.metrics):
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
+ return eval_metric_ops
def _create_keras_model_fn(keras_model, custom_objects=None):
@@ -395,26 +285,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
model._make_test_function() # pylint: disable=protected-access
loss = model.total_loss
- if model.metrics:
- # TODO(psv/fchollet): support stateful metrics
- eval_metric_ops = {}
- # When each metric maps to an output
- if isinstance(model.metrics, dict):
- for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
- # When some outputs use the same metric
- if list(model.metrics.values()).count(metric_name) > 1:
- metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
- else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ eval_metric_ops = _convert_keras_metrics_to_estimator(model)
# Set train_op only during train.
if mode is model_fn_lib.ModeKeys.TRAIN:
@@ -423,7 +294,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
if not model._is_graph_network:
# Reset model state to original state,
# to avoid `model_fn` being destructive for the initial model argument.
- _in_place_subclassed_model_state_restoration(keras_model)
+ models.in_place_subclassed_model_state_restoration(keras_model)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=predictions,
@@ -487,8 +358,9 @@ def model_to_estimator(keras_model=None,
config=None):
"""Constructs an `Estimator` instance from given keras model.
- For usage example, please see
- @{$guide/estimators$creating_estimators_from_keras_models}.
+ For usage example, please see:
+ [Creating estimators from Keras
+ Models](https://tensorflow.org/guide/estimators#model_to_estimator).
Args:
keras_model: A compiled Keras model object. This argument is mutually
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 332e385726..7e5a0c80a7 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -26,20 +26,23 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import keras as keras_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
-from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import rmsprop
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
try:
@@ -90,6 +93,15 @@ def simple_subclassed_model():
return SimpleModel()
+def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
+ def input_fn():
+ ds = dataset_ops.Dataset.from_tensor_slices((x, y) if y is not None else x)
+ if shuffle:
+ ds = ds.shuffle(1000)
+ return ds.repeat(num_epochs).batch(batch_size)
+ return input_fn
+
+
def get_resource_for_simple_model(model_type='sequential',
is_evaluate=False,):
if model_type == 'sequential':
@@ -117,19 +129,19 @@ def get_resource_for_simple_model(model_type='sequential',
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
- evaluate_input_fn = numpy_io.numpy_input_fn(
+ evaluate_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name),
y=randomize_io_type(y_test, output_name),
num_epochs=1, shuffle=False)
- predict_input_fn = numpy_io.numpy_input_fn(
+ predict_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
@@ -184,12 +196,14 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gfile.MakeDirs(self._base_dir)
self._config = run_config_lib.RunConfig(
tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+ super(TestKerasEstimator, self).setUp()
def tearDown(self):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
if os.path.isdir(self._base_dir):
gfile.DeleteRecursively(self._base_dir)
+ super(TestKerasEstimator, self).tearDown()
def test_train(self):
for model_type in ['sequential', 'functional']:
@@ -201,7 +215,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -226,7 +240,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -250,7 +264,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
keras_est = keras_lib.model_to_estimator(
@@ -272,7 +286,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
config=self._config)
@@ -295,7 +309,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
@@ -314,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
# Create state
keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -341,7 +355,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=True)
- with self.test_session():
+ with self.cached_session():
metrics = [
'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',
'categorical_crossentropy', 'cosine_proximity', 'hinge',
@@ -355,7 +369,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_eval = keras_est.evaluate(input_fn=eval_input_fn)
@@ -383,7 +397,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='sequential', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='adam',
@@ -391,7 +405,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_pred = [
@@ -437,7 +451,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict
- with self.test_session():
+ with self.cached_session():
model = multi_inputs_multi_outputs_model()
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
@@ -454,7 +468,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -464,7 +478,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
fname = os.path.join(self._base_dir, 'keras_model.h5')
keras.models.save_model(keras_model, fname)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model_path=fname, config=self._config)
est_pred = [
@@ -477,19 +491,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'Either'):
keras_lib.model_to_estimator()
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not both'):
keras_lib.model_to_estimator(
keras_model=keras_model,
keras_model_path=tempfile.mkdtemp(dir=self._base_dir))
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not a local path'):
keras_lib.model_to_estimator(
@@ -511,19 +525,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
input_dict = {'input_1': x_train}
output_dict = {'invalid_output_name': y_train}
return input_dict, output_dict
-
model = simple_functional_model()
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
-
- with self.test_session():
- with self.assertRaises(ValueError):
+ with self.cached_session():
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_input_name'):
est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_output_name'):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
@@ -545,20 +559,20 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
y_train = keras.utils.to_categorical(y_train, 2)
input_name = keras_model.input_names[0]
output_name = keras_model.output_names[0]
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
with self.assertRaisesRegexp(ValueError, 'relu6'):
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
est.train(input_fn=train_input_fn, steps=1)
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
@@ -584,7 +598,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
@@ -600,7 +614,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
self.assertEqual(
@@ -616,7 +630,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=run_config_lib.RunConfig())
@@ -627,7 +641,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
self.assertEqual(self._base_dir, est_keras._config.model_dir)
self.assertEqual(self._base_dir, est_keras._model_dir)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=None)
@@ -646,7 +660,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
@@ -661,7 +675,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
'constructor and `RunConfig`'):
keras_lib.model_to_estimator(
@@ -674,7 +688,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -688,6 +702,32 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
+ def assert_increasing_global_step(self, optimizer):
+ keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ with self.cached_session() as sess:
+ keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
+ global_step = training_util.create_global_step()
+ features, labels = train_input_fn().make_one_shot_iterator().get_next()
+ spec = keras_model_fn(features, labels, mode=model_fn_lib.ModeKeys.TRAIN)
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(variables.local_variables_initializer())
+
+ self.assertEqual(global_step.eval(), 0) # Sanity check
+ sess.run(spec.train_op)
+ self.assertEqual(global_step.eval(), 1)
+
+ def test_model_fn_increments_global_step_tf_optimizer(self):
+ self.assert_increasing_global_step(rmsprop.RMSPropOptimizer(1e-3))
+
+ def test_model_fn_increments_global_step_keras_optimizer(self):
+ self.assert_increasing_global_step('rmsprop')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 9db9ccd01d..439cc2e3a4 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.metrics import Metric
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -142,12 +143,14 @@ class EstimatorSpec(
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
train_op: Op for the training step.
- eval_metric_ops: Dict of metric results keyed by name. The values of the
- dict are the results of calling a metric function, namely a
- `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
- without any impact on state (typically is a pure computation results
- based on variables.). For example, it should not trigger the `update_op`
- or requires any input fetching.
+ eval_metric_ops: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) Results of calling a metric function, namely a
+ `(metric_tensor, update_op)` tuple. `metric_tensor` should be
+ evaluated without any impact on state (typically is a pure computation
+ results based on variables.). For example, it should not trigger the
+ `update_op` or requires any input fetching.
export_outputs: Describes the output signatures to be exported to
`SavedModel` and used during serving.
A dict `{name: output}` where:
@@ -218,21 +221,27 @@ class EstimatorSpec(
if not isinstance(eval_metric_ops, dict):
raise TypeError(
'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
- for key, metric_value_and_update in six.iteritems(eval_metric_ops):
- if (not isinstance(metric_value_and_update, tuple) or
- len(metric_value_and_update) != 2):
- raise TypeError(
- 'Values of eval_metric_ops must be (metric_value, update_op) '
- 'tuples, given: {} for key: {}'.format(
- metric_value_and_update, key))
- metric_value, metric_update = metric_value_and_update
- for metric_value_member in nest.flatten(metric_value):
- # Allow (possibly nested) tuples for metric values, but require that
- # each of them be Tensors or Operations.
- _check_is_tensor_or_operation(metric_value_member,
+ for key, value in six.iteritems(eval_metric_ops):
+ # TODO(psv): When we deprecate the old metrics, throw an error here if
+ # the value is not an instance of `Metric` class.
+ if isinstance(value, Metric):
+ if not value.updates: # Check if metrics updates are available.
+ raise ValueError(
+ 'Please call update_state(...) on the "{metric_name}" metric'
+ .format(metric_name=value.name))
+ else:
+ if not isinstance(value, tuple) or len(value) != 2:
+ raise TypeError(
+ 'Values of eval_metric_ops must be (metric_value, update_op) '
+ 'tuples, given: {} for key: {}'.format(value, key))
+ metric_value, metric_update = value
+ for metric_value_member in nest.flatten(metric_value):
+ # Allow (possibly nested) tuples for metric values, but require that
+ # each of them be Tensors or Operations.
+ _check_is_tensor_or_operation(metric_value_member,
+ 'eval_metric_ops[{}]'.format(key))
+ _check_is_tensor_or_operation(metric_update,
'eval_metric_ops[{}]'.format(key))
- _check_is_tensor_or_operation(metric_update,
- 'eval_metric_ops[{}]'.format(key))
# Validate the passed export outputs, or generate defaults.
if mode == ModeKeys.PREDICT:
@@ -267,8 +276,12 @@ class EstimatorSpec(
if train_op is not None and train_op.graph is not default_graph:
raise ValueError(error_message_template.format('train_op', train_op.name))
for key, value in list(six.iteritems(eval_metric_ops)):
- values = nest.flatten(value)
- for val in values:
+ if isinstance(value, Metric):
+ values_to_check = value.updates[:]
+ values_to_check.append(value.result())
+ else:
+ values_to_check = nest.flatten(value)
+ for val in values_to_check:
if val.graph is not default_graph:
raise ValueError(error_message_template.format(
'eval_metric_ops',
@@ -287,6 +300,19 @@ class EstimatorSpec(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
+ # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
+ # are by default not added to any collections. We are doing this here, so
+ # that metric variables get initialized.
+ local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ vars_to_add = set()
+ for key, value in six.iteritems(eval_metric_ops):
+ if isinstance(value, Metric):
+ vars_to_add.update(value.variables)
+ # Remove variables that are in the local variables collection already.
+ vars_to_add = vars_to_add.difference(local_vars)
+ for v in vars_to_add:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
+
scaffold = scaffold or monitored_session.Scaffold()
# Validate scaffold.
if not isinstance(scaffold, monitored_session.Scaffold):
@@ -449,3 +475,44 @@ def _check_is_tensor(x, tensor_name):
if not isinstance(x, ops.Tensor):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
+
+
+def export_outputs_for_mode(
+ mode, serving_export_outputs=None, predictions=None, loss=None,
+ metrics=None):
+ """Util function for constructing a `ExportOutput` dict given a mode.
+
+ The returned dict can be directly passed to `build_all_signature_defs` helper
+ function as the `export_outputs` argument, used for generating a SignatureDef
+ map.
+
+ Args:
+ mode: A `ModeKeys` specifying the mode.
+ serving_export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict or None.
+ predictions: A dict of Tensors or single Tensor representing model
+ predictions. This argument is only used if serving_export_outputs is not
+ set.
+ loss: A dict of Tensors or single Tensor representing calculated loss.
+ metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op
+
+ Returns:
+ Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
+ The key is the expected SignatureDef key for the mode.
+
+ Raises:
+ ValueError: if an appropriate ExportOutput cannot be found for the mode.
+ """
+ # TODO(b/113185250): move all model export helper functions into an util file.
+ if mode == ModeKeys.PREDICT:
+ return _get_export_outputs(serving_export_outputs, predictions)
+ elif mode == ModeKeys.TRAIN:
+ return {mode: export_output_lib.TrainOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ elif mode == ModeKeys.EVAL:
+ return {mode: export_output_lib.EvalOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ else:
+ raise ValueError(
+ 'Export output type not found for mode: {}'.format(mode))
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index 08e41fd414..8a3a9f3f51 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -48,7 +49,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -56,16 +57,21 @@ class EstimatorSpecTrainTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -77,7 +83,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -86,20 +92,20 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant([1.]),
train_op=control_flow_ops.no_op())
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, train_op=control_flow_ops.no_op())
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -107,7 +113,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -121,7 +127,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -130,13 +136,13 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testTrainOpMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing train_op'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.))
def testTrainOpNotOperationAndTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError,
'train_op must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -147,7 +153,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testTrainOpFromDifferentGraph(self):
with ops.Graph().as_default():
train_op = control_flow_ops.no_op()
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -156,7 +162,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=train_op)
def testTrainingChiefHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -166,7 +172,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_chief_hooks=[_InvalidHook()])
def testTrainingHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -176,7 +182,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_hooks=[_InvalidHook()])
def testScaffoldInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'scaffold must be tf\.train\.Scaffold'):
model_fn.EstimatorSpec(
@@ -186,7 +192,7 @@ class EstimatorSpecTrainTest(test.TestCase):
scaffold=_InvalidScaffold())
def testReturnDefaultScaffold(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
estimator_spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -199,7 +205,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -208,16 +214,21 @@ class EstimatorSpecEvalTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -227,7 +238,7 @@ class EstimatorSpecEvalTest(test.TestCase):
evaluation_hooks=[_FakeHook()])
def testEvaluationHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -237,7 +248,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testTupleMetric(self):
"""Tests that no errors are raised when a metric is tuple-valued."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -248,7 +259,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1.])
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -257,7 +268,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -265,14 +276,14 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=1.)
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions={'loss': constant_op.constant(1.)})
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1., 2.])
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
@@ -281,7 +292,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -296,7 +307,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -305,7 +316,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testReplaceRaisesConstructorChecks(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -313,7 +324,7 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(loss=constant_op.constant([1., 2.]))
def testReplaceDoesReplace(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -321,7 +332,7 @@ class EstimatorSpecEvalTest(test.TestCase):
self.assertEqual(['m'], list(new_spec.predictions.keys()))
def testReplaceNotAllowModeChange(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -331,13 +342,13 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(mode=model_fn.ModeKeys.TRAIN)
def testPredictionsMissingIsOkay(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, loss=constant_op.constant(1.))
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -345,7 +356,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
@@ -354,7 +365,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -370,7 +381,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testPredictionsFromDifferentGraph(self):
with ops.Graph().as_default():
predictions = {'loss': constant_op.constant(1.)}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -379,7 +390,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testEvalMetricOpsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError, 'eval_metric_ops must be a dict'):
@@ -390,7 +401,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops=loss)
def testEvalMetricOpsNoTuple(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError,
@@ -403,7 +414,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': loss})
def testEvalMetricOpsNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -413,7 +424,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ('NonTensor', loss)})
def testEvalMetricNestedNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -423,11 +434,26 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ((('NonTensor',),),
control_flow_ops.no_op())})
- def testEvalMetricOpsFromDifferentGraph(self):
+ def testEvalMetricOpsFromDifferentGraphWithMetricTuple(self):
with ops.Graph().as_default():
eval_metric_ops = {
'loss': (control_flow_ops.no_op(), constant_op.constant(1.))}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(
+ ValueError, 'must be from the default graph'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
+ def testEvalMetricOpsFromDifferentGraphWithMetricObject(self):
+ with ops.Graph().as_default():
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(constant_op.constant(1.))
+ eval_metric_ops = {'metric': metric_obj}
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
@@ -437,29 +463,46 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss,
eval_metric_ops=eval_metric_ops)
+ def testEvalMetricOpsWithoutUpdates(self):
+ with ops.Graph().as_default():
+ eval_metric_ops = {'mean': metrics.Mean()}
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(ValueError, 'Please call update_state(...)'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
class EstimatorSpecInferTest(test.TestCase):
"""Tests EstimatorSpec in infer mode."""
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions={'loss': constant_op.constant(1.)})
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -470,7 +513,7 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_FakeHook()])
def testPredictionHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -479,25 +522,25 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_InvalidHook()])
def testPredictionsMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing predictions'):
model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT)
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.))
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.})
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -509,7 +552,7 @@ class EstimatorSpecInferTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT, predictions=predictions)
def testExportOutputsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
classes = constant_op.constant('hello')
with self.assertRaisesRegexp(
@@ -520,7 +563,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs=export_output.ClassificationOutput(classes=classes))
def testExportOutputsValueNotExportOutput(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
with self.assertRaisesRegexp(
TypeError,
@@ -533,7 +576,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs={'head_name': predictions})
def testExportOutputsSingleheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
regression_output = export_output.RegressionOutput(value=output_1)
@@ -552,7 +595,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(expected_export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadWithDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -571,7 +614,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -594,13 +637,13 @@ class EstimatorSpecInferTest(test.TestCase):
def testDefaultExportOutputCreated(self):
"""Ensure that a default PredictOutput is created for export."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = constant_op.constant(1.)
self._assertDefaultExportOutputForPredictions(predictions)
def testDefaultExportOutputCreatedDict(self):
"""Ensure that a default PredictOutput is created for export for dicts."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.),
'score': constant_op.constant(10.)}
self._assertDefaultExportOutputForPredictions(predictions)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 220c3e58ca..3773810a04 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -26,6 +26,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
@@ -51,6 +52,7 @@ _DEFAULT_REPLACEABLE_LIST = [
'device_fn',
'protocol',
'eval_distribute',
+ 'experimental_distribute',
]
_SAVE_CKPT_ERR = (
@@ -331,7 +333,8 @@ class RunConfig(object):
train_distribute=None,
device_fn=None,
protocol=None,
- eval_distribute=None):
+ eval_distribute=None,
+ experimental_distribute=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -458,7 +461,8 @@ class RunConfig(object):
train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.train_distribute` is preferred.
device_fn: A callable invoked for every `Operation` that takes the
`Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter`
@@ -468,7 +472,13 @@ class RunConfig(object):
eval_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.eval_distribute` is preferred.
+ experimental_distribute: an optional
+ `tf.contrib.distribute.DistributeConfig` object specifying
+ DistributionStrategy-related configuration. The `train_distribute` and
+ `eval_distribute` can be passed as parameters to `RunConfig` or set in
+ `experimental_distribute` but not both.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -508,11 +518,20 @@ class RunConfig(object):
train_distribute=train_distribute,
device_fn=device_fn,
protocol=protocol,
- eval_distribute=eval_distribute)
-
- self._init_distributed_setting_from_environment_var(tf_config)
-
- self._maybe_overwrite_session_config_for_distributed_training()
+ eval_distribute=eval_distribute,
+ experimental_distribute=experimental_distribute)
+
+ # TODO(frankchn,priyag): Eventually use distributed coordinator for TPUs.
+ if ((train_distribute and
+ train_distribute.__class__.__name__ != 'TPUStrategy') or
+ (eval_distribute and
+ eval_distribute.__class__.__name__ != 'TPUStrategy') or
+ experimental_distribute):
+ logging.info('Initializing RunConfig with distribution strategies.')
+ distribute_coordinator_training.init_run_config(self, tf_config)
+ else:
+ self._init_distributed_setting_from_environment_var(tf_config)
+ self._maybe_overwrite_session_config_for_distributed_training()
def _maybe_overwrite_session_config_for_distributed_training(self):
"""Overwrites the session_config for distributed training.
@@ -810,6 +829,7 @@ class RunConfig(object):
- `device_fn`,
- `protocol`.
- `eval_distribute`,
+ - `experimental_distribute`,
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index bb1305767f..240be5dabe 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -26,6 +26,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -129,8 +130,8 @@ class TrainSpec(
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -193,8 +194,8 @@ class EvalSpec(
Args:
input_fn: A function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/api_guides/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
evaluation `input_fn`, steps, etc.
This utility function provides consistent behavior for both local
- (non-distributed) and distributed configurations. Currently, the only
- supported distributed training configuration is between-graph replication.
+ (non-distributed) and distributed configurations. The default distribution
+ configuration is parameter server-based between-graph replication. For other
+ types of distribution configurations such as all-reduce training, please use
+ [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
Overfitting: In order to avoid overfitting, it is recommended to set up the
training `input_fn` to shuffle the training data properly.
@@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
}'
```
+ When `distribute` or `experimental_distribute.train_distribute` and
+ `experimental_distribute.remote_cluster` is set, this method will start a
+ client running on the current host which connects to the `remote_cluster` for
+ training and evaluation.
+
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
@@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
-
config = estimator.config
+
+ # If `distribute_coordinator_mode` is set and running in distributed
+ # environment, we run `train_and_evaluate` via distribute coordinator.
+ if distribute_coordinator_training.should_run_distribute_coordinator(config):
+ logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
+ distribute_coordinator_training.train_and_evaluate(
+ estimator, train_spec, eval_spec, _TrainingExecutor)
+ return
+
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
config.task_id > 0):
raise ValueError(
@@ -837,6 +853,13 @@ class _TrainingExecutor(object):
if difference > 0:
logging.info('Waiting %f secs before starting next eval run.', difference)
time.sleep(difference)
+ elif (throttle_secs == 0 and
+ eval_result.status != _EvalStatus.EVALUATED):
+ # Prints a user-actionable warning to avoid unnecessary load on evaluator.
+ logging.warning(
+ 'EvalSpec.throttle_secs is set as 0. This might overload the job '
+ 'before finding (next) new checkpoint. Please consider to increase '
+ 'it.')
return (eval_result, should_early_stop)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index dc106c7d3b..7d46917a6f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -83,6 +83,9 @@ _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
+_INPROPER_THROTTL_SECS = (
+ 'EvalSpec.throttle_secs is set as 0.*Please consider to increase')
+
# The message should NOT have 'local' word as part of it. As (?!word) is looking
# ahead, so, the $ (ending) check is required; otherwise, it will match
# partially and return successuful.
@@ -1281,7 +1284,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
]
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(logging, 'warning') as mock_log:
@@ -1295,6 +1298,34 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
# successuful evaluation)
self.assertEqual(2, mock_log.call_count)
+ def test_warning_if_throttle_secs_is_zero(self):
+ training_max_step = 200
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: training_max_step}
+ ]
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec.max_steps = training_max_step
+
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ # We need to make the first one invalid, so it will check the
+ # throttle_secs=0.
+ mock_est.latest_checkpoint.side_effect = [None, 'path']
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ executor.run_evaluator()
+
+ # First ckpt is invalid.
+ self.assertEqual(2, mock_est.latest_checkpoint.call_count)
+ self.assertEqual(1, mock_est.evaluate.call_count)
+
+ self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS)
+
def test_continuous_eval_listener_eval_result(self):
training_max_step = 200
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index d4a75478d5..31e4778e72 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -109,13 +109,17 @@ def parse_input_fn_result(result):
else:
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
+ return parse_iterator_result(result) + (input_hooks,)
+
+def parse_iterator_result(result):
+ """Gets features, labels from result."""
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
'input_fn should return (features, labels) as a len 2 tuple.')
- return result[0], result[1], input_hooks
- return result, None, input_hooks
+ return result[0], result[1]
+ return result, None
class _DatasetInitializerHook(training.SessionRunHook):
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py
index d7e0610779..d440c454dc 100644
--- a/tensorflow/python/estimator/util_test.py
+++ b/tensorflow/python/estimator/util_test.py
@@ -39,7 +39,7 @@ class UtilTest(test.TestCase):
features, labels, hooks = util.parse_input_fn_result(_input_fn())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vals = sess.run([features, labels])
self.assertAllEqual(vals[0], np.arange(100))
@@ -67,7 +67,7 @@ class UtilTest(test.TestCase):
features, labels, hooks = util.parse_input_fn_result(_input_fn())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vals = sess.run([features])
self.assertAllEqual(vals[0], np.arange(100))
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 1017d4ba47..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -12,6 +12,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_column",
+ ":feature_column_v2",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 6be930be87..9b482237ab 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -262,7 +262,7 @@ class NumericColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_parse_example_with_default_value(self):
@@ -284,7 +284,7 @@ class NumericColumnTest(test.TestCase):
no_data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
def test_normalizer_fn_must_be_callable(self):
@@ -298,7 +298,7 @@ class NumericColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
def test_get_dense_tensor(self):
@@ -433,7 +433,7 @@ class BucketizedColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([bucketized_price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
@@ -700,7 +700,7 @@ class HashedCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -719,7 +719,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
- with self.test_session():
+ with self.cached_session():
self.assertEqual(dtypes.int64, output.values.dtype)
self.assertAllEqual(expected_values, output.values.eval())
self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
@@ -775,7 +775,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = builder.get(hashed_sparse)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_int32_64_is_compatible(self):
@@ -789,7 +789,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = builder.get(hashed_sparse)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_get_sparse_tensors(self):
@@ -984,7 +984,7 @@ class CrossedColumnTest(test.TestCase):
features=fc.make_parse_example_spec([price_cross_wire]))
self.assertIn('price', features)
self.assertIn('wire', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
wire_sparse = features['wire']
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
@@ -1007,7 +1007,7 @@ class CrossedColumnTest(test.TestCase):
}
outputs = _transform_features(features, [price_cross_wire])
output = outputs[price_cross_wire]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
@@ -2747,6 +2747,62 @@ class FunctionalInputLayerTest(test.TestCase):
variables_lib.Variable)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ def test_fills_cols_to_vars_shared_embedding(self):
+ # Provide 5 DenseColumn's to input_layer: a NumericColumn, a
+ # BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
+ # EmbeddingColumn creates a Variable and the two SharedEmbeddingColumns
+ # shared one variable.
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ shared_embedding_a, shared_embedding_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ cols_to_vars = {}
+ all_cols = [
+ price1, dense_feature_bucketized, some_embedding_column,
+ shared_embedding_a, shared_embedding_b
+ ]
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(1, len(cols_to_vars[shared_embedding_a]))
+ # This is a bug in the current implementation and should be fixed in the
+ # new one.
+ self.assertEqual(0, len(cols_to_vars[shared_embedding_b]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ self.assertIsInstance(cols_to_vars[shared_embedding_a][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[shared_embedding_a][0].shape, [3, 2])
+
def test_fills_cols_to_vars_partitioned_variables(self):
price1 = fc.numeric_column('price1')
dense_feature = fc.numeric_column('dense_feature')
@@ -2772,6 +2828,10 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(0, len(cols_to_vars[price1]))
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(
+ 'input_from_feature_columns/input_layer/sparse_feature_embedding/'
+ 'embedding_weights/part_0:0',
+ cols_to_vars[some_embedding_column][0].name)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
@@ -3262,7 +3322,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_vocabulary_size(self):
@@ -3286,7 +3346,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_num_oov_buckets(self):
@@ -3350,7 +3410,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3775,7 +3835,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3797,7 +3857,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4096,7 +4156,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4365,7 +4425,7 @@ class IndicatorColumnTest(test.TestCase):
fc.categorical_column_with_hash_bucket('animal', 4))
builder = _LazyBuilder({'animal': ['fox', 'fox']})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_2D_shape_succeeds(self):
@@ -4380,7 +4440,7 @@ class IndicatorColumnTest(test.TestCase):
dense_shape=[2, 1])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_multi_hot(self):
@@ -4393,7 +4453,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
def test_multi_hot2(self):
@@ -4405,7 +4465,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
def test_deep_copy(self):
@@ -4430,7 +4490,7 @@ class IndicatorColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_indicator]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4641,7 +4701,7 @@ class EmbeddingColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_embedded]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5407,7 +5467,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_embedded, b_embedded]))
self.assertIn('aaa', features)
self.assertIn('bbb', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5544,20 +5604,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
- # Expected lookup result, using combiner='mean'.
- expected_lookups_a = (
- # example 0:
- (7., 11.), # ids [2], embedding = [7, 11]
- # example 1:
- (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
- )
- expected_lookups_b = (
- # example 0:
- (1., 2.), # ids [0], embedding = [1, 2]
- # example 1:
- (0., 0.), # ids [], embedding = [0, 0]
- )
-
# Build columns.
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
@@ -5990,7 +6036,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_weighted]))
self.assertIn('aaa', features)
self.assertIn('weights', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b6bf516286..28c5c82d2c 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -142,6 +142,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -155,7 +156,6 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -164,67 +164,148 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
-def _internal_input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None,
- scope=None):
- """See input_layer. `scope` is a name or variable scope to use."""
+class StateManager(object):
+ """Manages the state associated with FeatureColumns.
- feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access
- for column in feature_columns:
- if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be a _DenseColumn. '
- 'You can wrap a categorical column with an '
- 'embedding_column or indicator_column. Given: {}'.format(column))
- weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- output_tensors = []
- ordered_columns = []
- for column in sorted(feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name): # pylint: disable=protected-access
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
- num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
- batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
- if cols_to_vars is not None:
- # Retrieve any variables created (some _DenseColumn's don't create
- # variables, in which case an empty list is returned).
- cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES,
- scope=variable_scope.get_variable_scope().name)
- _verify_static_batch_size_equality(output_tensors, ordered_columns)
- return array_ops.concat(output_tensors, 1)
+ Some `FeatureColumn`s create variables or resources to assist their
+ computation. The `StateManager` is responsible for creating and storing these
+ objects since `FeatureColumn`s are supposed to be stateless configuration
+ only.
+ """
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a new variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ shape: variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ trainable: Whether this variable is trainable or not.
+ initializer: initializer instance (callable).
+ Returns:
+ The created variable.
+ """
+ del feature_column, name, shape, dtype, trainable, initializer
+ raise NotImplementedError('StateManager.create_variable')
+
+ def add_variable(self, feature_column, var):
+ """Adds an existing variable to the state.
+
+ Args:
+ feature_column: A `FeatureColumn` object to associate this variable with.
+ var: The variable.
+ """
+ del feature_column, var
+ raise NotImplementedError('StateManager.add_variable')
+
+ def get_variable(self, feature_column, name):
+ """Returns an existing variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_var')
-def input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+ def add_resource(self, feature_column, name, resource):
+ """Creates a new resource.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this resource corresponds to.
+ name: Name of the resource.
+ resource: The resource.
+
+ Returns:
+ The created resource.
+ """
+ del feature_column, name, resource
+ raise NotImplementedError('StateManager.add_resource')
+
+ def get_resource(self, feature_column, name):
+ """Returns an already created resource.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: Name of the resource.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_resource')
+
+
+class _InputLayerStateManager(StateManager):
+ """Manages the state of InputLayer."""
+
+ def __init__(self, layer, feature_columns, trainable):
+ """Creates an _InputLayerStateManager object.
+
+ Args:
+ layer: The input layer this state manager is associated with.
+ feature_columns: List of feature columns for the input layer
+ trainable: Whether by default, variables created are trainable or not.
+ """
+ self._trainable = trainable
+ self._layer = layer
+ self._cols_to_vars_map = {}
+ self._cols_to_names_map = {}
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ self._cols_to_vars_map[column] = {}
+ base_name = column.name
+ if isinstance(column, SharedEmbeddingColumn):
+ base_name = column.shared_collection_name
+ with variable_scope.variable_scope(base_name) as vs:
+ self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ if name in self._cols_to_vars_map[feature_column]:
+ raise ValueError('Variable already exists.')
+ with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name in self._cols_to_vars_map[feature_column]:
+ return self._cols_to_vars_map[feature_column][name]
+ raise ValueError('Variable does not exist.')
+
+
+class FeatureLayer(Layer):
+ """A layer that produces a dense `Tensor` based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
At the first layer of the model, this column oriented data should be converted
to a single `Tensor`.
+ This layer can be called multiple times with different features.
+
Example:
```python
@@ -233,105 +314,138 @@ def input_layer(features,
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- dense_tensor = input_layer(features, columns)
+ feature_layer = FeatureLayer(columns)
+ dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)
- ```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values can be a `SparseTensor` or a `Tensor` depends on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_DenseColumn` such as `numeric_column`, `embedding_column`,
- `bucketized_column`, `indicator_column`. If you have categorical features,
- you can wrap them with an `embedding_column` or `indicator_column`.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to list of `Variable`s. For example, after
- the call, we might have cols_to_vars =
- {_EmbeddingColumn(
- categorical_column=_HashedCategoricalColumn(
- key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
- dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
- <tf.Variable 'some_variable:1' shape=(5, 10)]}
- If a column creates no variables, its value will be an empty list.
-
- Returns:
- A `Tensor` which represents input layer of a model. Its shape
- is (batch_size, first_layer_dimension) and its dtype is `float32`.
- first_layer_dimension is determined based on given `feature_columns`.
-
- Raises:
- ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
- """
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
-
-
-# TODO(akshayka): InputLayer should be a subclass of Layer, and it
-# should implement the logic in input_layer using Layer's build-and-call
-# paradigm; input_layer should create an instance of InputLayer and
-# return the result of invoking its apply method, just as functional layers do.
-class InputLayer(object):
- """An object-oriented version of `input_layer` that reuses variables."""
+ prediction = tf.layers.dense(dense_tensor, 1)."""
def __init__(self,
feature_columns,
- weight_collections=None,
trainable=True,
- cols_to_vars=None):
- """See `input_layer`."""
+ name=None,
+ shared_state_manager=None,
+ **kwargs):
+ """Constructs a FeatureLayer.
- self._feature_columns = feature_columns
- self._weight_collections = weight_collections
- self._trainable = trainable
- self._cols_to_vars = cols_to_vars
- self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
- self._scope = self._input_layer_template.variable_scope
-
- def __call__(self, features):
- return self._input_layer_template(
- features=features,
- feature_columns=self._feature_columns,
- weight_collections=self._weight_collections,
- trainable=self._trainable,
- cols_to_vars=None,
- scope=self._scope)
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the FeatureLayer.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. The state of SharedEmbeddingColumns, unlike
+ regular embedding columns cannot be owned by the InputLayer itself since
+ SharedEmbeddingColumns can be shared across different InputLayers. As a
+ result users are expected to create a SharedEmbeddingStateManager object
+ which would be responsible for managing the shared state and can be
+ passed into different InputLayer objects to share state. For example,
+
+ ```python
+ sc_1, sc_2 = shared_embedding_column_v2(...)
+ sc_3, sc_4 = shared_embedding_column_v2(...)
+ ssm = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm)
+ ```
+ now input_layer1 and input_layer2 will share variables across. If
+ sharing is not desired, one can create 2 separate
+ SharedEmbeddingStateManager objects
+
+ ```python
+ ssm1 = SharedEmbeddingStateManager()
+ ssm2 = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm1)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm2)
+ ```
+ **kwargs: Keyword arguments to construct a layer.
- @property
- def non_trainable_variables(self):
- return self._input_layer_template.non_trainable_variables
+ Raises:
+ ValueError: if an item in `feature_columns` is not a `DenseColumn`.
+ """
+ super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
- @property
- def non_trainable_weights(self):
- return self._input_layer_template.non_trainable_weights
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._state_manager = _InputLayerStateManager(self, self._feature_columns,
+ self.trainable)
+ self._shared_state_manager = shared_state_manager
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if not isinstance(column, DenseColumn):
+ raise ValueError(
+ 'Items of feature_columns must be a DenseColumn. '
+ 'You can wrap a categorical column with an '
+ 'embedding_column or indicator_column. Given: {}'.format(column))
@property
- def trainable_variables(self):
- return self._input_layer_template.trainable_variables
+ def _is_feature_layer(self):
+ return True
- @property
- def trainable_weights(self):
- return self._input_layer_template.trainable_weights
+ def build(self, _):
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ else:
+ with variable_scope.variable_scope(None, default_name=self.name):
+ column.create_state(self._state_manager)
+ super(FeatureLayer, self).build(None)
- @property
- def variables(self):
- return self._input_layer_template.variables
+ def call(self, features, cols_to_output_tensors=None):
+ """Returns a dense tensor corresponding to the `feature_columns`.
- @property
- def weights(self):
- return self._input_layer_template.weights
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `FeatureColumn`.
+ cols_to_output_tensors: If not `None`, this will be filled with a dict
+ mapping feature columns to output tensors created.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: If features are not a dictionary.
+ """
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ transformation_cache = FeatureTransformationCache(features)
+ output_tensors = []
+ ordered_columns = []
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
+
+ _verify_static_batch_size_equality(output_tensors, ordered_columns)
+ return array_ops.concat(output_tensors, 1)
+
+ def compute_output_shape(self, input_shape):
+ total_elements = 0
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ total_elements += column.variable_shape.num_elements()
+ return (input_shape[0], total_elements)
def linear_model(features,
@@ -565,12 +679,15 @@ class _BiasLayer(base.Layer):
return self._bias_variable
-def _get_expanded_variable_list(variable):
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- return [variable] # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- return list(variable)
+def _get_expanded_variable_list(var_list):
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variables.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
def _strip_leading_slashes(name):
@@ -661,7 +778,7 @@ class _LinearModel(training.Model):
scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
name='weighted_sum')
bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
+ self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
return predictions
def _add_layers(self, layers):
@@ -877,10 +994,15 @@ def embedding_column(
trainable=trainable)
-def shared_embedding_columns(
- categorical_columns, dimension, combiner='mean', initializer=None,
- shared_embedding_collection_name=None, ckpt_to_load_from=None,
- tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+def shared_embedding_columns_v2(categorical_columns,
+ dimension,
+ combiner='mean',
+ initializer=None,
+ shared_embedding_collection_name=None,
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None,
+ max_norm=None,
+ trainable=True):
"""List of dense columns that convert from sparse, categorical input.
This is similar to `embedding_column`, except that it produces a list of
@@ -1803,51 +1925,6 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
-class StateManager(object):
- """Manages the state associated with FeatureColumns.
-
- Some `FeatureColumn`s create variables or resources to assist their
- computation. The `StateManager` is responsible for creating and storing these
- objects since `FeatureColumn`s are supposed to be stateless configuration
- only.
- """
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- """Creates a new variable or returns an existing one.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: variable name.
- shape: variable shape.
- dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
- initializer: initializer instance (callable).
-
- Returns:
- The variable.
- """
- raise NotImplementedError('StateManager.get_variable')
-
- def get_resource(self, feature_column, name, resource_creator):
- """Creates a new resource or returns an existing one.
-
- Resources can be things such as tables etc.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: Name of the resource.
- resource_creator: A callable that can create the resource.
-
- Returns:
- The resource.
- """
- raise NotImplementedError('StateManager.get_resource')
-
-
class FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -2550,6 +2627,17 @@ class EmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
@@ -2558,13 +2646,8 @@ class EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name='embedding_weights')
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
@@ -2637,6 +2720,68 @@ def _get_graph_for_variable(var):
return var.graph
+class SharedEmbeddingStateManager(Layer):
+ """A state manager that handle the state of shared embedding columns.
+
+ This can handle multiple sets of columns that share variables."""
+
+ def __init__(self, trainable=True, name=None, **kwargs):
+ """Constructs a `SharedEmbeddingStateManager`.
+
+ Args:
+ trainable: If true, variables created are trainable.
+ name: Name of the State Manager.
+ **kwargs: Keyword arguments.
+ """
+ super(SharedEmbeddingStateManager, self).__init__(
+ name=name, trainable=trainable, **kwargs)
+ self._var_dict = {}
+
+ def create_variable(self,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a variable.
+
+ Makes sure only one var is created per `shared_collection_name`. `name` is
+ ignored here as the variable is named `shared_collection_name` instead.
+
+ Args:
+ name: Name of the variable. Not used.
+ shape: Variable shape.
+ dtype: Variable type.
+ trainable: If variable created should be trainable or not.
+ initializer: Variable initializer.
+
+ Returns:
+ A variable or partitioned variable.
+ """
+ if name in self._var_dict:
+ var = self._var_dict[name]
+ return var
+ with variable_scope.variable_scope(
+ self.name, reuse=variable_scope.AUTO_REUSE):
+ var = self.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ trainable=self.trainable and trainable,
+ initializer=initializer,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._var_dict[name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name not in self._var_dict:
+ raise ValueError('Variable name: {} not recognized.'.format(name))
+ return self._var_dict[name]
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2675,6 +2820,16 @@ class SharedEmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the shared embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ name=self.shared_collection_name,
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
@@ -2687,13 +2842,8 @@ class SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name=self.shared_collection_name)
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 80a9d5d40e..58168e0f9e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -33,12 +33,12 @@ from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
+from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
+from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
-from tensorflow.python.feature_column.feature_column_v2 import InputLayer
from tensorflow.python.feature_column.feature_column_v2 import StateManager
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
-from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -269,7 +269,7 @@ class NumericColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_parse_example_with_default_value(self):
@@ -291,7 +291,7 @@ class NumericColumnTest(test.TestCase):
no_data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
def test_normalizer_fn_must_be_callable(self):
@@ -305,7 +305,7 @@ class NumericColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
def test_get_dense_tensor(self):
@@ -439,7 +439,7 @@ class BucketizedColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([bucketized_price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
@@ -717,7 +717,7 @@ class HashedCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -736,7 +736,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
- with self.test_session():
+ with self.cached_session():
self.assertEqual(dtypes.int64, output.values.dtype)
self.assertAllEqual(expected_values, output.values.eval())
self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
@@ -792,7 +792,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_int32_64_is_compatible(self):
@@ -806,7 +806,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_get_sparse_tensors(self):
@@ -824,22 +824,6 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_hash_bucket('aaa', 10)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column._get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
transformation_cache = FeatureTransformationCache({
@@ -1000,7 +984,7 @@ class CrossedColumnTest(test.TestCase):
features=fc.make_parse_example_spec([price_cross_wire]))
self.assertIn('price', features)
self.assertIn('wire', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
wire_sparse = features['wire']
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
@@ -1023,7 +1007,7 @@ class CrossedColumnTest(test.TestCase):
}
outputs = _transform_features(features, [price_cross_wire], None)
output = outputs[price_cross_wire]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
@@ -2640,13 +2624,13 @@ class _LinearModelTest(test.TestCase):
sess.run(net, feed_dict={features['price']: np.array(1)})
-class InputLayerTest(test.TestCase):
+class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_retrieving_input(self):
features = {'a': [0.]}
- input_layer = InputLayer(fc_old.numeric_column('a'))
- inputs = self.evaluate(input_layer(features))
+ feature_layer = FeatureLayer(fc.numeric_column('a'))
+ inputs = self.evaluate(feature_layer(features))
self.assertAllClose([[0.]], inputs)
def test_reuses_variables(self):
@@ -2657,7 +2641,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
def _embedding_column_initializer(shape, dtype, partition_info):
@@ -2670,16 +2654,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
- inputs = input_layer(features)
- variables = input_layer.variables
+ inputs = feature_layer(features)
+ variables = feature_layer.variables
# Sanity check: test that the inputs are correct.
self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
@@ -2687,13 +2671,13 @@ class InputLayerTest(test.TestCase):
# Check that only one variable was created.
self.assertEqual(1, len(variables))
- # Check that invoking input_layer on the same features does not create
+ # Check that invoking feature_layer on the same features does not create
# additional variables
- _ = input_layer(features)
+ _ = feature_layer(features)
self.assertEqual(1, len(variables))
- self.assertEqual(variables[0], input_layer.variables[0])
+ self.assertEqual(variables[0], feature_layer.variables[0])
- def test_feature_column_input_layer_gradient(self):
+ def test_feature_column_feature_layer_gradient(self):
with context.eager_mode():
sparse_input = sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (2, 0)),
@@ -2701,7 +2685,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
@@ -2715,16 +2699,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
def scale_matrix():
- matrix = input_layer(features)
+ matrix = feature_layer(features)
return 2 * matrix
# Sanity check: Verify that scale_matrix returns the correct output.
@@ -2739,185 +2723,154 @@ class InputLayerTest(test.TestCase):
self.assertAllEqual([0, 1, 2], indexed_slice.indices)
self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
-
-class FunctionalInputLayerTest(test.TestCase):
-
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.input_layer(features={}, feature_columns=[])
+ FeatureLayer(feature_columns=[])(features={})
def test_should_be_dense_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- ])
+ with self.assertRaisesRegexp(ValueError, 'must be a DenseColumn'):
+ FeatureLayer(feature_columns=[
+ fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])(
+ features={
+ 'a': [[0]]
+ })
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ FeatureLayer(feature_columns={'a': fc.numeric_column('a')})(
+ features={
+ 'a': [[0]]
+ })
def test_bare_column(self):
with ops.Graph().as_default():
features = features = {'a': [0.]}
- net = fc.input_layer(features, fc_old.numeric_column('a'))
+ net = FeatureLayer(fc.numeric_column('a'))(features)
with _initialized_session():
self.assertAllClose([[0.]], net.eval())
def test_column_generator(self):
with ops.Graph().as_default():
features = features = {'a': [0.], 'b': [1.]}
- columns = (fc_old.numeric_column(key) for key in features)
- net = fc.input_layer(features, columns)
+ columns = (fc.numeric_column(key) for key in features)
+ net = FeatureLayer(columns)(features)
with _initialized_session():
self.assertAllClose([[0., 1.]], net.eval())
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ FeatureLayer(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])(
+ features={
+ 'a': [[0]]
+ })
def test_one_column(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1.], [5.]], net.eval())
def test_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+ def test_compute_output_shape(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2', shape=4)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+ }
+ feature_layer = FeatureLayer([price1, price2])
+ self.assertEqual((None, 6), feature_layer.compute_output_shape((None,)))
+ net = feature_layer(features)
+ with _initialized_session():
+ self.assertAllClose(
+ [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval())
+
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
def test_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session():
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
- def test_fills_cols_to_vars(self):
- # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
- # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
- # creates a Variable.
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
+ def test_cols_to_output_tensors(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
- self.assertIsInstance(cols_to_vars[some_embedding_column][0],
- variables_lib.Variable)
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
- with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+ cols_dict = {}
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ feature_layer = FeatureLayer([price1, price2])
+ net = feature_layer(features, cols_dict)
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], cols_dict[price1].eval())
+ self.assertAllClose([[3.], [4.]], cols_dict[price2].eval())
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
}
- net1 = fc.input_layer(features, [price_a, price_b])
- net2 = fc.input_layer(features, [price_b, price_a])
+ net1 = FeatureLayer([price_a, price_b])(features)
+ net2 = FeatureLayer([price_b, price_a])(features)
with _initialized_session():
self.assertAllClose([[1., 3.]], net1.eval())
self.assertAllClose([[1., 3.]], net2.eval())
def test_fails_for_categorical_column(self):
- animal = fc_old.categorical_column_with_identity('animal', num_buckets=4)
+ animal = fc.categorical_column_with_identity('animal', num_buckets=4)
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
- fc.input_layer(features, [animal])
+ with self.assertRaisesRegexp(Exception, 'must be a DenseColumn'):
+ FeatureLayer([animal])(features)
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -2926,12 +2879,12 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2])
+ FeatureLayer([price1, price2])(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -2941,31 +2894,31 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2, price3])
+ FeatureLayer([price1, price2, price3])(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'Dimensions of inputs should match'):
sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
sess.run(
net,
@@ -2975,9 +2928,9 @@ class FunctionalInputLayerTest(test.TestCase):
})
def test_multiple_layers_with_same_embedding_column(self):
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
+ some_embedding_column = fc.embedding_column(
some_sparse_column, dimension=10)
with ops.Graph().as_default():
@@ -2985,28 +2938,30 @@ class FunctionalInputLayerTest(test.TestCase):
'sparse_feature': [['a'], ['x']],
}
all_cols = [some_embedding_column]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(all_cols)(features)
+ FeatureLayer(all_cols)(features)
# Make sure that 2 variables get created in this case.
self.assertEqual(2, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
expected_var_names = [
- 'input_layer/sparse_feature_embedding/embedding_weights:0',
- 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ 'feature_layer/sparse_feature_embedding/embedding_weights:0',
+ 'feature_layer_1/sparse_feature_embedding/embedding_weights:0'
]
self.assertItemsEqual(
expected_var_names,
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
with ops.Graph().as_default():
features = {
@@ -3022,27 +2977,33 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
all_cols = [embedding_column_a, embedding_column_b]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
all_cols = [embedding_column_a, embedding_column_b]
with ops.Graph().as_default():
+ shared_state_manager1 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3055,12 +3016,16 @@ class FunctionalInputLayerTest(test.TestCase):
values=(1, 2, 1),
dense_shape=(2, 2)),
}
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager1)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
with ops.Graph().as_default():
+ shared_state_manager2 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features1 = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3074,12 +3039,14 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
- fc.input_layer(features1, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager2)(
+ features1)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_with_numpy_input_fn(self):
@@ -3092,14 +3059,14 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- # one_hot_body_style has 3 dims in input_layer.
- one_hot_body_style = fc_old.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- embedded_body_style = fc_old.embedding_column(
+ # one_hot_body_style has 3 dims in feature_layer.
+ one_hot_body_style = fc.indicator_column(body_style)
+ # embedded_body_style has 5 dims in feature_layer.
+ embedded_body_style = fc.embedding_column(
body_style, dimension=5, initializer=_initializer)
input_fn = numpy_io.numpy_input_fn(
@@ -3110,8 +3077,8 @@ class FunctionalInputLayerTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_body_style])
+ net = FeatureLayer([price, one_hot_body_style, embedded_body_style])(
+ features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
@@ -3137,18 +3104,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=5, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3165,8 +3132,7 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
self.assertEqual(1, features['country'].shape.ndims)
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
@@ -3187,18 +3153,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=2, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3219,8 +3185,7 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2,))
country_data = np.array([['US'], ['CA']])
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 2, net.shape[1])
with _initialized_session() as sess:
@@ -3237,8 +3202,8 @@ class FunctionalInputLayerTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -3246,13 +3211,13 @@ class FunctionalInputLayerTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
@@ -3267,7 +3232,7 @@ class MakeParseExampleSpecTest(test.TestCase):
@property
def name(self):
- return "_TestFeatureColumn"
+ return '_TestFeatureColumn'
def transform_feature(self, transformation_cache, state_manager):
pass
@@ -3427,7 +3392,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_vocabulary_size(self):
@@ -3451,7 +3416,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_num_oov_buckets(self):
@@ -3521,7 +3486,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3593,25 +3558,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_file(
- key='aaa',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',
@@ -3972,7 +3918,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3994,7 +3940,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4043,24 +3989,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
@@ -4311,7 +4239,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4356,22 +4284,6 @@ class IdentityCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 1, 0),
- dense_shape=(2, 2))
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
id_weight_pair = column.get_sparse_tensors(
@@ -4595,7 +4507,7 @@ class IndicatorColumnTest(test.TestCase):
'animal': ['fox', 'fox']
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_2D_shape_succeeds(self):
@@ -4610,7 +4522,7 @@ class IndicatorColumnTest(test.TestCase):
dense_shape=[2, 1])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_multi_hot(self):
@@ -4623,7 +4535,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
def test_multi_hot2(self):
@@ -4635,7 +4547,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
def test_deep_copy(self):
@@ -4660,7 +4572,7 @@ class IndicatorColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_indicator]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4765,16 +4677,16 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
- def test_input_layer(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ def test_feature_layer(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- net = fc.input_layer(features, [animal])
+ net = FeatureLayer([animal])(features)
with _initialized_session():
self.assertAllClose([[0., 1., 1., 0.]], net.eval())
@@ -4786,12 +4698,13 @@ class _TestStateManager(StateManager):
self._all_variables = {}
self._trainable = trainable
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
if feature_column not in self._all_variables:
self._all_variables[feature_column] = {}
var_dict = self._all_variables[feature_column]
@@ -4801,11 +4714,19 @@ class _TestStateManager(StateManager):
var = variable_scope.get_variable(
name=name,
shape=shape,
- initializer=initializer,
- trainable=self._trainable)
+ dtype=dtype,
+ trainable=self._trainable and trainable,
+ initializer=initializer)
var_dict[name] = var
return var
+ def get_variable(self, feature_column, name):
+ if feature_column not in self._all_variables:
+ raise ValueError('Do not recognize FeatureColumn.')
+ if name in self._all_variables[feature_column]:
+ return self._all_variables[feature_column][name]
+ raise ValueError('Could not find variable.')
+
class EmbeddingColumnTest(test.TestCase):
@@ -4898,7 +4819,7 @@ class EmbeddingColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_embedded]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4967,6 +4888,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5028,6 +4950,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5043,36 +4966,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 5))
-
- # Build columns.
- categorical_column = fc.categorical_column_with_identity(
- key='aaa', num_buckets=3)
- embedding_column = fc.embedding_column(categorical_column, dimension=2)
-
- # Provide sparse input and get dense result.
- embedding_column.get_dense_tensor(
- FeatureTransformationCache({
- 'aaa': sparse_input
- }),
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
- tuple([v.name for v in global_vars]))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in my_vars]))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5117,6 +5010,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
input_indices = array_ops.placeholder(dtype=dtypes.int64)
@@ -5187,6 +5081,7 @@ class EmbeddingColumnTest(test.TestCase):
ckpt_to_load_from=ckpt_path,
tensor_name_in_ckpt=ckpt_tensor)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5354,7 +5249,7 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_input_layer(self):
+ def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5392,30 +5287,29 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ l = FeatureLayer((embedding_column,))
+ feature_layer = l({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
- def test_input_layer_not_trainable(self):
+ def test_feature_layer_not_trainable(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5453,65 +5347,26 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer,
trainable=False)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ feature_layer = FeatureLayer((embedding_column,))({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
self.assertItemsEqual(
[], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
-
-
-class _TestSharedEmbeddingStateManager(StateManager):
- """Manages the state for shared embedding columns.
-
- This can handle multiple groups of shared embedding columns.
- """
-
- def __init__(self, trainable=True):
- # Dict of shared_embedding_collection_name to a dict of variables.
- self._all_variables = {}
- self._trainable = trainable
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- if not isinstance(feature_column, fc.SharedEmbeddingColumn):
- raise ValueError(
- 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. '
- 'Given type: {} '.format(type(feature_column)))
-
- collection_name = feature_column.shared_collection_name
- if collection_name not in self._all_variables:
- self._all_variables[collection_name] = {}
- var_dict = self._all_variables[collection_name]
- if name in var_dict:
- return var_dict[name]
- else:
- var = variable_scope.get_variable(
- name=name,
- shape=shape,
- initializer=initializer,
- trainable=self._trainable)
- var_dict[name] = var
- return var
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
class SharedEmbeddingColumnTest(test.TestCase):
@@ -5522,7 +5377,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
@@ -5560,7 +5415,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5605,7 +5460,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- original_a, _ = fc.shared_embedding_columns(
+ original_a, _ = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5613,7 +5468,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
shared_embedding_collection_name='shared_embedding_collection_name',
ckpt_to_load_from='my_ckpt',
tensor_name_in_ckpt='my_ckpt_tensor',
- max_norm=42., trainable=False)
+ max_norm=42.,
+ trainable=False)
for embedding_column_a in (original_a, copy.deepcopy(original_a)):
self.assertEqual('aaa', embedding_column_a.categorical_column.name)
self.assertEqual(3, embedding_column_a.categorical_column.num_buckets)
@@ -5642,8 +5498,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
- fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b], dimension=2,
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b],
+ dimension=2,
initializer='not_fn')
def test_incompatible_column_type(self):
@@ -5656,7 +5513,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'all categorical_columns must have the same type.*'
'IdentityCategoricalColumn.*HashedCategoricalColumn'):
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b, categorical_column_c],
dimension=2)
@@ -5669,11 +5526,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='bbb', num_buckets=3)
weighted_categorical_column_b = fc.weighted_categorical_column(
categorical_column_b, weight_feature_key='bbb_weights')
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, weighted_categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, weighted_categorical_column_b],
dimension=2)
@@ -5682,8 +5539,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
b = fc.categorical_column_with_vocabulary_list(
key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
data = example_pb2.Example(features=feature_pb2.Features(
feature={
'aaa':
@@ -5698,7 +5554,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_embedded, b_embedded]))
self.assertIn('aaa', features)
self.assertIn('bbb', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5717,8 +5573,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
def test_transform_feature(self):
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
features = {
'aaa': sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -5788,10 +5643,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager(name='shared_feature_layer')
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -5801,7 +5659,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
+ self.assertItemsEqual(('shared_feature_layer/aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
@@ -5809,58 +5667,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- # Inputs.
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
- input_features = {'aaa': input_a, 'bbb': input_b}
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_values = (
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return embedding_values
-
- # Build columns.
- categorical_column_a = fc.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- fc.input_layer(
- input_features, [embedding_column_a, embedding_column_b],
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in global_vars))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in my_vars))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5903,10 +5709,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager()
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -6096,7 +5905,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
- def _test_input_layer(self, trainable=True):
+ def _test_feature_layer(self, trainable=True):
# Inputs.
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -6111,6 +5920,18 @@ class SharedEmbeddingColumnTest(test.TestCase):
indices=((0, 0),),
values=(0,),
dense_shape=(2, 5))
+ sparse_input_c = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 1), (1, 1), (1, 3)),
+ values=(2, 0, 1),
+ dense_shape=(2, 5))
+ sparse_input_d = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids []
+ indices=((0, 1),),
+ values=(2,),
+ dense_shape=(2, 5))
# Embedding variable.
embedding_dimension = 2
@@ -6130,51 +5951,127 @@ class SharedEmbeddingColumnTest(test.TestCase):
# example 0:
# A ids [2], embedding = [7, 11]
# B ids [0], embedding = [1, 2]
- (7., 11., 1., 2.),
+ # C ids [2], embedding = [7, 11]
+ # D ids [2], embedding = [7, 11]
+ (7., 11., 1., 2., 7., 11., 7., 11.),
# example 1:
# A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
# B ids [], embedding = [0, 0]
- (2., 3.5, 0., 0.),
+ # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # D ids [], embedding = [0, 0]
+ (2., 3.5, 0., 0., 2., 3.5, 0., 0.),
)
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=vocabulary_size)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=vocabulary_size)
+
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer,
trainable=trainable)
+ embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d],
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=trainable)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+
+ features = {
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ 'ccc': sparse_input_c,
+ 'ddd': sparse_input_d
+ }
# Provide sparse input and get dense result.
- input_layer = fc.input_layer(
- features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
- feature_columns=(embedding_column_b, embedding_column_a))
+ feature_layer = FeatureLayer(
+ feature_columns=(embedding_column_b, embedding_column_a,
+ embedding_column_c, embedding_column_d),
+ shared_state_manager=shared_state_manager)(
+ features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
if trainable:
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in trainable_vars]))
else:
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
shared_embedding_vars = global_vars
with _initialized_session():
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
+
+ def test_feature_layer(self):
+ self._test_feature_layer()
+
+ def test_feature_layer_no_trainable(self):
+ self._test_feature_layer(trainable=False)
+
- def test_input_layer(self):
- self._test_input_layer()
+class SharedEmbeddingStateManagerTest(test.TestCase):
- def test_input_layer_no_trainable(self):
- self._test_input_layer(trainable=False)
+ def test_basic(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_b = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ self.assertEqual(var_a, var_b)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertIsInstance(var_a, variables_lib.Variable)
+
+ def test_multiple_sets(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=3)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=3)
+
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_c = shared_state_manager.create_variable('ccc_ddd_shared_embedding',
+ [5, 10])
+ self.assertIsInstance(var_a, variables_lib.Variable)
+ self.assertIsInstance(var_c, variables_lib.Variable)
+ self.assertNotEquals(var_a, var_c)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertEqual('shared_feature_layer/ccc_ddd_shared_embedding:0',
+ var_c.name)
class WeightedCategoricalColumnTest(test.TestCase):
@@ -6271,7 +6168,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_weighted]))
self.assertIn('aaa', features)
self.assertIn('weights', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index b3eb57d067..4b2706d4cf 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Operations that generate constants.
-See the @{$python/constant_op$constants guide}.
+See the [constants guide](https://tensorflow.org/api_guides/python/constant_op).
"""
# Must be separate from array_ops to avoid a cyclic dependency.
@@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
scalar_cache = ctx.scalar_cache()
tensor = scalar_cache.get(cache_key, None)
if tensor is not None:
- return tensor
+ return ops.EagerTensor(
+ value, context=handle, device=device, dtype=dtype, other_value=tensor)
t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
scalar_cache[cache_key] = t
return t
@@ -145,6 +146,17 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
[-1. -1. -1.]]
```
+ `tf.constant` differs from `tf.fill` in a few ways:
+
+ * `tf.constant` supports arbitrary constants, not just uniform scalar
+ Tensors like `tf.fill`.
+ * `tf.constant` creates a `Const` node in the computation graph with the
+ exact value at graph construction time. On the other hand, `tf.fill`
+ creates an Op in the graph that is expanded at runtime.
+ * Because `tf.constant` only embeds constant values in the graph, it does
+ not support dynamic shapes based on other runtime Tensors, whereas
+ `tf.fill` does.
+
Args:
value: A constant value (or list) of output type `dtype`.
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index ab06a2babf..06c653097a 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import threading
from tensorflow.python.util.tf_export import tf_export
@@ -229,6 +230,12 @@ class DeviceSpec(object):
"""
return DeviceSpec().parse_from_string(spec)
+ def __eq__(self, other):
+ return self.to_string() == other.to_string()
+
+ def __hash__(self):
+ return hash(self.to_string())
+
def check_valid(spec):
"""Check that a device spec is valid.
@@ -254,6 +261,14 @@ def canonical_name(device):
return device.to_string()
+# Cache from DeviceSpec objects to their corresponding device functions.
+# This cache is maintained for correctness, not performance: it makes it
+# possible to compare the device function stacks belonging to different
+# graphs in a meaningful way.
+_cached_device_functions = {}
+_cache_lock = threading.Lock()
+
+
def merge_device(spec):
"""Returns a device function that merges devices specifications.
@@ -280,11 +295,18 @@ def merge_device(spec):
Raises:
ValueError: if the spec was not valid.
"""
- if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
- def _device_function(node_def):
- current_device = DeviceSpec.from_string(node_def.device or "")
- copy_spec = copy.copy(spec)
- copy_spec.merge_from(current_device) # current_device takes precedence.
- return copy_spec
- return _device_function
+ with _cache_lock:
+ if not isinstance(spec, DeviceSpec):
+ spec = DeviceSpec.from_string(spec or "")
+ cached_function = _cached_device_functions.get(spec, None)
+ if cached_function is not None:
+ return cached_function
+
+ def _device_function(node_def):
+ current_device = DeviceSpec.from_string(node_def.device or "")
+ copy_spec = copy.copy(spec)
+ copy_spec.merge_from(current_device) # current_device takes precedence.
+ return copy_spec
+
+ _cached_device_functions[spec] = _device_function
+ return _device_function
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index 6e844e14b9..bc3c81b2a2 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -15,7 +15,7 @@
"""Function for interpolating formatted errors from the TensorFlow runtime.
Exposes the function `interpolate` to interpolate messages with tags of the form
-^^type:name:format^^.
+{{type name}}.
"""
from __future__ import absolute_import
@@ -26,21 +26,17 @@ import collections
import itertools
import os
import re
-import string
import six
from tensorflow.python.util import tf_stack
-
_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
-_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
-_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
- name=_NAME_REGEX, fmt=_FORMAT_REGEX)
+_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
-_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
+_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
-_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
+_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
_BAD_FILE_SUBSTRINGS = [
os.path.join("tensorflow", "python"),
@@ -52,16 +48,9 @@ def _parse_message(message):
"""Parses the message.
Splits the message into separators and tags. Tags are named tuples
- representing the string ^^type:name:format^^ and they are separated by
- separators. For example, in
- "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
- three separators. The separators are the numeric characters.
-
- Supported tags after node:<node_name>
- file: Replaced with the filename in which the node was defined.
- line: Replaced by the line number at which the node was defined.
- colocations: Replaced by a multi-line message describing the file and
- line numbers at which this node was colocated with other nodes.
+ representing the string {{type name}} and they are separated by
+ separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
+ two tags and three separators. The separators are the numeric characters.
Args:
message: String to parse
@@ -69,8 +58,8 @@ def _parse_message(message):
Returns:
(list of separator strings, list of _ParseTags).
- For example, if message is "123^^node:Foo:${file}^^456" then this function
- returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")])
+ For example, if message is "123{{node Foo}}456" then this function
+ returns (["123", "456"], [_ParseTag("node", "Foo")])
"""
seps = []
tags = []
@@ -79,7 +68,7 @@ def _parse_message(message):
match = re.match(_INTERPOLATION_PATTERN, message[pos:])
if match:
seps.append(match.group(1))
- tags.append(_ParseTag(match.group(3), match.group(4), match.group(5)))
+ tags.append(_ParseTag(match.group(3), match.group(4)))
pos += match.end()
else:
break
@@ -111,12 +100,12 @@ def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
return prefix + message
str_list = []
- str_list.append("%sDevice assignments active during op '%s' creation:"
- % (prefix, name))
+ str_list.append(
+ "%sDevice assignments active during op '%s' creation:" % (prefix, name))
for traceable_obj in device_assignment_list:
- location_summary = "<{file}:{line}>".format(file=traceable_obj.filename,
- line=traceable_obj.lineno)
+ location_summary = "<{file}:{line}>".format(
+ file=traceable_obj.filename, line=traceable_obj.lineno)
subs = {
"prefix": prefix,
"indent": " ",
@@ -160,12 +149,12 @@ def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
return prefix + message
str_list = []
- str_list.append("%sNode-device colocations active during op '%s' creation:"
- % (prefix, name))
+ str_list.append("%sNode-device colocations active during op '%s' creation:" %
+ (prefix, name))
for coloc_name, location in colocation_dict.items():
- location_summary = "<{file}:{line}>".format(file=location.filename,
- line=location.lineno)
+ location_summary = "<{file}:{line}>".format(
+ file=location.filename, line=location.lineno)
subs = {
"prefix": prefix,
"indent": " ",
@@ -180,8 +169,10 @@ def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
def _compute_colocation_summary_from_op(op, prefix=""):
"""Fetch colocation file, line, and nesting and return a summary string."""
- return _compute_colocation_summary_from_dict(
- op.name, op._colocation_dict, prefix) # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
+ prefix)
+ # pylint: enable=protected-access
def _find_index_of_defining_frame_for_op(op):
@@ -276,7 +267,7 @@ def compute_field_dict(op):
def interpolate(error_message, graph):
"""Interpolates an error message.
- The error message can contain tags of the form ^^type:name:format^^ which will
+ The error message can contain tags of the form ^^type:name^^ which will
be replaced.
Args:
@@ -285,29 +276,29 @@ def interpolate(error_message, graph):
message.
Returns:
- The string with tags of the form ^^type:name:format^^ interpolated.
+ The string with tags of the form {{type name}} interpolated.
"""
seps, tags = _parse_message(error_message)
+ subs = []
+ end_msg = ""
- node_name_to_substitution_dict = {}
- for name in [t.name for t in tags]:
- if name in node_name_to_substitution_dict:
- continue
+ for t in tags:
try:
- op = graph.get_operation_by_name(name)
+ op = graph.get_operation_by_name(t.name)
except KeyError:
op = None
+ msg = "{{%s %s}}" % (t.type, t.name)
if op is not None:
field_dict = compute_field_dict(op)
- else:
- msg = "<NA>"
- field_dict = collections.defaultdict(lambda s=msg: s)
- node_name_to_substitution_dict[name] = field_dict
-
- subs = [
- string.Template(tag.format).safe_substitute(
- node_name_to_substitution_dict[tag.name]) for tag in tags
- ]
+ if t.type == "node":
+ msg = "node %s%s " % (t.name, field_dict["defined_at"])
+ elif t.type == "colocation_node":
+ msg = "node %s%s having device %s " % (t.name, field_dict["defined_at"],
+ field_dict["devices"])
+ end_msg += "\n\n" + field_dict["devs_and_colocs"]
+ subs.append(msg)
+ subs.append(end_msg)
+
return "".join(
itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index 0427156b2b..1b77548592 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -50,9 +50,9 @@ def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
stack = []
for idx in range(0, num_outer_frames):
stack.append(op._traceback[idx])
- for idx in range(len(stack), len(stack)+num_user_frames):
+ for idx in range(len(stack), len(stack) + num_user_frames):
stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
- for idx in range(len(stack), len(stack)+num_inner_tf_frames):
+ for idx in range(len(stack), len(stack) + num_inner_tf_frames):
stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
op._traceback = stack
@@ -62,13 +62,11 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
def testCorrectFormatWithActiveDeviceAssignments(self):
assignments = []
assignments.append(
- traceable_stack.TraceableObject("/cpu:0",
- filename="hope.py",
- lineno=24))
+ traceable_stack.TraceableObject(
+ "/cpu:0", filename="hope.py", lineno=24))
assignments.append(
- traceable_stack.TraceableObject("/gpu:2",
- filename="please.py",
- lineno=42))
+ traceable_stack.TraceableObject(
+ "/gpu:2", filename="please.py", lineno=42))
summary = error_interpolation._compute_device_summary_from_list(
"nodename", assignments, prefix=" ")
@@ -90,12 +88,10 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
class ComputeColocationSummaryFromOpTest(test.TestCase):
def testCorrectFormatWithActiveColocations(self):
- t_obj_1 = traceable_stack.TraceableObject(None,
- filename="test_1.py",
- lineno=27)
- t_obj_2 = traceable_stack.TraceableObject(None,
- filename="test_2.py",
- lineno=38)
+ t_obj_1 = traceable_stack.TraceableObject(
+ None, filename="test_1.py", lineno=27)
+ t_obj_2 = traceable_stack.TraceableObject(
+ None, filename="test_2.py", lineno=38)
colocation_dict = {
"test_node_1": t_obj_1,
"test_node_2": t_obj_2,
@@ -140,10 +136,11 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def testFindIndexOfDefiningFrameForOp(self):
local_op = constant_op.constant(42).op
user_filename = "hope.py"
- _modify_op_stack_with_filenames(local_op,
- num_user_frames=3,
- user_filename=user_filename,
- num_inner_tf_frames=5)
+ _modify_op_stack_with_filenames(
+ local_op,
+ num_user_frames=3,
+ user_filename=user_filename,
+ num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
# Expected frame is 6th from the end because there are 5 inner frames witih
# TF filenames.
@@ -155,46 +152,46 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
# Truncate stack to known length.
local_op._traceback = local_op._traceback[:7]
# Ensure all frames look like TF frames.
- _modify_op_stack_with_filenames(local_op,
- num_user_frames=0,
- user_filename="user_file.py",
- num_inner_tf_frames=7)
+ _modify_op_stack_with_filenames(
+ local_op,
+ num_user_frames=0,
+ user_filename="user_file.py",
+ num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
self.assertEqual(0, idx)
def testNothingToDo(self):
normal_string = "This is just a normal string"
- interpolated_string = error_interpolation.interpolate(normal_string,
- self.graph)
+ interpolated_string = error_interpolation.interpolate(
+ normal_string, self.graph)
self.assertEqual(interpolated_string, normal_string)
- def testOneTag(self):
- one_tag_string = "^^node:Two:${file}^^"
- interpolated_string = error_interpolation.interpolate(one_tag_string,
- self.graph)
- self.assertTrue(interpolated_string.endswith("constant_op.py"),
- "interpolated_string '%s' did not end with constant_op.py"
- % interpolated_string)
-
def testOneTagWithAFakeNameResultsInPlaceholders(self):
- one_tag_string = "^^node:MinusOne:${file}^^"
- interpolated_string = error_interpolation.interpolate(one_tag_string,
- self.graph)
- self.assertEqual("<NA>", interpolated_string)
+ one_tag_string = "{{node MinusOne}}"
+ interpolated_string = error_interpolation.interpolate(
+ one_tag_string, self.graph)
+ self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self):
- two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^"
- interpolated_string = error_interpolation.interpolate(two_tags_no_seps,
- self.graph)
- self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+")
+ two_tags_no_seps = "{{node One}}{{node Three}}"
+ interpolated_string = error_interpolation.interpolate(
+ two_tags_no_seps, self.graph)
+ self.assertRegexpMatches(interpolated_string,
+ "constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
def testTwoTagsWithSeps(self):
- two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;"
- interpolated_string = error_interpolation.interpolate(two_tags_with_seps,
- self.graph)
- expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$"
+ two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
+ interpolated_string = error_interpolation.interpolate(
+ two_tags_with_seps, self.graph)
+ expected_regex = (
+ r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
+ def testNewLine(self):
+ newline = "\n\n{{node One}}"
+ interpolated_string = error_interpolation.interpolate(newline, self.graph)
+ self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
+
class InterpolateDeviceSummaryTest(test.TestCase):
@@ -214,30 +211,26 @@ class InterpolateDeviceSummaryTest(test.TestCase):
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self):
- message = "^^node:zero:${devices}^^"
+ message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self):
- message = "^^node:one:${devices}^^"
+ message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, self.graph)
- num_devices = result.count("tf.device")
- self.assertEqual(1, num_devices)
- self.assertIn("tf.device(/cpu)", result)
+ self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self):
- message = "^^node:two:${devices}^^"
+ message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, self.graph)
- num_devices = result.count("tf.device")
- self.assertEqual(2, num_devices)
- self.assertIn("tf.device(/cpu)", result)
- self.assertIn("tf.device(/cpu:0)", result)
+ self.assertEqual(2, result.count("tf.device(/cpu)"))
+ self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
- message = "^^node:three:${devices}^^"
+ message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, self.graph)
num_devices = result.count("tf.device")
- self.assertEqual(1, num_devices)
+ self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
expected_re = r"with tf.device\(.*%s\)" % name_re
self.assertRegexpMatches(result, expected_re)
@@ -268,27 +261,26 @@ class InterpolateColocationSummaryTest(test.TestCase):
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
- message = "^^node:Three_with_one:${colocations}^^"
+ message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
- message = "^^node:Four_with_three:${colocations}^^"
+ message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
- "Node One should not appear in Four_with_three's summary:\n%s"
- % result)
+ "Node One should not appear in Four_with_three's summary:\n%s" % result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
- message = "^^node:Five_with_one_with_two:${colocations}^^"
+ message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
- message = "^^node:One:${colocations}^^"
+ message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 9f973de400..5af71f2cfb 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -25,6 +25,7 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -47,11 +48,17 @@ class OpError(Exception):
error_code: The `error_codes_pb2.Code` describing the error.
"""
super(OpError, self).__init__()
- self._message = message
self._node_def = node_def
self._op = op
+ self._message = message
self._error_code = error_code
+ def __reduce__(self):
+ # Allow the subclasses to accept less arguments in their __init__.
+ init_argspec = tf_inspect.getargspec(self.__class__.__init__)
+ args = tuple(getattr(self, arg) for arg in init_argspec.args[1:])
+ return self.__class__, args
+
@property
def message(self):
"""The error message that describes the error."""
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
index 62f8ab030c..574b126cae 100644
--- a/tensorflow/python/framework/errors_test.py
+++ b/tensorflow/python/framework/errors_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import gc
+import pickle
import warnings
from tensorflow.core.lib.core import error_codes_pb2
@@ -107,6 +108,34 @@ class ErrorsTest(test.TestCase):
gc.collect()
self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
+ def testPickleable(self):
+ for error_code in [
+ errors.CANCELLED,
+ errors.UNKNOWN,
+ errors.INVALID_ARGUMENT,
+ errors.DEADLINE_EXCEEDED,
+ errors.NOT_FOUND,
+ errors.ALREADY_EXISTS,
+ errors.PERMISSION_DENIED,
+ errors.UNAUTHENTICATED,
+ errors.RESOURCE_EXHAUSTED,
+ errors.FAILED_PRECONDITION,
+ errors.ABORTED,
+ errors.OUT_OF_RANGE,
+ errors.UNIMPLEMENTED,
+ errors.INTERNAL,
+ errors.UNAVAILABLE,
+ errors.DATA_LOSS,
+ ]:
+ # pylint: disable=protected-access
+ exc = errors_impl._make_specific_exception(None, None, None, error_code)
+ # pylint: enable=protected-access
+ unpickled = pickle.loads(pickle.dumps(exc))
+ self.assertEqual(exc.node_def, unpickled.node_def)
+ self.assertEqual(exc.op, unpickled.op)
+ self.assertEqual(exc.message, unpickled.message)
+ self.assertEqual(exc.error_code, unpickled.error_code)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 5eb59141a2..6901715e5d 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase):
load_library.load_file_system_library(file_system_library)
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f47c0d8a5e..a8aef3a009 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import hashlib
-import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -34,7 +33,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
@@ -42,9 +40,6 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
-# This is to avoid a circular dependency with cond_v2_impl.
-cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-
class Defun(object):
"""Decorator used to define TensorFlow functions.
@@ -1029,20 +1024,10 @@ def _from_definition(fdef, grad_func=None):
result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
python_grad_func, out_names)
# pylint: disable=protected-access
- if ops._USE_C_API:
- serialized = fdef.SerializeToString()
- c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- result._c_func = c_api_util.ScopedTFFunction(c_func)
- result._extra_inputs = []
- else:
- result._definition = fdef
- # Captured inputs are added as regular inputs to a function when it's
- # serialized, i.e. any extra inputs from the original function are now
- # included in `result`._args
- result._extra_inputs = []
- result._hash_str = result._create_hash_str(
- result._definition.signature.input_arg,
- result._definition.signature.output_arg, result._definition.node_def)
+ serialized = fdef.SerializeToString()
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ result._c_func = c_api_util.ScopedTFFunction(c_func)
+ result._extra_inputs = []
# pylint: enable=protected-access
return result
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 1b09506662..a04fa369ae 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -23,7 +23,7 @@ import sys
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
@@ -34,13 +34,13 @@ cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=p
def function_def_to_graph(fdef, input_shapes=None):
- """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+ """Converts a FunctionDef to a function.FuncGraph (sub-class Graph).
- The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+ The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
The input tensors are represented as placeholders.
- Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
- set by the caller.
+ Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
+ by the caller.
Args:
fdef: FunctionDef.
@@ -50,9 +50,9 @@ def function_def_to_graph(fdef, input_shapes=None):
placeholder will have unknown shape.
Returns:
- A _FuncGraph.
+ A FuncGraph.
"""
- func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access
+ func_graph = function.FuncGraph(fdef.signature.name)
graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
fdef, input_shapes)
@@ -60,7 +60,7 @@ def function_def_to_graph(fdef, input_shapes=None):
# Add all function nodes to the graph.
importer.import_graph_def(graph_def, name="")
- # Initialize fields specific to _FuncGraph.
+ # Initialize fields specific to FuncGraph.
# inputs
input_tensor_names = [
@@ -144,6 +144,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+ control_name = "^" + arg_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
for node_def in fdef.node_def:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
@@ -172,6 +174,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
flat_name = "{}:{}".format(node_def.name, flattened_index)
nested_to_flat_tensor_name[nested_name] = flat_name
flattened_index += 1
+ control_name = "^" + node_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index cd2a16ed5a..e013fb6e4d 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
@@ -56,7 +56,7 @@ class FunctionDefToGraphTest(test.TestCase):
fdef = self._build_function_def()
g = function_def_to_graph.function_def_to_graph(fdef)
self.assertEqual(g.name, "_whats_in_a_name")
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
self.assertSequenceEqual(inputs, [2.0, 3.0])
outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
@@ -154,14 +154,20 @@ class FunctionDefToGraphDefTest(test.TestCase):
self.assertDictEqual(
tensor_name_map, {
"x": "x:0",
+ "^x": "^x",
"y": "y:0",
+ "^y": "^y",
"z": "z:0",
+ "^z": "^z",
"foo_1:d:0": "foo_1:0",
"foo_1:e:0": "foo_1:1",
+ "^foo_1": "^foo_1",
"list_output:a:0": "list_output:0",
"list_output:a:1": "list_output:1",
+ "^list_output": "^list_output",
"foo_2:d:0": "foo_2:0",
"foo_2:e:0": "foo_2:1",
+ "^foo_2": "^foo_2",
})
def testShapes(self):
@@ -184,33 +190,56 @@ class FunctionDefToGraphDefTest(test.TestCase):
x = constant_op.constant(5.0)
y = constant_op.constant(10.0)
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def inner_fn():
return x + y
return inner_fn()
- # Instantiate the function in this graph so that
- # `function_def_to_graph` can find it.
- fn()
-
+ @function.defun
def fn2():
return 2 * fn()
- fdef = function._DefinedFunction(fn2, [], []).definition
+ fn2_defun = fn2.get_concrete_function()
+
+ # Call `fn2` to make sure `fn` is correctly instantiated so
+ # `function_def_to_graph` can find it.
+ fn2_defun()
+
+ fdef = fn2_defun._inference_function.definition
func_graph = function_def_to_graph.function_def_to_graph(fdef)
with func_graph.as_default():
x_ph, y_ph = func_graph.inputs
- with self.test_session(graph=func_graph) as sess:
+ with self.session(graph=func_graph) as sess:
self.assertEqual(
sess.run(func_graph.outputs[0], feed_dict={
x_ph: 5.0,
y_ph: 10.0
}), 30.0)
+ def testControlDependencies(self):
+
+ @function.defun
+ def fn(inp):
+ x = constant_op.constant(2.0, name="x")
+ # TODO(b/79881896): Test external control dependency once that's
+ # supported.
+ with ops.control_dependencies([x, inp]):
+ constant_op.constant(3.0, name="y")
+ return 4.0
+
+ inp = constant_op.constant(1.0)
+ fdef = fn.get_concrete_function(inp).function_def
+ func_graph = function_def_to_graph.function_def_to_graph(fdef)
+
+ op = func_graph.get_operation_by_name("y")
+ self.assertEqual(len(op.control_inputs), 2)
+ self.assertEqual(op.control_inputs[0].name, "x")
+ self.assertEqual(op.control_inputs[1].name, "placeholder")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 1707f929b8..903768a039 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -347,7 +347,7 @@ class FunctionTest(test.TestCase):
do_function_inlining=True,
do_constant_folding=True)))
- with self.test_session(graph=g, config=cfg):
+ with self.session(graph=g, config=cfg):
self.assertAllClose(y.eval(), 6.)
self.assertAllClose(dx.eval(), 2.)
@@ -419,7 +419,7 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([z]):
return x * 2
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
@@ -434,7 +434,7 @@ class FunctionTest(test.TestCase):
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
- with g.as_default(), self.test_session():
+ with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
@@ -448,7 +448,7 @@ class FunctionTest(test.TestCase):
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
@@ -530,7 +530,7 @@ class FunctionTest(test.TestCase):
v = variables.Variable(constant_op.constant(10.0))
z = Foo(v)
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllEqual(z.eval(), 101.)
@@ -552,7 +552,7 @@ class FunctionTest(test.TestCase):
expected_val = v.value()
actual_val, actual_shape = Foo()
- with self.test_session(graph=g):
+ with self.session(graph=g):
v.initializer.run()
self.assertAllEqual(expected_val.eval(), actual_val.eval())
self.assertAllEqual(expected_shape, actual_shape.eval())
@@ -667,7 +667,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):
@@ -732,7 +732,7 @@ class FunctionTest(test.TestCase):
dx1, = gradients_impl.gradients([y1], [x])
# Both should produce the same result and gradient.
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))})
self.assertAllClose(vals[0], vals[1])
self.assertAllClose(vals[2], vals[3])
@@ -762,7 +762,7 @@ class FunctionTest(test.TestCase):
z = Bar()
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllEqual(y.eval(), [[12.0]])
self.assertAllEqual(z.eval(), [[1.0]])
@@ -795,7 +795,7 @@ class FunctionTest(test.TestCase):
y = Foo()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self.assertEqual(sess.run(y), 10)
def testCaptureInCond(self):
@@ -810,7 +810,7 @@ class FunctionTest(test.TestCase):
y = Foo(True)
z = Foo(False)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
self.assertEqual(sess.run(y), 1)
self.assertEqual(sess.run(z), 2)
@@ -855,7 +855,7 @@ class FunctionTest(test.TestCase):
y = Foo(x)
z = Bar(x)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
v0, v1 = sess.run([y, z])
self.assertAllEqual(v0, 20.)
self.assertAllEqual(v1, 20.)
@@ -1128,7 +1128,7 @@ class FunctionTest(test.TestCase):
y2 = PartThree(x2)
dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
v0, v1, v2 = sess.run([dx0, dx1, dx2])
self.assertAllEqual(v0, 2.)
@@ -1353,7 +1353,7 @@ class FunctionOverloadTest(test.TestCase):
x = Sinh(constant_op.constant(0.25, dtypes.float32))
y = Sinh(constant_op.constant(0.25, dtypes.float64))
- with self.test_session(graph=g):
+ with self.session(graph=g):
self.assertAllClose(x.eval(), np.sinh(0.25))
self.assertAllClose(y.eval(), np.sinh(0.25))
@@ -1374,7 +1374,7 @@ class FunctionOverloadTest(test.TestCase):
y = F(x)
dx, = gradients_impl.gradients(y, x)
- with self.test_session(graph=g):
+ with self.session(graph=g):
self.assertAllClose(dx.eval(), 0.25)
def testDocString(self):
@@ -1418,7 +1418,7 @@ class FunctionCaptureByValueTest(test.TestCase):
self.assertEqual(0, len(Foo.captured_inputs))
- with self.test_session(graph=g):
+ with self.session(graph=g):
self.assertAllEqual(y.eval(), [[12.0]])
@@ -1701,7 +1701,7 @@ class VariableHoistingTest(test.TestCase):
self.assertEqual("Foo/w", w.op.name)
self.assertEqual("Foo/b", b.op.name)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 7182c28666..2b4d8e7299 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase):
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
- # with self.test_session() as sess:
+ # with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
@@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase):
imported_r, = importer.import_graph_def(graph_def,
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(imported_r), 10)
def testImportWhileLoopInCond(self):
@@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase):
pred = array_ops.placeholder(dtypes.bool)
out = control_flow_ops.cond(pred, ImportFn,
lambda: constant_op.constant(1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out, {pred: True}), 10)
self.assertEqual(sess.run(out, {pred: False}), 1)
@@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase):
out = control_flow_ops.while_loop(
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out), 10)
def testTypeMismatchInGraphDef(self):
@@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase):
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, t.eval())
def testInvalidInputForReturnOperations(self):
@@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase):
array_ops.stack([c, c], name="pack")
gdef = g.as_graph_def()
- with self.test_session():
+ with self.cached_session():
pack, = importer.import_graph_def(gdef, return_elements=["pack"])
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
@@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual([10], biases_grad.get_shape())
def testLargeGraph(self):
- with self.test_session():
+ with self.cached_session():
# The default message byte limit is 64M. Ours is 2G with a warning at 512.
# Adding a 130M entries float32 tensor should exceed the warning, but not
# the hard limit.
@@ -1205,7 +1205,7 @@ class ImportGraphDefTest(test.TestCase):
gdef, return_elements=["p1:0", "p2:0", "f:0", "f:1"], name="")
grad = gradients_impl.gradients([a], [p1, p2])
- with self.test_session(graph=g2) as sess:
+ with self.session(graph=g2) as sess:
feed_dict = {p1: 1, p2: 2}
a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
self.assertEqual(a_val, 3.0)
@@ -1225,7 +1225,7 @@ class ImportGraphDefTest(test.TestCase):
# functions created in g2).
grad = gradients_impl.gradients([a], [p1, p2])
- with self.test_session(graph=g3) as sess:
+ with self.session(graph=g3) as sess:
feed_dict = {p1: 1, p2: 2}
a_val, b_val, grad_val = sess.run([a, b, grad], feed_dict=feed_dict)
self.assertEqual(a_val, 3.0)
@@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase):
z = TestFunc()
- with self.test_session():
+ with self.cached_session():
z_val = z.eval()
self.assertEqual(z_val, -2.0)
@@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 5cf8697210..fc98b91a01 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -70,7 +70,7 @@ class SimpleMetaGraphTest(test.TestCase):
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = ops.Graph()
- with self.test_session(graph=orig_graph) as sess:
+ with self.session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = array_ops.placeholder(
dtypes.float32, shape=[], name="input")
@@ -98,7 +98,7 @@ class SimpleMetaGraphTest(test.TestCase):
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = ops.Graph()
- with self.test_session(graph=new_graph) as sess:
+ with self.session(graph=new_graph) as sess:
# Import the previously export meta graph.
meta_graph.import_scoped_meta_graph(filename)
@@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertEqual(new_output_value, output_value)
def testStrippedOpListNestedFunctions(self):
- with self.test_session():
+ with self.cached_session():
# Square two levels deep
@function.Defun(dtypes.int32)
def f0(x):
@@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase):
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
- with self.test_session():
+ with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -197,7 +197,7 @@ class SimpleMetaGraphTest(test.TestCase):
# When inputs to the Complex Op are float64 instances, "T" maps to float64
# and "Tout" maps to complex128. Since these attr values don't map to their
# defaults, they must not be stripped.
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase):
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
- with self.test_session():
+ with self.cached_session():
+
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase):
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.stripped_op_list.op.add()
- with self.test_session():
+ with self.cached_session():
meta_graph_def = meta_graph.create_meta_graph_def(
meta_info_def=meta_info_def, graph_def=graph_def,
strip_default_attrs=True)
@@ -855,7 +856,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
_TestDir("metrics_export"), "meta_graph.pb")
graph = ops.Graph()
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -876,7 +877,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
# Verifies that importing a meta_graph with LOCAL_VARIABLES collection
# works correctly.
graph = ops.Graph()
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
sess.run(initializer)
@@ -885,7 +886,7 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
# collection is of node_list type works, but cannot build initializer
# with the collection.
graph = ops.Graph()
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(
test.test_src_dir_path(
"python/framework/testdata/metrics_export_meta_graph.pb"))
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 5527f52860..75678cbc01 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import os
import re
import sys
import threading
@@ -56,6 +55,7 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
@@ -67,7 +67,7 @@ from tensorflow.python.util.tf_export import tf_export
# Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True
-_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0"
+_USE_C_SHAPES = True
def tensor_id(tensor):
@@ -516,6 +516,11 @@ class Tensor(_TensorLike):
==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
```
+ NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
+ result in inconsistencies between the statically-known graph and the runtime
+ value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
+ instead.
+
Args:
shape: A `TensorShape` representing the shape of this tensor, a
`TensorShapeProto`, a list, a tuple, or None.
@@ -753,6 +758,9 @@ class _EagerTensorBase(Tensor):
def __format__(self, format_spec):
return self.numpy().__format__(format_spec)
+ def __reduce__(self):
+ return (convert_to_tensor, (self.numpy(),))
+
def _numpy(self):
raise NotImplementedError()
@@ -795,6 +803,19 @@ class _EagerTensorBase(Tensor):
"""
raise NotImplementedError()
+ def _num_elements(self):
+ """Number of elements of this Tensor.
+
+ Unlike regular Tensors, the number of elements is always known for
+ EagerTensors.
+
+ This is more performant than tensor.shape.num_elements
+
+ Returns:
+ Long - num elements in the tensor
+ """
+ raise NotImplementedError()
+
def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name
raise NotImplementedError()
@@ -2856,19 +2877,11 @@ class Graph(object):
# TODO(skyewm): fold as much of the above as possible into the C
# implementation
- if self._use_c_api_hack():
- self._scoped_c_graph = c_api_util.ScopedTFGraph()
- # The C API requires all ops to have shape functions. Disable this
- # requirement (many custom ops do not have shape functions, and we don't
- # want to break these existing cases).
- c_api.SetRequireShapeInferenceFns(self._c_graph, False)
- else:
- self._scoped_c_graph = None
-
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- """Temporary hack; can be overridden to force C API usage."""
- return _USE_C_API
+ self._scoped_c_graph = c_api_util.ScopedTFGraph()
+ # The C API requires all ops to have shape functions. Disable this
+ # requirement (many custom ops do not have shape functions, and we don't
+ # want to break these existing cases).
+ c_api.SetRequireShapeInferenceFns(self._c_graph, False)
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@@ -3118,7 +3131,7 @@ class Graph(object):
Returns:
bool indicating whether or not 'name' is registered in function library.
"""
- return name in self._functions
+ return compat.as_str(name) in self._functions
def _get_function(self, name):
"""Returns the function definition for 'name'.
@@ -3128,7 +3141,7 @@ class Graph(object):
Returns:
The function def proto.
"""
- return self._functions.get(name, None)
+ return self._functions.get(compat.as_str(name), None)
def _add_function(self, function):
"""Adds a function to the graph.
@@ -3164,7 +3177,7 @@ class Graph(object):
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
# pylint: enable=protected-access
- self._functions[name] = function
+ self._functions[compat.as_str(name)] = function
# Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12:
@@ -5223,6 +5236,7 @@ _default_graph_stack = _DefaultGraphStack()
# pylint: disable=g-doc-return-or-yield,line-too-long
+@tf_export("init_scope")
@tf_contextlib.contextmanager
def init_scope():
"""A context manager that lifts ops out of control-flow scopes and function-building graphs.
@@ -5252,6 +5266,23 @@ def init_scope():
(3) The gradient tape is paused while the scope is active.
+ When eager execution is enabled, code inside an init_scope block runs with
+ eager execution enabled even when defining graph functions via
+ tf.contrib.eager.defun. For example:
+
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def func():
+ # A defun-decorated function constructs TensorFlow graphs,
+ # it does not execute eagerly.
+ assert not tf.executing_eagerly()
+ with tf.init_scope():
+ # Initialization runs with eager execution enabled
+ assert tf.executing_eagerly()
+ ```
+
Raises:
RuntimeError: if graph state is incompatible with this initialization.
"""
@@ -5333,6 +5364,7 @@ def enable_eager_execution(config=None,
computational graph).
For example:
+
```python
tf.enable_eager_execution()
@@ -5382,11 +5414,12 @@ def enable_eager_execution(config=None,
TensorFlow graph, or if options provided conflict with a previous call
to this function.
"""
- return enable_eager_execution_internal(
- config=config,
- device_policy=device_policy,
- execution_mode=execution_mode,
- server_def=None)
+ if context.default_execution_mode != context.EAGER_MODE:
+ return enable_eager_execution_internal(
+ config=config,
+ device_policy=device_policy,
+ execution_mode=execution_mode,
+ server_def=None)
def enable_eager_execution_internal(config=None,
@@ -5424,15 +5457,15 @@ def enable_eager_execution_internal(config=None,
raise ValueError(
"execution_mode must be one of None, tf.contrib.eager.SYNC, "
"tf.contrib.eager.ASYNC")
- # pylint: disable=protected-access
- if context._default_mode == context.GRAPH_MODE:
+ if context.default_execution_mode == context.GRAPH_MODE:
graph_mode_has_been_used = (
_default_session_stack.stack
or len(get_default_graph().get_operations()) > 0) # pylint: disable=g-explicit-length-test
if graph_mode_has_been_used:
raise ValueError(
"tf.enable_eager_execution must be called at program startup.")
- context._default_mode = context.EAGER_MODE
+ context.default_execution_mode = context.EAGER_MODE
+ # pylint: disable=protected-access
if context._context is None:
context._context = context.Context(
config=config,
@@ -5776,14 +5809,43 @@ class GraphKeys(object):
_STREAMING_MODEL_PORTS = "streaming_model_ports"
@decorator_utils.classproperty
+ @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
def VARIABLES(cls): # pylint: disable=no-self-argument
- logging.log_first_n(logging.WARN,
- "VARIABLES collection name is deprecated, please use "
- "GLOBAL_VARIABLES instead; VARIABLES will be removed "
- "after 2017-03-02.", 1)
return cls.GLOBAL_VARIABLES
+def dismantle_graph(graph):
+ """Cleans up reference cycles from a `Graph`.
+
+ Helpful for making sure the garbage collector doesn't need to run after a
+ temporary `Graph` is no longer needed.
+
+ Args:
+ graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
+ after this function runs.
+ """
+ # pylint: disable=protected-access
+ # OrderedDict, constructed on Graph creation, makes a simple reference loop
+ # and hides it in an __attribute in some Python versions. We don't need to
+ # throw an error if we can't find it, but if we do find it we can break the
+ # loop to avoid creating work for the garbage collector.
+ graph_operations = graph.get_operations()
+ problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
+ # pylint: enable=protected-access
+ if problematic_cycle:
+ try:
+ del problematic_cycle[0][:]
+ except TypeError:
+ # This is probably not one of the problematic Python versions. Continue
+ # with the rest of our cleanup.
+ pass
+ # Now clean up Operation<->Graph reference cycles by clearing all of the
+ # attributes for the Graph and its ops.
+ for op in graph_operations:
+ op.__dict__ = {}
+ graph.__dict__ = {}
+
+
@tf_export("add_to_collection")
def add_to_collection(name, value):
"""Wrapper for `Graph.add_to_collection()` using the default graph.
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/python/framework/ops_enable_eager_test.py
index a02fc24c79..99d06f1c2d 100644
--- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py
+++ b/tensorflow/python/framework/ops_enable_eager_test.py
@@ -12,43 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Unit test for tf.scan under eager execution."""
+"""Tests enabling eager execution at process level."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import googletest
- def benchmarkScan16000(self):
- self.runScan(16000)
- def benchmarkScan32000(self):
- self.runScan(32000)
+class OpsEnableEagerTest(googletest.TestCase):
- def benchmarkScan64000(self):
- self.runScan(64000)
+ def test_enable_eager_execution_multiple_times(self):
+ ops.enable_eager_execution()
+ self.assertTrue(context.executing_eagerly())
- def benchmarkScan128000(self):
- self.runScan(128000)
+ # Calling enable eager execution a second time should not cause an error.
+ ops.enable_eager_execution()
+ self.assertTrue(context.executing_eagerly())
if __name__ == '__main__':
- tf.enable_eager_execution()
- tf.test.main()
+ googletest.main()
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 318387c61b..d59adf3d48 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
- with self.test_session():
+ with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
def testInitialize(self):
- with self.test_session():
+ with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
@@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
pass
def testAddShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
def testUnknownDim(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape)
def testShapeFunctionError(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
def testNegation(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices.eval(), [0, 2])
def testScalarMul(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
def testConvertToTensorNestedArray(self):
- with self.test_session():
+ with self.cached_session():
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
def testShapeTuple(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
@@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(converted, ops.EagerTensor))
def testConvertToTensorNestedTuple(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
def testConvertToTensorNestedTensors(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(values, tensor.eval())
def testConvertToTensorNestedMix(self):
- with self.test_session():
+ with self.cached_session():
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
def testConvertToTensorPreferred(self):
- with self.test_session():
+ with self.cached_session():
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
@@ -493,7 +493,7 @@ class OperationTest(test_util.TensorFlowTestCase):
y.op._add_control_input(z.op) # pylint: disable=protected-access
y.op._add_control_input(x.op) # pylint: disable=protected-access
x.op._add_control_input(y.op) # pylint: disable=protected-access
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Graph is invalid, contains a cycle with 2 nodes"):
@@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("bar_2", g.unique_name("bar"))
def testNameAndVariableScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
@@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase):
# Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key"))
+ def test_defun(self):
+ with context.eager_mode():
+
+ @eager_function.defun
+ def defun():
+ ops.add_to_collection("int", 1)
+ ops.add_to_collection("tensor", constant_op.constant(2))
+
+ @eager_function.defun
+ def inner_defun():
+ self.assertEqual(ops.get_collection("int"), [1])
+ three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
+ ops.add_to_collection("int", 2)
+ self.assertEqual(ops.get_collection("int"), [1, 2])
+ ops.add_to_collection("foo", "bar")
+ self.assertEqual(ops.get_collection("foo"), ["bar"])
+ return three
+
+ self.assertEqual(ops.get_collection("int"), [1])
+ three = inner_defun()
+ self.assertEqual(ops.get_collection("int"), [1, 2])
+ self.assertEqual(ops.get_collection("foo"), ["bar"])
+ return three
+
+ three = defun()
+ self.assertEqual(three.numpy(), 3)
+
ops.NotDifferentiable("FloatOutput")
@@ -2137,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
- with self.test_session():
+ with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
@@ -2389,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
return (a, b)
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2427,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase):
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2459,7 +2486,7 @@ class AsGraphDefTest(test_util.TensorFlowTestCase):
"""Test that the graphdef version is plumbed through to kernels."""
with ops.Graph().as_default() as g:
version = g.graph_def_versions.producer
- with self.test_session(graph=g):
+ with self.session(graph=g):
v = test_ops.graph_def_version().eval()
self.assertEqual(version, v)
@@ -2757,7 +2784,7 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g:
test_util.set_producer_version(g, 7)
old = test_ops.old()
- with self.test_session(graph=g):
+ with self.session(graph=g):
old.run()
def _error(self):
@@ -2873,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(2.0)
sess.run(
a,
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 76d4c2017c..2022fbcbaa 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -102,15 +102,6 @@ string TensorPBString(const TensorProto& pb) {
return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
public:
GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index 031b4a384e..f6aef5bc50 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -15,18 +15,20 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen_internal.h"
+#include <float.h>
#include <stdio.h>
+#include <iomanip>
#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -435,7 +437,12 @@ string AttrValueToPython(const string& type, const AttrValue& value,
if (std::isnan(value.f()) || std::isinf(value.f())) {
return strings::StrCat("float('", value.f(), "')");
} else {
- return strings::StrCat(value.f());
+ // Use locale-independent conversion.
+ static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+ std::ostringstream s;
+ s.imbue(std::locale::classic());
+ s << std::setprecision(FLT_DIG) << value.f();
+ return s.str();
}
} else if (type == "bool") {
return value.b() ? "True" : "False";
@@ -483,15 +490,6 @@ const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
return nullptr;
}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
const string& function_name)
: op_def_(op_def),
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 8eb943b960..e20ad5fd33 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -52,7 +52,7 @@ Status ReadOpListFromFile(const string& filename,
if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
.Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.GetResult(nullptr, &op_name)) {
- op_list->emplace_back(op_name.ToString());
+ op_list->emplace_back(op_name);
}
s = input_buffer->ReadLine(&line_contents);
}
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
index 48a834392b..7ee2b5b347 100644
--- a/tensorflow/python/framework/smart_cond.py
+++ b/tensorflow/python/framework/smart_cond.py
@@ -77,11 +77,9 @@ def smart_constant_value(pred):
pred_value = pred
elif isinstance(pred, ops.Tensor):
pred_value = tensor_util.constant_value(pred)
- # TODO(skyewm): consider folding this into tensor_util.constant_value when
- # _USE_C_API is removed (there may be performance and correctness bugs, so I
- # wanted to limit the change hidden behind _USE_C_API).
+ # TODO(skyewm): consider folding this into tensor_util.constant_value.
# pylint: disable=protected-access
- if pred_value is None and ops._USE_C_API:
+ if pred_value is None:
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index a45581190f..d1bdd9b80a 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -112,8 +112,6 @@ class SparseTensor(_TensorLike):
values: A 1-D tensor of any type and shape `[N]`.
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
- Returns:
- A `SparseTensor`.
"""
with ops.name_scope(None, "SparseTensor",
[indices, values, dense_shape]):
@@ -184,10 +182,31 @@ class SparseTensor(_TensorLike):
return self._dense_shape
@property
+ def shape(self):
+ """Get the `TensorShape` representing the shape of the dense tensor.
+
+ Returns:
+ A `TensorShape` object.
+ """
+ return tensor_util.constant_value_as_shape(self._dense_shape)
+
+ @property
def graph(self):
"""The `Graph` that contains the index, value, and dense_shape tensors."""
return self._indices.graph
+ def consumers(self):
+ """Returns a list of `Operation`s that consume this `SparseTensor`.
+
+ Returns:
+ A list of `Operation`s.
+ """
+ values_consumers = set(self._values.consumers())
+ indices_consumers = set(self._indices.consumers())
+ dense_shape_consumers = set(self._dense_shape.consumers())
+ return list(values_consumers \
+ .union(indices_consumers, dense_shape_consumers))
+
def __str__(self):
return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
self._indices, self._values, self._dense_shape)
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index c001fed3b0..22423c4f58 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -21,8 +21,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@@ -43,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertEqual(sp.dense_shape.dtype, dtypes.int64)
self.assertEqual(sp.get_shape(), (4, 5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sp.eval()
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
@@ -63,18 +65,30 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
sparse_tensor.is_sparse(
sparse_tensor.SparseTensorValue([[0]], [0], [1])))
+ def testConsumers(self):
+ sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
+ w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
+ out = sparse_ops.sparse_tensor_dense_matmul(sp, w)
+ self.assertEqual(len(sp.consumers()), 1)
+ self.assertEqual(sp.consumers()[0], out.op)
+
+ dense = sparse_ops.sparse_tensor_to_dense(sp)
+ self.assertEqual(len(sp.consumers()), 2)
+ self.assertTrue(dense.op in sp.consumers())
+ self.assertTrue(out.op in sp.consumers())
+
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
def test_convert_dense(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
- with self.test_session():
+ with self.cached_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py
index cee7398974..00759eb611 100644
--- a/tensorflow/python/framework/subscribe.py
+++ b/tensorflow/python/framework/subscribe.py
@@ -137,12 +137,7 @@ def _subscribe_new(tensor, side_effects, control_cache):
# are subscribed at the same time, we remove the control dependency from
# the original op only once and we add the dependencies to all the
# new identities.
- if ops._USE_C_API: # pylint: disable=protected-access
- new_control_inputs = consumer_op.control_inputs
- else:
- # Make a copy so we don't modify the actual control inputs (this is fixed
- # in the C API).
- new_control_inputs = list(consumer_op.control_inputs)
+ new_control_inputs = consumer_op.control_inputs
if tensor.op in new_control_inputs:
new_control_inputs.remove(tensor.op)
new_control_inputs.append(out.op)
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index d6de45fdc4..1d594e4078 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertFalse(c0.op in d.op.control_inputs)
self.assertTrue(c.op in d.op.control_inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
n_out = sess.run([n])
d_out = sess.run([d])
@@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
b = subscribe.subscribe(b,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
@@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(c_sub, c_sub3)
# Expect the three side effect graphs to have been evaluated.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
@@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
self.assertTrue(subscribe._is_subscribed_identity(v1_sub))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize the variables first.
sess.run([v1.initializer])
sess.run([v2.initializer])
@@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(tensor_array_sub, tensor_array.handle)
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([reader])
self.assertEqual(0, len(shared))
@@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
subscribe.subscribe(sparse_add.op.outputs,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([neg])
# All three ops have been processed.
@@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
# Verify that sub(x1) and sub(branch) are not.
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(cond)
self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index bd0f691a61..3c2a736fb9 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -498,7 +498,8 @@ class TensorShape(object):
If a tensor is produced by an operation of type `"Foo"`, its shape
may be inferred if there is a registered shape function for
- `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
+ `"Foo"`. See [Shape
+ functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c)
for details of shape functions and how to register them. Alternatively,
the shape may be set explicitly using `tf.Tensor.set_shape`.
"""
@@ -605,8 +606,8 @@ class TensorShape(object):
slice.
Raises:
- ValueError: If `key` is a slice, and any of its elements are negative, or
- if `self` is completely unknown and the step is set.
+ ValueError: If `key` is a slice and `self` is completely unknown and
+ the step is set.
"""
if self._dims is not None:
if isinstance(key, slice):
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index b14290c203..26170b000d 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
A `TensorProto`. Depending on the type, it may contain data in the
"tensor_content" attribute, which is not directly useful to Python programs.
To access the values you should convert the proto back to a numpy ndarray
- with `tensor_util.MakeNdarray(proto)`.
+ with `tf.make_ndarray(proto)`.
If `values` is a `TensorProto`, it is immediately returned; `dtype` and
`shape` are ignored.
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 395cf43b3f..bdf759f220 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase):
def __array__(self, dtype=None):
return np.asarray(self.array, dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MockArray(np.array([10, 20, 30]))
t = ops.convert_to_tensor(ma)
a = sess.run(t)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 9be6391b04..d63abd7f01 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -63,6 +63,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
@@ -369,6 +370,7 @@ def enable_c_shapes(fn):
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+
# pylint: enable=protected-access
return wrapper
@@ -398,6 +400,53 @@ def with_c_shapes(cls):
return cls
+def enable_cond_v2(fn):
+ """Decorator for enabling CondV2 on a test.
+
+ Note this enables using CondV2 after running the test class's setup/teardown
+ methods.
+
+ Args:
+ fn: the function to be wrapped
+
+ Returns:
+ The wrapped function
+ """
+
+ # pylint: disable=protected-access
+ def wrapper(*args, **kwargs):
+ prev_value = control_flow_ops._ENABLE_COND_V2
+ control_flow_ops._ENABLE_COND_V2 = True
+ try:
+ fn(*args, **kwargs)
+ finally:
+ control_flow_ops._ENABLE_COND_V2 = prev_value
+ # pylint: enable=protected-access
+
+ return wrapper
+
+
+def with_cond_v2(cls):
+ """Adds methods that call original methods but with CondV2 enabled.
+
+ Note this enables CondV2 in new methods after running the test class's
+ setup method.
+
+ Args:
+ cls: class to decorate
+
+ Returns:
+ cls with new test methods added
+ """
+ if control_flow_ops._ENABLE_COND_V2:
+ return cls
+
+ for name, value in cls.__dict__.copy().items():
+ if callable(value) and name.startswith("test"):
+ setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ return cls
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
@@ -416,28 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
- collection_sizes_before = {
- collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections}
+ if ops.has_default_graph():
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
- for collection_key in ops.get_default_graph().collections:
- collection = ops.get_collection(collection_key)
- size_before = collection_sizes_before.get(collection_key, 0)
- if len(collection) > size_before:
- raise AssertionError(
- ("Collection %s increased in size from "
- "%d to %d (current items %s).")
- % (collection_key, size_before, len(collection), collection))
- # Make sure our collection checks don't show up as leaked memory by
- # removing references to temporary variables.
- del collection
- del collection_key
- del size_before
- del collection_sizes_before
+ if ops.has_default_graph():
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).") %
+ (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
@@ -446,8 +498,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
# Using plain assert because not all classes using this decorator
# have assertLessEqual
assert new_count <= previous_count, (
- "new_count(%d) is not less than or equal to previous_count(%d)" % (
- new_count, previous_count))
+ "new_count(%d) is not less than or equal to previous_count(%d)" %
+ (new_count, previous_count))
gc.enable()
return decorator
@@ -485,19 +537,20 @@ def assert_no_new_tensors(f):
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
- if context.executing_eagerly():
- f(self, **kwargs)
- ops.reset_default_graph()
- else:
- # Run the test in a new graph so that collections get cleared when it's
- # done, but inherit the graph key so optimizers behave.
- outside_graph_key = ops.get_default_graph()._graph_key
- with ops.Graph().as_default():
- ops.get_default_graph()._graph_key = outside_graph_key
+ outside_executed_eagerly = context.executing_eagerly()
+ # Run the test in a new graph so that collections get cleared when it's
+ # done, but inherit the graph key so optimizers behave.
+ outside_graph_key = ops.get_default_graph()._graph_key
+ with ops.Graph().as_default():
+ ops.get_default_graph()._graph_key = outside_graph_key
+ if outside_executed_eagerly:
+ with context.eager_mode():
+ f(self, **kwargs)
+ else:
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
- context.get_default_context()._clear_caches() # pylint: disable=protected-access
+ context.context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
@@ -547,10 +600,12 @@ def assert_no_garbage_created(f):
return "<%s %d>" % (obj.__class__.__name__, id(obj))
logging.error(" Object type: %s", _safe_object_str(obj))
- logging.error(" Referrer types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
- logging.error(" Referent types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
+ logging.error(
+ " Referrer types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
+ logging.error(
+ " Referent types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
logging.error(" Object attribute names: %s", dir(obj))
logging.error(" Object __str__:")
logging.error(obj)
@@ -629,9 +684,8 @@ def generate_combinations_with_testcase_name(**kwargs):
for combination in combinations:
assert isinstance(combination, OrderedDict)
name = "".join([
- "_{}_{}".format(
- "".join(filter(str.isalnum, key)),
- "".join(filter(str.isalnum, str(value))))
+ "_{}_{}".format("".join(filter(str.isalnum, key)), "".join(
+ filter(str.isalnum, str(value))))
for key, value in combination.items()
])
named_combinations.append(
@@ -736,15 +790,19 @@ def run_in_graph_and_eager_modes(func=None,
run_eagerly = assert_no_new_tensors(
assert_no_garbage_created(run_eagerly))
- with context.eager_mode():
+ if reset_test:
+ # This decorator runs the wrapped test twice.
+ # Reset the test environment between runs.
+ self.tearDown()
+ self._tempdir = None
+ # Create a new graph for the eagerly executed version of this test for
+ # better isolation.
+ graph_for_eager_test = ops.Graph()
+ with graph_for_eager_test.as_default(), context.eager_mode():
if reset_test:
- # This decorator runs the wrapped test twice.
- # Reset the test environment between runs.
- self.tearDown()
- self._tempdir = None
self.setUp()
-
run_eagerly(self, **kwargs)
+ ops.dismantle_graph(graph_for_eager_test)
return decorated
@@ -811,6 +869,18 @@ def device(use_gpu):
yield
+class ErrorLoggingSession(session.Session):
+ """Wrapper around a Session that logs errors in run().
+ """
+
+ def run(self, *args, **kwargs):
+ try:
+ return super(ErrorLoggingSession, self).run(*args, **kwargs)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(str(e))
+ raise
+
+
@tf_export("test.TestCase")
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
@@ -967,21 +1037,60 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
- def test_session(self,
- graph=None,
- config=None,
- use_gpu=False,
- force_gpu=False):
+ def session(self, graph=None, config=None, use_gpu=False, force_gpu=False):
"""Returns a TensorFlow Session for use in executing tests.
- This method should be used for all functional tests.
+ Note that this will set this session and the graph as global defaults.
- This method behaves different than session.Session: for performance reasons
- `test_session` will by default (if `graph` is None) reuse the same session
- across tests. This means you may want to either call the function
- `reset_default_graph()` before tests, or if creating an explicit new graph,
- pass it here (simply setting it with `as_default()` won't do it), which will
- trigger the creation of a new session.
+ Use the `use_gpu` and `force_gpu` options to control where ops are run. If
+ `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
+ `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
+ possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
+ the CPU.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ with self.session(use_gpu=True):
+ valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ result = MyOperator(valid_input).eval()
+ self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
+ invalid_input = [-1.0, 2.0, 7.0]
+ with self.assertRaisesOpError("negative input not supported"):
+ MyOperator(invalid_input).eval()
+ ```
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ use_gpu: If True, attempt to run as many ops as possible on GPU.
+ force_gpu: If True, pin all ops to `/device:GPU:0`.
+
+ Yields:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ if context.executing_eagerly():
+ yield None
+ else:
+ with self._create_session(graph, config, force_gpu) as sess:
+ with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
+ yield sess
+
+ @contextlib.contextmanager
+ def cached_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Returns a TensorFlow Session for use in executing tests.
+
+ This method behaves differently than self.session(): for performance reasons
+ `cached_session` will by default reuse the same session within the same
+ test. The session returned by this function will only be closed at the end
+ of the test (in the TearDown function).
Use the `use_gpu` and `force_gpu` options to control where ops are run. If
`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
@@ -993,7 +1102,7 @@ class TensorFlowTestCase(googletest.TestCase):
```python
class MyOperatorTest(test_util.TensorFlowTestCase):
def testMyOperator(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True) as sess:
valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
result = MyOperator(valid_input).eval()
self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
@@ -1009,74 +1118,41 @@ class TensorFlowTestCase(googletest.TestCase):
use_gpu: If True, attempt to run as many ops as possible on GPU.
force_gpu: If True, pin all ops to `/device:GPU:0`.
- Returns:
+ Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
+ if context.executing_eagerly():
+ yield None
+ else:
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=True)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
+
+ @contextlib.contextmanager
+ def test_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Use cached_session instead."""
if self.id().endswith(".test_session"):
self.skipTest("Not a test.")
- def prepare_config(config):
- """Returns a config for sessions.
-
- Args:
- config: An optional config_pb2.ConfigProto to use to configure the
- session.
- Returns:
- A config_pb2.ConfigProto object.
- """
- if config is None:
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
- config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
- config.allow_soft_placement = False
- # Don't perform optimizations for tests so we don't inadvertently run
- # gpu ops on cpu
- config.graph_options.optimizer_options.opt_level = -1
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
if context.executing_eagerly():
yield None
- elif graph is None:
- if self._cached_session is None:
- self._cached_session = session.Session(
- graph=None, config=prepare_config(config))
- sess = self._cached_session
- with sess.graph.as_default(), sess.as_default():
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
- yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
else:
- with session.Session(graph=graph, config=prepare_config(config)) as sess:
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
+ if graph is None:
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=False)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
+ else:
+ with self.session(graph, config, use_gpu, force_gpu) as sess:
yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
# pylint: enable=g-doc-return-or-yield
@@ -1202,9 +1278,10 @@ class TensorFlowTestCase(googletest.TestCase):
msg: An optional string message to append to the failure message.
"""
# f1 == f2 is needed here as we might have: f1, f2 = inf, inf
- self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err,
- "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
- if msg is not None else ""))
+ self.assertTrue(
+ f1 == f2 or math.fabs(f1 - f2) <= err,
+ "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
+ if msg is not None else ""))
def assertArrayNear(self, farray1, farray2, err, msg=None):
"""Asserts that two float arrays are near each other.
@@ -1250,8 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
- (a.shape, b.shape))
+ # When the array rank is small, print its contents. Numpy array printing is
+ # implemented using inefficient recursion so prints can cause tests to
+ # time out.
+ if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
+ shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
+ "%s.") % (a.shape, b.shape, b)
+ else:
+ shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
+ b.shape)
+ self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#
@@ -1453,8 +1539,9 @@ class TensorFlowTestCase(googletest.TestCase):
msg = msg if msg else ""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
- " %s" % (a.shape, b.shape, msg))
+ self.assertEqual(
+ a.shape, b.shape, "Shape mismatch: expected %s, got %s."
+ " %s" % (a.shape, b.shape, msg))
same = (a == b)
if (a.dtype in [
@@ -1682,8 +1769,8 @@ class TensorFlowTestCase(googletest.TestCase):
self.fail(exception_type.__name__ + " not raised")
except Exception as e: # pylint: disable=broad-except
if not isinstance(e, exception_type) or not predicate(e):
- raise AssertionError("Exception of type %s: %s" % (str(type(e)),
- str(e)))
+ raise AssertionError(
+ "Exception of type %s: %s" % (str(type(e)), str(e)))
# pylint: enable=g-doc-return-or-yield
@@ -1719,8 +1806,9 @@ class TensorFlowTestCase(googletest.TestCase):
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
- self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
- (device1, device2, msg))
+ self.assertEqual(
+ device1, device2,
+ "Devices %s and %s are not equal. %s" % (device1, device2, msg))
# Fix Python 3 compatibility issues
if six.PY3:
@@ -1734,6 +1822,91 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: enable=invalid-name
+ @contextlib.contextmanager
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """Set the session and its graph to global default and constrain devices."""
+ if context.executing_eagerly():
+ yield None
+ else:
+ with sess.graph.as_default(), sess.as_default():
+ if force_gpu:
+ # Use the name of an actual device if one is detected, or
+ # '/device:GPU:0' otherwise
+ gpu_name = gpu_device_name()
+ if not gpu_name:
+ gpu_name = "/device:GPU:0"
+ with sess.graph.device(gpu_name):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device("/cpu:0"):
+ yield sess
+
+ def _create_session(self, graph, config, force_gpu):
+ """See session() for details."""
+ def prepare_config(config):
+ """Returns a config for sessions.
+
+ Args:
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+
+ Returns:
+ A config_pb2.ConfigProto object.
+ """
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+ return ErrorLoggingSession(graph=graph, config=prepare_config(config))
+
+ def _get_cached_session(self,
+ graph=None,
+ config=None,
+ force_gpu=False,
+ crash_if_inconsistent_args=True):
+ """See cached_session() for documentation."""
+ if self._cached_session is None:
+ sess = self._create_session(
+ graph=graph, config=config, force_gpu=force_gpu)
+ self._cached_session = sess
+ self._cached_graph = graph
+ self._cached_config = config
+ self._cached_force_gpu = force_gpu
+ return sess
+ else:
+ if crash_if_inconsistent_args and self._cached_graph is not graph:
+ raise ValueError("The graph used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_config is not config:
+ raise ValueError("The config used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and (self._cached_force_gpu is
+ not force_gpu):
+ raise ValueError(
+ "The force_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ return self._cached_session
+
@tf_export("test.create_local_cluster")
def create_local_cluster(num_workers,
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index f983cbef04..c4f8fa9108 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -22,6 +22,7 @@ import collections
import copy
import random
import threading
+import weakref
import numpy as np
@@ -40,6 +41,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -57,6 +59,30 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
{"hello": "Variable"}, ops.get_default_graph())
+ def test_session_functions(self):
+ with self.test_session() as sess:
+ sess_ref = weakref.ref(sess)
+ with self.cached_session(graph=None, config=None) as sess2:
+ # We make sure that sess2 is sess.
+ assert sess2 is sess
+ # We make sure we raise an exception if we use cached_session with
+ # different values.
+ with self.assertRaises(ValueError):
+ with self.cached_session(graph=ops.Graph()) as sess2:
+ pass
+ with self.assertRaises(ValueError):
+ with self.cached_session(force_gpu=True) as sess2:
+ pass
+ # We make sure that test_session will cache the session even after the
+ # with scope.
+ assert not sess_ref()._closed
+ with self.session() as unique_sess:
+ unique_sess_ref = weakref.ref(unique_sess)
+ with self.session() as sess2:
+ assert sess2 is not unique_sess
+ # We make sure the session is closed when we leave the with statement.
+ assert unique_sess_ref()._closed
+
def test_assert_equal_graph_def(self):
with ops.Graph().as_default() as g:
def_empty = g.as_graph_def()
@@ -92,6 +118,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
else:
print("MKL is disabled")
+ @test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsStr(self):
graph_str = "node { name: 'w1' op: 'params' }"
@@ -104,6 +131,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
# test original comparison
self.assertProtoEquals(graph_def, graph_def)
+ @test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsAny(self):
# Test assertProtoEquals with a protobuf.Any field.
meta_graph_def_str = """
@@ -132,6 +160,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
r'meta_graph_version: "inner"'):
self.assertProtoEquals("", meta_graph_def_outer)
+ @test_util.run_in_graph_and_eager_modes
def testNDArrayNear(self):
a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -139,6 +168,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadSucceeds(self):
def noop(ev):
@@ -152,6 +182,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue(event_arg.is_set())
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadFails(self):
def err_func():
@@ -163,6 +194,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue("integer division or modulo by zero" in str(fe.exception))
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadWithWrongAssertionFails(self):
x = 37
@@ -175,6 +207,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue("False is not true" in str(fe.exception))
+ @test_util.run_in_graph_and_eager_modes
def testMultipleThreadsWithOneFailure(self):
def err_func(i):
@@ -203,6 +236,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
original_op=op_orig)
raise errors.UnauthenticatedError(node_def, op, "true_err")
+ @test_util.run_in_graph_and_eager_modes
def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
with self.assertRaises(AssertionError):
self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
@@ -211,6 +245,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self._WeMustGoDeeper("name")
self._WeMustGoDeeper("orig")
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseTensors(self):
a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
a = constant_op.constant(a_raw_data)
@@ -226,17 +261,20 @@ class TestUtilTest(test_util.TensorFlowTestCase):
y_list = [a_raw_data, b]
self.assertAllClose(x_list, y_list)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseScalars(self):
self.assertAllClose(7, 7 + 1e-8)
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(7, 7 + 1e-5)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose({"a": 1}, 1)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseNamedtuples(self):
a = 7
b = (2., 3.)
@@ -249,6 +287,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(
my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDicts(self):
a = 7
b = (2., 3.)
@@ -276,6 +315,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseListOfNamedtuples(self):
my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
l1 = [
@@ -288,6 +328,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
]
self.assertAllClose(l1, l2)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseNestedStructure(self):
a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
self.assertAllClose(a, a)
@@ -301,6 +342,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
self.assertAllClose(a, b)
+ @test_util.run_in_graph_and_eager_modes
def testArrayNear(self):
a = [1, 2]
b = [1, 2, 5]
@@ -323,6 +365,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
y = [15]
control_flow_ops.Assert(x, y).run()
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllCloseAccordingToType(self):
# test plain int
self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
@@ -399,6 +442,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
half_rtol=1e-4, half_atol=1e-4
)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllEqual(self):
i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
@@ -408,6 +452,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
self.assertNotAllClose([0.1], [0.2])
@@ -424,6 +469,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([1.0, 1.0], x)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseRTol(self):
# Test with arrays
with self.assertRaises(AssertionError):
@@ -438,6 +484,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseATol(self):
# Test with arrays
with self.assertRaises(AssertionError):
@@ -452,6 +499,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLess(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
@@ -472,6 +520,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertAllLess(x, 95.0)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLessEqual(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
@@ -504,6 +553,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertAllInRange(b, 0, 1)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRange(self):
x = constant_op.constant([10.0, 15.0], name="x")
self.assertAllInRange(x, 10, 15)
@@ -516,24 +566,28 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllInRange(
x, 10, 15, open_lower_bound=True, open_upper_bound=True)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeErrorMessageEllipses(self):
x_init = np.array([[10.0, 15.0]] * 12)
x = constant_op.constant(x_init, name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 5, 10)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeDetectsNaNs(self):
x = constant_op.constant(
[[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 0.0, 2.0)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeWithInfinities(self):
x = constant_op.constant([10.0, np.inf], name="x")
self.assertAllInRange(x, 10, np.inf)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInSet(self):
b = constant_op.constant([True, False], name="b")
x = constant_op.constant([13, 37], name="x")
@@ -666,6 +720,22 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
+# Its own test case to reproduce variable sharing issues which only pop up when
+# setUp() is overridden and super() is not called.
+class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ pass # Intentionally does not call TensorFlowTestCase's super()
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_no_variable_sharing(self):
+ variable_scope.get_variable(
+ name="step_size",
+ initializer=np.array(1e-5, np.float32),
+ use_resource=True,
+ trainable=False)
+
+
class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_reference_cycle_decorator(self):
diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h
index b5364aa37a..d15858c1ee 100644
--- a/tensorflow/python/grappler/cost_analyzer.h
+++ b/tensorflow/python/grappler/cost_analyzer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
-#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
+#ifndef TENSORFLOW_PYTHON_GRAPPLER_COST_ANALYZER_H_
+#define TENSORFLOW_PYTHON_GRAPPLER_COST_ANALYZER_H_
#include <iostream>
#include "tensorflow/core/framework/cost_graph.pb.h"
@@ -80,4 +80,4 @@ class CostAnalyzer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ANALYZER_H_
+#endif // TENSORFLOW_PYTHON_GRAPPLER_COST_ANALYZER_H_
diff --git a/tensorflow/core/lib/gtl/optional.cc b/tensorflow/python/grappler/graph_analyzer.i
index 8dea073788..cc7b5358eb 100644
--- a/tensorflow/core/lib/gtl/optional.cc
+++ b/tensorflow/python/grappler/graph_analyzer.i
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/lib/gtl/optional.h"
+%{
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
+%}
-namespace tensorflow {
-namespace gtl {
+%{
+void GraphAnalyzer(const string& file_path, int n) {
+ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
+}
+%}
-nullopt_t::init_t nullopt_t::init;
-extern const nullopt_t nullopt{nullopt_t::init};
-
-} // namespace gtl
-} // namespace tensorflow
+void GraphAnalyzer(const string& file_path, int n);
diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py
new file mode 100644
index 0000000000..ec5544e38e
--- /dev/null
+++ b/tensorflow/python/grappler/graph_analyzer.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""A tool that finds all subgraphs of a given size in a TF graph.
+
+The subgraph patterns are sorted by occurrence, and only the transitive fanin
+part of the graph with regard to the fetch nodes is considered.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python import pywrap_tensorflow as tf_wrap
+from tensorflow.python.platform import app
+
+
+def main(_):
+ tf_wrap.GraphAnalyzer(FLAGS.input, FLAGS.n)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input",
+ type=str,
+ default=None,
+ help="Input file path for a TensorFlow MetaGraphDef.")
+ parser.add_argument(
+ "--n", type=int, default=None, help="The size of the subgraphs.")
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h
index 97ffafabe1..9764a75b29 100644
--- a/tensorflow/python/grappler/model_analyzer.h
+++ b/tensorflow/python/grappler/model_analyzer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_MODEL_ANALYZER_H_
-#define TENSORFLOW_CORE_GRAPPLER_COSTS_MODEL_ANALYZER_H_
+#ifndef TENSORFLOW_PYTHON_GRAPPLER_MODEL_ANALYZER_H_
+#define TENSORFLOW_PYTHON_GRAPPLER_MODEL_ANALYZER_H_
#include <iostream>
#include "tensorflow/core/framework/node_def.pb.h"
@@ -43,4 +43,4 @@ class ModelAnalyzer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_MODEL_ANALYZER_H_
+#endif // TENSORFLOW_PYTHON_GRAPPLER_MODEL_ANALYZER_H_
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index fa1ec51aa7..290e182a79 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -102,7 +102,6 @@ py_library(
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
@@ -140,6 +139,7 @@ py_library(
":backend",
"//tensorflow/python/data",
"//tensorflow/python/training/checkpointable:data_structures",
+ "//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",
],
)
@@ -388,7 +388,7 @@ py_test(
py_test(
name = "embeddings_test",
- size = "small",
+ size = "medium",
srcs = ["layers/embeddings_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -688,7 +688,7 @@ py_test(
py_test(
name = "training_test",
- size = "large",
+ size = "enormous",
srcs = ["engine/training_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
@@ -700,6 +700,20 @@ py_test(
)
py_test(
+ name = "feature_columns_integration_test",
+ size = "small",
+ srcs = ["engine/feature_columns_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/feature_column:feature_column_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "training_eager_test",
size = "medium",
srcs = ["engine/training_eager_test.py"],
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index 5cff1f8f9c..dd0bbcff39 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -45,7 +45,7 @@ class KerasActivationsTest(test.TestCase):
assert fn == ref_fn
def test_softmax(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softmax(x)])
test_values = np.random.random((2, 5))
@@ -59,7 +59,7 @@ class KerasActivationsTest(test.TestCase):
keras.activations.softmax(x)
def test_temporal_softmax(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(shape=(2, 2, 3))
f = keras.backend.function([x], [keras.activations.softmax(x)])
test_values = np.random.random((2, 2, 3)) * 10
@@ -73,7 +73,7 @@ class KerasActivationsTest(test.TestCase):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
- with self.test_session():
+ with self.cached_session():
positive_values = np.array([[1, 2]], dtype=keras.backend.floatx())
result = f([positive_values])[0]
self.assertAllClose(result, positive_values * scale, rtol=1e-05)
@@ -87,7 +87,7 @@ class KerasActivationsTest(test.TestCase):
def softplus(x):
return np.log(np.ones_like(x) + np.exp(x))
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softplus(x)])
test_values = np.random.random((2, 5))
@@ -99,7 +99,7 @@ class KerasActivationsTest(test.TestCase):
def softsign(x):
return np.divide(x, np.ones_like(x) + np.absolute(x))
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softsign(x)])
test_values = np.random.random((2, 5))
@@ -116,7 +116,7 @@ class KerasActivationsTest(test.TestCase):
return z / (1 + z)
sigmoid = np.vectorize(ref_sigmoid)
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.sigmoid(x)])
test_values = np.random.random((2, 5))
@@ -130,7 +130,7 @@ class KerasActivationsTest(test.TestCase):
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
return z
hard_sigmoid = np.vectorize(ref_hard_sigmoid)
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.hard_sigmoid(x)])
test_values = np.random.random((2, 5))
@@ -139,7 +139,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-05)
def test_relu(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.relu(x)])
test_values = np.random.random((2, 5))
@@ -148,7 +148,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, test_values, rtol=1e-05)
def test_elu(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.elu(x, 0.5)])
test_values = np.random.random((2, 5))
@@ -160,7 +160,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, true_result)
def test_tanh(self):
- with self.test_session():
+ with self.cached_session():
test_values = np.random.random((2, 5))
x = keras.backend.placeholder(ndim=2)
exp = keras.activations.tanh(x)
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index cd9462d6b5..a8b6d55e41 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Keras Applications are canned architectures with pre-trained weights."""
# pylint: disable=g-import-not-at-top
+# pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -25,13 +26,49 @@ from tensorflow.python.keras import engine
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
+
+# `get_submodules_from_kwargs` has been introduced in 1.0.5, but we would
+# like to be able to handle prior versions. Note that prior to 1.0.5,
+# `keras_applications` did not expose a `__version__` attribute.
+if not hasattr(keras_applications, 'get_submodules_from_kwargs'):
+
+ if 'engine' in tf_inspect.getfullargspec(
+ keras_applications.set_keras_submodules)[0]:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils,
+ engine=engine)
+ else:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils)
+
+
+def keras_modules_injection(base_fun):
+ """Decorator injecting tf.keras replacements for Keras modules.
+
+ Arguments:
+ base_fun: Application function to decorate (e.g. `MobileNet`).
+
+ Returns:
+ Decorated function that injects keyword argument for the tf.keras
+ modules required by the Applications.
+ """
+
+ def wrapper(*args, **kwargs):
+ if hasattr(keras_applications, 'get_submodules_from_kwargs'):
+ kwargs['backend'] = backend
+ kwargs['layers'] = layers
+ kwargs['models'] = models
+ kwargs['utils'] = utils
+ return base_fun(*args, **kwargs)
+ return wrapper
-keras_applications.set_keras_submodules(
- backend=backend,
- engine=engine,
- layers=layers,
- models=models,
- utils=utils)
from tensorflow.python.keras.applications.densenet import DenseNet121
from tensorflow.python.keras.applications.densenet import DenseNet169
@@ -39,7 +76,7 @@ from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
-# TODO(fchollet): enable MobileNetV2 in next version.
+from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.python.keras.applications.nasnet import NASNetLarge
from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py
index ef3198a937..b15ca5990a 100644
--- a/tensorflow/python/keras/applications/applications_test.py
+++ b/tensorflow/python/keras/applications/applications_test.py
@@ -32,7 +32,8 @@ MODEL_LIST = [
(applications.InceptionV3, 2048),
(applications.InceptionResNetV2, 1536),
(applications.MobileNet, 1024),
- # TODO(fchollet): enable MobileNetV2 in next version.
+ # TODO(fchollet): enable MobileNetV2 tests when a new TensorFlow test image
+ # is released with keras_applications upgraded to 1.0.5 or above.
(applications.DenseNet121, 1024),
(applications.DenseNet169, 1664),
(applications.DenseNet201, 1920),
@@ -44,11 +45,6 @@ MODEL_LIST = [
class ApplicationsTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(*MODEL_LIST)
- def test_classification_model(self, model_fn, _):
- model = model_fn(classes=1000, weights=None)
- self.assertEqual(model.output_shape[-1], 1000)
-
- @parameterized.parameters(*MODEL_LIST)
def test_feature_extration_model(self, model_fn, output_dim):
model = model_fn(include_top=False, weights=None)
self.assertEqual(model.output_shape, (None, None, None, output_dim))
diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index fbdcc66d2d..172848bbdb 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -20,18 +20,39 @@ from __future__ import division
from __future__ import print_function
from keras_applications import densenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-DenseNet121 = densenet.DenseNet121
-DenseNet169 = densenet.DenseNet169
-DenseNet201 = densenet.DenseNet201
-decode_predictions = densenet.decode_predictions
-preprocess_input = densenet.preprocess_input
-
-tf_export('keras.applications.densenet.DenseNet121',
- 'keras.applications.DenseNet121')(DenseNet121)
-tf_export('keras.applications.densenet.DenseNet169',
- 'keras.applications.DenseNet169')(DenseNet169)
-tf_export('keras.applications.densenet.DenseNet201',
- 'keras.applications.DenseNet201')(DenseNet201)
-tf_export('keras.applications.densenet.preprocess_input')(preprocess_input)
+
+@tf_export('keras.applications.densenet.DenseNet121',
+ 'keras.applications.DenseNet121')
+@keras_modules_injection
+def DenseNet121(*args, **kwargs):
+ return densenet.DenseNet121(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet169',
+ 'keras.applications.DenseNet169')
+@keras_modules_injection
+def DenseNet169(*args, **kwargs):
+ return densenet.DenseNet169(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet201',
+ 'keras.applications.DenseNet201')
+@keras_modules_injection
+def DenseNet201(*args, **kwargs):
+ return densenet.DenseNet201(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return densenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return densenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py
index 70f8f6fb32..c25b5c2bdd 100644
--- a/tensorflow/python/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/applications/imagenet_utils.py
@@ -19,27 +19,18 @@ from __future__ import division
from __future__ import print_function
from keras_applications import imagenet_utils
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-decode_predictions = imagenet_utils.decode_predictions
-preprocess_input = imagenet_utils.preprocess_input
-tf_export(
- 'keras.applications.imagenet_utils.decode_predictions',
- 'keras.applications.densenet.decode_predictions',
- 'keras.applications.inception_resnet_v2.decode_predictions',
- 'keras.applications.inception_v3.decode_predictions',
- 'keras.applications.mobilenet.decode_predictions',
- 'keras.applications.mobilenet_v2.decode_predictions',
- 'keras.applications.nasnet.decode_predictions',
- 'keras.applications.resnet50.decode_predictions',
- 'keras.applications.vgg16.decode_predictions',
- 'keras.applications.vgg19.decode_predictions',
- 'keras.applications.xception.decode_predictions',
-)(decode_predictions)
-tf_export(
- 'keras.applications.imagenet_utils.preprocess_input',
- 'keras.applications.resnet50.preprocess_input',
- 'keras.applications.vgg16.preprocess_input',
- 'keras.applications.vgg19.preprocess_input',
-)(preprocess_input)
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return imagenet_utils.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return imagenet_utils.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index 63debb4e0d..0b9ef371fa 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -20,13 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_resnet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionResNetV2 = inception_resnet_v2.InceptionResNetV2
-decode_predictions = inception_resnet_v2.decode_predictions
-preprocess_input = inception_resnet_v2.preprocess_input
-tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
- 'keras.applications.InceptionResNetV2')(InceptionResNetV2)
-tf_export(
- 'keras.applications.inception_resnet_v2.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
+ 'keras.applications.InceptionResNetV2')
+@keras_modules_injection
+def InceptionResNetV2(*args, **kwargs):
+ return inception_resnet_v2.InceptionResNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_resnet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_resnet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index 87534086c8..ab76826e17 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_v3
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionV3 = inception_v3.InceptionV3
-decode_predictions = inception_v3.decode_predictions
-preprocess_input = inception_v3.preprocess_input
-tf_export('keras.applications.inception_v3.InceptionV3',
- 'keras.applications.InceptionV3')(InceptionV3)
-tf_export('keras.applications.inception_v3.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_v3.InceptionV3',
+ 'keras.applications.InceptionV3')
+@keras_modules_injection
+def InceptionV3(*args, **kwargs):
+ return inception_v3.InceptionV3(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_v3.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_v3.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index 3528f027b3..1f71a5ae99 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import mobilenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-MobileNet = mobilenet.MobileNet
-decode_predictions = mobilenet.decode_predictions
-preprocess_input = mobilenet.preprocess_input
-tf_export('keras.applications.mobilenet.MobileNet',
- 'keras.applications.MobileNet')(MobileNet)
-tf_export('keras.applications.mobilenet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.mobilenet.MobileNet',
+ 'keras.applications.MobileNet')
+@keras_modules_injection
+def MobileNet(*args, **kwargs):
+ return mobilenet.MobileNet(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 9194c3ee14..52ac5959ad 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -19,4 +19,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(fchollet): export MobileNetV2 as part of the public API in next version.
+from keras_applications import mobilenet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('keras.applications.mobilenet_v2.MobileNetV2',
+ 'keras.applications.MobileNetV2')
+@keras_modules_injection
+def MobileNetV2(*args, **kwargs):
+ return mobilenet_v2.MobileNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index 26ff5db53f..44fc329d57 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -20,15 +20,32 @@ from __future__ import division
from __future__ import print_function
from keras_applications import nasnet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-NASNetMobile = nasnet.NASNetMobile
-NASNetLarge = nasnet.NASNetLarge
-decode_predictions = nasnet.decode_predictions
-preprocess_input = nasnet.preprocess_input
-tf_export('keras.applications.nasnet.NASNetMobile',
- 'keras.applications.NASNetMobile')(NASNetMobile)
-tf_export('keras.applications.nasnet.NASNetLarge',
- 'keras.applications.NASNetLarge')(NASNetLarge)
-tf_export('keras.applications.nasnet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.nasnet.NASNetMobile',
+ 'keras.applications.NASNetMobile')
+@keras_modules_injection
+def NASNetMobile(*args, **kwargs):
+ return nasnet.NASNetMobile(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.NASNetLarge',
+ 'keras.applications.NASNetLarge')
+@keras_modules_injection
+def NASNetLarge(*args, **kwargs):
+ return nasnet.NASNetLarge(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return nasnet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return nasnet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py
index 4d804a3c44..80d3f9044f 100644
--- a/tensorflow/python/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/applications/resnet50.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import resnet50
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-ResNet50 = resnet50.ResNet50
-decode_predictions = resnet50.decode_predictions
-preprocess_input = resnet50.preprocess_input
-tf_export('keras.applications.resnet50.ResNet50',
- 'keras.applications.ResNet50')(ResNet50)
+@tf_export('keras.applications.resnet50.ResNet50',
+ 'keras.applications.ResNet50')
+@keras_modules_injection
+def ResNet50(*args, **kwargs):
+ return resnet50.ResNet50(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return resnet50.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return resnet50.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index c420d9b81e..8557d26931 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg16
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG16 = vgg16.VGG16
-decode_predictions = vgg16.decode_predictions
-preprocess_input = vgg16.preprocess_input
-tf_export('keras.applications.vgg16.VGG16',
- 'keras.applications.VGG16')(VGG16)
+@tf_export('keras.applications.vgg16.VGG16',
+ 'keras.applications.VGG16')
+@keras_modules_injection
+def VGG16(*args, **kwargs):
+ return vgg16.VGG16(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg16.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg16.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index 73d3d1d1c3..8fc04413a0 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg19
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG19 = vgg19.VGG19
-decode_predictions = vgg19.decode_predictions
-preprocess_input = vgg19.preprocess_input
-tf_export('keras.applications.vgg19.VGG19',
- 'keras.applications.VGG19')(VGG19)
+@tf_export('keras.applications.vgg19.VGG19',
+ 'keras.applications.VGG19')
+@keras_modules_injection
+def VGG19(*args, **kwargs):
+ return vgg19.VGG19(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg19.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg19.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index 5b221ac8e0..960e6dec69 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import xception
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-Xception = xception.Xception
-decode_predictions = xception.decode_predictions
-preprocess_input = xception.preprocess_input
-tf_export('keras.applications.xception.Xception',
- 'keras.applications.Xception')(Xception)
-tf_export('keras.applications.xception.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.xception.Xception',
+ 'keras.applications.Xception')
+@keras_modules_injection
+def Xception(*args, **kwargs):
+ return xception.Xception(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return xception.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return xception.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 418586b85f..7768caeaf0 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -94,6 +94,14 @@ _IMAGE_DATA_FORMAT = 'channels_last'
# We assume our devices don't change henceforth.
_LOCAL_DEVICES = None
+# This dictionary holds a mapping between a graph and variables to initialize
+# in the graph.
+_GRAPH_VARIABLES = {}
+
+# This dictionary holds a mapping between a graph and TF optimizers created in
+# the graph.
+_GRAPH_TF_OPTIMIZERS = {}
+
@tf_export('keras.backend.backend')
def backend():
@@ -309,6 +317,8 @@ def clear_session():
"""
global _SESSION
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
+ global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned
+ global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned
ops.reset_default_graph()
reset_uids()
_SESSION = None
@@ -316,6 +326,8 @@ def clear_session():
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
+ _GRAPH_VARIABLES.pop(ops.get_default_graph(), None)
+ _GRAPH_TF_OPTIMIZERS.pop(ops.get_default_graph(), None)
@tf_export('keras.backend.manual_variable_initialization')
@@ -431,13 +443,7 @@ def get_session():
session = default_session
else:
if _SESSION is None:
- if not os.environ.get('OMP_NUM_THREADS'):
- config = config_pb2.ConfigProto(allow_soft_placement=True)
- else:
- num_thread = int(os.environ.get('OMP_NUM_THREADS'))
- config = config_pb2.ConfigProto(
- intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
- _SESSION = session_module.Session(config=config)
+ _SESSION = session_module.Session(config=get_default_session_config())
session = _SESSION
if not _MANUAL_VAR_INIT:
with session.graph.as_default():
@@ -456,6 +462,16 @@ def set_session(session):
_SESSION = session
+def get_default_session_config():
+ if not os.environ.get('OMP_NUM_THREADS'):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ num_thread = int(os.environ.get('OMP_NUM_THREADS'))
+ config = config_pb2.ConfigProto(
+ intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
+ return config
+
+
# DEVICE MANIPULATION
@@ -651,12 +667,42 @@ def variable(value, dtype=None, name=None, constraint=None):
elif hasattr(value, 'shape'):
v._keras_shape = int_shape(value)
v._uses_learning_phase = False
+ track_variable(v)
return v
+def track_tf_optimizer(tf_optimizer):
+ """Tracks the given TF optimizer for initialization of its variables."""
+ if context.executing_eagerly():
+ return
+ graph = ops.get_default_graph()
+ if graph not in _GRAPH_TF_OPTIMIZERS:
+ _GRAPH_TF_OPTIMIZERS[graph] = set()
+ _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
+
+
+def track_variable(v):
+ """Tracks the given variable for initialization."""
+ if context.executing_eagerly():
+ return
+ graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
+ if graph not in _GRAPH_VARIABLES:
+ _GRAPH_VARIABLES[graph] = set()
+ _GRAPH_VARIABLES[graph].add(v)
+
+
+def _get_variables(graph=None):
+ """Returns variables corresponding to the given graph for initialization."""
+ assert not context.executing_eagerly()
+ variables = _GRAPH_VARIABLES.get(graph, set())
+ for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
+ variables.update(opt.optimizer.variables())
+ return variables
+
+
def _initialize_variables(session):
"""Utility to initialize uninitialized variables on the fly."""
- variables = variables_module.global_variables()
+ variables = _get_variables(ops.get_default_graph())
candidate_vars = []
for v in variables:
if not getattr(v, '_keras_initialized', False):
@@ -974,6 +1020,7 @@ def zeros(shape, dtype=None, name=None):
v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
+ track_variable(v)
return v
@@ -1008,6 +1055,7 @@ def ones(shape, dtype=None, name=None):
v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
+ track_variable(v)
return v
@@ -2766,7 +2814,8 @@ class Function(object):
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: A name to help users identify what this function does.
- session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
+ session_kwargs: Arguments to `tf.Session.run()`:
+ `fetches`, `feed_dict`, `options`, `run_metadata`.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
@@ -2800,6 +2849,8 @@ class Function(object):
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
+ self.run_options = session_kwargs.pop('options', None)
+ self.run_metadata = session_kwargs.pop('run_metadata', None)
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates
# This requires us to wrap fetches in `identity` ops.
@@ -2857,6 +2908,9 @@ class Function(object):
callable_opts.fetch.append(x.name)
# Handle updates.
callable_opts.target.append(self.updates_op.name)
+ # Handle run_options.
+ if self.run_options:
+ callable_opts.run_options.CopyFrom(self.run_options)
# Create callable.
callable_fn = session._make_callable_from_options(callable_opts)
# Cache parameters corresponding to the generated callable, so that
@@ -2915,7 +2969,8 @@ class Function(object):
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
- fetched = self._callable_fn(*array_vals)
+ fetched = self._callable_fn(*array_vals,
+ run_metadata=self.run_metadata)
self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 40e7910061..266af56611 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
import scipy.sparse
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -118,7 +119,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(keras.backend.get_uid('foo'), 1)
def test_learning_phase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
keras.backend.set_learning_phase(1)
self.assertEqual(keras.backend.learning_phase(), 1)
with self.assertRaises(ValueError):
@@ -132,7 +133,7 @@ class BackendUtilsTest(test.TestCase):
sess.run(y, feed_dict={x: np.random.random((2, 3))})
def test_learning_phase_scope(self):
- with self.test_session():
+ with self.cached_session():
initial_learning_phase = keras.backend.learning_phase()
with keras.backend.learning_phase_scope(1) as lp:
self.assertEqual(lp, 1)
@@ -155,7 +156,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(keras.backend.int_shape(x), (None, 4))
def test_in_train_phase(self):
- with self.test_session():
+ with self.cached_session():
y1 = keras.backend.variable(1)
y2 = keras.backend.variable(2)
y = keras.backend.in_train_phase(y1, y2)
@@ -193,7 +194,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(y.op.name[:12], 'StopGradient')
def test_function_tf_feed_symbols(self):
- with self.test_session():
+ with self.cached_session():
# Test feeding a resource variable to `function`.
x1 = keras.backend.placeholder(shape=())
x2 = keras.backend.placeholder(shape=())
@@ -231,7 +232,7 @@ class BackendUtilsTest(test.TestCase):
# keras.backend.function() these do not have control dependency on `outputs`
# so they can run in parallel. Also they should not contribute to output of
# keras.backend.function().
- with self.test_session():
+ with self.cached_session():
x = keras.backend.variable(0.)
y = keras.backend.variable(0.)
x_placeholder = keras.backend.placeholder(shape=())
@@ -252,7 +253,7 @@ class BackendUtilsTest(test.TestCase):
# constructor but we can modify the values in the dictionary. Through
# this feed_dict we can provide additional substitutions besides Keras
# inputs.
- with self.test_session():
+ with self.cached_session():
x = keras.backend.variable(0.)
y = keras.backend.variable(0.)
x_placeholder = keras.backend.placeholder(shape=())
@@ -277,6 +278,29 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
+ def test_function_tf_run_options_with_run_metadata(self):
+ with self.test_session():
+ x_placeholder = keras.backend.placeholder(shape=())
+ y_placeholder = keras.backend.placeholder(shape=())
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ # enable run_options.
+ f = keras.backend.function(inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder],
+ options=run_options,
+ run_metadata=run_metadata)
+ output = f([10., 20.])
+ self.assertEqual(output, [30.])
+ self.assertGreater(len(run_metadata.partition_graphs), 0)
+ # disable run_options.
+ f1 = keras.backend.function(inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder],
+ run_metadata=run_metadata)
+ output1 = f1([10., 20.])
+ self.assertEqual(output1, [30.])
+ self.assertEqual(len(run_metadata.partition_graphs), 0)
+
def test_function_fetch_callbacks(self):
class CallbackStub(object):
@@ -289,7 +313,7 @@ class BackendUtilsTest(test.TestCase):
self.times_called += 1
self.callback_result = result
- with self.test_session():
+ with self.cached_session():
callback = CallbackStub()
x_placeholder = keras.backend.placeholder(shape=())
y_placeholder = keras.backend.placeholder(shape=())
@@ -311,39 +335,39 @@ class BackendUtilsTest(test.TestCase):
class BackendVariableTest(test.TestCase):
def test_zeros(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
val = keras.backend.eval(x)
self.assertAllClose(val, np.zeros((3, 4)))
def test_ones(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.ones((3, 4))
val = keras.backend.eval(x)
self.assertAllClose(val, np.ones((3, 4)))
def test_eye(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.eye(4)
val = keras.backend.eval(x)
self.assertAllClose(val, np.eye(4))
def test_zeros_like(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
y = keras.backend.zeros_like(x)
val = keras.backend.eval(y)
self.assertAllClose(val, np.zeros((3, 4)))
def test_ones_like(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
y = keras.backend.ones_like(x)
val = keras.backend.eval(y)
self.assertAllClose(val, np.ones((3, 4)))
def test_random_uniform_variable(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.random_uniform_variable((30, 20), low=1, high=2, seed=0)
val = keras.backend.eval(x)
self.assertAllClose(val.mean(), 1.5, atol=1e-1)
@@ -351,7 +375,7 @@ class BackendVariableTest(test.TestCase):
self.assertAllClose(val.min(), 1., atol=1e-1)
def test_random_normal_variable(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.random_normal_variable((30, 20), 1., 0.5,
seed=0)
val = keras.backend.eval(x)
@@ -359,20 +383,20 @@ class BackendVariableTest(test.TestCase):
self.assertAllClose(val.std(), 0.5, atol=1e-1)
def test_count_params(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((4, 5))
val = keras.backend.count_params(x)
self.assertAllClose(val, 20)
def test_constant(self):
- with self.test_session():
+ with self.cached_session():
ref_val = np.random.random((3, 4)).astype('float32')
x = keras.backend.constant(ref_val)
val = keras.backend.eval(x)
self.assertAllClose(val, ref_val)
def test_sparse_variable(self):
- with self.test_session():
+ with self.cached_session():
val = scipy.sparse.eye(10)
x = keras.backend.variable(val)
self.assertTrue(isinstance(x, sparse_tensor.SparseTensor))
@@ -421,7 +445,7 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.argmax, np.argmax),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
keras_kwargs={'axis': 1},
np_kwargs={'axis': 1})
@@ -447,7 +471,7 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.exp, np.exp),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7))
ops_to_test = [
@@ -455,19 +479,19 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.log, np.log),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op,
input_shape=(4, 7),
negative_values=False)
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.clip, np.clip,
input_shape=(6, 4),
keras_kwargs={'min_value': 0.1, 'max_value': 2.4},
np_kwargs={'a_min': 0.1, 'a_max': 1.4})
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.pow, np.power,
input_shape=(6, 4),
@@ -486,14 +510,14 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.minimum, np.minimum),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_two_inputs_op_to_numpy(keras_op, np_op,
input_shape_a=(4, 7),
input_shape_b=(4, 7))
def test_relu(self):
x = ops.convert_to_tensor([[-4, 0], [2, 7]], 'float32')
- with self.test_session():
+ with self.cached_session():
# standard relu
relu_op = keras.backend.relu(x)
self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
@@ -555,7 +579,7 @@ class BackendLinearAlgebraTest(test.TestCase):
class BackendShapeOpsTest(test.TestCase):
def test_reshape(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.reshape, np.reshape,
input_shape=(4, 7),
keras_args=[(2, 14)],
@@ -568,7 +592,7 @@ class BackendShapeOpsTest(test.TestCase):
self.assertEqual(y.get_shape().as_list(), [1, 2, 5])
def test_permute_dimensions(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.permute_dimensions,
np.transpose,
input_shape=(4, 7),
@@ -647,14 +671,14 @@ class BackendShapeOpsTest(test.TestCase):
self.assertEqual(y.get_shape().as_list(), [1, 2, 3])
def test_flatten(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.flatten,
np.reshape,
input_shape=(4, 7, 6),
np_args=[(4 * 7 * 6,)])
def test_batch_flatten(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.batch_flatten,
np.reshape,
input_shape=(4, 7, 6),
@@ -669,7 +693,7 @@ class BackendShapeOpsTest(test.TestCase):
y[:, padding[0]:-padding[1], :] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.temporal_padding,
ref_op,
input_shape=(4, 7, 6),
@@ -692,7 +716,7 @@ class BackendShapeOpsTest(test.TestCase):
y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1]] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.spatial_2d_padding,
ref_op,
@@ -735,7 +759,7 @@ class BackendShapeOpsTest(test.TestCase):
padding[2][0]:-padding[2][1]] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.spatial_3d_padding,
ref_op,
@@ -757,7 +781,7 @@ class BackendShapeOpsTest(test.TestCase):
class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
def test_bias_add(self):
- with self.test_session():
+ with self.cached_session():
keras_op = keras.backend.bias_add
np_op = np.add
compare_two_inputs_op_to_numpy(keras_op, np_op,
@@ -783,7 +807,8 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
keras.backend.bias_add(x, b, data_format='unknown')
def test_bias_add_channels_first(self):
- with self.test_session():
+ with self.cached_session():
+
def keras_op(x, b):
return keras.backend.bias_add(x, b, data_format='channels_first')
@@ -959,7 +984,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
strides,
output_shape,
'channels_last')
- with self.test_session():
+ with self.cached_session():
conv_cf = keras.backend.eval(conv_cf)
conv_cl = keras.backend.eval(conv_cl)
@@ -1009,7 +1034,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
output_shape,
'channels_last')
- with self.test_session():
+ with self.cached_session():
local_conv = keras.backend.eval(local_conv)
local_conv_dim = keras.backend.eval(local_conv_dim)
@@ -1167,7 +1192,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
- with self.test_session():
+ with self.cached_session():
for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
@@ -1263,7 +1288,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
- with self.test_session():
+ with self.cached_session():
for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
@@ -1359,7 +1384,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
class TestCTC(test.TestCase):
def test_ctc_decode(self):
- with self.test_session():
+ with self.cached_session():
depth = 6
seq_len_0 = 5
input_prob_matrix_0 = np.asarray(
@@ -1384,8 +1409,8 @@ class TestCTC(test.TestCase):
np.array([seq_len_0], dtype=np.int32))
# batch_size length vector of negative log probabilities
log_prob_truth = np.array([
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -3.5821197, # output beam 0
+ -3.777835 # output beam 1
], np.float32)[np.newaxis, :]
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
@@ -1408,7 +1433,7 @@ class TestCTC(test.TestCase):
self.assertAllClose(log_prob_truth, log_prob_pred)
def test_ctc_batch_cost(self):
- with self.test_session():
+ with self.cached_session():
label_lens = np.expand_dims(np.asarray([5, 4]), 1)
input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps
loss_log_probs = [3.34211, 5.42262]
@@ -1464,13 +1489,13 @@ class TestCTC(test.TestCase):
class TestRandomOps(test.TestCase):
def test_random_binomial(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
x = keras.backend.random_binomial((1000, 1000), p=0.5)
self.assertAllClose(np.mean(keras.backend.eval(x)), 0.5, atol=0.1)
def test_truncated_normal(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
x = keras.backend.truncated_normal((1000, 1000), mean=0.0, stddev=1.0)
y = keras.backend.eval(x)
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index e84e023384..7675a6586f 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -235,11 +235,8 @@ class KerasCallbacksTest(test.TestCase):
num_classes=NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -298,9 +295,8 @@ class KerasCallbacksTest(test.TestCase):
test_samples=50,
input_shape=(1,),
num_classes=NUM_CLASSES)
- model = keras.models.Sequential((keras.layers.Dense(
- 1, input_dim=1, activation='relu'), keras.layers.Dense(
- 1, activation='sigmoid'),))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=1, num_classes=1, input_dim=1)
model.compile(
optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
@@ -334,11 +330,8 @@ class KerasCallbacksTest(test.TestCase):
num_classes=NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
@@ -388,12 +381,8 @@ class KerasCallbacksTest(test.TestCase):
def make_model():
random_seed.set_random_seed(1234)
np.random.seed(1337)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
-
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1),
@@ -498,12 +487,8 @@ class KerasCallbacksTest(test.TestCase):
def make_model():
np.random.seed(1337)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
-
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.SGD(lr=0.1),
@@ -985,9 +970,8 @@ class KerasCallbacksTest(test.TestCase):
yield x, y
with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
- model.add(keras.layers.Dense(10, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
model.compile(
loss='categorical_crossentropy',
optimizer='sgd',
@@ -1083,11 +1067,8 @@ class KerasCallbacksTest(test.TestCase):
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
model.compile(
loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
@@ -1179,40 +1160,36 @@ class KerasCallbacksTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_Tensorboard_eager(self):
- with self.test_session():
- temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
-
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=TRAIN_SAMPLES,
- test_samples=TEST_SAMPLES,
- input_shape=(INPUT_DIM,),
- num_classes=NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test)
- y_train = keras.utils.to_categorical(y_train)
-
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
- model.compile(
- loss='binary_crossentropy',
- optimizer=adam.AdamOptimizer(0.01),
- metrics=['accuracy'])
-
- cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)]
+ temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=2,
- verbose=0)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
- self.assertTrue(os.path.exists(temp_dir))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=adam.AdamOptimizer(0.01),
+ metrics=['accuracy'])
+
+ cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)]
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=2,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(temp_dir))
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
diff --git a/tensorflow/python/keras/constraints_test.py b/tensorflow/python/keras/constraints_test.py
index 84e2db1033..4f674ea7c5 100644
--- a/tensorflow/python/keras/constraints_test.py
+++ b/tensorflow/python/keras/constraints_test.py
@@ -49,7 +49,7 @@ class KerasConstraintsTest(test.TestCase):
assert fn.__class__ == ref_fn.__class__
def test_max_norm(self):
- with self.test_session():
+ with self.cached_session():
array = get_example_array()
for m in get_test_values():
norm_instance = keras.constraints.max_norm(m)
@@ -69,13 +69,13 @@ class KerasConstraintsTest(test.TestCase):
self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05)
def test_non_neg(self):
- with self.test_session():
+ with self.cached_session():
non_neg_instance = keras.constraints.non_neg()
normed = non_neg_instance(keras.backend.variable(get_example_array()))
assert np.all(np.min(keras.backend.eval(normed), axis=1) == 0.)
def test_unit_norm(self):
- with self.test_session():
+ with self.cached_session():
unit_norm_instance = keras.constraints.unit_norm()
normalized = unit_norm_instance(
keras.backend.variable(get_example_array()))
@@ -87,7 +87,7 @@ class KerasConstraintsTest(test.TestCase):
assert np.abs(largest_difference) < 10e-5
def test_min_max_norm(self):
- with self.test_session():
+ with self.cached_session():
array = get_example_array()
for m in get_test_values():
norm_instance = keras.constraints.min_max_norm(min_value=m,
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index d6d3db21fb..cb19a412a2 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
+import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
@@ -42,7 +42,6 @@ from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint:
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import function_utils
@@ -50,6 +49,7 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.docs import doc_controls
class CallConvention(enum.Enum):
@@ -79,6 +79,7 @@ class Layer(checkpointable.CheckpointableBase):
Users will just instantiate a layer and then treat it as a callable.
We recommend that descendants of `Layer` implement the following methods:
+
* `__init__()`: Save configuration in member variables
* `build()`: Called once from `__call__`, when we know the shapes of inputs
and `dtype`. Should have the calls to `add_weight()`, and then
@@ -272,6 +273,7 @@ class Layer(checkpointable.CheckpointableBase):
return []
return self._updates
+ @doc_controls.for_subclass_implementers
def add_update(self, updates, inputs=None):
"""Add update op(s), potentially dependent on layer inputs.
@@ -372,6 +374,7 @@ class Layer(checkpointable.CheckpointableBase):
else:
return self._losses
+ @doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
"""Add loss tensor(s), potentially dependent on layer inputs.
@@ -463,10 +466,12 @@ class Layer(checkpointable.CheckpointableBase):
"""Creates the variables of the layer."""
self.built = True
+ @doc_controls.for_subclass_implementers
def add_variable(self, *args, **kwargs):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
+ @doc_controls.for_subclass_implementers
def add_weight(self,
name,
shape,
@@ -477,9 +482,9 @@ class Layer(checkpointable.CheckpointableBase):
constraint=None,
partitioner=None,
use_resource=None,
- synchronization=vs.VariableSynchronization.AUTO,
- aggregation=vs.VariableAggregation.NONE,
- getter=None):
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
+ **kwargs):
"""Adds a new variable to the layer, or gets an existing one; returns it.
Arguments:
@@ -507,7 +512,8 @@ class Layer(checkpointable.CheckpointableBase):
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
- getter: Variable getter argument to be passed to the `Checkpointable` API.
+ **kwargs: Additional keyword arguments. Accepted values are `getter` and
+ `collections`.
Returns:
The created variable. Usually either a `Variable` or `ResourceVariable`
@@ -520,6 +526,13 @@ class Layer(checkpointable.CheckpointableBase):
ValueError: When giving unsupported dtype and no initializer or when
trainable has been set to True with synchronization set as `ON_READ`.
"""
+ # Validate optional keyword arguments.
+ for kwarg in kwargs:
+ if kwarg not in ['getter', 'collections']:
+ raise TypeError('Unknown keyword argument:', kwarg)
+ getter = kwargs.pop('getter', None)
+ collections = kwargs.pop('collections', None)
+
if dtype is None:
dtype = self.dtype or backend.floatx()
dtype = dtypes.as_dtype(dtype)
@@ -527,7 +540,7 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
- if synchronization == vs.VariableSynchronization.ON_READ:
+ if synchronization == tf_variables.VariableSynchronization.ON_READ:
if trainable:
raise ValueError(
'Synchronization value can be set to '
@@ -568,8 +581,10 @@ class Layer(checkpointable.CheckpointableBase):
trainable=trainable and self.trainable,
partitioner=partitioner,
use_resource=use_resource,
+ collections=collections,
synchronization=synchronization,
aggregation=aggregation)
+ backend.track_variable(variable)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
@@ -646,6 +661,7 @@ class Layer(checkpointable.CheckpointableBase):
activity_regularization = self._activity_regularizer(output)
self.add_loss(activity_regularization, inputs=inputs)
+ @doc_controls.for_subclass_implementers
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
@@ -985,7 +1001,7 @@ class Layer(checkpointable.CheckpointableBase):
self.build(input_shape)
with context.graph_mode():
- graph = eager_function.CapturingGraph()
+ graph = eager_function.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
inputs = [generate_placeholders_from_shape(shape)
@@ -1412,11 +1428,13 @@ class Layer(checkpointable.CheckpointableBase):
'instead.' % self.name)
@property
+ @doc_controls.do_not_doc_inheritable
def inbound_nodes(self):
"""Deprecated, do NOT use! Only for compatibility with external Keras."""
return self._inbound_nodes
@property
+ @doc_controls.do_not_doc_inheritable
def outbound_nodes(self):
"""Deprecated, do NOT use! Only for compatibility with external Keras."""
return self._outbound_nodes
@@ -1871,7 +1889,7 @@ def get_default_graph_uid_map():
graph = ops.get_default_graph()
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
if name_uid_map is None:
- name_uid_map = collections.defaultdict(int)
+ name_uid_map = collections_lib.defaultdict(int)
backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
return name_uid_map
@@ -1886,8 +1904,9 @@ def make_variable(name,
validate_shape=True,
constraint=None,
use_resource=None,
- synchronization=vs.VariableSynchronization.AUTO,
- aggregation=vs.VariableAggregation.NONE,
+ collections=None,
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1915,10 +1934,12 @@ def make_variable(name,
then this parameter is ignored and any added variables are also
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
- caching_device: Passed to `vs.variable`.
- validate_shape: Passed to `vs.variable`.
+ caching_device: Passed to `tf.Variable`.
+ validate_shape: Passed to `tf.Variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
@@ -1951,7 +1972,7 @@ def make_variable(name,
if use_resource is None:
use_resource = True
- v = vs.variable(
+ v = tf_variables.Variable(
initial_value=init_val,
name=name,
trainable=trainable,
@@ -1960,6 +1981,7 @@ def make_variable(name,
validate_shape=validate_shape,
constraint=constraint,
use_resource=use_resource,
+ collections=collections,
synchronization=synchronization,
aggregation=aggregation)
return v
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index fcb073322c..b28df75493 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,8 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import backend
+from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
@@ -46,7 +48,7 @@ def set_weights(distribution_strategy, dist_model, weights):
assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
weights = weights[num_param:]
- backend.get_session().run(assign_ops)
+ K.get_session().run(assign_ops)
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
@@ -211,7 +213,10 @@ def validate_distributed_dataset_inputs(distribution_strategy, x, y):
# validate the input and targets.
x_values_list = validate_per_device_inputs(distribution_strategy, x)
- y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ if y is not None:
+ y_values_list = validate_per_device_inputs(distribution_strategy, y)
+ else:
+ y_values_list = None
# Return the unwrapped values to avoid calling `unwrap` a second time.
return x_values_list, y_values_list
@@ -269,3 +274,91 @@ def validate_all_tensor_shapes(x, x_values):
if x_shape != x_values[i].get_shape().as_list():
raise ValueError('Input tensor shapes do not match for distributed tensor'
' inputs {}'.format(x))
+
+
+def configure_and_create_session(distribution_strategy):
+ """Configure session config and create a session with it."""
+ # TODO(priyag): Throw error if a session already exists.
+ session_config = K.get_default_session_config()
+ distribution_strategy.configure(session_config)
+
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ # TODO(priyag): Remove this workaround when Distributed Coordinator is
+ # integrated with keras and we can create a session from there.
+ master = distribution_strategy._tpu_cluster_resolver.master() # pylint: disable=protected-access
+ session = session_module.Session(config=session_config, target=master)
+ else:
+ session = session_module.Session(config=session_config)
+
+ K.set_session(session)
+
+
+def validate_inputs(x, y):
+ """Validate inputs when using DistributionStrategy.
+
+ Args:
+ x: Model Inputs.
+ y: Model Targets.
+
+ Raises:
+ ValueError: if input is not a Dataset or a numpy array.
+ """
+ if isinstance(x, list) or isinstance(y, list):
+ raise ValueError('DistributionStrategy does not support lists of numpy'
+ 'arrays. You must pass a Dataset object or a numpy array '
+ 'as input.')
+
+ if isinstance(x, dict) or isinstance(y, dict):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'dict. You must pass a Dataset object or a numpy array as '
+ 'input.')
+
+ if isinstance(x, iterator_ops.Iterator) or \
+ isinstance(y, iterator_ops.Iterator):
+ raise ValueError('DistributionStrategy does not support inputs of type '
+ 'Iterator. You must pass a Dataset object or a numpy '
+ 'array as input.')
+
+
+def get_input_batch_params(first_x_value, batch_size, current_strategy):
+ """Calculate the number of batches and steps/steps_per_epoch.
+
+ Args:
+ first_x_value: This is the first input numpy array that is passed in as the
+ model input.
+ batch_size: The specified batch_size or the default batch_size of 32.
+ current_strategy: The current DistributionStrategy used to compile the
+ model.
+
+ Returns:
+ The steps or steps_per_epoch argument depending on if a user is
+ calling `fit`, `evaluate` or `predict`.
+
+ Raises:
+ ValueError: If the number of batches or steps evaluates to 0.
+
+ """
+ num_batches = first_x_value.shape[0] // batch_size
+ if not num_batches:
+ raise ValueError('Please specify a batch_size that is smaller than'
+ 'the number of input samples %d.' % first_x_value.shape[0])
+ # TODO(anjalisridhar): TPU currently supports using the num_towers property.
+ # We might want to look into implementing worker_devices. In multi worker
+ # strategy, perhaps num_towers works better?
+ steps = num_batches // current_strategy.num_towers
+ if not steps:
+ # TODO(anjalisridhar): Number of towers in the error message may not convey
+ # what we want to the user. Is there another terminology that we can use
+ # that is consistent across different strategies.
+ raise ValueError('The number of batches %d is smaller than the number '
+ 'of towers %d used for DistributionStrategy. ' %
+ num_batches, current_strategy.num_towers)
+ return steps
+
+
+def get_batch_dimension(iterator):
+ shapes = nest.flatten(iterator.output_shapes)
+ # Take the batch size from the first element, as it should be the same for
+ # all.
+ dims = shapes[0].dims
+ return dims[0] if dims else None
diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py
new file mode 100644
index 0000000000..e0478ee357
--- /dev/null
+++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py
@@ -0,0 +1,237 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests specific to Feature Columns integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
+
+
+class TestDNNModel(keras.models.Model):
+
+ def __init__(self, feature_columns, units, name=None, **kwargs):
+ super(TestDNNModel, self).__init__(name=name, **kwargs)
+ self._input_layer = fc.FeatureLayer(feature_columns, name='input_layer')
+ self._dense_layer = keras.layers.Dense(units, name='dense_layer')
+
+ def call(self, features):
+ net = self._input_layer(features)
+ net = self._dense_layer(net)
+ return net
+
+
+class FeatureColumnsIntegrationTest(test.TestCase):
+ """Most Sequential model API tests are covered in `training_test.py`.
+
+ """
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model(self):
+ columns = [fc.numeric_column('a')]
+ model = keras.models.Sequential([
+ fc.FeatureLayer(columns),
+ keras.layers.Dense(64, activation='relu'),
+ keras.layers.Dense(20, activation='softmax')
+ ])
+ model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ x = {'a': np.random.random((10, 1))}
+ y = np.random.randint(20, size=(10, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ model.fit(x, y, epochs=1, batch_size=5)
+ model.fit(x, y, epochs=1, batch_size=5)
+ model.evaluate(x, y, batch_size=5)
+ model.predict(x, batch_size=5)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model_with_ds_input(self):
+ columns = [fc.numeric_column('a')]
+ model = keras.models.Sequential([
+ fc.FeatureLayer(columns),
+ keras.layers.Dense(64, activation='relu'),
+ keras.layers.Dense(20, activation='softmax')
+ ])
+ model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ y = np.random.randint(20, size=(100, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ x = {'a': np.random.random((100, 1))}
+ ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+ ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+ ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+ model.fit(ds, steps_per_epoch=1)
+ model.fit(ds, steps_per_epoch=1)
+ model.evaluate(ds, steps=1)
+ model.predict(ds, steps=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_subclassed_model_with_feature_columns(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ dnn_model = TestDNNModel([col_a, col_b], 20)
+
+ dnn_model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ x = {'a': np.random.random((10, 1)), 'b': np.random.random((10, 1))}
+ y = np.random.randint(20, size=(10, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+ dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+ dnn_model.evaluate(x=x, y=y, batch_size=5)
+ dnn_model.predict(x=x, batch_size=5)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_subclassed_model_with_feature_columns_with_ds_input(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ dnn_model = TestDNNModel([col_a, col_b], 20)
+
+ dnn_model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ y = np.random.randint(20, size=(100, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ x = {'a': np.random.random((100, 1)), 'b': np.random.random((100, 1))}
+ ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+ ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+ ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+ dnn_model.fit(ds, steps_per_epoch=1)
+ dnn_model.fit(ds, steps_per_epoch=1)
+ dnn_model.evaluate(ds, steps=1)
+ dnn_model.predict(ds, steps=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def DISABLED_test_function_model_feature_layer_input(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ feature_layer = fc.FeatureLayer([col_a, col_b], name='fc')
+ dense = keras.layers.Dense(4)
+
+ # This seems problematic.... We probably need something for FeatureLayer
+ # the way Input is for InputLayer.
+ output = dense(feature_layer)
+
+ model = keras.models.Model([feature_layer], [output])
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ data = ({'a': np.arange(10), 'b': np.arange(10)}, np.arange(10, 20))
+ print(model.fit(*data, epochs=1))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def DISABLED_test_function_model_multiple_feature_layer_inputs(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+ col_c = fc.numeric_column('c')
+
+ fc1 = fc.FeatureLayer([col_a, col_b], name='fc1')
+ fc2 = fc.FeatureLayer([col_b, col_c], name='fc2')
+ dense = keras.layers.Dense(4)
+
+ # This seems problematic.... We probably need something for FeatureLayer
+ # the way Input is for InputLayer.
+ output = dense(fc1) + dense(fc2)
+
+ model = keras.models.Model([fc1, fc2], [output])
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ data_list = ([{
+ 'a': np.arange(10),
+ 'b': np.arange(10)
+ }, {
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }], np.arange(10, 100))
+ print(model.fit(*data_list, epochs=1))
+
+ data_bloated_list = ([{
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }, {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }], np.arange(10, 100))
+ print(model.fit(*data_bloated_list, epochs=1))
+
+ data_dict = ({
+ 'fc1': {
+ 'a': np.arange(10),
+ 'b': np.arange(10)
+ },
+ 'fc2': {
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }
+ }, np.arange(10, 100))
+ print(model.fit(*data_dict, epochs=1))
+
+ data_bloated_dict = ({
+ 'fc1': {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ },
+ 'fc2': {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }
+ }, np.arange(10, 100))
+ print(model.fit(*data_bloated_dict, epochs=1))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 708fa1c807..5ef8d13487 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -394,10 +394,10 @@ class Network(base_layer.Layer):
no_dependency = isinstance(value, data_structures.NoDependency)
value = data_structures.sticky_attribute_assignment(
checkpointable=self, value=value, name=name)
- if isinstance(value, (
- base_layer.Layer,
- Network,
- data_structures.CheckpointableDataStructure)):
+ if (isinstance(value, (base_layer.Layer,
+ Network,
+ data_structures.CheckpointableDataStructure))
+ or checkpointable_layer_utils.has_weights(value)):
try:
is_graph_network = self._is_graph_network
except AttributeError:
@@ -689,14 +689,14 @@ class Network(base_layer.Layer):
def trainable_weights(self):
return checkpointable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
return checkpointable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
@@ -770,7 +770,7 @@ class Network(base_layer.Layer):
# and graph building, the variables created after building the model in
# a Graph are still valid when executing eagerly.
with context.graph_mode():
- graph = eager_function.CapturingGraph()
+ graph = eager_function.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
x = [base_layer.generate_placeholders_from_shape(shape)
@@ -1355,7 +1355,9 @@ class Network(base_layer.Layer):
```
"""
if not self._is_graph_network:
- raise NotImplementedError
+ raise NotImplementedError(
+ 'Currently `save` requires model to be a graph network. Consider '
+ 'using `save_weights`, in order to save the weights of the model.')
from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top
save_model(self, filepath, overwrite, include_optimizer)
@@ -1574,7 +1576,10 @@ class Network(base_layer.Layer):
def get_json_type(obj):
# If obj is any numpy type
if type(obj).__module__ == np.__name__:
- return obj.item()
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return obj.item()
# If obj is a python 'type'
if type(obj).__name__ == type.__name__:
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index a2eed7cb46..a2f31fda8f 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -248,7 +248,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
loss = convert_custom_objects(training_config['loss'])
metrics = convert_custom_objects(training_config['metrics'])
weighted_metrics = convert_custom_objects(
- training_config['weighted_metrics'])
+ training_config.get('weighted_metrics', None))
sample_weight_mode = training_config['sample_weight_mode']
loss_weights = training_config['loss_weights']
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index b7c2e9cb53..148dd23be7 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -48,7 +48,7 @@ except ImportError:
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
def test_weight_loading(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(2,))
x = keras.layers.Dense(3)(a)
b = keras.layers.Dense(1)(x)
@@ -208,7 +208,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
}))
def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
self, layer_class, layer_args):
- with self.test_session():
+ with self.cached_session():
layer = layer_class(**layer_args)
layer.build(input_shape=layer_args.get('input_shape'))
weights1 = layer.get_weights()
@@ -232,7 +232,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
model.add(keras.layers.Dense(num_classes))
@@ -261,7 +261,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
num_hidden = 5
input_dim = 3
num_classes = 2
- with self.test_session():
+ with self.cached_session():
ref_model = keras.models.Sequential()
ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
name='d1'))
@@ -298,7 +298,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
num_hidden = 5
input_dim = 3
num_classes = 2
- with self.test_session():
+ with self.cached_session():
ref_model = keras.models.Sequential()
ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
name='d1'))
@@ -333,7 +333,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -378,7 +378,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -402,7 +402,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
# test with custom optimizer, loss
class CustomOp(keras.optimizers.RMSprop):
@@ -438,7 +438,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -474,7 +474,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -490,7 +490,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -508,7 +508,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -522,7 +522,7 @@ class TestWholeModelSaving(test.TestCase):
os.remove(fname)
def test_saving_lambda_numpy_array_arguments(self):
- with self.test_session():
+ with self.cached_session():
if h5py is None:
self.skipTest('h5py required to run this test')
@@ -548,7 +548,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
# This layer name will make the `layers_name` HDF5 attribute blow
# out of proportion. Note that it fits into the internal HDF5
# attribute memory limit on its own but because h5py converts
@@ -589,7 +589,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
x = keras.Input(shape=(2,), name='nested_model_input')
f = x
for i in range(4):
@@ -634,7 +634,7 @@ class TestWholeModelSaving(test.TestCase):
if h5py is None:
self.skipTest('h5py required to run this test')
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
outputs = keras.layers.Dense(3)(x)
@@ -687,7 +687,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def test_keras_optimizer_warning(self):
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -703,7 +703,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_tensorflow_format_overwrite(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
model = SubclassedModel()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
@@ -741,7 +741,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def test_no_graph_pollution(self):
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph) as session:
+ with graph.as_default(), self.session(graph) as session:
model = SubclassedModel()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
@@ -760,7 +760,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.assertEqual(len(graph.get_operations()), op_count)
def _weight_loading_test_template(self, make_model_fn):
- with self.test_session():
+ with self.cached_session():
model = make_model_fn()
model.compile(
loss='mse',
@@ -822,7 +822,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
def _new_layer_weight_loading_test_template(
self, first_model_fn, second_model_fn, restore_init_fn):
- with self.test_session() as session:
+ with self.cached_session() as session:
model = first_model_fn()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 415b15fde1..9f4019e29c 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -239,9 +239,9 @@ class Sequential(Model):
x = inputs
for layer in self.layers:
kwargs = {}
- if 'mask' in tf_inspect.getargspec(layer.call).args:
+ if 'mask' in tf_inspect.getfullargspec(layer.call).args:
kwargs['mask'] = mask
- if 'training' in tf_inspect.getargspec(layer.call).args:
+ if 'training' in tf_inspect.getfullargspec(layer.call).args:
kwargs['training'] = training
if isinstance(layer, Network) and layer._compute_output_and_mask_jointly:
@@ -332,6 +332,7 @@ class Sequential(Model):
else:
name = None
build_input_shape = None
+ layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
layer = layer_module.deserialize(layer_config,
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 3f8e120df0..9d615c9b0c 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -25,22 +25,12 @@ from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop
-def _get_small_mlp(num_hidden, num_classes, input_dim=None):
- model = keras.models.Sequential()
- if input_dim:
- model.add(keras.layers.Dense(num_hidden, activation='relu',
- input_dim=input_dim))
- else:
- model.add(keras.layers.Dense(num_hidden, activation='relu'))
- model.add(keras.layers.Dense(num_classes, activation='softmax'))
- return model
-
-
class TestSequential(test.TestCase, parameterized.TestCase):
"""Most Sequential model API tests are covered in `training_test.py`.
"""
@@ -63,7 +53,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes, input_dim)
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden, num_classes, input_dim)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
x = np.random.random((batch_size, input_dim))
y = np.random.random((batch_size, num_classes))
@@ -94,7 +85,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -118,7 +109,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
num_samples = 50
steps_per_epoch = 10
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -141,13 +132,13 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@parameterized.parameters((True,), (False,))
def test_training_and_eval_methods_on_symbolic_tensors(self, deferred):
- with self.test_session():
+ with self.cached_session():
def get_model():
if deferred:
- model = _get_small_mlp(10, 4)
+ model = testing_utils.get_small_sequential_mlp(10, 4)
else:
- model = _get_small_mlp(10, 4, input_dim=3)
+ model = testing_utils.get_small_sequential_mlp(10, 4, input_dim=3)
model.compile(
optimizer=rmsprop.RMSPropOptimizer(1e-3),
loss='categorical_crossentropy',
@@ -231,7 +222,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
assert model.updates
@@ -262,7 +253,7 @@ class TestSequential(test.TestCase, parameterized.TestCase):
batch_size = 5
num_classes = 2
- model = _get_small_mlp(num_hidden, num_classes)
+ model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
model.compile(
loss='mse',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
@@ -284,21 +275,21 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_shape_inference_deferred(self):
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
output_shape = model.compute_output_shape((None, 7))
self.assertEqual(tuple(output_shape.as_list()), (None, 5))
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_build_deferred(self):
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
model.build((None, 10))
self.assertTrue(model.built)
self.assertEqual(len(model.weights), 4)
# Test with nested model
- model = _get_small_mlp(4, 3)
- inner_model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 3)
+ inner_model = testing_utils.get_small_sequential_mlp(4, 5)
model.add(inner_model)
model.build((None, 10))
@@ -308,8 +299,8 @@ class TestSequential(test.TestCase, parameterized.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_nesting(self):
- model = _get_small_mlp(4, 3)
- inner_model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 3)
+ inner_model = testing_utils.get_small_sequential_mlp(4, 5)
model.add(inner_model)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
@@ -353,7 +344,7 @@ class TestSequentialEagerIntegration(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_build_before_fit(self):
# Fix for b/112433577
- model = _get_small_mlp(4, 5)
+ model = testing_utils.get_small_sequential_mlp(4, 5)
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
model.build((None, 6))
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 079c8dae71..061db8ee34 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -342,7 +342,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual(model.non_trainable_weights, weights)
def test_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -458,7 +458,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(dense.get_output_mask_at(1), None)
def test_multi_input_layer(self):
- with self.test_session():
+ with self.cached_session():
# test multi-input layer
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -530,7 +530,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
def test_recursion(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -591,7 +591,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)])
def test_multi_input_multi_output_recursion(self):
- with self.test_session():
+ with self.cached_session():
# test multi-input multi-output
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
@@ -816,7 +816,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(loss, 4.)
def test_layer_sharing_at_heterogenous_depth(self):
- with self.test_session():
+ with self.cached_session():
x_val = np.random.random((10, 5))
x = input_layer_lib.Input(shape=(5,))
@@ -837,7 +837,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertAllClose(output_val, output_val_2, atol=1e-6)
def test_layer_sharing_at_heterogenous_depth_with_concat(self):
- with self.test_session():
+ with self.cached_session():
input_shape = (16, 9, 3)
input_layer = input_layer_lib.Input(shape=input_shape)
@@ -864,7 +864,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertAllClose(output_val, output_val_2, atol=1e-6)
def test_explicit_training_argument(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(2,))
b = keras.layers.Dropout(0.5)(a)
base_model = keras.models.Model(a, b)
@@ -887,7 +887,8 @@ class TopologyConstructionTest(test.TestCase):
def test_multi_output_model_with_none_masking(self):
- with self.test_session():
+ with self.cached_session():
+
def func(x):
return [x * 0.2, x * 0.3]
@@ -912,6 +913,23 @@ class TopologyConstructionTest(test.TestCase):
assert out.shape == (4, 3, 2, 1)
self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
+ def test_constant_initializer_with_numpy(self):
+
+ with self.test_session():
+ initializer = keras.initializers.Constant(np.ones((3, 2)))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,),
+ kernel_initializer=initializer))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ json_str = model.to_json()
+ keras.models.model_from_json(json_str)
+
+ if yaml is not None:
+ yaml_str = model.to_yaml()
+ keras.models.model_from_yaml(yaml_str)
+
class DeferredModeTest(test.TestCase):
@@ -1169,7 +1187,7 @@ class GraphUtilsTest(test.TestCase):
def testGetReachableFromInputs(self):
- with self.test_session():
+ with self.cached_session():
pl_1 = array_ops.placeholder(shape=None, dtype='float32')
pl_2 = array_ops.placeholder(shape=None, dtype='float32')
pl_3 = array_ops.placeholder(shape=None, dtype='float32')
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index f71388cadb..49b25e307e 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import weakref
import numpy as np
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -45,6 +47,7 @@ from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -405,20 +408,9 @@ class Model(Network):
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
if self._distribution_strategy is not None:
- self._grouped_model = self._compile_distributed_model(
+ self._grouped_model = None
+ distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
- with self._distribution_strategy.scope():
- first_replicated_model = self._distribution_strategy.unwrap(
- self._grouped_model)[0]
- # If the specified metrics in `compile` are stateful, raise an error
- # since we currently don't support stateful metrics.
- if first_replicated_model.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
- # We initialize the callback model with the first replicated model.
- self._replicated_model = DistributedCallbackModel(first_replicated_model)
- self._replicated_model.set_original_model(self)
if not self.built:
# Model is not compilable because it does not know its number of inputs
# and outputs, nor their shapes and names. We will compile after the first
@@ -636,6 +628,12 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
+ # If using distribution strategy and stateful_metrics, raise an error
+ # since we currently don't support stateful metrics.
+ if self._distribution_strategy is not None and self.stateful_metric_names:
+ raise NotImplementedError('Stateful metrics are not supported with '
+ 'DistributionStrategy.')
+
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -652,19 +650,6 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
- def _compile_distributed_model(self, distribution_strategy):
- # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the
- # model?
- def _clone_model_per_tower(model):
- new_model = training_distributed.clone_and_build_model(model)
- return new_model
-
- with distribution_strategy.scope():
- # Create a copy of this model on each of the devices.
- grouped_models = distribution_strategy.call_for_each_tower(
- _clone_model_per_tower, self)
- return grouped_models
-
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -771,9 +756,8 @@ class Model(Network):
the model.
Args:
- x: Input data. A `tf.data` dataset.
- y: Since `x` is a dataset, `y` should not be specified
- (since targets will be obtained from the iterator).
+ x: Input data. A numpy array or `tf.data` dataset.
+ y: Target data. A numpy array or None if x is a `tf.data` dataset.
sample_weight: An optional sample-weight array passed by the user to
weight the importance of each sample in `x`.
class_weight: An optional class-weight array by the user to
@@ -790,28 +774,64 @@ class Model(Network):
Fraction of the training data to be used as validation data.
Returns:
- A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
- If the model's input and targets are symbolic, these lists are empty
- (since the model takes no user-provided data, instead the data comes
- from the symbolic inputs/targets).
+ Iterator for reading the dataset `x`.
Raises:
ValueError: In case of invalid user-provided data.
RuntimeError: If the model was never compiled.
"""
if sample_weight is not None and sample_weight.all():
- raise NotImplementedError('sample_weight is currently not supported when '
- 'using DistributionStrategy.')
+ raise NotImplementedError('`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.')
if class_weight:
- raise NotImplementedError('class_weight is currently not supported when '
- 'using DistributionStrategy.')
+ raise NotImplementedError('`class_weight` is currently not supported '
+ 'when using DistributionStrategy.')
+
+ # Validates `steps` argument right at the beginning since we use it to
+ # construct the dataset object.
+ # TODO(anjalisridhar): This may not be a valid error since we now accept
+ # numpy array inputs. We still want to assert that we have a populated steps
+ # parameter.
+ if check_steps:
+ if steps is None:
+ raise ValueError('When using DistributionStrategy, '
+ 'you should specify the `{steps_name}` argument.'
+ .format(steps_name=steps_name))
+
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray):
+ x_shape = first_x_value.shape
+ x_dtype = first_x_value.dtype
+ if batch_size is None:
+ batch_size = x_shape[0] // steps
+ if y is not None:
+ first_y_value = nest.flatten(y)[0]
+ x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
+ output_types=(x_dtype, first_y_value.dtype),
+ output_shapes=(x_shape[1:],
+ first_y_value.shape[1:]))
+ # TODO(anjalisridhar): What should the buffer size be?
+ x = x.shuffle(10000)
+ x = x.repeat()
+ x = x.batch(batch_size)
+ y = None
+ else:
+ # This case is for the predict call where the dataset only contains
+ # inputs and no targets i.e it does not return a tuple.
+ # TODO(anjalisridhar): Raise an error if we are not able to process
+ # all the predict samples. This can happen if the number of batches is
+ # not evenly divisible by the number of worker devices.
+ x = Dataset.from_generator(lambda x=x: x,
+ output_types=x_dtype,
+ output_shapes=x_shape[1:])
+ x = x.repeat()
+ x = x.batch(batch_size)
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
# multiple devices.
- if not isinstance(x, dataset_ops.Dataset):
- raise ValueError('When using DistributionStrategy you must specify a '
- 'Dataset object instead of a %s.' % type(x))
+ assert isinstance(x, dataset_ops.Dataset)
+
# TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
# function which returns a Dataset. Currently distribute_dataset() only
# accepts a function that returns a Dataset. Once we add support for being
@@ -819,38 +839,10 @@ class Model(Network):
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
K.get_session().run(iterator.initializer)
- # Validates `steps` argument based on x's type.
- if check_steps:
- if steps is None:
- raise ValueError('When using a Dataset instance as input to a model, '
- 'you should specify the `{steps_name}` argument.'
- .format(steps_name=steps_name))
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
- # x an y may be PerDevice objects with an input and output tensor
- # corresponding to each device. For example, x could be
- # PerDevice:{device: get_next tensor,...}.
- next_element = iterator.get_next()
-
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
- x, y = next_element
- # Validate that all the elements in x and y are of the same type and shape.
- # We can then pass the first element of x and y to `_standardize_weights`
- # below and be confident of the output. We need to reopen the scope since
- # we unwrap values when we validate x and y.
- with self._distribution_strategy.scope():
- x_values, y_values = distributed_training_utils.\
- validate_distributed_dataset_inputs(self._distribution_strategy, x, y)
-
- _, _, sample_weights = self._standardize_weights(x_values,
- y_values,
- sample_weight,
- class_weight,
- batch_size)
- return x, y, sample_weights
+ return iterator
def _standardize_user_data(self,
x,
@@ -905,7 +897,8 @@ class Model(Network):
Fraction of the training data to be used as validation data.
Returns:
- A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
+ A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
+ or not), target arrays, sample-weight arrays.
If the model's input and targets are symbolic, these lists are empty
(since the model takes no user-provided data, instead the data comes
from the symbolic inputs/targets).
@@ -915,7 +908,7 @@ class Model(Network):
RuntimeError: If the model was never compiled.
"""
if self._distribution_strategy:
- return self._distribution_standardize_user_data(
+ iterator = self._distribution_standardize_user_data(
x,
y,
sample_weight=sample_weight,
@@ -925,6 +918,7 @@ class Model(Network):
steps_name=steps_name,
steps=steps,
validation_split=validation_split)
+ return iterator, None, None
if isinstance(x, dataset_ops.Dataset):
if context.executing_eagerly():
@@ -970,20 +964,32 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
- x, y = next_element
+ if (not isinstance(next_element, (list, tuple)) or
+ len(next_element) not in [2, 3]):
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
def _standardize_weights(self, x, y, sample_weight=None, class_weight=None,
batch_size=None,):
+ # TODO(sourabhbajaj): Split input validation from weight standardization.
+ if sample_weight is not None and class_weight is not None:
+ logging.warning(
+ 'Received both a `sample_weight` and `class_weight` argument. '
+ 'The `class_weight` argument will be ignored.')
# First, we build/compile the model on the fly if necessary.
all_inputs = []
is_build_called = False
is_compile_called = False
+ dict_inputs = False
if not self.inputs:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
@@ -995,7 +1001,9 @@ class Model(Network):
'array or a list of arrays. You passed: x=' + str(x))
all_inputs += list(x)
elif isinstance(x, dict):
- raise ValueError('Please do not pass a dictionary as model inputs.')
+ dict_inputs = True
+ keys = sorted(x.keys())
+ all_inputs = [x[k] for k in keys]
else:
if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x):
raise ValueError('Please provide as model inputs either a single '
@@ -1008,6 +1016,8 @@ class Model(Network):
if not self.inputs:
is_build_called = True
self._set_inputs(x)
+ else:
+ dict_inputs = isinstance(self.inputs, dict)
if y is not None:
if not self.optimizer:
@@ -1160,6 +1170,10 @@ class Model(Network):
'a number of samples that can be '
'divided by the batch size. Found: ' +
str(x[0].shape[0]) + ' samples')
+
+ # If dictionary inputs were provided, we return a dictionary as well.
+ if dict_inputs:
+ x = dict(zip(feed_input_names, x))
return x, y, sample_weights
@checkpointable.no_automatic_dependency_tracking
@@ -1182,6 +1196,9 @@ class Model(Network):
training: Boolean or None. Only relevant in symbolic mode. Specifies
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
+ Raises:
+ ValueError: If dict inputs are passed to a Sequential Model where the
+ first layer isn't FeatureLayer.
"""
call_convention = getattr(
self,
@@ -1198,6 +1215,14 @@ class Model(Network):
if tensor_util.is_tensor(inputs):
input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
self.build(input_shape=input_shape)
+ elif isinstance(inputs, dict):
+ # We assert that the first layer is a FeatureLayer.
+ if not training_utils.is_feature_layer(self.layers[0]):
+ raise ValueError('Passing a dictionary input to a Sequential Model '
+ 'which doesnt have FeatureLayer as the first layer '
+ 'is an error')
+ input_shape = (None,)
+ self.build(input_shape=input_shape)
else:
input_shape = (None,) + inputs.shape[1:]
self.build(input_shape=input_shape)
@@ -1225,36 +1250,22 @@ class Model(Network):
assert context.executing_eagerly()
if self.inputs:
raise ValueError('Model inputs are already set.')
+
# On-the-fly setting of model inputs/outputs as DeferredTensors,
# to keep track of number of inputs and outputs and their ndim.
- if isinstance(inputs, (list, tuple)):
- if tensor_util.is_tensor(inputs[0]):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
- dummy_input_values = list(inputs)
- else:
- if tensor_util.is_tensor(inputs):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- ops.convert_to_tensor(inputs, dtype=K.floatx()))
- dummy_input_values = [inputs]
- if isinstance(dummy_output_values, (list, tuple)):
- dummy_output_values = list(dummy_output_values)
- else:
- dummy_output_values = [dummy_output_values]
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_input_values()
+ dummy_output_values = self.call(dummy_input_values)
+
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
+ dummy_output_values = nest.flatten(dummy_output_values)
self.outputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_output_values]
- self.inputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_input_values]
- self.input_names = [
- 'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
+ base_layer.DeferredTensor(shape=(None
+ for _ in v.shape), dtype=v.dtype)
+ for v in dummy_output_values
+ ]
self.output_names = [
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
@@ -1284,58 +1295,29 @@ class Model(Network):
# On-the-fly setting of symbolic model inputs (either by using the tensor
# provided, or by creating a placeholder if Numpy data was provided).
- self.inputs = []
- self.input_names = []
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_symbolic_inputs()
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
self._feed_inputs = []
self._feed_input_names = []
self._feed_input_shapes = []
- if isinstance(inputs, (list, tuple)):
- inputs = list(inputs)
- else:
- inputs = [inputs]
-
- for i, v in enumerate(inputs):
- name = 'input_%d' % (i + 1)
- self.input_names.append(name)
- if isinstance(v, list):
- v = np.asarray(v)
- if v.ndim == 1:
- v = np.expand_dims(v, 1)
- if isinstance(v, (np.ndarray)):
- # We fix the placeholder shape except the batch size.
- # This is suboptimal, but it is the best we can do with the info
- # we have. The user should call `model._set_inputs(placeholders)`
- # to specify custom placeholders if the need arises.
- shape = (None,) + v.shape[1:]
- placeholder = K.placeholder(shape=shape, name=name)
- self.inputs.append(placeholder)
- self._feed_inputs.append(placeholder)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(shape)
- else:
- # Assumed tensor - TODO(fchollet) additional type check?
- self.inputs.append(v)
- if K.is_placeholder(v):
- self._feed_inputs.append(v)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(K.int_shape(v))
+
+ for k, v in model_inputs.as_dict():
+ if K.is_placeholder(v):
+ self._feed_inputs.append(v)
+ self._feed_input_names.append(k)
+ self._feed_input_shapes.append(K.int_shape(v))
if outputs is None:
# Obtain symbolic outputs by calling the model.
- if len(self.inputs) == 1:
- if self._expects_training_arg:
- outputs = self.call(self.inputs[0], training=training)
- else:
- outputs = self.call(self.inputs[0])
+ if self._expects_training_arg:
+ outputs = self.call(dummy_input_values, training=training)
else:
- if self._expects_training_arg:
- outputs = self.call(self.inputs, training=training)
- else:
- outputs = self.call(self.inputs)
- if isinstance(outputs, (list, tuple)):
- outputs = list(outputs)
- else:
- outputs = [outputs]
+ outputs = self.call(dummy_input_values)
+
+ outputs = nest.flatten(outputs)
self.outputs = outputs
self.output_names = [
'output_%d' % (i + 1) for i in range(len(self.outputs))]
@@ -1367,7 +1349,8 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset or a dataset iterator.
+ - A `tf.data` dataset or a dataset iterator. Should return a tuple
+ of either (inputs, targets) or (inputs, targets, sample_weights).
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
@@ -1432,7 +1415,8 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead
+ provide the sample_weights as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1478,6 +1462,13 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
+ distributed_training_utils.validate_inputs(x, y)
+
+ first_x_value = nest.flatten(x)[0]
+ if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
+ steps_per_epoch = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1512,6 +1503,13 @@ class Model(Network):
'However we received `validation_data=%s`' % validation_data)
# Validate and standardize validation data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(val_x, val_y)
+ first_valx_value = nest.flatten(val_x)[0]
+ if not validation_steps and isinstance(first_valx_value, np.ndarray):
+ validation_steps = distributed_training_utils.get_input_batch_params(
+ first_valx_value, batch_size, self._distribution_strategy)
+
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x,
val_y,
@@ -1560,12 +1558,11 @@ class Model(Network):
validation_steps=validation_steps)
elif self._distribution_strategy:
return training_distributed.fit_loop(
- self, x, y,
+ self, x,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
- val_inputs=val_x,
- val_targets=val_y,
+ val_iterator=val_x,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
@@ -1650,6 +1647,13 @@ class Model(Network):
batch_size = 32
# Validate and standardize user data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_inputs(x, y)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1671,8 +1675,7 @@ class Model(Network):
elif self._distribution_strategy:
return training_distributed.test_loop(
self,
- inputs=x,
- targets=y,
+ iterator=x,
verbose=verbose,
steps=steps)
else:
@@ -1721,7 +1724,22 @@ class Model(Network):
if batch_size is None and steps is None:
batch_size = 32
+ if self._distribution_strategy:
+ # Turn off prefetching since this is currently not deterministic. Once
+ # b/112498930 is fixed we can turn it back on.
+ # `_prefetch_on_device` is currently a property of only
+ # `MirroredStrategy`.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+ distributed_training_utils.validate_inputs(x, None)
+ first_x_value = nest.flatten(x)[0]
+ if isinstance(first_x_value, np.ndarray) and not steps:
+ steps = distributed_training_utils.get_input_batch_params(
+ first_x_value, batch_size, self._distribution_strategy)
+
# Validate and standardize user data.
+ # TODO(anjalisridhar): We don't pass batch_size here for some reason. This
+ # means that we end up calculating it twice which we should avoid.
x, _, _ = self._standardize_user_data(
x, check_steps=True, steps_name='steps', steps=steps)
@@ -1729,8 +1747,12 @@ class Model(Network):
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
elif self._distribution_strategy:
- return training_distributed.predict_loop(
+ results = training_distributed.predict_loop(
self, x, verbose=verbose, steps=steps)
+ # Turn prefetching back on since we turned it off previously.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = True # pylint: disable=protected-access
+ return results
else:
return training_arrays.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
@@ -2171,6 +2193,13 @@ class Model(Network):
return self.callback_model
return self
+ def _make_callback_model(self):
+ first_replicated_model = self._distribution_strategy.unwrap(
+ self._grouped_model)[0]
+ # We initialize the callback model with the first replicated model.
+ self._replicated_model = DistributedCallbackModel(first_replicated_model)
+ self._replicated_model.set_original_model(self)
+
class DistributedCallbackModel(Model):
"""Model that is used for callbacks with DistributionStrategy."""
@@ -2208,6 +2237,6 @@ class DistributedCallbackModel(Model):
# Whitelisted atttributes of the model that can be accessed by the user
# during a callback.
if item not in ['_setattr_tracking']:
- logging.warning('You are accessing attribute ' + item + 'of the'
- 'DistributedCallbackModel that may not have been set'
+ logging.warning('You are accessing attribute ' + item + 'of the '
+ 'DistributedCallbackModel that may not have been set '
'correctly.')
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e2c458c65f..95b864bef0 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -55,7 +55,7 @@ def fit_loop(model,
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
+ inputs: Either a list of arrays or a dictionary.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
batch_size: Integer batch size or None if unknown.
@@ -88,6 +88,7 @@ def fit_loop(model,
sample_weights = sample_weights or []
val_sample_weights = val_sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [1]
else:
@@ -262,6 +263,7 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None):
model._make_predict_function()
f = model.predict_function
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + [0]
else:
@@ -368,6 +370,7 @@ def test_loop(model,
f = model.test_function
sample_weights = sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [0]
else:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 5feedc43a5..53291c3956 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -19,38 +19,41 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
+
+
+# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
def fit_loop(
model,
- inputs,
- targets,
+ iterator,
epochs=100,
verbose=1,
callbacks=None,
- val_inputs=None,
- val_targets=None,
+ val_iterator=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None):
- """fit function when using DistributionStrategy for training.
+ """Fit loop for training with DistributionStrategy.
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
- targets: List of target arrays.
+ iterator: Iterator for input data.
epochs: Number of times to iterate over the data
- verbose: Verbosity mode, 0, 1 or 2
+ verbose: Integer, Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
- val_inputs: List of input arrays.
- val_targets: List of target arrays.
+ val_iterator: Iterator for validation data.
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
@@ -67,6 +70,16 @@ def fit_loop(
ValueError: in case of invalid arguments.
"""
current_strategy = model._distribution_strategy
+
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_fit_loop(
+ model, iterator, epochs, verbose, callbacks, initial_epoch,
+ steps_per_epoch)
+
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy, make_callback_model=True)
+
def _per_device_train_function(model):
model._make_train_function()
return (model.train_function.inputs,
@@ -74,6 +87,7 @@ def fit_loop(
model.train_function.updates_op,
model.train_function.session_kwargs)
+ inputs, targets = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
# Create train ops on each of the devices when we call
# `_per_device_train_function`.
@@ -115,11 +129,6 @@ def fit_loop(
do_validation = False
if validation_steps:
do_validation = True
- if steps_per_epoch is None:
- raise ValueError('Can only use `validation_steps` '
- 'when doing step-wise '
- 'training, i.e. `steps_per_epoch` '
- 'must be set.')
# Copy the weights from the original model to each of the replicated models.
orig_model_weights = model.get_weights()
@@ -139,45 +148,46 @@ def fit_loop(
verbose=verbose)
out_labels = model.metrics_names or []
callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
- if steps_per_epoch is not None:
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(
- len(current_strategy._devices), out_labels, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
- if callbacks.model.stop_training:
- break
- if do_validation:
- val_outs = test_loop(
- model,
- val_inputs,
- val_targets,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, out_labels, outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
@@ -192,14 +202,178 @@ def fit_loop(
return model.history
-def test_loop(model, inputs, targets, verbose=0, steps=None):
- """evaluate method to validate a model that uses DistributionStrategy.
+def _experimental_fit_loop(
+ model,
+ iterator,
+ epochs=100,
+ verbose=1,
+ callbacks=None,
+ initial_epoch=0,
+ steps_per_epoch=None):
+ """Fit loop for training with TPU DistributionStrategy.
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
- targets: List of target arrays.
- verbose: verbosity mode.
+ iterator: Iterator that returns inputs and targets
+ epochs: Number of times to iterate over the data
+ verbose: Integer, Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. Ignored with the default value of `None`.
+
+ Returns:
+ Returns `None`.
+
+ Raises:
+ ValueError: in case of invalid arguments.
+ """
+ current_strategy = model._distribution_strategy
+
+ # TODO(priyag): Add validation that shapes are fully defined for TPU case.
+
+ K.get_session().run(current_strategy.initialize())
+
+ def _per_device_train_function(model):
+ model._make_train_function()
+ return (model.train_function.inputs,
+ model.train_function.outputs,
+ model.train_function.updates_op,
+ model.train_function.session_kwargs)
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(1)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_train_function."""
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=True,
+ inputs=inputs,
+ targets=targets)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_train_function, model._grouped_model)
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs,
+ grouped_updates, grouped_session_args)
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ out_labels = model.metrics_names or []
+ for label, output in zip(out_labels, combined_fn.outputs):
+ if label == 'loss':
+ aggregation = distribute_lib.get_loss_reduction()
+ else:
+ # We aggregate all other metrics using mean for now. This is temporary
+ # workaround until new metrics are in place.
+ aggregation = variable_scope.VariableAggregation.MEAN
+ ctx.set_last_step_output(label, output, aggregation)
+
+ # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
+ # feed_dict, session kwargs, run options, run_metadata for now. These should
+ # be handled appropriately
+ return combined_fn.updates_op
+
+ # Add initial dummy values for loss and other metric tensors.
+ initial_loop_values = {}
+ initial_loop_values['loss'] = constant_op.constant(1e7)
+ for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
+ # steps_per_epoch and number of epochs.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=current_strategy.steps_per_run,
+ initial_loop_values=initial_loop_values)
+
+ train_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps_per_epoch is not None
+
+ # TODO(sourabhbajaj): Convert this into a proper validation function
+ if callbacks:
+ raise NotImplementedError(
+ 'Callbacks are not supported with TPUStrategy right now.')
+
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=False,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
+ # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
+ # TODO(priyag, sourabhbajaj): Add validation.
+ callbacks.on_train_begin()
+ for epoch in range(initial_epoch, epochs):
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
+ # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
+ # and batch_size
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ _, outputs = K.get_session().run([train_op, output_tensors])
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ batch_logs.update(outputs)
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+
+ # Copy the weights back from the replicated model to the original model.
+ with current_strategy.scope():
+ updated_weights = current_strategy.unwrap(
+ model._grouped_model)[0].get_weights()
+ model.set_weights(updated_weights)
+
+ K.get_session().run(current_strategy.finalize())
+ return model.history
+
+
+def test_loop(model, iterator, verbose=0, steps=None):
+ """Test loop for evaluating with DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring predictions finished.
Ignored with the default value of `None`.
@@ -208,9 +382,17 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
Scalar loss (if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
+ the display labels for the outputs.
"""
current_strategy = model._distribution_strategy
+
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_test_loop(model, iterator, verbose, steps)
+
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy)
+
def _per_device_test_function(model):
model._make_test_function()
return (model.test_function.inputs,
@@ -218,6 +400,7 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
model.test_function.updates_op,
model.test_function.session_kwargs)
+ inputs, targets = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
@@ -259,38 +442,149 @@ def test_loop(model, inputs, targets, verbose=0, steps=None):
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- len(current_strategy._devices), model.metrics_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- for _ in enumerate(batch_outs):
- outs.append(0.)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ outs[i] /= steps
+
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def _experimental_test_loop(model, iterator, verbose=0, steps=None):
+ """Test loop for evaluating with TPU DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
+ steps: Total number of steps (batches of samples)
+ before declaring predictions finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the outputs.
+ """
+ current_strategy = model._distribution_strategy
+ K.get_session().run(current_strategy.initialize())
+
+ def _per_device_test_function(model):
+ model._make_test_function()
+ return (model.test_function.inputs,
+ model.test_function.outputs,
+ model.test_function.updates_op,
+ model.test_function.session_kwargs)
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(0)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_test_function."""
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=False,
+ inputs=inputs,
+ targets=targets)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_test_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args)
+
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
+
+ for label, output in zip(model.metrics_names, combined_fn.outputs):
+ if label == 'loss':
+ aggregation = distribute_lib.get_loss_reduction()
else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose == 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- outs[i] /= steps
+ # We aggregate all other metrics using mean for now. This is temporary
+ # workaround until new metrics are in place.
+ aggregation = variable_scope.VariableAggregation.MEAN
+ ctx.set_last_step_output(label, output, aggregation)
+
+ return combined_fn.updates_op
+
+ # Add initial dummy values for loss and other metric tensors.
+ initial_loop_values = {}
+ initial_loop_values['loss'] = constant_op.constant(1e7)
+ for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag): Use steps_per_run when we use new metrics as they will
+ # allow handling metric computation at each step using variables.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=1,
+ initial_loop_values=initial_loop_values)
+
+ test_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps is not None
+ outs = [0.] * len(model.metrics_names)
+ for step in range(steps):
+ _, batch_outs = K.get_session().run([test_op, output_tensors])
+ for i, label in enumerate(model.metrics_names):
+ outs[i] += batch_outs[label]
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ outs[i] /= (steps)
+
+ K.get_session().run(current_strategy.finalize())
if len(outs) == 1:
return outs[0]
return outs
-def predict_loop(model, inputs, verbose=0, steps=None):
- """Abstract method to loop over some data in batches.
+def predict_loop(model, iterator, verbose=0, steps=None):
+ """Predict loop for predicting with DistributionStrategy.
Arguments:
model: Keras Model instance.
- inputs: list of tensors to be fed to `f`.
- verbose: verbosity mode.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring `_predict_loop` finished.
Ignored with the default value of `None`.
@@ -301,6 +595,14 @@ def predict_loop(model, inputs, verbose=0, steps=None):
(if the model has multiple outputs).
"""
current_strategy = model._distribution_strategy
+
+ # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+ if current_strategy.__class__.__name__ == 'TPUStrategy':
+ return _experimental_predict_loop(model, iterator, verbose, steps)
+
+ if not model._grouped_model:
+ clone_model_on_towers(model, current_strategy)
+
def _per_device_predict_function(model):
model._make_predict_function()
return (model.predict_function.inputs,
@@ -308,6 +610,7 @@ def predict_loop(model, inputs, verbose=0, steps=None):
model.predict_function.updates_op,
model.predict_function.session_kwargs)
+ inputs, _ = _get_input_from_iterator(iterator, model)
with current_strategy.scope():
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.call_for_each_tower(
@@ -354,9 +657,11 @@ def predict_loop(model, inputs, verbose=0, steps=None):
if step == 0:
for _ in batch_outs:
unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
for i, batch_out in enumerate(batch_outs):
unconcatenated_outs[i].append(batch_out)
- if verbose == 1:
+ if verbose >= 1:
progbar.update(step + 1)
if len(unconcatenated_outs) == 1:
return np.concatenate(unconcatenated_outs[0], axis=0)
@@ -366,12 +671,128 @@ def predict_loop(model, inputs, verbose=0, steps=None):
]
-def clone_and_build_model(model):
+def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
+ """Predict loop for predicting with TPU DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ iterator: Iterator for input data.
+ verbose: Integer, Verbosity mode 0 or 1.
+ steps: Total number of steps (batches of samples)
+ before declaring `_predict_loop` finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions
+ (if the model has multiple outputs).
+ """
+ current_strategy = model._distribution_strategy
+ K.get_session().run(current_strategy.initialize())
+
+ # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+ K.set_learning_phase(0)
+
+ def _per_device_predict_function(model):
+ model._make_predict_function()
+ return (model.predict_function.inputs,
+ model.predict_function.outputs,
+ model.predict_function.updates_op,
+ model.predict_function.session_kwargs)
+
+ def step_fn(ctx, inputs, targets):
+ """Clones the model and calls make_predict_function."""
+
+ # TODO(anjalisridhar): Support predict input correctly as it will not
+ # contain targets, only inputs.
+ del targets
+
+ # TODO(priyag, sourabhbajaj): The model gets cloned every time
+ # fit/test/predict is called. We should look into caching this keyed on
+ # input shapes.
+ clone_model_on_towers(
+ model,
+ current_strategy,
+ make_callback_model=False,
+ inputs=inputs)
+
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_predict_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args)
+
+ combined_fn = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
+
+ for label, output in zip(model.output_names, combined_fn.outputs):
+ ctx.set_last_step_output(label, output)
+
+ return combined_fn.updates_op
+
+ # Add initial dummy values for outputs.
+ initial_loop_values = {}
+ batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
+ for name, tensor in zip(model.output_names, model.outputs):
+ # TODO(priyag): This is a workaround as we do not know the batch dimension
+ # of the model's output at this point.
+ tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+ with current_strategy.scope():
+ # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
+ ctx = current_strategy.run_steps_on_dataset(
+ step_fn, iterator, iterations=1,
+ initial_loop_values=initial_loop_values)
+
+ predict_op = ctx.run_op
+ output_tensors = ctx.last_step_outputs
+
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ assert steps is not None
+ # Since we do not know how many samples we will see, we cannot pre-allocate
+ # the returned Numpy arrays. Instead, we store one array per batch seen
+ # and concatenate them upon returning.
+ unconcatenated_outs = [[] for _ in model.outputs]
+ for step in range(steps):
+ _, batch_outs = K.get_session().run([predict_op, output_tensors])
+ # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
+ for i, label in enumerate(model.output_names):
+ unconcatenated_outs[i].extend(batch_outs[label])
+ if verbose >= 1:
+ progbar.update(step + 1)
+
+ K.get_session().run(current_strategy.finalize())
+
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
+
+
+def _clone_and_build_model(model, inputs=None, targets=None):
"""Clone and build the given keras_model."""
# We need to set the import here since we run into a circular dependency
# error.
from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
- cloned_model = models.clone_model(model, input_tensors=None)
+ cloned_model = models.clone_model(model, input_tensors=inputs)
# Compile and build model.
if isinstance(model.optimizer, optimizers.TFOptimizer):
@@ -380,16 +801,32 @@ def clone_and_build_model(model):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
+ # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
+ # single tensor should be OK but it throws an error in that case.
+ if (targets is not None and not isinstance(targets, list) and
+ not isinstance(targets, dict)):
+ targets = [targets]
cloned_model.compile(
optimizer,
model.loss,
metrics=model.metrics,
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics)
+ weighted_metrics=model.weighted_metrics,
+ target_tensors=targets)
return cloned_model
+def clone_model_on_towers(
+ model, strategy, make_callback_model=False, inputs=None, targets=None):
+ """Create a cloned model on each tower."""
+ with strategy.scope():
+ model._grouped_model = strategy.call_for_each_tower(
+ _clone_and_build_model, model, inputs, targets)
+ if make_callback_model:
+ model._make_callback_model()
+
+
def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
"""Aggregate metrics values across all towers.
@@ -419,3 +856,23 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
merged_output.append(m)
current_index += num_devices
return merged_output
+
+
+def _get_input_from_iterator(iterator, model):
+ """Get elements from the iterator and verify the input shape and type."""
+ next_element = iterator.get_next()
+
+ if isinstance(next_element, tuple):
+ x, y = next_element
+ else:
+ x = next_element
+ y = None
+ # Validate that all the elements in x and y are of the same type and shape.
+ # We can then pass the first element of x and y to `_standardize_weights`
+ # below and be confident of the output.
+ x_values, y_values = distributed_training_utils.\
+ validate_distributed_dataset_inputs(model._distribution_strategy, x, y)
+ # TODO(sourabhbajaj): Add support for sample weights in distribution
+ # strategy.
+ model._standardize_weights(x_values, y_values)
+ return x, y
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 1e377149b6..939a7f2356 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -67,7 +67,8 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
Arguments:
model: The model on which metrics are being calculated.
- inputs: List of input arrays.
+ inputs: Either a dictionary of inputs to the model or a list of input
+ arrays.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
training: Whether the model should be run in inference or training mode.
@@ -82,7 +83,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
kwargs = {}
if model._expects_training_arg:
kwargs['training'] = training
- if len(inputs) == 1:
+ if len(inputs) == 1 and not isinstance(inputs, dict):
inputs = inputs[0]
if model._compute_output_and_mask_jointly:
@@ -369,6 +370,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
# Get current step size.
if isinstance(x, list):
step_size = x[0].get_shape().as_list()[0]
+ elif isinstance(x, dict):
+ step_size = list(x.values())[0].get_shape().as_list()[0]
else:
step_size = x.get_shape().as_list()[0]
@@ -417,11 +420,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
if not isinstance(inputs.output_shapes,
- (list, tuple)) or len(inputs.output_shapes) > 2:
+ (list, tuple)) or len(inputs.output_shapes) > 3:
raise ValueError(
- 'Please provide data as a list or tuple of 1 or 2 elements '
- ' - input or input and target pair. Received %s. We do not use the '
- '`target` value here.' % inputs.output_shapes)
+ 'Please provide data as a list or tuple of 1, 2, or 3 elements '
+ ' - `(input)`, or `(input, target)`, or `(input, target,'
+ 'sample_weights)`. Received %s. We do not use the `target` or'
+ '`sample_weights` value here.' % inputs.output_shapes)
outs = []
if verbose == 1:
progbar = generic_utils.Progbar(target=steps)
@@ -444,10 +448,13 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
x, _, _ = model._standardize_user_data(x)
x = training_utils.cast_if_floating_dtype(x)
+ if isinstance(x, list) and len(x) == 1:
+ x = x[0]
+
if model._expects_training_arg:
- batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
+ batch_outs = model.call(x, training=False)
else:
- batch_outs = model.call(x[0] if len(x) == 1 else x)
+ batch_outs = model.call(x)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 15e7d725de..8938333b1a 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -35,6 +35,8 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine.training_utils import weighted_masked_objective
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -49,289 +51,287 @@ class TrainingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_fit_on_arrays(self):
- with self.test_session():
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- model.compile(
- optimizer,
- loss,
- metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
- loss_weights=loss_weights)
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
-
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
-
- # Test fit at different verbosity
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=2)
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- # Test model with input data as a list of lists
- model.fit(
- [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
- [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=2)
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
- # Test with validation data
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=1)
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- validation_data=([input_a_np, input_b_np], [output_d_np,
- output_e_np]),
- epochs=2,
- batch_size=5,
- verbose=2)
- # Test with validation split
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=2,
- batch_size=5,
- verbose=0,
- validation_split=0.2)
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
- # Test with dictionary inputs
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- epochs=1,
- batch_size=5,
- verbose=0)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- epochs=1,
- batch_size=5,
- verbose=1)
- model.fit(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- validation_data=({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- }),
- epochs=1,
- batch_size=5,
- verbose=0)
- model.train_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- })
-
- # Test with lists for loss, metrics
- loss = ['mae', 'mse']
- model.compile(
- optimizer,
- loss,
- metrics=[metrics_module.CategoricalAccuracy(), 'mae'])
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- batch_size=5,
- verbose=0)
+ model = keras.models.Model([a, b], [d, e])
- # Test with dictionaries for loss, metrics, loss weights
- loss = {'dense': 'mse', 'dropout': 'mae'}
- loss_weights = {'dense': 1., 'dropout': 0.5}
- metrics = {
- 'dense': 'mse',
- 'dropout': metrics_module.CategoricalAccuracy()
- }
- model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test fit at different verbosity
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ # Test model with input data as a list of lists
+ model.fit(
+ [np.ndarray.tolist(input_a_np), np.ndarray.tolist(input_b_np)],
+ [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+
+ # Test with validation data
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ validation_data=([input_a_np, input_b_np], [output_d_np,
+ output_e_np]),
+ epochs=2,
+ batch_size=5,
+ verbose=2)
+ # Test with validation split
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=2,
+ batch_size=5,
+ verbose=0,
+ validation_split=0.2)
+
+ # Test with dictionary inputs
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ epochs=1,
+ batch_size=5,
+ verbose=1)
+ model.fit(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ validation_data=({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ }),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ model.train_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ })
+
+ # Test with lists for loss, metrics
+ loss = ['mae', 'mse']
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'])
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Test with dictionaries for loss, metrics, loss weights
+ loss = {'dense': 'mse', 'dropout': 'mae'}
+ loss_weights = {'dense': 1., 'dropout': 0.5}
+ metrics = {
+ 'dense': 'mse',
+ 'dropout': metrics_module.CategoricalAccuracy()
+ }
+ model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights)
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+
+ # Invalid use cases
+ with self.assertRaises(ValueError):
+ model.train_on_batch({'input_a': input_a_np},
+ [output_d_np, output_e_np])
+ with self.assertRaises(AttributeError):
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=1,
- batch_size=5,
+ validation_data=([input_a_np, input_b_np], 0, 0),
verbose=0)
+ with self.assertRaises(ValueError):
+ model.train_on_batch([input_a_np], [output_d_np, output_e_np])
+ with self.assertRaises(AttributeError):
+ model.train_on_batch(1, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ model.train_on_batch(input_a_np, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_input = np.random.random((11, 3))
+ model.train_on_batch([bad_input, input_b_np],
+ [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_target = np.random.random((11, 4))
+ model.train_on_batch([input_a_np, input_b_np],
+ [bad_target, output_e_np])
+
+ # Build single-input model
+ x = keras.layers.Input(shape=(3,), name='input_a')
+ y = keras.layers.Dense(4)(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer, loss='mse')
+ # This will work
+ model.fit([input_a_np], output_d_np, epochs=1)
+ with self.assertRaises(ValueError):
+ model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
- # Invalid use cases
- with self.assertRaises(ValueError):
- model.train_on_batch({'input_a': input_a_np},
- [output_d_np, output_e_np])
- with self.assertRaises(AttributeError):
- model.fit(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- epochs=1,
- validation_data=([input_a_np, input_b_np], 0, 0),
- verbose=0)
- with self.assertRaises(ValueError):
- model.train_on_batch([input_a_np], [output_d_np, output_e_np])
- with self.assertRaises(AttributeError):
- model.train_on_batch(1, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- model.train_on_batch(input_a_np, [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_input = np.random.random((11, 3))
- model.train_on_batch([bad_input, input_b_np],
- [output_d_np, output_e_np])
- with self.assertRaises(ValueError):
- bad_target = np.random.random((11, 4))
- model.train_on_batch([input_a_np, input_b_np],
- [bad_target, output_e_np])
-
- # Build single-input model
- x = keras.layers.Input(shape=(3,), name='input_a')
- y = keras.layers.Dense(4)(x)
- model = keras.models.Model(x, y)
- model.compile(optimizer, loss='mse')
- # This will work
- model.fit([input_a_np], output_d_np, epochs=1)
- with self.assertRaises(ValueError):
- model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
-
- # Test model on a list of floats
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 4))
+ # Test model on a list of floats
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 4))
- model.fit([np.ndarray.tolist(input_a_np)],
- [np.ndarray.tolist(input_b_np)],
- epochs=2,
- batch_size=5,
- verbose=2)
+ model.fit([np.ndarray.tolist(input_a_np)],
+ [np.ndarray.tolist(input_b_np)],
+ epochs=2,
+ batch_size=5,
+ verbose=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_evaluate_predict_on_arrays(self):
- with self.test_session():
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
-
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- loss_weights = [1., 0.5]
- model.compile(
- optimizer,
- loss,
- metrics=['mae', metrics_module.CategoricalAccuracy()],
- loss_weights=loss_weights,
- sample_weight_mode=None)
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
+ model = keras.models.Model([a, b], [d, e])
- # Test evaluate at different verbosity
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=0)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=1)
- self.assertEqual(len(out), 7)
- out = model.evaluate(
- [input_a_np, input_b_np], [output_d_np, output_e_np],
- batch_size=5,
- verbose=2)
- self.assertEqual(len(out), 7)
- out = model.test_on_batch([input_a_np, input_b_np],
- [output_d_np, output_e_np])
- self.assertEqual(len(out), 7)
-
- # Test evaluate with dictionary inputs
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- batch_size=5,
- verbose=0)
- model.evaluate(
- {
- 'input_a': input_a_np,
- 'input_b': input_b_np
- }, {
- 'dense': output_d_np,
- 'dropout': output_e_np
- },
- batch_size=5,
- verbose=1)
-
- # Test predict
- out = model.predict([input_a_np, input_b_np], batch_size=5)
- self.assertEqual(len(out), 2)
- out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
- self.assertEqual(len(out), 2)
- out = model.predict_on_batch({
- 'input_a': input_a_np,
- 'input_b': input_b_np
- })
- self.assertEqual(len(out), 2)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=['mae', metrics_module.CategoricalAccuracy()],
+ loss_weights=loss_weights,
+ sample_weight_mode=None)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ # Test evaluate at different verbosity
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=0)
+ self.assertEqual(len(out), 7)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=1)
+ self.assertEqual(len(out), 7)
+ out = model.evaluate(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ batch_size=5,
+ verbose=2)
+ self.assertEqual(len(out), 7)
+ out = model.test_on_batch([input_a_np, input_b_np],
+ [output_d_np, output_e_np])
+ self.assertEqual(len(out), 7)
+
+ # Test evaluate with dictionary inputs
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ batch_size=5,
+ verbose=0)
+ model.evaluate(
+ {
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
+ batch_size=5,
+ verbose=1)
+
+ # Test predict
+ out = model.predict([input_a_np, input_b_np], batch_size=5)
+ self.assertEqual(len(out), 2)
+ out = model.predict({'input_a': input_a_np, 'input_b': input_b_np})
+ self.assertEqual(len(out), 2)
+ out = model.predict_on_batch({
+ 'input_a': input_a_np,
+ 'input_b': input_b_np
+ })
+ self.assertEqual(len(out), 2)
@tf_test_util.run_in_graph_and_eager_modes
def test_invalid_loss(self):
@@ -340,37 +340,33 @@ class TrainingTest(test.TestCase):
test_samples = 1000
input_dim = 5
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(optimizer, loss='categorical_crossentropy')
- np.random.seed(1337)
- (x_train, y_train), (_, _) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, loss='categorical_crossentropy')
+ np.random.seed(1337)
+ (x_train, y_train), (_, _) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
- with self.assertRaises(ValueError):
- model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
+ with self.assertRaises(ValueError):
+ model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
- if not context.executing_eagerly():
- # TODO(psv): Investigate these use cases in eager mode.
- with self.assertRaises(ValueError):
- model.fit(x_train, y_train)
+ if not context.executing_eagerly():
+ # TODO(psv): Investigate these use cases in eager mode.
+ with self.assertRaises(ValueError):
+ model.fit(x_train, y_train)
- with self.assertRaises(ValueError):
- model.compile(optimizer, loss=None)
+ with self.assertRaises(ValueError):
+ model.compile(optimizer, loss=None)
def test_training_on_sparse_data_with_dense_placeholders(self):
if scipy_sparse is None:
return
- with self.test_session():
+ with self.cached_session():
test_inputs = [
scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)
]
@@ -392,11 +388,24 @@ class TrainingTest(test.TestCase):
epochs=1, batch_size=2, validation_split=0.5)
model.evaluate(test_inputs, test_outputs, batch_size=2)
+ def test_compile_with_sparse_placeholders(self):
+ with self.cached_session():
+ input_layer = keras.layers.Input(shape=(10,), sparse=True)
+ weights = variables_lib.Variable(
+ np.ones((10, 1)).astype(np.float32), name='weights')
+ weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights)
+ output_layer = keras.layers.Lambda(weights_mult)(input_layer)
+ model = keras.Model([input_layer], output_layer)
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer=keras.optimizers.Adam(lr=0.0001),
+ metrics=['accuracy'])
+
def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,))
b = layer(a)
@@ -432,7 +441,7 @@ class TrainingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_compile_warning_for_loss_missing_output(self):
- with self.test_session():
+ with self.cached_session():
inp = keras.layers.Input(shape=(16,), name='input_a')
out_1 = keras.layers.Dense(8, name='dense_1')(inp)
out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1)
@@ -468,67 +477,82 @@ class LossWeightingTest(test.TestCase):
input_dim = 5
learning_rate = 0.001
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- model.compile(
- loss='categorical_crossentropy',
- metrics=['acc'],
- weighted_metrics=['mae'],
- optimizer=RMSPropOptimizer(learning_rate=learning_rate))
-
- np.random.seed(1337)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
- int_y_test = y_test.copy()
- int_y_train = y_train.copy()
- # convert class vectors to binary class matrices
- y_train = keras.utils.to_categorical(y_train, num_classes)
- y_test = keras.utils.to_categorical(y_test, num_classes)
- test_ids = np.where(int_y_test == np.array(weighted_class))[0]
-
- class_weight = dict([(i, 1.) for i in range(num_classes)])
- class_weight[weighted_class] = 2.
-
- sample_weight = np.ones((y_train.shape[0]))
- sample_weight[int_y_train == weighted_class] = 2.
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ model.compile(
+ loss='categorical_crossentropy',
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
+ optimizer=RMSPropOptimizer(learning_rate=learning_rate))
+
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ class_weight = dict([(i, 1.) for i in range(num_classes)])
+ class_weight[weighted_class] = 2.
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ class_weight=class_weight,
+ validation_data=(x_train, y_train, sample_weight))
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 2,
+ verbose=0,
+ class_weight=class_weight,
+ validation_split=0.1)
+
+ model.train_on_batch(
+ x_train[:batch_size], y_train[:batch_size], class_weight=class_weight)
+ ref_score = model.evaluate(x_test, y_test, verbose=0)
+ score = model.evaluate(
+ x_test[test_ids, :], y_test[test_ids, :], verbose=0)
+ self.assertLess(score[0], ref_score[0])
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
- verbose=0,
- class_weight=class_weight,
- validation_data=(x_train, y_train, sample_weight))
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 2,
- verbose=0,
- class_weight=class_weight)
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 2,
- verbose=0,
- class_weight=class_weight,
- validation_split=0.1)
-
- model.train_on_batch(
- x_train[:batch_size], y_train[:batch_size], class_weight=class_weight)
- ref_score = model.evaluate(x_test, y_test, verbose=0)
- score = model.evaluate(
- x_test[test_ids, :], y_test[test_ids, :], verbose=0)
- self.assertLess(score[0], ref_score[0])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model_fails_with_dict_inputs(self):
+ num_classes = 5
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes)
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ metrics=['acc'],
+ weighted_metrics=['mae'],
+ loss='categorical_crossentropy')
+
+ x = {'dense_input': np.random.random((10, 1))}
+ y = np.random.randint(num_classes, size=(10, 1))
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Passing a dictionary input to a Sequential Model which '
+ 'doesnt have FeatureLayer as the first layer is an error'):
+ model.fit(x, y, batch_size=5, epochs=1)
@tf_test_util.run_in_graph_and_eager_modes
def test_sample_weights(self):
@@ -541,63 +565,82 @@ class LossWeightingTest(test.TestCase):
input_dim = 5
learning_rate = 0.001
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
- model.add(keras.layers.Activation('relu'))
- model.add(keras.layers.Dense(num_classes))
- model.add(keras.layers.Activation('softmax'))
- model.compile(
- RMSPropOptimizer(learning_rate=learning_rate),
- metrics=['acc'],
- weighted_metrics=['mae'],
- loss='categorical_crossentropy')
-
- np.random.seed(43)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=train_samples,
- test_samples=test_samples,
- input_shape=(input_dim,),
- num_classes=num_classes)
- int_y_test = y_test.copy()
- int_y_train = y_train.copy()
- # convert class vectors to binary class matrices
- y_train = keras.utils.to_categorical(y_train, num_classes)
- y_test = keras.utils.to_categorical(y_test, num_classes)
- test_ids = np.where(int_y_test == np.array(weighted_class))[0]
-
- sample_weight = np.ones((y_train.shape[0]))
- sample_weight[int_y_train == weighted_class] = 2.
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
+ model.compile(
+ RMSPropOptimizer(learning_rate=learning_rate),
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
+ loss='categorical_crossentropy')
+
+ np.random.seed(43)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ int_y_test = y_test.copy()
+ int_y_train = y_train.copy()
+ # convert class vectors to binary class matrices
+ y_train = keras.utils.to_categorical(y_train, num_classes)
+ y_test = keras.utils.to_categorical(y_test, num_classes)
+ test_ids = np.where(int_y_test == np.array(weighted_class))[0]
+
+ sample_weight = np.ones((y_train.shape[0]))
+ sample_weight[int_y_train == weighted_class] = 2.
+
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs // 3,
+ verbose=0,
+ sample_weight=sample_weight,
+ validation_split=0.1)
+
+ model.train_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+ model.test_on_batch(
+ x_train[:batch_size],
+ y_train[:batch_size],
+ sample_weight=sample_weight[:batch_size])
+ ref_score = model.evaluate(x_test, y_test, verbose=0)
+ if not context.executing_eagerly():
+ score = model.evaluate(
+ x_test[test_ids, :], y_test[test_ids, :], verbose=0)
+ self.assertLess(score[0], ref_score[0])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_warning_for_concurrent_sample_and_class_weights(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(3,)))
+ model.compile(
+ loss='mse',
+ optimizer=RMSPropOptimizer(learning_rate=0.01))
+ x_train = np.random.random((10, 3))
+ y_train = np.random.random((10, 10))
+ sample_weight = np.ones((y_train.shape[0]))
+ class_weight = {0: 1., 1: 1.}
+
+ with test.mock.patch.object(logging, 'warning') as mock_log:
model.fit(
x_train,
y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
- verbose=0,
- sample_weight=sample_weight)
- model.fit(
- x_train,
- y_train,
- batch_size=batch_size,
- epochs=epochs // 3,
+ epochs=1,
verbose=0,
sample_weight=sample_weight,
- validation_split=0.1)
-
- model.train_on_batch(
- x_train[:batch_size],
- y_train[:batch_size],
- sample_weight=sample_weight[:batch_size])
- model.test_on_batch(
- x_train[:batch_size],
- y_train[:batch_size],
- sample_weight=sample_weight[:batch_size])
- ref_score = model.evaluate(x_test, y_test, verbose=0)
- if not context.executing_eagerly():
- score = model.evaluate(
- x_test[test_ids, :], y_test[test_ids, :], verbose=0)
- self.assertLess(score[0], ref_score[0])
+ class_weight=class_weight)
+ msg = ('The `class_weight` argument will be ignored.')
+ self.assertRegexpMatches(str(mock_log.call_args), msg)
@tf_test_util.run_in_graph_and_eager_modes
def test_temporal_sample_weights(self):
@@ -611,7 +654,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -655,8 +698,8 @@ class LossWeightingTest(test.TestCase):
model.compile(
RMSPropOptimizer(learning_rate=learning_rate),
loss='binary_crossentropy',
- metrics=['acc'],
- weighted_metrics=['mae'],
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
sample_weight_mode='temporal')
model.fit(
@@ -698,7 +741,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -767,7 +810,7 @@ class LossWeightingTest(test.TestCase):
timesteps = 3
learning_rate = 0.001
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -811,7 +854,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_graph_sequential(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
@@ -825,7 +868,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_deferred_sequential(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0))
@@ -839,7 +882,7 @@ class LossMaskingTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_masking_functional(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
inputs = keras.layers.Input((2, 1))
outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -869,7 +912,7 @@ class LossMaskingTest(test.TestCase):
def compute_output_shape(self, input_shape):
return input_shape
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
inputs = keras.layers.Input((3,))
masked = keras.layers.Masking(mask_value=0)(inputs)
@@ -881,7 +924,7 @@ class LossMaskingTest(test.TestCase):
model.train_on_batch(x, y)
def test_loss_masking(self):
- with self.test_session():
+ with self.cached_session():
weighted_loss = weighted_masked_objective(keras.losses.get('mae'))
shape = (3, 4, 2)
x = np.arange(24).reshape(shape)
@@ -902,12 +945,12 @@ class LossMaskingTest(test.TestCase):
class LearningPhaseTest(test.TestCase):
def test_empty_model_no_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
self.assertFalse(model.uses_learning_phase)
def test_dropout_has_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_dim=3))
model.add(keras.layers.Dropout(0.5))
@@ -918,7 +961,7 @@ class LearningPhaseTest(test.TestCase):
class TestDynamicTrainability(test.TestCase):
def test_trainable_warning(self):
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
y = np.random.random((5, 2))
@@ -931,7 +974,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertRaises(Warning)
def test_trainable_argument(self):
- with self.test_session():
+ with self.cached_session():
x = np.random.random((5, 3))
y = np.random.random((5, 2))
@@ -954,7 +997,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertAllClose(out, out_2)
def test_layer_trainability_switch(self):
- with self.test_session():
+ with self.cached_session():
# with constructor argument, in Sequential
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, trainable=False, input_dim=1))
@@ -984,7 +1027,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertListEqual(model.trainable_weights, [])
def test_model_trainability_switch(self):
- with self.test_session():
+ with self.cached_session():
# a non-trainable model has no trainable weights
x = keras.layers.Input(shape=(1,))
y = keras.layers.Dense(2)(x)
@@ -999,7 +1042,7 @@ class TestDynamicTrainability(test.TestCase):
self.assertListEqual(model.trainable_weights, [])
def test_nested_model_trainability(self):
- with self.test_session():
+ with self.cached_session():
# a Sequential inside a Model
inner_model = keras.models.Sequential()
inner_model.add(keras.layers.Dense(2, input_dim=1))
@@ -1078,7 +1121,7 @@ class TestGeneratorMethods(test.TestCase):
y = arr_labels[start: end]
yield x, y
- with self.test_session():
+ with self.cached_session():
x = keras.Input((2,))
y = keras.layers.Dense(1)(x)
fn_model = keras.models.Model(x, y)
@@ -1164,7 +1207,7 @@ class TestGeneratorMethods(test.TestCase):
w = arr_sample_weights[start: end]
yield x, y, w
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(
@@ -1201,7 +1244,7 @@ class TestGeneratorMethods(test.TestCase):
while 1:
yield 0
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
@@ -1259,7 +1302,7 @@ class TestGeneratorMethods(test.TestCase):
w = arr_sample_weights[start: end]
yield x, y, w
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(2,)))
model.compile(loss='mse', optimizer='sgd')
@@ -1317,7 +1360,7 @@ class TestTrainingUtils(test.TestCase):
class TestTrainingWithDataTensors(test.TestCase):
def test_training_and_eval_methods_on_symbolic_tensors_single_io(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -1357,7 +1400,7 @@ class TestTrainingWithDataTensors(test.TestCase):
validation_data=(inputs, targets), validation_steps=2)
def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -1458,16 +1501,17 @@ class TestTrainingWithDataTensors(test.TestCase):
by only passing them data for the placeholder inputs
in the model.
"""
- with self.test_session():
+ with self.cached_session():
input_a_np = np.random.random((10, 3))
input_b_np = np.random.random((10, 3))
output_a_np = np.random.random((10, 4))
output_b_np = np.random.random((10, 3))
- a = keras.Input(
- tensor=keras.backend.variables_module.Variable(input_a_np,
- dtype='float32'))
+ input_v = keras.backend.variables_module.Variable(
+ input_a_np, dtype='float32')
+ self.evaluate(variables_lib.variables_initializer([input_v]))
+ a = keras.Input(tensor=input_v)
b = keras.Input(shape=(3,), name='input_b')
a_2 = keras.layers.Dense(4, name='dense_1')(a)
@@ -1512,9 +1556,8 @@ class TestTrainingWithDataTensors(test.TestCase):
# Now test a model with a single input
# i.e. we don't pass any data to fit the model.
- a = keras.Input(
- tensor=keras.backend.variables_module.Variable(input_a_np,
- dtype='float32'))
+ self.evaluate(variables_lib.variables_initializer([input_v]))
+ a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
model = keras.models.Model(a, a_2)
@@ -1552,9 +1595,8 @@ class TestTrainingWithDataTensors(test.TestCase):
# Same, without learning phase
# i.e. we don't pass any data to fit the model.
- a = keras.Input(
- tensor=keras.backend.variables_module.Variable(input_a_np,
- dtype='float32'))
+ self.evaluate(variables_lib.variables_initializer([input_v]))
+ a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
model = keras.models.Model(a, a_2)
model.summary()
@@ -1590,7 +1632,7 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out.shape, (10 * 3, 4))
def test_model_with_partial_loss(self):
- with self.test_session():
+ with self.cached_session():
a = keras.Input(shape=(3,), name='input_a')
a_2 = keras.layers.Dense(4, name='dense_1')(a)
dp = keras.layers.Dropout(0.5, name='dropout')
@@ -1631,7 +1673,7 @@ class TestTrainingWithDataTensors(test.TestCase):
_ = model.evaluate(input_a_np, [output_a_np])
def test_model_with_external_loss(self):
- with self.test_session():
+ with self.cached_session():
# None loss, only regularization loss.
a = keras.Input(shape=(3,), name='input_a')
a_2 = keras.layers.Dense(4, name='dense_1',
@@ -1677,9 +1719,10 @@ class TestTrainingWithDataTensors(test.TestCase):
out = model.evaluate(input_a_np, None)
# Test model with no external data at all.
- a = keras.Input(
- tensor=keras.backend.variables_module.Variable(input_a_np,
- dtype='float32'))
+ input_v = keras.backend.variables_module.Variable(
+ input_a_np, dtype='float32')
+ self.evaluate(variables_lib.variables_initializer([input_v]))
+ a = keras.Input(tensor=input_v)
a_2 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
model = keras.models.Model(a, a_2)
@@ -1720,9 +1763,8 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out.shape, (10 * 3, 4))
# Test multi-output model with no external data at all.
- a = keras.Input(
- tensor=keras.backend.variables_module.Variable(input_a_np,
- dtype='float32'))
+ self.evaluate(variables_lib.variables_initializer([input_v]))
+ a = keras.Input(tensor=input_v)
a_1 = keras.layers.Dense(4, name='dense_1')(a)
a_2 = keras.layers.Dropout(0.5, name='dropout')(a_1)
model = keras.models.Model(a, [a_1, a_2])
@@ -1761,7 +1803,7 @@ class TestTrainingWithDataTensors(test.TestCase):
self.assertEqual(out[1].shape, (10 * 3, 4))
def test_target_tensors(self):
- with self.test_session():
+ with self.cached_session():
# single-output, as list
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,), name='dense'))
@@ -1822,7 +1864,7 @@ class TestTrainingWithDataTensors(test.TestCase):
sample_weight={'dense_a': np.random.random((10,))})
def test_model_custom_target_tensors(self):
- with self.test_session():
+ with self.cached_session():
a = keras.Input(shape=(3,), name='input_a')
b = keras.Input(shape=(3,), name='input_b')
@@ -1886,223 +1928,235 @@ class TestTrainingWithDatasetIterators(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_iterators_single_io(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae', metrics_module.CategoricalAccuracy()]
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
-
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
- model.evaluate(iterator, steps=2, verbose=1)
- model.predict(iterator, steps=2)
- model.train_on_batch(iterator)
- model.test_on_batch(iterator)
- model.predict_on_batch(iterator)
-
- # Test with validation data
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(iterator, steps=2, verbose=1)
+ model.predict(iterator, steps=2)
+ model.train_on_batch(iterator)
+ model.test_on_batch(iterator)
+ model.predict_on_batch(iterator)
+
+ # Test with validation data
+ model.fit(iterator,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=iterator, validation_steps=2)
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=iterator, validation_steps=2)
- # Test with validation split
- with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(iterator,
- epochs=1, steps_per_epoch=2, verbose=0,
- validation_split=0.5, validation_steps=2)
-
- # Test with sample weight.
- sample_weight = np.random.random((10,))
- with self.assertRaisesRegexp(
- ValueError, '`sample_weight` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(
- iterator,
- epochs=1,
- steps_per_epoch=2,
- verbose=0,
- sample_weight=sample_weight)
+ validation_split=0.5, validation_steps=2)
- # Test invalid usage
- with self.assertRaisesRegexp(ValueError,
- 'you should not specify a target'):
- model.fit(iterator, iterator,
- epochs=1, steps_per_epoch=2, verbose=0)
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(
+ iterator,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
- with self.assertRaisesRegexp(
- ValueError, 'you should specify the `steps_per_epoch` argument'):
- model.fit(iterator, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.evaluate(iterator, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.predict(iterator, verbose=0)
+ # Test invalid usage
+ with self.assertRaisesRegexp(ValueError,
+ 'you should not specify a target'):
+ model.fit(iterator, iterator,
+ epochs=1, steps_per_epoch=2, verbose=0)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(iterator, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(iterator, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(iterator, verbose=0)
@tf_test_util.run_in_graph_and_eager_modes
def test_get_next_op_created_once(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
-
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
- # Finalize graph to make sure we are not appending another iterator
- # get_next op in the graph.
- ops.get_default_graph().finalize()
- model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
+ # Finalize graph to make sure we are not appending another iterator
+ # get_next op in the graph.
+ ops.get_default_graph().finalize()
+ model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
@tf_test_util.run_in_graph_and_eager_modes
def test_iterators_running_out_of_data(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(2)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(2)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
-
- with test.mock.patch.object(logging, 'warning') as mock_log:
- model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
- self.assertRegexpMatches(
- str(mock_log.call_args),
- 'dataset iterator ran out of data')
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ 'dataset iterator ran out of data')
class TestTrainingWithDataset(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_calling_model_on_same_dataset(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae']
- model.compile(optimizer, loss, metrics=metrics)
-
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
-
- # Call fit with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
- # Finalize the graph to make sure new ops aren't added when calling on the
- # same dataset
- ops.get_default_graph().finalize()
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ # Call fit with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+ # Finalize the graph to make sure new ops aren't added when calling on the
+ # same dataset
+ ops.get_default_graph().finalize()
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
@tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_dataset(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- metrics = ['mae', metrics_module.CategoricalAccuracy()]
- model.compile(optimizer, loss, metrics=metrics)
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
+
+ # Test with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(dataset,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_split=0.5, validation_steps=2)
- inputs = np.zeros((10, 3))
- targets = np.zeros((10, 4))
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
- model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
- model.train_on_batch(dataset)
- model.predict_on_batch(dataset)
+ # Test invalid usage
+ with self.assertRaisesRegexp(ValueError,
+ 'you should not specify a target'):
+ model.fit(dataset, dataset,
+ epochs=1, steps_per_epoch=2, verbose=0)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(dataset, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(dataset, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(dataset, verbose=0)
- # Test with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sample_weights(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), np.float32)
+ targets = np.zeros((10, 4), np.float32)
+ sample_weights = np.ones((10), np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
+ sample_weights))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
- # Test with validation split
- with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(dataset,
- epochs=1, steps_per_epoch=2, verbose=0,
- validation_split=0.5, validation_steps=2)
-
- # Test with sample weight.
- sample_weight = np.random.random((10,))
- with self.assertRaisesRegexp(
- ValueError, '`sample_weight` argument is not supported '
- 'when input `x` is a dataset or a dataset iterator'):
- model.fit(
- dataset,
- epochs=1,
- steps_per_epoch=2,
- verbose=0,
- sample_weight=sample_weight)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sparse_labels(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'sparse_categorical_crossentropy'
+ model.compile(optimizer, loss)
- # Test invalid usage
- with self.assertRaisesRegexp(ValueError,
- 'you should not specify a target'):
- model.fit(dataset, dataset,
- epochs=1, steps_per_epoch=2, verbose=0)
+ inputs = np.zeros((10, 3))
+ targets = np.random.randint(0, 4, size=10, dtype=np.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
- with self.assertRaisesRegexp(
- ValueError, 'you should specify the `steps_per_epoch` argument'):
- model.fit(dataset, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.evaluate(dataset, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.predict(dataset, verbose=0)
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
def test_dataset_input_shape_validation(self):
- with self.test_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'mse'
- model.compile(optimizer, loss)
+ with self.cached_session():
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
# User forgets to batch the dataset
inputs = np.zeros((10, 3))
@@ -2110,8 +2164,10 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
+ ):
model.train_on_batch(dataset)
# Wrong input shape
@@ -2122,7 +2178,7 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset.batch(10)
with self.assertRaisesRegexp(ValueError,
- 'expected input to have shape'):
+ r'expected (.*?) to have shape \(3,\)'):
model.train_on_batch(dataset)
@@ -2153,138 +2209,131 @@ class TestTrainingWithMetrics(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 3, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='mae',
- metrics=['accuracy', metrics_module.BinaryAccuracy()],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- # verify correctness of stateful and stateless metrics.
- x = np.ones((100, 4))
- y = np.ones((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 1.)
- self.assertEqual(outs[2], 1.)
-
- y = np.zeros((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 0.)
- self.assertEqual(outs[2], 0.)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 3, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(
+ 1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='mae',
+ metrics=['accuracy', metrics_module.BinaryAccuracy()],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ # verify correctness of stateful and stateless metrics.
+ x = np.ones((100, 4))
+ y = np.ones((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 1.)
+ self.assertEqual(outs[2], 1.)
+
+ y = np.zeros((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 0.)
+ self.assertEqual(outs[2], 0.)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_iterator(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 8, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='binary_crossentropy',
- metrics=['accuracy', metrics_module.BinaryAccuracy()],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- np.random.seed(123)
- x = np.random.randint(10, size=(100, 4)).astype(np.float32)
- y = np.random.randint(2, size=(100, 1)).astype(np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(np.around(outs[1], decimals=1), 0.5)
- self.assertEqual(np.around(outs[2], decimals=1), 0.5)
-
- y = np.zeros((100, 1), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(outs[1], 0.)
- self.assertEqual(outs[2], 0.)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 8, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(
+ 1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='binary_crossentropy',
+ metrics=['accuracy', metrics_module.BinaryAccuracy()],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ np.random.seed(123)
+ x = np.random.randint(10, size=(100, 4)).astype(np.float32)
+ y = np.random.randint(2, size=(100, 1)).astype(np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(np.around(outs[1], decimals=1), 0.5)
+ self.assertEqual(np.around(outs[2], decimals=1), 0.5)
+
+ y = np.zeros((100, 1), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(outs[1], 0.)
+ self.assertEqual(outs[2], 0.)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_weighted_metrics(self):
- with self.test_session():
- np.random.seed(1337)
- x = np.array([[[1.], [1.]], [[0.], [0.]]])
- model = keras.models.Sequential()
- model.add(
- keras.layers.TimeDistributed(
- keras.layers.Dense(1, kernel_initializer='ones'),
- input_shape=(2, 1)))
- model.compile(
- RMSPropOptimizer(learning_rate=0.001),
- loss='mse',
- sample_weight_mode='temporal',
- weighted_metrics=['accuracy',
- metrics_module.BinaryAccuracy()])
- y = np.array([[[1.], [1.]], [[1.], [1.]]])
-
- outs = model.evaluate(x, y)
- self.assertEqual(outs, [0.5, 0.5, 0.5])
-
- w = np.array([[0., 0.], [0., 0.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertEqual(outs, [0., 0., 0.])
-
- w = np.array([[3., 4.], [1., 2.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001)
+ np.random.seed(1337)
+ x = np.array([[[1.], [1.]], [[0.], [0.]]])
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(1, kernel_initializer='ones'),
+ input_shape=(2, 1)))
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss='mse',
+ sample_weight_mode='temporal',
+ weighted_metrics=['accuracy',
+ metrics_module.BinaryAccuracy()])
+ y = np.array([[[1.], [1.]], [[1.], [1.]]])
+
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs, [0.5, 0.5, 0.5])
+
+ w = np.array([[0., 0.], [0., 0.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertEqual(outs, [0., 0., 0.])
+
+ w = np.array([[3., 4.], [1., 2.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001)
@tf_test_util.run_in_graph_and_eager_modes
def test_metric_state_reset_between_fit_and_evaluate(self):
- with self.test_session():
- model = keras.Sequential()
- model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
- model.add(keras.layers.Dense(1, activation='sigmoid'))
- acc_obj = metrics_module.BinaryAccuracy()
- model.compile(
- loss='mae',
- metrics=[acc_obj],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
-
- x_train = np.random.random((100, 4))
- y_train = np.random.random((100, 1))
- model.fit(x_train, y_train, batch_size=5, epochs=2)
- self.assertEqual(self.evaluate(acc_obj.count), 100)
-
- x_test = np.random.random((10, 4))
- y_test = np.random.random((10, 1))
- model.evaluate(x_test, y_test, batch_size=5)
- self.assertEqual(self.evaluate(acc_obj.count), 10)
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+ acc_obj = metrics_module.BinaryAccuracy()
+ model.compile(
+ loss='mae',
+ metrics=[acc_obj],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ x_train = np.random.random((100, 4))
+ y_train = np.random.random((100, 1))
+ model.fit(x_train, y_train, batch_size=5, epochs=2)
+ self.assertEqual(self.evaluate(acc_obj.count), 100)
+
+ x_test = np.random.random((10, 4))
+ y_test = np.random.random((10, 1))
+ model.evaluate(x_test, y_test, batch_size=5)
+ self.assertEqual(self.evaluate(acc_obj.count), 10)
@tf_test_util.run_in_graph_and_eager_modes
def test_invalid_metrics(self):
num_classes = 5
input_dim = 5
- with self.test_session():
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(10, activation='relu', input_shape=(input_dim,)))
- model.add(keras.layers.Dense(num_classes, activation='softmax'))
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes, input_dim=input_dim)
- with self.assertRaisesRegexp(
- TypeError, 'Type of `metrics` argument not understood. '
- 'Expected a list or dictionary, found: '):
- model.compile(
- RMSPropOptimizer(learning_rate=0.001),
- loss='categorical_crossentropy',
- metrics=metrics_module.CategoricalAccuracy())
+ with self.assertRaisesRegexp(
+ TypeError, 'Type of `metrics` argument not understood. '
+ 'Expected a list or dictionary, found: '):
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=metrics_module.CategoricalAccuracy())
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_masking(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index f94697c913..898e9223cb 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -22,18 +22,22 @@ import copy
import math
import numpy as np
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.util import nest
def _map_nested(data, func):
@@ -210,10 +214,11 @@ def check_num_samples(ins,
def standardize_single_array(x):
if x is None:
return None
- elif tensor_util.is_tensor(x):
- return x
- elif x.ndim == 1:
- x = np.expand_dims(x, 1)
+ if x.shape is not None and len(x.shape) == 1:
+ if tensor_util.is_tensor(x):
+ return array_ops.expand_dims(x, axis=1)
+ else:
+ return np.expand_dims(x, 1)
return x
@@ -245,7 +250,8 @@ def standardize_input_data(data,
ValueError: in case of improperly formatted user-provided data.
"""
if not names:
- if data is not None and hasattr(data, '__len__') and len(data):
+ if (data is not None and hasattr(data, '__len__') and len(data) and
+ not isinstance(data, dict)):
raise ValueError('Error when checking model ' + exception_prefix + ': '
'expected no data, but got:', data)
return []
@@ -341,7 +347,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
Raises:
ValueError: In case of invalid user-provided argument.
"""
- if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test
+ if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
return [None for _ in output_names]
if len(output_names) == 1:
if isinstance(x_weight, list) and len(x_weight) == 1:
@@ -675,7 +681,8 @@ def standardize_weights(y,
'Expected sample_weight with rank '
'less than or equal to ' + str(len(y.shape)))
- if y.shape[:sample_weight.ndim] != sample_weight.shape:
+ if (not tensor_util.is_tensor(sample_weight) and
+ y.shape[:sample_weight.ndim] != sample_weight.shape):
raise ValueError(
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
' for an input with shape ' + str(y.shape) + '. '
@@ -717,6 +724,8 @@ def has_symbolic_tensors(ls):
def has_tensors(ls):
if isinstance(ls, (list, tuple)):
return any(tensor_util.is_tensor(v) for v in ls)
+ if isinstance(ls, dict):
+ return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
return tensor_util.is_tensor(ls)
@@ -777,7 +786,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: %s' % (x, y))
if sample_weight is not None:
raise ValueError('`sample_weight` argument is not supported when input '
- '`x` is a dataset or a dataset iterator. '
+ '`x` is a dataset or a dataset iterator. Instead, you'
+ 'can provide sample_weight as the third element of your'
+ 'dataset, i.e. (inputs, targets, sample_weight). '
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
if validation_split is not None and validation_split != 0.0:
raise ValueError(
@@ -825,6 +836,12 @@ def check_steps_argument(input_data, steps, steps_name):
return False
+def cast_single_tensor(x):
+ if tensor_util.is_tensor(x) and x.dtype.is_floating:
+ return math_ops.cast(x, dtype=K.floatx())
+ return x
+
+
def cast_if_floating_dtype(x):
"""Casts the given data tensors to the default floating point type.
@@ -842,13 +859,7 @@ def cast_if_floating_dtype(x):
raise RuntimeError(
'Please provide tensors for casting, got: {x}'.format(x=x))
- if isinstance(x, (list, tuple)):
- return [
- math_ops.cast(val, dtype=K.floatx())
- if tensor_util.is_tensor(val) and val.dtype.is_floating else val
- for val in x
- ]
- return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
+ return nest.map_structure(cast_single_tensor, x)
def get_output_sample_weight_and_mode(skip_target_weighing_indices,
@@ -929,3 +940,103 @@ def prepare_sample_weights(output_names, sample_weight_mode,
sample_weights.append(weight)
sample_weight_modes.append(mode)
return sample_weights, sample_weight_modes
+
+
+# TODO(rohanj): This is a hack to get around not depending on feature_column and
+# create a cyclical dependency. Figure out a cleaner solution
+def is_feature_layer(layer):
+ """Returns whether `layer` is a FeatureLayer or not."""
+ return getattr(layer, '_is_feature_layer', False)
+
+
+class ModelInputs(object):
+ """Encapsulates model inputs.
+
+ Allows for transforming model inputs while keeping the same structure.
+ """
+
+ def __init__(self, inputs):
+ self._inputs = inputs
+ self._is_dict = isinstance(self._inputs, dict)
+ self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
+ self._flattened_inputs = []
+ self._input_names = []
+ if isinstance(self._inputs, dict):
+ for k in sorted(self._inputs.keys()):
+ self._flattened_inputs.append(self._inputs[k])
+ self._input_names.append(k)
+ else:
+ self._flattened_inputs = nest.flatten(self._inputs)
+ self._input_names = [
+ 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
+ ]
+ assert len(self._input_names) == len(self._flattened_inputs)
+
+ def get_input_names(self):
+ """Returns keys to name inputs by.
+
+ In case inputs provided were a list, tuple or single entry, we make up a
+ key 'input_%d'. For dictionary case, we return a sorted list of keys.
+ """
+ return self._input_names
+
+ def _get(self, return_single_as_list=False):
+ """Returns provided inputs, potentially transformed.
+
+ Inputs are returned in the same format they were provided i.e. lists
+ are returned as lists, single entries as single entries (unless
+ `return_single_as_list` is true), dictionaries as dictionaries.
+
+ Args:
+ return_single_as_list: Returns a list of size 1 for single entry case.
+ """
+ if self._is_dict:
+ return dict(zip(self._input_names, self._flattened_inputs))
+ if self._is_single_input and not return_single_as_list:
+ return self._flattened_inputs[0]
+ return self._flattened_inputs
+
+ def get_input_values(self):
+ """Returns input values passed in."""
+ if context.executing_eagerly():
+ for i in range(len(self._flattened_inputs)):
+ v = self._flattened_inputs[i]
+ if tensor_util.is_tensor(v):
+ v = cast_single_tensor(v)
+ else:
+ v = ops.convert_to_tensor(v, dtype=K.floatx())
+ self._flattened_inputs[i] = v
+ return self._get(return_single_as_list=False)
+
+ def get_symbolic_inputs(self, return_single_as_list=False):
+ """Returns inputs to be set as self.inputs for a model."""
+ for i in range(len(self._flattened_inputs)):
+ k = self._input_names[i]
+ v = self._flattened_inputs[i]
+ if context.executing_eagerly():
+ v = base_layer.DeferredTensor(
+ shape=(None for _ in v.shape), dtype=v.dtype)
+ else:
+ if isinstance(v, list):
+ v = np.asarray(v)
+ if v.ndim == 1:
+ v = np.expand_dims(v, 1)
+ if isinstance(v, (np.ndarray)):
+ # We fix the placeholder shape except the batch size.
+ # This is suboptimal, but it is the best we can do with the info
+ # we have. The user should call `model._set_inputs(placeholders)`
+ # to specify custom placeholders if the need arises.
+ shape = (None,) + v.shape[1:]
+ v = K.placeholder(shape=shape, name=k)
+ self._flattened_inputs[i] = v
+
+ return self._get(return_single_as_list)
+
+ def as_dict(self):
+ """An iterable over a dictionary version of inputs."""
+ for i in range(len(self._flattened_inputs)):
+ yield self._input_names[i], self._flattened_inputs[i]
+
+ def as_list(self):
+ """Returning the inputs as a list."""
+ return self._flattened_inputs
diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py
index 297a1ae494..e777cb6db3 100644
--- a/tensorflow/python/keras/engine/training_utils_test.py
+++ b/tensorflow/python/keras/engine/training_utils_test.py
@@ -20,8 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.platform import test
@@ -146,5 +149,91 @@ class TrainingUtilTest(test.TestCase):
self.assertEquals(any_true, False)
+class ModelInputsTest(test.TestCase):
+
+ def test_single_thing(self):
+ a = np.ones(10)
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(10), vals)
+ self.assertFalse(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.assertEquals(1, len(vals))
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+
+ def test_single_thing_eager(self):
+ with context.eager_mode():
+ a = np.ones(10)
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(10), vals)
+ self.assertTrue(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals, base_layer.DeferredTensor))
+ vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.assertEquals(1, len(vals))
+ self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+
+ def test_list(self):
+ a = [np.ones(10), np.ones(20)]
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertEqual(2, len(vals))
+ self.assertAllEqual(np.ones(10), vals[0])
+ self.assertAllEqual(np.ones(20), vals[1])
+ self.assertFalse(tensor_util.is_tensor(vals[0]))
+ self.assertFalse(tensor_util.is_tensor(vals[1]))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+ self.assertTrue(tensor_util.is_tensor(vals[1]))
+
+ def test_list_eager(self):
+ with context.eager_mode():
+ a = [np.ones(10), np.ones(20)]
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertEqual(2, len(vals))
+ self.assertAllEqual(np.ones(10), vals[0])
+ self.assertAllEqual(np.ones(20), vals[1])
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+ self.assertTrue(tensor_util.is_tensor(vals[1]))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+ self.assertTrue(isinstance(vals[1], base_layer.DeferredTensor))
+
+ def test_dict(self):
+ a = {'b': np.ones(10), 'a': np.ones(20)}
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(20), vals['a'])
+ self.assertAllEqual(np.ones(10), vals['b'])
+ self.assertFalse(tensor_util.is_tensor(vals['a']))
+ self.assertFalse(tensor_util.is_tensor(vals['b']))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals['a']))
+ self.assertTrue(tensor_util.is_tensor(vals['b']))
+
+ def test_dict_eager(self):
+ with context.eager_mode():
+ a = {'b': np.ones(10), 'a': np.ones(20)}
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(20), vals['a'])
+ self.assertAllEqual(np.ones(10), vals['b'])
+ self.assertTrue(tensor_util.is_tensor(vals['a']))
+ self.assertTrue(tensor_util.is_tensor(vals['b']))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals['a'], base_layer.DeferredTensor))
+ self.assertTrue(isinstance(vals['b'], base_layer.DeferredTensor))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index b9d856efa8..cac78c44ca 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -20,14 +20,15 @@ from __future__ import print_function
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
from tensorflow.python.ops.init_ops import Constant
-from tensorflow.python.ops.init_ops import glorot_normal_initializer
-from tensorflow.python.ops.init_ops import glorot_uniform_initializer
+from tensorflow.python.ops.init_ops import GlorotNormal
+from tensorflow.python.ops.init_ops import GlorotUniform
from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Identity
@@ -36,15 +37,84 @@ from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unuse
from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
-from tensorflow.python.ops.init_ops import RandomNormal
-from tensorflow.python.ops.init_ops import RandomUniform
-from tensorflow.python.ops.init_ops import TruncatedNormal
+from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
+from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
+from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Zeros
from tensorflow.python.util.tf_export import tf_export
+@tf_export('keras.initializers.TruncatedNormal',
+ 'keras.initializers.truncated_normal')
+class TruncatedNormal(TFTruncatedNormal):
+ """Initializer that generates a truncated normal distribution.
+
+ These values are similar to values from a `random_normal_initializer`
+ except that values more than two standard deviations from the mean
+ are discarded and re-drawn. This is the recommended initializer for
+ neural network weights and filters.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values to
+ generate. Defaults to 0.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the random
+ values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type. Only floating point types are supported.
+ """
+
+ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
+ super(TruncatedNormal, self).__init__(
+ mean=mean, stddev=stddev, seed=seed, dtype=dtype)
+
+
+@tf_export('keras.initializers.RandomUniform', 'keras.initializers.uniform',
+ 'keras.initializers.random_uniform')
+class RandomUniform(TFRandomUniform):
+ """Initializer that generates tensors with a uniform distribution.
+
+ Args:
+ minval: A python scalar or a scalar tensor. Lower bound of the range of
+ random values to generate. Defaults to -0.05.
+ maxval: A python scalar or a scalar tensor. Upper bound of the range of
+ random values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type.
+ """
+
+ def __init__(self, minval=-0.05, maxval=0.05, seed=None,
+ dtype=dtypes.float32):
+ super(RandomUniform, self).__init__(
+ minval=minval, maxval=maxval, seed=seed, dtype=dtype)
+
+
+@tf_export('keras.initializers.RandomNormal', 'keras.initializers.normal',
+ 'keras.initializers.random_normal')
+class RandomNormal(TFRandomNormal):
+ """Initializer that generates tensors with a normal distribution.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values to
+ generate. Defaults to 0.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the random
+ values to generate. Defaults to 0.05.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed` for behavior.
+ dtype: The data type. Only floating point types are supported.
+
+ Returns:
+ RandomNormal instance.
+ """
+
+ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
+ super(RandomNormal, self).__init__(
+ mean=mean, stddev=stddev, seed=seed, dtype=dtype)
+
+
# Compatibility aliases
# pylint: disable=invalid-name
@@ -56,10 +126,9 @@ normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
-glorot_normal = glorot_normal_initializer
-glorot_uniform = glorot_uniform_initializer
+glorot_normal = GlorotNormal
+glorot_uniform = GlorotUniform
-# pylint: enable=invalid-name
# Utility functions
@@ -92,3 +161,6 @@ def get(identifier):
else:
raise ValueError('Could not interpret initializer identifier: ' +
str(identifier))
+
+
+# pylint: enable=invalid-name
diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py
index 51725e03f2..2b758a98f3 100644
--- a/tensorflow/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/initializers_test.py
@@ -40,7 +40,7 @@ class KerasInitializersTest(test.TestCase):
def test_uniform(self):
tensor_shape = (9, 6, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.RandomUniform(minval=-1,
maxval=1,
seed=124),
@@ -49,14 +49,14 @@ class KerasInitializersTest(test.TestCase):
def test_normal(self):
tensor_shape = (8, 12, 99)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
tensor_shape,
target_mean=0., target_std=1)
def test_truncated_normal(self):
tensor_shape = (12, 99, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.TruncatedNormal(mean=0,
stddev=1,
seed=126),
@@ -65,13 +65,13 @@ class KerasInitializersTest(test.TestCase):
def test_constant(self):
tensor_shape = (5, 6, 4)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.Constant(2), tensor_shape,
target_mean=2, target_max=2, target_min=2)
def test_lecun_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
@@ -79,7 +79,7 @@ class KerasInitializersTest(test.TestCase):
def test_glorot_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
@@ -87,7 +87,7 @@ class KerasInitializersTest(test.TestCase):
def test_he_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_uniform(seed=123), tensor_shape,
@@ -95,7 +95,7 @@ class KerasInitializersTest(test.TestCase):
def test_lecun_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape,
@@ -103,7 +103,7 @@ class KerasInitializersTest(test.TestCase):
def test_glorot_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
@@ -111,7 +111,7 @@ class KerasInitializersTest(test.TestCase):
def test_he_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_normal(seed=123), tensor_shape,
@@ -119,12 +119,12 @@ class KerasInitializersTest(test.TestCase):
def test_orthogonal(self):
tensor_shape = (20, 20)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.orthogonal(seed=123), tensor_shape,
target_mean=0.)
def test_identity(self):
- with self.test_session():
+ with self.cached_session():
tensor_shape = (3, 4, 5)
with self.assertRaises(ValueError):
self._runner(keras.initializers.identity(), tensor_shape,
@@ -136,16 +136,31 @@ class KerasInitializersTest(test.TestCase):
def test_zero(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.zeros(), tensor_shape,
target_mean=0., target_max=0.)
def test_one(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.ones(), tensor_shape,
target_mean=1., target_max=1.)
+ def test_default_random_uniform(self):
+ ru = keras.initializers.get('uniform')
+ self.assertEqual(ru.minval, -0.05)
+ self.assertEqual(ru.maxval, 0.05)
+
+ def test_default_random_normal(self):
+ rn = keras.initializers.get('normal')
+ self.assertEqual(rn.mean, 0.0)
+ self.assertEqual(rn.stddev, 0.05)
+
+ def test_default_truncated_normal(self):
+ tn = keras.initializers.get('truncated_normal')
+ self.assertEqual(tn.mean, 0.0)
+ self.assertEqual(tn.stddev, 0.05)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index a103b9fbf2..3c0f73b1c3 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -35,7 +35,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertTrue(keras.__version__.endswith('-tf'))
def test_vector_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -60,7 +60,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_vector_classification_functional(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -84,7 +84,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_temporal_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -106,7 +106,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_temporal_classification_sequential_tf_rnn(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -130,7 +130,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_image_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -164,7 +164,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_video_classification_functional(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -194,7 +194,7 @@ class KerasIntegrationTest(test.TestCase):
def test_vector_classification_shared_sequential(self):
# Test that Sequential models that feature internal updates
# and internal losses can be shared.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -228,7 +228,7 @@ class KerasIntegrationTest(test.TestCase):
def test_vector_classification_shared_model(self):
# Test that functional models that feature internal updates
# and internal losses can be shared.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -259,14 +259,14 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_embedding_with_clipnorm(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=1, output_dim=1))
model.compile(optimizer=keras.optimizers.SGD(clipnorm=0.1), loss='mse')
model.fit(np.array([[0]]), np.array([[[0.5]]]), epochs=1)
def test_using_tf_layers_in_keras_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -289,7 +289,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_using_tf_layers_in_keras_functional_model(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 53c1baa2bb..b020b6e730 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -26,44 +26,44 @@ from tensorflow.python.platform import test
class AdvancedActivationsTest(test.TestCase):
def test_leaky_relu(self):
- with self.test_session():
+ with self.cached_session():
for alpha in [0., .5, -1.]:
testing_utils.layer_test(keras.layers.LeakyReLU,
kwargs={'alpha': alpha},
input_shape=(2, 3, 4))
def test_prelu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.PReLU, kwargs={},
input_shape=(2, 3, 4))
def test_prelu_share(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.PReLU,
kwargs={'shared_axes': 1},
input_shape=(2, 3, 4))
def test_elu(self):
- with self.test_session():
+ with self.cached_session():
for alpha in [0., .5, -1.]:
testing_utils.layer_test(keras.layers.ELU,
kwargs={'alpha': alpha},
input_shape=(2, 3, 4))
def test_thresholded_relu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ThresholdedReLU,
kwargs={'theta': 0.5},
input_shape=(2, 3, 4))
def test_softmax(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.Softmax,
kwargs={'axis': 1},
input_shape=(2, 3, 4))
def test_relu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
@@ -71,14 +71,14 @@ class AdvancedActivationsTest(test.TestCase):
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(
ValueError, 'max_value of Relu layer cannot be negative value: -10'):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': -10},
input_shape=(2, 3, 4))
with self.assertRaisesRegexp(
ValueError,
'negative_slope of Relu layer cannot be negative value: -2'):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.ReLU,
kwargs={'negative_slope': -2},
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/layers/convolutional_recurrent_test.py
index 4b8f6f2a14..4a75793884 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent_test.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent_test.py
@@ -47,7 +47,7 @@ class ConvLSTMTest(test.TestCase):
input_channel)
for return_sequences in [True, False]:
- with self.test_session():
+ with self.cached_session():
# test for return state:
x = keras.Input(batch_shape=inputs.shape)
kwargs = {'data_format': data_format,
@@ -92,7 +92,7 @@ class ConvLSTMTest(test.TestCase):
input_num_row, input_num_col,
input_channel)
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
kwargs = {'data_format': 'channels_last',
'return_sequences': False,
@@ -144,7 +144,7 @@ class ConvLSTMTest(test.TestCase):
input_num_row, input_num_col,
input_channel)
- with self.test_session():
+ with self.cached_session():
kwargs = {'data_format': 'channels_last',
'return_sequences': False,
'kernel_size': (num_row, num_col),
@@ -168,7 +168,7 @@ class ConvLSTMTest(test.TestCase):
def test_conv_lstm_dropout(self):
# check dropout
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.ConvLSTM2D,
kwargs={'data_format': 'channels_last',
@@ -181,7 +181,7 @@ class ConvLSTMTest(test.TestCase):
input_shape=(1, 2, 5, 5, 2))
def test_conv_lstm_cloning(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3)))
@@ -190,7 +190,7 @@ class ConvLSTMTest(test.TestCase):
weights = model.get_weights()
# Use a new graph to clone the model
- with self.test_session():
+ with self.cached_session():
clone = keras.models.clone_model(model)
clone.set_weights(weights)
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 49ca68ee9e..1df1d575b1 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -30,16 +30,16 @@ from tensorflow.python.platform import test
class CoreLayersTest(test.TestCase):
def test_masking(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Masking, kwargs={}, input_shape=(3, 2, 3))
def test_dropout(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Dropout, kwargs={'rate': 0.5}, input_shape=(3, 2))
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Dropout,
kwargs={'rate': 0.5,
@@ -47,7 +47,7 @@ class CoreLayersTest(test.TestCase):
input_shape=(3, 2))
# https://github.com/tensorflow/tensorflow/issues/14819
- with self.test_session():
+ with self.cached_session():
dropout = keras.layers.Dropout(0.5)
self.assertEqual(True, dropout.supports_masking)
@@ -210,7 +210,7 @@ class CoreLayersTest(test.TestCase):
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2))
def test_dense_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.Dense(
3,
kernel_regularizer=keras.regularizers.l1(0.01),
@@ -221,7 +221,7 @@ class CoreLayersTest(test.TestCase):
self.assertEqual(3, len(layer.losses))
def test_dense_constraints(self):
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
layer = keras.layers.Dense(
@@ -231,14 +231,14 @@ class CoreLayersTest(test.TestCase):
self.assertEqual(layer.bias.constraint, b_constraint)
def test_activity_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.ActivityRegularization(l1=0.1)
layer(keras.backend.variable(np.ones((2, 4))))
self.assertEqual(1, len(layer.losses))
_ = layer.get_config()
def test_lambda_output_shape(self):
- with self.test_session():
+ with self.cached_session():
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual((1, 1), l.get_config()['output_shape'])
@@ -247,13 +247,13 @@ class CoreLayersTest(test.TestCase):
def get_output_shape(input_shape):
return 1 * input_shape
- with self.test_session():
+ with self.cached_session():
l = keras.layers.Lambda(lambda x: x + 1, output_shape=get_output_shape)
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
def test_lambda_config_serialization(self):
- with self.test_session():
+ with self.cached_session():
# test serialization with output_shape and output_shape_type
layer = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
layer(keras.backend.variable(np.ones((1, 1))))
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index fff1c5ef98..cab176ee34 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -68,7 +68,7 @@ class EmbeddingTest(test.TestCase):
expected_output_dtype='float32')
def test_embedding_correctness(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.Embedding(output_dim=2, input_dim=2)
layer.build((None, 2))
matrix = np.array([[1, 1], [2, 2]])
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index afef997b00..9988c9fae5 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -87,7 +87,7 @@ class GRULayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class GRULayerTest(test.TestCase):
def test_regularizers_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class GRULayerTest(test.TestCase):
def test_constraints_GRU(self):
embedding_dim = 4
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class GRULayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_GRU(self):
layer_class = keras.layers.GRU
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index 4781bcae07..8589b32b3c 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -87,7 +87,7 @@ class LocallyConnectedLayersTest(test.TestCase):
keras.layers.LocallyConnected1D,
**kwargs)
else:
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected1D(**kwargs)
layer.build((num_samples, num_steps, input_dim))
self.assertEqual(len(layer.losses), 2)
@@ -105,7 +105,7 @@ class LocallyConnectedLayersTest(test.TestCase):
'kernel_constraint': k_constraint,
'bias_constraint': b_constraint,
}
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected1D(**kwargs)
layer.build((num_samples, num_steps, input_dim))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -197,7 +197,7 @@ class LocallyConnectedLayersTest(test.TestCase):
keras.layers.LocallyConnected2D,
**kwargs)
else:
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected2D(**kwargs)
layer.build((num_samples, num_row, num_col, stack_size))
self.assertEqual(len(layer.losses), 2)
@@ -214,7 +214,7 @@ class LocallyConnectedLayersTest(test.TestCase):
'kernel_constraint': k_constraint,
'bias_constraint': b_constraint,
}
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected2D(**kwargs)
layer.build((num_samples, num_row, num_col, stack_size))
self.assertEqual(layer.kernel.constraint, k_constraint)
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 9802820fd0..f536915324 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -102,7 +102,7 @@ class LSTMLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -161,7 +161,7 @@ class LSTMLayerTest(test.TestCase):
def test_regularizers_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -180,7 +180,7 @@ class LSTMLayerTest(test.TestCase):
def test_constraints_LSTM(self):
embedding_dim = 4
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -200,7 +200,7 @@ class LSTMLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_LSTM(self):
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
@@ -225,7 +225,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -252,7 +252,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
# Test with non-Keras tensor
inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.backend.random_normal_variable(
@@ -275,7 +275,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LSTM(units, stateful=True)
layer.build((num_samples, timesteps, embedding_dim))
layer.reset_states()
@@ -306,7 +306,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((timesteps, embedding_dim))
_ = keras.layers.Masking()(inputs)
initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -329,7 +329,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, stateful=True)
outputs = layer(inputs)
@@ -347,7 +347,7 @@ class LSTMLayerTest(test.TestCase):
units = 3
num_samples = 2
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = keras.layers.LSTM(units, return_state=True, return_sequences=True)
outputs = layer(inputs)
@@ -366,7 +366,7 @@ class LSTMLayerTest(test.TestCase):
num_states = 2
layer_class = keras.layers.LSTM
- with self.test_session():
+ with self.cached_session():
# Test with Keras tensor
main_inputs = keras.Input((timesteps, embedding_dim))
initial_state = [keras.Input((units,)) for _ in range(num_states)]
diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py
index 39bc98d039..7bcfcaeddb 100644
--- a/tensorflow/python/keras/layers/merge_test.py
+++ b/tensorflow/python/keras/layers/merge_test.py
@@ -46,7 +46,7 @@ class MergeLayersTest(test.TestCase):
self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
def test_merge_add_masking(self):
- with self.test_session():
+ with self.cached_session():
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
m1 = keras.layers.Masking()(i1)
@@ -57,7 +57,7 @@ class MergeLayersTest(test.TestCase):
self.assertListEqual(mask.get_shape().as_list(), [None, 4])
def test_merge_add_dynamic_shape(self):
- with self.test_session():
+ with self.cached_session():
i1 = array_ops.placeholder(shape=(4, None), dtype='float32')
i2 = array_ops.placeholder(shape=(4, 5), dtype='float32')
layer = keras.layers.Add()
@@ -149,7 +149,7 @@ class MergeLayersTest(test.TestCase):
self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
def test_merge_concatenate_masking(self):
- with self.test_session():
+ with self.cached_session():
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
m1 = keras.layers.Masking()(i1)
diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py
index aa2be62390..cea304680b 100644
--- a/tensorflow/python/keras/layers/noise_test.py
+++ b/tensorflow/python/keras/layers/noise_test.py
@@ -27,14 +27,14 @@ from tensorflow.python.platform import test
class NoiseLayersTest(test.TestCase):
def test_GaussianNoise(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.GaussianNoise,
kwargs={'stddev': 1.},
input_shape=(3, 2, 3))
def test_GaussianDropout(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.GaussianDropout,
kwargs={'rate': 0.5},
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index cd26e04c39..013d572088 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -313,18 +313,18 @@ class BatchNormalization(Layer):
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
@@ -340,9 +340,9 @@ class BatchNormalization(Layer):
shape=shape,
dtype=param_dtype,
initializer=init_ops.zeros_initializer(),
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
return var
with distribution_strategy_context.get_distribution_strategy(
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index a97b4cac46..2844b84799 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class NormalizationLayersTest(test.TestCase):
def test_basic_batchnorm(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.BatchNormalization,
kwargs={
@@ -54,7 +54,7 @@ class NormalizationLayersTest(test.TestCase):
input_shape=(3, 3))
def test_batchnorm_weights(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.BatchNormalization(scale=False, center=False)
layer.build((None, 3, 4))
self.assertEqual(len(layer.trainable_weights), 0)
@@ -66,7 +66,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(len(layer.weights), 4)
def test_batchnorm_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.BatchNormalization(
gamma_regularizer='l1', beta_regularizer='l1')
layer.build((None, 3, 4))
@@ -79,7 +79,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(layer.beta.constraint, max_norm)
def test_batchnorm_correctness(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
@@ -96,7 +96,7 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
def test_batchnorm_mixed_precision(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
@@ -133,7 +133,7 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
def test_batchnorm_convnet_channel_last(self):
- with self.test_session():
+ with self.cached_session():
# keras.backend.set_learning_phase(True)
model = keras.models.Sequential()
@@ -155,7 +155,7 @@ class NormalizationLayersTest(test.TestCase):
def test_shared_batchnorm(self):
"""Test that a BN layer can be shared across different data streams.
"""
- with self.test_session():
+ with self.cached_session():
# Test single layer reuse
bn = keras.layers.BatchNormalization()
x1 = keras.layers.Input(shape=(10,))
@@ -187,7 +187,7 @@ class NormalizationLayersTest(test.TestCase):
new_model.train_on_batch(x, x)
def test_that_trainable_disables_updates(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -230,7 +230,7 @@ class NormalizationLayersTest(test.TestCase):
Computes mean and std for current inputs then
applies batch normalization using them.
"""
- with self.test_session():
+ with self.cached_session():
bn_mean = 0.5
bn_std = 10.
val_a = np.expand_dims(np.arange(10.), axis=1)
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 12c82a53f6..ba7498e7e6 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -33,7 +33,6 @@ from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -74,19 +73,27 @@ class StackedRNNCells(Layer):
'`state_size` attribute. '
'received cells:', cells)
self.cells = cells
+ # reverse_state_order determines whether the state size will be in a reverse
+ # order of the cells' state. User might want to set this to True to keep the
+ # existing behavior. This is only useful when use RNN(return_state=True)
+ # since the state will be returned as the same order of state_size.
+ self.reverse_state_order = kwargs.pop('reverse_state_order', False)
+ if self.reverse_state_order:
+ logging.warning('reverse_state_order=True in StackedRNNCells will soon '
+ 'be deprecated. Please update the code to work with the '
+ 'natural order of states if you reply on the RNN states, '
+ 'eg RNN(return_state=True).')
super(StackedRNNCells, self).__init__(**kwargs)
@property
def state_size(self):
- # States are a flat list
- # in reverse order of the cell stack.
- # This allows to preserve the requirement
- # `stack.state_size[0] == output_dim`.
- # e.g. states of a 2-layer LSTM would be
- # `[h2, c2, h1, c1]`
+ # States are a flat list of the individual cell state size.
+ # e.g. states of a 2-layer LSTM would be `[h1, c1, h2, c2]`.
# (assuming one LSTM has states [h, c])
+ # In the case of reverse_state_order=True, the state_size will be
+ # [h2, c2, h1, c1].
state_size = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
state_size += list(cell.state_size)
else:
@@ -95,22 +102,40 @@ class StackedRNNCells(Layer):
@property
def output_size(self):
- if hasattr(self.cells[-1], 'output_size'):
+ if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
+ elif _is_multiple_state(self.cells[-1].state_size):
+ return self.cells[-1].state_size[0]
else:
- return self.state_size[0]
+ return self.cells[-1].state_size
+
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ # The init state is flattened into a list because state_size is a flattened
+ # list.
+ initial_states = []
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
+ get_initial_state_fn = getattr(cell, 'get_initial_state', None)
+ if get_initial_state_fn:
+ initial_states.append(get_initial_state_fn(
+ inputs=inputs, batch_size=batch_size, dtype=dtype))
+ else:
+ initial_states.append(_generate_zero_filled_state_for_cell(
+ cell, inputs, batch_size, dtype))
+
+ return nest.flatten(initial_states)
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
nested_states.append(states[:len(cell.state_size)])
states = states[len(cell.state_size):]
else:
nested_states.append([states[0]])
states = states[1:]
- nested_states = nested_states[::-1]
+ if self.reverse_state_order:
+ nested_states = nested_states[::-1]
# Call the cells in order and store the returned states.
new_nested_states = []
@@ -124,11 +149,12 @@ class StackedRNNCells(Layer):
new_nested_states.append(states)
# Format the new states as a flat list
- # in reverse cell order.
- states = []
- for cell_states in new_nested_states[::-1]:
- states += cell_states
- return inputs, states
+ new_states = []
+ if self.reverse_state_order:
+ new_nested_states = new_nested_states[::-1]
+ for cell_states in new_nested_states:
+ new_states += cell_states
+ return inputs, new_states
@tf_utils.shape_type_conversion
def build(self, input_shape):
@@ -141,7 +167,9 @@ class StackedRNNCells(Layer):
cell.build([input_shape] + constants_shape)
else:
cell.build(input_shape)
- if _is_multiple_state(cell.state_size):
+ if getattr(cell, 'output_size', None) is not None:
+ output_dim = cell.output_size
+ elif _is_multiple_state(cell.state_size):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
@@ -261,6 +289,22 @@ class RNN(Layer):
compatible reason, if this attribute is not available for the
cell, the value will be inferred by the first element of the
`state_size`.
+ - a `get_initial_state(inputs=None, batch_size=None, dtype=None)`
+ method that creates a tensor meant to be fed to `call()` as the
+ initial state, if user didn't specify any initial state via other
+ means. The returned initial state should be in shape of
+ [batch, cell.state_size]. Cell might choose to create zero filled
+ tensor, or with other values based on the cell implementations.
+ `inputs` is the input tensor to the RNN layer, which should
+ contain the batch size as its shape[0], and also dtype. Note that
+ the shape[0] might be None during the graph construction. Either
+ the `inputs` or the pair of `batch` and `dtype `are provided.
+ `batch` is a scalar tensor that represent the batch size
+ of the input. `dtype` is `tf.dtype` that represent the dtype of
+ the input.
+ For backward compatible reason, if this method is not implemented
+ by the cell, RNN layer will create a zero filled tensors with the
+ size of [batch, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on after the other in the RNN, implementing an
efficient stacked RNN.
@@ -453,7 +497,7 @@ class RNN(Layer):
else:
state_size = [self.cell.state_size]
- if hasattr(self.cell, 'output_size'):
+ if getattr(self.cell, 'output_size', None) is not None:
output_dim = tensor_shape.as_shape(self.cell.output_size).as_list()
else:
# Note that state_size[0] could be a tensor_shape or int.
@@ -553,26 +597,18 @@ class RNN(Layer):
raise validation_error
def get_initial_state(self, inputs):
- # build an all-zero tensor of shape (batch, cell.state_size)
- initial_state = array_ops.zeros_like(inputs)
- # shape of initial_state = (batch, timesteps, ...)
- initial_state = math_ops.reduce_sum(
- initial_state, axis=list(range(1, len(inputs.shape))))
- # shape of initial_state = (batch,)
- if _is_multiple_state(self.cell.state_size):
- states = []
- for dims in self.cell.state_size:
- state = initial_state
- flat_dims = tensor_shape.as_shape(dims).as_list()
- # reshape the state to (batch, 1, 1, ....) and then expand each state.
- state = array_ops.reshape(state, [-1,] + [1] * len(flat_dims))
- states.append(K.tile(state, [1] + flat_dims))
- return states
+ get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+ if get_initial_state_fn:
+ init_state = get_initial_state_fn(
+ inputs=inputs, batch_size=None, dtype=None)
else:
- flat_dims = tensor_shape.as_shape(self.cell.state_size).as_list()
- initial_state = array_ops.reshape(
- initial_state, [-1] + [1] * len(flat_dims))
- return [K.tile(initial_state, [1] + flat_dims)]
+ init_state = _generate_zero_filled_state(
+ array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ # Keras RNN expect the states in a list, even if it's a single state tensor.
+ if not nest.is_sequence(init_state):
+ init_state = [init_state]
+ # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
+ return list(init_state)
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = _standardize_args(inputs,
@@ -636,6 +672,14 @@ class RNN(Layer):
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
+ # get initial_state from full input spec
+ # as they could be copied to multiple GPU.
+ if self._num_constants is None:
+ initial_state = inputs[1:]
+ else:
+ initial_state = inputs[1:-self._num_constants]
+ if len(initial_state) == 0:
+ initial_state = None
inputs = inputs[0]
if initial_state is not None:
pass
@@ -986,6 +1030,9 @@ class SimpleRNNCell(Layer):
output._uses_learning_phase = True
return output, [output]
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
def get_config(self):
config = {
'units':
@@ -1517,6 +1564,9 @@ class GRUCell(Layer):
base_config = super(GRUCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
@tf_export('keras.layers.GRU')
class GRU(RNN):
@@ -2042,11 +2092,17 @@ class LSTMCell(Layer):
base_config = super(LSTMCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
+
@tf_export('keras.layers.LSTM')
class LSTM(RNN):
"""Long Short-Term Memory layer - Hochreiter 1997.
+ Note that this cell is not optimized for performance on GPU. Please use
+ `tf.keras.layers.CuDNNLSTM` for better performance on GPU.
+
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
@@ -2142,6 +2198,10 @@ class LSTM(RNN):
logging.warning('`implementation=0` has been deprecated, '
'and now defaults to `implementation=1`.'
'Please update your layer call.')
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn('%s: Note that this layer is not optimized for performance. '
+ 'Please use tf.keras.layers.CuDNNLSTM for better '
+ 'performance on GPU.', self)
cell = LSTMCell(
units,
activation=activation,
@@ -2354,3 +2414,30 @@ def _is_multiple_state(state_size):
"""Check whether the state_size contains multiple states."""
return (hasattr(state_size, '__len__') and
not isinstance(state_size, tensor_shape.TensorShape))
+
+
+def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
+ if inputs is not None:
+ batch_size = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
+
+
+def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
+ """Generate a zero filled tensor with shape [batch_size, state_size]."""
+ if None in [batch_size_tensor, dtype]:
+ raise ValueError(
+ 'batch_size and dtype cannot be None while constructing initial state: '
+ 'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
+ if _is_multiple_state(state_size):
+ states = []
+ for dims in state_size:
+ flat_dims = tensor_shape.as_shape(dims).as_list()
+ init_state_size = [batch_size_tensor] + flat_dims
+ init_state = array_ops.zeros(init_state_size, dtype=dtype)
+ states.append(init_state)
+ return states
+ else:
+ flat_dims = tensor_shape.as_shape(state_size).as_list()
+ init_state_size = [batch_size_tensor] + flat_dims
+ return array_ops.zeros(init_state_size, dtype=dtype)
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 13bd070528..a3861e44d5 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -50,7 +50,7 @@ class RNNTest(test.TestCase):
output = keras.backend.dot(inputs, self.kernel) + prev_output
return output, [output]
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = MinimalRNNCell(32, 5)
x = keras.Input((None, 5))
@@ -88,7 +88,7 @@ class RNNTest(test.TestCase):
output -= prev_output_2
return output, [output * 2, output * 3]
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = MinimalRNNCell(32, 5)
x = keras.Input((None, 5))
@@ -103,7 +103,8 @@ class RNNTest(test.TestCase):
MinimalRNNCell(16, 8),
MinimalRNNCell(32, 16)]
layer = keras.layers.RNN(cells)
- assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
+ self.assertEqual(layer.cell.state_size, (8, 8, 16, 16, 32, 32))
+ self.assertEqual(layer.cell.output_size, 32)
y = layer(x)
model = keras.models.Model(x, y)
model.compile(optimizer='rmsprop', loss='mse')
@@ -139,7 +140,7 @@ class RNNTest(test.TestCase):
base_config = super(MinimalRNNCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
cell = MinimalRNNCell(32)
@@ -228,7 +229,7 @@ class RNNTest(test.TestCase):
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -243,7 +244,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -259,7 +260,7 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict([x_np, c_np])
self.assertAllClose(y_np, y_np_2, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# test flat list inputs.
with keras.utils.CustomObjectScope(custom_objects):
layer = keras.layers.RNN.from_config(config.copy())
@@ -269,7 +270,7 @@ class RNNTest(test.TestCase):
y_np_3 = model.predict([x_np, c_np])
self.assertAllClose(y_np, y_np_3, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# Test stacking.
cells = [keras.layers.recurrent.GRUCell(8),
RNNCellWithConstants(12),
@@ -283,7 +284,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test GRUCell reset_after property.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -297,7 +298,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test stacked RNN serialization
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -355,7 +356,7 @@ class RNNTest(test.TestCase):
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -370,7 +371,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
s_np = np.random.random((6, 32))
@@ -392,7 +393,7 @@ class RNNTest(test.TestCase):
with self.assertRaises(AssertionError):
self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# test flat list inputs
with keras.utils.CustomObjectScope(custom_objects):
layer = keras.layers.RNN.from_config(config.copy())
@@ -467,7 +468,7 @@ class RNNTest(test.TestCase):
timesteps = 2
num_samples = 2
- with self.test_session():
+ with self.cached_session():
input1 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = layer_class(units,
return_state=True,
@@ -487,7 +488,7 @@ class RNNTest(test.TestCase):
for cell_class in [keras.layers.SimpleRNNCell,
keras.layers.GRUCell,
keras.layers.LSTMCell]:
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
cell = cell_class(32)
@@ -534,7 +535,7 @@ class RNNTest(test.TestCase):
keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
layer = keras.layers.RNN(cells)
- with self.test_session():
+ with self.cached_session():
x = keras.Input((None, 5))
y = layer(x)
model = keras.models.Model(x, y)
@@ -551,6 +552,21 @@ class RNNTest(test.TestCase):
layer = keras.layers.RNN(cells, return_state=True, return_sequences=True)
output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
expected_output_shape = [(None, timesteps, 6),
+ (None, 3),
+ (None, 3),
+ (None, 6),
+ (None, 6)]
+ self.assertEqual(
+ [tuple(o.as_list()) for o in output_shape],
+ expected_output_shape)
+
+ # Test reverse_state_order = True for stacked cell.
+ stacked_cell = keras.layers.StackedRNNCells(
+ cells, reverse_state_order=True)
+ layer = keras.layers.RNN(
+ stacked_cell, return_state=True, return_sequences=True)
+ output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
+ expected_output_shape = [(None, timesteps, 6),
(None, 6),
(None, 6),
(None, 3),
@@ -561,7 +577,7 @@ class RNNTest(test.TestCase):
def test_checkpointable_dependencies(self):
rnn = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
x = np.random.random((2, 2, 2))
y = np.random.random((2, 2))
model = keras.models.Sequential()
@@ -576,7 +592,7 @@ class RNNTest(test.TestCase):
self.assertIn(v, checkpointed_objects)
def test_high_dimension_RNN(self):
- with self.test_session():
+ with self.cached_session():
# Basic test case.
unit_a = 10
unit_b = 20
@@ -626,7 +642,7 @@ class RNNTest(test.TestCase):
batch = 32
time_step = 4
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = Minimal2DRNNCell(unit_a, unit_b)
x = keras.Input((None, input_a, input_b))
@@ -642,7 +658,7 @@ class RNNTest(test.TestCase):
], np.zeros((batch, unit_a, unit_b)))
self.assertEqual(model.output_shape, (None, unit_a, unit_b))
- with self.test_session():
+ with self.cached_session():
# Bad init state shape.
bad_shape_a = unit_a * 2
bad_shape_b = unit_b * 2
@@ -655,7 +671,7 @@ class RNNTest(test.TestCase):
layer(x, initial_state=s)
def test_inconsistent_output_state_size(self):
- with self.test_session():
+ with self.cached_session():
batch = 32
time_step = 4
state_size = 5
@@ -678,6 +694,23 @@ class RNNTest(test.TestCase):
np.zeros((batch, input_size)))
self.assertEqual(model.output_shape, (None, input_size))
+ def test_get_initial_state(self):
+ cell = keras.layers.SimpleRNNCell(5)
+ with self.assertRaisesRegexp(ValueError,
+ 'batch_size and dtype cannot be None'):
+ cell.get_initial_state(None, None, None)
+
+ inputs = keras.Input((None, 2, 10))
+ initial_state = cell.get_initial_state(inputs, None, None)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
+ batch = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ initial_state = cell.get_initial_state(None, batch, dtype)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
class Minimal2DRNNCell(keras.layers.Layer):
"""The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 1429537648..2f2295a793 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -87,7 +87,7 @@ class SimpleRNNLayerTest(test.TestCase):
embedding_dim = 4
units = 2
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Embedding(
@@ -146,7 +146,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_regularizers_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
layer = layer_class(
5,
return_sequences=False,
@@ -166,7 +166,7 @@ class SimpleRNNLayerTest(test.TestCase):
def test_constraints_SimpleRNN(self):
embedding_dim = 4
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@ class SimpleRNNLayerTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_SimpleRNN(self):
layer_class = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index 9b8d5fc5cc..a1933c11b0 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -545,11 +545,27 @@ class Bidirectional(Wrapper):
if initial_state is not None and generic_utils.has_arg(
self.layer.call, 'initial_state'):
- forward_state = initial_state[:len(initial_state) // 2]
- backward_state = initial_state[len(initial_state) // 2:]
- y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs)
- y_rev = self.backward_layer.call(
- inputs, initial_state=backward_state, **kwargs)
+ forward_inputs = [inputs[0]]
+ backward_inputs = [inputs[0]]
+ pivot = len(initial_state) // 2 + 1
+ # add forward initial state
+ forward_state = inputs[1:pivot]
+ forward_inputs += forward_state
+ if self._num_constants is None:
+ # add backward initial state
+ backward_state = inputs[pivot:]
+ backward_inputs += backward_state
+ else:
+ # add backward initial state
+ backward_state = inputs[pivot:-self._num_constants]
+ backward_inputs += backward_state
+ # add constants for forward and backward layers
+ forward_inputs += inputs[-self._num_constants:]
+ backward_inputs += inputs[-self._num_constants:]
+ y = self.forward_layer.call(forward_inputs,
+ initial_state=forward_state, **kwargs)
+ y_rev = self.backward_layer.call(backward_inputs,
+ initial_state=backward_state, **kwargs)
else:
y = self.forward_layer.call(inputs, **kwargs)
y_rev = self.backward_layer.call(inputs, **kwargs)
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 0cd774ef0f..965960917c 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -113,7 +113,7 @@ class TimeDistributedTest(test.TestCase):
keras.layers.TimeDistributed(x)
def test_timedistributed_conv2d(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -128,7 +128,7 @@ class TimeDistributedTest(test.TestCase):
model.summary()
def test_timedistributed_stacked(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -144,7 +144,7 @@ class TimeDistributedTest(test.TestCase):
batch_size=10)
def test_regularizers(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -155,7 +155,7 @@ class TimeDistributedTest(test.TestCase):
self.assertEqual(len(model.losses), 1)
def test_TimeDistributed_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
# test layers that need learning_phase to be set
np.random.seed(1234)
x = keras.layers.Input(shape=(3, 2))
@@ -166,7 +166,7 @@ class TimeDistributedTest(test.TestCase):
self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
def test_TimeDistributed_batchnorm(self):
- with self.test_session():
+ with self.cached_session():
# test that wrapped BN updates still work.
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(
@@ -202,7 +202,7 @@ class TimeDistributedTest(test.TestCase):
assert len(layer.trainable_weights) == 2
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
- with self.test_session():
+ with self.cached_session():
# test with unspecified shape and Embeddings with mask_zero
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(
@@ -234,7 +234,7 @@ class TimeDistributedTest(test.TestCase):
self.assertIs(mask_outputs[-1], None) # final layer
def test_TimeDistributed_with_masking_layer(self):
- with self.test_session():
+ with self.cached_session():
# test with Masking layer
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(keras.layers.Masking(
@@ -266,7 +266,7 @@ class BidirectionalTest(test.TestCase):
dim = 2
timesteps = 2
output_dim = 2
- with self.test_session():
+ with self.cached_session():
for mode in ['sum', 'concat', 'ave', 'mul']:
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
@@ -310,7 +310,7 @@ class BidirectionalTest(test.TestCase):
dim = 2
timesteps = 2
output_dim = 2
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
model = keras.models.Sequential()
model.add(
@@ -331,7 +331,7 @@ class BidirectionalTest(test.TestCase):
output_dim = 2
mode = 'sum'
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
@@ -363,7 +363,7 @@ class BidirectionalTest(test.TestCase):
output_dim = 2
mode = 'sum'
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
@@ -383,7 +383,7 @@ class BidirectionalTest(test.TestCase):
units = 3
x = [np.random.rand(samples, timesteps, dim)]
- with self.test_session():
+ with self.cached_session():
for merge_mode in ['sum', 'mul', 'ave', 'concat', None]:
if merge_mode == 'sum':
merge_func = lambda y, y_rev: y + y_rev
@@ -447,7 +447,7 @@ class BidirectionalTest(test.TestCase):
merge_mode = 'sum'
x = [np.random.rand(samples, timesteps, dim)]
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((timesteps, dim))
wrapped = keras.layers.Bidirectional(
rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode)
@@ -474,7 +474,7 @@ class BidirectionalTest(test.TestCase):
timesteps = 3
units = 3
- with self.test_session():
+ with self.cached_session():
input1 = keras.layers.Input((timesteps, dim))
layer = keras.layers.Bidirectional(
rnn(units, return_state=True, return_sequences=True))
@@ -498,7 +498,7 @@ class BidirectionalTest(test.TestCase):
def test_Bidirectional_trainable(self):
# test layers that need learning_phase to be set
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
_ = layer(x)
@@ -509,7 +509,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.trainable_weights) == 6
def test_Bidirectional_updates(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
x_reachable_update = x * x
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
@@ -526,7 +526,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.get_updates_for(x)) == 2
def test_Bidirectional_losses(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
x_reachable_loss = x * x
layer = keras.layers.Bidirectional(
@@ -545,7 +545,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.get_losses_for(x)) == 2
def test_Bidirectional_with_constants(self):
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((5, 5))
c = keras.Input((3,))
@@ -586,7 +586,7 @@ class BidirectionalTest(test.TestCase):
self.assertAllClose(y_np, y_np_3, atol=1e-4)
def test_Bidirectional_with_constants_layer_passing_initial_state(self):
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((5, 5))
c = keras.Input((3,))
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 3098a6d071..c7015270ac 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -63,7 +63,7 @@ class _MSEMAELoss(object):
class KerasLossesTest(test.TestCase):
def test_objective_shapes_3d(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.random((5, 6, 7)))
y_b = keras.backend.variable(np.random.random((5, 6, 7)))
for obj in ALL_LOSSES:
@@ -71,7 +71,7 @@ class KerasLossesTest(test.TestCase):
self.assertListEqual(objective_output.get_shape().as_list(), [5, 6])
def test_objective_shapes_2d(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.random((6, 7)))
y_b = keras.backend.variable(np.random.random((6, 7)))
for obj in ALL_LOSSES:
@@ -79,7 +79,7 @@ class KerasLossesTest(test.TestCase):
self.assertListEqual(objective_output.get_shape().as_list(), [6,])
def test_cce_one_hot(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.randint(0, 7, (5, 6)))
y_b = keras.backend.variable(np.random.random((5, 6, 7)))
objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b)
@@ -119,7 +119,7 @@ class KerasLossesTest(test.TestCase):
self.addCleanup(shutil.rmtree, tmpdir)
model_filename = os.path.join(tmpdir, 'custom_loss.h5')
- with self.test_session():
+ with self.cached_session():
with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
loss = _MSEMAELoss(0.3)
inputs = keras.layers.Input((2,))
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 9b87170ebe..473d8cd95b 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -22,7 +22,10 @@ from __future__ import print_function
from abc import ABCMeta
from abc import abstractmethod
+import functools
+import sys
import types
+import weakref
import six
from tensorflow.python.eager import context
@@ -53,11 +56,12 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.docs import doc_controls
def check_is_tensor_or_operation(x, name):
@@ -136,6 +140,21 @@ def result_wrapper(result_fn):
return tf_decorator.make_decorator(result_fn, decorated)
+def weakmethod(method):
+ """Creates a weak reference to the bound method."""
+
+ cls = method.im_class
+ func = method.im_func
+ instance_ref = weakref.ref(method.im_self)
+
+ @functools.wraps(method)
+ def inner(*args, **kwargs):
+ return func.__get__(instance_ref(), cls)(*args, **kwargs)
+
+ del method
+ return inner
+
+
def safe_div(numerator, denominator):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
@@ -241,7 +260,7 @@ class Metric(Layer):
```python
m = SomeMetric(...)
- init_op = tf.global_variables_initializer() # Initialize variables
+ init_op = tf.variables_initializer(m.variables) # Initialize variables
with tf.Session() as sess:
sess.run(init_op)
for input in ...:
@@ -317,14 +336,27 @@ class Metric(Layer):
def __new__(cls, *args, **kwargs):
obj = super(Metric, cls).__new__(cls)
- # TODO(psv): Fix reference cycle issue here.
-
- # Converting update_state_fn() into a graph function, so that
- # we can return a single op that performs all of the variable updates.
- defuned_update_state_fn = function.defun(obj.update_state)
- obj.update_state = types.MethodType(
- update_state_wrapper(defuned_update_state_fn), obj)
- obj.result = types.MethodType(result_wrapper(obj.result), obj)
+
+ if sys.version_info < (3,):
+ # Wrap methods in `weakmethod` function to remove binding and create a
+ # weak reference. This is to remove reference cycle that is created here.
+ # This is not an issue in python versions > 3.
+ if context.executing_eagerly():
+ update_state = weakmethod(obj.update_state)
+ else:
+ update_state = function.defun(obj.update_state)
+ obj.update_state = weakmethod(
+ types.MethodType(update_state_wrapper(update_state), obj))
+ result = weakmethod(obj.result)
+ obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
+ else:
+ # Converting update_state_fn() into a graph function, so that
+ # we can return a single op that performs all of the variable updates.
+ defuned_update_state_fn = function.defun(obj.update_state)
+ obj.update_state = types.MethodType(
+ update_state_wrapper(defuned_update_state_fn), obj)
+ obj.result = types.MethodType(result_wrapper(obj.result), obj)
+
return obj
def __call__(self, *args, **kwargs):
@@ -388,11 +420,12 @@ class Metric(Layer):
return cls(**config)
### For use by subclasses ###
+ @doc_controls.for_subclass_implementers
def add_weight(self,
name,
shape=(),
- aggregation=vs.VariableAggregation.SUM,
- synchronization=vs.VariableSynchronization.ON_READ,
+ aggregation=tf_variables.VariableAggregation.SUM,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
initializer=None):
"""Adds state variable. Only for use by subclasses."""
return super(Metric, self).add_weight(
@@ -401,6 +434,7 @@ class Metric(Layer):
dtype=self._dtype,
trainable=False,
initializer=initializer,
+ collections=[],
synchronization=synchronization,
aggregation=aggregation)
@@ -582,11 +616,15 @@ def categorical_accuracy(y_true, y_pred):
def sparse_categorical_accuracy(y_true, y_pred):
- return math_ops.cast(
- math_ops.equal(
- math_ops.reduce_max(y_true, axis=-1),
- math_ops.cast(math_ops.argmax(y_pred, axis=-1), K.floatx())),
- K.floatx())
+ y_true = math_ops.reduce_max(y_true, axis=-1)
+ y_pred = math_ops.argmax(y_pred, axis=-1)
+
+ # If the expected labels are float, we need to cast the int returned by
+ # argmax to compare.
+ if K.dtype(y_true) == K.floatx():
+ y_pred = math_ops.cast(y_pred, K.floatx())
+
+ return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
@tf_export('keras.metrics.top_k_categorical_accuracy')
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 2ac74219d4..4195ea18ad 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable_uti
class KerasMetricsTest(test.TestCase):
def test_metrics(self):
- with self.test_session():
+ with self.cached_session():
y_a = K.variable(np.random.random((6, 7)))
y_b = K.variable(np.random.random((6, 7)))
for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
@@ -48,14 +48,37 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
metric = metrics.sparse_categorical_accuracy
- y_a = K.variable(np.random.randint(0, 7, (6,)))
- y_b = K.variable(np.random.random((6, 7)))
- self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
+ y_true = K.variable(np.random.randint(0, 7, (6,)))
+ y_pred = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+
+ def test_sparse_categorical_accuracy_float(self):
+ with self.cached_session():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = K.variable(np.random.random((6,)))
+ y_pred = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+
+ def test_sparse_categorical_accuracy_eager(self):
+ """Tests that ints passed in via Eager return results. See b/113504761."""
+ with context.eager_mode():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = np.arange(6).reshape([6, 1])
+ y_pred = np.arange(36).reshape([6, 6])
+ self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
+
+ def test_sparse_categorical_accuracy_float_eager(self):
+ """Tests that floats passed in via Eager return results. See b/113504761."""
+ with context.eager_mode():
+ metric = metrics.sparse_categorical_accuracy
+ y_true = np.arange(6, dtype=np.float32).reshape([6, 1])
+ y_pred = np.arange(36).reshape([6, 6])
+ self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
def test_sparse_top_k_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -69,7 +92,7 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(result, 0.)
def test_top_k_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
@@ -80,7 +103,7 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1334)
class BinaryTruePositives(layers.Layer):
@@ -189,7 +212,7 @@ class KerasMetricsTest(test.TestCase):
self.assertAllClose(
val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
- @test_util.run_in_graph_and_eager_modes
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def test_mean(self):
m = metrics.Mean(name='my_mean')
@@ -198,7 +221,7 @@ class KerasMetricsTest(test.TestCase):
self.assertTrue(m.stateful)
self.assertEqual(m.dtype, dtypes.float32)
self.assertEqual(len(m.variables), 2)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(m.variables))
# check initial state
self.assertEqual(self.evaluate(m.total), 0)
@@ -225,7 +248,7 @@ class KerasMetricsTest(test.TestCase):
def test_mean_with_sample_weight(self):
m = metrics.Mean(dtype=dtypes.float64)
self.assertEqual(m.dtype, dtypes.float64)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(m.variables))
# check scalar weight
result_t = m(100, sample_weight=0.5)
@@ -266,11 +289,11 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(np.round(self.evaluate(m.count), decimals=2), 5.6)
def test_mean_graph_with_placeholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
v = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
- sess.run(variables.global_variables_initializer())
+ sess.run(variables.variables_initializer(m.variables))
# check __call__()
result_t = m(v, sample_weight=w)
@@ -291,7 +314,7 @@ class KerasMetricsTest(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
m = metrics.Mean()
checkpoint = checkpointable_utils.Checkpoint(mean=m)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(m.variables))
# update state
self.evaluate(m(100.))
@@ -325,7 +348,7 @@ class KerasMetricsTest(test.TestCase):
self.assertTrue(acc_obj.stateful)
self.assertEqual(len(acc_obj.variables), 2)
self.assertEqual(acc_obj.dtype, dtypes.float32)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
update_op = acc_obj.update_state([[1], [0]], [[1], [0]])
@@ -357,7 +380,7 @@ class KerasMetricsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_binary_accuracy_threshold(self):
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
result_t = acc_obj([[1], [1], [0], [0]], [[0.9], [0.6], [0.4], [0.8]])
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.5, 2)
@@ -371,7 +394,7 @@ class KerasMetricsTest(test.TestCase):
self.assertTrue(acc_obj.stateful)
self.assertEqual(len(acc_obj.variables), 2)
self.assertEqual(acc_obj.dtype, dtypes.float32)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
update_op = acc_obj.update_state([[0, 0, 1], [0, 1, 0]],
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 0bd6620220..41c5e3cccf 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,13 +20,19 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import generic_utils
-
+from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.util.tf_export import tf_export
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
@@ -220,6 +226,7 @@ def _clone_sequential_model(model, input_tensors=None):
return Sequential(layers=[input_layer] + layers, name=model.name)
+@tf_export('keras.models.clone_model')
def clone_model(model, input_tensors=None):
"""Clone any `Model` instance.
@@ -246,3 +253,223 @@ def clone_model(model, input_tensors=None):
return _clone_sequential_model(model, input_tensors=input_tensors)
else:
return _clone_functional_model(model, input_tensors=input_tensors)
+
+
+# "Clone" a subclassed model by reseting all of the attributes.
+
+
+def _in_place_subclassed_model_reset(model):
+ """Substitute for model cloning that works for subclassed models.
+
+ Subclassed models cannot be cloned because their topology is not serializable.
+ To "instantiate" an identical model in a new TF graph, we reuse the original
+ model object, but we clear its state.
+
+ After calling this function on a model instance, you can use the model
+ instance as if it were a model clone (in particular you can use it in a new
+ graph).
+
+ This method clears the state of the input model. It is thus destructive.
+ However the original state can be restored fully by calling
+ `_in_place_subclassed_model_state_restoration`.
+
+ Args:
+ model: Instance of a Keras model created via subclassing.
+
+ Raises:
+ ValueError: In case the model uses a subclassed model as inner layer.
+ """
+ assert not model._is_graph_network # Only makes sense for subclassed networks
+ # Retrieve all layers tracked by the model as well as their attribute names
+ attributes_cache = {}
+ for name in dir(model):
+ try:
+ value = getattr(model, name)
+ except (AttributeError, ValueError, TypeError):
+ continue
+ if isinstance(value, Layer):
+ attributes_cache[name] = value
+ assert value in model._layers
+ elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ # Handle case: list/tuple of layers (also tracked by the Network API).
+ if value and all(isinstance(val, Layer) for val in value):
+ raise ValueError('We do not support the use of list-of-layers '
+ 'attributes in subclassed models used with '
+ '`model_to_estimator` at this time. Found list '
+ 'model: %s' % name)
+
+ # Replace layers on the model with fresh layers
+ layers_to_names = {value: key for key, value in attributes_cache.items()}
+ original_layers = model._layers[:]
+ model._layers = data_structures.NoDependency([])
+ for layer in original_layers: # We preserve layer order.
+ config = layer.get_config()
+ # This will not work for nested subclassed models used as layers.
+ # This would be theoretically possible to support, but would add complexity.
+ # Only do it if users complain.
+ if isinstance(layer, Network) and not layer._is_graph_network:
+ raise ValueError('We do not support the use of nested subclassed models '
+ 'in `model_to_estimator` at this time. Found nested '
+ 'model: %s' % layer)
+ fresh_layer = layer.__class__.from_config(config)
+ name = layers_to_names[layer]
+ setattr(model, name, fresh_layer)
+
+ # Cache original model build attributes (in addition to layers)
+ if (not hasattr(model, '_original_attributes_cache') or
+ model._original_attributes_cache is None):
+ if model.built:
+ attributes_to_cache = [
+ 'inputs',
+ 'outputs',
+ '_feed_outputs',
+ '_feed_output_names',
+ '_feed_output_shapes',
+ '_feed_loss_fns',
+ 'loss_weights_list',
+ 'targets',
+ '_feed_targets',
+ 'sample_weight_modes',
+ 'weighted_metrics',
+ 'metrics_names',
+ 'metrics_tensors',
+ 'metrics_updates',
+ 'stateful_metric_names',
+ 'total_loss',
+ 'sample_weights',
+ '_feed_sample_weights',
+ 'train_function',
+ 'test_function',
+ 'predict_function',
+ '_collected_trainable_weights',
+ '_feed_inputs',
+ '_feed_input_names',
+ '_feed_input_shapes',
+ 'optimizer',
+ ]
+ for name in attributes_to_cache:
+ attributes_cache[name] = getattr(model, name)
+ model._original_attributes_cache = data_structures.NoDependency(
+ attributes_cache)
+ # Reset built state
+ model.built = False
+ model.inputs = None
+ model.outputs = None
+
+
+def in_place_subclassed_model_state_restoration(model):
+ """Restores the original state of a model after it was "reset".
+
+ This undoes this action of `_in_place_subclassed_model_reset`, which is called
+ in `clone_and_build_model` if `in_place_reset` is set to True.
+
+ Args:
+ model: Instance of a Keras model created via subclassing, on which
+ `_in_place_subclassed_model_reset` was previously called.
+ """
+ assert not model._is_graph_network
+ # Restore layers and build attributes
+ if (hasattr(model, '_original_attributes_cache') and
+ model._original_attributes_cache is not None):
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
+ for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
+ setattr(model, name, value)
+ model._original_attributes_cache = None
+ else:
+ # Restore to the state of a never-called model.
+ model.built = False
+ model.inputs = None
+ model.outputs = None
+
+
+def clone_and_build_model(
+ model, input_tensors=None, target_tensors=None, custom_objects=None,
+ compile_clone=True, in_place_reset=False, optimizer_iterations=None):
+ """Clone a `Model` and build/compile it with the same settings used before.
+
+ This function can be be run in the same graph or in a separate graph from the
+ model. When using a separate graph, `in_place_reset` must be `False`.
+
+ Args:
+ model: `tf.keras.Model` object. Can be Functional, Sequential, or
+ sub-classed.
+ input_tensors: Optional list of input tensors to build the model upon. If
+ not provided, placeholders will be created.
+ target_tensors: Optional list of target tensors for compiling the model. If
+ not provided, placeholders will be created.
+ custom_objects: Optional dictionary mapping string names to custom classes
+ or functions.
+ compile_clone: Boolean, whether to compile model clone (default `True`).
+ in_place_reset: Boolean, whether to reset the model in place. Only used if
+ the model is not a graph network. If the model is a subclassed model, then
+ this argument must be set to `True` (default `False`). To restore the
+ original model, use the function
+ `in_place_subclassed_model_state_restoration(model)`.
+ optimizer_iterations: An iterations variable that will be incremented by the
+ optimizer if the clone is compiled. This argument is used when a Keras
+ model is cloned into an Estimator model function, because Estimators
+ create their own global step variable.
+
+ Returns:
+ Clone of the model.
+
+ Raises:
+ ValueError: if trying to clone a subclassed model, and `in_place_reset` is
+ set to False.
+ """
+ if model._is_graph_network:
+ if custom_objects:
+ with CustomObjectScope(custom_objects):
+ clone = clone_model(model, input_tensors=input_tensors)
+ else:
+ clone = clone_model(model, input_tensors=input_tensors)
+ else:
+ if not in_place_reset:
+ raise ValueError(
+ 'Model is not a graph network (usually means that it is a subclassed '
+ 'model). The model cannot be cloned, but there is a workaround where '
+ 'the model is reset in-place. To use this, please set the argument '
+ '`in_place_reset` to `True`. This will reset the attributes in the '
+ 'original model. To restore the attributes, call '
+ '`in_place_subclassed_model_state_restoration(model)`.')
+ clone = model
+ _in_place_subclassed_model_reset(clone)
+ if input_tensors is not None:
+ if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
+ input_tensors = input_tensors[0]
+ clone._set_inputs(input_tensors)
+
+ # Compile/Build model
+ if not compile_clone:
+ if isinstance(clone, Sequential):
+ clone.build()
+ elif model.optimizer:
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ optimizer = optimizers.TFOptimizer(
+ model.optimizer.optimizer, optimizer_iterations)
+ K.track_tf_optimizer(optimizer)
+ else:
+ optimizer_config = model.optimizer.get_config()
+ optimizer = model.optimizer.__class__.from_config(optimizer_config)
+ if optimizer_iterations is not None:
+ optimizer.iterations = optimizer_iterations
+
+ clone.compile(
+ optimizer,
+ model.loss,
+ metrics=model.metrics,
+ loss_weights=model.loss_weights,
+ sample_weight_mode=model.sample_weight_mode,
+ weighted_metrics=model.weighted_metrics,
+ target_tensors=target_tensors)
+
+ return clone
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 1385ad5390..c550caeb80 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -18,20 +18,42 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import os
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras import models
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training import adam
+class TestModel(keras.Model):
+ """A model subclass."""
+
+ def __init__(self, n_outputs=4, trainable=True):
+ """A test class with one dense layer and number of outputs as a variable."""
+ super(TestModel, self).__init__()
+ self.layer1 = keras.layers.Dense(n_outputs)
+ self.n_outputs = resource_variable_ops.ResourceVariable(
+ n_outputs, trainable=trainable)
+
+ def call(self, x):
+ return self.layer1(x)
+
+
class TestModelCloning(test.TestCase):
def test_clone_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -44,7 +66,7 @@ class TestModelCloning(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
# update ops from batch norm needs to be included
@@ -69,7 +91,7 @@ class TestModelCloning(test.TestCase):
new_model.train_on_batch(None, val_out)
def test_clone_functional_model(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_b = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -90,7 +112,7 @@ class TestModelCloning(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
@@ -117,7 +139,7 @@ class TestModelCloning(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_clone_functional_model_with_masking(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
inputs = keras.Input((2, 1))
outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -169,6 +191,7 @@ class CheckpointingTests(test.TestCase):
model.load_weights(save_prefix)
self.assertEqual(12., self.evaluate(beta1_power))
+
class TestModelBackend(test.TestCase):
def test_model_backend_float64_use_cases(self):
@@ -183,5 +206,196 @@ class TestModelBackend(test.TestCase):
keras.backend.set_floatx(floatx)
+
+class TestModelDeepCopy(test.TestCase):
+
+ def test_deep_copy_eager_mode_trainable(self):
+ with context.eager_mode():
+ x = random_ops.random_normal((32, 4))
+ model = TestModel(trainable=True)
+ model(x) # Initialize Variables.
+ model_copy = copy.deepcopy(model)
+ self.assertEqual(len(model_copy.trainable_variables), 3)
+ model_copy.n_outputs.assign(1200)
+ self.assertFalse(
+ np.allclose(model_copy.n_outputs.numpy(),
+ model.n_outputs.numpy()))
+
+ def test_deep_copy_eager_mode_not_trainable(self):
+ with context.eager_mode():
+ x = random_ops.random_normal((32, 4))
+ model = TestModel(trainable=False)
+ model(x)
+ model_copy = copy.deepcopy(model)
+ self.assertEqual(len(model_copy.trainable_variables), 2)
+
+ weights = model_copy.get_weights()
+ weights = [w * 4 for w in weights]
+ model_copy.set_weights(weights)
+ self.assertFalse(
+ np.allclose(model.get_weights()[0],
+ model_copy.get_weights()[0]))
+
+
+class TestCloneAndBuildModel(test.TestCase):
+
+ def test_clone_and_build_non_compiled_model(self):
+ with self.cached_session():
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(4,)))
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dropout(0.5))
+ model.add(keras.layers.Dense(4))
+
+ # Everything should work in a new session.
+ keras.backend.clear_session()
+
+ with self.cached_session():
+ # With placeholder creation
+ new_model = models.clone_and_build_model(model, compile_clone=True)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.evaluate(inp, out)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.train_on_batch(inp, out)
+ new_model.compile('rmsprop', 'mse')
+ new_model.train_on_batch(inp, out)
+
+ # Create new tensors for inputs and targets
+ input_a = keras.Input(shape=(4,))
+ target_a = keras.Input(shape=(4,))
+ new_model = models.clone_and_build_model(model, input_tensors=input_a,
+ target_tensors=[target_a],
+ compile_clone=True)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.evaluate(inp, out)
+ with self.assertRaisesRegexp(RuntimeError, 'must compile'):
+ new_model.train_on_batch(inp, out)
+ new_model.compile('rmsprop', 'mse')
+ new_model.train_on_batch(inp, out)
+
+ def _assert_same_compile_params(self, model):
+ """Assert that two models have the same compile parameters."""
+
+ self.assertEqual('mse', model.loss)
+ self.assertTrue(
+ isinstance(model.optimizer, keras.optimizers.RMSprop))
+ self.assertEqual(['acc', metrics.categorical_accuracy], model.metrics)
+
+ def _clone_and_build_test_helper(self, model, is_subclassed=False):
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+
+ # Everything should work in a new session.
+ keras.backend.clear_session()
+
+ with self.cached_session():
+ # With placeholder creation
+ new_model = models.clone_and_build_model(
+ model, compile_clone=True, in_place_reset=is_subclassed)
+
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ # Create new tensors for inputs and targets
+ input_a = keras.Input(shape=(4,), name='a')
+ new_model = models.clone_and_build_model(
+ model, input_tensors=input_a, compile_clone=True,
+ in_place_reset=is_subclassed)
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ target_a = keras.Input(shape=(4,), name='b')
+ new_model = models.clone_and_build_model(
+ model, input_tensors=input_a, target_tensors=[target_a],
+ compile_clone=True, in_place_reset=is_subclassed)
+ self._assert_same_compile_params(new_model)
+ new_model.train_on_batch(inp, out)
+ new_model.evaluate(inp, out)
+
+ def test_clone_and_build_compiled_sequential_model(self):
+ with self.cached_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(4,)))
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dropout(0.5))
+ model.add(keras.layers.Dense(4))
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ self._clone_and_build_test_helper(model)
+
+ def test_clone_and_build_functional_model(self):
+ with self.cached_session():
+ input_a = keras.Input(shape=(4,))
+ dense_1 = keras.layers.Dense(4,)
+ dense_2 = keras.layers.Dense(4,)
+
+ x_a = dense_1(input_a)
+ x_a = keras.layers.Dropout(0.5)(x_a)
+ x_a = keras.layers.BatchNormalization()(x_a)
+ x_a = dense_2(x_a)
+ model = keras.models.Model(input_a, x_a)
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ self._clone_and_build_test_helper(model)
+
+ def test_clone_and_build_subclassed_model(self):
+ class SubclassedModel(keras.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.layer1 = keras.layers.Dense(4)
+ self.layer2 = keras.layers.Dense(4)
+
+ def call(self, inp):
+ out = self.layer1(inp)
+ out = keras.layers.BatchNormalization()(out)
+ out = keras.layers.Dropout(0.5)(out)
+ out = self.layer2(out)
+ return out
+
+ with self.cached_session():
+ model = SubclassedModel()
+ model.compile('rmsprop', 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+ self._clone_and_build_test_helper(model, True)
+
+ def assert_optimizer_iterations_increases(self, optimizer):
+ with self.cached_session():
+ input_a = keras.Input(shape=(4,))
+ dense_1 = keras.layers.Dense(4,)
+ dense_2 = keras.layers.Dense(4,)
+
+ x_a = dense_1(input_a)
+ x_a = keras.layers.Dropout(0.5)(x_a)
+ x_a = keras.layers.BatchNormalization()(x_a)
+ x_a = dense_2(x_a)
+ model = keras.models.Model(input_a, x_a)
+ model.compile(optimizer, 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ global_step = keras.backend.variable(123, dtype=dtypes.int64)
+ clone_model = models.clone_and_build_model(
+ model, compile_clone=True, optimizer_iterations=global_step)
+
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+ clone_model.train_on_batch(inp, out)
+
+ self.assertEqual(K.eval(global_step), 124)
+
+ def test_replace_tf_optimizer_iterations_variable(self):
+ self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
+
+ def test_replace_keras_optimizer_iterations_variable(self):
+ self.assert_optimizer_iterations_increases('adam')
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index f339a7e047..ab13e5c632 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -692,14 +692,18 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
"""Wrapper class for native TensorFlow optimizers.
"""
- def __init__(self, optimizer): # pylint: disable=super-init-not-called
+ def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called
self.optimizer = optimizer
self._track_checkpointable(optimizer, name='optimizer')
- with K.name_scope(self.__class__.__name__):
- self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if iterations is None:
+ with K.name_scope(self.__class__.__name__):
+ self.iterations = K.variable(0, dtype='int64', name='iterations')
+ else:
+ self.iterations = iterations
+ self._track_checkpointable(self.iterations, name='global_step')
def apply_gradients(self, grads):
- self.optimizer.apply_gradients(grads)
+ self.optimizer.apply_gradients(grads, global_step=self.iterations)
def get_grads(self, loss, params):
return self.optimizer.compute_gradients(loss, params)
@@ -813,7 +817,9 @@ def get(identifier):
"""
# Wrap TF optimizer instances
if isinstance(identifier, tf_optimizer_module.Optimizer):
- return TFOptimizer(identifier)
+ opt = TFOptimizer(identifier)
+ K.track_tf_optimizer(opt)
+ return opt
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 4d295351f5..9a68fc0e35 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
from tensorflow.python.training.adam import AdamOptimizer
@@ -140,6 +142,7 @@ class KerasOptimizersTest(test.TestCase):
2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
# This is possible
model.compile(loss='mean_squared_error', optimizer=optimizer)
+ keras.backend.track_tf_optimizer(optimizer)
model.fit(np.random.random((5, 3)),
np.random.random((5, 2)),
epochs=1,
@@ -153,6 +156,7 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ @test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
with self.test_session():
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
@@ -160,6 +164,7 @@ class KerasOptimizersTest(test.TestCase):
model.add(keras.layers.Dense(
2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
model.compile(loss='mean_squared_error', optimizer=optimizer)
+ keras.backend.track_tf_optimizer(optimizer)
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 0)
model.fit(np.random.random((55, 3)),
@@ -169,11 +174,15 @@ class KerasOptimizersTest(test.TestCase):
verbose=0)
self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
- model.fit(np.random.random((20, 3)),
- np.random.random((20, 2)),
- steps_per_epoch=8,
- verbose=0)
- self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19)
+ if not context.executing_eagerly():
+ # TODO(kathywu): investigate why training with an array input and
+ # setting the argument steps_per_epoch does not work in eager mode.
+ model.fit(np.random.random((20, 3)),
+ np.random.random((20, 2)),
+ steps_per_epoch=8,
+ verbose=0)
+ self.assertEqual(
+ keras.backend.get_value(model.optimizer.iterations), 19)
def test_negative_clipvalue_or_clipnorm(self):
with self.assertRaises(ValueError):
diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py
index 2f08f88600..0860eed3cf 100644
--- a/tensorflow/python/keras/preprocessing/__init__.py
+++ b/tensorflow/python/keras/preprocessing/__init__.py
@@ -23,6 +23,8 @@ import keras_preprocessing
from tensorflow.python.keras import backend
from tensorflow.python.keras import utils
+# This exists for compatibility with prior version of keras_preprocessing.
+# TODO(fchollet): remove in the future.
keras_preprocessing.set_keras_submodules(backend=backend, utils=utils)
from tensorflow.python.keras.preprocessing import image
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index ba227385ef..e33993950d 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -27,6 +27,9 @@ try:
except ImportError:
pass
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
random_rotation = image.random_rotation
@@ -38,14 +41,482 @@ random_channel_shift = image.random_channel_shift
apply_brightness_shift = image.apply_brightness_shift
random_brightness = image.random_brightness
apply_affine_transform = image.apply_affine_transform
-array_to_img = image.array_to_img
-img_to_array = image.img_to_array
-save_img = image.save_img
load_img = image.load_img
-ImageDataGenerator = image.ImageDataGenerator
-Iterator = image.Iterator
-NumpyArrayIterator = image.NumpyArrayIterator
-DirectoryIterator = image.DirectoryIterator
+
+
+@tf_export('keras.preprocessing.image.array_to_img')
+def array_to_img(x, data_format=None, scale=True, dtype=None):
+ """Converts a 3D Numpy array to a PIL Image instance.
+
+ Arguments:
+ x: Input Numpy array.
+ data_format: Image data format.
+ either "channels_first" or "channels_last".
+ scale: Whether to rescale image values
+ to be within `[0, 255]`.
+ dtype: Dtype to use.
+
+ Returns:
+ A PIL Image instance.
+
+ Raises:
+ ImportError: if PIL is not available.
+ ValueError: if invalid `x` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.img_to_array')
+def img_to_array(img, data_format=None, dtype=None):
+ """Converts a PIL Image instance to a Numpy array.
+
+ Arguments:
+ img: PIL Image instance.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ dtype: Dtype to use for the returned array.
+
+ Returns:
+ A 3D Numpy array.
+
+ Raises:
+ ValueError: if invalid `img` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.img_to_array(img, data_format=data_format, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.save_img')
+def save_img(path,
+ x,
+ data_format=None,
+ file_format=None,
+ scale=True,
+ **kwargs):
+ """Saves an image stored as a Numpy array to a path or file object.
+
+ Arguments:
+ path: Path or file object.
+ x: Numpy array.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ file_format: Optional file format override. If omitted, the
+ format to use is determined from the filename extension.
+ If a file object was used instead of a filename, this
+ parameter should always be used.
+ scale: Whether to rescale image values to be within `[0, 255]`.
+ **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
+ """
+ if data_format is None:
+ data_format = backend.image_data_format()
+ image.save_img(path,
+ x,
+ data_format=data_format,
+ file_format=file_format,
+ scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.Iterator')
+class Iterator(image.Iterator, utils.Sequence):
+ pass
+
+
+@tf_export('keras.preprocessing.image.DirectoryIterator')
+class DirectoryIterator(image.DirectoryIterator, Iterator):
+ """Iterator capable of reading images from a directory on disk.
+
+ Arguments:
+ directory: Path to the directory to read images from.
+ Each subdirectory in this directory will be
+ considered to contain images from one class,
+ or alternatively you could specify class subdirectories
+ via the `classes` argument.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ target_size: tuple of integers, dimensions to resize input images to.
+ color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
+ Color mode to read images.
+ classes: Optional list of strings, names of subdirectories
+ containing images from each class (e.g. `["dogs", "cats"]`).
+ It will be computed automatically if not set.
+ class_mode: Mode for yielding the targets:
+ `"binary"`: binary targets (if there are only two classes),
+ `"categorical"`: categorical targets,
+ `"sparse"`: integer targets,
+ `"input"`: targets are images identical to input images (mainly
+ used to work with autoencoders),
+ `None`: no targets get yielded (only input images are yielded).
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image.
+ Supported methods are "nearest", "bilinear", and "bicubic".
+ If PIL version 1.1.3 or newer is installed, "lanczos" is also
+ supported. If PIL version 3.4.0 or newer is installed, "box" and
+ "hamming" are also supported. By default, "nearest" is used.
+ dtype: Dtype to use for generated arrays.
+ """
+
+ def __init__(self, directory, image_data_generator,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ follow_links=False,
+ subset=None,
+ interpolation='nearest',
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(DirectoryIterator, self).__init__(
+ directory, image_data_generator,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ follow_links=follow_links,
+ subset=subset,
+ interpolation=interpolation,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.NumpyArrayIterator')
+class NumpyArrayIterator(image.NumpyArrayIterator, Iterator):
+ """Iterator yielding data from a Numpy array.
+
+ Arguments:
+ x: Numpy array of input data or tuple.
+ If tuple, the second elements is either
+ another numpy array or a list of numpy arrays,
+ each of which gets passed
+ through as an output without any modifications.
+ y: Numpy array of targets data.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ sample_weight: Numpy array of sample weights.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ dtype: Dtype to use for the generated arrays.
+ """
+
+ def __init__(self, x, y, image_data_generator,
+ batch_size=32,
+ shuffle=False,
+ sample_weight=None,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.NumpyArrayIterator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(NumpyArrayIterator, self).__init__(
+ x, y, image_data_generator,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ sample_weight=sample_weight,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.ImageDataGenerator')
+class ImageDataGenerator(image.ImageDataGenerator):
+ """Generate batches of tensor image data with real-time data augmentation.
+
+ The data will be looped over (in batches).
+
+ Arguments:
+ featurewise_center: Boolean.
+ Set input mean to 0 over the dataset, feature-wise.
+ samplewise_center: Boolean. Set each sample mean to 0.
+ featurewise_std_normalization: Boolean.
+ Divide inputs by std of the dataset, feature-wise.
+ samplewise_std_normalization: Boolean. Divide each input by its std.
+ zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
+ zca_whitening: Boolean. Apply ZCA whitening.
+ rotation_range: Int. Degree range for random rotations.
+ width_shift_range: Float, 1-D array-like or int
+ - float: fraction of total width, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-width_shift_range, +width_shift_range)`
+ - With `width_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `width_shift_range=[-1, 0, +1]`,
+ while with `width_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ height_shift_range: Float, 1-D array-like or int
+ - float: fraction of total height, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-height_shift_range, +height_shift_range)`
+ - With `height_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `height_shift_range=[-1, 0, +1]`,
+ while with `height_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ brightness_range: Tuple or list of two floats. Range for picking
+ a brightness shift value from.
+ shear_range: Float. Shear Intensity
+ (Shear angle in counter-clockwise direction in degrees)
+ zoom_range: Float or [lower, upper]. Range for random zoom.
+ If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
+ channel_shift_range: Float. Range for random channel shifts.
+ fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
+ Default is 'nearest'.
+ Points outside the boundaries of the input are filled
+ according to the given mode:
+ - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
+ - 'nearest': aaaaaaaa|abcd|dddddddd
+ - 'reflect': abcddcba|abcd|dcbaabcd
+ - 'wrap': abcdabcd|abcd|abcdabcd
+ cval: Float or Int.
+ Value used for points outside the boundaries
+ when `fill_mode = "constant"`.
+ horizontal_flip: Boolean. Randomly flip inputs horizontally.
+ vertical_flip: Boolean. Randomly flip inputs vertically.
+ rescale: rescaling factor. Defaults to None.
+ If None or 0, no rescaling is applied,
+ otherwise we multiply the data by the value provided
+ (after applying all other transformations).
+ preprocessing_function: function that will be implied on each input.
+ The function will run after the image is resized and augmented.
+ The function should take one argument:
+ one image (Numpy tensor with rank 3),
+ and should output a Numpy tensor with the same shape.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ "channels_last" mode means that the images should have shape
+ `(samples, height, width, channels)`,
+ "channels_first" mode means that the images should have shape
+ `(samples, channels, height, width)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
+ validation_split: Float. Fraction of images reserved for validation
+ (strictly between 0 and 1).
+ dtype: Dtype to use for the generated arrays.
+
+ Examples:
+
+ Example of using `.flow(x, y)`:
+
+ ```python
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+ y_train = np_utils.to_categorical(y_train, num_classes)
+ y_test = np_utils.to_categorical(y_test, num_classes)
+ datagen = ImageDataGenerator(
+ featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=20,
+ width_shift_range=0.2,
+ height_shift_range=0.2,
+ horizontal_flip=True)
+ # compute quantities required for featurewise normalization
+ # (std, mean, and principal components if ZCA whitening is applied)
+ datagen.fit(x_train)
+ # fits the model on batches with real-time data augmentation:
+ model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
+ steps_per_epoch=len(x_train) / 32, epochs=epochs)
+ # here's a more "manual" example
+ for e in range(epochs):
+ print('Epoch', e)
+ batches = 0
+ for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
+ model.fit(x_batch, y_batch)
+ batches += 1
+ if batches >= len(x_train) / 32:
+ # we need to break the loop by hand because
+ # the generator loops indefinitely
+ break
+ ```
+
+ Example of using `.flow_from_directory(directory)`:
+
+ ```python
+ train_datagen = ImageDataGenerator(
+ rescale=1./255,
+ shear_range=0.2,
+ zoom_range=0.2,
+ horizontal_flip=True)
+ test_datagen = ImageDataGenerator(rescale=1./255)
+ train_generator = train_datagen.flow_from_directory(
+ 'data/train',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ validation_generator = test_datagen.flow_from_directory(
+ 'data/validation',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50,
+ validation_data=validation_generator,
+ validation_steps=800)
+ ```
+
+ Example of transforming images and masks together.
+
+ ```python
+ # we create two instances with the same arguments
+ data_gen_args = dict(featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=90,
+ width_shift_range=0.1,
+ height_shift_range=0.1,
+ zoom_range=0.2)
+ image_datagen = ImageDataGenerator(**data_gen_args)
+ mask_datagen = ImageDataGenerator(**data_gen_args)
+ # Provide the same seed and keyword arguments to the fit and flow methods
+ seed = 1
+ image_datagen.fit(images, augment=True, seed=seed)
+ mask_datagen.fit(masks, augment=True, seed=seed)
+ image_generator = image_datagen.flow_from_directory(
+ 'data/images',
+ class_mode=None,
+ seed=seed)
+ mask_generator = mask_datagen.flow_from_directory(
+ 'data/masks',
+ class_mode=None,
+ seed=seed)
+ # combine generators into one which yields image and masks
+ train_generator = zip(image_generator, mask_generator)
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50)
+ ```
+ """
+
+ def __init__(self,
+ featurewise_center=False,
+ samplewise_center=False,
+ featurewise_std_normalization=False,
+ samplewise_std_normalization=False,
+ zca_whitening=False,
+ zca_epsilon=1e-6,
+ rotation_range=0,
+ width_shift_range=0.,
+ height_shift_range=0.,
+ brightness_range=None,
+ shear_range=0.,
+ zoom_range=0.,
+ channel_shift_range=0.,
+ fill_mode='nearest',
+ cval=0.,
+ horizontal_flip=False,
+ vertical_flip=False,
+ rescale=None,
+ preprocessing_function=None,
+ data_format=None,
+ validation_split=0.0,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(ImageDataGenerator, self).__init__(
+ featurewise_center=featurewise_center,
+ samplewise_center=samplewise_center,
+ featurewise_std_normalization=featurewise_std_normalization,
+ samplewise_std_normalization=samplewise_std_normalization,
+ zca_whitening=zca_whitening,
+ zca_epsilon=zca_epsilon,
+ rotation_range=rotation_range,
+ width_shift_range=width_shift_range,
+ height_shift_range=height_shift_range,
+ brightness_range=brightness_range,
+ shear_range=shear_range,
+ zoom_range=zoom_range,
+ channel_shift_range=channel_shift_range,
+ fill_mode=fill_mode,
+ cval=cval,
+ horizontal_flip=horizontal_flip,
+ vertical_flip=vertical_flip,
+ rescale=rescale,
+ preprocessing_function=preprocessing_function,
+ data_format=data_format,
+ validation_split=validation_split,
+ **kwargs)
tf_export('keras.preprocessing.image.random_rotation')(random_rotation)
tf_export('keras.preprocessing.image.random_shift')(random_shift)
@@ -59,11 +530,4 @@ tf_export(
tf_export('keras.preprocessing.image.random_brightness')(random_brightness)
tf_export(
'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform)
-tf_export('keras.preprocessing.image.array_to_img')(array_to_img)
-tf_export('keras.preprocessing.image.img_to_array')(img_to_array)
-tf_export('keras.preprocessing.image.save_img')(save_img)
tf_export('keras.preprocessing.image.load_img')(load_img)
-tf_export('keras.preprocessing.image.ImageDataGenerator')(ImageDataGenerator)
-tf_export('keras.preprocessing.image.Iterator')(Iterator)
-tf_export('keras.preprocessing.image.NumpyArrayIterator')(NumpyArrayIterator)
-tf_export('keras.preprocessing.image.DirectoryIterator')(DirectoryIterator)
diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py
index 116d3108d9..f014668909 100644
--- a/tensorflow/python/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/preprocessing/sequence.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from keras_preprocessing import sequence
+from tensorflow.python.keras import utils
from tensorflow.python.util.tf_export import tf_export
pad_sequences = sequence.pad_sequences
@@ -28,11 +29,67 @@ make_sampling_table = sequence.make_sampling_table
skipgrams = sequence.skipgrams
# TODO(fchollet): consider making `_remove_long_seq` public.
_remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access
-TimeseriesGenerator = sequence.TimeseriesGenerator
+
+
+@tf_export('keras.preprocessing.sequence.TimeseriesGenerator')
+class TimeseriesGenerator(sequence.TimeseriesGenerator, utils.Sequence):
+ """Utility class for generating batches of temporal data.
+ This class takes in a sequence of data-points gathered at
+ equal intervals, along with time series parameters such as
+ stride, length of history, etc., to produce batches for
+ training/validation.
+ # Arguments
+ data: Indexable generator (such as list or Numpy array)
+ containing consecutive data points (timesteps).
+ The data should be at 2D, and axis 0 is expected
+ to be the time dimension.
+ targets: Targets corresponding to timesteps in `data`.
+ It should have same length as `data`.
+ length: Length of the output sequences (in number of timesteps).
+ sampling_rate: Period between successive individual timesteps
+ within sequences. For rate `r`, timesteps
+ `data[i]`, `data[i-r]`, ... `data[i - length]`
+ are used for create a sample sequence.
+ stride: Period between successive output sequences.
+ For stride `s`, consecutive output samples would
+ be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
+ start_index: Data points earlier than `start_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ end_index: Data points later than `end_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ shuffle: Whether to shuffle output samples,
+ or instead draw them in chronological order.
+ reverse: Boolean: if `true`, timesteps in each output sample will be
+ in reverse chronological order.
+ batch_size: Number of timeseries samples in each batch
+ (except maybe the last one).
+ # Returns
+ A [Sequence](/utils/#sequence) instance.
+ # Examples
+ ```python
+ from keras.preprocessing.sequence import TimeseriesGenerator
+ import numpy as np
+ data = np.array([[i] for i in range(50)])
+ targets = np.array([[i] for i in range(50)])
+ data_gen = TimeseriesGenerator(data, targets,
+ length=10, sampling_rate=2,
+ batch_size=2)
+ assert len(data_gen) == 20
+ batch_0 = data_gen[0]
+ x, y = batch_0
+ assert np.array_equal(x,
+ np.array([[[0], [2], [4], [6], [8]],
+ [[1], [3], [5], [7], [9]]]))
+ assert np.array_equal(y,
+ np.array([[10], [11]]))
+ ```
+ """
+ pass
+
tf_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences)
tf_export(
'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table)
tf_export('keras.preprocessing.sequence.skipgrams')(skipgrams)
-tf_export(
- 'keras.preprocessing.sequence.TimeseriesGenerator')(TimeseriesGenerator)
diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py
index e2075785d8..bba4ebb287 100644
--- a/tensorflow/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/regularizers_test.py
@@ -50,7 +50,7 @@ def create_model(kernel_regularizer=None, activity_regularizer=None):
class KerasRegularizersTest(test.TestCase):
def test_kernel_regularization(self):
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), _ = get_data()
for reg in [keras.regularizers.l1(),
keras.regularizers.l2(),
@@ -62,7 +62,7 @@ class KerasRegularizersTest(test.TestCase):
epochs=1, verbose=0)
def test_activity_regularization(self):
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), _ = get_data()
for reg in [keras.regularizers.l1(), keras.regularizers.l2()]:
model = create_model(activity_regularizer=reg)
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 6e8ee06ff5..58405c550b 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -184,3 +184,22 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
# for further checks in the caller function
return actual_output
+
+def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
+ model = keras.models.Sequential()
+ if input_dim:
+ model.add(keras.layers.Dense(num_hidden, activation='relu',
+ input_dim=input_dim))
+ else:
+ model.add(keras.layers.Dense(num_hidden, activation='relu'))
+ activation = 'sigmoid' if num_classes == 1 else 'softmax'
+ model.add(keras.layers.Dense(num_classes, activation=activation))
+ return model
+
+
+def get_small_functional_mlp(num_hidden, num_classes, input_dim):
+ inputs = keras.Input(shape=(input_dim,))
+ outputs = keras.layers.Dense(num_hidden, activation='relu')(inputs)
+ activation = 'sigmoid' if num_classes == 1 else 'softmax'
+ outputs = keras.layers.Dense(num_classes, activation=activation)(outputs)
+ return keras.Model(inputs, outputs)
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index c1ee34ae46..d93a7b6afc 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -494,6 +494,7 @@ class SequenceEnqueuer(object):
raise NotImplementedError
+@tf_export('keras.utils.OrderedEnqueuer')
class OrderedEnqueuer(SequenceEnqueuer):
"""Builds a Enqueuer from a Sequence.
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 1f28c59ea4..158a9a5e76 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel
from tensorflow.python.util.tf_export import tf_export
+@tf_export('keras.utils.get_source_inputs')
def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 77792d14f5..c7e94998b4 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -180,6 +180,23 @@ class TestMultiGPUModel(test.TestCase):
target_tensors=[targets])
parallel_model.fit(epochs=1, steps_per_epoch=3)
+ def test_multi_gpu_with_multi_input_layers(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ inputs = keras.Input((4, 3))
+ init_state = keras.Input((3,))
+ outputs = keras.layers.SimpleRNN(
+ 3, return_sequences=True)(inputs, initial_state=init_state)
+ x = [np.random.randn(2, 4, 3), np.random.randn(2, 3)]
+ y = np.random.randn(2, 4, 3)
+ model = keras.Model([inputs, init_state], outputs)
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus)
+ parallel_model.compile(loss='mean_squared_error', optimizer='adam')
+ parallel_model.train_on_batch(x, y)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py
index 162e5b2cd6..cfdb3de2aa 100644
--- a/tensorflow/python/keras/utils/tf_utils.py
+++ b/tensorflow/python/keras/utils/tf_utils.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import nest
@@ -109,10 +110,10 @@ def get_reachable_from_inputs(inputs, targets=None):
if isinstance(x, ops.Operation):
outputs = x.outputs[:] or []
outputs += x._control_outputs # pylint: disable=protected-access
- elif isinstance(x, ops.Tensor):
- outputs = x.consumers()
elif isinstance(x, variables.Variable):
outputs = [x.op]
+ elif tensor_util.is_tensor(x):
+ outputs = x.consumers()
else:
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c84ed9d485..0403211d92 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -84,6 +84,25 @@ tf_py_test(
)
tf_py_test(
+ name = "batch_scatter_ops_test",
+ srcs = ["batch_scatter_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+tf_py_test(
name = "bcast_ops_test",
size = "small",
srcs = ["bcast_ops_test.py"],
@@ -582,7 +601,7 @@ tf_py_test(
tf_py_test(
name = "matrix_logarithm_op_test",
- size = "small",
+ size = "medium",
srcs = ["matrix_logarithm_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -603,6 +622,7 @@ cuda_py_test(
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
],
+ tags = ["notap"],
)
cuda_py_test(
@@ -645,7 +665,7 @@ cuda_py_test(
cuda_py_test(
name = "parameterized_truncated_normal_op_test",
- size = "small",
+ size = "medium",
srcs = ["parameterized_truncated_normal_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -747,6 +767,7 @@ tf_py_test(
size = "small",
srcs = ["regex_replace_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@@ -759,6 +780,7 @@ tf_py_test(
size = "small",
srcs = ["regex_full_match_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@@ -1101,6 +1123,7 @@ tf_py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:state_ops",
+ "//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
@@ -1368,6 +1391,8 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -1420,6 +1445,7 @@ cuda_py_test(
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:cond_v2",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
@@ -1610,6 +1636,7 @@ cuda_py_test(
srcs = ["functional_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 81442d12e9..a164682227 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -559,6 +559,22 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
s = array_ops.strided_slice(x, begin, end, strides)
self.assertAllEqual([3.], self.evaluate(s))
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ @test_util.assert_no_garbage_created
+ def testTensorSliceEagerMemory(self):
+ with context.eager_mode():
+ inputs = constant_op.constant(
+ [[[1], [2], [3], [4]]], dtype=dtypes.float32)
+ # Tests that slicing an EagerTensor doesn't leak memory
+ inputs[0] # pylint: disable=pointless-statement
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ @test_util.assert_no_garbage_created
+ def testVariableSliceEagerMemory(self):
+ with context.eager_mode():
+ v = variables.Variable([1., 2.])
+ v[0] # pylint: disable=pointless-statement
+
def testDegenerateSlices(self):
with self.test_session(use_gpu=True):
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
@@ -1145,7 +1161,7 @@ class IdentityTest(test_util.TensorFlowTestCase):
def testEagerIdentity(self):
with context.eager_mode():
- ctx = context.get_default_context()
+ ctx = context.context()
if not ctx.num_gpus():
self.skipTest("No GPUs found")
diff --git a/tensorflow/python/kernel_tests/batch_scatter_ops_test.py b/tensorflow/python/kernel_tests/batch_scatter_ops_test.py
new file mode 100644
index 0000000000..0d41a7e3b3
--- /dev/null
+++ b/tensorflow/python/kernel_tests/batch_scatter_ops_test.py
@@ -0,0 +1,129 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.tf.scatter."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _AsType(v, vtype):
+ return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
+
+
+def _NumpyUpdate(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ indx = i[:-1] + (indx,)
+ ref[indx] = updates[i]
+
+
+_TF_OPS_TO_NUMPY = {
+ state_ops.batch_scatter_update: _NumpyUpdate,
+}
+
+
+class ScatterTest(test.TestCase):
+
+ def _VariableRankTest(self,
+ tf_scatter,
+ vtype,
+ itype,
+ repeat_indices=False,
+ updates_are_scalar=False):
+ np.random.seed(8)
+ with self.test_session(use_gpu=False):
+ for indices_shape in (2,), (3, 7), (3, 4, 7):
+ for extra_shape in (), (5,), (5, 9):
+ # Generate random indices with no duplicates for easy numpy comparison
+ sparse_dim = len(indices_shape) - 1
+ indices = np.random.randint(
+ indices_shape[sparse_dim], size=indices_shape, dtype=itype)
+ updates = _AsType(
+ np.random.randn(*(indices_shape + extra_shape)), vtype)
+
+ old = _AsType(np.random.randn(*(indices_shape + extra_shape)), vtype)
+
+ # Scatter via numpy
+ new = old.copy()
+ np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
+ np_scatter(new, indices, updates)
+ # Scatter via tensorflow
+ ref = variables.Variable(old)
+ ref.initializer.run()
+ tf_scatter(ref, indices, updates).eval()
+ self.assertAllClose(ref.eval(), new)
+
+ def _VariableRankTests(self,
+ tf_scatter):
+ vtypes = [np.float32, np.float64]
+ if tf_scatter != state_ops.scatter_div:
+ vtypes.append(np.int32)
+
+ for vtype in vtypes:
+ for itype in (np.int32, np.int64):
+ self._VariableRankTest(tf_scatter, vtype, itype)
+
+ def testVariableRankUpdate(self):
+ vtypes = [np.float32, np.float64]
+ for vtype in vtypes:
+ for itype in (np.int32, np.int64):
+ self._VariableRankTest(
+ state_ops.batch_scatter_update, vtype, itype)
+
+ def testBooleanScatterUpdate(self):
+ with self.test_session(use_gpu=False) as session:
+ var = variables.Variable([True, False])
+ update0 = state_ops.batch_scatter_update(var, [1], [True])
+ update1 = state_ops.batch_scatter_update(
+ var, constant_op.constant(
+ [0], dtype=dtypes.int64), [False])
+ var.initializer.run()
+
+ session.run([update0, update1])
+
+ self.assertAllEqual([False, True], var.eval())
+
+ def testScatterOutOfRange(self):
+ params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
+ updates = np.array([-3, -4, -5]).astype(np.float32)
+ with self.test_session(use_gpu=False):
+ ref = variables.Variable(params)
+ ref.initializer.run()
+
+ # Indices all in range, no problem.
+ indices = np.array([2, 0, 5])
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+ # Test some out of range errors.
+ indices = np.array([-1, 0, 5])
+ with self.assertRaisesOpError(
+ r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+ indices = np.array([2, 0, 6])
+ with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
+ r'shape \[6\]'):
+ state_ops.batch_scatter_update(ref, indices, updates).eval()
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index 6a1bd958ba..bd2339f31d 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -21,8 +21,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import test as test_lib
@@ -81,5 +83,47 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
# check shape inference when shape input is constant
self.assertAllEqual(shape, v_np.shape)
+ def testGradientForScalar(self):
+ # TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
+ # hence we make this test cpu-only.
+ with ops.device("cpu:0"):
+ x = constant_op.constant(1, dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 4, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithSameRank(self):
+ x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [2, 5, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithIncreasingRank(self):
+ x = constant_op.constant([[1], [2]],
+ dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 2, 3])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+ def testGradientWithBroadcastAllDimensions(self):
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
+ v = array_ops.broadcast_to(x, [5, 4, 6])
+ out = 2 * v
+ with self.test_session():
+ err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+ out, out.get_shape())
+ self.assertLess(err, 1e-4)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index bda6ca5ca9..27a674e223 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -18,8 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,6 +33,9 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -110,7 +117,7 @@ class AssertEqualTest(test.TestCase):
check_ops.assert_equal(static_big, static_small, message="fail")
def test_raises_when_greater_dynamic(self):
- with self.test_session():
+ with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies(
@@ -188,7 +195,7 @@ First 2 elements of y:
check_ops.assert_equal(static_big, static_small, message="fail")
def test_raises_when_less_dynamic(self):
- with self.test_session():
+ with self.cached_session():
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
@@ -265,30 +272,28 @@ class AssertNoneEqualTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
- with self.test_session():
- small = constant_op.constant([1, 1, 1], name="small")
- big = constant_op.constant([10, 10], name="big")
- # The exception in eager and non-eager mode is different because
- # eager mode relies on shape check done as part of the C++ op, while
- # graph mode does shape checks when creating the `Operation` instance.
- with self.assertRaisesRegexp(
- (ValueError, errors.InvalidArgumentError),
- (r"Incompatible shapes: \[3\] vs. \[2\]|"
- r"Dimensions must be equal, but are 3 and 2")):
- with ops.control_dependencies(
- [check_ops.assert_none_equal(small, big)]):
- out = array_ops.identity(small)
- self.evaluate(out)
+ small = constant_op.constant([1, 1, 1], name="small")
+ big = constant_op.constant([10, 10], name="big")
+ # The exception in eager and non-eager mode is different because
+ # eager mode relies on shape check done as part of the C++ op, while
+ # graph mode does shape checks when creating the `Operation` instance.
+ with self.assertRaisesRegexp(
+ (ValueError, errors.InvalidArgumentError),
+ (r"Incompatible shapes: \[3\] vs. \[2\]|"
+ r"Dimensions must be equal, but are 3 and 2")):
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(small, big)]):
+ out = array_ops.identity(small)
+ self.evaluate(out)
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
- with self.test_session():
- larry = constant_op.constant([])
- curly = constant_op.constant([])
- with ops.control_dependencies(
- [check_ops.assert_none_equal(larry, curly)]):
- out = array_ops.identity(larry)
- self.evaluate(out)
+ larry = constant_op.constant([])
+ curly = constant_op.constant([])
+ with ops.control_dependencies(
+ [check_ops.assert_none_equal(larry, curly)]):
+ out = array_ops.identity(larry)
+ self.evaluate(out)
def test_returns_none_with_eager(self):
with context.eager_mode():
@@ -745,6 +750,158 @@ class AssertPositiveTest(test.TestCase):
self.evaluate(out)
+class EnsureShapeTest(test.TestCase):
+
+ # Static shape inference
+ def testStaticShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3))
+ self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3))
+
+ def testStaticShape_MergesShapes(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None))
+ self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3))
+
+ def testStaticShape_RaisesErrorWhenRankIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 3))
+
+ def testStaticShape_RaisesErrorWhenDimIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 2, 4))
+
+ def testStaticShape_CanSetUnknownShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ ensure_shape_op = check_ops.ensure_shape(derived, None)
+ self.assertEqual(ensure_shape_op.get_shape(), None)
+
+ # Dynamic shape check
+ def testEnsuresDynamicShape_RaisesError(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = math_ops.divide(placeholder, 3, name="MyDivide")
+ derived = check_ops.ensure_shape(derived, (3, 3, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor MyDivide \[2,1\] is not compatible with "
+ r"expected shape \[3,3,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
+ r"expected shape \[\?,\?,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (2, 1))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_WithUnknownDims(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testGradient(self):
+ placeholder = array_ops.placeholder(dtypes.float32)
+ derived = check_ops.ensure_shape(placeholder, (None, None))
+ gradient = gradients.gradients(derived, placeholder)
+
+ feed_val = [[4.0], [-1.0]]
+ with self.test_session() as sess:
+ gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
+
+ expected = [[1.0], [1.0]]
+ self.assertAllEqual(gradient_values, expected)
+
+
+class EnsureShapeBenchmark(test.Benchmark):
+
+ def _grappler_all_off_config(self):
+ config = config_pb2.ConfigProto()
+ off = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.disable_model_pruning = 1
+ config.graph_options.rewrite_options.constant_folding = off
+ config.graph_options.rewrite_options.layout_optimizer = off
+ config.graph_options.rewrite_options.arithmetic_optimization = off
+ config.graph_options.rewrite_options.dependency_optimization = off
+ return config
+
+ def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs):
+ config = self._grappler_all_off_config()
+ with session.Session(config=config) as sess:
+ deltas = []
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op, feed_dict=feed_dict)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(op, feed_dict=feed_dict)
+ end = time.time()
+ deltas.append(end - start)
+ mean_time = np.median(deltas)
+ mean_us = mean_time * 1e6
+ # mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ wall_time=mean_us,
+ extras=kwargs,
+ )
+
+ def benchmark_const_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ self._run(array_ops.identity(input_op), name="SingleConstOp")
+
+ def benchmark_single_ensure_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ ensure_shape_op = check_ops.ensure_shape(input_op, shape)
+ self._run(ensure_shape_op, name="SingleEnsureShapeOp")
+
+ def _apply_n_times(self, op, target, n=1000):
+ for _ in range(n):
+ target = op(target)
+ return target
+
+ def benchmark_n_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ops = self._apply_n_times(array_ops.identity, input_op)
+ self._run(n_ops, name="NIdentityOps_1000")
+
+ def benchmark_n_ensure_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ensure_ops = self._apply_n_times(
+ lambda x: check_ops.ensure_shape(array_ops.identity(x), shape),
+ input_op)
+ self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000")
+
+
class AssertRankTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -759,7 +916,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -777,7 +934,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -794,7 +951,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -811,7 +968,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -828,7 +985,7 @@ class AssertRankTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
@@ -843,7 +1000,7 @@ class AssertRankTest(test.TestCase):
check_ops.assert_rank(tensor, np.array([], dtype=np.int32))
def test_raises_if_rank_is_not_scalar_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
@@ -860,7 +1017,7 @@ class AssertRankTest(test.TestCase):
check_ops.assert_rank(tensor, .5)
def test_raises_if_rank_is_not_integer_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
[1, 2], dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -883,7 +1040,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank0))
def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
@@ -899,7 +1056,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank0))
def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
@@ -915,7 +1072,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank1))
def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
with ops.control_dependencies([
@@ -933,7 +1090,7 @@ class AssertRankInTest(test.TestCase):
self.evaluate(array_ops.identity(tensor_rank1))
def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
with ops.control_dependencies([
check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
@@ -952,7 +1109,7 @@ class AssertRankInTest(test.TestCase):
check_ops.assert_rank_in(tensor, desired_ranks)
def test_raises_if_rank_is_not_scalar_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
desired_ranks = (
@@ -974,7 +1131,7 @@ class AssertRankInTest(test.TestCase):
check_ops.assert_rank_in(tensor, (1, .5,))
def test_raises_if_rank_is_not_integer_dynamic(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant(
(42, 43), dtype=dtypes.float32, name="my_tensor")
rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -997,7 +1154,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -1014,7 +1171,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -1030,7 +1187,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 0
with ops.control_dependencies(
@@ -1046,7 +1203,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 1
with ops.control_dependencies(
@@ -1063,7 +1220,7 @@ class AssertRankAtLeastTest(test.TestCase):
self.evaluate(array_ops.identity(tensor))
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
desired_rank = 2
with ops.control_dependencies(
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index 400d38b936..de52a70cc0 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
@@ -158,13 +159,19 @@ class ClipTest(test.TestCase):
ans = clip_ops.clip_by_norm(x, clip_norm)
tf_ans = ans.eval()
- clip_tensor = constant_op.constant(4.0)
ans = clip_ops.clip_by_norm(x, clip_norm)
tf_ans_tensor = ans.eval()
self.assertAllClose(np_ans, tf_ans)
self.assertAllClose(np_ans, tf_ans_tensor)
+ def testClipByNormGradientZeros(self):
+ with self.test_session(use_gpu=True):
+ x = array_ops.zeros([3])
+ b = clip_ops.clip_by_norm(x, 1.)
+ grad, = gradients_impl.gradients(b, x)
+ self.assertAllEqual(grad.eval(), [1., 1., 1.])
+
def testClipByNormBadShape(self):
with self.test_session(use_gpu=True):
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index b9910133d8..0dc3c53bc0 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -20,9 +20,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
@@ -158,7 +158,7 @@ class CondV2Test(test.TestCase):
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
return x * y * 2.0
@@ -172,6 +172,8 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -180,10 +182,10 @@ class CondV2Test(test.TestCase):
def false_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
return x * y * 2.0
@@ -196,18 +198,20 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
- @function.Defun()
+ @function.defun
def nested_nested_fn():
return x * y * 2.0
@@ -368,7 +372,7 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -426,10 +430,10 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
- @function.Defun()
+ @function.defun
def inner_nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -464,6 +468,7 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
+ self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -472,7 +477,7 @@ class CondV2Test(test.TestCase):
y = constant_op.constant(2.0, name="y")
# Build cond and its gradient inside a Defun.
- @function.Defun()
+ @function.defun
def fn():
def true_fn():
@@ -718,6 +723,7 @@ class CondV2ContainerTest(test.TestCase):
Make sure the containers are set correctly for both variable creation
(tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
"""
+ self.skipTest("b/113048653")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -795,6 +801,7 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -819,6 +826,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
def testColocateWithInAndOutOfCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -866,6 +874,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
+ self.skipTest("b/112166045")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
def fn():
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 7570523495..86802664d1 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(
@@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase):
extract_t.op.run()
self.assertEqual(q.num_accumulated().eval(), 0)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
- elems_ave = sum(elems) / len(elems)
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
@@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(constant_op.constant(1))
@@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0]
+
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(constant_op.constant(1))
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self):
with self.test_session():
@@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase):
with self.assertRaises(errors_impl.InvalidArgumentError):
takeg_t.eval()
- def testAccumulatorRepeatedTakeGrad(self):
+ def testAccumulatorRepeatedTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase):
val = takeg_t.eval()
self.assertEqual(elems_ave + 0.0, val)
+ def testAccumulatorRepeatedTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+
+ elems = [10.0, 20.0]
+ elems_sum = 30.0
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
+ elems = [20.0, 30.0]
+ elems_sum = 50.0
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
def testAccumulatorIncrementGlobalStep(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index a0d5557b92..cc788219ef 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -523,7 +523,7 @@ class OnesLikeTest(test.TestCase):
class FillTest(test.TestCase):
def _compare(self, dims, val, np_ans, use_gpu):
- ctx = context.get_default_context()
+ ctx = context.context()
device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0"
with ops.device(device):
tf_ans = array_ops.fill(dims, val, name="fill")
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 1a29d0816d..eac97af4ed 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,6 +23,7 @@ from __future__ import print_function
import collections
import math
import time
+import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -38,7 +39,9 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2 # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
@@ -122,6 +125,7 @@ def isum(s, maximum_iterations=None):
return r_s
+@test_util.with_cond_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
@@ -329,6 +333,9 @@ class ControlFlowTest(test.TestCase):
res.eval(feed_dict={data: 1.0})
def testCondBool(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113296297")
+
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -377,6 +384,9 @@ class ControlFlowTest(test.TestCase):
sess.run(r, feed_dict={t: 3})
def testCondIndexedSlices(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113296180")
+
with self.test_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -392,6 +402,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(0, ind)
def testCondSparseTensor(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113296161 (SparseTensors)")
+
with self.test_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -409,6 +422,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
rv = resource_variable_ops.ResourceVariable(True)
variables.global_variables_initializer().run()
@@ -422,6 +438,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113293074")
+
with self.test_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -465,10 +484,16 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
self._testCond_1(use_gpu=False)
self._testCond_1(use_gpu=True)
def testCond_2(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
x = constant_op.constant(10)
r = control_flow_ops.cond(
@@ -478,6 +503,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
x = constant_op.constant(10)
pred = math_ops.less(1, 2)
@@ -490,6 +518,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, result)
def testCond_4(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113324949 (ref vars)")
+
with self.test_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -525,6 +556,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
v1 = variables.Variable([7])
@@ -549,6 +583,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
x = gen_state_ops.variable(
shape=[1],
@@ -562,6 +599,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([2.0], r.eval())
def testCondWithControl(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/79881896")
+
with self.test_session() as sess:
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -601,6 +641,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([1.0], sess.run(merged_op.output))
def testCondSwitchIdentity(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/112477618 (Operation returned from cond)")
+
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -615,6 +658,9 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondRecvIdentity(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/112477618 (Operation returned from cond)")
+
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -631,6 +677,9 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113346829 (gpu failure)")
+
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -642,13 +691,6 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(r, [x])[0]
with self.test_session():
self.assertAllEqual(1.0, grad.eval())
- # The gradients computation creates a tensor with zeros by broadcasting a
- # zeros constant to the required shape. Verify that the zero constant
- # feeding into the fill is dominated by a Switch.
- zero = graph.get_operation_by_name("gradients/zeros/Const")
- self.assertEqual(len(zero.control_inputs), 1)
- self.assertEqual(zero.control_inputs[0].type, "Identity")
- self.assertEqual(zero.control_inputs[0].inputs[0].op.type, "Switch")
def testCondGrad_2(self):
with self.test_session():
@@ -664,6 +706,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
def testCondGrad_3(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/110550782 (gradient w.r.t external variable)")
+
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -696,6 +741,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, result.eval())
def testCondGrad_Gather(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113327884")
+
with self.test_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -868,6 +916,9 @@ class ControlFlowTest(test.TestCase):
_ = gradients_impl.gradients(loop_with_maxiter, v)
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294340 (enable while_v2)")
+
v = constant_op.constant(1.0)
def create_while_loop():
@@ -1324,6 +1375,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.test_session() as sess:
@@ -1338,6 +1392,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, sess.run(loop))
def testWhileCondWithControl_1(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113324949 (ref vars)")
+
with self.test_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1360,6 +1417,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(65536.0, v.eval())
def testWhileCondExitControl(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294340 (enable while_v2)")
+
with self.test_session():
v = variables.Variable(1)
@@ -1383,6 +1443,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
@@ -1393,6 +1456,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
@@ -1403,6 +1469,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294340 (enable while_v2)")
+
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1429,6 +1498,9 @@ class ControlFlowTest(test.TestCase):
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
with self.test_session():
i = ops.convert_to_tensor(0, name="i")
n = ops.convert_to_tensor(10, name="n")
@@ -1444,6 +1516,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
with self.test_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
@@ -1452,6 +1527,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
with self.test_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
@@ -1794,6 +1872,9 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1832,6 +1913,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/110550782 (gradient w.r.t external variable)")
+
with self.test_session():
n = ops.convert_to_tensor(1.0, name="n")
x = array_ops.placeholder(dtypes.float32, shape=None)
@@ -1880,6 +1964,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testCondGradInNestedWhiles(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113346829 (gpu failure)")
+
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
lambda j, x: j < 3, inner_body, [0, 0.0])
@@ -2193,6 +2280,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileCondGrad_Simple(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113294377 (unknown shape)")
+
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
@@ -2543,6 +2633,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
one = ops.convert_to_tensor(1, name="one")
@@ -2558,6 +2651,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/111124878 (don't return tuple)")
+
with self.test_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")
d = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2573,6 +2669,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testCase(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/112477618 (Operation returned from cond)")
+
with self.test_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2625,6 +2724,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
def testCaseSideEffects(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/112477618 (Operation returned from cond)")
+
with self.test_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2660,6 +2762,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
def testOneOpCond(self):
+ if control_flow_ops._ENABLE_COND_V2:
+ return unittest.skip("b/113324949 (ref vars)")
+
with self.test_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
diff --git a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
index e1920eb568..41ae0b456f 100644
--- a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
+++ b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
@@ -188,11 +188,11 @@ class CTCGreedyDecoderTest(test.TestCase):
],
dtype=np.float32)
# Add arbitrary offset - this is fine
- input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0
+ input_prob_matrix_0 = input_prob_matrix_0 + 2.0
# len max_time_steps array of batch_size x depth matrices
inputs = ([
- input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
+ input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
] # Pad to max_time_steps = 8
+ 2 * [np.zeros(
(1, depth), dtype=np.float32)])
@@ -200,11 +200,11 @@ class CTCGreedyDecoderTest(test.TestCase):
# batch_size length vector of sequence_lengths
seq_lens = np.array([seq_len_0], dtype=np.int32)
- # batch_size length vector of negative log probabilities
+ # batch_size length vector of log probabilities
log_prob_truth = np.array(
[
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -5.811451, # output beam 0
+ -6.63339 # output beam 1
],
np.float32)[np.newaxis, :]
@@ -215,11 +215,11 @@ class CTCGreedyDecoderTest(test.TestCase):
[[0, 0], [0, 1]], dtype=np.int64), np.array(
[1, 0], dtype=np.int64), np.array(
[1, 2], dtype=np.int64)),
- # beam 1, batch 0, three outputs decoded
+ # beam 1, batch 0, one output decoded
(np.array(
- [[0, 0], [0, 1], [0, 2]], dtype=np.int64), np.array(
- [0, 1, 0], dtype=np.int64), np.array(
- [1, 3], dtype=np.int64)),
+ [[0, 0]], dtype=np.int64), np.array(
+ [1], dtype=np.int64), np.array(
+ [1, 1], dtype=np.int64)),
]
# Test correct decoding.
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 58845552db..5741f2ec64 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -205,6 +205,19 @@ class DepthwiseConv2DTest(test.TestCase):
use_gpu=True,
grouped_conv=True)
+ def testDepthwiseConv2DWithUnknownShape(self):
+ # GitHub issue 22110.
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True):
+ x = array_ops.placeholder(dtypes.float32)
+ f = np.ones([1, 1, 1, 1], np.float32)
+ v = nn_impl.depthwise_conv2d(
+ x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW")
+ self.assertAllEqual(
+ np.ones([1, 1, 1, 1], np.float32),
+ v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)}))
+
def testDepthwiseConv2DFormat(self):
if not test.is_gpu_available():
return
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 9ad77a54cb..26d013bccb 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -62,59 +62,50 @@ class BernoulliTest(test.TestCase):
def testP(self):
p = [0.2, 0.4]
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(p, self.evaluate(dist.probs))
+ self.assertAllClose(p, self.evaluate(dist.probs))
@test_util.run_in_graph_and_eager_modes
def testLogits(self):
logits = [-42., 42.]
dist = bernoulli.Bernoulli(logits=logits)
- with self.test_session():
- self.assertAllClose(logits, self.evaluate(dist.logits))
+ self.assertAllClose(logits, self.evaluate(dist.logits))
if not special:
return
- with self.test_session():
- self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
+ self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
p = [0.01, 0.99, 0.42]
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
+ self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
@test_util.run_in_graph_and_eager_modes
def testInvalidP(self):
invalid_ps = [1.01, 2.]
for p in invalid_ps:
- with self.test_session():
- with self.assertRaisesOpError("probs has components greater than 1"):
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- self.evaluate(dist.probs)
+ with self.assertRaisesOpError("probs has components greater than 1"):
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ self.evaluate(dist.probs)
invalid_ps = [-0.01, -3.]
for p in invalid_ps:
- with self.test_session():
- with self.assertRaisesOpError("Condition x >= 0"):
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- self.evaluate(dist.probs)
+ with self.assertRaisesOpError("Condition x >= 0"):
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ self.evaluate(dist.probs)
valid_ps = [0.0, 0.5, 1.0]
for p in valid_ps:
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=p)
- self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
@test_util.run_in_graph_and_eager_modes
def testShapes(self):
- with self.test_session():
- for batch_shape in ([], [1], [2, 3, 4]):
- dist = make_bernoulli(batch_shape)
- self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
- self.assertAllEqual(batch_shape,
- self.evaluate(dist.batch_shape_tensor()))
- self.assertAllEqual([], dist.event_shape.as_list())
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ for batch_shape in ([], [1], [2, 3, 4]):
+ dist = make_bernoulli(batch_shape)
+ self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
+ self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
+ self.assertAllEqual([], dist.event_shape.as_list())
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
@test_util.run_in_graph_and_eager_modes
def testDtype(self):
@@ -137,31 +128,29 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def _testPmf(self, **kwargs):
dist = bernoulli.Bernoulli(**kwargs)
- with self.test_session():
- # pylint: disable=bad-continuation
- xs = [
- 0,
- [1],
- [1, 0],
- [[1, 0]],
- [[1, 0], [1, 1]],
- ]
- expected_pmfs = [
- [[0.8, 0.6], [0.7, 0.4]],
- [[0.2, 0.4], [0.3, 0.6]],
- [[0.2, 0.6], [0.3, 0.4]],
- [[0.2, 0.6], [0.3, 0.4]],
- [[0.2, 0.6], [0.3, 0.6]],
- ]
- # pylint: enable=bad-continuation
-
- for x, expected_pmf in zip(xs, expected_pmfs):
- self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
- self.assertAllClose(
- self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
+ # pylint: disable=bad-continuation
+ xs = [
+ 0,
+ [1],
+ [1, 0],
+ [[1, 0]],
+ [[1, 0], [1, 1]],
+ ]
+ expected_pmfs = [
+ [[0.8, 0.6], [0.7, 0.4]],
+ [[0.2, 0.4], [0.3, 0.6]],
+ [[0.2, 0.6], [0.3, 0.4]],
+ [[0.2, 0.6], [0.3, 0.4]],
+ [[0.2, 0.6], [0.3, 0.6]],
+ ]
+ # pylint: enable=bad-continuation
+
+ for x, expected_pmf in zip(xs, expected_pmfs):
+ self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
+ self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
def testPmfCorrectBroadcastDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtype=dtypes.float32)
dist = bernoulli.Bernoulli(probs=p)
event1 = [1, 0, 1]
@@ -178,12 +167,11 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testPmfInvalid(self):
p = [0.1, 0.2, 0.7]
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=p, validate_args=True)
- with self.assertRaisesOpError("must be non-negative."):
- self.evaluate(dist.prob([1, 1, -1]))
- with self.assertRaisesOpError("Elements cannot exceed 1."):
- self.evaluate(dist.prob([2, 0, 1]))
+ dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+ with self.assertRaisesOpError("must be non-negative."):
+ self.evaluate(dist.prob([1, 1, -1]))
+ with self.assertRaisesOpError("Elements cannot exceed 1."):
+ self.evaluate(dist.prob([2, 0, 1]))
@test_util.run_in_graph_and_eager_modes
def testPmfWithP(self):
@@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase):
self._testPmf(logits=special.logit(p))
def testBroadcasting(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
dist = bernoulli.Bernoulli(probs=p)
self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
@@ -208,70 +196,63 @@ class BernoulliTest(test.TestCase):
}))
def testPmfShapes(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
dist = bernoulli.Bernoulli(probs=p)
self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
- with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
- with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual((), dist.log_prob(1).get_shape())
self.assertEqual((1), dist.log_prob([1]).get_shape())
self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
- with self.test_session():
dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
self.assertEqual((2, 1), dist.log_prob(1).get_shape())
@test_util.run_in_graph_and_eager_modes
def testBoundaryConditions(self):
- with self.test_session():
- dist = bernoulli.Bernoulli(probs=1.0)
- self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
- self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
+ dist = bernoulli.Bernoulli(probs=1.0)
+ self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
+ self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
@test_util.run_in_graph_and_eager_modes
def testEntropyNoBatch(self):
p = 0.2
dist = bernoulli.Bernoulli(probs=p)
- with self.test_session():
- self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
+ self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
@test_util.run_in_graph_and_eager_modes
def testEntropyWithBatch(self):
p = [[0.1, 0.7], [0.2, 0.6]]
dist = bernoulli.Bernoulli(probs=p, validate_args=False)
- with self.test_session():
- self.assertAllClose(
- self.evaluate(dist.entropy()),
- [[entropy(0.1), entropy(0.7)], [entropy(0.2),
- entropy(0.6)]])
+ self.assertAllClose(
+ self.evaluate(dist.entropy()),
+ [[entropy(0.1), entropy(0.7)], [entropy(0.2),
+ entropy(0.6)]])
@test_util.run_in_graph_and_eager_modes
def testSampleN(self):
- with self.test_session():
- p = [0.2, 0.6]
- dist = bernoulli.Bernoulli(probs=p)
- n = 100000
- samples = dist.sample(n)
- samples.set_shape([n, 2])
- self.assertEqual(samples.dtype, dtypes.int32)
- sample_values = self.evaluate(samples)
- self.assertTrue(np.all(sample_values >= 0))
- self.assertTrue(np.all(sample_values <= 1))
- # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
- # n). This means that the tolerance is very sensitive to the value of p
- # as well as n.
- self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
- self.assertEqual(set([0, 1]), set(sample_values.flatten()))
- # In this test we're just interested in verifying there isn't a crash
- # owing to mismatched types. b/30940152
- dist = bernoulli.Bernoulli(np.log([.2, .4]))
- self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+ p = [0.2, 0.6]
+ dist = bernoulli.Bernoulli(probs=p)
+ n = 100000
+ samples = dist.sample(n)
+ samples.set_shape([n, 2])
+ self.assertEqual(samples.dtype, dtypes.int32)
+ sample_values = self.evaluate(samples)
+ self.assertTrue(np.all(sample_values >= 0))
+ self.assertTrue(np.all(sample_values <= 1))
+ # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
+ # n). This means that the tolerance is very sensitive to the value of p
+ # as well as n.
+ self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
+ self.assertEqual(set([0, 1]), set(sample_values.flatten()))
+ # In this test we're just interested in verifying there isn't a crash
+ # owing to mismatched types. b/30940152
+ dist = bernoulli.Bernoulli(np.log([.2, .4]))
+ self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
@test_util.run_in_graph_and_eager_modes
def testNotReparameterized(self):
@@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase):
self.assertIsNone(grad_p)
def testSampleActsLikeSampleN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = [0.2, 0.6]
dist = bernoulli.Bernoulli(probs=p)
n = 1000
@@ -299,27 +280,24 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMean(self):
- with self.test_session():
- p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
- dist = bernoulli.Bernoulli(probs=p)
- self.assertAllEqual(self.evaluate(dist.mean()), p)
+ p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertAllEqual(self.evaluate(dist.mean()), p)
@test_util.run_in_graph_and_eager_modes
def testVarianceAndStd(self):
var = lambda p: p * (1. - p)
- with self.test_session():
- p = [[0.2, 0.7], [0.5, 0.4]]
- dist = bernoulli.Bernoulli(probs=p)
- self.assertAllClose(
- self.evaluate(dist.variance()),
- np.array(
- [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
- self.assertAllClose(
- self.evaluate(dist.stddev()),
- np.array(
- [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
- [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
- dtype=np.float32))
+ p = [[0.2, 0.7], [0.5, 0.4]]
+ dist = bernoulli.Bernoulli(probs=p)
+ self.assertAllClose(
+ self.evaluate(dist.variance()),
+ np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
+ dtype=np.float32))
+ self.assertAllClose(
+ self.evaluate(dist.stddev()),
+ np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
+ [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
+ dtype=np.float32))
@test_util.run_in_graph_and_eager_modes
def testBernoulliBernoulliKL(self):
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index 36f3ffc333..d580a415dd 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -20,7 +20,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
@@ -51,237 +50,215 @@ stats = try_import("scipy.stats")
class BetaTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
- a = np.random.rand(3)
- b = np.random.rand(3)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+ a = np.random.rand(3)
+ b = np.random.rand(3)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2)
- b = np.random.rand(3, 2, 2)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(
- tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(3, 2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testComplexShapesBroadcast(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2)
- b = np.random.rand(2, 2)
- dist = beta_lib.Beta(a, b)
- self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
- self.assertEqual(
- tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+ a = np.random.rand(3, 2, 2)
+ b = np.random.rand(2, 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testAlphaProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.concentration1.get_shape())
- self.assertAllClose(a, self.evaluate(dist.concentration1))
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration1.get_shape())
+ self.assertAllClose(a, self.evaluate(dist.concentration1))
def testBetaProperty(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b)
- self.assertEqual([1, 3], dist.concentration0.get_shape())
- self.assertAllClose(b, self.evaluate(dist.concentration0))
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual([1, 3], dist.concentration0.get_shape())
+ self.assertAllClose(b, self.evaluate(dist.concentration0))
def testPdfXProper(self):
a = [[1., 2, 3]]
b = [[2., 4, 3]]
- with self.test_session():
- dist = beta_lib.Beta(a, b, validate_args=True)
- self.evaluate(dist.prob([.1, .3, .6]))
- self.evaluate(dist.prob([.2, .3, .5]))
- # Either condition can trigger.
- with self.assertRaisesOpError("sample must be positive"):
- self.evaluate(dist.prob([-1., 0.1, 0.5]))
- with self.assertRaisesOpError("sample must be positive"):
- self.evaluate(dist.prob([0., 0.1, 0.5]))
- with self.assertRaisesOpError("sample must be less than `1`"):
- self.evaluate(dist.prob([.1, .2, 1.2]))
- with self.assertRaisesOpError("sample must be less than `1`"):
- self.evaluate(dist.prob([.1, .2, 1.0]))
+ dist = beta_lib.Beta(a, b, validate_args=True)
+ self.evaluate(dist.prob([.1, .3, .6]))
+ self.evaluate(dist.prob([.2, .3, .5]))
+ # Either condition can trigger.
+ with self.assertRaisesOpError("sample must be positive"):
+ self.evaluate(dist.prob([-1., 0.1, 0.5]))
+ with self.assertRaisesOpError("sample must be positive"):
+ self.evaluate(dist.prob([0., 0.1, 0.5]))
+ with self.assertRaisesOpError("sample must be less than `1`"):
+ self.evaluate(dist.prob([.1, .2, 1.2]))
+ with self.assertRaisesOpError("sample must be less than `1`"):
+ self.evaluate(dist.prob([.1, .2, 1.0]))
def testPdfTwoBatches(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [.5, .5]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2,), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [.5, .5]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2,), pdf.get_shape())
def testPdfTwoBatchesNontrivialX(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [.3, .7]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
- self.assertEqual((2,), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [.3, .7]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
+ self.assertEqual((2,), pdf.get_shape())
def testPdfUniformZeroBatch(self):
- with self.test_session():
- # This is equivalent to a uniform distribution
- a = 1.
- b = 1.
- x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([1.] * 5, self.evaluate(pdf))
- self.assertEqual((5,), pdf.get_shape())
+ # This is equivalent to a uniform distribution
+ a = 1.
+ b = 1.
+ x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([1.] * 5, self.evaluate(pdf))
+ self.assertEqual((5,), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- a = [[1., 2]]
- b = [[1., 2]]
- x = [[.5, .5], [.3, .7]]
- dist = beta_lib.Beta(a, b)
- pdf = dist.prob(x)
- self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2]]
+ b = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = beta_lib.Beta(a, b)
+ pdf = dist.prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- a = [1., 2]
- b = [1., 2]
- x = [[.5, .5], [.2, .8]]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [1., 2]
+ b = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- a = [[1., 2], [2., 3]]
- b = [[1., 2], [2., 3]]
- x = [[.5, .5]]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- a = [[1., 2], [2., 3]]
- b = [[1., 2], [2., 3]]
- x = [.5, .5]
- pdf = beta_lib.Beta(a, b).prob(x)
- self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
- self.assertEqual((2, 2), pdf.get_shape())
+ a = [[1., 2], [2., 3]]
+ b = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = beta_lib.Beta(a, b).prob(x)
+ self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+ self.assertEqual((2, 2), pdf.get_shape())
def testBetaMean(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.mean().get_shape(), (3,))
- if not stats:
- return
- expected_mean = stats.beta.mean(a, b)
- self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = stats.beta.mean(a, b)
+ self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
def testBetaVariance(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variance = stats.beta.var(a, b)
- self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = stats.beta.var(a, b)
+ self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
def testBetaMode(self):
- with session.Session():
- a = np.array([1.1, 2, 3])
- b = np.array([2., 4, 1.2])
- expected_mode = (a - 1) / (a + b - 2)
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.mode().get_shape(), (3,))
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ a = np.array([1.1, 2, 3])
+ b = np.array([2., 4, 1.2])
+ expected_mode = (a - 1) / (a + b - 2)
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.mode().get_shape(), (3,))
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaModeInvalid(self):
- with session.Session():
- a = np.array([1., 2, 3])
- b = np.array([2., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dist.mode())
-
- a = np.array([2., 2, 3])
- b = np.array([1., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dist.mode())
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dist.mode())
+
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dist.mode())
def testBetaModeEnableAllowNanStats(self):
- with session.Session():
- a = np.array([1., 2, 3])
- b = np.array([2., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+ a = np.array([1., 2, 3])
+ b = np.array([2., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
- expected_mode = (a - 1) / (a + b - 2)
- expected_mode[0] = np.nan
- self.assertEqual((3,), dist.mode().get_shape())
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
- a = np.array([2., 2, 3])
- b = np.array([1., 4, 1.2])
- dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+ a = np.array([2., 2, 3])
+ b = np.array([1., 4, 1.2])
+ dist = beta_lib.Beta(a, b, allow_nan_stats=True)
- expected_mode = (a - 1) / (a + b - 2)
- expected_mode[0] = np.nan
- self.assertEqual((3,), dist.mode().get_shape())
- self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+ expected_mode = (a - 1) / (a + b - 2)
+ expected_mode[0] = np.nan
+ self.assertEqual((3,), dist.mode().get_shape())
+ self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaEntropy(self):
- with session.Session():
- a = [1., 2, 3]
- b = [2., 4, 1.2]
- dist = beta_lib.Beta(a, b)
- self.assertEqual(dist.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.beta.entropy(a, b)
- self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
+ a = [1., 2, 3]
+ b = [2., 4, 1.2]
+ dist = beta_lib.Beta(a, b)
+ self.assertEqual(dist.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.beta.entropy(a, b)
+ self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
def testBetaSample(self):
- with self.test_session():
- a = 1.
- b = 2.
- beta = beta_lib.Beta(a, b)
- n = constant_op.constant(100000)
- samples = beta.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000,))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- self.assertLess(
- stats.kstest(
- # Beta is a univariate distribution.
- sample_values,
- stats.beta(a=1., b=2.).cdf)[0],
- 0.01)
- # The standard error of the sample mean is 1 / (sqrt(18 * n))
- self.assertAllClose(
- sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
- self.assertAllClose(
- np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+ a = 1.
+ b = 2.
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values,
+ stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
+ # The standard error of the sample mean is 1 / (sqrt(18 * n))
+ self.assertAllClose(
+ sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
+ self.assertAllClose(
+ np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
def testBetaFullyReparameterized(self):
a = constant_op.constant(1.0)
@@ -297,78 +274,71 @@ class BetaTest(test.TestCase):
# Test that sampling with the same seed twice gives the same results.
def testBetaSampleMultipleTimes(self):
- with self.test_session():
- a_val = 1.
- b_val = 2.
- n_val = 100
+ a_val = 1.
+ b_val = 2.
+ n_val = 100
- random_seed.set_random_seed(654321)
- beta1 = beta_lib.Beta(concentration1=a_val,
- concentration0=b_val,
- name="beta1")
- samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
+ random_seed.set_random_seed(654321)
+ beta1 = beta_lib.Beta(
+ concentration1=a_val, concentration0=b_val, name="beta1")
+ samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
- random_seed.set_random_seed(654321)
- beta2 = beta_lib.Beta(concentration1=a_val,
- concentration0=b_val,
- name="beta2")
- samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
+ random_seed.set_random_seed(654321)
+ beta2 = beta_lib.Beta(
+ concentration1=a_val, concentration0=b_val, name="beta2")
+ samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
- self.assertAllClose(samples1, samples2)
+ self.assertAllClose(samples1, samples2)
def testBetaSampleMultidimensional(self):
- with self.test_session():
- a = np.random.rand(3, 2, 2).astype(np.float32)
- b = np.random.rand(3, 2, 2).astype(np.float32)
- beta = beta_lib.Beta(a, b)
- n = constant_op.constant(100000)
- samples = beta.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- self.assertAllClose(
- sample_values[:, 1, :].mean(axis=0),
- stats.beta.mean(a, b)[1, :],
- atol=1e-1)
+ a = np.random.rand(3, 2, 2).astype(np.float32)
+ b = np.random.rand(3, 2, 2).astype(np.float32)
+ beta = beta_lib.Beta(a, b)
+ n = constant_op.constant(100000)
+ samples = beta.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values[:, 1, :].mean(axis=0),
+ stats.beta.mean(a, b)[1, :],
+ atol=1e-1)
def testBetaCdf(self):
- with self.test_session():
- shape = (30, 40, 50)
- for dt in (np.float32, np.float64):
- a = 10. * np.random.random(shape).astype(dt)
- b = 10. * np.random.random(shape).astype(dt)
- x = np.random.random(shape).astype(dt)
- actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
- if not stats:
- return
- self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaLogCdf(self):
- with self.test_session():
- shape = (30, 40, 50)
- for dt in (np.float32, np.float64):
- a = 10. * np.random.random(shape).astype(dt)
- b = 10. * np.random.random(shape).astype(dt)
- x = np.random.random(shape).astype(dt)
- actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
- self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
- if not stats:
- return
- self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+ shape = (30, 40, 50)
+ for dt in (np.float32, np.float64):
+ a = 10. * np.random.random(shape).astype(dt)
+ b = 10. * np.random.random(shape).astype(dt)
+ x = np.random.random(shape).astype(dt)
+ actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+ self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+ if not stats:
+ return
+ self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaWithSoftplusConcentration(self):
- with self.test_session():
- a, b = -4.2, -9.1
- dist = beta_lib.BetaWithSoftplusConcentration(a, b)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
+ a, b = -4.2, -9.1
+ dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
def testBetaBetaKL(self):
for shape in [(10,), (4, 5)]:
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index 8b11556330..e20f59f48a 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -36,11 +36,10 @@ class BaseBijectorTest(test.TestCase):
"""Tests properties of the Bijector base-class."""
def testIsAbstract(self):
- with self.test_session():
- with self.assertRaisesRegexp(TypeError,
- ("Can't instantiate abstract class Bijector "
- "with abstract methods __init__")):
- bijector.Bijector() # pylint: disable=abstract-class-instantiated
+ with self.assertRaisesRegexp(TypeError,
+ ("Can't instantiate abstract class Bijector "
+ "with abstract methods __init__")):
+ bijector.Bijector() # pylint: disable=abstract-class-instantiated
def testDefaults(self):
class _BareBonesBijector(bijector.Bijector):
@@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase):
def testBijectorDynamicEventNdims(self):
bij = BrokenBijector(validate_args=True)
event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Expected scalar"):
bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
event_ndims: (1, 2)})
@@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase):
event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
bij = ExpOnlyJacobian(forward_min_event_ndims=1)
bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
feed_dict={event_ndims: 1})
self.assertAllClose(-np.log(x_), ildj)
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index d8939433ce..c6bb06eab3 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -47,7 +47,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testP(self):
p = [0.2, 0.8]
dist = categorical.Categorical(probs=p)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(p, dist.probs.eval())
self.assertAllEqual([2], dist.logits.get_shape())
@@ -55,14 +55,14 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
p = np.array([0.2, 0.8], dtype=np.float32)
logits = np.log(p) - 50.
dist = categorical.Categorical(logits=logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([2], dist.probs.get_shape())
self.assertAllEqual([2], dist.logits.get_shape())
self.assertAllClose(dist.probs.eval(), p)
self.assertAllClose(dist.logits.eval(), logits)
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_categorical(batch_shape, 10)
self.assertAllEqual(batch_shape, dist.batch_shape)
@@ -108,7 +108,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertEqual(dist.dtype, dist.sample(5).dtype)
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(dtype=dtypes.float32)
dist = categorical.Categorical(logits)
sample = dist.sample()
@@ -124,13 +124,13 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testPMFWithBatch(self):
histograms = [[0.2, 0.8], [0.6, 0.4]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.prob([0, 1]).eval(), [0.2, 0.4])
def testPMFNoBatch(self):
histograms = [0.2, 0.8]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.prob(0).eval(), 0.2)
def testCDFWithDynamicEventShapeKnownNdims(self):
@@ -162,7 +162,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
event: event_feed_two
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one)
actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two)
@@ -192,7 +192,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
dist = categorical.Categorical(probs=histograms)
cdf_op = dist.cdf(event)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(cdf_op.eval(), expected_cdf)
def testCDFNoBatch(self):
@@ -202,7 +202,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
dist = categorical.Categorical(probs=histogram)
cdf_op = dist.cdf(event)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(cdf_op.eval(), expected_cdf)
def testCDFBroadcasting(self):
@@ -228,7 +228,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
expected_cdf_result[2, 0] = 0.3
expected_cdf_result[2, 1] = 0.75
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result)
def testBroadcastWithBatchParamsAndBiggerEvent(self):
@@ -286,7 +286,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
"norm_log_cdf": norm.log_cdf(real_event_tf),
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_result = sess.run(to_run)
self.assertAllEqual(run_result["cat_prob"].shape,
@@ -301,28 +301,28 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testLogPMF(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4]))
self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4]))
def testEntropyNoBatch(self):
logits = np.log([0.2, 0.8]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.entropy().eval(),
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
def testEntropyWithBatch(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.entropy().eval(), [
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
-(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
])
def testEntropyGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]])
probabilities = nn_ops.softmax(logits)
@@ -348,7 +348,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
res["categorical_entropy_g"])
def testSample(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
n = 10000
@@ -366,7 +366,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
sample_values == 1, axis=0), atol=1e-2)
def testSampleWithSampleShape(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
samples = dist.sample((100, 100), seed=123)
@@ -387,7 +387,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(grad_p)
def testLogPMFBroadcasting(self):
- with self.test_session():
+ with self.cached_session():
# 1 x 2 x 2
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
@@ -415,7 +415,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
prob.eval())
def testLogPMFShape(self):
- with self.test_session():
+ with self.cached_session():
# shape [1, 2, 2]
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms))
@@ -441,7 +441,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 2, 2], log_prob.get_shape())
def testMode(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.6, 0.4]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
self.assertAllEqual(dist.mode().eval(), [[1, 0]])
@@ -452,7 +452,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
exp_logits = np.exp(logits)
return exp_logits / exp_logits.sum(axis=-1, keepdims=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for categories in [2, 4]:
for batch_size in [1, 10]:
a_logits = np.random.randn(batch_size, categories)
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index 1b9edcc85a..d558ca09cc 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -37,7 +37,7 @@ class DirichletMultinomialTest(test.TestCase):
self._rng = np.random.RandomState(42)
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
alpha = np.random.rand(3)
dist = ds.DirichletMultinomial(1., alpha)
self.assertEqual(3, dist.event_shape_tensor().eval())
@@ -46,7 +46,7 @@ class DirichletMultinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
alpha = np.random.rand(3, 2, 2)
n = [[3., 2], [4, 5], [6, 7]]
dist = ds.DirichletMultinomial(n, alpha)
@@ -58,14 +58,14 @@ class DirichletMultinomialTest(test.TestCase):
def testNproperty(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha)
self.assertEqual([1, 1], dist.total_count.get_shape())
self.assertAllClose(n, dist.total_count.eval())
def testAlphaProperty(self):
alpha = [[1., 2, 3]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1, alpha)
self.assertEqual([1, 3], dist.concentration.get_shape())
self.assertAllClose(alpha, dist.concentration.eval())
@@ -73,7 +73,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfNandCountsAgree(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -86,7 +86,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfNonIntegerCounts(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -104,7 +104,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatches(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [1., 0]
@@ -116,7 +116,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatchesNontrivialN(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [3., 2]
@@ -128,7 +128,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatchesMultidimensionalN(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [1., 2]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
@@ -140,7 +140,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfAlphaStretchedInBroadcastWhenSameRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2]]
counts = [[1., 0], [0., 1]]
dist = ds.DirichletMultinomial([1.], alpha)
@@ -151,7 +151,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [1., 2]
counts = [[1., 0], [0., 1]]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
@@ -161,7 +161,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2], [2., 3]]
counts = [[1., 0]]
pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
@@ -171,7 +171,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2], [2., 3]]
counts = [1., 0]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
@@ -182,7 +182,7 @@ class DirichletMultinomialTest(test.TestCase):
# The probabilities of one vote falling into class k is the mean for class
# k.
alpha = [1., 2, 3]
- with self.test_session():
+ with self.cached_session():
for class_num in range(3):
counts = np.zeros([3], dtype=np.float32)
counts[class_num] = 1
@@ -199,7 +199,7 @@ class DirichletMultinomialTest(test.TestCase):
# DirichletMultinomial(2, alpha) is twice as much as the probability of one
# vote falling into class k for DirichletMultinomial(1, alpha)
alpha = [1., 2, 3]
- with self.test_session():
+ with self.cached_session():
for class_num in range(3):
counts_one = np.zeros([3], dtype=np.float32)
counts_one[class_num] = 1.
@@ -223,7 +223,7 @@ class DirichletMultinomialTest(test.TestCase):
# Ideally we'd be able to test broadcasting but, the multinomial sampler
# doesn't support different total counts.
n = np.float32(5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# batch_shape=[2], event_shape=[3]
dist = ds.DirichletMultinomial(n, alpha)
x = dist.sample(int(250e3), seed=1)
@@ -281,7 +281,7 @@ class DirichletMultinomialTest(test.TestCase):
variance_entry(alpha[1], alpha_0)
]])
- with self.test_session():
+ with self.cached_session():
for n in ns:
# n is shape [] and alpha is shape [2].
dist = ds.DirichletMultinomial(n, alpha)
@@ -319,7 +319,7 @@ class DirichletMultinomialTest(test.TestCase):
]]],
dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
# ns is shape [4, 1], and alpha is shape [4, 3].
dist = ds.DirichletMultinomial(ns, alpha)
covariance = dist.covariance()
@@ -336,7 +336,7 @@ class DirichletMultinomialTest(test.TestCase):
ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(ns, alpha)
dist2 = ds.DirichletMultinomial(ns2, alpha2)
@@ -350,7 +350,7 @@ class DirichletMultinomialTest(test.TestCase):
# probability 1.
alpha = [5, 0.5]
counts = [0., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(0., alpha)
pmf = dist.prob(counts)
self.assertAllClose(1.0, pmf.eval())
@@ -365,7 +365,7 @@ class DirichletMultinomialTest(test.TestCase):
# One (three sided) coin flip. Prob[coin 3] = 0.8.
# Note that since it was one flip, value of tau didn't matter.
counts = [0., 0, 1]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.8, pmf.eval(), atol=1e-4)
@@ -373,7 +373,7 @@ class DirichletMultinomialTest(test.TestCase):
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
counts = [0., 0, 2]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(2., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
@@ -381,7 +381,7 @@ class DirichletMultinomialTest(test.TestCase):
# Three (three sided) coin flips.
counts = [1., 0, 2]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(3., alpha)
pmf = dist.prob(counts)
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
@@ -396,7 +396,7 @@ class DirichletMultinomialTest(test.TestCase):
# If there is only one draw, it is still a coin flip, even with small tau.
counts = [1., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.5, pmf.eval())
@@ -405,7 +405,7 @@ class DirichletMultinomialTest(test.TestCase):
# If there are two draws, it is much more likely that they are the same.
counts_same = [2., 0]
counts_different = [1, 1.]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(2., alpha)
pmf_same = dist.prob(counts_same)
pmf_different = dist.prob(counts_different)
@@ -414,7 +414,7 @@ class DirichletMultinomialTest(test.TestCase):
def testNonStrictTurnsOffAllChecks(self):
# Make totally invalid input.
- with self.test_session():
+ with self.cached_session():
alpha = [[-1., 2]] # alpha should be positive.
counts = [[1., 0], [0., -1]] # counts should be non-negative.
n = [-5.3] # n should be a non negative integer equal to counts.sum.
@@ -422,7 +422,7 @@ class DirichletMultinomialTest(test.TestCase):
dist.prob(counts).eval() # Should not raise.
def testSampleUnbiasedNonScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
@@ -451,7 +451,7 @@ class DirichletMultinomialTest(test.TestCase):
actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
def testSampleUnbiasedScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 67ed0447ed..cace5b3ba2 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -49,115 +49,102 @@ stats = try_import("scipy.stats")
class DirichletTest(test.TestCase):
def testSimpleShapes(self):
- with self.test_session():
- alpha = np.random.rand(3)
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
+ alpha = np.random.rand(3)
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
- alpha = np.random.rand(3, 2, 2)
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
- self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
- self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
- self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
+ alpha = np.random.rand(3, 2, 2)
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
+ self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
+ self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
+ self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
def testConcentrationProperty(self):
alpha = [[1., 2, 3]]
- with self.test_session():
- dist = dirichlet_lib.Dirichlet(alpha)
- self.assertEqual([1, 3], dist.concentration.get_shape())
- self.assertAllClose(alpha, self.evaluate(dist.concentration))
+ dist = dirichlet_lib.Dirichlet(alpha)
+ self.assertEqual([1, 3], dist.concentration.get_shape())
+ self.assertAllClose(alpha, self.evaluate(dist.concentration))
def testPdfXProper(self):
alpha = [[1., 2, 3]]
- with self.test_session():
- dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
- self.evaluate(dist.prob([.1, .3, .6]))
- self.evaluate(dist.prob([.2, .3, .5]))
- # Either condition can trigger.
- with self.assertRaisesOpError("samples must be positive"):
- self.evaluate(dist.prob([-1., 1.5, 0.5]))
- with self.assertRaisesOpError("samples must be positive"):
- self.evaluate(dist.prob([0., .1, .9]))
- with self.assertRaisesOpError(
- "sample last-dimension must sum to `1`"):
- self.evaluate(dist.prob([.1, .2, .8]))
+ dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
+ self.evaluate(dist.prob([.1, .3, .6]))
+ self.evaluate(dist.prob([.2, .3, .5]))
+ # Either condition can trigger.
+ with self.assertRaisesOpError("samples must be positive"):
+ self.evaluate(dist.prob([-1., 1.5, 0.5]))
+ with self.assertRaisesOpError("samples must be positive"):
+ self.evaluate(dist.prob([0., .1, .9]))
+ with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
+ self.evaluate(dist.prob([.1, .2, .8]))
def testPdfZeroBatches(self):
- with self.test_session():
- alpha = [1., 2]
- x = [.5, .5]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose(1., self.evaluate(pdf))
- self.assertEqual((), pdf.get_shape())
+ alpha = [1., 2]
+ x = [.5, .5]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose(1., self.evaluate(pdf))
+ self.assertEqual((), pdf.get_shape())
def testPdfZeroBatchesNontrivialX(self):
- with self.test_session():
- alpha = [1., 2]
- x = [.3, .7]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose(7. / 5, self.evaluate(pdf))
- self.assertEqual((), pdf.get_shape())
+ alpha = [1., 2]
+ x = [.3, .7]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose(7. / 5, self.evaluate(pdf))
+ self.assertEqual((), pdf.get_shape())
def testPdfUniformZeroBatches(self):
- with self.test_session():
- # Corresponds to a uniform distribution
- alpha = [1., 1, 1]
- x = [[.2, .5, .3], [.3, .4, .3]]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose([2., 2.], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ # Corresponds to a uniform distribution
+ alpha = [1., 1, 1]
+ x = [[.2, .5, .3], [.3, .4, .3]]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose([2., 2.], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- alpha = [[1., 2]]
- x = [[.5, .5], [.3, .7]]
- dist = dirichlet_lib.Dirichlet(alpha)
- pdf = dist.prob(x)
- self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2]]
+ x = [[.5, .5], [.3, .7]]
+ dist = dirichlet_lib.Dirichlet(alpha)
+ pdf = dist.prob(x)
+ self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- alpha = [1., 2]
- x = [[.5, .5], [.2, .8]]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [1., 2]
+ x = [[.5, .5], [.2, .8]]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
- alpha = [[1., 2], [2., 3]]
- x = [[.5, .5]]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2], [2., 3]]
+ x = [[.5, .5]]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
- alpha = [[1., 2], [2., 3]]
- x = [.5, .5]
- pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
- self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
- self.assertEqual((2), pdf.get_shape())
+ alpha = [[1., 2], [2., 3]]
+ x = [.5, .5]
+ pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+ self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+ self.assertEqual((2), pdf.get_shape())
def testMean(self):
- with self.test_session():
- alpha = [1., 2, 3]
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.mean().get_shape(), [3])
- if not stats:
- return
- expected_mean = stats.dirichlet.mean(alpha)
- self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
+ alpha = [1., 2, 3]
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mean().get_shape(), [3])
+ if not stats:
+ return
+ expected_mean = stats.dirichlet.mean(alpha)
+ self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
def testCovarianceFromSampling(self):
alpha = np.array([[1., 2, 3],
@@ -197,73 +184,66 @@ class DirichletTest(test.TestCase):
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
def testVariance(self):
- with self.test_session():
- alpha = [1., 2, 3]
- denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
- if not stats:
- return
- expected_covariance = np.diag(stats.dirichlet.var(alpha))
- expected_covariance += [[0., -2, -3], [-2, 0, -6],
- [-3, -6, 0]] / denominator
- self.assertAllClose(
- self.evaluate(dirichlet.covariance()), expected_covariance)
+ alpha = [1., 2, 3]
+ denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
+ if not stats:
+ return
+ expected_covariance = np.diag(stats.dirichlet.var(alpha))
+ expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
+ ] / denominator
+ self.assertAllClose(
+ self.evaluate(dirichlet.covariance()), expected_covariance)
def testMode(self):
- with self.test_session():
- alpha = np.array([1.1, 2, 3])
- expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.mode().get_shape(), [3])
- self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+ alpha = np.array([1.1, 2, 3])
+ expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
+ self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testModeInvalid(self):
- with self.test_session():
- alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
- allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
- self.evaluate(dirichlet.mode())
+ alpha = np.array([1., 2, 3])
+ dirichlet = dirichlet_lib.Dirichlet(
+ concentration=alpha, allow_nan_stats=False)
+ with self.assertRaisesOpError("Condition x < y.*"):
+ self.evaluate(dirichlet.mode())
def testModeEnableAllowNanStats(self):
- with self.test_session():
- alpha = np.array([1., 2, 3])
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
- allow_nan_stats=True)
- expected_mode = np.zeros_like(alpha) + np.nan
+ alpha = np.array([1., 2, 3])
+ dirichlet = dirichlet_lib.Dirichlet(
+ concentration=alpha, allow_nan_stats=True)
+ expected_mode = np.zeros_like(alpha) + np.nan
- self.assertEqual(dirichlet.mode().get_shape(), [3])
- self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+ self.assertEqual(dirichlet.mode().get_shape(), [3])
+ self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testEntropy(self):
- with self.test_session():
- alpha = [1., 2, 3]
- dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
- self.assertEqual(dirichlet.entropy().get_shape(), ())
- if not stats:
- return
- expected_entropy = stats.dirichlet.entropy(alpha)
- self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
+ alpha = [1., 2, 3]
+ dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+ self.assertEqual(dirichlet.entropy().get_shape(), ())
+ if not stats:
+ return
+ expected_entropy = stats.dirichlet.entropy(alpha)
+ self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
def testSample(self):
- with self.test_session():
- alpha = [1., 2]
- dirichlet = dirichlet_lib.Dirichlet(alpha)
- n = constant_op.constant(100000)
- samples = dirichlet.sample(n)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertTrue(np.all(sample_values > 0.0))
- if not stats:
- return
- self.assertLess(
- stats.kstest(
- # Beta is a univariate distribution.
- sample_values[:, 0],
- stats.beta(
- a=1., b=2.).cdf)[0],
- 0.01)
+ alpha = [1., 2]
+ dirichlet = dirichlet_lib.Dirichlet(alpha)
+ n = constant_op.constant(100000)
+ samples = dirichlet.sample(n)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertTrue(np.all(sample_values > 0.0))
+ if not stats:
+ return
+ self.assertLess(
+ stats.kstest(
+ # Beta is a univariate distribution.
+ sample_values[:, 0],
+ stats.beta(a=1., b=2.).cdf)[0],
+ 0.01)
def testDirichletFullyReparameterized(self):
alpha = constant_op.constant([1.0, 2.0, 3.0])
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 850da3e969..27d1291912 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -22,7 +22,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
@@ -48,121 +47,108 @@ stats = try_import("scipy.stats")
class ExponentialTest(test.TestCase):
def testExponentialLogPDF(self):
- with session.Session():
- batch_size = 6
- lam = constant_op.constant([2.0] * batch_size)
- lam_v = 2.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(rate=lam)
+ batch_size = 6
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ exponential = exponential_lib.Exponential(rate=lam)
- log_pdf = exponential.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
+ log_pdf = exponential.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
- pdf = exponential.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
+ pdf = exponential.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ if not stats:
+ return
+ expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testExponentialCDF(self):
- with session.Session():
- batch_size = 6
- lam = constant_op.constant([2.0] * batch_size)
- lam_v = 2.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- exponential = exponential_lib.Exponential(rate=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
- cdf = exponential.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
+ cdf = exponential.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ if not stats:
+ return
+ expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testExponentialMean(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.mean().get_shape(), (3,))
- if not stats:
- return
- expected_mean = stats.expon.mean(scale=1 / lam_v)
- self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_mean = stats.expon.mean(scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
def testExponentialVariance(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variance = stats.expon.var(scale=1 / lam_v)
- self.assertAllClose(
- self.evaluate(exponential.variance()), expected_variance)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variance = stats.expon.var(scale=1 / lam_v)
+ self.assertAllClose(
+ self.evaluate(exponential.variance()), expected_variance)
def testExponentialEntropy(self):
- with session.Session():
- lam_v = np.array([1.0, 4.0, 2.5])
- exponential = exponential_lib.Exponential(rate=lam_v)
- self.assertEqual(exponential.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.expon.entropy(scale=1 / lam_v)
- self.assertAllClose(
- self.evaluate(exponential.entropy()), expected_entropy)
+ lam_v = np.array([1.0, 4.0, 2.5])
+ exponential = exponential_lib.Exponential(rate=lam_v)
+ self.assertEqual(exponential.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.expon.entropy(scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
def testExponentialSample(self):
- with self.test_session():
- lam = constant_op.constant([3.0, 4.0])
- lam_v = [3.0, 4.0]
- n = constant_op.constant(100000)
- exponential = exponential_lib.Exponential(rate=lam)
-
- samples = exponential.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- for i in range(2):
- self.assertLess(
- stats.kstest(
- sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
+ lam = constant_op.constant([3.0, 4.0])
+ lam_v = [3.0, 4.0]
+ n = constant_op.constant(100000)
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ samples = exponential.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ for i in range(2):
+ self.assertLess(
+ stats.kstest(sample_values[:, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
def testExponentialSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 2
- lam_v = [3.0, 22.0]
- lam = constant_op.constant([lam_v] * batch_size)
+ batch_size = 2
+ lam_v = [3.0, 22.0]
+ lam = constant_op.constant([lam_v] * batch_size)
- exponential = exponential_lib.Exponential(rate=lam)
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ n = 100000
+ samples = exponential.sample(n, seed=138)
+ self.assertEqual(samples.get_shape(), (n, batch_size, 2))
+
+ sample_values = self.evaluate(samples)
- n = 100000
- samples = exponential.sample(n, seed=138)
- self.assertEqual(samples.get_shape(), (n, batch_size, 2))
-
- sample_values = self.evaluate(samples)
-
- self.assertFalse(np.any(sample_values < 0.0))
- if not stats:
- return
- for i in range(2):
- self.assertLess(
- stats.kstest(
- sample_values[:, 0, i],
- stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
- self.assertLess(
- stats.kstest(
- sample_values[:, 1, i],
- stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
- 0.01)
+ self.assertFalse(np.any(sample_values < 0.0))
+ if not stats:
+ return
+ for i in range(2):
+ self.assertLess(
+ stats.kstest(sample_values[:, 0, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
+ self.assertLess(
+ stats.kstest(sample_values[:, 1, i],
+ stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
def testFullyReparameterized(self):
lam = constant_op.constant([0.1, 1.0])
@@ -174,11 +160,10 @@ class ExponentialTest(test.TestCase):
self.assertIsNotNone(grad_lam)
def testExponentialWithSoftplusRate(self):
- with self.test_session():
- lam = [-2.2, -3.4]
- exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
+ lam = [-2.2, -3.4]
+ exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 297e20264c..4eff40b029 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -50,221 +50,203 @@ stats = try_import("scipy.stats")
class GammaTest(test.TestCase):
def testGammaShape(self):
- with self.test_session():
- alpha = constant_op.constant([3.0] * 5)
- beta = constant_op.constant(11.0)
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ alpha = constant_op.constant([3.0] * 5)
+ beta = constant_op.constant(11.0)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
- self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
- self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
+ self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
+ self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
def testGammaLogPDF(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([2.0] * batch_size)
- beta = constant_op.constant([3.0] * batch_size)
- alpha_v = 2.0
- beta_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
- pdf = gamma.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([2.0] * batch_size)
+ beta = constant_op.constant([3.0] * batch_size)
+ alpha_v = 2.0
+ beta_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
+ pdf = gamma.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
- beta = constant_op.constant([[3.0, 4.0]] * batch_size)
- alpha_v = np.array([2.0, 4.0])
- beta_v = np.array([3.0, 4.0])
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = gamma.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+ beta = constant_op.constant([[3.0, 4.0]] * batch_size)
+ alpha_v = np.array([2.0, 4.0])
+ beta_v = np.array([3.0, 4.0])
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = gamma.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
- beta = constant_op.constant(3.0)
- alpha_v = np.array([2.0, 4.0])
- beta_v = 3.0
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- log_pdf = gamma.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = gamma.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
-
- if not stats:
- return
- expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+ beta = constant_op.constant(3.0)
+ alpha_v = np.array([2.0, 4.0])
+ beta_v = 3.0
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ log_pdf = gamma.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = gamma.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
- def testGammaCDF(self):
- with self.test_session():
- batch_size = 6
- alpha = constant_op.constant([2.0] * batch_size)
- beta = constant_op.constant([3.0] * batch_size)
- alpha_v = 2.0
- beta_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ if not stats:
+ return
+ expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- cdf = gamma.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ def testGammaCDF(self):
+ batch_size = 6
+ alpha = constant_op.constant([2.0] * batch_size)
+ beta = constant_op.constant([3.0] * batch_size)
+ alpha_v = 2.0
+ beta_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ cdf = gamma.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testGammaMean(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.mean().get_shape(), (3,))
- if not stats:
- return
- expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
- with self.test_session():
- alpha_v = np.array([5.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- expected_modes = (alpha_v - 1) / beta_v
- self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+ alpha_v = np.array([5.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ expected_modes = (alpha_v - 1) / beta_v
+ self.assertEqual(gamma.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
- with self.test_session():
- # Mode will not be defined for the first entry.
- alpha_v = np.array([0.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(gamma.mode())
+ # Mode will not be defined for the first entry.
+ alpha_v = np.array([0.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(gamma.mode())
def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
- with self.test_session():
- # Mode will not be defined for the first entry.
- alpha_v = np.array([0.5, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- allow_nan_stats=True)
- expected_modes = (alpha_v - 1) / beta_v
- expected_modes[0] = np.nan
- self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+ # Mode will not be defined for the first entry.
+ alpha_v = np.array([0.5, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
+ expected_modes = (alpha_v - 1) / beta_v
+ expected_modes[0] = np.nan
+ self.assertEqual(gamma.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaVariance(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
def testGammaStd(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.stddev().get_shape(), (3,))
- if not stats:
- return
- expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
- self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.stddev().get_shape(), (3,))
+ if not stats:
+ return
+ expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
+ self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
def testGammaEntropy(self):
- with self.test_session():
- alpha_v = np.array([1.0, 3.0, 2.5])
- beta_v = np.array([1.0, 4.0, 5.0])
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- self.assertEqual(gamma.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
- self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
+ alpha_v = np.array([1.0, 3.0, 2.5])
+ beta_v = np.array([1.0, 4.0, 5.0])
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ self.assertEqual(gamma.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
+ self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
def testGammaSampleSmallAlpha(self):
- with self.test_session():
- alpha_v = 0.05
- beta_v = 1.0
- alpha = constant_op.constant(alpha_v)
- beta = constant_op.constant(beta_v)
- n = 100000
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.gamma.mean(
- alpha_v, scale=1 / beta_v),
- atol=.01)
- self.assertAllClose(
- sample_values.var(),
- stats.gamma.var(alpha_v, scale=1 / beta_v),
- atol=.15)
+ alpha_v = 0.05
+ beta_v = 1.0
+ alpha = constant_op.constant(alpha_v)
+ beta = constant_op.constant(beta_v)
+ n = 100000
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.gamma.mean(alpha_v, scale=1 / beta_v),
+ atol=.01)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.gamma.var(alpha_v, scale=1 / beta_v),
+ atol=.15)
def testGammaSample(self):
- with self.test_session():
- alpha_v = 4.0
- beta_v = 3.0
- alpha = constant_op.constant(alpha_v)
- beta = constant_op.constant(beta_v)
- n = 100000
- gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.gamma.mean(
- alpha_v, scale=1 / beta_v),
- atol=.01)
- self.assertAllClose(
- sample_values.var(),
- stats.gamma.var(alpha_v, scale=1 / beta_v),
- atol=.15)
+ alpha_v = 4.0
+ beta_v = 3.0
+ alpha = constant_op.constant(alpha_v)
+ beta = constant_op.constant(beta_v)
+ n = 100000
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.gamma.mean(alpha_v, scale=1 / beta_v),
+ atol=.01)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.gamma.var(alpha_v, scale=1 / beta_v),
+ atol=.15)
def testGammaFullyReparameterized(self):
alpha = constant_op.constant(4.0)
@@ -279,37 +261,37 @@ class GammaTest(test.TestCase):
self.assertIsNotNone(grad_beta)
def testGammaSampleMultiDimensional(self):
- with self.test_session():
- alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
- beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
- n = 10000
- samples = gamma.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n, 10, 100))
- self.assertEqual(sample_values.shape, (n, 10, 100))
- zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
- alpha_bc = alpha_v + zeros
- beta_bc = beta_v + zeros
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(axis=0),
- stats.gamma.mean(
- alpha_bc, scale=1 / beta_bc),
- atol=0., rtol=.05)
- self.assertAllClose(
- sample_values.var(axis=0),
- stats.gamma.var(alpha_bc, scale=1 / beta_bc),
- atol=10.0, rtol=0.)
- fails = 0
- trials = 0
- for ai, a in enumerate(np.reshape(alpha_v, [-1])):
- for bi, b in enumerate(np.reshape(beta_v, [-1])):
- s = sample_values[:, bi, ai]
- trials += 1
- fails += 0 if self._kstest(a, b, s) else 1
- self.assertLess(fails, trials * 0.03)
+ alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
+ beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
+ gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+ n = 10000
+ samples = gamma.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n, 10, 100))
+ self.assertEqual(sample_values.shape, (n, 10, 100))
+ zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
+ alpha_bc = alpha_v + zeros
+ beta_bc = beta_v + zeros
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(axis=0),
+ stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
+ atol=0.,
+ rtol=.05)
+ self.assertAllClose(
+ sample_values.var(axis=0),
+ stats.gamma.var(alpha_bc, scale=1 / beta_bc),
+ atol=10.0,
+ rtol=0.)
+ fails = 0
+ trials = 0
+ for ai, a in enumerate(np.reshape(alpha_v, [-1])):
+ for bi, b in enumerate(np.reshape(beta_v, [-1])):
+ s = sample_values[:, bi, ai]
+ trials += 1
+ fails += 0 if self._kstest(a, b, s) else 1
+ self.assertLess(fails, trials * 0.03)
def _kstest(self, alpha, beta, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -320,30 +302,29 @@ class GammaTest(test.TestCase):
return ks < 0.02
def testGammaPdfOfSampleMultiDims(self):
- with self.test_session():
- gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
- num = 50000
- samples = gamma.sample(num, seed=137)
- pdfs = gamma.prob(samples)
- sample_vals, pdf_vals = self.evaluate([samples, pdfs])
- self.assertEqual(samples.get_shape(), (num, 2, 2))
- self.assertEqual(pdfs.get_shape(), (num, 2, 2))
- self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
- if not stats:
- return
- self.assertAllClose(
- stats.gamma.mean(
- [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
- sample_vals.mean(axis=0),
- atol=.1)
- self.assertAllClose(
- stats.gamma.var([[7., 11.], [7., 11.]],
- scale=1 / np.array([[5., 5.], [6., 6.]])),
- sample_vals.var(axis=0),
- atol=.1)
+ gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
+ num = 50000
+ samples = gamma.sample(num, seed=137)
+ pdfs = gamma.prob(samples)
+ sample_vals, pdf_vals = self.evaluate([samples, pdfs])
+ self.assertEqual(samples.get_shape(), (num, 2, 2))
+ self.assertEqual(pdfs.get_shape(), (num, 2, 2))
+ self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
+ self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+ if not stats:
+ return
+ self.assertAllClose(
+ stats.gamma.mean([[7., 11.], [7., 11.]],
+ scale=1 / np.array([[5., 5.], [6., 6.]])),
+ sample_vals.mean(axis=0),
+ atol=.1)
+ self.assertAllClose(
+ stats.gamma.var([[7., 11.], [7., 11.]],
+ scale=1 / np.array([[5., 5.], [6., 6.]])),
+ sample_vals.var(axis=0),
+ atol=.1)
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
s_p = zip(sample_vals, pdf_vals)
@@ -356,32 +337,29 @@ class GammaTest(test.TestCase):
self.assertNear(1., total, err=err)
def testGammaNonPositiveInitializationParamsRaises(self):
- with self.test_session():
- alpha_v = constant_op.constant(0.0, name="alpha")
- beta_v = constant_op.constant(1.0, name="beta")
- with self.assertRaisesOpError("x > 0"):
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- self.evaluate(gamma.mean())
- alpha_v = constant_op.constant(1.0, name="alpha")
- beta_v = constant_op.constant(0.0, name="beta")
- with self.assertRaisesOpError("x > 0"):
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- self.evaluate(gamma.mean())
+ alpha_v = constant_op.constant(0.0, name="alpha")
+ beta_v = constant_op.constant(1.0, name="beta")
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, validate_args=True)
+ self.evaluate(gamma.mean())
+ alpha_v = constant_op.constant(1.0, name="alpha")
+ beta_v = constant_op.constant(0.0, name="beta")
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(
+ concentration=alpha_v, rate=beta_v, validate_args=True)
+ self.evaluate(gamma.mean())
def testGammaWithSoftplusConcentrationRate(self):
- with self.test_session():
- alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
- beta_v = constant_op.constant([1.0, -3.6], name="beta")
- gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
- concentration=alpha_v, rate=beta_v)
- self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
- self.evaluate(gamma.concentration))
- self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
- self.evaluate(gamma.rate))
+ alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
+ beta_v = constant_op.constant([1.0, -3.6], name="beta")
+ gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
+ concentration=alpha_v, rate=beta_v)
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(alpha_v)),
+ self.evaluate(gamma.concentration))
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
def testGammaGammaKL(self):
alpha0 = np.array([3.])
@@ -391,15 +369,14 @@ class GammaTest(test.TestCase):
beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
# Build graph.
- with self.test_session():
- g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
- g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
- x = g0.sample(int(1e4), seed=0)
- kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
- kl_actual = kullback_leibler.kl_divergence(g0, g1)
-
- # Execute graph.
- [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
+ g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
+ g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
+ x = g0.sample(int(1e4), seed=0)
+ kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
+ kl_actual = kullback_leibler.kl_divergence(g0, g1)
+
+ # Execute graph.
+ [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
self.assertEqual(beta0.shape, kl_actual.get_shape())
diff --git a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
index b347c20db2..e35a8e1cdd 100644
--- a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
@@ -42,7 +42,7 @@ class IdentityBijectorTest(test.TestCase):
bijector.forward_log_det_jacobian(x, event_ndims=3)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = identity_bijector.Identity()
bijector_test_util.assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
diff --git a/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
index d0fa1fe989..e77e1117d4 100644
--- a/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
+++ b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
@@ -58,7 +58,7 @@ class KLTest(test.TestCase):
# pylint: disable=unused-argument,unused-variable
- with self.test_session():
+ with self.cached_session():
a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False)
kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
with self.assertRaisesOpError(
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 24b243f647..630c2cb424 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -21,7 +21,6 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
@@ -49,212 +48,198 @@ stats = try_import("scipy.stats")
class LaplaceTest(test.TestCase):
def testLaplaceShape(self):
- with self.test_session():
- loc = constant_op.constant([3.0] * 5)
- scale = constant_op.constant(11.0)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ loc = constant_op.constant([3.0] * 5)
+ scale = constant_op.constant(11.0)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
- self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
- self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
+ self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
+ self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
def testLaplaceLogPDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- self.assertEqual(log_pdf.get_shape(), (6,))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ self.assertEqual(log_pdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
- pdf = laplace.prob(x)
- self.assertEqual(pdf.get_shape(), (6,))
- self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ pdf = laplace.prob(x)
+ self.assertEqual(pdf.get_shape(), (6,))
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([[2.0, 4.0]] * batch_size)
- scale = constant_op.constant([[3.0, 4.0]] * batch_size)
- loc_v = np.array([2.0, 4.0])
- scale_v = np.array([3.0, 4.0])
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
-
- pdf = laplace.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+ scale = constant_op.constant([[3.0, 4.0]] * batch_size)
+ loc_v = np.array([2.0, 4.0])
+ scale_v = np.array([3.0, 4.0])
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+
+ pdf = laplace.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensionalBroadcasting(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([[2.0, 4.0]] * batch_size)
- scale = constant_op.constant(3.0)
- loc_v = np.array([2.0, 4.0])
- scale_v = 3.0
- x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- log_pdf = laplace.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
-
- pdf = laplace.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- if not stats:
- return
- expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
- self.assertAllClose(log_pdf_values, expected_log_pdf)
- self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+ batch_size = 6
+ loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+ scale = constant_op.constant(3.0)
+ loc_v = np.array([2.0, 4.0])
+ scale_v = 3.0
+ x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ log_pdf = laplace.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+
+ pdf = laplace.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ if not stats:
+ return
+ expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(log_pdf_values, expected_log_pdf)
+ self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceCDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- cdf = laplace.cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ cdf = laplace.cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogCDF(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- cdf = laplace.log_cdf(x)
- self.assertEqual(cdf.get_shape(), (6,))
- if not stats:
- return
- expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ cdf = laplace.log_cdf(x)
+ self.assertEqual(cdf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogSurvivalFunction(self):
- with self.test_session():
- batch_size = 6
- loc = constant_op.constant([2.0] * batch_size)
- scale = constant_op.constant([3.0] * batch_size)
- loc_v = 2.0
- scale_v = 3.0
- x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+ batch_size = 6
+ loc = constant_op.constant([2.0] * batch_size)
+ scale = constant_op.constant([3.0] * batch_size)
+ loc_v = 2.0
+ scale_v = 3.0
+ x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- sf = laplace.log_survival_function(x)
- self.assertEqual(sf.get_shape(), (6,))
- if not stats:
- return
- expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(sf), expected_sf)
+ sf = laplace.log_survival_function(x)
+ self.assertEqual(sf.get_shape(), (6,))
+ if not stats:
+ return
+ expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(sf), expected_sf)
def testLaplaceMean(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.mean().get_shape(), (3,))
- if not stats:
- return
- expected_means = stats.laplace.mean(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.mean().get_shape(), (3,))
+ if not stats:
+ return
+ expected_means = stats.laplace.mean(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
def testLaplaceMode(self):
- with self.test_session():
- loc_v = np.array([0.5, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.mode().get_shape(), (3,))
- self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
+ loc_v = np.array([0.5, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.mode().get_shape(), (3,))
+ self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
def testLaplaceVariance(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.variance().get_shape(), (3,))
- if not stats:
- return
- expected_variances = stats.laplace.var(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.variance().get_shape(), (3,))
+ if not stats:
+ return
+ expected_variances = stats.laplace.var(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
def testLaplaceStd(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.stddev().get_shape(), (3,))
- if not stats:
- return
- expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.stddev().get_shape(), (3,))
+ if not stats:
+ return
+ expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
def testLaplaceEntropy(self):
- with self.test_session():
- loc_v = np.array([1.0, 3.0, 2.5])
- scale_v = np.array([1.0, 4.0, 5.0])
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- self.assertEqual(laplace.entropy().get_shape(), (3,))
- if not stats:
- return
- expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
- self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
+ loc_v = np.array([1.0, 3.0, 2.5])
+ scale_v = np.array([1.0, 4.0, 5.0])
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ self.assertEqual(laplace.entropy().get_shape(), (3,))
+ if not stats:
+ return
+ expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
+ self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
def testLaplaceSample(self):
- with session.Session():
- loc_v = 4.0
- scale_v = 3.0
- loc = constant_op.constant(loc_v)
- scale = constant_op.constant(scale_v)
- n = 100000
- laplace = laplace_lib.Laplace(loc=loc, scale=scale)
- samples = laplace.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n,))
- self.assertEqual(sample_values.shape, (n,))
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(),
- stats.laplace.mean(
- loc_v, scale=scale_v),
- rtol=0.05,
- atol=0.)
- self.assertAllClose(
- sample_values.var(),
- stats.laplace.var(loc_v, scale=scale_v),
- rtol=0.05,
- atol=0.)
- self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+ loc_v = 4.0
+ scale_v = 3.0
+ loc = constant_op.constant(loc_v)
+ scale = constant_op.constant(scale_v)
+ n = 100000
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ samples = laplace.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n,))
+ self.assertEqual(sample_values.shape, (n,))
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(),
+ stats.laplace.mean(loc_v, scale=scale_v),
+ rtol=0.05,
+ atol=0.)
+ self.assertAllClose(
+ sample_values.var(),
+ stats.laplace.var(loc_v, scale=scale_v),
+ rtol=0.05,
+ atol=0.)
+ self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
def testLaplaceFullyReparameterized(self):
loc = constant_op.constant(4.0)
@@ -269,39 +254,37 @@ class LaplaceTest(test.TestCase):
self.assertIsNotNone(grad_scale)
def testLaplaceSampleMultiDimensional(self):
- with session.Session():
- loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
- scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
- laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
- n = 10000
- samples = laplace.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (n, 10, 100))
- self.assertEqual(sample_values.shape, (n, 10, 100))
- zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
- loc_bc = loc_v + zeros
- scale_bc = scale_v + zeros
- if not stats:
- return
- self.assertAllClose(
- sample_values.mean(axis=0),
- stats.laplace.mean(
- loc_bc, scale=scale_bc),
- rtol=0.35,
- atol=0.)
- self.assertAllClose(
- sample_values.var(axis=0),
- stats.laplace.var(loc_bc, scale=scale_bc),
- rtol=0.10,
- atol=0.)
- fails = 0
- trials = 0
- for ai, a in enumerate(np.reshape(loc_v, [-1])):
- for bi, b in enumerate(np.reshape(scale_v, [-1])):
- s = sample_values[:, bi, ai]
- trials += 1
- fails += 0 if self._kstest(a, b, s) else 1
- self.assertLess(fails, trials * 0.03)
+ loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
+ scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
+ laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+ n = 10000
+ samples = laplace.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (n, 10, 100))
+ self.assertEqual(sample_values.shape, (n, 10, 100))
+ zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
+ loc_bc = loc_v + zeros
+ scale_bc = scale_v + zeros
+ if not stats:
+ return
+ self.assertAllClose(
+ sample_values.mean(axis=0),
+ stats.laplace.mean(loc_bc, scale=scale_bc),
+ rtol=0.35,
+ atol=0.)
+ self.assertAllClose(
+ sample_values.var(axis=0),
+ stats.laplace.var(loc_bc, scale=scale_bc),
+ rtol=0.10,
+ atol=0.)
+ fails = 0
+ trials = 0
+ for ai, a in enumerate(np.reshape(loc_v, [-1])):
+ for bi, b in enumerate(np.reshape(scale_v, [-1])):
+ s = sample_values[:, bi, ai]
+ trials += 1
+ fails += 0 if self._kstest(a, b, s) else 1
+ self.assertLess(fails, trials * 0.03)
def _kstest(self, loc, scale, samples):
# Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -349,30 +332,26 @@ class LaplaceTest(test.TestCase):
self.assertNear(1., total, err=err)
def testLaplaceNonPositiveInitializationParamsRaises(self):
- with self.test_session():
- loc_v = constant_op.constant(0.0, name="loc")
- scale_v = constant_op.constant(-1.0, name="scale")
- with self.assertRaisesOpError(
- "Condition x > 0 did not hold element-wise"):
- laplace = laplace_lib.Laplace(
- loc=loc_v, scale=scale_v, validate_args=True)
- self.evaluate(laplace.mean())
- loc_v = constant_op.constant(1.0, name="loc")
- scale_v = constant_op.constant(0.0, name="scale")
- with self.assertRaisesOpError(
- "Condition x > 0 did not hold element-wise"):
- laplace = laplace_lib.Laplace(
- loc=loc_v, scale=scale_v, validate_args=True)
- self.evaluate(laplace.mean())
+ loc_v = constant_op.constant(0.0, name="loc")
+ scale_v = constant_op.constant(-1.0, name="scale")
+ with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+ laplace = laplace_lib.Laplace(
+ loc=loc_v, scale=scale_v, validate_args=True)
+ self.evaluate(laplace.mean())
+ loc_v = constant_op.constant(1.0, name="loc")
+ scale_v = constant_op.constant(0.0, name="scale")
+ with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+ laplace = laplace_lib.Laplace(
+ loc=loc_v, scale=scale_v, validate_args=True)
+ self.evaluate(laplace.mean())
def testLaplaceWithSoftplusScale(self):
- with self.test_session():
- loc_v = constant_op.constant([0.0, 1.0], name="loc")
- scale_v = constant_op.constant([-1.0, 2.0], name="scale")
- laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
- self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
+ loc_v = constant_op.constant([0.0, 1.0], name="loc")
+ scale_v = constant_op.constant([-1.0, 2.0], name="scale")
+ laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
+ self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index bfd40ba2b7..3840d7331c 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -34,7 +34,7 @@ class MultinomialTest(test.TestCase):
self._rng = np.random.RandomState(42)
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
p = [.1, .3, .6]
dist = multinomial.Multinomial(total_count=1., probs=p)
self.assertEqual(3, dist.event_shape_tensor().eval())
@@ -43,7 +43,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
n = [[3., 2], [4, 5], [6, 7]]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -55,14 +55,14 @@ class MultinomialTest(test.TestCase):
def testN(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=n, probs=p)
self.assertEqual((2, 1), dist.total_count.get_shape())
self.assertAllClose(n, dist.total_count.eval())
def testP(self):
p = [[0.1, 0.2, 0.7]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=3., probs=p)
self.assertEqual((1, 3), dist.probs.get_shape())
self.assertEqual((1, 3), dist.logits.get_shape())
@@ -71,7 +71,7 @@ class MultinomialTest(test.TestCase):
def testLogits(self):
p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32)
logits = np.log(p) - 50.
- with self.test_session():
+ with self.cached_session():
multinom = multinomial.Multinomial(total_count=3., logits=logits)
self.assertEqual((1, 3), multinom.probs.get_shape())
self.assertEqual((1, 3), multinom.logits.get_shape())
@@ -80,7 +80,7 @@ class MultinomialTest(test.TestCase):
def testPmfUnderflow(self):
logits = np.array([[-200, 0]], dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=1., logits=logits)
lp = dist.log_prob([1., 0.]).eval()[0]
self.assertAllClose(-200, lp, atol=0, rtol=1e-6)
@@ -88,7 +88,7 @@ class MultinomialTest(test.TestCase):
def testPmfandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -100,7 +100,7 @@ class MultinomialTest(test.TestCase):
def testPmfNonIntegerCounts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
# No errors with integer n.
multinom = multinomial.Multinomial(
total_count=n, probs=p, validate_args=True)
@@ -122,7 +122,7 @@ class MultinomialTest(test.TestCase):
multinom.prob([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = [0.5, 0.5]
counts = [1., 0]
@@ -131,7 +131,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = [0.1, 0.9]
counts = [3., 2]
@@ -142,7 +142,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9]]
counts = [[1., 0], [0, 1]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -150,7 +150,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
p = [0.1, 0.9]
counts = [[1., 0], [0, 1]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -158,7 +158,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [[1., 0]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -166,7 +166,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [1., 0]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -174,7 +174,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(pmf.get_shape(), (2))
def testPmfShapeCountsStretchedN(self):
- with self.test_session():
+ with self.cached_session():
# [2, 2, 2]
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
# [2, 2]
@@ -186,7 +186,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(pmf.get_shape(), (2, 2))
def testPmfShapeCountsPStretchedN(self):
- with self.test_session():
+ with self.cached_session():
p = [0.1, 0.9]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
@@ -195,7 +195,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((4, 3), pmf.get_shape())
def testMultinomialMean(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -204,7 +204,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(expected_means, dist.mean().eval())
def testMultinomialCovariance(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -215,7 +215,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(expected_covariances, dist.covariance().eval())
def testMultinomialCovarianceBatch(self):
- with self.test_session():
+ with self.cached_session():
# Shape [2]
n = [5.] * 2
# Shape [4, 1, 2]
@@ -237,7 +237,7 @@ class MultinomialTest(test.TestCase):
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(ns, p)
dist2 = multinomial.Multinomial(ns2, p2)
@@ -253,7 +253,7 @@ class MultinomialTest(test.TestCase):
[2.5, 4, 0.01]], dtype=np.float32)
theta /= np.sum(theta, 1)[..., array_ops.newaxis]
n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# batch_shape=[3, 2], event_shape=[3]
dist = multinomial.Multinomial(n, theta)
x = dist.sample(int(1000e3), seed=1)
@@ -289,7 +289,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
def testSampleUnbiasedNonScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = multinomial.Multinomial(
total_count=[7., 6., 5.],
logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
@@ -318,7 +318,7 @@ class MultinomialTest(test.TestCase):
actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
def testSampleUnbiasedScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = multinomial.Multinomial(
total_count=5.,
logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32)))
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index 7ff48c0c10..de73a40b23 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -61,16 +61,15 @@ class NormalTest(test.TestCase):
self.assertAllEqual(all_true, is_finite)
def _testParamShapes(self, sample_shape, expected):
- with self.test_session():
- param_shapes = normal_lib.Normal.param_shapes(sample_shape)
- mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
- self.assertAllEqual(expected, self.evaluate(mu_shape))
- self.assertAllEqual(expected, self.evaluate(sigma_shape))
- mu = array_ops.zeros(mu_shape)
- sigma = array_ops.ones(sigma_shape)
- self.assertAllEqual(
- expected,
- self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
+ param_shapes = normal_lib.Normal.param_shapes(sample_shape)
+ mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
+ self.assertAllEqual(expected, self.evaluate(mu_shape))
+ self.assertAllEqual(expected, self.evaluate(sigma_shape))
+ mu = array_ops.zeros(mu_shape)
+ sigma = array_ops.ones(sigma_shape)
+ self.assertAllEqual(
+ expected,
+ self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
def _testParamStaticShapes(self, sample_shape, expected):
param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
@@ -91,156 +90,150 @@ class NormalTest(test.TestCase):
self._testParamStaticShapes(
tensor_shape.TensorShape(sample_shape), sample_shape)
- @test_util.run_in_graph_and_eager_modes
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalWithSoftplusScale(self):
- with self.test_session():
- mu = array_ops.zeros((10, 3))
- rho = array_ops.ones((10, 3)) * -2.
- normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
- self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
- self.assertAllEqual(
- self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
+ mu = array_ops.zeros((10, 3))
+ rho = array_ops.ones((10, 3)) * -2.
+ normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
+ self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
+ self.assertAllEqual(
+ self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
@test_util.run_in_graph_and_eager_modes
def testNormalLogPDF(self):
- with self.test_session():
- batch_size = 6
- mu = constant_op.constant([3.0] * batch_size)
- sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
- x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
-
- log_pdf = normal.log_prob(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(log_pdf).shape)
- self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+ batch_size = 6
+ mu = constant_op.constant([3.0] * batch_size)
+ sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- pdf = normal.prob(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(pdf).shape)
- self.assertAllEqual(normal.batch_shape, pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
-
- if not stats:
- return
- expected_log_pdf = stats.norm(self.evaluate(mu),
- self.evaluate(sigma)).logpdf(x)
- self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
- self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
+ log_pdf = normal.log_prob(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = normal.prob(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(pdf).shape)
+ self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
+
+ if not stats:
+ return
+ expected_log_pdf = stats.norm(self.evaluate(mu),
+ self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
+ self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testNormalLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- mu = constant_op.constant([[3.0, -3.0]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] *
- batch_size)
- x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
- normal = normal_lib.Normal(loc=mu, scale=sigma)
-
- log_pdf = normal.log_prob(x)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(log_pdf).shape)
- self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
-
- pdf = normal.prob(x)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
- self.assertAllEqual(normal.batch_shape, pdf.get_shape())
- self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+ batch_size = 6
+ mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
+ x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- if not stats:
- return
- expected_log_pdf = stats.norm(self.evaluate(mu),
- self.evaluate(sigma)).logpdf(x)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ log_pdf = normal.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = normal.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
+ self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+
+ if not stats:
+ return
+ expected_log_pdf = stats.norm(self.evaluate(mu),
+ self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
@test_util.run_in_graph_and_eager_modes
def testNormalCDF(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- cdf = normal.cdf(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(cdf).shape)
- self.assertAllEqual(normal.batch_shape, cdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
- if not stats:
- return
- expected_cdf = stats.norm(mu, sigma).cdf(x)
- self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ cdf = normal.cdf(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(cdf).shape)
+ self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+ if not stats:
+ return
+ expected_cdf = stats.norm(mu, sigma).cdf(x)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
@test_util.run_in_graph_and_eager_modes
def testNormalSurvivalFunction(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- sf = normal.survival_function(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(sf).shape)
- self.assertAllEqual(normal.batch_shape, sf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
- if not stats:
- return
- expected_sf = stats.norm(mu, sigma).sf(x)
- self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
+ sf = normal.survival_function(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(sf).shape)
+ self.assertAllEqual(normal.batch_shape, sf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+ if not stats:
+ return
+ expected_sf = stats.norm(mu, sigma).sf(x)
+ self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
@test_util.run_in_graph_and_eager_modes
def testNormalLogCDF(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- cdf = normal.log_cdf(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(cdf).shape)
- self.assertAllEqual(normal.batch_shape, cdf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+ cdf = normal.log_cdf(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(cdf).shape)
+ self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
- if not stats:
- return
- expected_cdf = stats.norm(mu, sigma).logcdf(x)
- self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
+ if not stats:
+ return
+ expected_cdf = stats.norm(mu, sigma).logcdf(x)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
def testFiniteGradientAtDifficultPoints(self):
for dtype in [np.float32, np.float64]:
@@ -256,7 +249,7 @@ class NormalTest(test.TestCase):
]:
value = func(x)
grads = gradients_impl.gradients(value, [mu, sigma])
- with self.test_session(graph=g):
+ with self.session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(value)
self.assertAllFinite(grads[0])
@@ -264,112 +257,106 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalLogSurvivalFunction(self):
- with self.test_session():
- batch_size = 50
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- sf = normal.log_survival_function(x)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(sf).shape)
- self.assertAllEqual(normal.batch_shape, sf.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+ sf = normal.log_survival_function(x)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(sf).shape)
+ self.assertAllEqual(normal.batch_shape, sf.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
- if not stats:
- return
- expected_sf = stats.norm(mu, sigma).logsf(x)
- self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
+ if not stats:
+ return
+ expected_sf = stats.norm(mu, sigma).logsf(x)
+ self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
@test_util.run_in_graph_and_eager_modes
def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test.
- with self.test_session():
- mu_v = 2.34
- sigma_v = 4.56
- normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
-
- entropy = normal.entropy()
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(entropy).shape)
- self.assertAllEqual(normal.batch_shape, entropy.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
- # scipy.stats.norm cannot deal with these shapes.
- if not stats:
- return
- expected_entropy = stats.norm(mu_v, sigma_v).entropy()
- self.assertAllClose(expected_entropy, self.evaluate(entropy))
+ mu_v = 2.34
+ sigma_v = 4.56
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+
+ entropy = normal.entropy()
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(entropy).shape)
+ self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+ # scipy.stats.norm cannot deal with these shapes.
+ if not stats:
+ return
+ expected_entropy = stats.norm(mu_v, sigma_v).entropy()
+ self.assertAllClose(expected_entropy, self.evaluate(entropy))
@test_util.run_in_graph_and_eager_modes
def testNormalEntropy(self):
- with self.test_session():
- mu_v = np.array([1.0, 1.0, 1.0])
- sigma_v = np.array([[1.0, 2.0, 3.0]]).T
- normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
-
- # scipy.stats.norm cannot deal with these shapes.
- sigma_broadcast = mu_v * sigma_v
- expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
- 2)
- entropy = normal.entropy()
- np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(entropy).shape)
- self.assertAllEqual(normal.batch_shape, entropy.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
-
- @test_util.run_in_graph_and_eager_modes
+ mu_v = np.array([1.0, 1.0, 1.0])
+ sigma_v = np.array([[1.0, 2.0, 3.0]]).T
+ normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+
+ # scipy.stats.norm cannot deal with these shapes.
+ sigma_broadcast = mu_v * sigma_v
+ expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
+ entropy = normal.entropy()
+ np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(entropy).shape)
+ self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalMeanAndMode(self):
- with self.test_session():
- # Mu will be broadcast to [7, 7, 7].
- mu = [7.]
- sigma = [11., 12., 13.]
+ # Mu will be broadcast to [7, 7, 7].
+ mu = [7.]
+ sigma = [11., 12., 13.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.mean().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
+ self.assertAllEqual((3,), normal.mean().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
- self.assertAllEqual((3,), normal.mode().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
+ self.assertAllEqual((3,), normal.mode().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
@test_util.run_in_graph_and_eager_modes
def testNormalQuantile(self):
- with self.test_session():
- batch_size = 52
- mu = self._rng.randn(batch_size)
- sigma = self._rng.rand(batch_size) + 1.0
- p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
- # Quantile performs piecewise rational approximation so adding some
- # special input values to make sure we hit all the pieces.
- p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
+ batch_size = 52
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
+ # Quantile performs piecewise rational approximation so adding some
+ # special input values to make sure we hit all the pieces.
+ p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- x = normal.quantile(p)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ x = normal.quantile(p)
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()), x.get_shape())
- self.assertAllEqual(
- self.evaluate(normal.batch_shape_tensor()),
- self.evaluate(x).shape)
- self.assertAllEqual(normal.batch_shape, x.get_shape())
- self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()), x.get_shape())
+ self.assertAllEqual(
+ self.evaluate(normal.batch_shape_tensor()),
+ self.evaluate(x).shape)
+ self.assertAllEqual(normal.batch_shape, x.get_shape())
+ self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
- if not stats:
- return
- expected_x = stats.norm(mu, sigma).ppf(p)
- self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+ if not stats:
+ return
+ expected_x = stats.norm(mu, sigma).ppf(p)
+ self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
g = ops.Graph()
@@ -385,7 +372,7 @@ class NormalTest(test.TestCase):
value = dist.quantile(p)
grads = gradients_impl.gradients(value, [mu, p])
- with self.test_session(graph=g):
+ with self.cached_session(graph=g):
variables.global_variables_initializer().run()
self.assertAllFinite(grads[0])
self.assertAllFinite(grads[1])
@@ -398,61 +385,58 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalVariance(self):
- with self.test_session():
- # sigma will be broadcast to [7, 7, 7]
- mu = [1., 2., 3.]
- sigma = [7.]
+ # sigma will be broadcast to [7, 7, 7]
+ mu = [1., 2., 3.]
+ sigma = [7.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.variance().get_shape())
- self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
+ self.assertAllEqual((3,), normal.variance().get_shape())
+ self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
@test_util.run_in_graph_and_eager_modes
def testNormalStandardDeviation(self):
- with self.test_session():
- # sigma will be broadcast to [7, 7, 7]
- mu = [1., 2., 3.]
- sigma = [7.]
+ # sigma will be broadcast to [7, 7, 7]
+ mu = [1., 2., 3.]
+ sigma = [7.]
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertAllEqual((3,), normal.stddev().get_shape())
- self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
+ self.assertAllEqual((3,), normal.stddev().get_shape())
+ self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
@test_util.run_in_graph_and_eager_modes
def testNormalSample(self):
- with self.test_session():
- mu = constant_op.constant(3.0)
- sigma = constant_op.constant(math.sqrt(3.0))
- mu_v = 3.0
- sigma_v = np.sqrt(3.0)
- n = constant_op.constant(100000)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- samples = normal.sample(n)
- sample_values = self.evaluate(samples)
- # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
- # The sample variance similarly is dependent on sigma and n.
- # Thus, the tolerances below are very sensitive to number of samples
- # as well as the variances chosen.
- self.assertEqual(sample_values.shape, (100000,))
- self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
- self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
-
- expected_samples_shape = tensor_shape.TensorShape(
- [self.evaluate(n)]).concatenate(
- tensor_shape.TensorShape(
- self.evaluate(normal.batch_shape_tensor())))
-
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
-
- expected_samples_shape = (
- tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
- normal.batch_shape))
-
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
+ mu = constant_op.constant(3.0)
+ sigma = constant_op.constant(math.sqrt(3.0))
+ mu_v = 3.0
+ sigma_v = np.sqrt(3.0)
+ n = constant_op.constant(100000)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(n)
+ sample_values = self.evaluate(samples)
+ # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+ # The sample variance similarly is dependent on sigma and n.
+ # Thus, the tolerances below are very sensitive to number of samples
+ # as well as the variances chosen.
+ self.assertEqual(sample_values.shape, (100000,))
+ self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
+ self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
+
+ expected_samples_shape = tensor_shape.TensorShape(
+ [self.evaluate(n)]).concatenate(
+ tensor_shape.TensorShape(
+ self.evaluate(normal.batch_shape_tensor())))
+
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
+
+ expected_samples_shape = (
+ tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+ normal.batch_shape))
+
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
def testNormalFullyReparameterized(self):
mu = constant_op.constant(4.0)
@@ -468,66 +452,63 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNormalSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 2
- mu = constant_op.constant([[3.0, -3.0]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] *
- batch_size)
- mu_v = [3.0, -3.0]
- sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
- n = constant_op.constant(100000)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
- samples = normal.sample(n)
- sample_values = self.evaluate(samples)
- # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
- # The sample variance similarly is dependent on sigma and n.
- # Thus, the tolerances below are very sensitive to number of samples
- # as well as the variances chosen.
- self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
- self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
- self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
-
- expected_samples_shape = tensor_shape.TensorShape(
- [self.evaluate(n)]).concatenate(
- tensor_shape.TensorShape(
- self.evaluate(normal.batch_shape_tensor())))
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
-
- expected_samples_shape = (
- tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
- normal.batch_shape))
- self.assertAllEqual(expected_samples_shape, samples.get_shape())
- self.assertAllEqual(expected_samples_shape, sample_values.shape)
+ batch_size = 2
+ mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
+ mu_v = [3.0, -3.0]
+ sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
+ n = constant_op.constant(100000)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(n)
+ sample_values = self.evaluate(samples)
+ # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+ # The sample variance similarly is dependent on sigma and n.
+ # Thus, the tolerances below are very sensitive to number of samples
+ # as well as the variances chosen.
+ self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
+ self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
+ self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
+
+ expected_samples_shape = tensor_shape.TensorShape(
+ [self.evaluate(n)]).concatenate(
+ tensor_shape.TensorShape(
+ self.evaluate(normal.batch_shape_tensor())))
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
+
+ expected_samples_shape = (
+ tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+ normal.batch_shape))
+ self.assertAllEqual(expected_samples_shape, samples.get_shape())
+ self.assertAllEqual(expected_samples_shape, sample_values.shape)
@test_util.run_in_graph_and_eager_modes
def testNegativeSigmaFails(self):
- with self.test_session():
- with self.assertRaisesOpError("Condition x > 0 did not hold"):
- normal = normal_lib.Normal(
- loc=[1.], scale=[-5.], validate_args=True, name="G")
- self.evaluate(normal.mean())
+ with self.assertRaisesOpError("Condition x > 0 did not hold"):
+ normal = normal_lib.Normal(
+ loc=[1.], scale=[-5.], validate_args=True, name="G")
+ self.evaluate(normal.mean())
@test_util.run_in_graph_and_eager_modes
def testNormalShape(self):
- with self.test_session():
- mu = constant_op.constant([-3.0] * 5)
- sigma = constant_op.constant(11.0)
- normal = normal_lib.Normal(loc=mu, scale=sigma)
+ mu = constant_op.constant([-3.0] * 5)
+ sigma = constant_op.constant(11.0)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
- self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
- self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
- self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
+ self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
+ self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
def testNormalShapeWithPlaceholders(self):
mu = array_ops.placeholder(dtype=dtypes.float32)
sigma = array_ops.placeholder(dtype=dtypes.float32)
normal = normal_lib.Normal(loc=mu, scale=sigma)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
self.assertEqual(normal.event_shape, ())
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index a634194ce5..cc43e12168 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -92,22 +92,21 @@ class NdtriTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNdtri(self):
"""Verifies that ndtri computation is correct."""
- with self.test_session():
- if not special:
- return
+ if not special:
+ return
- p = np.linspace(0., 1.0, 50).astype(np.float64)
- # Quantile performs piecewise rational approximation so adding some
- # special input values to make sure we hit all the pieces.
- p = np.hstack((p, np.exp(-32), 1. - np.exp(-32),
- np.exp(-2), 1. - np.exp(-2)))
- expected_x = special.ndtri(p)
- x = special_math.ndtri(p)
- self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+ p = np.linspace(0., 1.0, 50).astype(np.float64)
+ # Quantile performs piecewise rational approximation so adding some
+ # special input values to make sure we hit all the pieces.
+ p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
+ 1. - np.exp(-2)))
+ expected_x = special.ndtri(p)
+ x = special_math.ndtri(p)
+ self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def testNdtriDynamicShape(self):
"""Verifies that ndtri computation is correct."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if not special:
return
@@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase):
def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
raw_grid = _make_grid(dtype, grid_spec)
grid = ops.convert_to_tensor(raw_grid)
- with self.test_session():
+ with self.cached_session():
fn = sm.log_ndtr if self._use_log else sm.ndtr
# If there are N points in the grid,
@@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest):
class ErfInvTest(test.TestCase):
def testErfInvValues(self):
- with self.test_session():
+ with self.cached_session():
if not special:
return
@@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase):
self.assertAllClose(expected_x, x.eval(), atol=0.)
def testErfInvIntegerInput(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
x = np.array([1, 2, 3]).astype(np.int32)
@@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertAllEqual(np.ones_like(x, dtype=np.bool), x)
def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
- with self.test_session():
+ with self.cached_session():
grid = _make_grid(dtype, grid_spec)
actual = sm.log_cdf_laplace(grid).eval()
@@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase):
ErrorSpec(rtol=0.05, atol=0))
def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways.
grid = _make_grid(
@@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertFalse(np.any(grad_ == 0))
def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways.
grid = _make_grid(
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
index 05590542ef..b34b538160 100644
--- a/tensorflow/python/kernel_tests/distributions/student_t_test.py
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -50,100 +50,96 @@ stats = try_import("scipy.stats")
class StudentTTest(test.TestCase):
def testStudentPDFAndLogPDF(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([3.] * batch_size)
- mu = constant_op.constant([7.] * batch_size)
- sigma = constant_op.constant([8.] * batch_size)
- df_v = 3.
- mu_v = 7.
- sigma_v = 8.
- t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = student_t.StudentT(df, loc=mu, scale=-sigma)
-
- log_pdf = student.log_prob(t)
- self.assertEquals(log_pdf.get_shape(), (6,))
- log_pdf_values = self.evaluate(log_pdf)
- pdf = student.prob(t)
- self.assertEquals(pdf.get_shape(), (6,))
- pdf_values = self.evaluate(pdf)
-
- if not stats:
- return
-
- expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.log(expected_pdf), log_pdf_values)
- self.assertAllClose(expected_pdf, pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=-sigma)
+
+ log_pdf = student.log_prob(t)
+ self.assertEquals(log_pdf.get_shape(), (6,))
+ log_pdf_values = self.evaluate(log_pdf)
+ pdf = student.prob(t)
+ self.assertEquals(pdf.get_shape(), (6,))
+ pdf_values = self.evaluate(pdf)
+
+ if not stats:
+ return
+
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentLogPDFMultidimensional(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([[1.5, 7.2]] * batch_size)
- mu = constant_op.constant([[3., -3.]] * batch_size)
- sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
- batch_size)
- df_v = np.array([1.5, 7.2])
- mu_v = np.array([3., -3.])
- sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
- t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
- student = student_t.StudentT(df, loc=mu, scale=sigma)
- log_pdf = student.log_prob(t)
- log_pdf_values = self.evaluate(log_pdf)
- self.assertEqual(log_pdf.get_shape(), (6, 2))
- pdf = student.prob(t)
- pdf_values = self.evaluate(pdf)
- self.assertEqual(pdf.get_shape(), (6, 2))
-
- if not stats:
- return
- expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_pdf, log_pdf_values)
- self.assertAllClose(np.log(expected_pdf), log_pdf_values)
- self.assertAllClose(expected_pdf, pdf_values)
- self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+ batch_size = 6
+ df = constant_op.constant([[1.5, 7.2]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant(
+ [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+ df_v = np.array([1.5, 7.2])
+ mu_v = np.array([3., -3.])
+ sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
+ t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+ log_pdf = student.log_prob(t)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.get_shape(), (6, 2))
+ pdf = student.prob(t)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.get_shape(), (6, 2))
+
+ if not stats:
+ return
+ expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_pdf, log_pdf_values)
+ self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+ self.assertAllClose(expected_pdf, pdf_values)
+ self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentCDFAndLogCDF(self):
- with self.test_session():
- batch_size = 6
- df = constant_op.constant([3.] * batch_size)
- mu = constant_op.constant([7.] * batch_size)
- sigma = constant_op.constant([-8.] * batch_size)
- df_v = 3.
- mu_v = 7.
- sigma_v = 8.
- t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
- student = student_t.StudentT(df, loc=mu, scale=sigma)
-
- log_cdf = student.log_cdf(t)
- self.assertEquals(log_cdf.get_shape(), (6,))
- log_cdf_values = self.evaluate(log_cdf)
- cdf = student.cdf(t)
- self.assertEquals(cdf.get_shape(), (6,))
- cdf_values = self.evaluate(cdf)
-
- if not stats:
- return
- expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
- expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
- self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(
- np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
- self.assertAllClose(
- np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
+ batch_size = 6
+ df = constant_op.constant([3.] * batch_size)
+ mu = constant_op.constant([7.] * batch_size)
+ sigma = constant_op.constant([-8.] * batch_size)
+ df_v = 3.
+ mu_v = 7.
+ sigma_v = 8.
+ t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+ student = student_t.StudentT(df, loc=mu, scale=sigma)
+
+ log_cdf = student.log_cdf(t)
+ self.assertEquals(log_cdf.get_shape(), (6,))
+ log_cdf_values = self.evaluate(log_cdf)
+ cdf = student.cdf(t)
+ self.assertEquals(cdf.get_shape(), (6,))
+ cdf_values = self.evaluate(cdf)
+
+ if not stats:
+ return
+ expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
+ expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
+ self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
+ self.assertAllClose(
+ np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
def testStudentEntropy(self):
df_v = np.array([[2., 3., 7.]]) # 1x3
mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
- with self.test_session():
- student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
- ent = student.entropy()
- ent_values = self.evaluate(ent)
+ student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
+ ent = student.entropy()
+ ent_values = self.evaluate(ent)
# Help scipy broadcast to 3x3
ones = np.array([[1, 1, 1]])
@@ -160,90 +156,81 @@ class StudentTTest(test.TestCase):
self.assertAllClose(expected_entropy, ent_values)
def testStudentSample(self):
- with self.test_session():
- df = constant_op.constant(4.)
- mu = constant_op.constant(3.)
- sigma = constant_op.constant(-math.sqrt(10.))
- df_v = 4.
- mu_v = 3.
- sigma_v = np.sqrt(10.)
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- n_val = 200000
- self.assertEqual(sample_values.shape, (n_val,))
- self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values.var(),
- sigma_v**2 * df_v / (df_v - 2),
- rtol=0.1,
- atol=0)
- self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(-math.sqrt(10.))
+ df_v = 4.
+ mu_v = 3.
+ sigma_v = np.sqrt(10.)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val,))
+ self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
+ self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
# Test that sampling with the same seed twice gives the same results.
def testStudentSampleMultipleTimes(self):
- with self.test_session():
- df = constant_op.constant(4.)
- mu = constant_op.constant(3.)
- sigma = constant_op.constant(math.sqrt(10.))
- n = constant_op.constant(100)
+ df = constant_op.constant(4.)
+ mu = constant_op.constant(3.)
+ sigma = constant_op.constant(math.sqrt(10.))
+ n = constant_op.constant(100)
- random_seed.set_random_seed(654321)
- student = student_t.StudentT(
- df=df, loc=mu, scale=sigma, name="student_t1")
- samples1 = self.evaluate(student.sample(n, seed=123456))
+ random_seed.set_random_seed(654321)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
+ samples1 = self.evaluate(student.sample(n, seed=123456))
- random_seed.set_random_seed(654321)
- student2 = student_t.StudentT(
- df=df, loc=mu, scale=sigma, name="student_t2")
- samples2 = self.evaluate(student2.sample(n, seed=123456))
+ random_seed.set_random_seed(654321)
+ student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
+ samples2 = self.evaluate(student2.sample(n, seed=123456))
- self.assertAllClose(samples1, samples2)
+ self.assertAllClose(samples1, samples2)
def testStudentSampleSmallDfNoNan(self):
- with self.test_session():
- df_v = [1e-1, 1e-5, 1e-10, 1e-20]
- df = constant_op.constant(df_v)
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=1., scale=1.)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- n_val = 200000
- self.assertEqual(sample_values.shape, (n_val, 4))
- self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
+ df_v = [1e-1, 1e-5, 1e-10, 1e-20]
+ df = constant_op.constant(df_v)
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=1., scale=1.)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ n_val = 200000
+ self.assertEqual(sample_values.shape, (n_val, 4))
+ self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
def testStudentSampleMultiDimensional(self):
- with self.test_session():
- batch_size = 7
- df = constant_op.constant([[5., 7.]] * batch_size)
- mu = constant_op.constant([[3., -3.]] * batch_size)
- sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
- batch_size)
- df_v = [5., 7.]
- mu_v = [3., -3.]
- sigma_v = [np.sqrt(10.), np.sqrt(15.)]
- n = constant_op.constant(200000)
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- samples = student.sample(n, seed=123456)
- sample_values = self.evaluate(samples)
- self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
- self.assertAllClose(
- sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values[:, 0, 0].var(),
- sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
- rtol=0.2,
- atol=0)
- self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
- self.assertAllClose(
- sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
- self.assertAllClose(
- sample_values[:, 0, 1].var(),
- sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
- rtol=0.2,
- atol=0)
- self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
+ batch_size = 7
+ df = constant_op.constant([[5., 7.]] * batch_size)
+ mu = constant_op.constant([[3., -3.]] * batch_size)
+ sigma = constant_op.constant(
+ [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+ df_v = [5., 7.]
+ mu_v = [3., -3.]
+ sigma_v = [np.sqrt(10.), np.sqrt(15.)]
+ n = constant_op.constant(200000)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(n, seed=123456)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
+ self.assertAllClose(
+ sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 0].var(),
+ sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
+ rtol=0.2,
+ atol=0)
+ self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
+ self.assertAllClose(
+ sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
+ self.assertAllClose(
+ sample_values[:, 0, 1].var(),
+ sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
+ rtol=0.2,
+ atol=0)
+ self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
def _checkKLApprox(self, df, mu, sigma, samples):
n = samples.size
@@ -325,114 +312,102 @@ class StudentTTest(test.TestCase):
_check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
- with self.test_session():
- mu = [1., 3.3, 4.4]
- student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
- mean = self.evaluate(student.mean())
- self.assertAllClose([1., 3.3, 4.4], mean)
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
+ mean = self.evaluate(student.mean())
+ self.assertAllClose([1., 3.3, 4.4], mean)
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
- with self.test_session():
- mu = [1., 3.3, 4.4]
- student = student_t.StudentT(
- df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
- allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.mean())
+ mu = [1., 3.3, 4.4]
+ student = student_t.StudentT(
+ df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.mean())
def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
- with self.test_session():
- mu = [-2, 0., 1., 3.3, 4.4]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(
- df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
- allow_nan_stats=True)
- mean = self.evaluate(student.mean())
- self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
+ mean = self.evaluate(student.mean())
+ self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
- with self.test_session():
- # df = 0.5 ==> undefined mean ==> undefined variance.
- # df = 1.5 ==> infinite variance.
- df = [0.5, 1.5, 3., 5., 7.]
- mu = [-2, 0., 1., 3.3, 4.4]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(
- df=df, loc=mu, scale=sigma, allow_nan_stats=True)
- var = self.evaluate(student.variance())
- ## scipy uses inf for variance when the mean is undefined. When mean is
- # undefined we say variance is undefined as well. So test the first
- # member of var, making sure it is NaN, then replace with inf and compare
- # to scipy.
- self.assertTrue(np.isnan(var[0]))
- var[0] = np.inf
-
- if not stats:
- return
- expected_var = [
- stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_var, var)
+ # df = 0.5 ==> undefined mean ==> undefined variance.
+ # df = 1.5 ==> infinite variance.
+ df = [0.5, 1.5, 3., 5., 7.]
+ mu = [-2, 0., 1., 3.3, 4.4]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(
+ df=df, loc=mu, scale=sigma, allow_nan_stats=True)
+ var = self.evaluate(student.variance())
+ ## scipy uses inf for variance when the mean is undefined. When mean is
+ # undefined we say variance is undefined as well. So test the first
+ # member of var, making sure it is NaN, then replace with inf and compare
+ # to scipy.
+ self.assertTrue(np.isnan(var[0]))
+ var[0] = np.inf
+
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
self):
- with self.test_session():
- # df = 1.5 ==> infinite variance.
- df = [1.5, 3., 5., 7.]
- mu = [0., 1., 3.3, 4.4]
- sigma = [4., 3., 2., 1.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- var = self.evaluate(student.variance())
+ # df = 1.5 ==> infinite variance.
+ df = [1.5, 3., 5., 7.]
+ mu = [0., 1., 3.3, 4.4]
+ sigma = [4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ var = self.evaluate(student.variance())
- if not stats:
- return
- expected_var = [
- stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_var, var)
+ if not stats:
+ return
+ expected_var = [
+ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_var, var)
def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
- with self.test_session():
- # df <= 1 ==> variance not defined
- student = student_t.StudentT(
- df=1., loc=0., scale=1., allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.variance())
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.variance())
- with self.test_session():
- # df <= 1 ==> variance not defined
- student = student_t.StudentT(
- df=0.5, loc=0., scale=1., allow_nan_stats=False)
- with self.assertRaisesOpError("x < y"):
- self.evaluate(student.variance())
+ # df <= 1 ==> variance not defined
+ student = student_t.StudentT(
+ df=0.5, loc=0., scale=1., allow_nan_stats=False)
+ with self.assertRaisesOpError("x < y"):
+ self.evaluate(student.variance())
def testStd(self):
- with self.test_session():
- # Defined for all batch members.
- df = [3.5, 5., 3., 5., 7.]
- mu = [-2.2]
- sigma = [5., 4., 3., 2., 1.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- # Test broadcast of mu across shape of df/sigma
- stddev = self.evaluate(student.stddev())
- mu *= len(df)
+ # Defined for all batch members.
+ df = [3.5, 5., 3., 5., 7.]
+ mu = [-2.2]
+ sigma = [5., 4., 3., 2., 1.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ stddev = self.evaluate(student.stddev())
+ mu *= len(df)
- if not stats:
- return
- expected_stddev = [
- stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
- ]
- self.assertAllClose(expected_stddev, stddev)
+ if not stats:
+ return
+ expected_stddev = [
+ stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+ ]
+ self.assertAllClose(expected_stddev, stddev)
def testMode(self):
- with self.test_session():
- df = [0.5, 1., 3]
- mu = [-1, 0., 1]
- sigma = [5., 4., 3.]
- student = student_t.StudentT(df=df, loc=mu, scale=sigma)
- # Test broadcast of mu across shape of df/sigma
- mode = self.evaluate(student.mode())
- self.assertAllClose([-1., 0, 1], mode)
+ df = [0.5, 1., 3]
+ mu = [-1, 0., 1]
+ sigma = [5., 4., 3.]
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ # Test broadcast of mu across shape of df/sigma
+ mode = self.evaluate(student.mode())
+ self.assertAllClose([-1., 0, 1], mode)
def testPdfOfSample(self):
student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
@@ -510,25 +485,23 @@ class StudentTTest(test.TestCase):
self.assertNear(1., total, err=err)
def testNegativeDofFails(self):
- with self.test_session():
- with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
- student = student_t.StudentT(
- df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
- self.evaluate(student.mean())
+ with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
+ student = student_t.StudentT(
+ df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
+ self.evaluate(student.mean())
def testStudentTWithAbsDfSoftplusScale(self):
- with self.test_session():
- df = constant_op.constant([-3.2, -4.6])
- mu = constant_op.constant([-4.2, 3.4])
- sigma = constant_op.constant([-6.4, -8.8])
- student = student_t.StudentTWithAbsDfSoftplusScale(
- df=df, loc=mu, scale=sigma)
- self.assertAllClose(
- math_ops.floor(self.evaluate(math_ops.abs(df))),
- self.evaluate(student.df))
- self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
- self.assertAllClose(
- self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
+ df = constant_op.constant([-3.2, -4.6])
+ mu = constant_op.constant([-4.2, 3.4])
+ sigma = constant_op.constant([-6.4, -8.8])
+ student = student_t.StudentTWithAbsDfSoftplusScale(
+ df=df, loc=mu, scale=sigma)
+ self.assertAllClose(
+ math_ops.floor(self.evaluate(math_ops.abs(df))),
+ self.evaluate(student.df))
+ self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
+ self.assertAllClose(
+ self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index bc9c267b9a..9cdcd369c1 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -50,255 +50,239 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUniformRange(self):
- with self.test_session():
- a = 3.0
- b = 10.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- self.assertAllClose(a, self.evaluate(uniform.low))
- self.assertAllClose(b, self.evaluate(uniform.high))
- self.assertAllClose(b - a, self.evaluate(uniform.range()))
+ a = 3.0
+ b = 10.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ self.assertAllClose(a, self.evaluate(uniform.low))
+ self.assertAllClose(b, self.evaluate(uniform.high))
+ self.assertAllClose(b - a, self.evaluate(uniform.range()))
@test_util.run_in_graph_and_eager_modes
def testUniformPDF(self):
- with self.test_session():
- a = constant_op.constant([-3.0] * 5 + [15.0])
- b = constant_op.constant([11.0] * 5 + [20.0])
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([-3.0] * 5 + [15.0])
+ b = constant_op.constant([11.0] * 5 + [20.0])
+ uniform = uniform_lib.Uniform(low=a, high=b)
- a_v = -3.0
- b_v = 11.0
- x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
+ a_v = -3.0
+ b_v = 11.0
+ x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
- def _expected_pdf():
- pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
- pdf[x > b_v] = 0.0
- pdf[x < a_v] = 0.0
- pdf[5] = 1.0 / (20.0 - 15.0)
- return pdf
+ def _expected_pdf():
+ pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
+ pdf[x > b_v] = 0.0
+ pdf[x < a_v] = 0.0
+ pdf[5] = 1.0 / (20.0 - 15.0)
+ return pdf
- expected_pdf = _expected_pdf()
+ expected_pdf = _expected_pdf()
- pdf = uniform.prob(x)
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob(x)
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
- log_pdf = uniform.log_prob(x)
- self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
+ log_pdf = uniform.log_prob(x)
+ self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformShape(self):
- with self.test_session():
- a = constant_op.constant([-3.0] * 5)
- b = constant_op.constant(11.0)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([-3.0] * 5)
+ b = constant_op.constant(11.0)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
- self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
- self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
+ self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
+ self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
+ self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
@test_util.run_in_graph_and_eager_modes
def testUniformPDFWithScalarEndpoint(self):
- with self.test_session():
- a = constant_op.constant([0.0, 5.0])
- b = constant_op.constant(10.0)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = constant_op.constant([0.0, 5.0])
+ b = constant_op.constant(10.0)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- x = np.array([0.0, 8.0], dtype=np.float32)
- expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
+ x = np.array([0.0, 8.0], dtype=np.float32)
+ expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
- pdf = uniform.prob(x)
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob(x)
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformCDF(self):
- with self.test_session():
- batch_size = 6
- a = constant_op.constant([1.0] * batch_size)
- b = constant_op.constant([11.0] * batch_size)
- a_v = 1.0
- b_v = 11.0
- x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
+ batch_size = 6
+ a = constant_op.constant([1.0] * batch_size)
+ b = constant_op.constant([11.0] * batch_size)
+ a_v = 1.0
+ b_v = 11.0
+ x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
- uniform = uniform_lib.Uniform(low=a, high=b)
+ uniform = uniform_lib.Uniform(low=a, high=b)
- def _expected_cdf():
- cdf = (x - a_v) / (b_v - a_v)
- cdf[x >= b_v] = 1
- cdf[x < a_v] = 0
- return cdf
+ def _expected_cdf():
+ cdf = (x - a_v) / (b_v - a_v)
+ cdf[x >= b_v] = 1
+ cdf[x < a_v] = 0
+ return cdf
- cdf = uniform.cdf(x)
- self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
+ cdf = uniform.cdf(x)
+ self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
- log_cdf = uniform.log_cdf(x)
- self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
+ log_cdf = uniform.log_cdf(x)
+ self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
@test_util.run_in_graph_and_eager_modes
def testUniformEntropy(self):
- with self.test_session():
- a_v = np.array([1.0, 1.0, 1.0])
- b_v = np.array([[1.5, 2.0, 3.0]])
- uniform = uniform_lib.Uniform(low=a_v, high=b_v)
+ a_v = np.array([1.0, 1.0, 1.0])
+ b_v = np.array([[1.5, 2.0, 3.0]])
+ uniform = uniform_lib.Uniform(low=a_v, high=b_v)
- expected_entropy = np.log(b_v - a_v)
- self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
+ expected_entropy = np.log(b_v - a_v)
+ self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
@test_util.run_in_graph_and_eager_modes
def testUniformAssertMaxGtMin(self):
- with self.test_session():
- a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
- b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+ a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
+ b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- "x < y"):
- uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
- self.evaluate(uniform.low)
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ "x < y"):
+ uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
+ self.evaluate(uniform.low)
@test_util.run_in_graph_and_eager_modes
def testUniformSample(self):
- with self.test_session():
- a = constant_op.constant([3.0, 4.0])
- b = constant_op.constant(13.0)
- a1_v = 3.0
- a2_v = 4.0
- b_v = 13.0
- n = constant_op.constant(100000)
- uniform = uniform_lib.Uniform(low=a, high=b)
-
- samples = uniform.sample(n, seed=137)
- sample_values = self.evaluate(samples)
- self.assertEqual(sample_values.shape, (100000, 2))
- self.assertAllClose(
- sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
- self.assertAllClose(
- sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
- self.assertFalse(
- np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
- self.assertFalse(
- np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
+ a = constant_op.constant([3.0, 4.0])
+ b = constant_op.constant(13.0)
+ a1_v = 3.0
+ a2_v = 4.0
+ b_v = 13.0
+ n = constant_op.constant(100000)
+ uniform = uniform_lib.Uniform(low=a, high=b)
+
+ samples = uniform.sample(n, seed=137)
+ sample_values = self.evaluate(samples)
+ self.assertEqual(sample_values.shape, (100000, 2))
+ self.assertAllClose(
+ sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
+ self.assertAllClose(
+ sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
+ self.assertFalse(
+ np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
+ self.assertFalse(
+ np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
@test_util.run_in_graph_and_eager_modes
def _testUniformSampleMultiDimensional(self):
# DISABLED: Please enable this test once b/issues/30149644 is resolved.
- with self.test_session():
- batch_size = 2
- a_v = [3.0, 22.0]
- b_v = [13.0, 35.0]
- a = constant_op.constant([a_v] * batch_size)
- b = constant_op.constant([b_v] * batch_size)
-
- uniform = uniform_lib.Uniform(low=a, high=b)
-
- n_v = 100000
- n = constant_op.constant(n_v)
- samples = uniform.sample(n)
- self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
-
- sample_values = self.evaluate(samples)
-
- self.assertFalse(
- np.any(sample_values[:, 0, 0] < a_v[0]) or
- np.any(sample_values[:, 0, 0] >= b_v[0]))
- self.assertFalse(
- np.any(sample_values[:, 0, 1] < a_v[1]) or
- np.any(sample_values[:, 0, 1] >= b_v[1]))
-
- self.assertAllClose(
- sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
- self.assertAllClose(
- sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
+ batch_size = 2
+ a_v = [3.0, 22.0]
+ b_v = [13.0, 35.0]
+ a = constant_op.constant([a_v] * batch_size)
+ b = constant_op.constant([b_v] * batch_size)
+
+ uniform = uniform_lib.Uniform(low=a, high=b)
+
+ n_v = 100000
+ n = constant_op.constant(n_v)
+ samples = uniform.sample(n)
+ self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
+
+ sample_values = self.evaluate(samples)
+
+ self.assertFalse(
+ np.any(sample_values[:, 0, 0] < a_v[0]) or
+ np.any(sample_values[:, 0, 0] >= b_v[0]))
+ self.assertFalse(
+ np.any(sample_values[:, 0, 1] < a_v[1]) or
+ np.any(sample_values[:, 0, 1] >= b_v[1]))
+
+ self.assertAllClose(
+ sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
+ self.assertAllClose(
+ sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
@test_util.run_in_graph_and_eager_modes
def testUniformMean(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
@test_util.run_in_graph_and_eager_modes
def testUniformVariance(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
@test_util.run_in_graph_and_eager_modes
def testUniformStd(self):
- with self.test_session():
- a = 10.0
- b = 100.0
- uniform = uniform_lib.Uniform(low=a, high=b)
- if not stats:
- return
- s_uniform = stats.uniform(loc=a, scale=b - a)
- self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
+ a = 10.0
+ b = 100.0
+ uniform = uniform_lib.Uniform(low=a, high=b)
+ if not stats:
+ return
+ s_uniform = stats.uniform(loc=a, scale=b - a)
+ self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
@test_util.run_in_graph_and_eager_modes
def testUniformNans(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 100.0]
- uniform = uniform_lib.Uniform(low=a, high=b)
+ a = 10.0
+ b = [11.0, 100.0]
+ uniform = uniform_lib.Uniform(low=a, high=b)
- no_nans = constant_op.constant(1.0)
- nans = constant_op.constant(0.0) / constant_op.constant(0.0)
- self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
- with_nans = array_ops.stack([no_nans, nans])
+ no_nans = constant_op.constant(1.0)
+ nans = constant_op.constant(0.0) / constant_op.constant(0.0)
+ self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
+ with_nans = array_ops.stack([no_nans, nans])
- pdf = uniform.prob(with_nans)
+ pdf = uniform.prob(with_nans)
- is_nan = self.evaluate(math_ops.is_nan(pdf))
- self.assertFalse(is_nan[0])
- self.assertTrue(is_nan[1])
+ is_nan = self.evaluate(math_ops.is_nan(pdf))
+ self.assertFalse(is_nan[0])
+ self.assertTrue(is_nan[1])
@test_util.run_in_graph_and_eager_modes
def testUniformSamplePdf(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 100.0]
- uniform = uniform_lib.Uniform(a, b)
- self.assertTrue(
- self.evaluate(
- math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
+ a = 10.0
+ b = [11.0, 100.0]
+ uniform = uniform_lib.Uniform(a, b)
+ self.assertTrue(
+ self.evaluate(
+ math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
@test_util.run_in_graph_and_eager_modes
def testUniformBroadcasting(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 20.0]
- uniform = uniform_lib.Uniform(a, b)
+ a = 10.0
+ b = [11.0, 20.0]
+ uniform = uniform_lib.Uniform(a, b)
- pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
- expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
+ expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
@test_util.run_in_graph_and_eager_modes
def testUniformSampleWithShape(self):
- with self.test_session():
- a = 10.0
- b = [11.0, 20.0]
- uniform = uniform_lib.Uniform(a, b)
-
- pdf = uniform.prob(uniform.sample((2, 3)))
- # pylint: disable=bad-continuation
- expected_pdf = [
- [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
- [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
- ]
- # pylint: enable=bad-continuation
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
-
- pdf = uniform.prob(uniform.sample())
- expected_pdf = [1.0, 0.1]
- self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ a = 10.0
+ b = [11.0, 20.0]
+ uniform = uniform_lib.Uniform(a, b)
+
+ pdf = uniform.prob(uniform.sample((2, 3)))
+ # pylint: disable=bad-continuation
+ expected_pdf = [
+ [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+ [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+ ]
+ # pylint: enable=bad-continuation
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
+
+ pdf = uniform.prob(uniform.sample())
+ expected_pdf = [1.0, 0.1]
+ self.assertAllClose(expected_pdf, self.evaluate(pdf))
def testFullyReparameterized(self):
a = constant_op.constant(0.1)
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 61faa8466e..27d652c2c6 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase):
w = array_ops.placeholder(dtypes.float32)
feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
- with self.test_session():
+ with self.cached_session():
with ops.control_dependencies([du.assert_integer_form(x)]):
array_ops.identity(x).eval(feed_dict=feed_dict)
@@ -122,58 +122,52 @@ class GetLogitsAndProbsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testImproperArguments(self):
- with self.test_session():
- with self.assertRaises(ValueError):
- du.get_logits_and_probs(logits=None, probs=None)
+ with self.assertRaises(ValueError):
+ du.get_logits_and_probs(logits=None, probs=None)
- with self.assertRaises(ValueError):
- du.get_logits_and_probs(logits=[0.1], probs=[0.1])
+ with self.assertRaises(ValueError):
+ du.get_logits_and_probs(logits=[0.1], probs=[0.1])
@test_util.run_in_graph_and_eager_modes
def testLogits(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
logits = _logit(p)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- logits=logits, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ logits=logits, validate_args=True)
- self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
- self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
+ self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
+ self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
@test_util.run_in_graph_and_eager_modes
def testLogitsMultidimensional(self):
p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
logits = np.log(p)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- logits=logits, multidimensional=True, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ logits=logits, multidimensional=True, validate_args=True)
- self.assertAllClose(self.evaluate(new_p), p)
- self.assertAllClose(self.evaluate(new_logits), logits)
+ self.assertAllClose(self.evaluate(new_p), p)
+ self.assertAllClose(self.evaluate(new_logits), logits)
@test_util.run_in_graph_and_eager_modes
def testProbability(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- probs=p, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
- self.assertAllClose(_logit(p), self.evaluate(new_logits))
- self.assertAllClose(p, self.evaluate(new_p))
+ self.assertAllClose(_logit(p), self.evaluate(new_logits))
+ self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
- with self.test_session():
- new_logits, new_p = du.get_logits_and_probs(
- probs=p, multidimensional=True, validate_args=True)
+ new_logits, new_p = du.get_logits_and_probs(
+ probs=p, multidimensional=True, validate_args=True)
- self.assertAllClose(np.log(p), self.evaluate(new_logits))
- self.assertAllClose(p, self.evaluate(new_p))
+ self.assertAllClose(np.log(p), self.evaluate(new_logits))
+ self.assertAllClose(p, self.evaluate(new_p))
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgs(self):
@@ -183,29 +177,23 @@ class GetLogitsAndProbsTest(test.TestCase):
# Component greater than 1.
p3 = [2, 0.2, 0.5, 0.3, .2]
- with self.test_session():
- _, prob = du.get_logits_and_probs(
- probs=p, validate_args=True)
- self.evaluate(prob)
-
- with self.assertRaisesOpError("Condition x >= 0"):
- _, prob = du.get_logits_and_probs(
- probs=p2, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p, validate_args=True)
+ self.evaluate(prob)
- _, prob = du.get_logits_and_probs(
- probs=p2, validate_args=False)
+ with self.assertRaisesOpError("Condition x >= 0"):
+ _, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError("probs has components greater than 1"):
- _, prob = du.get_logits_and_probs(
- probs=p3, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
+ self.evaluate(prob)
- _, prob = du.get_logits_and_probs(
- probs=p3, validate_args=False)
+ with self.assertRaisesOpError("probs has components greater than 1"):
+ _, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
+ self.evaluate(prob)
+
@test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgsMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
@@ -216,41 +204,39 @@ class GetLogitsAndProbsTest(test.TestCase):
# Does not sum to 1.
p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
- with self.test_session():
- _, prob = du.get_logits_and_probs(
- probs=p, multidimensional=True)
- self.evaluate(prob)
-
- with self.assertRaisesOpError("Condition x >= 0"):
- _, prob = du.get_logits_and_probs(
- probs=p2, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
+ self.evaluate(prob)
+ with self.assertRaisesOpError("Condition x >= 0"):
_, prob = du.get_logits_and_probs(
- probs=p2, multidimensional=True, validate_args=False)
+ probs=p2, multidimensional=True, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError(
- "(probs has components greater than 1|probs does not sum to 1)"):
- _, prob = du.get_logits_and_probs(
- probs=p3, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p2, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+ with self.assertRaisesOpError(
+ "(probs has components greater than 1|probs does not sum to 1)"):
_, prob = du.get_logits_and_probs(
- probs=p3, multidimensional=True, validate_args=False)
+ probs=p3, multidimensional=True, validate_args=True)
self.evaluate(prob)
- with self.assertRaisesOpError("probs does not sum to 1"):
- _, prob = du.get_logits_and_probs(
- probs=p4, multidimensional=True, validate_args=True)
- self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p3, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+ with self.assertRaisesOpError("probs does not sum to 1"):
_, prob = du.get_logits_and_probs(
- probs=p4, multidimensional=True, validate_args=False)
+ probs=p4, multidimensional=True, validate_args=True)
self.evaluate(prob)
+ _, prob = du.get_logits_and_probs(
+ probs=p4, multidimensional=True, validate_args=False)
+ self.evaluate(prob)
+
def testProbsMultidimShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
p = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
@@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase):
prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
def testLogitsMultidimShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
l = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs(
@@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase):
class EmbedCheckCategoricalEventShapeTest(test.TestCase):
def testTooSmall(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([1], dtype=np.float16)
checked_param = du.embed_check_categorical_event_shape(
@@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
checked_param.eval(feed_dict={param: np.ones([1])})
def testTooLarge(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape(
@@ -310,18 +296,17 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self):
- with self.test_session():
- param = ops.convert_to_tensor(
- np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
- dtype=dtypes.qint16)
- with self.assertRaises(TypeError):
- du.embed_check_categorical_event_shape(param)
+ param = ops.convert_to_tensor(
+ np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
+ dtype=dtypes.qint16)
+ with self.assertRaises(TypeError):
+ du.embed_check_categorical_event_shape(param)
class EmbedCheckIntegerCastingClosedTest(test.TestCase):
def testCorrectlyAssertsNonnegative(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements must be non-negative"):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
@@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
def testCorrectlyAssersIntegerForm(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements must be int16-equivalent."):
x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed(
@@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
def testCorrectlyAssertsLargestPossibleInteger(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements cannot exceed 32767."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
@@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
def testCorrectlyAssertsSmallestPossibleInteger(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Elements cannot be smaller than 0."):
x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed(
@@ -365,29 +350,27 @@ class LogCombinationsTest(test.TestCase):
log_combs = np.log(special.binom(n, k))
- with self.test_session():
- n = np.array(n, dtype=np.float32)
- counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
- log_binom = du.log_combinations(n, counts)
- self.assertEqual([4], log_binom.get_shape())
- self.assertAllClose(log_combs, self.evaluate(log_binom))
+ n = np.array(n, dtype=np.float32)
+ counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
+ log_binom = du.log_combinations(n, counts)
+ self.assertEqual([4], log_binom.get_shape())
+ self.assertAllClose(log_combs, self.evaluate(log_binom))
def testLogCombinationsShape(self):
# Shape [2, 2]
n = [[2, 5], [12, 15]]
- with self.test_session():
- n = np.array(n, dtype=np.float32)
- # Shape [2, 2, 4]
- counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
- log_binom = du.log_combinations(n, counts)
- self.assertEqual([2, 2], log_binom.get_shape())
+ n = np.array(n, dtype=np.float32)
+ # Shape [2, 2, 4]
+ counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
+ log_binom = du.log_combinations(n, counts)
+ self.assertEqual([2, 2], log_binom.get_shape())
class DynamicShapeTest(test.TestCase):
def testSameDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
scalar = constant_op.constant(2.0)
scalar1 = array_ops.placeholder(dtype=dtypes.float32)
@@ -497,22 +480,21 @@ class RotateTransposeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testRollStatic(self):
- with self.test_session():
- if context.executing_eagerly():
- error_message = r"Attempt to convert a value \(None\)"
- else:
- error_message = "None values not supported."
- with self.assertRaisesRegexp(ValueError, error_message):
- du.rotate_transpose(None, 1)
- for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
- for shift in np.arange(-5, 5):
- y = du.rotate_transpose(x, shift)
- self.assertAllEqual(
- self._np_rotate_transpose(x, shift), self.evaluate(y))
- self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
+ if context.executing_eagerly():
+ error_message = r"Attempt to convert a value \(None\)"
+ else:
+ error_message = "None values not supported."
+ with self.assertRaisesRegexp(ValueError, error_message):
+ du.rotate_transpose(None, 1)
+ for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
+ for shift in np.arange(-5, 5):
+ y = du.rotate_transpose(x, shift)
+ self.assertAllEqual(
+ self._np_rotate_transpose(x, shift), self.evaluate(y))
+ self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
def testRollDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
shift = array_ops.placeholder(dtypes.int32)
for x_value in (np.ones(
@@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase):
class PickVectorTest(test.TestCase):
def testCorrectlyPicksVector(self):
- with self.test_session():
+ with self.cached_session():
x = np.arange(10, 12)
y = np.arange(15, 18)
self.assertAllEqual(
@@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase):
def testDynamicRankEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicRankEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
def testDynamicRankEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
@@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase):
def testDynamicShapeEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicShapeEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
def testDynamicShapeEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
@@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase):
def testDynamicValueEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.zeros((2, 3)),
value.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicValueEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
def testDynamicValueEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
@@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
# Add `zeros_like(x)` such that x's value and gradient are identical. We
@@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest):
def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
@@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.],
[-5, 0, 5]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
expected = math_ops.reduce_logsumexp(logx, axis=-1)
grad_expected = gradients_impl.gradients(expected, logx)[0]
@@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, -2, 1],
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(
logx_, w_, axis=-1, keep_dims=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logx = constant_op.constant(logx_)
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
def testDocString(self):
"""This test verifies the correctness of the docstring examples."""
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[0., 0, 0],
[0, 0, 0]])
@@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 1e-4)
def testInverseSoftplusGradientNeverNan(self):
- with self.test_session():
+ with self.cached_session():
# Note that this range contains both zero and inf.
x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
y = du.softplus_inverse(x)
@@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase):
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
def testInverseSoftplusGradientFinite(self):
- with self.test_session():
+ with self.cached_session():
# This range of x is all finite, and so is 1 / x. So the
# gradient and its approximations should be finite as well.
x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index c4d4ce780b..49b9569e2b 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -104,6 +104,27 @@ class DynamicStitchTestBase(object):
# Dimension 0 is max(flatten(indices))+1.
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+ def testZeroSizeTensor(self):
+ with self.test_session(use_gpu=True):
+ indices = [
+ constant_op.constant([0, 4, 7]),
+ constant_op.constant([1, 6]),
+ constant_op.constant([2, 3, 5]),
+ array_ops.zeros([0], dtype=dtypes.int32)
+ ]
+ data = [
+ constant_op.constant([[0, 1], [40, 41], [70, 71]]),
+ constant_op.constant([[10, 11], [60, 61]]),
+ constant_op.constant([[20, 21], [30, 31], [50, 51]]),
+ array_ops.zeros([0, 2], dtype=dtypes.int32)
+ ]
+ stitched_t = self.stitch_op(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
+ [50, 51], [60, 61], [70, 71]], stitched_val)
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
indices = [
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 55d75cb474..dcd435e1ff 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -480,7 +480,7 @@ class EmbeddingLookupTest(test.TestCase):
id_vals, shape=ids_shape, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
y = embedding_ops.embedding_lookup(x, ids)
- y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
+ y_shape = ids_shape + tuple(params[_PName(0) + ":0"].shape[1:])
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
@@ -663,8 +663,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
for num_shards, combiner, dtype, ignore_weights in itertools.product(
- [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
- dtypes.float64], [True, False]):
+ [1, 5], ["sum", "mean", "sqrtn"],
+ [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
+ [True, False]):
with self.test_session():
p, params, feed_dict = _EmbeddingParams(
@@ -677,6 +678,10 @@ class EmbeddingLookupSparseTest(test.TestCase):
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
+ if dtype in (dtypes.float16, dtypes.bfloat16):
+ self.assertEqual(embedding_sum.dtype, dtypes.float32)
+ else:
+ self.assertEqual(embedding_sum.dtype, dtype)
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
@@ -692,7 +697,14 @@ class EmbeddingLookupSparseTest(test.TestCase):
if combiner == "sqrtn":
np_embedding_sum /= np.reshape(
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
- self.assertAllClose(np_embedding_sum, tf_embedding_sum)
+
+ rtol = 1e-6
+ if dtype == dtypes.bfloat16:
+ rtol = 1e-2
+ elif dtype == dtypes.float16:
+ rtol = 1e-3
+ atol = rtol
+ self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol)
def testGradientsEmbeddingLookupSparse(self):
vocab_size = 12
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index 60090a1510..e1f5a6b620 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -25,6 +25,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@@ -100,6 +102,24 @@ class ExtractImagePatchesGradTest(test.TestCase):
print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
+ def testConstructGradientWithLargeImages(self):
+ batch_size = 4
+ height = 1024
+ width = 1024
+ ksize = 5
+ images = variable_scope.get_variable('inputs',
+ (batch_size, height, width, 1))
+ patches = array_ops.extract_image_patches(images,
+ ksizes=[1, ksize, ksize, 1],
+ strides=[1, 1, 1, 1],
+ rates=[1, 1, 1, 1],
+ padding='SAME')
+ # Github issue: #20146
+ # tf.extract_image_patches() gradient very slow at graph construction time
+ gradients = gradients_impl.gradients(patches, images)
+ # Won't time out.
+ self.assertIsNotNone(gradients)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 5db2e9821d..e39daf1371 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
@@ -59,42 +60,48 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldl_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+ elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
- r = functional_ops.foldl(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems)
- self.assertAllEqual(208, self.evaluate(r))
+ r = functional_ops.foldl(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems)
+ self.assertAllEqual(208, self.evaluate(r))
- r = functional_ops.foldl(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems,
- initializer=10)
- self.assertAllEqual(880, self.evaluate(r))
+ r = functional_ops.foldl(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems,
+ initializer=10)
+ self.assertAllEqual(880, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldl_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array([1, -1.0])
- r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual(22, r_value[0])
- self.assertAllEqual(20, r_value[1])
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldl_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
- initializer)
- self.assertAllEqual(1, self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFoldl_MultiInputDifferentDimsSingleOutput(self):
+ elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
+ other_elems = np.array([-1.0, 1.0])
+ initializer = np.array([0.0, 0.0, 0.0])
+ r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
+ (elems, other_elems), initializer)
+ self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
def testFoldl_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -114,42 +121,39 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldr_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+ elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
- r = functional_ops.foldr(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems)
- self.assertAllEqual(450, self.evaluate(r))
+ r = functional_ops.foldr(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems)
+ self.assertAllEqual(450, self.evaluate(r))
- r = functional_ops.foldr(
- lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
- elems,
- initializer=10)
- self.assertAllEqual(1282, self.evaluate(r))
+ r = functional_ops.foldr(
+ lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+ elems,
+ initializer=10)
+ self.assertAllEqual(1282, self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testFoldr_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array([1, -1.0])
- r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual(22, r_value[0])
- self.assertAllEqual(20, r_value[1])
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
@test_util.run_in_graph_and_eager_modes
def testFoldr_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
- initializer)
- self.assertAllEqual(1, self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
def testFoldr_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -169,7 +173,7 @@ class FunctionalOpsTest(test.TestCase):
# pylint: disable=unnecessary-lambda
def testFold_Grad(self):
- with self.test_session():
+ with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
r = functional_ops.foldl(
@@ -185,16 +189,15 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMap_Simple(self):
- with self.test_session():
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, name="data")
- r = functional_ops.map_fn(
- lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
- self.assertAllEqual(
- np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, name="data")
+ r = functional_ops.map_fn(
+ lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
+ self.assertAllEqual(
+ np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
def testMapSparseTensor(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
functional_ops.map_fn(
lambda x: x,
@@ -211,7 +214,7 @@ class FunctionalOpsTest(test.TestCase):
functional_ops.map_fn(lambda x: x, 1)
def testMap_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def double_scoped(x):
"""2x with a dummy 2 that is scoped."""
@@ -242,7 +245,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(doubles, self.evaluate(r))
def testMap_Grad(self):
- with self.test_session():
+ with self.cached_session():
param = constant_op.constant(2.0)
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
y = functional_ops.map_fn(
@@ -254,142 +257,131 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMap_SimpleNotTensor(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
- self.assertAllEqual(
- np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
+ self.assertAllEqual(
+ np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testMap_SingleInputMultiOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: ((x + 3) * 2, -(x + 3) * 2),
- nums,
- dtype=(dtypes.int64, dtypes.int64))
- self.assertEqual(2, len(r))
- self.assertEqual((6,), r[0].get_shape())
- self.assertEqual((6,), r[1].get_shape())
- received = self.evaluate(r)
- self.assertAllEqual((nums + 3) * 2, received[0])
- self.assertAllEqual(-(nums + 3) * 2, received[1])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: ((x + 3) * 2, -(x + 3) * 2),
+ nums,
+ dtype=(dtypes.int64, dtypes.int64))
+ self.assertEqual(2, len(r))
+ self.assertEqual((6,), r[0].get_shape())
+ self.assertEqual((6,), r[1].get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual((nums + 3) * 2, received[0])
+ self.assertAllEqual(-(nums + 3) * 2, received[1])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiOutputMismatchedDtype(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- with self.assertRaisesRegexp(
- TypeError, r"two structures don't have the same nested structure"):
- # lambda emits tuple, but dtype is a list
- functional_ops.map_fn(
- lambda x: ((x + 3) * 2, -(x + 3) * 2),
- nums,
- dtype=[dtypes.int64, dtypes.int64])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ with self.assertRaisesRegexp(
+ TypeError, r"two structures don't have the same nested structure"):
+ # lambda emits tuple, but dtype is a list
+ functional_ops.map_fn(
+ lambda x: ((x + 3) * 2, -(x + 3) * 2),
+ nums,
+ dtype=[dtypes.int64, dtypes.int64])
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSingleOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(
- lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
- dtype=dtypes.int64)
- self.assertEqual((6,), r.get_shape())
- received = self.evaluate(r)
- self.assertAllEqual(nums * nums + (-nums), received)
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(
+ lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
+ dtype=dtypes.int64)
+ self.assertEqual((6,), r.get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual(nums * nums + (-nums), received)
@test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSameStructureOutput(self):
- with self.test_session():
- nums = np.array([1, 2, 3, 4, 5, 6])
- r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
- (nums, (2 * nums, -nums)))
- r = [r[0], r[1][0], r[1][1]]
- self.assertEqual((6,), r[0].get_shape())
- self.assertEqual((6,), r[1].get_shape())
- self.assertEqual((6,), r[2].get_shape())
- received = self.evaluate(r)
- self.assertAllEqual(2 * nums, received[0])
- self.assertAllEqual(-nums, received[1])
- self.assertAllEqual(nums, received[2])
+ nums = np.array([1, 2, 3, 4, 5, 6])
+ r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
+ (nums, (2 * nums, -nums)))
+ r = [r[0], r[1][0], r[1][1]]
+ self.assertEqual((6,), r[0].get_shape())
+ self.assertEqual((6,), r[1].get_shape())
+ self.assertEqual((6,), r[2].get_shape())
+ received = self.evaluate(r)
+ self.assertAllEqual(2 * nums, received[0])
+ self.assertAllEqual(-nums, received[1])
+ self.assertAllEqual(nums, received[2])
@test_util.run_in_graph_and_eager_modes
def testScan_Simple(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
- v = constant_op.constant(2.0, name="v")
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+ v = constant_op.constant(2.0, name="v")
- # pylint: disable=unnecessary-lambda
- r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
- self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
+ # pylint: disable=unnecessary-lambda
+ r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
+ self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
- r = functional_ops.scan(
- lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
- self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
- # pylint: enable=unnecessary-lambda
+ r = functional_ops.scan(
+ lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
+ self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
+ # pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_Reverse(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
- v = constant_op.constant(2.0, name="v")
-
- # pylint: disable=unnecessary-lambda
- r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
- reverse=True)
- self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
- r = functional_ops.scan(
- lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
- reverse=True)
- self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
- self.evaluate(r))
- # pylint: enable=unnecessary-lambda
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+ v = constant_op.constant(2.0, name="v")
+
+ # pylint: disable=unnecessary-lambda
+ r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
+ reverse=True)
+ self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
+ r = functional_ops.scan(
+ lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
+ reverse=True)
+ self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
+ self.evaluate(r))
+ # pylint: enable=unnecessary-lambda
@test_util.run_in_graph_and_eager_modes
def testScan_SingleInputMultiOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = (np.array(1.0), np.array(-1.0))
- r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
- initializer)
- r_value = self.evaluate(r)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = (np.array(1.0), np.array(-1.0))
+ r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
+ initializer)
+ r_value = self.evaluate(r)
- self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
- self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
+ self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
+ self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSingleOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- # Multiply a * 1 each time
- r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
- (elems + 1, -elems), initializer)
- self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ # Multiply a * 1 each time
+ r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
+ (elems + 1, -elems), initializer)
+ self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
@test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSameTypeOutput(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
- (elems, -elems))
- r_value = self.evaluate(r)
- self.assertAllEqual(np.cumsum(elems), r_value[0])
- self.assertAllEqual(np.cumsum(-elems), r_value[1])
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
+ (elems, -elems))
+ r_value = self.evaluate(r)
+ self.assertAllEqual(np.cumsum(elems), r_value[0])
+ self.assertAllEqual(np.cumsum(-elems), r_value[1])
@test_util.run_in_graph_and_eager_modes
def testScan_MultiOutputMismatchedInitializer(self):
- with self.test_session():
- elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
- initializer = np.array(1.0)
- # Multiply a * 1 each time
- with self.assertRaisesRegexp(
- ValueError, "two structures don't have the same nested structure"):
- functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ # Multiply a * 1 each time
+ with self.assertRaisesRegexp(
+ ValueError, "two structures don't have the same nested structure"):
+ functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
def testScan_Scoped(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root") as varscope:
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -411,30 +403,29 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testScanFoldl_Nested(self):
- with self.test_session():
- elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
- inner_elems = constant_op.constant([0.5, 0.5], name="data")
-
- def r_inner(a, x):
- return functional_ops.foldl(
- lambda b, y: b * y * x, inner_elems, initializer=a)
-
- r = functional_ops.scan(r_inner, elems)
-
- # t == 0 (returns 1)
- # t == 1, a == 1, x == 2 (returns 1)
- # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
- # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
- # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
- # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
- # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
- # t == 3, a == 2.25, x == 4 (returns 9)
- # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
- # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
- self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
+ inner_elems = constant_op.constant([0.5, 0.5], name="data")
+
+ def r_inner(a, x):
+ return functional_ops.foldl(
+ lambda b, y: b * y * x, inner_elems, initializer=a)
+
+ r = functional_ops.scan(r_inner, elems)
+
+ # t == 0 (returns 1)
+ # t == 1, a == 1, x == 2 (returns 1)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
+ # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
+ # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
+ # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
+ # t == 3, a == 2.25, x == 4 (returns 9)
+ # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
+ # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
+ self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
def testScan_Control(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
s = array_ops.placeholder(dtypes.float32, shape=[None])
b = array_ops.placeholder(dtypes.bool)
@@ -445,7 +436,7 @@ class FunctionalOpsTest(test.TestCase):
b: True}))
def testScan_Grad(self):
- with self.test_session():
+ with self.cached_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = constant_op.constant(2.0, name="v")
@@ -470,22 +461,20 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFoldShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- def fn(_, current_input):
- return current_input
+ def fn(_, current_input):
+ return current_input
- initializer = constant_op.constant([0, 0, 0])
- y = functional_ops.foldl(fn, x, initializer=initializer)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ initializer = constant_op.constant([0, 0, 0])
+ y = functional_ops.foldl(fn, x, initializer=initializer)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
@test_util.run_in_graph_and_eager_modes
def testMapShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- y = functional_ops.map_fn(lambda e: e, x)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ y = functional_ops.map_fn(lambda e: e, x)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
def testMapUnknownShape(self):
x = array_ops.placeholder(dtypes.float32)
@@ -494,15 +483,14 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testMapEmptyScalar(self):
- with self.test_session():
- map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
- self.assertAllEqual([0], map_return.get_shape().dims)
- self.assertAllEqual([0], self.evaluate(map_return).shape)
+ map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
+ self.assertAllEqual([0], map_return.get_shape().dims)
+ self.assertAllEqual([0], self.evaluate(map_return).shape)
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
def testMapEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]),
constant_op.constant([]))
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
@@ -510,20 +498,19 @@ class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testScanShape(self):
- with self.test_session():
- x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- def fn(_, current_input):
- return current_input
+ def fn(_, current_input):
+ return current_input
- initializer = constant_op.constant([0, 0, 0])
- y = functional_ops.scan(fn, x, initializer=initializer)
- self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+ initializer = constant_op.constant([0, 0, 0])
+ y = functional_ops.scan(fn, x, initializer=initializer)
+ self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
# so the body of the while loop never executes
def testScanEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
x = functional_ops.scan(
lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
self.assertAllEqual([0, 2, 4], x.get_shape())
@@ -540,7 +527,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertIs(None, y.get_shape().dims)
def testScanVaryingShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
x_t = array_ops.transpose(x)
# scan over dimension 0 (with shape None)
@@ -619,7 +606,7 @@ class FunctionalOpsTest(test.TestCase):
remote_op = functional_ops.remote_call(
args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, [6])
@@ -643,7 +630,7 @@ class FunctionalOpsTest(test.TestCase):
f=_remote_fn,
target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
@@ -667,7 +654,7 @@ class FunctionalOpsTest(test.TestCase):
f=_remote_fn,
target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
@@ -686,7 +673,7 @@ class FunctionalOpsTest(test.TestCase):
remote_op = functional_ops.remote_call(
args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ret = sess.run(remote_op)
self.assertAllEqual(ret, [b"a"])
@@ -752,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
+ def testWhileLowering(self):
+
+ def Run(n, fetch_by_name):
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # outputs: [0, n*(n+1)/2]
+ outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
+
+ # `outputs` is the list of output tensors of the While op. We
+ # arbitrarily choose the 0th tensor to get the While op and set the
+ # lowering attribute on it.
+ outputs[0].op._set_attr("_lower_using_switch_merge",
+ attr_value_pb2.AttrValue(b=True))
+ if not fetch_by_name:
+ fetch = outputs[1]
+ else:
+ fetch = "my_while:1"
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ return sess.run(fetch)
+
+ self.assertAllEqual(Run(20., False), 210.)
+ self.assertAllEqual(Run(20., True), 210.)
+ self.assertAllEqual(Run(100., False), 5050.)
+ self.assertAllEqual(Run(100., True), 5050.)
+
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@@ -1075,30 +1096,13 @@ class PartitionedCallTest(test.TestCase):
with ops.device("/cpu:2"):
s3 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
- with ops.device(""):
- # TODO(akshayka): This is unfortunate and brittle. It prevents
- # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'.
- # Remove this hack once we have a way of obtaining metadata about
- # function execution.
- s4 = iterator_ops.Iterator.from_structure(
- (dtypes.float32,)).string_handle()
- return s1, s2, s3, s4
+ return s1, s2, s3
with self.test_session(config=config, use_gpu=True) as sess:
- with ops.device("/cpu:3"):
- outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
- self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
- self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
- self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
- self.assertIn(compat.as_bytes("CPU:3"), outputs[3])
-
- with self.test_session(config=config, use_gpu=True):
- with ops.device("/cpu:0"):
- outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
+ outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
- self.assertIn(compat.as_bytes("CPU:0"), outputs[3])
def testAssignAddResourceVariable(self):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
index 612a50bcec..99497914f2 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
@@ -191,7 +191,7 @@ class NonSquareLinearOperatorCompositionTest(
linalg.LinearOperatorFullMatrix(rng.rand(2, 4, 5))
]
operator = linalg.LinearOperatorComposition(operators)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
def test_shape_tensors_when_only_dynamically_available(self):
@@ -206,7 +206,7 @@ class NonSquareLinearOperatorCompositionTest(
linalg.LinearOperatorFullMatrix(mat_ph_2)
]
operator = linalg.LinearOperatorComposition(operators)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
(1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 83cc8c483f..52861ae84a 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -52,7 +52,7 @@ class LinearOperatorDiagTest(
def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
# Matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag)
@@ -62,7 +62,7 @@ class LinearOperatorDiagTest(
operator.assert_positive_definite().run()
def test_assert_positive_definite_raises_for_negative_real_eigvalues(self):
- with self.test_session():
+ with self.cached_session():
diag_x = [1.0, -2.0]
diag_y = [0., 0.] # Imaginary eigenvalues should not matter.
diag = math_ops.complex(diag_x, diag_y)
@@ -74,7 +74,7 @@ class LinearOperatorDiagTest(
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 2.]
y = [1., 0.]
diag = math_ops.complex(x, y) # Re[diag] > 0.
@@ -83,14 +83,14 @@ class LinearOperatorDiagTest(
def test_assert_non_singular_raises_if_zero_eigenvalue(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 1.]
diag = math_ops.complex(x, y)
@@ -98,7 +98,7 @@ class LinearOperatorDiagTest(
linalg.LinearOperatorDiag(diag).assert_non_singular().run()
def test_assert_self_adjoint_raises_if_diag_has_complex_part(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 1.]
diag = math_ops.complex(x, y)
@@ -107,7 +107,7 @@ class LinearOperatorDiagTest(
operator.assert_self_adjoint().run()
def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 0.]
diag = math_ops.complex(x, y)
@@ -123,7 +123,7 @@ class LinearOperatorDiagTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.matmul cannot handle.
# In particular, tf.matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = random_ops.random_normal(shape=(2, 2, 3, 4))
# This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
index 1a40a29ec6..8373b5263f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
@@ -65,7 +65,7 @@ class SquareLinearOperatorFullMatrixTest(
self.assertTrue(operator.is_square)
def test_assert_non_singular_raises_if_cond_too_big_but_finite(self):
- with self.test_session():
+ with self.cached_session():
tril = linear_operator_test_util.random_tril_matrix(
shape=(50, 50), dtype=np.float32)
diag = np.logspace(-2, 2, 50).astype(np.float32)
@@ -80,7 +80,7 @@ class SquareLinearOperatorFullMatrixTest(
operator.assert_non_singular().run()
def test_assert_non_singular_raises_if_cond_infinite(self):
- with self.test_session():
+ with self.cached_session():
matrix = [[1., 1.], [1., 1.]]
# We don't pass the is_self_adjoint hint here, which means we take the
# generic code path.
@@ -91,14 +91,14 @@ class SquareLinearOperatorFullMatrixTest(
def test_assert_self_adjoint(self):
matrix = [[0., 1.], [0., 1.]]
operator = linalg.LinearOperatorFullMatrix(matrix)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not equal to its adjoint"):
operator.assert_self_adjoint().run()
def test_assert_positive_definite(self):
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Cholesky decomposition was not success"):
operator.assert_positive_definite().run()
@@ -158,7 +158,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
# Cholesky decomposition may fail, so the error is not specific to
# non-singular.
with self.assertRaisesOpError(""):
@@ -168,7 +168,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[0., 1.], [0., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not equal to its adjoint"):
operator.assert_self_adjoint().run()
@@ -176,7 +176,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
# Cholesky decomposition may fail, so the error is not specific to
# non-singular.
with self.assertRaisesOpError(""):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index 35dcf4417c..0c3c6b390f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -57,24 +57,24 @@ class LinearOperatorIdentityTest(
return operator, mat
def test_assert_positive_definite(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_positive_definite().run() # Should not fail
def test_assert_non_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_non_singular().run() # Should not fail
def test_assert_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.matrix_solve does
# not work with float16.
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, dtype=dtypes.float16)
x = rng.randn(2, 3).astype(np.float16)
@@ -106,7 +106,7 @@ class LinearOperatorIdentityTest(
linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
def test_non_scalar_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
@@ -114,7 +114,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={num_rows: [2]})
def test_negative_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
@@ -122,7 +122,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={num_rows: -2})
def test_non_1d_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -130,7 +130,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={batch_shape: 2})
def test_negative_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -147,7 +147,7 @@ class LinearOperatorIdentityTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
@@ -158,7 +158,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = random_ops.random_normal(shape=(1, 2, 3, 4))
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
@@ -172,7 +172,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
@@ -188,7 +188,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = random_ops.random_normal(shape=(1, 2, 3, 4))
@@ -209,7 +209,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = array_ops.placeholder(dtypes.float32)
@@ -287,39 +287,39 @@ class LinearOperatorScaledIdentityTest(
return operator, matrix
def test_assert_positive_definite_does_not_raise_when_positive(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=1.)
operator.assert_positive_definite().run() # Should not fail
def test_assert_positive_definite_raises_when_negative(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=-1.)
with self.assertRaisesOpError("not positive definite"):
operator.assert_positive_definite().run()
def test_assert_non_singular_does_not_raise_when_non_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 3.])
operator.assert_non_singular().run() # Should not fail
def test_assert_non_singular_raises_when_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 0.])
with self.assertRaisesOpError("was singular"):
operator.assert_non_singular().run()
def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 0J])
operator.assert_self_adjoint().run() # Should not fail
def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 1J])
with self.assertRaisesOpError("not self-adjoint"):
@@ -328,7 +328,7 @@ class LinearOperatorScaledIdentityTest(
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.matrix_solve does
# not work with float16.
- with self.test_session():
+ with self.cached_session():
multiplier = rng.rand(3).astype(np.float16)
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=multiplier)
@@ -353,7 +353,7 @@ class LinearOperatorScaledIdentityTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows, multiplier=[1., 2], assert_proper_shapes=True)
y = operator.matmul(x)
@@ -364,7 +364,7 @@ class LinearOperatorScaledIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = random_ops.random_normal(shape=(1, 2, 3, 4))
@@ -392,7 +392,7 @@ class LinearOperatorScaledIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorScaledIdentity shape of (3, 3), the
# broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same
# shape as x.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index e26b946151..7e81c9c6c4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -70,7 +70,7 @@ class KroneckerDenseTest(test.TestCase):
[10., 15., -2., -3.],
[5., 10., -1., -2.]], dtype=dtypes.float32)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(_kronecker_dense([x, y]).eval(), z.eval())
self.assertAllClose(_kronecker_dense([y, x]).eval(), w.eval())
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 0e38dbd48d..61268607a4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -256,7 +256,7 @@ class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
# domain_dimension is 3
self.assertAllEqual([2, 3, 3], operator.shape)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([2, 3, 3], operator.to_dense().eval().shape)
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
@@ -274,7 +274,7 @@ class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
u_shape_ph: [2, 3, 2], # batch_shape = [2]
}
- with self.test_session():
+ with self.cached_session():
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], shape_tensor)
dense = operator.to_dense().eval(feed_dict=feed_dict)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index b389e0cbdf..eb4bff915b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -51,7 +51,7 @@ class LinearOperatorLowerTriangularTest(
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
tril = [[1., 0.], [1., 0.]]
operator = linalg.LinearOperatorLowerTriangular(tril)
with self.assertRaisesOpError("Singular operator"):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
index 8e9f0150a2..819347343b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
@@ -108,7 +108,7 @@ class LinearOperatorTest(test.TestCase):
self.assertAllEqual(3, operator.range_dimension)
def test_all_shape_methods_defined_by_the_one_method_shape(self):
- with self.test_session():
+ with self.cached_session():
shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape)
@@ -131,7 +131,7 @@ class LinearOperatorTest(test.TestCase):
def test_generic_to_dense_method_non_square_matrix_static(self):
matrix = rng.randn(2, 3, 4)
operator = LinearOperatorMatmulSolve(matrix)
- with self.test_session():
+ with self.cached_session():
operator_dense = operator.to_dense()
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
self.assertAllClose(matrix, operator_dense.eval())
@@ -140,7 +140,7 @@ class LinearOperatorTest(test.TestCase):
matrix = rng.randn(2, 3, 4)
matrix_ph = array_ops.placeholder(dtypes.float64)
operator = LinearOperatorMatmulSolve(matrix_ph)
- with self.test_session():
+ with self.cached_session():
operator_dense = operator.to_dense()
self.assertAllClose(
matrix, operator_dense.eval(feed_dict={matrix_ph: matrix}))
@@ -149,7 +149,7 @@ class LinearOperatorTest(test.TestCase):
matrix = [[1., 0], [0., 2.]]
operator = LinearOperatorMatmulSolve(matrix)
x = [1., 1.]
- with self.test_session():
+ with self.cached_session():
y = operator.matvec(x)
self.assertAllEqual((2,), y.get_shape())
self.assertAllClose([1., 2.], y.eval())
@@ -158,7 +158,7 @@ class LinearOperatorTest(test.TestCase):
matrix = [[1., 0], [0., 2.]]
operator = LinearOperatorMatmulSolve(matrix)
y = [1., 1.]
- with self.test_session():
+ with self.cached_session():
x = operator.solvevec(y)
self.assertAllEqual((2,), x.get_shape())
self.assertAllClose([1., 1 / 2.], x.eval())
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
index 7b291e29de..86847d38c2 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
@@ -36,7 +36,7 @@ class AssertZeroImagPartTest(test.TestCase):
def test_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([0., 2, 3])
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(x, message="ABC123").run()
@@ -44,7 +44,7 @@ class AssertZeroImagPartTest(test.TestCase):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([0., 0, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
@@ -52,7 +52,7 @@ class AssertZeroImagPartTest(test.TestCase):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
@@ -61,7 +61,7 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
def test_nonzero_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([1., 2, 3])
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
@@ -70,14 +70,14 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
def test_zero_real_tensor_raises(self):
x = ops.convert_to_tensor([1., 0, 3])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
@@ -86,7 +86,7 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
@@ -103,7 +103,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
self.assertTrue(isinstance(tensor, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(arr, tensor.eval())
def test_static_dims_broadcast(self):
@@ -118,7 +118,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
@@ -137,7 +137,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
@@ -159,7 +159,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@@ -179,7 +179,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@@ -203,7 +203,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 7)
chol_broadcast = chol + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
@@ -219,7 +219,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
chol_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.cholesky_solve_with_broadcast(
@@ -242,7 +242,7 @@ class MatmulWithBroadcastTest(test.TestCase):
y = rng.rand(3, 7)
y_broadcast = y + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matmul_with_broadcast(x, y)
self.assertAllEqual((2, 1, 7), result.get_shape())
expected = math_ops.matmul(x, y_broadcast)
@@ -258,7 +258,7 @@ class MatmulWithBroadcastTest(test.TestCase):
x_ph = array_ops.placeholder(dtypes.float64)
y_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
@@ -279,7 +279,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 7)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
@@ -295,7 +295,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_solve_with_broadcast(
@@ -317,7 +317,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
@@ -333,7 +333,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_triangular_solve_with_broadcast(
@@ -359,7 +359,7 @@ class DomainDimensionStubOperator(object):
class AssertCompatibleMatrixDimensionsTest(test.TestCase):
def test_compatible_dimensions_do_not_raise(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
operator = DomainDimensionStubOperator(3)
# Should not raise
@@ -367,7 +367,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
operator, x).run() # pyformat: disable
def test_incompatible_dimensions_raise(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
operator = DomainDimensionStubOperator(3)
with self.assertRaisesOpError("Incompatible matrix dimensions"):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
index 8f60b55e0a..f0556304ad 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
@@ -73,7 +73,7 @@ class LinearOperatorZerosTest(
operator.assert_non_singular()
def test_assert_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
@@ -108,7 +108,7 @@ class LinearOperatorZerosTest(
linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
def test_non_scalar_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
@@ -116,7 +116,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={num_rows: [2]})
def test_negative_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
n = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=n, assert_proper_shapes=True)
@@ -129,7 +129,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={n: -2})
def test_non_1d_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -137,7 +137,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={batch_shape: 2})
def test_negative_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -154,7 +154,7 @@ class LinearOperatorZerosTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index bf82e08551..0f5607712b 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -71,6 +71,36 @@ class ListOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
@test_util.run_in_graph_and_eager_modes
+ def testGatherGrad(self):
+ with backprop.GradientTape() as tape:
+ l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
+ element_shape=scalar_shape())
+ c0 = constant_op.constant(1.0)
+ tape.watch(c0)
+ l = list_ops.tensor_list_push_back(l, c0)
+ l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
+ t = list_ops.tensor_list_gather(l, [1, 0], element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(t), [2.0, 1.0])
+ s = (t[0] + t[1]) * (t[0] + t[1])
+ dt = tape.gradient(s, c0)
+ self.assertAllEqual(self.evaluate(dt), 6.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testScatterGrad(self):
+ with backprop.GradientTape() as tape:
+ c0 = constant_op.constant([1.0, 2.0])
+ tape.watch(c0)
+ l = list_ops.tensor_list_scatter(
+ c0, [1, 0], ops.convert_to_tensor([], dtype=dtypes.int32))
+ t0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
+ t1 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(t0), 2.0)
+ self.assertAllEqual(self.evaluate(t1), 1.0)
+ loss = t0 * t0 + t1 * t1
+ dt = tape.gradient(loss, c0)
+ self.assertAllEqual(self.evaluate(dt), [2., 4.])
+
+ @test_util.run_in_graph_and_eager_modes
def testStackGPU(self):
if not context.num_gpus():
return
@@ -140,9 +170,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
- @test_util.run_in_graph_and_eager_modes
def testGraphStack(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
tl = list_ops.empty_tensor_list(
element_shape=constant_op.constant([1], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -152,9 +181,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
[[1]])
- @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoop(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
t1 = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -170,9 +198,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
- @test_util.run_in_graph_and_eager_modes
def testGraphStackSwitchDtype(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
list_ = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -192,9 +219,8 @@ class ListOpsTest(test_util.TensorFlowTestCase):
np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllEqual(self.evaluate(s1), np_s1)
- @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoopSwitchDtype(self):
- with context.graph_mode(), self.test_session():
+ with self.cached_session():
t1 = list_ops.empty_tensor_list(
element_shape=constant_op.constant([], dtype=dtypes.int32),
element_dtype=dtypes.int32)
@@ -421,6 +447,72 @@ class ListOpsTest(test_util.TensorFlowTestCase):
"Invalid data type at index 0"):
self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
+ @test_util.run_in_graph_and_eager_modes
+ def testZerosLike(self):
+ for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+ dtypes.float64, dtypes.complex64, dtypes.complex128,
+ dtypes.bool):
+ l_empty = list_ops.empty_tensor_list(
+ element_dtype=dtype, element_shape=scalar_shape())
+ l_empty_zeros = array_ops.zeros_like(l_empty)
+ t_empty_zeros = list_ops.tensor_list_stack(
+ l_empty_zeros, element_dtype=dtype)
+
+ l_full = list_ops.tensor_list_push_back(l_empty,
+ math_ops.cast(0, dtype=dtype))
+ l_full = list_ops.tensor_list_push_back(l_full,
+ math_ops.cast(1, dtype=dtype))
+ l_full_zeros = array_ops.zeros_like(l_full)
+ t_full_zeros = list_ops.tensor_list_stack(
+ l_full_zeros, element_dtype=dtype)
+
+ self.assertAllEqual(self.evaluate(t_empty_zeros), [])
+ self.assertAllEqual(
+ self.evaluate(t_full_zeros), np.zeros(
+ (2,), dtype=dtype.as_numpy_dtype))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testZerosLikeVariant(self):
+ for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
+ dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+ dtypes.float64, dtypes.complex64, dtypes.complex128,
+ dtypes.bool):
+ l = list_ops.empty_tensor_list(
+ element_dtype=dtypes.variant, element_shape=scalar_shape())
+
+ sub_l = list_ops.empty_tensor_list(
+ element_dtype=dtype, element_shape=scalar_shape())
+ l = list_ops.tensor_list_push_back(l, sub_l)
+ sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+ 1, dtype=dtype))
+ l = list_ops.tensor_list_push_back(l, sub_l)
+ sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+ 2, dtype=dtype))
+ l = list_ops.tensor_list_push_back(l, sub_l)
+
+ # l : [[],
+ # [1],
+ # [1, 2]]
+ #
+ # l_zeros : [[],
+ # [0],
+ # [0, 0]]
+ l_zeros = array_ops.zeros_like(l)
+
+ outputs = []
+ for _ in range(3):
+ l_zeros, out = list_ops.tensor_list_pop_back(
+ l_zeros, element_dtype=dtypes.variant)
+ outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype))
+
+ # Note: `outputs` contains popped values so the order is reversed.
+ self.assertAllEqual(self.evaluate(outputs[2]), [])
+ self.assertAllEqual(
+ self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype))
+ self.assertAllEqual(
+ self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 24edc4f59f..723a15fbd1 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import test
@@ -39,7 +40,7 @@ class LogarithmOpTest(test.TestCase):
inp = x.astype(np_type)
with self.test_session(use_gpu=True):
# Verify that expm(logm(A)) == A.
- tf_ans = gen_linalg_ops.matrix_exponential(
+ tf_ans = linalg_impl.matrix_exponential(
gen_linalg_ops.matrix_logarithm(inp))
out = tf_ans.eval()
self.assertAllClose(inp, out, rtol=1e-4, atol=1e-3)
@@ -98,16 +99,25 @@ class LogarithmOpTest(test.TestCase):
self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64))
self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64))
- def testRandomSmallAndLarge(self):
+ def testRandomSmallAndLargeComplex64(self):
np.random.seed(42)
- for dtype in np.complex64, np.complex128:
- for batch_dims in [(), (1,), (3,), (2, 2)]:
- for size in 8, 31, 32:
- shape = batch_dims + (size, size)
- matrix = np.random.uniform(
- low=-1.0, high=1.0,
- size=np.prod(shape)).reshape(shape).astype(dtype)
- self._verifyLogarithmComplex(matrix)
+ for batch_dims in [(), (1,), (3,), (2, 2)]:
+ for size in 8, 31, 32:
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(np.complex64)
+ self._verifyLogarithmComplex(matrix)
+
+ def testRandomSmallAndLargeComplex128(self):
+ np.random.seed(42)
+ for batch_dims in [(), (1,), (3,), (2, 2)]:
+ for size in 8, 31, 32:
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(np.complex128)
+ self._verifyLogarithmComplex(matrix)
def testConcurrentExecutesWithoutError(self):
with self.test_session(use_gpu=True) as sess:
diff --git a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
index dd67919f69..e14894cf56 100644
--- a/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
+++ b/tensorflow/python/kernel_tests/parameterized_truncated_normal_op_test.py
@@ -182,6 +182,19 @@ class ParameterizedTruncatedNormalTest(test.TestCase):
def testSmallStddev(self):
self.validateKolmogorovSmirnov([10**5], 0.0, 0.1, 0.05, 0.10)
+ def testSamplingWithSmallStdDevFarFromBound(self):
+ sample_op = random_ops.parameterized_truncated_normal(
+ shape=(int(1e5),), means=0.8, stddevs=0.05, minvals=-1., maxvals=1.)
+
+ with self.test_session(use_gpu=True) as sess:
+ samples = sess.run(sample_op)
+ # 0. is more than 16 standard deviations from the mean, and
+ # should have a likelihood < 1e-57.
+ # TODO(jjhunt) Sampler is still numerically unstable in this case,
+ # numbers less than 0 should never observed.
+ no_neg_samples = np.sum(samples < 0.)
+ self.assertLess(no_neg_samples, 2.)
+
# Benchmarking code
def parameterized_vs_naive(shape, num_iters, use_gpu=False):
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 59b3ee2013..7dff4501cc 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -60,8 +60,9 @@ def flatten(list_of_lists):
def flatten_values_tensors_or_sparse(tensors_list):
"""Flatten each SparseTensor object into 3 Tensors for session.run()."""
return list(
- flatten([[v.indices, v.values, v.dense_shape] if isinstance(
- v, sparse_tensor.SparseTensor) else [v] for v in tensors_list]))
+ flatten([[v.indices, v.values, v.dense_shape]
+ if isinstance(v, sparse_tensor.SparseTensor) else [v]
+ for v in tensors_list]))
def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
@@ -106,8 +107,9 @@ class ParseExampleTest(test.TestCase):
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
serialized = kwargs["serialized"]
- batch_size = (serialized.eval().size if isinstance(serialized, ops.Tensor)
- else np.asarray(serialized).size)
+ batch_size = (
+ serialized.eval().size if isinstance(serialized, ops.Tensor) else
+ np.asarray(serialized).size)
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(
@@ -129,12 +131,9 @@ class ParseExampleTest(test.TestCase):
c_default = np.random.rand(2).astype(np.float32)
expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
expected_output = {
sparse_name: expected_st_a,
@@ -143,28 +142,23 @@ class ParseExampleTest(test.TestCase):
c_name: np.array(2 * [c_default]),
}
- self._test(
- {
- "example_names":
- np.empty(
- (0,), dtype=bytes),
- "serialized":
- ops.convert_to_tensor(["", ""]),
- "features": {
- sparse_name:
- parsing_ops.VarLenFeature(dtypes.int64),
- a_name:
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- b_name:
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- c_name:
- parsing_ops.FixedLenFeature(
- (2,), dtypes.float32, default_value=c_default),
- }
- },
- expected_output)
+ self._test({
+ "example_names": np.empty((0,), dtype=bytes),
+ "serialized": ops.convert_to_tensor(["", ""]),
+ "features": {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ }
+ }, expected_output)
def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = {
@@ -180,8 +174,7 @@ class ParseExampleTest(test.TestCase):
default_value=np.random.rand(3, 3).astype(bytes)),
# Feature "c" is missing a default, this gap will cause failure.
"c":
- parsing_ops.FixedLenFeature(
- (2,), dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature((2,), dtype=dtypes.float32),
}
# Edge case where the key is there but the feature value is empty
@@ -211,7 +204,8 @@ class ParseExampleTest(test.TestCase):
original = [
example(features=features({
"a": float_feature([1, 1, 3]),
- })), example(features=features({
+ })),
+ example(features=features({
"a": float_feature([-1, -1]),
}))
]
@@ -231,7 +225,11 @@ class ParseExampleTest(test.TestCase):
"Name: failing, Key: a, Index: 1. Number of float val"))
def testDenseDefaultNoShapeShouldFail(self):
- original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })),
+ ]
serialized = [m.SerializeToString() for m in original]
@@ -250,31 +248,31 @@ class ParseExampleTest(test.TestCase):
example(features=features({
"st_c": float_feature([3, 4])
})),
- example(features=features({
- "st_c": float_feature([]), # empty float list
- })),
- example(features=features({
- "st_d": feature(), # feature with nothing in it
- })),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- }))
+ example(
+ features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(
+ features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(
+ features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_st_c = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
- [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
- [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+ np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
+ np.array([4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
expected_st_d = ( # indices, values, shape
- np.array(
- [[3, 0]], dtype=np.int64), np.array(
- ["hi"], dtype=bytes), np.array(
- [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+ np.array([[3, 0]], dtype=np.int64), np.array(["hi"], dtype=bytes),
+ np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
expected_output = {
"st_c": expected_st_c,
@@ -291,70 +289,74 @@ class ParseExampleTest(test.TestCase):
def testSerializedContainingSparseFeature(self):
original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx":
- int64_feature([0, 9, 3]) # unsorted
- }))
+ example(
+ features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(
+ features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(
+ features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(
+ features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
- np.array(
- [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
- [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ np.array([4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
- expected_output = {"sp": expected_sp,}
+ expected_output = {
+ "sp": expected_sp,
+ }
self._test({
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "sp": parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.float32, [13])
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])
}
}, expected_output)
def testSerializedContainingSparseFeatureReuse(self):
original = [
- example(features=features({
- "val1": float_feature([3, 4]),
- "val2": float_feature([5, 6]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val1": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
+ example(
+ features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(
+ features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
]
serialized = [m.SerializeToString() for m in original]
expected_sp1 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [3.0, 4.0], dtype=np.float32), np.array(
- [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
expected_sp2 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [5.0, 6.0], dtype=np.float32), np.array(
- [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+ np.array([[0, 5], [0, 10]], dtype=np.int64),
+ np.array([5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
expected_output = {
"sp1": expected_sp1,
@@ -374,25 +376,29 @@ class ParseExampleTest(test.TestCase):
def testSerializedContaining3DSparseFeature(self):
original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx0": int64_feature([5, 10]),
- "idx1": int64_feature([0, 2]),
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx0": int64_feature([]),
- "idx1": int64_feature([]),
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx0": int64_feature([0, 9, 3]), # unsorted
- "idx1": int64_feature([1, 0, 2]),
- }))
+ example(
+ features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(
+ features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(
+ features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(
+ features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
]
serialized = [m.SerializeToString() for m in original]
@@ -407,13 +413,16 @@ class ParseExampleTest(test.TestCase):
# shape batch == 4, max_elems = 13
np.array([4, 13, 3], dtype=np.int64))
- expected_output = {"sp": expected_sp,}
+ expected_output = {
+ "sp": expected_sp,
+ }
self._test({
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "sp": parsing_ops.SparseFeature(
- ["idx0", "idx1"], "val", dtypes.float32, [13, 3])
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
}
}, expected_output)
@@ -421,41 +430,37 @@ class ParseExampleTest(test.TestCase):
aname = "a"
bname = "b*has+a:tricky_name"
original = [
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- })), example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b""]),
- }))
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
]
serialized = [m.SerializeToString() for m in original]
expected_output = {
aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
bname:
- np.array(
- ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ np.array(["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ }
+ }, expected_output)
# This test is identical as the previous one except
# for the creation of 'serialized'.
@@ -466,18 +471,22 @@ class ParseExampleTest(test.TestCase):
original = [
(example(features=features({
aname: float_feature([10, 10]),
- })), example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- }))),
+ })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
(
example(features=features({
bname: bytes_feature([b"b100"]),
})),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"]),
- })),),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ ),
]
serialized = [
@@ -486,55 +495,45 @@ class ParseExampleTest(test.TestCase):
expected_output = {
aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
bname:
- np.array(
- ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ np.array(["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
}
# No defaults, values required
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ }
+ }, expected_output)
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
"a": float_feature([1]),
- })), example(features=features({}))
+ })),
+ example(features=features({}))
]
serialized = [m.SerializeToString() for m in original]
expected_output = {
"a":
- np.array(
- [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ np.array([[1], [-1]], dtype=np.float32) # 2x1 (column vector)
}
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- "a":
- parsing_ops.FixedLenFeature(
- (1,), dtype=dtypes.float32, default_value=-1),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ }
+ }, expected_output)
def testSerializedContainingDenseWithDefaults(self):
original = [
@@ -553,58 +552,48 @@ class ParseExampleTest(test.TestCase):
expected_output = {
"a":
- np.array(
- [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
- 1),
+ np.array([[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(
+ 3, 1, 2, 1),
"b":
- np.array(
- ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
- 1),
+ np.array(["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(
+ 3, 1, 1, 1, 1),
}
- self._test(
- {
- "serialized":
- ops.convert_to_tensor(serialized),
- "features": {
- "a":
- parsing_ops.FixedLenFeature(
- (1, 2, 1),
- dtype=dtypes.float32,
- default_value=[3.0, -3.0]),
- "b":
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1),
- dtype=dtypes.string,
- default_value="tmp_str"),
- }
- },
- expected_output)
+ self._test({
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ }
+ }, expected_output)
def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array([2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
original = [
- example(features=features({
- "c": float_feature([3, 4]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "c": float_feature([1, 2]),
- "val": bytes_feature([b"c"]),
- "idx": int64_feature([7])
- }))
+ example(
+ features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })),
+ example(
+ features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
]
names = ["in1", "in2"]
@@ -617,16 +606,13 @@ class ParseExampleTest(test.TestCase):
"sp": expected_sp,
"a": np.array(2 * [[a_default]]),
"b": np.array(2 * [b_default]),
- "c": np.array(
- [[3, 4], [1, 2]], dtype=np.float32),
+ "c": np.array([[3, 4], [1, 2]], dtype=np.float32),
}
self._test(
{
- "example_names":
- names,
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_names": names,
+ "serialized": ops.convert_to_tensor(serialized),
"features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.int64),
@@ -647,25 +633,26 @@ class ParseExampleTest(test.TestCase):
def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
- np.array([0, 3, 7, 1]), np.array(
- [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+ np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]),
+ np.array([2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "d", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+ np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64),
+ np.array(["a", "b", "d", "c"], dtype="|S"),
+ np.array([2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
original = [
- example(features=features({
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "val": bytes_feature([b"c", b"d"]),
- "idx": int64_feature([7, 1])
- }))
+ example(
+ features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })),
+ example(
+ features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
]
names = ["in1", "in2"]
@@ -680,9 +667,10 @@ class ParseExampleTest(test.TestCase):
"example_names": names,
"serialized": ops.convert_to_tensor(serialized),
"features": {
- "idx": parsing_ops.VarLenFeature(dtypes.int64),
- "sp": parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.string, [13]),
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
}
}, expected_output)
@@ -720,10 +708,11 @@ class ParseExampleTest(test.TestCase):
}
original = [
- example(features=features(
- {"a": int64_feature([truth_int[i]]),
- "b": bytes_feature(truth_str[i])}))
- for i in range(batch_size)
+ example(
+ features=features({
+ "a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])
+ })) for i in range(batch_size)
]
serialized = [m.SerializeToString() for m in original]
@@ -731,12 +720,18 @@ class ParseExampleTest(test.TestCase):
self._test({
"serialized": ops.convert_to_tensor(serialized, dtype=dtypes.string),
"features": {
- "a": parsing_ops.FixedLenSequenceFeature(
- shape=(), dtype=dtypes.int64, allow_missing=True,
- default_value=-1),
- "b": parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True,
- default_value="default"),
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
}
}, expected_output)
@@ -755,18 +750,21 @@ class ParseExampleTest(test.TestCase):
example(features=features({
cname: int64_feature([2]),
})),
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str", b"b1_str"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1, 2, 2]),
- bname: bytes_feature([b"b1"]),
- })),
- example(features=features({
- aname: float_feature([]),
- cname: int64_feature([3]),
- })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(
+ features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
]
serialized = [m.SerializeToString() for m in original]
@@ -827,7 +825,9 @@ class ParseExampleTest(test.TestCase):
"features": {
aname:
parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True,
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
default_value=-2.0),
bname:
parsing_ops.FixedLenSequenceFeature(
@@ -867,7 +867,9 @@ class ParseExampleTest(test.TestCase):
"features": {
aname:
parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True,
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
default_value=[]),
bname:
parsing_ops.FixedLenSequenceFeature(
@@ -908,26 +910,28 @@ class ParseExampleTest(test.TestCase):
"All dimensions of shape for feature c need to be known "
r"but received \(1, None\)."))
- self._test({
- "example_names": example_names,
- "serialized": ops.convert_to_tensor(serialized),
- "features": {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=False),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- }
- }, expected_err=(ValueError,
- "Unsupported: FixedLenSequenceFeature requires "
- "allow_missing to be True."))
+ self._test(
+ {
+ "example_names": example_names,
+ "serialized": ops.convert_to_tensor(serialized),
+ "features": {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
class ParseSingleExampleTest(test.TestCase):
@@ -949,8 +953,8 @@ class ParseSingleExampleTest(test.TestCase):
# Check shapes.
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
- self.assertEqual(tuple(out[k].get_shape()),
- tensor_shape.as_shape(f.shape))
+ self.assertEqual(
+ tuple(out[k].get_shape()), tensor_shape.as_shape(f.shape))
elif isinstance(f, parsing_ops.VarLenFeature):
self.assertEqual(
tuple(out[k].indices.get_shape().as_list()), (None, 1))
@@ -959,29 +963,25 @@ class ParseSingleExampleTest(test.TestCase):
tuple(out[k].dense_shape.get_shape().as_list()), (1,))
def testSingleExampleWithSparseAndSparseFeatureAndDense(self):
- original = example(features=features({
- "c": float_feature([3, 4]),
- "d": float_feature([0.0, 1.0]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3]),
- "st_a": float_feature([3.0, 4.0])
- }))
+ original = example(
+ features=features({
+ "c": float_feature([3, 4]),
+ "d": float_feature([0.0, 1.0]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3]),
+ "st_a": float_feature([3.0, 4.0])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0], [1]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0], dtype=np.float32), # values
- np.array(
- [2], dtype=np.int64)) # shape: max_values = 2
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: max_values = 2
expected_sp = ( # indices, values, shape
- np.array(
- [[0], [3]], dtype=np.int64), np.array(
- ["a", "b"], dtype="|S"), np.array(
- [13], dtype=np.int64)) # max_values = 13
+ np.array([[0], [3]], dtype=np.int64), np.array(["a", "b"], dtype="|S"),
+ np.array([13], dtype=np.int64)) # max_values = 13
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
@@ -996,16 +996,14 @@ class ParseSingleExampleTest(test.TestCase):
self._test(
{
- "example_names":
- ops.convert_to_tensor("in1"),
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_names": ops.convert_to_tensor("in1"),
+ "serialized": ops.convert_to_tensor(serialized),
"features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
"sp":
- parsing_ops.SparseFeature(
- ["idx"], "val", dtypes.string, [13]),
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string,
+ [13]),
"a":
parsing_ops.FixedLenFeature(
(1, 3), dtypes.int64, default_value=a_default),
@@ -1016,9 +1014,8 @@ class ParseSingleExampleTest(test.TestCase):
"c":
parsing_ops.FixedLenFeature(2, dtypes.float32),
"d":
- parsing_ops.FixedLenSequenceFeature([],
- dtypes.float32,
- allow_missing=True)
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True)
}
},
expected_output)
@@ -1050,43 +1047,71 @@ class ParseSequenceExampleTest(test.TestCase):
kwargs,
expected_context_values=None,
expected_feat_list_values=None,
- expected_err=None):
+ expected_length_values=None,
+ expected_err=None,
+ batch=False):
expected_context_values = expected_context_values or {}
expected_feat_list_values = expected_feat_list_values or {}
+ expected_length_values = expected_length_values or {}
with self.test_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
- c_out, fl_out = parsing_ops.parse_single_sequence_example(**kwargs)
+ if batch:
+ c_out, fl_out, _ = parsing_ops.parse_sequence_example(**kwargs)
+ else:
+ c_out, fl_out = parsing_ops.parse_single_sequence_example(**kwargs)
if c_out:
sess.run(flatten_values_tensors_or_sparse(c_out.values()))
if fl_out:
sess.run(flatten_values_tensors_or_sparse(fl_out.values()))
else:
# Returns dicts w/ Tensors and SparseTensors.
- context_out, feat_list_out = parsing_ops.parse_single_sequence_example(
- **kwargs)
+ if batch:
+ (context_out, feat_list_out,
+ lengths_out) = parsing_ops.parse_sequence_example(**kwargs)
+ else:
+ (context_out,
+ feat_list_out) = parsing_ops.parse_single_sequence_example(**kwargs)
+ lengths_out = {}
+
context_result = sess.run(
- flatten_values_tensors_or_sparse(context_out.values(
- ))) if context_out else []
+ flatten_values_tensors_or_sparse(
+ context_out.values())) if context_out else []
feat_list_result = sess.run(
- flatten_values_tensors_or_sparse(feat_list_out.values(
- ))) if feat_list_out else []
+ flatten_values_tensors_or_sparse(
+ feat_list_out.values())) if feat_list_out else []
+ lengths_result = sess.run(
+ flatten_values_tensors_or_sparse(
+ lengths_out.values())) if lengths_out else []
# Check values.
_compare_output_to_expected(self, context_out, expected_context_values,
context_result)
_compare_output_to_expected(self, feat_list_out,
expected_feat_list_values, feat_list_result)
+ _compare_output_to_expected(self, lengths_out, expected_length_values,
+ lengths_result)
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
if "context_features" in kwargs:
for k, f in kwargs["context_features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ if batch:
+ self.assertEqual(
+ tuple(context_out[k].get_shape().as_list()[1:]), f.shape)
+ else:
+ self.assertEqual(
+ tuple(context_out[k].get_shape().as_list()), f.shape)
+ elif isinstance(f, parsing_ops.VarLenFeature) and batch:
self.assertEqual(
- tuple(context_out[k].get_shape().as_list()), f.shape)
- elif isinstance(f, parsing_ops.VarLenFeature):
+ tuple(context_out[k].indices.get_shape().as_list()), (None, 2))
+ self.assertEqual(
+ tuple(context_out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(
+ tuple(context_out[k].dense_shape.get_shape().as_list()), (2,))
+ elif isinstance(f, parsing_ops.VarLenFeature) and not batch:
self.assertEqual(
tuple(context_out[k].indices.get_shape().as_list()), (None, 1))
self.assertEqual(
@@ -1094,38 +1119,94 @@ class ParseSequenceExampleTest(test.TestCase):
self.assertEqual(
tuple(context_out[k].dense_shape.get_shape().as_list()), (1,))
+ def _testBoth(self,
+ kwargs,
+ expected_context_values=None,
+ expected_feat_list_values=None,
+ expected_err=None):
+ # Test using tf.parse_single_sequence_example
+ self._test(
+ kwargs,
+ expected_context_values=expected_context_values,
+ expected_feat_list_values=expected_feat_list_values,
+ expected_err=expected_err,
+ batch=False)
+
+ # Convert the input to a batch of size 1, and test using
+ # tf.parse_sequence_example.
+
+ # Some replacements are needed for the batch version.
+ kwargs["serialized"] = [kwargs.pop("serialized")]
+ kwargs["example_names"] = [kwargs.pop("example_name")
+ ] if "example_name" in kwargs else None
+ # Disable error string matching; it's not consistent for batch mode.
+ if expected_err:
+ expected_err = (expected_err[0], "")
+
+ # Add a batch dimension to expected output
+ if expected_context_values:
+ new_values = {}
+ for k in expected_context_values:
+ v = expected_context_values[k]
+ if isinstance(kwargs["context_features"][k],
+ parsing_ops.FixedLenFeature):
+ new_values[k] = np.expand_dims(v, axis=0)
+ else:
+ # Sparse tensor.
+ new_values[k] = (np.insert(v[0], 0, 0, axis=1), v[1],
+ np.insert(v[2], 0, 1))
+ expected_context_values = new_values
+
+ expected_length_values = {}
+ if expected_feat_list_values:
+ new_values = {}
+ for k in expected_feat_list_values:
+ v = expected_feat_list_values[k]
+ if isinstance(kwargs["sequence_features"][k],
+ parsing_ops.FixedLenSequenceFeature):
+ expected_length_values[k] = [np.shape(v)[0]]
+ new_values[k] = np.expand_dims(v, axis=0)
+ else:
+ # Sparse tensor.
+ new_values[k] = (np.insert(v[0], 0, 0, axis=1), v[1],
+ np.insert(v[2], 0, 1))
+ expected_feat_list_values = new_values
+
+ self._test(
+ kwargs,
+ expected_context_values=expected_context_values,
+ expected_feat_list_values=expected_feat_list_values,
+ expected_length_values=expected_length_values,
+ expected_err=expected_err,
+ batch=True)
+
def testSequenceExampleWithSparseAndDenseContext(self):
- original = sequence_example(context=features({
- "c": float_feature([3, 4]),
- "st_a": float_feature([3.0, 4.0])
- }))
+ original = sequence_example(
+ context=features({
+ "c": float_feature([3, 4]),
+ "st_a": float_feature([3.0, 4.0])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0], [1]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0], dtype=np.float32), # values
- np.array(
- [2], dtype=np.int64)) # shape: num_features = 2
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: num_features = 2
- a_default = [1, 2, 3]
+ a_default = [[1, 2, 3]]
b_default = np.random.rand(3, 3).astype(bytes)
expected_context_output = {
"st_a": expected_st_a,
- "a": [a_default],
+ "a": a_default,
"b": b_default,
- "c": np.array(
- [3, 4], dtype=np.float32),
+ "c": np.array([3, 4], dtype=np.float32),
}
- self._test(
+ self._testBoth(
{
- "example_name":
- "in1",
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_name": "in1",
+ "serialized": ops.convert_to_tensor(serialized),
"context_features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
@@ -1143,51 +1224,54 @@ class ParseSequenceExampleTest(test.TestCase):
expected_context_values=expected_context_output)
def testSequenceExampleWithMultipleSizeFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([
- int64_feature([-1, 0, 1]),
- int64_feature([2, 3, 4]),
- int64_feature([5, 6, 7]),
- int64_feature([8, 9, 10]),
- ]),
- "b":
- feature_list([bytes_feature([b"r00", b"r01", b"r10", b"r11"])]),
- "c":
- feature_list([float_feature([3, 4]), float_feature([-1, 2])]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([-1, 0, 1]),
+ int64_feature([2, 3, 4]),
+ int64_feature([5, 6, 7]),
+ int64_feature([8, 9, 10]),
+ ]),
+ "b":
+ feature_list([bytes_feature([b"r00", b"r01", b"r10", b"r11"])]),
+ "c":
+ feature_list([float_feature([3, 4]),
+ float_feature([-1, 2])]),
+ }))
serialized = original.SerializeToString()
expected_feature_list_output = {
- "a": np.array(
- [ # outer dimension is time.
- [[-1, 0, 1]], # inside are 1x3 matrices
- [[2, 3, 4]],
- [[5, 6, 7]],
- [[8, 9, 10]]
- ],
- dtype=np.int64),
- "b": np.array(
- [ # outer dimension is time, inside are 2x2 matrices
- [[b"r00", b"r01"], [b"r10", b"r11"]]
- ],
- dtype=bytes),
- "c": np.array(
- [ # outer dimension is time, inside are 2-vectors
- [3, 4], [-1, 2]
- ],
- dtype=np.float32),
- "d": np.empty(
- shape=(0, 5), dtype=np.float32), # empty_allowed_missing
+ "a":
+ np.array(
+ [ # outer dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]
+ ],
+ dtype=np.int64),
+ "b":
+ np.array(
+ [ # outer dimension is time, inside are 2x2 matrices
+ [[b"r00", b"r01"], [b"r10", b"r11"]]
+ ],
+ dtype=bytes),
+ "c":
+ np.array(
+ [ # outer dimension is time, inside are 2-vectors
+ [3, 4], [-1, 2]
+ ],
+ dtype=np.float32),
+ "d":
+ np.empty(shape=(0, 5), dtype=np.float32), # empty_allowed_missing
}
- self._test(
+ self._testBoth(
{
- "example_name":
- "in1",
- "serialized":
- ops.convert_to_tensor(serialized),
+ "example_name": "in1",
+ "serialized": ops.convert_to_tensor(serialized),
"sequence_features": {
"a":
parsing_ops.FixedLenSequenceFeature((1, 3), dtypes.int64),
@@ -1203,56 +1287,51 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithoutDebugName(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([int64_feature([3, 4]), int64_feature([1, 0])]),
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]), float_feature([5.0]),
- float_feature([])
- ]),
- "st_b":
- feature_list([
- bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
- bytes_feature([b"b", b"c"])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])
+ ]),
+ "st_b":
+ feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array(
- [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(
- ["a", "b", "c"], dtype="|S"), # values
- np.array(
- [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype="|S"), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # values
- np.array(
- [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array(
- [[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test(
+ self._testBoth(
{
"serialized": ops.convert_to_tensor(serialized),
"sequence_features": {
@@ -1265,56 +1344,51 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithSparseAndDenseFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([int64_feature([3, 4]), int64_feature([1, 0])]),
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]), float_feature([5.0]),
- float_feature([])
- ]),
- "st_b":
- feature_list([
- bytes_feature([b"a"]), bytes_feature([]), bytes_feature([]),
- bytes_feature([b"b", b"c"])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])
+ ]),
+ "st_b":
+ feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])
+ ])
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_st_b = (
- np.array(
- [[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
- np.array(
- ["a", "b", "c"], dtype="|S"), # values
- np.array(
- [4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype="|S"), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
expected_st_c = (
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # values
- np.array(
- [0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
expected_feature_list_output = {
- "a": np.array(
- [[3, 4], [1, 0]], dtype=np.int64),
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
"st_a": expected_st_a,
"st_b": expected_st_b,
"st_c": expected_st_c,
}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1328,30 +1402,28 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleWithEmptyFeatureInFeatureLists(self):
- original = sequence_example(feature_lists=feature_lists({
- "st_a":
- feature_list([
- float_feature([3.0, 4.0]),
- feature(),
- float_feature([5.0]),
- ]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "st_a":
+ feature_list([
+ float_feature([3.0, 4.0]),
+ feature(),
+ float_feature([5.0]),
+ ]),
+ }))
serialized = original.SerializeToString()
expected_st_a = (
- np.array(
- [[0, 0], [0, 1], [2, 0]], dtype=np.int64), # indices
- np.array(
- [3.0, 4.0, 5.0], dtype=np.float32), # values
- np.array(
- [3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+ np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
expected_feature_list_output = {
"st_a": expected_st_a,
}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1362,13 +1434,15 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values=expected_feature_list_output)
def testSequenceExampleListWithInconsistentDataFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([int64_feature([-1, 0]), float_feature([2, 3])])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([int64_feature([-1, 0]),
+ float_feature([2, 3])])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1380,13 +1454,14 @@ class ParseSequenceExampleTest(test.TestCase):
" Data types don't match. Expected type: int64"))
def testSequenceExampleListWithWrongDataTypeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([float_feature([2, 3])])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([float_feature([2, 3])])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1399,17 +1474,19 @@ class ParseSequenceExampleTest(test.TestCase):
" Expected type: int64"))
def testSequenceExampleListWithWrongSparseDataTypeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a":
- feature_list([
- int64_feature([3, 4]), int64_feature([1, 2]),
- float_feature([2.0, 3.0])
- ])
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 2]),
+ float_feature([2.0, 3.0])
+ ])
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1423,13 +1500,16 @@ class ParseSequenceExampleTest(test.TestCase):
" Feature is: float_list"))
def testSequenceExampleListWithWrongShapeFails(self):
- original = sequence_example(feature_lists=feature_lists({
- "a": feature_list([int64_feature([2, 3]), int64_feature([2, 3, 4])]),
- }))
+ original = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([int64_feature([2, 3]),
+ int64_feature([2, 3, 4])]),
+ }))
serialized = original.SerializeToString()
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(serialized),
@@ -1446,7 +1526,7 @@ class ParseSequenceExampleTest(test.TestCase):
# Test fails because we didn't add:
# feature_list_dense_defaults = {"a": None}
- self._test(
+ self._testBoth(
{
"example_name": "in1",
"serialized": ops.convert_to_tensor(original.SerializeToString()),
@@ -1461,6 +1541,67 @@ class ParseSequenceExampleTest(test.TestCase):
" feature_list_dense_missing_assumed_empty or"
" feature_list_dense_defaults?"))
+ def testSequenceExampleBatch(self):
+ first = sequence_example(
+ feature_lists=feature_lists({
+ "a":
+ feature_list([
+ int64_feature([-1, 0, 1]),
+ int64_feature([2, 3, 4]),
+ int64_feature([5, 6, 7]),
+ int64_feature([8, 9, 10]),
+ ])
+ }))
+ second = sequence_example(
+ feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([21, 2, 11]),
+ ])
+ }))
+
+ serialized = [first.SerializeToString(), second.SerializeToString()]
+
+ expected_feature_list_output = {
+ "a":
+ np.array(
+ [ # outermost dimension is example id
+ [ # middle dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]
+ ],
+ [ # middle dimension is time.
+ [[21, 2, 11]], # inside are 1x3 matrices
+ [[0, 0, 0]], # additional entries are padded with 0
+ [[0, 0, 0]],
+ [[0, 0, 0]]
+ ]
+ ],
+ dtype=np.int64),
+ "d":
+ np.empty(shape=(2, 0, 5), dtype=np.float32), # allowed_missing
+ }
+
+ self._test(
+ {
+ "example_names": ops.convert_to_tensor(["in1", "in2"]),
+ "serialized": ops.convert_to_tensor(serialized),
+ "sequence_features": {
+ "a":
+ parsing_ops.FixedLenSequenceFeature((1, 3), dtypes.int64),
+ "d":
+ parsing_ops.FixedLenSequenceFeature(
+ (5,), dtypes.float32, allow_missing=True),
+ }
+ },
+ expected_feat_list_values=expected_feature_list_output,
+ expected_length_values={
+ "a": [4, 1],
+ "d": [0, 0]
+ },
+ batch=True)
+
class DecodeJSONExampleTest(test.TestCase):
@@ -1531,24 +1672,27 @@ class DecodeJSONExampleTest(test.TestCase):
example(features=features({
"st_d": feature()
})),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- })),
+ example(
+ features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ })),
])
def testSerializedContainingBytes(self):
aname = "a"
bname = "b*has+a:tricky_name"
self._testRoundTrip([
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"])
- })),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"])
- })),
+ example(
+ features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"])
+ })),
+ example(
+ features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"])
+ })),
])
def testInvalidSyntax(self):
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index ba9359d923..15d5702252 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -25,15 +27,13 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
@@ -546,6 +546,32 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
+ def testControlDepsNone(self):
+ with self.test_session() as session:
+ c = constant_op.constant(1.0)
+ with ops.control_dependencies([c]):
+ # d get the control dependency.
+ d = constant_op.constant(2.0)
+ # Partitioned variables do not.
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ ops_before_read = session.graph.get_operations()
+ var_x.as_tensor() # Caches the ops for subsequent reads.
+ reading_ops = [
+ op for op in session.graph.get_operations()
+ if op not in ops_before_read
+ ]
+
+ self.assertEqual([c.op], d.op.control_inputs)
+ # Tests that no control dependencies are added to reading a partitioned
+ # variable which is similar to reading a variable.
+ for op in reading_ops:
+ self.assertEqual([], op.control_inputs)
+
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@@ -571,57 +597,38 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
- def testVariableCreationInALoop(self):
- """Tests the variable created inside a loop can be used outside the loop."""
- with self.test_session():
- with variable_scope.variable_scope("ascope") as scope:
- def Body(i, _):
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(
- 4))
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
+ def testMetaGraphSaveLoad(self):
+ save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
+ with variable_scope.variable_scope("root", partitioner=partitioner):
+ v0 = variable_scope.get_variable(
+ "v0", dtype=dtypes.float32, shape=(10, 10))
+ v0_list = v0._get_variable_list()
+ v0_part = v0._get_partitions()
+ self.assertEqual(len(v0_list), 5)
+ self.assertAllEqual(v0_part, (5, 1))
variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 1.0], x.eval())
-
- scope.reuse_variables()
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
-
- def testReadInWhileLoop(self):
- """Tests the value is current (not cached) when read within a loop."""
- with self.test_session():
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- def Body(i, _):
- # Use a SGD step to update the variable's value.
- loss = math_ops.reduce_sum(var_x)
- optimizer = gradient_descent.GradientDescentOptimizer(1.0)
- minimize = optimizer.minimize(loss * 0.7)
- with ops.control_dependencies([minimize]):
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
- variables.global_variables_initializer().run()
- self.assertAllClose([-0.4, -0.4], x.eval())
+ save_graph.get_collection_ref("partvar").append(v0)
+ saver = saver_lib.Saver()
+ save_graph.finalize()
+ save_path = saver.save(sess=session, save_path=save_prefix)
+ previous_value = session.run(
+ save_graph.get_tensor_by_name(v0.name + ":0"))
+
+ restore_graph = ops.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as session:
+ saver = saver_lib.import_meta_graph(save_path + ".meta")
+ saver.restore(sess=session, save_path=save_path)
+ v0, = save_graph.get_collection_ref("partvar")
+ self.assertIsInstance(v0, variables.PartitionedVariable)
+ self.assertAllEqual(
+ previous_value,
+ session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 50154a45a8..5f5e24bd63 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -61,7 +61,7 @@ class PyFuncTest(test.TestCase):
for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
dtypes.int32, dtypes.int64]:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1, dtype=dtype)
y = constant_op.constant(2, dtype=dtype)
z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
@@ -71,7 +71,7 @@ class PyFuncTest(test.TestCase):
def sub_func(x, y):
return x - y
for dtype in [dtypes.complex64, dtypes.complex128]:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1 + 1j, dtype=dtype)
y = constant_op.constant(2 - 2j, dtype=dtype)
z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
@@ -81,21 +81,21 @@ class PyFuncTest(test.TestCase):
def and_func(x, y):
return x and y
dtype = dtypes.bool
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(True, dtype=dtype)
y = constant_op.constant(False, dtype=dtype)
z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
self.assertEqual(z, False)
def testSingleType(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
def testScalar(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = self.evaluate(
@@ -103,7 +103,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
def testArray(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([1.0, 2.0], dtypes.float64)
y = constant_op.constant([2.0, 3.0], dtypes.float64)
z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
@@ -111,14 +111,14 @@ class PyFuncTest(test.TestCase):
np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
def testComplexType(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1 + 2j, dtypes.complex64)
y = constant_op.constant(3 + 4j, dtypes.complex64)
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
def testRFFT(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
def rfft(x):
@@ -128,7 +128,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
def testPythonLiteral(self):
- with self.test_session():
+ with self.cached_session():
def literal(x):
return 1.0 if float(x) == 0.0 else 0.0
@@ -138,7 +138,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, 1.0)
def testList(self):
- with self.test_session():
+ with self.cached_session():
def list_func(x):
return [x, x + 1]
@@ -150,7 +150,7 @@ class PyFuncTest(test.TestCase):
def testTuple(self):
# returns a tuple
- with self.test_session():
+ with self.cached_session():
def tuple_func(x):
return x, x + 1
@@ -161,7 +161,7 @@ class PyFuncTest(test.TestCase):
self.assertAllClose(y, [0.0, 1.0])
# returns a tuple, Tout and inp a tuple
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(0.0, dtypes.float64)
y = self.evaluate(
script_ops.py_func(tuple_func, (x,),
@@ -176,7 +176,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -193,7 +193,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -210,7 +210,7 @@ class PyFuncTest(test.TestCase):
def read_and_return_strings(x, y):
return x + y
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
y, = script_ops.py_func(read_object_array, [],
[dtypes.string])
@@ -219,19 +219,19 @@ class PyFuncTest(test.TestCase):
def testStringPadding(self):
correct = [b"this", b"is", b"a", b"test"]
- with self.test_session():
+ with self.cached_session():
s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
def testStringPaddingAreConvertedToBytes(self):
inp = ["this", "is", "a", "test"]
correct = [b"this", b"is", b"a", b"test"]
- with self.test_session():
+ with self.cached_session():
s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
def testLarge(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.zeros([1000000], dtype=np.float32)
y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
@@ -239,12 +239,12 @@ class PyFuncTest(test.TestCase):
sess.run([y[0].op, z[0].op])
def testNoInput(self):
- with self.test_session():
+ with self.cached_session():
x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
self.assertAllClose(x, 42.0)
def testAlias(self):
- with self.test_session():
+ with self.cached_session():
np_array = np.array([1.0, 2.0], dtype=np.float32)
tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
@@ -252,7 +252,7 @@ class PyFuncTest(test.TestCase):
self.assertAllEqual(np_array, [1.0, 2.0])
def testReturnUnicodeString(self):
- with self.test_session():
+ with self.cached_session():
correct = u"你好 世界"
def unicode_string():
@@ -262,7 +262,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(z.eval(), correct.encode("utf8"))
def testBadNumpyReturnType(self):
- with self.test_session():
+ with self.cached_session():
def bad():
# Structured numpy arrays aren't supported.
@@ -275,7 +275,7 @@ class PyFuncTest(test.TestCase):
y.eval()
def testBadReturnType(self):
- with self.test_session():
+ with self.cached_session():
def bad():
# Non-string python objects aren't supported.
@@ -288,7 +288,7 @@ class PyFuncTest(test.TestCase):
z.eval()
def testReturnInput(self):
- with self.test_session():
+ with self.cached_session():
def ident(x):
return x[0]
@@ -303,7 +303,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
def testStateful(self):
- # Not using self.test_session(), which disables optimization.
+ # Not using self.cached_session(), which disables optimization.
with session_lib.Session() as sess:
producer = iter(range(3))
x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
@@ -312,7 +312,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(sess.run(x), 2)
def testStateless(self):
- # Not using self.test_session(), which disables optimization.
+ # Not using self.cached_session(), which disables optimization.
with session_lib.Session() as sess:
producer = iter(range(3))
x, = script_ops.py_func(
@@ -331,7 +331,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(None, ops.get_gradient_function(y.op))
def testCOrder(self):
- with self.test_session():
+ with self.cached_session():
val = [[1, 2], [3, 4]]
x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
[dtypes.int64])
@@ -339,7 +339,7 @@ class PyFuncTest(test.TestCase):
def testParallel(self):
# Tests that tf.py_func's can run in parallel if they release the GIL.
- with self.test_session() as session:
+ with self.cached_session() as session:
q = queue.Queue(1)
def blocking_put():
@@ -375,7 +375,7 @@ class PyFuncTest(test.TestCase):
def value(self):
return self._value
- with self.test_session():
+ with self.cached_session():
s = State()
op = s.increment(constant_op.constant(2, dtypes.int64))
ret = self.evaluate(op)
@@ -389,7 +389,7 @@ class PyFuncTest(test.TestCase):
f = script_ops.py_func(
do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(f), [])
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
@@ -417,21 +417,22 @@ class PyFuncTest(test.TestCase):
else:
f = script_ops.py_func(raise_exception, [], [])
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
- self.evaluate(f)
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
+ self.evaluate(f)
def testExceptionHandling(self):
- self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
- self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
- self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
- self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
- self._testExceptionHandling(NotImplementedError, errors.UnimplementedError)
+ with self.cached_session():
+ self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
+ self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
+ self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
+ self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
+ self._testExceptionHandling(NotImplementedError,
+ errors.UnimplementedError)
- class WeirdError(Exception):
- pass
+ class WeirdError(Exception):
+ pass
- self._testExceptionHandling(WeirdError, errors.UnknownError)
+ self._testExceptionHandling(WeirdError, errors.UnknownError)
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
@@ -452,7 +453,7 @@ class PyFuncTest(test.TestCase):
# (see #18292)
_ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
-
+
# Call garbage collector to enforce deletion.
make_graphs()
ops.reset_default_graph()
@@ -565,6 +566,18 @@ class PyFuncTest(test.TestCase):
dy_dx = gradients_impl.gradients(y, x)[0]
self.assertEqual(self.evaluate(dy_dx), 6.0)
+ def testEagerGradientGraphTwoOutputs(self):
+
+ def f(x, y):
+ return x * y, x / y
+
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(2.0)
+ fa, fb = script_ops.eager_py_func(f, inp=[x, y],
+ Tout=[dtypes.float32, dtypes.float32])
+ dy_dx = gradients_impl.gradients(fa + fb, x)[0]
+ self.assertEqual(self.evaluate(dy_dx), 2.5)
+
@test_util.run_in_graph_and_eager_modes
def testEagerGradientTapeMultipleArgs(self):
@@ -610,7 +623,7 @@ class PyFuncTest(test.TestCase):
func=log_huber, inp=[x, m], Tout=dtypes.float32)
dy_dx = gradients_impl.gradients(y, x)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Takes the first branch of log_huber.
y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
self.assertEqual(y, 1.0)
diff --git a/tensorflow/python/kernel_tests/random/random_crop_test.py b/tensorflow/python/kernel_tests/random/random_crop_test.py
index 6028be1228..8ded522320 100644
--- a/tensorflow/python/kernel_tests/random/random_crop_test.py
+++ b/tensorflow/python/kernel_tests/random/random_crop_test.py
@@ -30,12 +30,12 @@ class RandomCropTest(test.TestCase):
# No random cropping is performed since the size is value.shape.
for shape in (2, 1, 1), (2, 1, 3), (4, 5, 3):
value = np.arange(0, np.prod(shape), dtype=np.int32).reshape(shape)
- with self.test_session():
+ with self.cached_session():
crop = random_ops.random_crop(value, shape).eval()
self.assertAllEqual(crop, value)
def testContains(self):
- with self.test_session():
+ with self.cached_session():
shape = (3, 5, 7)
target = (2, 3, 4)
value = np.random.randint(1000000, size=shape)
@@ -57,7 +57,7 @@ class RandomCropTest(test.TestCase):
single = [1, 1, 1]
value = np.arange(size).reshape(shape)
- with self.test_session():
+ with self.cached_session():
crop = random_ops.random_crop(value, single, seed=7)
counts = np.zeros(size, dtype=np.int32)
for _ in range(num_samples):
diff --git a/tensorflow/python/kernel_tests/random/random_gamma_test.py b/tensorflow/python/kernel_tests/random/random_gamma_test.py
index aa40228dc1..d969944493 100644
--- a/tensorflow/python/kernel_tests/random/random_gamma_test.py
+++ b/tensorflow/python/kernel_tests/random/random_gamma_test.py
@@ -256,7 +256,7 @@ class RandomGammaTest(test.TestCase):
def testPositive(self):
n = int(10e3)
for dt in [dtypes.float16, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
x = random_ops.random_gamma(shape=[n], alpha=0.001, dtype=dt, seed=0)
self.assertEqual(0, math_ops.reduce_sum(math_ops.cast(
math_ops.less_equal(x, 0.), dtype=dtypes.int64)).eval())
diff --git a/tensorflow/python/kernel_tests/random/random_grad_test.py b/tensorflow/python/kernel_tests/random/random_grad_test.py
index c1d455b785..d89056c485 100644
--- a/tensorflow/python/kernel_tests/random/random_grad_test.py
+++ b/tensorflow/python/kernel_tests/random/random_grad_test.py
@@ -49,7 +49,7 @@ class AddLeadingUnitDimensionsTest(test.TestCase):
x = array_ops.placeholder(dtypes.float32)
num_dimensions = array_ops.placeholder(dtypes.int32)
ret = random_grad.add_leading_unit_dimensions(x, num_dimensions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ret_val = sess.run(ret, {x: np.ones([2, 2]), num_dimensions: 2})
self.assertAllEqual(ret_val.shape, [1, 1, 2, 2])
@@ -99,7 +99,7 @@ class RandomGammaGradTest(test.TestCase):
alpha_val = np.ones([1, 2])
beta_val = np.ones([2, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
grads_alpha_val, grads_beta_val = sess.run(
[grads_alpha, grads_beta],
{alpha: alpha_val, beta: beta_val, shape: [2, 1]})
diff --git a/tensorflow/python/kernel_tests/random/random_poisson_test.py b/tensorflow/python/kernel_tests/random/random_poisson_test.py
index afdf71e652..15ab95cdb7 100644
--- a/tensorflow/python/kernel_tests/random/random_poisson_test.py
+++ b/tensorflow/python/kernel_tests/random/random_poisson_test.py
@@ -137,7 +137,7 @@ class RandomPoissonTest(test.TestCase):
self.assertGreaterEqual(np.linalg.norm(diff.eval()), 1)
def testZeroShape(self):
- with self.test_session():
+ with self.cached_session():
rnd = random_ops.random_poisson([], [], seed=12345)
self.assertEqual([0], rnd.get_shape().as_list())
self.assertAllClose(np.array([], dtype=np.float32), rnd.eval())
@@ -186,7 +186,7 @@ class RandomPoissonTest(test.TestCase):
def testDTypeCombinationsV2(self):
"""Tests random_poisson_v2() for all supported dtype combinations."""
- with self.test_session():
+ with self.cached_session():
for lam_dt in _SUPPORTED_DTYPES:
for out_dt in _SUPPORTED_DTYPES:
random_ops.random_poisson(
diff --git a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
index b7a79f239c..0d85a072d4 100644
--- a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
+++ b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
@@ -46,7 +46,7 @@ class RandomShuffleQueueTest(test.TestCase):
tf_logging.error("Finished: %s", self._testMethodName)
def testEnqueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
self.assertAllEqual(0, q.size().eval())
@@ -54,7 +54,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertAllEqual(1, q.size().eval())
def testEnqueueWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shapes=tensor_shape.TensorShape([3, 2]))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
@@ -64,7 +64,7 @@ class RandomShuffleQueueTest(test.TestCase):
q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
def testEnqueueManyWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(
10, 5, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
@@ -76,7 +76,7 @@ class RandomShuffleQueueTest(test.TestCase):
q2.enqueue_many(([[1, 2, 3]],))
def testScalarShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
10, 0, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (1,)])
q.enqueue_many([[1, 2, 3, 4], [[5], [6], [7], [8]]]).run()
@@ -93,7 +93,7 @@ class RandomShuffleQueueTest(test.TestCase):
results)
def testParallelEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -119,7 +119,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -143,7 +143,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -156,7 +156,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(3, 0, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -185,7 +185,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
10, 0, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
@@ -202,12 +202,12 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testQueueSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32)
self.assertEqual(0, q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
@@ -220,7 +220,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual([0], size.eval())
def testEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -234,7 +234,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems + elems, results)
def testEmptyEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32)
empty_t = constant_op.constant(
[], dtype=dtypes_lib.float32, shape=[0, 2, 3])
@@ -246,7 +246,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(0, size_t.eval())
def testEmptyDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_many(0)
@@ -256,7 +256,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueUpTo(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_up_to(0)
@@ -266,7 +266,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueManyWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
enqueue_op = q.enqueue((constant_op.constant(
[10.0, 20.0], shape=(1, 2)),))
@@ -287,7 +287,7 @@ class RandomShuffleQueueTest(test.TestCase):
dequeued_t.eval()
def testEmptyDequeueUpToWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
enqueue_op = q.enqueue((constant_op.constant(
[10.0, 20.0], shape=(1, 2)),))
@@ -308,7 +308,7 @@ class RandomShuffleQueueTest(test.TestCase):
dequeued_t.eval()
def testMultiEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
10, 0, (dtypes_lib.float32, dtypes_lib.int32))
float_elems = [10.0, 20.0, 30.0, 40.0]
@@ -327,7 +327,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(expected, results)
def testDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -340,7 +340,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeueUpToNoBlocking(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -353,7 +353,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testMultiDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
10, 0, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -387,7 +387,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(zip(float_elems, int_elems), results)
def testMultiDequeueUpToNoBlocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
10, 0, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -422,7 +422,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(zip(float_elems, int_elems), results)
def testHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.int32, (
(4, 4, 4, 4)))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
@@ -433,7 +433,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval().tolist(), elems.tolist())
def testParallelEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
1000, 0, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(100)]
@@ -453,7 +453,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
1000, 0, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
@@ -476,7 +476,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
1000, 0, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
@@ -499,7 +499,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpToRandomPartition(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dequeue_sizes = [random.randint(50, 150) for _ in xrange(10)]
total_elements = sum(dequeue_sizes)
q = data_flow_ops.RandomShuffleQueue(
@@ -527,7 +527,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testBlockingDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -554,7 +554,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -581,7 +581,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testDequeueManyWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.RandomShuffleQueue(100, 0, dtypes_lib.int32)
@@ -607,7 +607,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testDequeueUpToWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.RandomShuffleQueue(100, 0, dtypes_lib.int32)
@@ -633,7 +633,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testDequeueFromClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 2, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -652,7 +652,7 @@ class RandomShuffleQueueTest(test.TestCase):
dequeued_t.eval()
def testBlockingDequeueFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
min_size = 2
q = data_flow_ops.RandomShuffleQueue(10, min_size, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
@@ -690,7 +690,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(len(results), 4)
def testBlockingDequeueFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
close_op = q.close()
dequeued_t = q.dequeue()
@@ -715,7 +715,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(len(finished), 1)
def testBlockingDequeueManyFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -751,7 +751,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(len(progress), 2)
def testBlockingDequeueUpToFromClosedQueueReturnsRemainder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -778,7 +778,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(results, elems)
def testBlockingDequeueUpToSmallerThanMinAfterDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(
capacity=10,
min_after_dequeue=2,
@@ -811,7 +811,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(results, elems)
def testBlockingDequeueManyFromClosedQueueWithElementsRemaining(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -845,7 +845,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(len(results), 4)
def testBlockingDequeueManyFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_many(4)
@@ -865,7 +865,7 @@ class RandomShuffleQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_up_to(4)
@@ -885,7 +885,7 @@ class RandomShuffleQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 4, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
close_op = q.close()
@@ -898,7 +898,7 @@ class RandomShuffleQueueTest(test.TestCase):
enqueue_op.run()
def testEnqueueManyToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.RandomShuffleQueue(10, 5, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -912,7 +912,7 @@ class RandomShuffleQueueTest(test.TestCase):
enqueue_op.run()
def testBlockingEnqueueToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(4, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -940,7 +940,7 @@ class RandomShuffleQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueManyToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(4, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -974,7 +974,7 @@ class RandomShuffleQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueToClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(4, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1019,7 +1019,7 @@ class RandomShuffleQueueTest(test.TestCase):
thread1.join()
def testBlockingEnqueueManyToClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(4, 0, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1067,7 +1067,7 @@ class RandomShuffleQueueTest(test.TestCase):
sess.run(blocking_enqueue_op)
def testSharedQueueSameSession(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.RandomShuffleQueue(
1, 0, dtypes_lib.float32, ((),), shared_name="shared_queue")
q1.enqueue((10.0,)).run()
@@ -1104,7 +1104,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), 0)
def testSharedQueueSameSessionGraphSeedNone(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.RandomShuffleQueue(
1,
0,
@@ -1127,7 +1127,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), 1)
def testIncompatibleSharedQueueErrors(self):
- with self.test_session():
+ with self.cached_session():
q_a_1 = data_flow_ops.RandomShuffleQueue(
10, 5, dtypes_lib.float32, shared_name="q_a")
q_a_2 = data_flow_ops.RandomShuffleQueue(
@@ -1193,7 +1193,7 @@ class RandomShuffleQueueTest(test.TestCase):
q_h_2.queue_ref.op.run()
def testSelectQueue(self):
- with self.test_session():
+ with self.cached_session():
num_queues = 10
qlist = list()
for _ in xrange(num_queues):
@@ -1207,7 +1207,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertEqual(q.dequeue().eval(), 10.0)
def testSelectQueueOutOfRange(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.RandomShuffleQueue(10, 0, dtypes_lib.float32)
q2 = data_flow_ops.RandomShuffleQueue(15, 0, dtypes_lib.float32)
enq_q = data_flow_ops.RandomShuffleQueue.from_list(3, [q1, q2])
@@ -1235,7 +1235,7 @@ class RandomShuffleQueueTest(test.TestCase):
sess.run(enqueue_many_op)
def testResetOfBlockingOperation(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
(),))
dequeue_op = q_empty.dequeue()
@@ -1267,7 +1267,7 @@ class RandomShuffleQueueTest(test.TestCase):
t.join()
def testDequeueManyInDifferentOrders(self):
- with self.test_session():
+ with self.cached_session():
# Specify seeds to make the test deterministic
# (https://en.wikipedia.org/wiki/Taxicab_number).
q1 = data_flow_ops.RandomShuffleQueue(
@@ -1301,7 +1301,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertNotEqual(results[i], results[j])
def testDequeueUpToInDifferentOrders(self):
- with self.test_session():
+ with self.cached_session():
# Specify seeds to make the test deterministic
# (https://en.wikipedia.org/wiki/Taxicab_number).
q1 = data_flow_ops.RandomShuffleQueue(
@@ -1335,7 +1335,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertNotEqual(results[i], results[j])
def testDequeueInDifferentOrders(self):
- with self.test_session():
+ with self.cached_session():
# Specify seeds to make the test deterministic
# (https://en.wikipedia.org/wiki/Taxicab_number).
q1 = data_flow_ops.RandomShuffleQueue(
@@ -1371,7 +1371,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertNotEqual(results[i], results[j])
def testBigEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.int32, ((),))
elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
enq = q.enqueue_many((elem,))
@@ -1416,7 +1416,7 @@ class RandomShuffleQueueTest(test.TestCase):
self.assertItemsEqual(elem, results)
def testBigDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.RandomShuffleQueue(2, 0, dtypes_lib.int32, ((),))
elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 5daae1b79b..7bd8c3ca27 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -18,37 +18,77 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class RegexFullMatchOpTest(test.TestCase):
+@parameterized.parameters(
+ (gen_string_ops.regex_full_match),
+ (gen_string_ops.static_regex_full_match))
+class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
- def testRegexFullMatch(self):
+ def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
- def testEmptyMatch(self):
+ def testRegexFullMatchTwoDims(self, op):
+ values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
+ with self.test_session():
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
+ self.assertAllEqual([[True, False], [True, False]], matched)
+
+ def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
- def testInvalidPattern(self):
+ def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
+ input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
- matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+ matched = op(input_tensor, invalid_pattern)
with self.assertRaisesOpError("Invalid pattern"):
matched.eval()
+class RegexFullMatchOpTest(test.TestCase):
+
+ def testRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 1):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
+
+ def testStaticRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 20):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]*"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index 6739ac3224..f0e84b8fca 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -18,54 +18,104 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class RegexReplaceOpTest(test.TestCase):
+@parameterized.parameters(
+ (gen_string_ops.regex_replace),
+ (gen_string_ops.static_regex_replace))
+class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
+
+ def testForwarding(self, op):
+ with self.test_session():
+ # Generate an input that is uniquely consumed by the regex op.
+ # This exercises code paths which are optimized for this case
+ # (e.g., using forwarding).
+ inp = string_ops.substr(
+ constant_op.constant(["AbCdEfG",
+ "HiJkLmN"], dtypes.string),
+ pos=0,
+ len=5)
+ stripped = op(inp, "\\p{Ll}", ".").eval()
+ self.assertAllEqual([b"A.C.E", b"H.J.L"], stripped)
- def testRemovePrefix(self):
+ def testRemovePrefix(self, op):
values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(
- input_vector, "^(a:|b:)", "", replace_global=False).eval()
+ stripped = op(input_vector, "^(a:|b:)", "", replace_global=False).eval()
self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
stripped)
- def testRegexReplace(self):
+ def testRegexReplace(self, op):
values = ["aba\naba", "abcdabcde"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "a.*a", "(\\0)").eval()
+ stripped = op(input_vector, "a.*a", "(\\0)").eval()
self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
- def testEmptyMatch(self):
+ def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "", "x").eval()
+ stripped = op(input_vector, "", "x").eval()
self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
- def testInvalidPattern(self):
+ def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
- replace = string_ops.regex_replace(input_vector, invalid_pattern, "x")
+ replace = op(input_vector, invalid_pattern, "x")
with self.assertRaisesOpError("Invalid pattern"):
replace.eval()
- def testGlobal(self):
+ def testGlobal(self, op):
values = ["ababababab", "abcabcabc", ""]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
- stripped = string_ops.regex_replace(input_vector, "ab", "abc",
- True).eval()
+ stripped = op(input_vector, "ab", "abc", True).eval()
self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
+def as_string(s):
+ return s
+
+
+def as_tensor(s):
+ return constant_op.constant(s, dtypes.string)
+
+
+class RegexReplaceTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ (as_string, as_tensor),
+ (as_tensor, as_string),
+ (as_tensor, as_tensor))
+ def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
+ with compat.forward_compatibility_horizon(2018, 10, 11):
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = pattern_fn("[a-z]")
+ replace = rewrite_fn(".")
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("RegexReplace"))
+
+ def testStaticRegexReplaceDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 10, 11):
+ with self.test_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ replace = "."
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("StaticRegexReplace"))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 3e24b8a2c4..86d9c90e83 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -24,6 +24,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
@@ -72,6 +73,35 @@ class ReluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
+ def _testReluInt8x4(self, np_inputs):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+ np_relu = self._npRelu(np_inputs)
+ with self.test_session(use_gpu=True):
+ relu = nn_ops.relu(constant_op.constant(np_inputs, dtypes.qint8))
+ if np_inputs.size % 4 == 0:
+ tf_relu = relu.eval()
+ self.assertAllClose(np_relu, tf_relu)
+ self.assertShapeEqual(np_relu, relu)
+ else:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Tensor size must be a multiple of 4 for Relu<qint8>. Got %d" %
+ np_inputs.size):
+ tf_relu = relu.eval()
+
+ def testReluInt8x4GoodShape(self):
+ self._testReluInt8x4(np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]]))
+
+ def testReluInt8x4BadShape(self):
+ np_inputs = np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]])
+ self.assertEqual(np_inputs.size, 9)
+ self._testReluInt8x4(np_inputs)
+ np_inputs = np.array(
+ [1, -2, 3, -4, 5, -6, 7, -8, 9, -8, 7, -6, 5, -4, 3, -2, 1])
+ self.assertEqual(np_inputs.size, 17)
+ self._testReluInt8x4(np_inputs)
+
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index b1ef46f2a1..f90545f84c 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import gc
+import os
+import pickle
import numpy as np
@@ -51,7 +54,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(0, len(gc.garbage))
def testHandleDtypeShapeMatch(self):
- with self.test_session():
+ with self.cached_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
with self.assertRaises(ValueError):
resource_variable_ops.assign_variable_op(
@@ -106,6 +109,34 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(False, name="bool_test")
self.assertAllEqual(bool(v), False)
+ def testEagerDeepCopy(self):
+ with context.eager_mode():
+ init_value = np.ones((4, 4, 4))
+ variable = resource_variable_ops.ResourceVariable(init_value,
+ name="init")
+
+ copied_variable = copy.deepcopy(variable)
+ copied_variable.assign(4 * np.ones((4, 4, 4)))
+
+ # Copying the variable should create a new underlying tensor with distinct
+ # values.
+ self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy()))
+
+ def testGraphDeepCopy(self):
+ with self.cached_session():
+ init_value = np.ones((4, 4, 4))
+ variable = resource_variable_ops.ResourceVariable(init_value,
+ name="init")
+ with self.assertRaises(NotImplementedError):
+ copy.deepcopy(variable)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStridedSliceAssign(self):
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(v[0].assign(2.0))
+ self.assertAllEqual(self.evaluate(v), [2.0, 2.0])
+
def testDifferentAssignGraph(self):
with ops.Graph().as_default():
v = resource_variable_ops.ResourceVariable(1.0)
@@ -114,13 +145,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# variable graph.
def testFetchHandle(self):
- with self.test_session():
+ with self.cached_session():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
self.assertGreater(len(handle.eval()), 0)
def testCachedValueReadBeforeWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0")
sess.run(v.initializer)
value, _ = sess.run([v, v.assign_add(1.0)])
@@ -233,6 +264,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
+ def testEagerPickle(self):
+ with context.eager_mode():
+ tmp_dir = self.get_temp_dir()
+ fname = os.path.join(tmp_dir, "var.pickle")
+ with open(fname, "wb") as f:
+ v = resource_variable_ops.ResourceVariable(10.0)
+ pickle.dump(v, f)
+
+ with open(fname, "rb") as f:
+ v = pickle.load(f)
+ self.assertAllEqual(v.numpy(), 10.0)
+
@test_util.run_in_graph_and_eager_modes
def testScatterDiv(self):
handle = resource_variable_ops.var_handle_op(
@@ -449,7 +492,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# TODO(alive): how should this work in Eager mode?
def testInitFn(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(
initial_value=lambda: 1, dtype=dtypes.float32)
self.assertEqual(v.handle.op.colocation_groups(),
@@ -526,11 +569,11 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(2.0, self.evaluate(v.value()))
def testVariableDefInitializedInstances(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v_def = resource_variable_ops.ResourceVariable(
initial_value=constant_op.constant(3.0)).to_proto()
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = resource_variable_ops.ResourceVariable(variable_def=v_def)
self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -541,7 +584,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# Restoring a legacy VariableDef proto that does not have
# initial_value_name set should still work.
v = resource_variable_ops.ResourceVariable(variable_def=v_def)
@@ -572,17 +615,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testSparseRead(self):
- with self.test_session():
- init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
- v = resource_variable_ops.ResourceVariable(
- constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
- self.evaluate(variables.global_variables_initializer())
+ init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
+ v = resource_variable_ops.ResourceVariable(
+ constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
+ self.evaluate(variables.global_variables_initializer())
- value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
- self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
+ value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
+ self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
def testToFromProto(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
@@ -643,7 +685,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
handle, ignore_lookup_error=True))
def testAssignDifferentShapes(self):
- with self.test_session() as sess, variable_scope.variable_scope(
+ with self.cached_session() as sess, variable_scope.variable_scope(
"foo", use_resource=True):
var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32)
placeholder = array_ops.placeholder(dtypes.float32)
@@ -685,7 +727,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
_ = w.value().op.get_attr("_class")
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
v = resource_variable_ops.ResourceVariable(300.0, name="var4")
variables.global_variables_initializer().run()
@@ -703,7 +745,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
def testSharedNameWithNamescope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("foo"):
v = resource_variable_ops.ResourceVariable(300.0, name="var6")
self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access
@@ -731,7 +773,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
def testSetInitialValue(self):
- with self.test_session():
+ with self.cached_session():
# Initialize variable with a value different from the initial value passed
# in the constructor.
v = resource_variable_ops.ResourceVariable(2.0)
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index c72ada11da..a28cdc3b26 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -44,11 +45,13 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_lib
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
+from tensorflow.python.ops.losses import losses
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.training import saver
+from tensorflow.python.training import training
class Plus1RNNCell(rnn_cell_impl.RNNCell):
@@ -194,7 +197,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -214,7 +217,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
if not in_eager_mode:
@@ -226,6 +229,13 @@ class RNNTest(test.TestCase):
self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
self.assertAllEqual(4, state)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ with context.eager_mode():
+ cell = TensorArrayStateRNNCell()
+ inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
+ rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4])
+
@test_util.run_in_graph_and_eager_modes
def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell()
@@ -236,7 +246,7 @@ class RNNTest(test.TestCase):
else:
inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
- with self.test_session() as sess:
+ with self.cached_session(use_gpu=True) as sess:
outputs, state = rnn.dynamic_rnn(
cell, inputs, dtype=dtypes.float32, sequence_length=[4])
state = (state[0], state[1].stack())
@@ -250,12 +260,44 @@ class RNNTest(test.TestCase):
self.assertAllEqual(4, state[0])
self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
+ def testCellGetInitialState(self):
+ cell = rnn_cell_impl.BasicRNNCell(5)
+ with self.assertRaisesRegexp(
+ ValueError, "batch_size and dtype cannot be None"):
+ cell.get_initial_state(None, None, None)
+
+ inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1))
+ with self.assertRaisesRegexp(
+ ValueError, "batch size from input tensor is different from"):
+ cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "batch size from input tensor is different from"):
+ cell.get_initial_state(
+ inputs=inputs, batch_size=constant_op.constant(50), dtype=None)
+
+ with self.assertRaisesRegexp(
+ ValueError, "dtype from input tensor is different from"):
+ cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16)
+
+ initial_state = cell.get_initial_state(
+ inputs=inputs, batch_size=None, dtype=None)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
+ batch = array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ initial_state = cell.get_initial_state(None, batch, dtype)
+ self.assertEqual(initial_state.shape.as_list(), [None, 5])
+ self.assertEqual(initial_state.dtype, inputs.dtype)
+
def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size,
out_size):
cell = cell_class(out_size, dtype=dtype)
in_shape = tensor_shape.TensorShape((batch_size, in_size))
cell.build(in_shape)
- state_output = cell.zero_state(batch_size, dtype)
+ state_output = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
@@ -278,12 +320,228 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3)
self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
+ def testRNNWithKerasSimpleRNNCell(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
+ def testRNNWithKerasGRUCell(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.GRUCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
+ def testRNNWithKerasLSTMCell(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.LSTMCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(len(state), 2)
+ self.assertEqual(state[0].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[0])
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), 2)
+ self.assertEqual(len(state[0]), batch)
+ self.assertEqual(len(state[1]), batch)
+
+ def testRNNWithStackKerasCell(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.StackedRNNCells(
+ [keras.layers.LSTMCell(2 * output_shape),
+ keras.layers.LSTMCell(output_shape)])
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(len(state), 4)
+ self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[2].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[3].shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[2])
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), 4)
+ for s in state:
+ self.assertEqual(len(s), batch)
+
+ def testStaticRNNWithKerasSimpleRNNCell(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ x_train = np.transpose(x_train, (1, 0, 2))
+ y_train = keras.utils.to_categorical(y_train)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+
+ inputs = [array_ops.placeholder(
+ dtypes.float32, shape=(None, input_shape))] * timestep
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.static_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(len(outputs), timestep)
+ self.assertEqual(outputs[0].shape.as_list(), [None, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables_lib.global_variables_initializer()])
+ feed_dict = {i: d for i, d in zip(inputs, x_train)}
+ feed_dict[predict] = y_train
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], feed_dict)
+
+ self.assertEqual(len(outputs), timestep)
+ self.assertEqual(len(outputs[0]), batch)
+ self.assertEqual(len(state), batch)
+
+ def testKerasAndTFRNNLayerOutputComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ weights = fix_weights_generator.get_weights()
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ k_input = keras.Input(shape=(timestep, input_shape),
+ dtype=dtypes.float32)
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ layer = keras.layers.RNN(cell, return_sequences=True, return_state=True)
+ keras_out = layer(k_input)
+ cell.set_weights(weights)
+ k_out, k_state = sess.run(keras_out, {k_input: x_train})
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.test_session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
basic_cell(array_ops.ones([1, 1]),
- state=basic_cell.zero_state(batch_size=1,
- dtype=dtypes.float32))
+ state=basic_cell.get_initial_state(inputs=None,
+ batch_size=1,
+ dtype=dtypes.float32))
self.evaluate([v.initializer for v in basic_cell.variables])
self.evaluate(basic_cell._bias.assign([10.] * 4))
save = saver.Saver()
@@ -293,8 +551,9 @@ class RNNTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
lstm_cell(array_ops.ones([1, 1]),
- state=lstm_cell.zero_state(batch_size=1,
- dtype=dtypes.float32))
+ state=lstm_cell.get_initial_state(inputs=None,
+ batch_size=1,
+ dtype=dtypes.float32))
self.evaluate([v.initializer for v in lstm_cell.variables])
save = saver.Saver()
save.restore(sess, save_path)
@@ -308,7 +567,7 @@ class RNNTest(test.TestCase):
rnn_cell_impl.GRUCell(
32, kernel_initializer="ones", dtype=dtypes.float32)
]:
- with self.test_session():
+ with self.cached_session():
x = keras.Input((None, 5))
layer = keras.layers.RNN(cell)
y = layer(x)
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 4a1fc1d9a9..40d384c623 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -260,6 +261,21 @@ class SliceTest(test.TestCase):
grad_actual = gradients_impl.gradients(out, inp)[0].eval()
self.assertAllClose([0., 1., 1.], grad_actual)
+ def _testGradientVariableSize2D(self):
+ # Regression test for bug in slice. A low-level bug in Eigen was causing
+ # incorrect results for negative indices in multi-dimensional tensors.
+ # See b/114318298.
+ with self.test_session(use_gpu=True) as sess:
+ x = constant_op.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 7]])
+ loss1 = math_ops.reduce_sum(x[:-1, :-1] * 1.0)
+ loss2 = math_ops.reduce_sum(x[:-1][:, :-1])
+
+ g1 = gradients_impl.gradients(loss1, x)[0]
+ g2 = gradients_impl.gradients(loss2, x)[0]
+
+ g1_val, g2_val = sess.run([g1, g2])
+ self.assertAllEqual(g1_val, g2_val)
+
def testGradientsAll(self):
# Slice the middle square out of a 4x4 input
self._testGradientSlice([4, 4], [1, 1], [2, 2])
@@ -276,6 +292,9 @@ class SliceTest(test.TestCase):
# Use -1 as a slice dimension.
self._testGradientVariableSize()
+ # Use -1 as a slice dimension on a 2D tensor.
+ self._testGradientVariableSize2D()
+
def testNotIterable(self):
# NOTE(mrry): If we register __getitem__ as an overloaded
# operator, Python will valiantly attempt to iterate over the
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index b8e7c50a37..c0269db9ae 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
@@ -121,9 +122,12 @@ class SoftplusTest(test.TestCase):
print("softplus (float) third-order gradient err = ", err)
self.assertLess(err, 5e-5)
- def testWarnInts(self):
- # Running the op triggers address sanitizer errors, so we just make it
- nn_ops.softplus(constant_op.constant(7))
+ def testNoInts(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "No OpKernel was registered to support Op 'Softplus'"):
+ nn_ops.softplus(constant_op.constant(7)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index 371f86ff15..a5247ce08d 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -65,11 +66,12 @@ class SoftsignTest(test.TestCase):
print("softsign (float) gradient err = ", err)
self.assertLess(err, 1e-4)
- def testWarnInts(self):
- # NOTE(irving): Actually I don't know how to intercept the warning, but
- # let's make sure it runs. I promised I've looked, and there was a warning.
+ def testNoInts(self):
with self.test_session():
- nn_ops.softsign(constant_op.constant(7)).eval()
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "No OpKernel was registered to support Op 'Softsign'"):
+ nn_ops.softsign(constant_op.constant(7)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index d749843410..3bb5e899fe 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.SparseConditionalAccumulator(
@@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
result = sess.run(accums[i].take_indexed_slices_grad(1))
self._assertEqual_indexedslices(expected_tensors[i], result)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_t = q.take_indexed_slices_grad(1)
val = sess.run(takeg_t)
- self.assertAllEqual(val.indices, [0, 1, 2])
- self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]])
- self.assertAllEqual(val.dense_shape, [-1, 2])
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
+
+ grad_indexed_slices = ops.IndexedSlices(
+ indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
+ accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
+ accum_op.run()
+ accum_op = q.apply_grad([0, 2],
+ np.array([[0, 1], [3, 0]]).astype(np.float32),
+ [3, 2])
+ accum_op.run()
+
+ takeg_t = q.take_indexed_slices_grad(1)
+ val = sess.run(takeg_t)
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self):
with self.test_session() as sess:
@@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
self.assertAllEqual(val.dense_shape, [-1, 2])
- def testParallelApplyGrad(self):
+ def testParallelApplyGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess)
+ def testParallelApplyGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([2, 2]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ accum_ops = []
+ for x in elems:
+ x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
+ accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
+ takeg_t = q.take_indexed_slices_grad(1)
+
+ def apply_indexed_slices_grad(accum_op):
+ sess.run(accum_op)
+
+ threads = [
+ self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
+ for o in accum_ops
+ ]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ val = sess.run(takeg_t)
+
+ expected_val = 550.0
+ self._assertEqual_nparray(
+ np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
+ val, sess)
+
def testParallelTakeGrad(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index 4935ed6ca5..f50e39d6d5 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -157,7 +157,7 @@ class MatMulGradientTest(test.TestCase):
m, [3, 4],
x_init_value=b.eval(),
delta=delta))
- self.assertLess(err, delta / 2.)
+ self.assertLessEqual(err, delta / 2.)
def testGradientInput(self):
for tr_a in [True, False]:
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index cb5a66312f..fc39de150e 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -205,6 +206,22 @@ class SparseMergeTest(test_util.TensorFlowTestCase):
output = sess.run(sp_output)
self._AssertResultsNotSorted(output, vocab_size)
+ def testShouldSetLastDimensionInDynamicShape(self):
+ with ops.Graph().as_default():
+ shape = constant_op.constant([2, 2], dtype=dtypes.int64)
+ dynamic_shape = array_ops.placeholder_with_default(shape, shape=[2])
+ ids = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]],
+ values=[1, 3],
+ dense_shape=dynamic_shape)
+ values = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]],
+ values=[0.4, 0.7],
+ dense_shape=dynamic_shape)
+ merged = sparse_ops.sparse_merge(
+ sp_ids=ids, sp_values=values, vocab_size=5)
+ self.assertEqual(5, merged.get_shape()[1])
+
class SparseMergeHighDimTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py
index 2f27d1839b..2a33c594a4 100644
--- a/tensorflow/python/kernel_tests/stack_op_test.py
+++ b/tensorflow/python/kernel_tests/stack_op_test.py
@@ -277,6 +277,18 @@ class AutomaticStackingTest(test.TestCase):
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=dtypes.float64)
self.assertEqual(dtypes.float64, t_2.dtype)
+ t_3 = ops.convert_to_tensor(
+ [[0., 0., 0.],
+ constant_op.constant([0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
+ ],
+ dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, t_3.dtype)
+
+ t_4 = ops.convert_to_tensor(
+ [constant_op.constant([0., 0., 0.], dtype=dtypes.float64)],
+ dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, t_4.dtype)
+
with self.assertRaises(TypeError):
ops.convert_to_tensor([
constant_op.constant(
@@ -284,17 +296,15 @@ class AutomaticStackingTest(test.TestCase):
[0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
])
- with self.assertRaises(TypeError):
- ops.convert_to_tensor(
- [[0., 0., 0.], constant_op.constant(
- [0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]],
- dtype=dtypes.float32)
+ def testDtypeConversionWhenTensorDtypeMismatch(self):
+ t_0 = ops.convert_to_tensor([0., 0., 0.])
+ self.assertEqual(dtypes.float32, t_0.dtype)
- with self.assertRaises(TypeError):
- ops.convert_to_tensor(
- [constant_op.constant(
- [0., 0., 0.], dtype=dtypes.float64)],
- dtype=dtypes.float32)
+ t_1 = ops.convert_to_tensor([0, 0, 0])
+ self.assertEqual(dtypes.int32, t_1.dtype)
+
+ t_2 = ops.convert_to_tensor([t_0, t_0, t_1], dtype=dtypes.float64)
+ self.assertEqual(dtypes.float64, t_2.dtype)
def testPlaceholder(self):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index e20daccb28..b6a0f45adc 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -58,14 +58,28 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(shape, [3, 5])
def testStringSplitEmptyToken(self):
- strings = [" hello ", "", "world "]
+ strings = ["", " a", "b ", " c", " ", " d ", " e", "f ", " g ", " "]
with self.test_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
- self.assertAllEqual(indices, [[0, 0], [2, 0]])
- self.assertAllEqual(values, [b"hello", b"world"])
- self.assertAllEqual(shape, [3, 1])
+ self.assertAllEqual(
+ indices,
+ [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
+ self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
+ self.assertAllEqual(shape, [10, 1])
+
+ def testStringSplitOnSetEmptyToken(self):
+ strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split(strings, delimiter=" .")
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(
+ indices,
+ [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
+ self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
+ self.assertAllEqual(shape, [10, 1])
def testStringSplitWithDelimiter(self):
strings = ["hello|world", "hello world"]
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 0b3a396d6b..9dcdaa61ed 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -359,6 +360,23 @@ class TemplateTest(test.TestCase):
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
+ model = training.Model()
+ model.template = tmpl1
+ self.assertEqual(model.variables, [v1, v2])
+ self.assertEqual(model.trainable_variables, [v1, v2])
+ self.assertEqual(len(model.non_trainable_variables), 0)
+ model.templates = [tmpl2]
+ self.assertEqual(model.variables, [v1, v2, v5, v6])
+ self.assertEqual(model.trainable_variables, [v1, v2, v5, v6])
+ self.assertEqual(len(model.non_trainable_variables), 0)
+ # Make sure losses, layers, and updates aren't broken by having a Template
+ # in the mix, which does not expose any updates or losses.
+ self.assertEqual([], model.layers)
+ self.assertEqual([], model.updates)
+ self.assertEqual([], model.losses)
+ self.assertEqual([], model.templates.layers)
+ self.assertEqual([], model.templates.updates)
+ self.assertEqual([], model.templates.losses)
@test_util.run_in_graph_and_eager_modes
def test_nested_templates_with_defun(self):
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index fa7c6a0f8a..d5f0726106 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -76,7 +76,7 @@ class TopKTest(test.TestCase):
for result_index, src_index in np.ndenumerate(indices):
value = values[result_index]
expected_value = np_inputs[result_index[0], src_index]
- np.testing.utils.assert_almost_equal(value, expected_value)
+ np.testing.assert_almost_equal(value, expected_value)
# Check that if two elements are equal, the lower-index element appears
# first.
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index ae2a0ab29a..d57b79cb90 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_inspect
class VariableScopeTest(test.TestCase):
@@ -335,7 +336,7 @@ class VariableScopeTest(test.TestCase):
# reuse=True is for now only supported when eager execution is disabled.
if not context.executing_eagerly():
v = variable_scope.get_variable("v",
- []) # "v" is alredy there, reused
+ []) # "v" is already there, reused
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(3, len(losses)) # No new loss added.
@@ -389,6 +390,18 @@ class VariableScopeTest(test.TestCase):
sess.run(v0.initializer)
sess.run(add)
+ def testEnableResourceVariables(self):
+ old = variable_scope._DEFAULT_USE_RESOURCE
+ try:
+ variable_scope.enable_resource_variables()
+ self.assertTrue(isinstance(variables_lib.Variable(1.0),
+ resource_variable_ops.ResourceVariable))
+ variable_scope.disable_resource_variables()
+ self.assertFalse(isinstance(variables_lib.Variable(1.0),
+ resource_variable_ops.ResourceVariable))
+ finally:
+ variable_scope._DEFAULT_USE_RESOURCE = old
+
def testControlFlow(self):
with self.test_session() as sess:
v0 = variable_scope.get_variable(
@@ -983,6 +996,13 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(
variable_scope.get_local_variable("w", []).name, "outer/w:0")
+ def testSignatureGetVarVsGetLocalVar(self):
+ """get_{local,}variable() must take the same list of args."""
+ arg_names = tf_inspect.getargspec(variable_scope.get_variable)[0]
+ local_arg_names = tf_inspect.getargspec(
+ variable_scope.get_local_variable)[0]
+ self.assertEqual(arg_names, local_arg_names)
+
def testGetVarWithDevice(self):
g = ops.Graph()
varname_type = []
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index ab08865532..3ba880d7a1 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -262,11 +262,13 @@ class Layer(base_layer.Layer):
use_resource = (use_resource or
self._use_resource_variables or
scope.use_resource)
+ if initializer is None:
+ initializer = scope.initializer
variable = super(Layer, self).add_weight(
name,
shape,
dtype=dtypes.as_dtype(dtype),
- initializer=initializer or scope.initializer,
+ initializer=initializer,
trainable=trainable,
constraint=constraint,
partitioner=partitioner,
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index 625320b48b..d61d3b6dba 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -264,7 +264,7 @@ class ConvTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv2DInitializerFromScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
height, width = 7, 9
@@ -647,7 +647,7 @@ class SeparableConv2DTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 3)
def testFunctionalConv2DInitializerFromScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
height, width = 7, 9
@@ -882,7 +882,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv2DTransposeInitializerFromScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
height, width = 7, 9
@@ -1061,7 +1061,7 @@ class Conv3DTransposeTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeInitializerFromScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 040c1cddc0..46009a30ac 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -60,7 +60,7 @@ class DenseTest(test.TestCase):
self.assertEqual(dense.name, 'dense_2')
def testVariableInput(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
'X', initializer=init_ops.zeros_initializer(), shape=(1, 1))
x = core_layers.Dense(1)(v)
@@ -221,7 +221,7 @@ class DenseTest(test.TestCase):
self.assertListEqual(dense.losses, loss_keys)
def testFunctionalDense(self):
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = core_layers.dense(
inputs, 2, activation=nn_ops.relu, name='my_dense')
@@ -240,7 +240,7 @@ class DenseTest(test.TestCase):
# TODO(alive): get this to work in eager mode.
def testFunctionalDenseTwiceReuse(self):
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2, name='my_dense')
vars1 = variables.trainable_variables()
@@ -250,7 +250,7 @@ class DenseTest(test.TestCase):
# TODO(alive): get this to work in eager mode.
def testFunctionalDenseTwiceReuseFromScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('scope'):
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2, name='my_dense')
@@ -262,7 +262,8 @@ class DenseTest(test.TestCase):
def testFunctionalDenseInitializerFromScope(self):
with variable_scope.variable_scope(
- 'scope', initializer=init_ops.ones_initializer()), self.test_session():
+ 'scope',
+ initializer=init_ops.ones_initializer()), self.cached_session():
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2)
variables.global_variables_initializer().run()
@@ -305,7 +306,7 @@ class DenseTest(test.TestCase):
self.assertEqual(called[0], 2)
def testFunctionalDenseInScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('test'):
inputs = random_ops.random_uniform((5, 3), seed=1)
core_layers.dense(inputs, 2, name='my_dense')
@@ -391,7 +392,7 @@ class DropoutTest(test.TestCase):
self.assertAllClose(np.ones((5, 3)), np_output)
def testDynamicLearningPhase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dp = core_layers.Dropout(0.5, seed=1)
inputs = array_ops.ones((5, 5))
training = array_ops.placeholder(dtype='bool')
@@ -424,7 +425,7 @@ class DropoutTest(test.TestCase):
self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])
def testFunctionalDropout(self):
- with self.test_session():
+ with self.cached_session():
inputs = array_ops.ones((5, 5))
dropped = core_layers.dropout(inputs, 0.5, training=True, seed=1)
variables.global_variables_initializer().run()
@@ -435,7 +436,7 @@ class DropoutTest(test.TestCase):
self.assertAllClose(np.ones((5, 5)), np_output)
def testDynamicRate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
rate = array_ops.placeholder(dtype='float32', name='rate')
dp = core_layers.Dropout(rate, name='dropout')
inputs = array_ops.ones((5, 5))
@@ -450,7 +451,7 @@ class DropoutTest(test.TestCase):
class FlattenTest(test.TestCase):
def testCreateFlatten(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
@@ -484,7 +485,7 @@ class FlattenTest(test.TestCase):
core_layers.Flatten()(x)
def testFlattenUnknownAxes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index e147f348b0..a72d147a0b 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -72,7 +72,7 @@ class BNTest(test.TestCase):
dtype=dtypes.float32):
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
image = array_ops.placeholder(dtype=dtype, shape=shape)
loss, train_op, saver = self._simple_model(image, is_fused, freeze_mode)
if restore:
@@ -94,7 +94,7 @@ class BNTest(test.TestCase):
dtype = image_val.dtype
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
image = array_ops.placeholder(dtype=dtype, shape=shape)
loss, _, saver = self._simple_model(image, is_fused, True)
saver.restore(sess, checkpoint_path)
@@ -319,7 +319,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
@@ -361,7 +361,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -442,7 +442,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -482,7 +482,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -522,7 +522,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -563,7 +563,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -603,7 +603,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -644,7 +644,7 @@ class BNTest(test.TestCase):
outputs_training = bn.apply(inputs, training=True)
outputs_infer = bn.apply(inputs, training=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
@@ -694,7 +694,7 @@ class BNTest(test.TestCase):
beta = all_vars['bn/beta:0']
gamma = all_vars['bn/gamma:0']
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
np_gamma, np_beta = sess.run([gamma, beta])
@@ -756,7 +756,7 @@ class BNTest(test.TestCase):
beta = all_vars['bn/beta:0']
gamma = all_vars['bn/gamma:0']
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
for _ in range(100):
@@ -1254,7 +1254,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
@@ -1294,7 +1294,7 @@ class BNTest(test.TestCase):
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test training with placeholder learning phase.
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index fc02d6de0e..6189503d8f 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -398,7 +398,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
TF_RETURN_IF_ERROR(NumericNpDTypeToTfDType(PyArray_TYPE(input), &dtype));
CHECK(DataTypeCanUseMemcpy(dtype));
if (reinterpret_cast<intptr_t>(PyArray_DATA(input)) %
- EIGEN_MAX_ALIGN_BYTES !=
+ std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
0) {
Tensor t(dtype, shape);
StringPiece p = t.tensor_data();
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 3b4f12ae31..269142a7c2 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -55,6 +55,10 @@ bool IsPyDouble(PyObject* obj) {
return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
}
+bool IsNumpyHalf(PyObject* obj) {
+ return PyIsInstance(obj, &PyHalfArrType_Type);
+}
+
bool IsPyFloat(PyObject* obj) {
return PyFloat_Check(obj) ||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
@@ -156,6 +160,8 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
}
} else if (IsPyDouble(obj)) {
*dtype = DT_DOUBLE;
+ } else if (IsNumpyHalf(obj)) {
+ *dtype = DT_HALF;
} else if (IsPyFloat(obj)) {
*dtype = DT_FLOAT;
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
@@ -357,6 +363,17 @@ const char* ConvertOneFloat(PyObject* v, T* out) {
DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
+const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
+ // NOTE(nareshmodi): Is there a way to convert to C double without the
+ // intermediate Python double? This will help with ConvertOneFloat as well.
+ Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
+ double v_double = PyFloat_AS_DOUBLE(as_float.get());
+ *out = Eigen::half(v_double);
+
+ return nullptr;
+}
+DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf);
+
// String support
const char* ConvertOneString(PyObject* v, string* out) {
@@ -452,6 +469,9 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
break;
+ case DT_HALF:
+ RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
case DT_INT64:
if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
break;
@@ -489,8 +509,13 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
// final type.
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
}
+
case DT_DOUBLE:
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
+
+ case DT_HALF:
+ RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
case DT_INT64:
if (requested_dtype == DT_INVALID) {
const char* error = ConvertInt32(obj, shape, ret);
diff --git a/tensorflow/python/lib/core/py_util.h b/tensorflow/python/lib/core/py_util.h
index 44dfe7ba21..a9f39d3946 100644
--- a/tensorflow/python/lib/core/py_util.h
+++ b/tensorflow/python/lib/core/py_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
-#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_UTIL_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PY_UTIL_H_
#include "tensorflow/core/platform/types.h"
@@ -24,4 +24,4 @@ namespace tensorflow {
string PyExceptionFetch();
} // end namespace tensorflow
-#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_UTIL_H_
diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i
index 891a7b0fd0..0aa08ea3d1 100644
--- a/tensorflow/python/lib/io/file_io.i
+++ b/tensorflow/python/lib/io/file_io.i
@@ -42,7 +42,7 @@ inline void FileExists(const string& filename, TF_Status* out_status) {
inline void FileExists(const tensorflow::StringPiece& filename,
TF_Status* out_status) {
tensorflow::Status status =
- tensorflow::Env::Default()->FileExists(filename.ToString());
+ tensorflow::Env::Default()->FileExists(string(filename));
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
}
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 9500fc6a7c..07ce071845 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -30,6 +30,8 @@ namespace io {
PyRecordReader::PyRecordReader() {}
+// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
+// RecordReaderOptions, if this changes the API can be updated at that time.
PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
const string& compression_type_string,
TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index 3c64813735..faf20df868 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -28,7 +28,7 @@ namespace io {
PyRecordWriter::PyRecordWriter() {}
PyRecordWriter* PyRecordWriter::New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& options,
TF_Status* out_status) {
std::unique_ptr<WritableFile> file;
Status s = Env::Default()->NewWritableFile(filename, &file);
@@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
}
PyRecordWriter* writer = new PyRecordWriter;
writer->file_ = std::move(file);
-
- RecordWriterOptions options =
- RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
-
writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
return writer;
}
@@ -52,10 +48,17 @@ PyRecordWriter::~PyRecordWriter() {
file_.reset();
}
-bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
- if (writer_ == nullptr) return false;
+void PyRecordWriter::WriteRecord(tensorflow::StringPiece record,
+ TF_Status* out_status) {
+ if (writer_ == nullptr) {
+ TF_SetStatus(out_status, TF_FAILED_PRECONDITION,
+ "Writer not initialized or previously closed");
+ return;
+ }
Status s = writer_->WriteRecord(record);
- return s.ok();
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ }
}
void PyRecordWriter::Flush(TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index 9d66c031d4..9b0792c6db 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -36,14 +37,12 @@ class RecordWriter;
// by multiple threads.
class PyRecordWriter {
public:
- // TODO(vrv): make this take a shared proto to configure
- // the compression options.
static PyRecordWriter* New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& compression_options,
TF_Status* out_status);
~PyRecordWriter();
- bool WriteRecord(tensorflow::StringPiece record);
+ void WriteRecord(tensorflow::StringPiece record, TF_Status* out_status);
void Flush(TF_Status* out_status);
void Close(TF_Status* out_status);
diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i
index 3181c9afce..b2c2bda5dd 100644
--- a/tensorflow/python/lib/io/py_record_writer.i
+++ b/tensorflow/python/lib/io/py_record_writer.i
@@ -18,6 +18,11 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%include "tensorflow/python/lib/core/strings.i"
+// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
+// and "stdint.i" disagree on the definition of int64_t.
+typedef signed char int8;
+%{ typedef signed char int8; %}
+
%feature("except") tensorflow::io::PyRecordWriter::New {
// Let other threads run while we write
Py_BEGIN_ALLOW_THREADS
@@ -26,6 +31,7 @@ limitations under the License.
}
%newobject tensorflow::io::PyRecordWriter::New;
+%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
// Let other threads run while we write
@@ -35,6 +41,8 @@ limitations under the License.
}
%{
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/python/lib/io/py_record_writer.h"
%}
@@ -48,7 +56,21 @@ limitations under the License.
%unignore tensorflow::io::PyRecordWriter::Flush;
%unignore tensorflow::io::PyRecordWriter::Close;
%unignore tensorflow::io::PyRecordWriter::New;
+%unignore tensorflow::io::ZlibCompressionOptions;
+%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
+%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
+%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
+%unignore tensorflow::io::RecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::zlib_options;
+%include "tensorflow/core/lib/io/record_writer.h"
+%include "tensorflow/core/lib/io/zlib_compression_options.h"
%include "tensorflow/python/lib/io/py_record_writer.h"
%unignoreall
diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py
index aec12ab3ea..404423ce07 100644
--- a/tensorflow/python/lib/io/python_io.py
+++ b/tensorflow/python/lib/io/python_io.py
@@ -15,7 +15,7 @@
"""Python functions for directly manipulating TFRecord-formatted files.
-See the @{$python/python_io} guide.
+See the [Python IO](https://tensorflow.org/api_guides/python/python_io) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 941d6cd67c..cce71a2bab 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -33,8 +33,6 @@ class TFRecordCompressionType(object):
GZIP = 2
-# NOTE(vrv): This will eventually be converted into a proto. to match
-# the interface used by the C++ RecordWriter.
@tf_export("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
@@ -44,14 +42,105 @@ class TFRecordOptions(object):
TFRecordCompressionType.NONE: ""
}
- def __init__(self, compression_type):
+ def __init__(self,
+ compression_type=None,
+ flush_mode=None,
+ input_buffer_size=None,
+ output_buffer_size=None,
+ window_bits=None,
+ compression_level=None,
+ compression_method=None,
+ mem_level=None,
+ compression_strategy=None):
+ # pylint: disable=line-too-long
+ """Creates a `TFRecordOptions` instance.
+
+ Options only effect TFRecordWriter when compression_type is not `None`.
+ Documentation, details, and defaults can be found in
+ [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
+ and in the [zlib manual](http://www.zlib.net/manual.html).
+ Leaving an option as `None` allows C++ to set a reasonable default.
+
+ Args:
+ compression_type: `TFRecordCompressionType` or `None`.
+ flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
+ input_buffer_size: int or `None`.
+ output_buffer_size: int or `None`.
+ window_bits: int or `None`.
+ compression_level: 0 to 9, or `None`.
+ compression_method: compression method or `None`.
+ mem_level: 1 to 9, or `None`.
+ compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
+
+ Returns:
+ A `TFRecordOptions` object.
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
+ # pylint: enable=line-too-long
+ # Check compression_type is valid, but for backwards compatibility don't
+ # immediately convert to a string.
+ self.get_compression_type_string(compression_type)
self.compression_type = compression_type
+ self.flush_mode = flush_mode
+ self.input_buffer_size = input_buffer_size
+ self.output_buffer_size = output_buffer_size
+ self.window_bits = window_bits
+ self.compression_level = compression_level
+ self.compression_method = compression_method
+ self.mem_level = mem_level
+ self.compression_strategy = compression_strategy
@classmethod
def get_compression_type_string(cls, options):
+ """Convert various option types to a unified string.
+
+ Args:
+ options: `TFRecordOption`, `TFRecordCompressionType`, or string.
+
+ Returns:
+ Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
if not options:
return ""
- return cls.compression_type_map[options.compression_type]
+ elif isinstance(options, TFRecordOptions):
+ return cls.get_compression_type_string(options.compression_type)
+ elif isinstance(options, TFRecordCompressionType):
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map:
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map.values():
+ return options
+ else:
+ raise ValueError('Not a valid compression_type: "{}"'.format(options))
+
+ def _as_record_writer_options(self):
+ """Convert to RecordWriterOptions for use with PyRecordWriter."""
+ options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
+ compat.as_bytes(
+ self.get_compression_type_string(self.compression_type)))
+
+ if self.flush_mode is not None:
+ options.zlib_options.flush_mode = self.flush_mode
+ if self.input_buffer_size is not None:
+ options.zlib_options.input_buffer_size = self.input_buffer_size
+ if self.output_buffer_size is not None:
+ options.zlib_options.output_buffer_size = self.output_buffer_size
+ if self.window_bits is not None:
+ options.zlib_options.window_bits = self.window_bits
+ if self.compression_level is not None:
+ options.zlib_options.compression_level = self.compression_level
+ if self.compression_method is not None:
+ options.zlib_options.compression_method = self.compression_method
+ if self.mem_level is not None:
+ options.zlib_options.mem_level = self.mem_level
+ if self.compression_strategy is not None:
+ options.zlib_options.compression_strategy = self.compression_strategy
+ return options
@tf_export("python_io.tf_record_iterator")
@@ -100,16 +189,21 @@ class TFRecordWriter(object):
Args:
path: The path to the TFRecords file.
- options: (optional) A TFRecordOptions object.
+ options: (optional) String specifying compression type,
+ `TFRecordCompressionType`, or `TFRecordOptions` object.
Raises:
IOError: If `path` cannot be opened for writing.
+ ValueError: If valid compression_type can't be determined from `options`.
"""
- compression_type = TFRecordOptions.get_compression_type_string(options)
+ if not isinstance(options, TFRecordOptions):
+ options = TFRecordOptions(compression_type=options)
with errors.raise_exception_on_not_ok_status() as status:
+ # pylint: disable=protected-access
self._writer = pywrap_tensorflow.PyRecordWriter_New(
- compat.as_bytes(path), compat.as_bytes(compression_type), status)
+ compat.as_bytes(path), options._as_record_writer_options(), status)
+ # pylint: enable=protected-access
def __enter__(self):
"""Enter a `with` block."""
@@ -125,8 +219,8 @@ class TFRecordWriter(object):
Args:
record: str
"""
- # TODO(sethtroisi): Failures are currently swallowed, change that.
- self._writer.WriteRecord(record)
+ with errors.raise_exception_on_not_ok_status() as status:
+ self._writer.WriteRecord(record, status)
def flush(self):
"""Flush the file."""
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index 4743c037ec..def8fe23e5 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import gzip
import os
+import random
+import string
import zlib
import six
@@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase):
class TFRecordWriterTest(TFCompressionTestCase):
- def setUp(self):
- super(TFRecordWriterTest, self).setUp()
-
def _AssertFilesEqual(self, a, b, equal):
for an, bn in zip(a, b):
with open(an, "rb") as af, open(bn, "rb") as bf:
@@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase):
else:
self.assertNotEqual(af.read(), bf.read())
+ def _CompressionSizeDelta(self, records, options_a, options_b):
+ """Validate compression with options_a and options_b and return size delta.
+
+ Compress records with options_a and options_b. Uncompress both compressed
+ files and assert that the contents match the original records. Finally
+ calculate how much smaller the file compressed with options_a was than the
+ file compressed with options_b.
+
+ Args:
+ records: The records to compress
+ options_a: First set of options to compress with, the baseline for size.
+ options_b: Second set of options to compress with.
+
+ Returns:
+ The difference in file size when using options_a vs options_b. A positive
+ value means options_a was a better compression than options_b. A negative
+ value means options_b had better compression than options_a.
+
+ """
+
+ fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a)
+ test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a))
+ self.assertEqual(records, test_a, options_a)
+
+ fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b)
+ test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b))
+ self.assertEqual(records, test_b, options_b)
+
+ # Negative number => better compression.
+ return os.path.getsize(fn_a) - os.path.getsize(fn_b)
+
def testWriteReadZLibFiles(self):
# Write uncompressed then compress manually.
options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
@@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase):
]
self._AssertFilesEqual(uncompressed_files, files, True)
+ def testNoCompressionType(self):
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions()))
+
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("")))
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions(5)
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions("BZ2")
+
+ def testZlibCompressionType(self):
+ zlib_t = tf_record.TFRecordCompressionType.ZLIB
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("ZLIB")))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(zlib_t)))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t))))
+
+ def testCompressionOptions(self):
+ # Create record with mix of random and repeated data to test compression on.
+ rnd = random.Random(123)
+ random_record = compat.as_bytes(
+ "".join(rnd.choice(string.digits) for _ in range(10000)))
+ repeated_record = compat.as_bytes(_TEXT)
+ for _ in range(10000):
+ start_i = rnd.randint(0, len(_TEXT))
+ length = rnd.randint(10, 200)
+ repeated_record += _TEXT[start_i:start_i + length]
+ records = [random_record, repeated_record, random_record]
+
+ tests = [
+ ("compression_level", 2, -1), # Lower compression is worse.
+ ("compression_level", 6, 0), # Default compression_level is equal.
+ ("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes.
+ ("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default.
+ ("input_buffer_size", 4096, 0), # Increases time not size.
+ ("output_buffer_size", 4096, 0), # Increases time not size.
+ ("window_bits", 8, -1), # Smaller than default window increases size.
+ ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse.
+ ("compression_strategy", zlib.Z_FILTERED, -1), # Worse.
+ ]
+
+ compression_type = tf_record.TFRecordCompressionType.ZLIB
+ options_a = tf_record.TFRecordOptions(compression_type)
+ for prop, value, delta_sign in tests:
+ options_b = tf_record.TFRecordOptions(
+ compression_type=compression_type, **{prop: value})
+ delta = self._CompressionSizeDelta(records, options_a, options_b)
+ self.assertTrue(
+ delta == 0 if delta_sign == 0 else delta // delta_sign > 0,
+ "Setting {} = {}, file was {} smaller didn't match sign of {}".format(
+ prop, value, delta, delta_sign))
+
class TFRecordWriterZlibTest(TFCompressionTestCase):
@@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
+
class TFRecordWriterCloseAndFlushTests(test.TestCase):
def setUp(self, compression_type=TFRecordCompressionType.NONE):
@@ -358,12 +459,12 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase):
with self.assertRaises(errors_impl.FailedPreconditionError):
self._writer.flush()
- def testWriteAfterClose(self):
+ def testWriteAfterCloseIsError(self):
self._writer.write(self._Record(0))
self._writer.close()
- # TODO(sethtroisi): No way to know this failed, changed that.
- self._writer.write(self._Record(1))
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ self._writer.write(self._Record(1))
class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests):
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index a2b5f77f91..ade86e85bf 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from math import ceil
-
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -734,7 +732,6 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
@ops.RegisterGradient("ExtractImagePatches")
def _ExtractImagePatchesGrad(op, grad):
-
batch_size, rows_in, cols_in, channels = [
dim.value for dim in op.inputs[0].get_shape()
]
@@ -742,28 +739,45 @@ def _ExtractImagePatchesGrad(op, grad):
batch_size = input_bhwc[0]
channels = input_bhwc[3]
+ # Create indices matrix for input tensor.
+ # Note that 0 is preserved for padding location,
+ # so indices for input start from 1 to 1 + rows_in * cols_in.
+ input_indices_num = 1 + rows_in * cols_in
+ input_idx = array_ops.reshape(math_ops.range(1, input_indices_num,
+ dtype=ops.dtypes.int64),
+ (1, rows_in, cols_in, 1))
+ input_idx_patched = gen_array_ops.extract_image_patches(
+ input_idx,
+ op.get_attr("ksizes"),
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding"))
+
+ # Create indices matrix for output tensor.
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()]
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
- _, stride_r, stride_h, _ = op.get_attr("strides")
- _, rate_r, rate_c, _ = op.get_attr("rates")
- padding = op.get_attr("padding")
-
- ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
- ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)
-
- if padding == b"SAME":
- rows_out = int(ceil(rows_in / stride_r))
- cols_out = int(ceil(cols_in / stride_h))
- pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
- pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2
-
- elif padding == b"VALID":
- rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
- cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
- pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
- pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in
-
- pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)
+ # Indices for output start from 0.
+ output_indices_num = rows_out * cols_out * ksize_r * ksize_c
+ output_idx = array_ops.reshape(math_ops.range(output_indices_num,
+ dtype=ops.dtypes.int64),
+ (1, rows_out, cols_out, ksize_r * ksize_c))
+
+ # Construct mapping table for indices: (input -> output).
+ idx_matrix = array_ops.concat(
+ [array_ops.expand_dims(input_idx_patched, axis=-1),
+ array_ops.expand_dims(output_idx, axis=-1)],
+ axis=-1)
+ idx_map = array_ops.reshape(idx_matrix, (-1, 2))
+
+ sp_shape = (input_indices_num, output_indices_num)
+ sp_mat_full = sparse_tensor.SparseTensor(
+ idx_map,
+ array_ops.ones([output_indices_num], dtype=grad.dtype),
+ sp_shape)
+ # Remove all padding locations [0, :].
+ sp_mat = sparse_ops.sparse_slice(sp_mat_full,
+ (1, 0),
+ (input_indices_num - 1, output_indices_num))
grad_expanded = array_ops.transpose(
array_ops.reshape(
@@ -771,27 +785,6 @@ def _ExtractImagePatchesGrad(op, grad):
(1, 2, 3, 4, 0, 5))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
- row_steps = range(0, rows_out * stride_r, stride_r)
- col_steps = range(0, cols_out * stride_h, stride_h)
-
- idx = []
- for i in range(rows_out):
- for j in range(cols_out):
- r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
- r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff
-
- idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j *
- (ksize_r * ksize_c) + ri * (ksize_c) + ci)
- for (ri, r) in enumerate(range(r_low, r_high, rate_r))
- for (ci, c) in enumerate(range(c_low, c_high, rate_c))
- if 0 <= r and r < rows_in and 0 <= c and c < cols_in])
-
- sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c)
-
- sp_mat = sparse_tensor.SparseTensor(
- array_ops.constant(idx, dtype=ops.dtypes.int64),
- array_ops.ones((len(idx),), dtype=grad.dtype), sp_shape)
-
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
@@ -812,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad):
indices = op.inputs[1]
updates_grad = array_ops.gather_nd(grad, indices)
return [grad, None, updates_grad]
+
+
+@ops.RegisterGradient("BroadcastTo")
+def _BroadcastToGrad(op, grad):
+ input_value = op.inputs[0]
+ broadcast_shape = op.inputs[1]
+ # Assign ids for each position in input_value.
+ input_value_shape = array_ops.shape(input_value)
+ input_value_size = array_ops.size(input_value)
+ ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
+ broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
+ # Group by ids and sum its gradients.
+ grad_flatten = array_ops.reshape(grad, [-1])
+ broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
+ updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
+ broadcast_ids_flatten,
+ input_value_size)
+ updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
+ return [updates_grad, None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 4b096cb73d..c8b883350d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -15,7 +15,7 @@
# Tests for this file live in python/kernel_tests/array_ops_test.py
"""Support for manipulating tensors.
-See the @{$python/array_ops} guide.
+See the [Array Ops](https://tensorflow.org/api_guides/python/array_ops) guide.
"""
from __future__ import absolute_import
@@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.util import deprecation
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -691,30 +692,28 @@ def strided_slice(input_,
parent_name = name
- def assign(val, name=None):
- """Closure that holds all the arguments to create an assignment."""
+ if not (var is None and isinstance(op, ops.EagerTensor)):
+ def assign(val, name=None):
+ """Closure that holds all the arguments to create an assignment."""
- if var is None:
- raise ValueError("Sliced assignment is only supported for variables")
+ if var is None:
+ raise ValueError("Sliced assignment is only supported for variables")
- if name is None:
- name = parent_name + "_assign"
+ if name is None:
+ name = parent_name + "_assign"
- return var._strided_slice_assign(
- begin=begin,
- end=end,
- strides=strides,
- value=val,
- name=name,
- begin_mask=begin_mask,
- end_mask=end_mask,
- ellipsis_mask=ellipsis_mask,
- new_axis_mask=new_axis_mask,
- shrink_axis_mask=shrink_axis_mask)
+ return var._strided_slice_assign(
+ begin=begin,
+ end=end,
+ strides=strides,
+ value=val,
+ name=name,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask)
- if not context.executing_eagerly():
- # TODO(apassos) In eager mode assignment will be done by overriding
- # __setitem__ instead.
op.assign = assign
return op
@@ -947,6 +946,15 @@ def _get_dtype_from_nested_lists(list_or_tuple):
return None
+def _cast_nested_seqs_to_dtype(dtype):
+ def _maybe_cast(elem):
+ if ops.is_dense_tensor_like(elem):
+ if dtype != elem.dtype.base_dtype:
+ elem = gen_math_ops.cast(elem, dtype)
+ return elem
+ return _maybe_cast
+
+
def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
"""Tensor conversion function that automatically packs arguments."""
if as_ref:
@@ -956,9 +964,11 @@ def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
# We did not find any tensor-like objects in the nested lists, so defer to
# other conversion functions.
return NotImplemented
- if dtype is not None and dtype != inferred_dtype:
- return NotImplemented
- return _autopacking_helper(v, inferred_dtype, name or "packed")
+ if dtype is None:
+ dtype = inferred_dtype
+ elif dtype != inferred_dtype:
+ v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
+ return _autopacking_helper(v, dtype, name or "packed")
# pylint: enable=invalid-name
@@ -1265,7 +1275,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor into sub tensors.
- If `num_or_size_splits` is an integer type, `num_split`, then splits `value`
+ If `num_or_size_splits` is an integer type, then `value` is split
along dimension `axis` into `num_split` smaller tensors.
Requires that `num_split` evenly divides `value.shape[axis]`.
@@ -1714,7 +1724,7 @@ def placeholder(dtype, shape=None, name=None):
@compatibility(eager)
Placeholders are not compatible with eager execution.
@end_compatibility
-
+
Args:
dtype: The type of elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 375a5ec2c3..c3cf6e61f2 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -15,7 +15,8 @@
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks.
-See the @{$python/check_ops} guide.
+See the [Asserts and
+checks](https://tensorflow.org/api_guides/python/check_ops) guide.
"""
from __future__ import absolute_import
@@ -29,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1242,3 +1244,57 @@ def assert_scalar(tensor, name=None):
raise ValueError('Expected scalar shape for %s, saw shape: %s.'
% (tensor.name, shape))
return tensor
+
+
+@tf_export('ensure_shape')
+def ensure_shape(x, shape, name=None):
+ """Updates the shape of a tensor and checks at runtime that the shape holds.
+
+ For example:
+ ```python
+ x = tf.placeholder(tf.int32)
+ print(x.shape)
+ ==> TensorShape(None)
+ y = x * 2
+ print(y.shape)
+ ==> TensorShape(None)
+
+ y = tf.ensure_shape(y, (None, 3, 3))
+ print(y.shape)
+ ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
+
+ with tf.Session() as sess:
+ # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
+ # compatible with the shape (None, 3, 3)
+ sess.run(y, feed_dict={x: [1, 2, 3]})
+
+ ```
+
+ NOTE: This differs from `Tensor.set_shape` in that it sets the static shape
+ of the resulting tensor and enforces it at runtime, raising an error if the
+ tensor's runtime shape is incompatible with the specified shape.
+ `Tensor.set_shape` sets the static shape of the tensor without enforcing it
+ at runtime, which may result in inconsistencies between the statically-known
+ shape of tensors and the runtime value of tensors.
+
+ Args:
+ x: A `Tensor`.
+ shape: A `TensorShape` representing the shape of this tensor, a
+ `TensorShapeProto`, a list, a tuple, or None.
+ name: A name for this operation (optional). Defaults to "EnsureShape".
+
+ Returns:
+ A `Tensor`. Has the same type and contents as `x`. At runtime, raises a
+ `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape
+ of `x`.
+ """
+ if not isinstance(shape, tensor_shape.TensorShape):
+ shape = tensor_shape.TensorShape(shape)
+
+ return array_ops.ensure_shape(x, shape, name=name)
+
+
+@ops.RegisterGradient('EnsureShape')
+def _ensure_shape_grad(op, grad):
+ del op # Unused.
+ return grad
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 78b395a6c1..29468431b3 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -144,7 +144,11 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
t = ops.convert_to_tensor(t, name="t")
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
- l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True))
+ l2sum = math_ops.reduce_sum(t * t, axes, keepdims=True)
+ pred = l2sum > 0
+ # Two-tap tf.where trick to bypass NaN gradients
+ l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum))
+ l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum)
intermediate = t * clip_norm
# Assert that the shape is compatible with the initial shape,
# to prevent unintentional broadcasting.
diff --git a/tensorflow/python/ops/clip_ops_test.py b/tensorflow/python/ops/clip_ops_test.py
index 7d8dc90491..444cd0f62c 100644
--- a/tensorflow/python/ops/clip_ops_test.py
+++ b/tensorflow/python/ops/clip_ops_test.py
@@ -30,7 +30,7 @@ class ClipOpsTest(test.TestCase):
super(ClipOpsTest, self).__init__(method_name)
def _testClipByNorm(self, inputs, max_norm, expected):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_op = constant_op.constant(inputs)
clipped = clip_ops.clip_by_norm(input_op, max_norm)
check_op = numerics.add_check_numerics_ops()
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 9cc64ef9f6..78c4b4bfe0 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class CollectiveOpTest(test.TestCase):
- def _testCollectiveReduce(self, t0, t1, expected):
+ def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
group_key = 1
instance_key = 1
with self.test_session(
@@ -43,7 +43,8 @@ class CollectiveOpTest(test.TestCase):
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
'Add', 'Div')
run_options = config_pb2.RunOptions()
- run_options.experimental.collective_graph_key = 1
+ if set_graph_key:
+ run_options.experimental.collective_graph_key = 1
results = sess.run([colred0, colred1], options=run_options)
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
@@ -51,7 +52,15 @@ class CollectiveOpTest(test.TestCase):
def testCollectiveReduce(self):
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
- [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
+ [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
+
+ def testCollectiveAutoGraphKey(self):
+ self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+ [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
+ [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
+
+ def testCollectiveReduceScalar(self):
+ self._testCollectiveReduce(0.1, 0.3, 0.2, True)
def _testCollectiveBroadcast(self, t0):
group_key = 1
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 76173e0f30..75a1a53eb7 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -24,7 +24,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.ops import gradients_impl
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index b3dacff6d6..c6a6b2a7fa 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -27,14 +27,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python import pywrap_tensorflow as c_api
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
-from tensorflow.python.util import compat
# The following modules cannot be imported directly because they cause circular
@@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
name = "cond"
with ops.name_scope(name) as scope:
- # Identify if there is a caller device, & get the innermost if possible.
- # pylint: disable=protected-access
- device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
- caller_device = device_funcs[-1] if device_funcs else None
-
- caller_colocation_stack = ops.get_default_graph()._colocation_stack
- caller_container = ops.get_default_graph()._container
- caller_collection_ref = ops.get_default_graph()._collections
-
with ops.name_scope(None):
# Find the outer most graph for uniquing function names.
# TODO(jpienaar): Make this work in eager mode.
graph = ops.get_default_graph()
- while isinstance(graph, _function._FuncGraph):
- graph = graph._outer_graph
+ while isinstance(graph, _function.FuncGraph):
+ graph = graph.outer_graph
true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))
- # pylint: enable=protected-access
+
true_graph = _function.func_graph_from_py_func(
- true_fn, [], [],
- name=true_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ true_name, true_fn, [], {})
false_graph = _function.func_graph_from_py_func(
- false_fn, [], [],
- name=false_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ false_name, false_fn, [], {})
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match(true_graph, false_graph,
- true_graph.extra_inputs,
- false_graph.extra_inputs)
+ true_graph.external_captures,
+ false_graph.external_captures)
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation.
@@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
true_graph, false_graph = _get_func_graphs(op)
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
- assert true_graph._outer_graph == op.graph
- assert false_graph._outer_graph == op.graph
+ assert true_graph.outer_graph == op.graph
+ assert false_graph.outer_graph == op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
@@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
- true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
- false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
+ true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
+ false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# Make the inputs to true_grad_graph and false_grad_graph match. Note that
# this modifies true_grad_graph and false_grad_graph.
grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_extra_inputs,
- false_grad_extra_inputs)
+ true_grad_inputs, false_grad_inputs)
# Add all intermediate tensors as function outputs so they're available for
# higher-order gradient computations.
@@ -201,32 +180,31 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
def _get_func_graphs(if_op):
- """Returns `_FuncGraph`s for the input op branches.
+ """Returns `FuncGraph`s for the input op branches.
Args:
if_op: The _If Operation.
Returns:
- A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch.
+ A 2-tuple of the `FuncGraph`s of the then_branch and else_branch.
"""
def _get_func_graph_for_branch(branch_name):
- """Generates and returns a _FuncGraph for the given branch."""
- extra_inputs = if_op.inputs[1:] # First input is pred.
- input_shapes = [t.shape for t in extra_inputs]
+ """Generates and returns a FuncGraph for the given branch."""
+ inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name
fdef = if_op.graph._get_function(func_name).definition
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
# in the case of nested if ops or when the gradient is being computed
# from inside a Defun. We build the `func_graph` with `if_op.graph` as its
- # `outer_graph`. This resembles how the `_FuncGraph` was built in the
+ # `outer_graph`. This resembles how the `FuncGraph` was built in the
# forward pass. We need this so that we can resolve references to tensors
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
with if_op.graph.as_default():
func_graph = _function_def_to_graph.function_def_to_graph(
fdef, input_shapes)
- func_graph.extra_inputs = extra_inputs
- func_graph.extra_args = func_graph.inputs
- func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ func_graph.captures = collections.OrderedDict(zip(inputs,
+ func_graph.inputs))
# Set the if op so that the gradient code can use it.
func_graph._if = if_op
return func_graph
@@ -243,7 +221,7 @@ def _grad_fn(func_graph, grads):
func_graph's outputs w.r.t. its inputs.
Args:
- func_graph: function._FuncGraph. The corresponding forward-pass function.
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
grads: The list of input gradient Tensors.
Returns:
@@ -281,13 +259,13 @@ def _grad_fn(func_graph, grads):
def _create_grad_func(func_graph, grads, name):
- """Returns the _FuncGraph representation of _grad_fn."""
- return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads),
- [], [], name)
+ """Returns the FuncGraph representation of _grad_fn."""
+ return _function.func_graph_from_py_func(
+ name, lambda: _grad_fn(func_graph, grads), [], {})
def _resolve_grad_inputs(cond_graph, grad_graph):
- """Returns the tensors to pass as `extra_inputs` to `grad_graph`.
+ """Returns the tensors to pass as inputs to `grad_graph`.
The `grad_graph` may have external references to
1. Its outer graph containing the input gradients. These references are kept
@@ -299,16 +277,16 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
functions, this is always possible.
Args:
- cond_graph: function._FuncGraph. The forward-pass function.
- grad_graph: function._FuncGraph. The gradients function.
+ cond_graph: function.FuncGraph. The forward-pass function.
+ grad_graph: function.FuncGraph. The gradients function.
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
- new_extra_inputs = []
+ new_inputs = []
- for t in grad_graph.extra_inputs:
- if t.graph != grad_graph._outer_graph:
+ for t in grad_graph.external_captures:
+ if t.graph != grad_graph.outer_graph:
# `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
# tensor to the least common ancestor of the `cond_graph` and
# `grad_graph` so that it is "in-scope" for `grad_graph`.
@@ -316,50 +294,33 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
# common ancestor once and re-use.
assert _is_ancestor(cond_graph, t.graph)
while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function._FuncGraph)
- if t in t.graph.extra_args:
- # TODO(srbs): Consider building a map of extra_args -> extra_inputs.
- # instead of searching for `t` twice.
- t = t.graph.extra_inputs[t.graph.extra_args.index(t)]
+ assert isinstance(t.graph, _function.FuncGraph)
+ if t in t.graph.internal_captures:
+ # TODO(srbs): Consider building a map of internal_captures ->
+ # external_captures instead of searching for `t` twice.
+ t = t.graph.external_captures[t.graph.internal_captures.index(t)]
else:
# Note: All intermediate tensors are output by the If op.
# TODO(srbs): .index() calls may be expensive. Optimize.
t = t.graph._if.outputs[t.graph.outputs.index(t)]
assert _is_ancestor(grad_graph, t.graph)
- new_extra_inputs.append(t)
+ new_inputs.append(t)
- return new_extra_inputs
+ return new_inputs
def _create_new_tf_function(func_graph):
"""Converts func_graph to a TF_Function and adds it to the current graph.
Args:
- func_graph: function._FuncGraph
+ func_graph: function.FuncGraph
Returns:
The name of the new TF_Function.
"""
- c_func = c_api.TF_GraphToFunction_wrapper(
- func_graph._c_graph,
- compat.as_str(func_graph.name),
- False, # append_hash_to_fn_name
- None, # opers
- [t._as_tf_output() for t in func_graph.inputs],
- [t._as_tf_output() for t in func_graph.outputs],
- [],
- None, # opts
- None) # description
- _ = c_api_util.ScopedTFFunction(c_func)
-
- # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
- # deserializing it into a Python FunctionDef, then reserializing it to create
- # a new TF_Function that we add to the graph.
- fdef = _function.function_def_from_tf_function(c_func)
- defined_func = _function._from_definition(fdef)
- defined_func._sub_functions = func_graph._functions
- defined_func.add_to_graph(func_graph._outer_graph)
-
+ func = _function._EagerDefinedFunction(
+ func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
+ func.add_to_graph(func_graph.outer_graph)
return func_graph.name
@@ -404,8 +365,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
There is no merging of params.
Args:
- true_graph: function._FuncGraph
- false_graph: function._FuncGraph
+ true_graph: function.FuncGraph
+ false_graph: function.FuncGraph
true_params: a list of Tensors from true_graph
false_params: a list of Tensors from false_graph
@@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
return new_true_params, new_false_inputs
-def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
- false_extra_inputs):
+def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
"""Modifies true_graph and false_graph so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so
- they have the same input signature, and updates the 'inputs', 'extra_inputs',
- and '_captured' fields of both graphs accordingly. It uses the input tensors
- from the outer graph to avoid duplicating shared arguments.
+ they have the same input signature, and updates the 'inputs' and 'captured'
+ fields of both graphs accordingly. It uses the input tensors from the outer
+ graph to avoid duplicating shared arguments.
Args:
- true_graph: function._FuncGraph
- false_graph: function._FuncGraph
- true_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ true_graph: function.FuncGraph
+ false_graph: function.FuncGraph
+ true_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
- false_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ false_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns:
@@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
false_inputs.
"""
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
- true_extra_inputs, false_extra_inputs)
+ true_inputs, false_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
- true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs))
- false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs))
+ true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
@@ -461,15 +421,11 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
_create_dummy_params(false_graph, true_only_inputs) +
[false_input_to_param[t] for t in false_only_inputs])
- # Rewrite the _FuncGraphs' state to reflect the new inputs.
- true_graph.extra_inputs = new_inputs
- false_graph.extra_inputs = new_inputs
-
- true_graph.extra_args = true_graph.inputs
- false_graph.extra_args = false_graph.inputs
-
- true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
- false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
+ # Rewrite the FuncGraphs' state to reflect the new inputs.
+ true_graph.captures = collections.OrderedDict(zip(new_inputs,
+ true_graph.inputs))
+ false_graph.captures = collections.OrderedDict(zip(new_inputs,
+ false_graph.inputs))
return new_inputs
@@ -478,7 +434,7 @@ def _create_dummy_params(func_graph, template_tensors):
"""Creates tensors in func_graph to represent template_tensors.
Args:
- func_graph: function._FuncGraph.
+ func_graph: function.FuncGraph.
template_tensors: a list of tensors in the outer graph.
Returns:
@@ -495,27 +451,16 @@ def _get_grad_fn_name(func_graph):
Ensures this name is unique in the entire hierarchy.
Args:
- func_graph: The _FuncGraph.
+ func_graph: The FuncGraph.
Returns:
A string, the name to use for the gradient function.
"""
name = "%s_grad" % func_graph.name
-
- base_name = name
- counter = 1
- has_conflict = True
- while has_conflict:
- curr_graph = func_graph._outer_graph
- has_conflict = curr_graph._is_function(name)
- while not has_conflict and isinstance(curr_graph, _function._FuncGraph):
- curr_graph = curr_graph._outer_graph
- has_conflict = curr_graph._is_function(name)
- if has_conflict:
- name = "%s_%s" % (base_name, counter)
- counter += 1
-
- return name
+ outer_most_graph = func_graph
+ while isinstance(outer_most_graph, _function.FuncGraph):
+ outer_most_graph = outer_most_graph.outer_graph
+ return outer_most_graph.unique_name(name)
def _check_same_outputs(true_graph, false_graph):
@@ -534,6 +479,6 @@ def _check_same_outputs(true_graph, false_graph):
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True
- if isinstance(graph, _function._FuncGraph):
- return _is_ancestor(graph._outer_graph, maybe_ancestor)
+ if isinstance(graph, _function.FuncGraph):
+ return _is_ancestor(graph.outer_graph, maybe_ancestor)
return False
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f84ff4ddf0..e3c1aa3d5a 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -14,7 +14,8 @@
# ==============================================================================
"""Control Flow Operations.
-See the @{$python/control_flow_ops} guide.
+See the [Control
+Flow](https://tensorflow.org/api_guides/python/control_flow_ops) guide.
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@@ -1965,8 +1966,12 @@ def cond(pred,
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
- Note that the conditional execution applies only to the operations defined in
- `true_fn` and `false_fn`. Consider the following simple program:
+ **WARNING**: Any Tensors or Operations created outside of `true_fn` and
+ `false_fn` will be executed regardless of which branch is selected at runtime.
+
+ Although this behavior is consistent with the dataflow model of TensorFlow,
+ it has frequently surprised users who expected a lazier semantics.
+ Consider the following simple program:
```python
z = tf.multiply(a, b)
@@ -1977,8 +1982,6 @@ def cond(pred,
operation will not be executed. Since `z` is needed for at least one
branch of the `cond`, the `tf.multiply` operation is always executed,
unconditionally.
- Although this behavior is consistent with the dataflow model of TensorFlow,
- it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 153548ae92..2c42176158 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -153,7 +153,7 @@ class WithDependenciesTestCase(test_util.TensorFlowTestCase):
const_with_dep = control_flow_ops.with_dependencies(
(increment_counter, constant_op.constant(42)),
constant_op.constant(7))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertEquals(0, counter.eval())
self.assertEquals(7, const_with_dep.eval())
@@ -167,7 +167,7 @@ class WithDependenciesTestCase(test_util.TensorFlowTestCase):
const_with_dep = control_flow_ops.with_dependencies(
[increment_counter, constant_op.constant(42)],
constant_op.constant(7))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertEquals(0, counter.eval())
self.assertEquals(7, const_with_dep.eval())
@@ -177,7 +177,7 @@ class WithDependenciesTestCase(test_util.TensorFlowTestCase):
class SwitchTestCase(test_util.TensorFlowTestCase):
def testIndexedSlicesWithDenseShape(self):
- with self.test_session():
+ with self.cached_session():
data = ops.IndexedSlices(
constant_op.constant([1, 2, 3]),
constant_op.constant([0, 1]),
@@ -208,7 +208,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
constant_op.constant(0.0)])
optimizer = momentum.MomentumOptimizer(0.1, 0.9)
train_op = optimizer.minimize(cost)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
for _ in range(10):
sess.run([train_op])
@@ -231,7 +231,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
_, cost = control_flow_ops.while_loop(
cond, body, [constant_op.constant(0),
constant_op.constant(0.0)])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual(10.0, cost.eval())
@@ -268,7 +268,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
static_grads = math_ops.segment_sum(static_grads.values,
static_grads.indices)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
@@ -280,7 +280,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
def testIndexedSlicesWithShapeGradientInWhileLoop(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_steps = 9
inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
@@ -309,7 +309,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = array_ops.placeholder(dtype=dtype)
initial_outputs = tensor_array_ops.TensorArray(
dtype=dtype, dynamic_size=True, size=1)
@@ -335,7 +335,7 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
self.assertAllEqual(grad, [1] * 3)
def testGradientThroughSingleBranchOutsideOfContext(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(2.)
s = constant_op.constant(True)
x_false, x_true = control_flow_ops.switch(x, s)
@@ -434,7 +434,7 @@ class CondTest(test_util.TensorFlowTestCase):
class ContextTest(test_util.TensorFlowTestCase):
def testCondContext(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(2)
y = constant_op.constant(5)
control_flow_ops.cond(
@@ -448,7 +448,7 @@ class ContextTest(test_util.TensorFlowTestCase):
control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
def _testWhileContextHelper(self, maximum_iterations=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
@@ -469,7 +469,7 @@ class ContextTest(test_util.TensorFlowTestCase):
self._testWhileContextHelper(maximum_iterations=10)
def testControlContextImportScope(self):
- with self.test_session():
+ with self.cached_session():
constant_op.constant(0, name="a")
constant_op.constant(2, name="test_scope/a")
b1 = constant_op.constant(1, name="b")
@@ -562,7 +562,7 @@ class DataTypesTest(test_util.TensorFlowTestCase):
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
strict=strict)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
true_feed_dict = {condition: True}
true_feed_dict.update(feed_dict)
@@ -884,7 +884,7 @@ class CaseTest(test_util.TensorFlowTestCase):
(math_ops.equal(x, 2), lambda: constant_op.constant(4))]
default = lambda: constant_op.constant(6)
output = control_flow_ops.case(conditions, default, exclusive=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
@@ -896,7 +896,7 @@ class CaseTest(test_util.TensorFlowTestCase):
(math_ops.equal(x, 2), lambda: constant_op.constant(6))]
default = lambda: constant_op.constant(8)
output = control_flow_ops.case(conditions, default, exclusive=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
@@ -909,7 +909,7 @@ class CaseTest(test_util.TensorFlowTestCase):
(math_ops.equal(x, 2), lambda: constant_op.constant(6))]
default = lambda: constant_op.constant(8)
output = control_flow_ops.case(conditions, default, exclusive=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
@@ -920,7 +920,7 @@ class CaseTest(test_util.TensorFlowTestCase):
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 3), lambda: constant_op.constant(6))]
output = control_flow_ops.case(conditions, exclusive=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
@@ -931,7 +931,7 @@ class CaseTest(test_util.TensorFlowTestCase):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
output = control_flow_ops.case(conditions, exclusive=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 871f236f78..d7834ba350 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -82,11 +82,10 @@ def custom_gradient(f):
scope must be using `ResourceVariable`s.
Args:
- f: function `f(x)` that returns a tuple `(y, grad_fn)` where:
- - `x` is a `Tensor` or sequence of `Tensor` inputs to the function.
+ f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
+ - `x` is a sequence of `Tensor` inputs to the function.
- `y` is a `Tensor` or sequence of `Tensor` outputs of applying
- TensorFlow
- operations in `f` to `x`.
+ TensorFlow operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
@@ -96,7 +95,8 @@ def custom_gradient(f):
signature `g(*grad_ys, variables=None)`, where `variables` is a list of
the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
`grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
- with the derivatives of `Tensor`s in `y` with respect to the variables.
+ with the derivatives of `Tensor`s in `y` with respect to the variables
+ (that is, grad_vars has one Tensor per variable in variables).
Returns:
A function `h(x)` which returns the same value as `f(x)[0]` and whose
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 7af2ca56be..69c0fcbbee 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
dtype,
shape=None,
shared_name=None,
- name="conditional_accumulator"):
+ name="conditional_accumulator",
+ reduction_type="MEAN"):
"""Creates a new ConditionalAccumulator.
Args:
@@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
def apply_grad(self, grad, local_step=0, name=None):
@@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
def __init__(self,
dtype,
shape=None,
shared_name=None,
- name="sparse_conditional_accumulator"):
+ name="sparse_conditional_accumulator",
+ reduction_type="MEAN"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(SparseConditionalAccumulator, self).__init__(dtype, shape,
accumulator_ref)
diff --git a/tensorflow/python/ops/dequantize_op_test.py b/tensorflow/python/ops/dequantize_op_test.py
index 31338db0dd..13e50273d8 100644
--- a/tensorflow/python/ops/dequantize_op_test.py
+++ b/tensorflow/python/ops/dequantize_op_test.py
@@ -32,7 +32,7 @@ class DequantizeOpTest(test.TestCase):
super(DequantizeOpTest, self).__init__(method_name)
def _testDequantizeOp(self, inputs, min_range, max_range, dtype):
- with self.test_session():
+ with self.cached_session():
input_op = constant_op.constant(inputs, shape=[len(inputs)], dtype=dtype)
dequantized = array_ops.dequantize(input_op, min_range, max_range)
tf_ans = dequantized.eval()
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index ddf9442cd2..578e7b7dd2 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -446,6 +446,24 @@ class Distribution(_BaseDistribution):
self._graph_parents = graph_parents
self._name = name
+ @property
+ def _parameters(self):
+ return self._parameter_dict
+
+ @_parameters.setter
+ def _parameters(self, value):
+ """Intercept assignments to self._parameters to avoid reference cycles.
+
+ Parameters are often created using locals(), so we need to clean out any
+ references to `self` before assigning it to an attribute.
+
+ Args:
+ value: A dictionary of parameters to assign to the `_parameters` property.
+ """
+ if "self" in value:
+ del value["self"]
+ self._parameter_dict = value
+
@classmethod
def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
"""Shapes of parameters given the desired shape of a call to `sample()`.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 7b9e7de145..6263041b8d 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -134,7 +134,10 @@ def _embedding_lookup_and_transform(params,
ids, max_norm)
if transform_fn:
result = transform_fn(result)
- return result
+ # Make sure the final result does not have colocation contraints on the
+ # params. Similar to the case np > 1 where parallel_dynamic_stitch is
+ # outside the scioe of all with ops.colocate_with(params[p]).
+ return array_ops.identity(result)
else:
# Flatten the ids. There are two cases where we need to do this.
# - There is more than one params tensor.
@@ -427,6 +430,8 @@ def embedding_lookup_sparse(params,
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
+ if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
+ embeddings = math_ops.to_float(embeddings)
if not ignore_weights:
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 4ecc74675a..a4e7c84ae4 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -15,7 +15,8 @@
"""Functional operations.
-See the @{$python/functional_ops} guide.
+See the [Higher Order
+Functions](https://tensorflow.org/api_guides/python/functional_ops) guide.
"""
from __future__ import absolute_import
@@ -90,7 +91,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
Example:
```python
- elems = [1, 2, 3, 4, 5, 6]
+ elems = tf.constant([1, 2, 3, 4, 5, 6])
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
diff --git a/tensorflow/python/ops/gradient_checker_test.py b/tensorflow/python/ops/gradient_checker_test.py
index b0ecdc6a50..fbb84b9018 100644
--- a/tensorflow/python/ops/gradient_checker_test.py
+++ b/tensorflow/python/ops/gradient_checker_test.py
@@ -76,7 +76,7 @@ class GradientCheckerTest(test.TestCase):
def testAddCustomized(self):
np.random.seed(3) # Fix seed to avoid flakiness
- with self.test_session():
+ with self.cached_session():
# a test case for Add operation
size = (2, 3)
x1 = constant_op.constant(
@@ -94,7 +94,7 @@ class GradientCheckerTest(test.TestCase):
def testGather(self):
np.random.seed(4) # Fix seed to avoid flakiness
- with self.test_session():
+ with self.cached_session():
p_shape = (4, 2)
p_size = 8
index_values = [1, 3]
@@ -111,7 +111,7 @@ class GradientCheckerTest(test.TestCase):
def testNestedGather(self):
np.random.seed(5) # Fix seed to avoid flakiness
- with self.test_session():
+ with self.cached_session():
p_shape = (8, 2)
p_size = 16
index_values = [1, 3, 5, 6]
@@ -131,7 +131,7 @@ class GradientCheckerTest(test.TestCase):
assert error < 1e-4
def testComplexMul(self):
- with self.test_session():
+ with self.cached_session():
size = ()
c = constant_op.constant(5 + 7j, dtype=dtypes.complex64)
x = constant_op.constant(11 - 13j, dtype=dtypes.complex64)
@@ -145,7 +145,7 @@ class GradientCheckerTest(test.TestCase):
gradient_checker.compute_gradient_error(x, size, y, size), 2e-4)
def testComplexConj(self):
- with self.test_session():
+ with self.cached_session():
size = ()
x = constant_op.constant(11 - 13j, dtype=dtypes.complex64)
y = math_ops.conj(x)
@@ -158,7 +158,7 @@ class GradientCheckerTest(test.TestCase):
gradient_checker.compute_gradient_error(x, size, y, size), 2e-5)
def testEmptySucceeds(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32)
y = array_ops.identity(x)
for grad in gradient_checker.compute_gradient(x, (0, 3), y, (0, 3)):
@@ -168,7 +168,7 @@ class GradientCheckerTest(test.TestCase):
def testEmptyFails(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = array_ops.placeholder(dtypes.float32)
with g.gradient_override_map({"Identity": "BadGrad"}):
y = array_ops.identity(x)
@@ -180,7 +180,7 @@ class GradientCheckerTest(test.TestCase):
def testNaNGradFails(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = array_ops.placeholder(dtypes.float32)
with g.gradient_override_map({"Identity": "NaNGrad"}):
y = array_ops.identity(x)
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 9fa8e27d5c..1dc666e78b 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -19,10 +19,10 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.python.eager import function
from tensorflow.python.eager.backprop import GradientTape
from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
# pylint: enable=unused-import
-
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index a68f680224..3268b38b86 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -31,7 +31,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
+from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -58,6 +58,10 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
+# This is to avoid a circular dependency (eager.function depends on
+# gradients_impl). This is set in eager/function.py.
+_function = None
+
# This is to avoid a circular dependency with cond_v2_impl.
cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access
@@ -121,7 +125,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs):
Args:
from_ops: list of Operations.
reached_ops: set of Operations.
- func_graphs: list of function._FuncGraphs. This method will traverse through
+ func_graphs: list of _function.FuncGraphs. This method will traverse through
these functions if they capture from_ops or any reachable ops.
"""
queue = collections.deque()
@@ -146,7 +150,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
- func_graphs: list of function._FuncGraphs. This method will traverse through
+ func_graphs: list of _function.FuncGraphs. This method will traverse through
these functions if they capture from_ops or any reachable ops. This is
useful if to_ops occur in a function and from_ops are in an outer function
or graph.
@@ -441,6 +445,19 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
% target_op.name)
+def _IsFunction(graph):
+ return (isinstance(graph, _function.FuncGraph) or
+ isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access
+
+
+def _Captures(func_graph):
+ if isinstance(func_graph, _function.FuncGraph):
+ return func_graph.captures
+ else:
+ assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
+ return func_graph._captured # pylint: disable=protected-access
+
+
def _MaybeCaptured(t):
"""If t is a captured value placeholder, returns the original captured value.
@@ -448,11 +465,11 @@ def _MaybeCaptured(t):
t: Tensor
Returns:
- A tensor, potentially from a different Graph/function._FuncGraph.
+ A tensor, potentially from a different Graph/_function.FuncGraph.
"""
# pylint: disable=protected-access
- if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder":
- for input_t, placeholder_t in t.op.graph._captured.items():
+ if _IsFunction(t.op.graph) and t.op.type == "Placeholder":
+ for input_t, placeholder_t in _Captures(t.op.graph).items():
if t == placeholder_t:
return _MaybeCaptured(input_t)
# pylint: enable=protected-access
@@ -470,10 +487,10 @@ def _Inputs(op, xs):
Returns:
A list of tensors. The tensors may be from multiple
- Graph/function._FuncGraphs if op is in a function._FuncGraph and has
+ Graph/_function.FuncGraphs if op is in a _function.FuncGraph and has
captured inputs.
"""
- if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access
+ if _IsFunction(op.graph): # pylint: disable=protected-access
# If we're differentiating w.r.t. `t`, do not attempt to traverse through it
# to a captured value. The algorithm needs to "see" `t` in this case, even
# if it's a function input for a captured value, whereas usually we'd like
@@ -489,7 +506,7 @@ def _Consumers(t, func_graphs):
Args:
t: Tensor
- func_graphs: a list of function._FuncGraphs that may have captured t.
+ func_graphs: a list of _function.FuncGraphs that may have captured t.
Returns:
A list of tensors. The tensors will be from the current graph and/or
@@ -497,7 +514,7 @@ def _Consumers(t, func_graphs):
"""
consumers = t.consumers()
for func in func_graphs:
- for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access
+ for input_t, placeholder in _Captures(func).items():
if input_t == t:
consumers.extend(_Consumers(placeholder, func_graphs))
return consumers
@@ -616,9 +633,13 @@ def _GradientsHelper(ys,
# ancestor graphs. This is necessary for correctly handling captured values.
func_graphs = []
curr_graph = src_graph
- while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access
+ while _IsFunction(curr_graph):
func_graphs.append(curr_graph)
- curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
+ if isinstance(curr_graph, _function.FuncGraph):
+ curr_graph = curr_graph.outer_graph
+ else:
+ assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access
+ curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
ys = _AsList(ys)
xs = _AsList(xs)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index d02fcf4ee2..3759d8a543 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -26,9 +26,10 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
+from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
@@ -159,7 +160,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
def testBoundaryContinue(self):
# Test that we differentiate both 'x' and 'y' correctly when x is a
# predecessor of y.
- with self.test_session():
+ with self.cached_session():
x = constant(1.0)
y = x * 2.0
z = y * 3.0
@@ -168,7 +169,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertEqual(6.0, grads[0].eval())
def testAggregationMethodAccumulateN(self):
- with self.test_session():
+ with self.cached_session():
x = constant(1.0)
y = x * 2.0
z = y + y + y + y + y + y + y + y + y + y
@@ -181,7 +182,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, grads[1].eval())
def testAggregationMethodAddN(self):
- with self.test_session():
+ with self.cached_session():
x = constant(1.0)
y = x * 2.0
z = y + y + y + y + y + y + y + y + y + y
@@ -192,7 +193,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, grads[1].eval())
def testAggregationMethodTree(self):
- with self.test_session():
+ with self.cached_session():
x = constant(1.0)
y = x * 2.0
z = y + y + y + y + y + y + y + y + y + y
@@ -232,7 +233,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
array_ops.placeholder(dtypes.int32))
dx, = gradients.gradients(y, x, grad_ys=dy)
# The IndexedSlices gradient of tf.identity is the identity map.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vdx, vdy = sess.run(
[dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
self.assertEqual(vdx, vdy)
@@ -276,7 +277,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(gradient)
def testDependentYs(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0)
y = math_ops.square(x)
y1 = math_ops.square(y)
@@ -291,7 +292,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertAllClose(17502.0, g[0].eval())
def testPartialDerivatives(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.)
y = 2 * x
z = x + y
@@ -341,7 +342,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
constants=constants, variables=variables_))
# evaluate all tensors in one call to session.run for speed
- with self.test_session() as sess:
+ with self.cached_session() as sess:
results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
for (npgrad1, npgrad2), case in zip(results, cases):
@@ -369,8 +370,8 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
@classmethod
def _GetFunc(cls, **kwargs):
- return function.Defun(dtypes.float32, dtypes.float32, **
- kwargs)(cls.XSquarePlusB)
+ return framework_function.Defun(dtypes.float32, dtypes.float32, **
+ kwargs)(cls.XSquarePlusB)
def _GetFuncGradients(self, f, x_value, b_value):
x = constant_op.constant(x_value, name="x")
@@ -378,7 +379,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
y = f(x, b)
grads = gradients.gradients(y, [x, b])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(grads)
def testFunctionGradientsBasic(self):
@@ -401,15 +402,16 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
# Build gradient graph (should add SymbolicGradient node for function).
grads = gradients.gradients(y, [x, b1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([40.0], sess.run(grads)[0])
self.assertAllEqual([10.0], sess.run(grads)[1])
def testFunctionGradientsWithGradFunc(self):
g = ops.Graph()
with g.as_default():
- grad_func = function.Defun(dtypes.float32, dtypes.float32,
- dtypes.float32)(self.XSquarePlusBGradient)
+ grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+ dtypes.float32)(
+ self.XSquarePlusBGradient)
f = self._GetFunc(grad_func=grad_func)
# Get gradients (should add SymbolicGradient node for function, which
# uses the grad_func above, which multiplies all gradients by 2).
@@ -430,8 +432,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
def testFunctionGradientWithGradFuncAndRegistration(self):
g = ops.Graph()
with g.as_default():
- grad_func = function.Defun(dtypes.float32, dtypes.float32,
- dtypes.float32)(self.XSquarePlusBGradient)
+ grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+ dtypes.float32)(
+ self.XSquarePlusBGradient)
with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
f = self._GetFunc(
grad_func=grad_func, python_grad_func=self._PythonGradient)
@@ -441,14 +444,14 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default():
x = constant_op.constant(1.0, name="x")
- @function.Defun()
+ @function.defun()
def Foo():
y = math_ops.multiply(x, 2.0, name="y")
g = gradients_impl.gradients(y, x)
return g[0]
f = Foo()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(f), 2.0)
def testGradientOfCaptured(self):
@@ -456,27 +459,27 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
x = constant_op.constant(1.0, name="x")
y = math_ops.multiply(x, 2.0, name="y")
- @function.Defun()
+ @framework_function.Defun()
def Foo():
g = gradients_impl.gradients(y, x)
return g[0]
f = Foo()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(f), 2.0)
def testCapturedResourceVariable(self):
with ops.Graph().as_default():
var = resource_variable_ops.ResourceVariable(1.0, name="var")
- @function.Defun()
+ @function.defun()
def Foo():
y = math_ops.multiply(var, 2.0, name="y")
g = gradients_impl.gradients(y, var)
return g[0]
f = Foo()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(sess.run(f), 2.0)
@@ -486,11 +489,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
x2 = constant_op.constant(2.0, name="x2")
x3 = math_ops.multiply(x1, x2, name="x3")
- @function.Defun()
+ @function.defun()
def Outer():
outer1 = array_ops.identity(x1, name="outer1")
- @function.Defun()
+ @function.defun()
def Inner():
inner1 = array_ops.identity(outer1, name="inner1")
inner2 = array_ops.identity(x2, name="inner2")
@@ -501,7 +504,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
return Inner()
x1_grad, x2_grad = Outer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# 1.0 + None + 2.0 + 1.0 = 4.0
self.assertEqual(sess.run(x1_grad), 4.0)
# None + 1.0 + 1.0 + None = 2.0
@@ -511,11 +514,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default():
x = constant_op.constant(1.0, name="x")
- @function.Defun()
+ @function.defun()
def Outer():
y = math_ops.multiply(x, 2.0, name="y")
- @function.Defun()
+ @function.defun()
def Inner():
z = math_ops.multiply(y, 3.0, name="z")
g = gradients_impl.gradients(z, y)
@@ -524,7 +527,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
return Inner()
z_grad = Outer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(z_grad), 3.0)
@@ -667,7 +670,7 @@ class HessianTest(test_util.TensorFlowTestCase):
class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
def testIndexedSlicesToTensor(self):
- with self.test_session():
+ with self.cached_session():
np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
c = constant_op.constant(np_val)
c_sparse = math_ops._as_indexed_slices(c)
@@ -676,7 +679,7 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(np_val, c_dense.eval())
def testIndexedSlicesToTensorList(self):
- with self.test_session():
+ with self.cached_session():
numpy_list = []
dense_list = []
sparse_list = []
@@ -692,7 +695,7 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
self.assertAllClose(packed_dense.eval(), packed_sparse.eval())
def testInt64Indices(self):
- with self.test_session():
+ with self.cached_session():
np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
c = constant_op.constant(np_val)
c_sparse = math_ops._as_indexed_slices(c)
@@ -938,7 +941,7 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
F(x)
def testRVGradientsDynamicCond(self):
- with self.test_session():
+ with self.cached_session():
alpha = resource_variable_ops.ResourceVariable(
np.random.random((1,)),
dtype="float32")
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index e86a8e5a5b..7291e05685 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -14,8 +14,6 @@
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Histograms.
-
-Please see @{$python/histogram_ops} guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 2e57ae8a2d..1ba805dbb4 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -35,7 +35,7 @@ class BinValuesFixedWidth(test.TestCase):
value_range = [0.0, 5.0]
values = []
expected_bins = []
- with self.test_session():
+ with self.cached_session():
bins = histogram_ops.histogram_fixed_width_bins(
values, value_range, nbins=5)
self.assertEqual(dtypes.int32, bins.dtype)
@@ -47,7 +47,7 @@ class BinValuesFixedWidth(test.TestCase):
value_range = [0.0, 5.0]
values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
expected_bins = [0, 0, 1, 2, 4, 4]
- with self.test_session():
+ with self.cached_session():
bins = histogram_ops.histogram_fixed_width_bins(
values, value_range, nbins=5, dtype=dtypes.int64)
self.assertEqual(dtypes.int32, bins.dtype)
@@ -59,7 +59,7 @@ class BinValuesFixedWidth(test.TestCase):
value_range = np.float64([0.0, 5.0])
values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
expected_bins = [0, 0, 1, 2, 4, 4]
- with self.test_session():
+ with self.cached_session():
bins = histogram_ops.histogram_fixed_width_bins(
values, value_range, nbins=5)
self.assertEqual(dtypes.int32, bins.dtype)
@@ -72,7 +72,7 @@ class BinValuesFixedWidth(test.TestCase):
values = constant_op.constant(
[[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3))
expected_bins = [[0, 0, 1], [2, 4, 4]]
- with self.test_session():
+ with self.cached_session():
bins = histogram_ops.histogram_fixed_width_bins(
values, value_range, nbins=5)
self.assertEqual(dtypes.int32, bins.dtype)
diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py
index 75d00c8ed1..fddde75f6b 100644
--- a/tensorflow/python/ops/image_grad_test.py
+++ b/tensorflow/python/ops/image_grad_test.py
@@ -108,7 +108,7 @@ class ResizeBilinearOpTest(test.TestCase):
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
self.assertEqual(out_shape, list(resize_out.get_shape()))
@@ -122,7 +122,7 @@ class ResizeBilinearOpTest(test.TestCase):
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
@@ -135,7 +135,7 @@ class ResizeBilinearOpTest(test.TestCase):
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
@@ -165,7 +165,7 @@ class ResizeBilinearOpTest(test.TestCase):
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for dtype in [np.float16, np.float32, np.float64]:
input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
@@ -190,7 +190,7 @@ class ResizeBicubicOpTest(test.TestCase):
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
@@ -206,7 +206,7 @@ class ResizeBicubicOpTest(test.TestCase):
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
@@ -221,7 +221,7 @@ class ResizeBicubicOpTest(test.TestCase):
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
@@ -235,7 +235,7 @@ class ResizeBicubicOpTest(test.TestCase):
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
grad = gradients_impl.gradients(input_tensor, [resize_out])
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 343531ac55..3de46e7cf3 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -16,7 +16,7 @@
# pylint: disable=g-short-docstring-punctuation
"""Image processing and decoding ops.
-See the @{$python/image} guide.
+See the [Images](https://tensorflow.org/api_guides/python/image) guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 12356944f8..de260f3140 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -330,6 +330,8 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
+ if isinstance(result, tuple):
+ result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
uniform_random = random_ops.random_uniform(
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 2c61bb232a..795e6bbc3e 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -238,7 +238,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_one(self):
"""Same image should be returned for gamma equal to one"""
- with self.test_session():
+ with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
@@ -252,7 +252,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_less_zero(self):
"""White image should be returned for gamma equal to zero"""
- with self.test_session():
+ with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
@@ -270,7 +270,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_less_zero_tensor(self):
"""White image should be returned for gamma equal to zero"""
- with self.test_session():
+ with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
@@ -290,7 +290,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_zero(self):
"""White image should be returned for gamma equal to zero"""
- with self.test_session():
+ with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
@@ -308,7 +308,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_less_one(self):
"""Verifying the output with expected results for gamma
correction with gamma equal to half"""
- with self.test_session():
+ with self.cached_session():
x_np = np.arange(0, 255, 4, np.uint8).reshape(8, 8)
y = image_ops.adjust_gamma(x_np, gamma=0.5)
y_tf = np.trunc(y.eval())
@@ -329,7 +329,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
def test_adjust_gamma_greater_one(self):
"""Verifying the output with expected results for gamma
correction with gamma equal to two"""
- with self.test_session():
+ with self.cached_session():
x_np = np.arange(0, 255, 4, np.uint8).reshape(8, 8)
y = image_ops.adjust_gamma(x_np, gamma=2)
y_tf = np.trunc(y.eval())
@@ -2367,7 +2367,7 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
for opt in self.OPTIONS:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [height, width], opt)
yshape = array_ops.shape(y)
@@ -3076,7 +3076,7 @@ class JpegTest(test_util.TensorFlowTestCase):
self.assertLess(error, 4)
def testCropAndDecodeJpeg(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Encode it, then decode it, then encode it
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
@@ -3102,7 +3102,7 @@ class JpegTest(test_util.TensorFlowTestCase):
self.assertAllEqual(image1_crop, image2)
def testCropAndDecodeJpegWithInvalidCropWindow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Encode it, then decode it, then encode it
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
@@ -3577,7 +3577,7 @@ class FormatTest(test_util.TensorFlowTestCase):
"png": functools.partial(image_ops.decode_png, channels=3),
"gif": lambda s: array_ops.squeeze(image_ops.decode_gif(s), axis=0),
}
- with self.test_session():
+ with self.cached_session():
for path in paths:
contents = io_ops.read_file(os.path.join(prefix, path)).eval()
images = {}
@@ -3592,7 +3592,7 @@ class FormatTest(test_util.TensorFlowTestCase):
def testError(self):
path = "tensorflow/core/lib/gif/testdata/scan.gif"
- with self.test_session():
+ with self.cached_session():
for decode in image_ops.decode_jpeg, image_ops.decode_png:
with self.assertRaisesOpError(r"Got 12 frames"):
decode(io_ops.read_file(path)).eval()
@@ -3606,7 +3606,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
max_output_size_np = 3
iou_threshold_np = 0.5
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np)
scores = constant_op.constant(scores_np)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3657,6 +3657,47 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+ def testDataTypes(self):
+ # Test case for GitHub issue 20199.
+ boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ max_output_size_np = 3
+ iou_threshold_np = 0.5
+ # Note: There are multiple versions of non_max_suppression v2, v3, v4.
+ # gen_image_ops.non_max_suppression_v2:
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices = gen_image_ops.non_max_suppression_v2(
+ boxes, scores, max_output_size, iou_threshold).eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+ # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices = image_ops.non_max_suppression(
+ boxes, scores, max_output_size, iou_threshold).eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+ # gen_image_ops.non_max_suppression_v4.
+ score_threshold = float('-inf')
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices, _ = gen_image_ops.non_max_suppression_v4(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+ selected_indices = selected_indices.eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
@@ -3686,7 +3727,7 @@ class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
# The output shape of the padded operation must be fully defined.
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0])
self.assertEqual(num_valid_padded.eval(), 3)
self.assertAllClose(selected_indices.eval(), [3, 0, 5])
@@ -4035,7 +4076,7 @@ class ImageGradientsTest(test_util.TensorFlowTestCase):
expected_dx = np.reshape([[2, 1, -2, 0], [-1, -2, 1, 0]], shape)
dy, dx = image_ops.image_gradients(img)
- with self.test_session():
+ with self.cached_session():
actual_dy = dy.eval()
actual_dx = dx.eval()
self.assertAllClose(expected_dy, actual_dy)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 4d75ee3974..fff3d9b930 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -39,12 +39,12 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.util.deprecation import (
- deprecated, deprecated_arg_values)
+from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.deprecation import deprecated_arg_values
from tensorflow.python.util.tf_export import tf_export
@@ -226,9 +226,7 @@ class Constant(Initializer):
return {"value": self.value, "dtype": self.dtype.name}
-@tf_export("keras.initializers.RandomUniform", "initializers.random_uniform",
- "random_uniform_initializer", "keras.initializers.uniform",
- "keras.initializers.random_uniform")
+@tf_export("initializers.random_uniform", "random_uniform_initializer")
class RandomUniform(Initializer):
"""Initializer that generates tensors with a uniform distribution.
@@ -264,9 +262,7 @@ class RandomUniform(Initializer):
}
-@tf_export("keras.initializers.RandomNormal", "initializers.random_normal",
- "random_normal_initializer", "keras.initializers.normal",
- "keras.initializers.random_normal")
+@tf_export("initializers.random_normal", "random_normal_initializer")
class RandomNormal(Initializer):
"""Initializer that generates tensors with a normal distribution.
@@ -302,9 +298,7 @@ class RandomNormal(Initializer):
}
-@tf_export("keras.initializers.TruncatedNormal",
- "initializers.truncated_normal", "truncated_normal_initializer",
- "keras.initializers.truncated_normal")
+@tf_export("initializers.truncated_normal", "truncated_normal_initializer")
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
@@ -1116,29 +1110,10 @@ class Identity(Initializer):
def get_config(self):
return {"gain": self.gain, "dtype": self.dtype.name}
-# Aliases.
-
-# pylint: disable=invalid-name
-zeros_initializer = Zeros
-ones_initializer = Ones
-constant_initializer = Constant
-random_uniform_initializer = RandomUniform
-random_normal_initializer = RandomNormal
-truncated_normal_initializer = TruncatedNormal
-uniform_unit_scaling_initializer = UniformUnitScaling
-variance_scaling_initializer = VarianceScaling
-orthogonal_initializer = Orthogonal
-identity_initializer = Identity
-convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
-convolutional_orthogonal_1d = ConvolutionOrthogonal1D
-convolutional_orthogonal_2d = ConvolutionOrthogonal2D
-convolutional_orthogonal_3d = ConvolutionOrthogonal3D
-# pylint: enable=invalid-name
-
@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform",
"initializers.glorot_uniform")
-def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
+class GlorotUniform(VarianceScaling):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
It draws samples from a uniform distribution within [-limit, limit]
@@ -1153,17 +1128,28 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
-
- Returns:
- An initializer.
"""
- return variance_scaling_initializer(
- scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
+
+ def __init__(self,
+ seed=None,
+ dtype=dtypes.float32):
+ super(GlorotUniform, self).__init__(
+ scale=1.0,
+ mode="fan_avg",
+ distribution="uniform",
+ seed=seed,
+ dtype=dtype)
+
+ def get_config(self):
+ return {
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal",
"initializers.glorot_normal")
-def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
+class GlorotNormal(VarianceScaling):
"""The Glorot normal initializer, also called Xavier normal initializer.
It draws samples from a truncated normal distribution centered on 0
@@ -1178,16 +1164,45 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
-
- Returns:
- An initializer.
"""
- return variance_scaling_initializer(
- scale=1.0,
- mode="fan_avg",
- distribution="truncated_normal",
- seed=seed,
- dtype=dtype)
+
+ def __init__(self,
+ seed=None,
+ dtype=dtypes.float32):
+ super(GlorotNormal, self).__init__(
+ scale=1.0,
+ mode="fan_avg",
+ distribution="truncated_normal",
+ seed=seed,
+ dtype=dtype)
+
+ def get_config(self):
+ return {
+ "seed": self.seed,
+ "dtype": self.dtype.name
+ }
+
+
+# Aliases.
+
+# pylint: disable=invalid-name
+zeros_initializer = Zeros
+ones_initializer = Ones
+constant_initializer = Constant
+random_uniform_initializer = RandomUniform
+random_normal_initializer = RandomNormal
+truncated_normal_initializer = TruncatedNormal
+uniform_unit_scaling_initializer = UniformUnitScaling
+variance_scaling_initializer = VarianceScaling
+glorot_uniform_initializer = GlorotUniform
+glorot_normal_initializer = GlorotNormal
+orthogonal_initializer = Orthogonal
+identity_initializer = Identity
+convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
+convolutional_orthogonal_1d = ConvolutionOrthogonal1D
+convolutional_orthogonal_2d = ConvolutionOrthogonal2D
+convolutional_orthogonal_3d = ConvolutionOrthogonal3D
+# pylint: enable=invalid-name
@tf_export("keras.initializers.lecun_normal", "initializers.lecun_normal")
diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py
index f6fffa9079..5693c3caaf 100644
--- a/tensorflow/python/ops/init_ops_test.py
+++ b/tensorflow/python/ops/init_ops_test.py
@@ -20,10 +20,14 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -55,7 +59,7 @@ class InitializersTest(test.TestCase):
def test_uniform(self):
tensor_shape = (9, 6, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(
init_ops.RandomUniform(minval=-1, maxval=1, seed=124),
tensor_shape,
@@ -65,7 +69,7 @@ class InitializersTest(test.TestCase):
def test_normal(self):
tensor_shape = (8, 12, 99)
- with self.test_session():
+ with self.cached_session():
self._runner(
init_ops.RandomNormal(mean=0, stddev=1, seed=153),
tensor_shape,
@@ -74,7 +78,7 @@ class InitializersTest(test.TestCase):
def test_truncated_normal(self):
tensor_shape = (12, 99, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(
init_ops.TruncatedNormal(mean=0, stddev=1, seed=126),
tensor_shape,
@@ -84,7 +88,7 @@ class InitializersTest(test.TestCase):
def test_constant(self):
tensor_shape = (5, 6, 4)
- with self.test_session():
+ with self.cached_session():
self._runner(
init_ops.Constant(2),
tensor_shape,
@@ -94,7 +98,7 @@ class InitializersTest(test.TestCase):
def test_lecun_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(
@@ -105,7 +109,7 @@ class InitializersTest(test.TestCase):
def test_glorot_uniform_initializer(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(
@@ -116,7 +120,7 @@ class InitializersTest(test.TestCase):
def test_he_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(
@@ -127,7 +131,7 @@ class InitializersTest(test.TestCase):
def test_lecun_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(
@@ -138,7 +142,7 @@ class InitializersTest(test.TestCase):
def test_glorot_normal_initializer(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(
@@ -149,7 +153,7 @@ class InitializersTest(test.TestCase):
def test_he_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(
@@ -160,11 +164,45 @@ class InitializersTest(test.TestCase):
def test_Orthogonal(self):
tensor_shape = (20, 20)
- with self.test_session():
+ with self.cached_session():
self._runner(init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.)
+ def testVariablePlacementWithOrthogonalInitializer(self):
+ if not context.context().num_gpus():
+ self.skipTest('No devices other than CPUs found')
+ with ops.Graph().as_default() as g:
+ with ops.device('gpu:0'):
+ variable_scope.get_variable(
+ name='v', shape=[8, 2], initializer=init_ops.Orthogonal)
+ variable_scope.get_variable(
+ name='w', shape=[8, 2], initializer=init_ops.RandomNormal)
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ config = config_pb2.ConfigProto(
+ allow_soft_placement=False, log_device_placement=True)
+
+ # Note: allow_soft_placement=False will fail whenever we cannot satisfy
+ # the colocation constraints.
+ with session.Session(config=config, graph=g) as sess:
+ sess.run(
+ variables.global_variables_initializer(),
+ options=run_options,
+ run_metadata=run_metadata)
+
+ def test_eager_orthogonal_gpu(self):
+ if not context.context().num_gpus():
+ self.skipTest('No devices other than CPUs found')
+ with context.eager_mode():
+ v = variable_scope.get_variable(
+ name='v', shape=[8, 2], initializer=init_ops.Orthogonal)
+ w = variable_scope.get_variable(
+ name='w', shape=[8, 2], initializer=init_ops.RandomNormal)
+ self.assertTrue('GPU' in v.handle.device)
+ self.assertTrue('GPU' in w.handle.device)
+
def test_Identity(self):
- with self.test_session():
+ with self.cached_session():
tensor_shape = (3, 4, 5)
with self.assertRaises(ValueError):
self._runner(
@@ -182,13 +220,13 @@ class InitializersTest(test.TestCase):
def test_Zeros(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(
init_ops.Zeros(), tensor_shape, target_mean=0., target_max=0.)
def test_Ones(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(init_ops.Ones(), tensor_shape, target_mean=1., target_max=1.)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index b5274ef2ed..f84785df2c 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -16,7 +16,8 @@
# pylint: disable=line-too-long
"""Inputs and Readers.
-See the @{$python/io_ops} guide.
+See the [Inputs and
+Readers](https://tensorflow.org/api_guides/python/io_ops) guide.
"""
from __future__ import absolute_import
@@ -32,8 +33,9 @@ from tensorflow.python.ops import gen_io_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_io_ops import *
-from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@@ -94,7 +96,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
preferred_shard, name=name)
-@tf_export("ReaderBase")
+@tf_export(v1=["ReaderBase"])
class ReaderBase(object):
"""Base class for different Reader types, that produce a record every step.
@@ -308,7 +310,7 @@ ops.NotDifferentiable("ReaderRestoreState")
ops.NotDifferentiable("ReaderReset")
-@tf_export("WholeFileReader")
+@tf_export(v1=["WholeFileReader"])
class WholeFileReader(ReaderBase):
"""A Reader that outputs the entire contents of a file as a value.
@@ -323,6 +325,9 @@ class WholeFileReader(ReaderBase):
@end_compatibility
"""
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.map(tf.read_file)`.")
def __init__(self, name=None):
"""Create a WholeFileReader.
@@ -336,7 +341,7 @@ class WholeFileReader(ReaderBase):
ops.NotDifferentiable("WholeFileReader")
-@tf_export("TextLineReader")
+@tf_export(v1=["TextLineReader"])
class TextLineReader(ReaderBase):
"""A Reader that outputs the lines of a file delimited by newlines.
@@ -350,6 +355,9 @@ class TextLineReader(ReaderBase):
"""
# TODO(josh11b): Support serializing and restoring state.
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.TextLineDataset`.")
def __init__(self, skip_header_lines=None, name=None):
"""Create a TextLineReader.
@@ -366,7 +374,7 @@ class TextLineReader(ReaderBase):
ops.NotDifferentiable("TextLineReader")
-@tf_export("FixedLengthRecordReader")
+@tf_export(v1=["FixedLengthRecordReader"])
class FixedLengthRecordReader(ReaderBase):
"""A Reader that outputs fixed-length records from a file.
@@ -379,6 +387,9 @@ class FixedLengthRecordReader(ReaderBase):
"""
# TODO(josh11b): Support serializing and restoring state.
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.FixedLengthRecordDataset`.")
def __init__(self,
record_bytes,
header_bytes=None,
@@ -409,7 +420,7 @@ class FixedLengthRecordReader(ReaderBase):
ops.NotDifferentiable("FixedLengthRecordReader")
-@tf_export("TFRecordReader")
+@tf_export(v1=["TFRecordReader"])
class TFRecordReader(ReaderBase):
"""A Reader that outputs the records from a TFRecords file.
@@ -422,6 +433,9 @@ class TFRecordReader(ReaderBase):
"""
# TODO(josh11b): Support serializing and restoring state.
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.TFRecordDataset`.")
def __init__(self, name=None, options=None):
"""Create a TFRecordReader.
@@ -440,7 +454,7 @@ class TFRecordReader(ReaderBase):
ops.NotDifferentiable("TFRecordReader")
-@tf_export("LMDBReader")
+@tf_export(v1=["LMDBReader"])
class LMDBReader(ReaderBase):
"""A Reader that outputs the records from a LMDB file.
@@ -451,6 +465,10 @@ class LMDBReader(ReaderBase):
use `tf.data` to get data into your model.
@end_compatibility
"""
+
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.contrib.data.LMDBDataset`.")
def __init__(self, name=None, options=None):
"""Create a LMDBReader.
@@ -458,6 +476,7 @@ class LMDBReader(ReaderBase):
name: A name for the operation (optional).
options: A LMDBRecordOptions object (optional).
"""
+ del options
rr = gen_io_ops.lmdb_reader(name=name)
super(LMDBReader, self).__init__(rr)
@@ -465,7 +484,7 @@ class LMDBReader(ReaderBase):
ops.NotDifferentiable("LMDBReader")
-@tf_export("IdentityReader")
+@tf_export(v1=["IdentityReader"])
class IdentityReader(ReaderBase):
"""A Reader that outputs the queued work as both the key and value.
@@ -480,6 +499,9 @@ class IdentityReader(ReaderBase):
@end_compatibility
"""
+ @deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.map(...)`.")
def __init__(self, name=None):
"""Create a IdentityReader.
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index d9ede87530..145a5f358c 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -97,3 +97,18 @@ def _TensorListSetItemGrad(op, dlist):
element_grad = gen_list_ops.tensor_list_get_item(
dlist, index, element_dtype=item.dtype)
return list_grad, index_grad, element_grad
+
+
+@ops.RegisterGradient("TensorListGather")
+def _TensorListGatherGrad(op, dtensor):
+ _, indices = op.inputs
+ return gen_list_ops.tensor_list_scatter(
+ tensor=dtensor, indices=indices,
+ element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None
+
+
+@ops.RegisterGradient("TensorListScatter")
+def _TensorListScatterGrad(op, dlist):
+ t, indices, _ = op.inputs
+ return gen_list_ops.tensor_list_gather(
+ dlist, indices, element_dtype=t.dtype), None
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index fb51fbc626..561a341cf3 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -22,6 +22,7 @@ import collections
import functools
import six
+from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -299,6 +300,7 @@ class HashTable(InitializableLookupTableBase):
self._value_shape))
return exported_keys, exported_values
+
class TableInitializerBase(object):
"""Base class for lookup table initializers."""
@@ -370,8 +372,13 @@ class KeyValueTensorInitializer(TableInitializerBase):
# Ensure a unique name when eager execution is enabled to avoid spurious
# sharing issues.
scope += str(ops.uid())
- init_op = gen_lookup_ops.initialize_table_v2(
- table.table_ref, self._keys, self._values, name=scope)
+ if fwd_compat.forward_compatible(2018, 9, 19):
+ init_op = gen_lookup_ops.lookup_table_import_v2(
+ table.table_ref, self._keys, self._values, name=scope)
+ else:
+ # To maintain forward compatibiltiy, use the old implementation.
+ init_op = gen_lookup_ops.initialize_table_v2(
+ table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 2a7a2fd51f..8e11c4bce1 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,9 +972,9 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
-@ops.RegisterGradient("UnsafeDiv")
-def _UnsafeDivGrad(op, grad):
- """UnsafeDiv op gradient."""
+@ops.RegisterGradient("DivNoNan")
+def _DivNoNanGrad(op, grad):
+ """DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
@@ -983,10 +983,10 @@ def _UnsafeDivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
- math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
+ math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
- grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
+ grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
ry), sy))
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index f9bb60e7fe..7110e0958c 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -102,14 +102,14 @@ class MinOrMaxGradientTest(test.TestCase):
def testMinGradient(self):
inputs = constant_op.constant([1.0], dtype=dtypes.float32)
outputs = math_ops.reduce_min(array_ops.concat([inputs, inputs], 0))
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [1], outputs, [])
self.assertLess(error, 1e-4)
def testMaxGradient(self):
inputs = constant_op.constant([1.0], dtype=dtypes.float32)
outputs = math_ops.reduce_max(array_ops.concat([inputs, inputs], 0))
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [1], outputs, [])
self.assertLess(error, 1e-4)
@@ -119,14 +119,14 @@ class MaximumOrMinimumGradientTest(test.TestCase):
def testMaximumGradient(self):
inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32)
outputs = math_ops.maximum(inputs, 3.0)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4])
self.assertLess(error, 1e-4)
def testMinimumGradient(self):
inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32)
outputs = math_ops.minimum(inputs, 2.0)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4])
self.assertLess(error, 1e-4)
@@ -137,7 +137,7 @@ class ProdGradientTest(test.TestCase):
inputs = constant_op.constant([[1., 2.], [3., 4.]],
dtype=dtypes.float32)
outputs = math_ops.reduce_prod(inputs)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -147,7 +147,7 @@ class ProdGradientTest(test.TestCase):
inputs = constant_op.constant([[1., 2.], [3., 4.]],
dtype=dtypes.float32)
outputs = math_ops.reduce_prod(inputs, -1)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -158,7 +158,7 @@ class ProdGradientTest(test.TestCase):
inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]],
dtype=dtype)
outputs = math_ops.reduce_prod(inputs)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -169,7 +169,7 @@ class ProdGradientTest(test.TestCase):
inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]],
dtype=dtype)
outputs = math_ops.reduce_prod(inputs, -1)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -182,7 +182,7 @@ class SegmentMinOrMaxGradientTest(test.TestCase):
data = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64)
segment_min = math_ops.segment_min(data, segment_ids)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(data, [3], segment_min,
[2])
self.assertLess(error, 1e-4)
@@ -191,7 +191,7 @@ class SegmentMinOrMaxGradientTest(test.TestCase):
data = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32)
segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64)
segment_max = math_ops.segment_max(data, segment_ids)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(data, [3], segment_max,
[2])
self.assertLess(error, 1e-4)
@@ -201,7 +201,7 @@ class SegmentMinOrMaxGradientTest(test.TestCase):
data = array_ops.concat([inputs, inputs], 0)
segment_ids = constant_op.constant([0, 0], dtype=dtypes.int64)
segment_min = math_ops.segment_min(data, segment_ids)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [1], segment_min,
[1])
self.assertLess(error, 1e-4)
@@ -211,7 +211,7 @@ class SegmentMinOrMaxGradientTest(test.TestCase):
data = array_ops.concat([inputs, inputs], 0)
segment_ids = constant_op.constant([0, 0], dtype=dtypes.int64)
segment_max = math_ops.segment_max(data, segment_ids)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [1], segment_max,
[1])
self.assertLess(error, 1e-4)
@@ -225,18 +225,19 @@ class FloorModGradientTest(test.TestCase):
ns = constant_op.constant([17.], dtype=dtypes.float32)
inputs = constant_op.constant([131.], dtype=dtypes.float32)
floor_mod = math_ops.floormod(inputs, ns)
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(inputs, [1],
floor_mod, [1])
self.assertLess(error, 1e-4)
-class UnsafeDivGradientTest(test.TestCase):
+class DivNoNanGradientTest(test.TestCase):
def testBasicGradient(self):
- inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
- with self.test_session():
+ inputs = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(inputs, 1 + math_ops.abs(inputs))
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs,
inputs.get_shape().as_list(), outputs,
@@ -244,10 +245,12 @@ class UnsafeDivGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
def testGradientWithDenominatorIsZero(self):
- x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- y = array_ops.zeros_like(x, dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(x, y)
- with self.test_session():
+ x = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ y = array_ops.zeros_like(x,
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(x, y)
+ with self.cached_session():
dx, dy = gradients.gradients(outputs, [x, y])
self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c9da1a0bba..33e7a5533b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -14,7 +14,7 @@
# ==============================================================================
"""Basic arithmetic operators.
-See the @{$python/math_ops} guide.
+See the [python/math_ops](python/math_ops) guide.
"""
from __future__ import absolute_import
from __future__ import division
@@ -618,7 +618,7 @@ def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
The operation casts `x` (in case of `Tensor`) or `x.values`
- (in case of `SparseTensor`) to `dtype`.
+ (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
For example:
@@ -637,15 +637,16 @@ def cast(x, dtype, name=None):
behavior of numpy.
Args:
- x: A `Tensor` or `SparseTensor` of numeric type. It could be
- `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
- `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
- dtype: The destination type. The list of supported dtypes is the same
- as `x`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
+ be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
+ `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
+ `bfloat16`.
+ dtype: The destination type. The list of supported dtypes is the same as
+ `x`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` and
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
same type as `dtype`.
Raises:
@@ -659,6 +660,9 @@ def cast(x, dtype, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
+ elif isinstance(x, ops.IndexedSlices):
+ values_cast = cast(x.values, base_type, name=name)
+ x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
else:
# TODO(josh11b): If x is not already a Tensor, we could return
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
@@ -711,11 +715,12 @@ def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float32`.
Raises:
TypeError: If `x` cannot be cast to the `float32`.
@@ -728,11 +733,12 @@ def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float64`.
Raises:
TypeError: If `x` cannot be cast to the `float64`.
@@ -745,11 +751,12 @@ def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int32`.
Raises:
TypeError: If `x` cannot be cast to the `int32`.
@@ -762,11 +769,12 @@ def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int64`.
Raises:
TypeError: If `x` cannot be cast to the `int64`.
@@ -779,11 +787,12 @@ def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `bfloat16`.
Raises:
TypeError: If `x` cannot be cast to the `bfloat16`.
@@ -796,11 +805,12 @@ def to_complex64(x, name="ToComplex64"):
"""Casts a tensor to type `complex64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex64`.
Raises:
TypeError: If `x` cannot be cast to the `complex64`.
@@ -813,11 +823,12 @@ def to_complex128(x, name="ToComplex128"):
"""Casts a tensor to type `complex128`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex128`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex128`.
Raises:
TypeError: If `x` cannot be cast to the `complex128`.
@@ -1038,29 +1049,27 @@ def div(x, y, name=None):
return _div_python2(x, y, name)
-def unsafe_div(x, y, name=None):
+@tf_export("div_no_nan")
+def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
- Note that the function uses Python 3 division operator semantics.
-
Args:
- x: A `Tensor`. Must be one of the following types:
- `float32`, `float64`, `int16`, `int32`, `int64`.
+ x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
Returns:
The element-wise value of the x divided by y.
"""
- with ops.name_scope(name, "unsafe_div", [x, y]) as name:
+ with ops.name_scope(name, "div_no_nan", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
- raise TypeError(
- "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
- return gen_math_ops.unsafe_div(x, y, name=name)
+ raise TypeError("x and y must have the same dtype, got %r != %r" %
+ (x_dtype, y_dtype))
+ return gen_math_ops.div_no_nan(x, y, name=name)
# TODO(aselle): This should be removed
@@ -2562,27 +2571,37 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
@tf_export("unsorted_segment_mean")
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
- r""" Computes the mean along segments of a tensor.
+ r"""Computes the mean along segments of a tensor.
- Read @{$math_ops#segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the mean of all
entries belonging to a segment such that:
- \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such
- that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
- of id \\i\\.
+ \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples
+ `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of
+ occurrences of id \\i\\.
If there is no entry for a given segment ID `i`, it outputs 0.
- segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
- first dimension.
+ If the given segment ID `i` is negative, the value is dropped and will not
+ be added to the sum of the segment.
- output: Has same shape as data, except for dimension 0 which
- has size `num_segments`.
+ Args:
+ data: A `Tensor` with floating point or complex dtype.
+ segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+ num_segments: An integer scalar `Tensor`. The number of distinct
+ segment IDs.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has same shape as data, except for the first `segment_ids.rank`
+ dimensions, which are replaced with a single dimension which has size
+ `num_segments`.
"""
with ops.name_scope(name, "UnsortedSegmentMean"):
data = ops.convert_to_tensor(data)
@@ -2596,28 +2615,38 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
- Read @{$math_ops#segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Additionally to computing the sum over segments, it divides the results by
sqrt(N).
- \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such
- that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
- of id \\i\\.
+ \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over
+ tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the
+ number of occurrences of id \\i\\.
If there is no entry for a given segment ID `i`, it outputs 0.
Note that this op only supports floating point and complex dtypes,
due to tf.sqrt only supporting these types.
- segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
- first dimension.
+ If the given segment ID `i` is negative, the value is dropped and will not
+ be added to the sum of the segment.
+
+ Args:
+ data: A `Tensor` with floating point or complex dtype.
+ segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+ num_segments: An integer scalar `Tensor`. The number of distinct
+ segment IDs.
+ name: A name for the operation (optional).
- output: Has same shape as data, except for dimension 0 which
- has size `num_segments`.
+ Returns:
+ A `Tensor`. Has same shape as data, except for the first `segment_ids.rank`
+ dimensions, which are replaced with a single dimension which has size
+ `num_segments`.
"""
with ops.name_scope(name, "UnsortedSegmentSqrtN"):
data = ops.convert_to_tensor(data)
@@ -2632,8 +2661,9 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor.
- Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -2707,8 +2737,9 @@ def sparse_segment_mean(data,
num_segments=None):
r"""Computes the mean along sparse segments of a tensor.
- Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
- of segments.
+ Read [the section on
+ segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 5fe7bbca11..1b01d1d37f 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -373,7 +373,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
def testFloorModInt(self):
nums, divs = self.intTestData()
- with self.test_session():
+ with self.cached_session():
# TODO(aselle): Change test to use % after switch
# tf_result = math_ops.floor_mod(nums, divs).eval()
tf_result = math_ops.floormod(nums, divs).eval()
@@ -382,7 +382,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
def testFloorModFloat(self):
nums, divs = self.floatTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = math_ops.floormod(nums, divs).eval()
np_result = nums % divs
self.assertAllEqual(tf_result, np_result)
@@ -393,21 +393,21 @@ class DivAndModTest(test_util.TensorFlowTestCase):
def testTruncateModInt(self):
nums, divs = self.intTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = math_ops.truncatemod(nums, divs).eval()
np_result = np.fmod(nums, divs)
self.assertAllEqual(tf_result, np_result)
def testTruncateModFloat(self):
nums, divs = self.floatTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = math_ops.truncatemod(nums, divs).eval()
np_result = np.fmod(nums, divs)
self.assertAllEqual(tf_result, np_result)
def testDivideInt(self):
nums, divs = self.intTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = math_ops.floor_div(nums, divs).eval()
np_result = nums // divs
self.assertAllEqual(tf_result, np_result)
@@ -417,29 +417,29 @@ class DivAndModTest(test_util.TensorFlowTestCase):
# self.assertAllEqual(tf2_result, tf_result)
def testDivideName(self):
- with self.test_session():
+ with self.cached_session():
op = math_ops.divide(
array_ops.constant(3), array_ops.constant(4), name="my_cool_divide")
self.assertEqual(op.name, "my_cool_divide:0")
def testRealDiv(self):
nums, divs = self.floatTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = math_ops.realdiv(nums, divs).eval()
np_result = np.divide(nums, divs)
self.assertAllEqual(tf_result, np_result)
def testComplexDiv(self):
foo = array_ops.constant([1. + 3.j])
- with self.test_session():
+ with self.cached_session():
_ = math_ops.divide(foo, 1.).eval()
_ = math_ops.div(foo, 2.).eval()
def testFloorDivGrad(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable(2.)
b = variables.Variable(4.)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
c_grad = gradients.gradients(math_ops.divide(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
@@ -451,7 +451,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
def testConsistent(self):
nums, divs = self.intTestData()
- with self.test_session():
+ with self.cached_session():
tf_result = (math_ops.floor_div(nums, divs) * divs + math_ops.floormod(
nums, divs)).eval()
tf_nums = array_ops.constant(nums)
@@ -473,18 +473,19 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, expanded_nums)
-class UnsafeDivTest(test_util.TensorFlowTestCase):
+class DivNoNanTest(test_util.TensorFlowTestCase):
def testBasic(self):
- nums = np.arange(-10, 10, .25).reshape(80, 1)
- divs = np.arange(-3, 3, .25).reshape(1, 24)
+ for dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
- np_result = np.true_divide(nums, divs)
- np_result[:, divs[0] == 0] = 0
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
- with self.test_session():
- tf_result = math_ops.unsafe_div(nums, divs).eval()
- self.assertAllEqual(tf_result, np_result)
+ with self.cached_session(use_gpu=True):
+ tf_result = math_ops.div_no_nan(nums, divs).eval()
+ self.assertAllEqual(tf_result, np_result)
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 9461a01515..763877c2d2 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -301,6 +301,40 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
return total_cm, update_op
+def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
+ """Aggregate metric value across towers."""
+ def fn(distribution, *a):
+ """Call `metric_value_fn` in the correct control flow context."""
+ if hasattr(distribution, '_outer_control_flow_context'):
+ # If there was an outer context captured before this method was called,
+ # then we enter that context to create the metric value op. If the
+ # caputred context is `None`, ops.control_dependencies(None) gives the
+ # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
+ # captured context.
+ # This special handling is needed because sometimes the metric is created
+ # inside a while_loop (and perhaps a TPU rewrite context). But we don't
+ # want the value op to be evaluated every step or on the TPU. So we
+ # create it outside so that it can be evaluated at the end on the host,
+ # once the update ops have been evaluted.
+
+ # pylint: disable=protected-access
+ if distribution._outer_control_flow_context is None:
+ with ops.control_dependencies(None):
+ metric_value = metric_value_fn(distribution, *a)
+ else:
+ distribution._outer_control_flow_context.Enter()
+ metric_value = metric_value_fn(distribution, *a)
+ distribution._outer_control_flow_context.Exit()
+ # pylint: enable=protected-access
+ else:
+ metric_value = metric_value_fn(distribution, *a)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric_value)
+ return metric_value
+
+ return distribution_strategy_context.get_tower_context().merge_call(fn, *args)
+
+
@tf_export('metrics.mean')
def mean(values,
weights=None,
@@ -368,14 +402,10 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -612,14 +642,8 @@ def _confusion_matrix_at_thresholds(labels,
def _aggregate_variable(v, collections):
-
- def f(distribution, value):
- value = distribution.read_var(value)
- if collections:
- ops.add_to_collections(collections, value)
- return value
-
- return distribution_strategy_context.get_tower_context().merge_call(f, v)
+ f = lambda distribution, value: distribution.read_var(value)
+ return _aggregate_across_towers(collections, f, v)
@tf_export('metrics.auc')
@@ -807,15 +831,12 @@ def auc(labels,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- def aggregate_auc(_, values):
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
- return auc_value
-
- auc_value = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_auc, values)
+ def compute_auc_value(_, values):
+ return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
+ 'value')
+
+ auc_value = _aggregate_across_towers(
+ metrics_collections, compute_auc_value, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
@@ -1046,16 +1067,14 @@ def mean_per_class_accuracy(labels,
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- def aggregate_mean_accuracy(_, count, total):
+ def compute_mean_accuracy(_, count, total):
per_class_accuracy = _safe_div(count, total, None)
mean_accuracy_v = math_ops.reduce_mean(
per_class_accuracy, name='mean_accuracy')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
return mean_accuracy_v
- mean_accuracy_v = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_mean_accuracy, count, total)
+ mean_accuracy_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_accuracy, count, total)
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
@@ -1128,7 +1147,7 @@ def mean_iou(labels,
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(total_cm, name):
+ def compute_mean_iou(_, total_cm):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1152,17 +1171,12 @@ def mean_iou(labels,
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
math_ops.greater(num_valid_entries, 0),
- math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
+ math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
return result
- def mean_iou_across_towers(_, v):
- mean_iou_v = compute_mean_iou(v, 'mean_iou')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
- return mean_iou_v
-
- mean_iou_v = distribution_strategy_context.get_tower_context().merge_call(
- mean_iou_across_towers, total_cm)
+ # TODO(priyag): Use outside_compilation if in TPU context.
+ mean_iou_v = _aggregate_across_towers(
+ metrics_collections, compute_mean_iou, total_cm)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1371,14 +1385,10 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def aggregate_across_towers(_, t, c):
- mean_t = _safe_div(t, c, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
- return mean_t
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value')
- mean_t = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, total, count)
+ mean_t = _aggregate_across_towers(
+ metrics_collections, compute_mean, total, count)
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
@@ -2004,13 +2014,10 @@ def precision(labels,
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
def once_across_towers(_, true_p, false_p):
- p = compute_precision(true_p, false_p, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
- return p
+ return compute_precision(true_p, false_p, 'value')
- p = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_p)
+ p = _aggregate_across_towers(metrics_collections, once_across_towers,
+ true_p, false_p)
update_op = compute_precision(true_positives_update_op,
false_positives_update_op, 'update_op')
@@ -2088,13 +2095,10 @@ def precision_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
def precision_across_towers(_, values):
- prec = compute_precision(values['tp'], values['fp'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
- return prec
+ return compute_precision(values['tp'], values['fp'], 'value')
- prec = distribution_strategy_context.get_tower_context().merge_call(
- precision_across_towers, values)
+ prec = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, values)
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
'update_op')
@@ -2184,13 +2188,10 @@ def recall(labels,
math_ops.div(true_p, true_p + false_n), 0, name)
def once_across_towers(_, true_p, false_n):
- rec = compute_recall(true_p, false_n, 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(true_p, false_n, 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, true_p, false_n)
+ rec = _aggregate_across_towers(
+ metrics_collections, once_across_towers, true_p, false_n)
update_op = compute_recall(true_positives_update_op,
false_negatives_update_op, 'update_op')
@@ -2622,14 +2623,11 @@ def recall_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fn):
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def compute_recall(_, tp, fn):
+ return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fn)
+ metric = _aggregate_across_towers(
+ metrics_collections, compute_recall, tp, fn)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
@@ -2704,13 +2702,10 @@ def recall_at_thresholds(labels,
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
def recall_across_towers(_, values):
- rec = compute_recall(values['tp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
- return rec
+ return compute_recall(values['tp'], values['fn'], 'value')
- rec = distribution_strategy_context.get_tower_context().merge_call(
- recall_across_towers, values)
+ rec = _aggregate_across_towers(
+ metrics_collections, recall_across_towers, values)
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
@@ -2778,14 +2773,9 @@ def root_mean_squared_error(labels,
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
- def once_across_towers(_, mse):
- rmse = math_ops.sqrt(mse)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
- return rmse
- rmse = distribution_strategy_context.get_tower_context().merge_call(
- once_across_towers, mse)
+ once_across_towers = lambda _, mse: math_ops.sqrt(mse)
+ rmse = _aggregate_across_towers(metrics_collections, once_across_towers, mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
if updates_collections:
@@ -2880,15 +2870,12 @@ def sensitivity_at_specificity(labels,
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- sensitivity = compute_sensitivity_at_specificity(
+ def sensitivity_across_towers(_, values):
+ return compute_sensitivity_at_specificity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
- return sensitivity
- sensitivity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ sensitivity = _aggregate_across_towers(
+ metrics_collections, sensitivity_across_towers, values)
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
@@ -3157,14 +3144,11 @@ def _streaming_sparse_average_precision_at_top_k(labels,
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- def aggregate_across_towers(_, total_var, max_var):
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
- return mean_average_precision
+ def precision_across_towers(_, total_var, max_var):
+ return _safe_scalar_div(total_var, max_var, name='mean')
- mean_average_precision = distribution_strategy_context.get_tower_context(
- ).merge_call(aggregate_across_towers, total_var, max_var)
+ mean_average_precision = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, total_var, max_var)
update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
@@ -3443,14 +3427,11 @@ def precision_at_top_k(labels,
class_id=class_id,
weights=weights)
- def aggregate_across_towers(_, tp, fp):
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- return metric
+ def precision_across_towers(_, tp, fp):
+ return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- metric = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, tp, fp)
+ metric = _aggregate_across_towers(
+ metrics_collections, precision_across_towers, tp, fp)
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
@@ -3681,15 +3662,12 @@ def specificity_at_sensitivity(labels,
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- def aggregate_across_towers(_, values):
- specificity = compute_specificity_at_sensitivity(
+ def specificity_across_towers(_, values):
+ return compute_specificity_at_sensitivity(
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
- return specificity
- specificity = distribution_strategy_context.get_tower_context().merge_call(
- aggregate_across_towers, values)
+ specificity = _aggregate_across_towers(
+ metrics_collections, specificity_across_towers, values)
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 339684122e..4b73fc830e 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -16,7 +16,7 @@
# pylint: disable=unused-import,g-bad-import-order
"""Neural network support.
-See the @{$python/nn} guide.
+See the [Neural network](https://tensorflow.org/api_guides/python/nn) guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index 7d6dd3fb02..a7467aa943 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -129,7 +129,7 @@ class BatchNormalizationTest(test.TestCase):
v_val = np.random.random_sample(param_shape).astype(np.float64)
beta_val = np.random.random_sample(param_shape).astype(np.float64)
gamma_val = np.random.random_sample(param_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x")
m = constant_op.constant(m_val, name="m")
v = constant_op.constant(v_val, name="v")
@@ -455,7 +455,7 @@ class MomentsTest(test.TestCase):
return nn_impl.moments(x, axes, keep_dims=keep_dims)
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
- with self.test_session():
+ with self.cached_session():
# shape = [batch, width, height, depth]
assert len(shape) == 4
@@ -482,7 +482,7 @@ class MomentsTest(test.TestCase):
expected_variance, var.eval(feed_dict={x: x_numpy}))
def RunMomentTest(self, shape, axes, keep_dims, dtype):
- with self.test_session():
+ with self.cached_session():
# shape = [batch, width, height, depth]
assert len(shape) == 4
@@ -547,7 +547,7 @@ class MomentsTest(test.TestCase):
dtype=dtype)
def _testGlobalGradient(self, from_y="mean"):
- with self.test_session():
+ with self.cached_session():
x_shape = [3, 5, 4, 2]
x_val = np.random.random_sample(x_shape).astype(np.float64)
x = constant_op.constant(x_val)
@@ -644,7 +644,7 @@ class WeightedMomentsTest(MomentsTest):
keep_dims,
dtype,
dynshapes=False):
- with self.test_session() as s:
+ with self.cached_session() as s:
x_numpy = np.random.normal(size=shape).astype(np.float32)
weights_numpy = np.absolute( # weights must be positive
np.random.normal(
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index c2dd58bdf0..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("Conv2DBackpropInput")
@@ -486,7 +485,9 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
- math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) *
+ math_ops.matmul(array_ops.expand_dims(grad_grad, 1),
+ array_ops.expand_dims(softmax, 2)),
+ axis=1)) *
softmax)
return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
@@ -990,25 +991,30 @@ def _TopKGrad(op, grad, _):
in_shape = array_ops.shape(op.inputs[0])
ind_shape = array_ops.shape(op.outputs[1])
- ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1)
+ # int32 is not supported on GPU hence up-casting
+ ind_lastdim = array_ops.gather(math_ops.cast(
+ ind_shape, dtypes.int64), array_ops.size(ind_shape) - 1)
# Flatten indices to 2D.
ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
- in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1)
+ in_lastdim = array_ops.gather(math_ops.cast(
+ in_shape, dtypes.int64), array_ops.size(in_shape) - 1)
outerdim = array_ops.shape(ind_2d)[0]
# Compute linear indices (flattened to 1D).
- ind = array_ops.reshape(ind_2d + array_ops.expand_dims(
- math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1])
+ ind = array_ops.reshape(ind_2d + math_ops.cast(array_ops.expand_dims(
+ math_ops.range(0, math_ops.cast(outerdim, dtypes.int64)
+ * in_lastdim, in_lastdim), -1), dtypes.int32), [-1])
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
return [
array_ops.reshape(
- sparse_ops.sparse_to_dense(
- ind,
- array_ops.reshape(math_ops.reduce_prod(in_shape), [1]),
+ array_ops.scatter_nd(
+ array_ops.expand_dims(ind, -1),
array_ops.reshape(grad, [-1]),
- validate_indices=False), in_shape),
+ [math_ops.reduce_prod(in_shape)]
+ ),
+ in_shape),
array_ops.zeros([], dtype=dtypes.int32)
]
diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py
index 49d54beb20..8065df4b16 100644
--- a/tensorflow/python/ops/nn_grad_test.py
+++ b/tensorflow/python/ops/nn_grad_test.py
@@ -37,7 +37,7 @@ class Relu6OpTest(test.TestCase):
x_init_value = np.array([[-3.5, -1.5, 2, 4], [4.5, 7.5, 8.5, 11]])
r = nn_ops.relu6(inputs)
r_g = gradients_impl.gradients(r, inputs)[0]
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs,
inputs.get_shape().as_list(),
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 51f812b395..2a1919e66f 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -1210,7 +1210,9 @@ def nce_loss(weights,
num_true]`. The target classes.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
- num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_sampled: An `int`. The number of negative classes to randomly sample
+ per batch. This single sample of negative classes is evaluated for each
+ element in the batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 52ea202636..d646245ce3 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -510,7 +510,7 @@ class _WithSpaceToBatch(object):
# Recover channel information for output shape if channels are not last.
if self.data_format is not None and self.data_format.startswith("NC"):
- if not result_converted.shape[1].value:
+ if not result_converted.shape[1].value and filter is not None:
output_shape = result_converted.shape.as_list()
output_shape[1] = filter.shape[-1]
result_converted.set_shape(output_shape)
@@ -698,7 +698,7 @@ def convolution(
`padded_input` is obtained by zero padding the input using an effective
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
output striding `strides` as described in the
- @{$python/nn#Convolution$comment here}.
+ [comment here](https://tensorflow.org/api_guides/python/nn#Convolution).
In the case that `data_format` does start with `"NC"`, the `input` and output
(but not the `filter`) are simply transposed as follows:
@@ -1586,7 +1586,7 @@ def leaky_relu(features, alpha=0.2, name=None):
"Rectifier Nonlinearities Improve Neural Network Acoustic Models"
AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013
- http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf
+ https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf
Args:
features: A `Tensor` representing preactivation values. Must be one of
@@ -1838,8 +1838,9 @@ def softmax_cross_entropy_with_logits_v2(
name: A name for the operation (optional).
Returns:
- A `Tensor` of the same shape as `labels` and of the same type as `logits`
- with the softmax cross entropy loss.
+ A `Tensor` that contains the softmax cross entropy loss. Its type is the
+ same as `logits` and its shape is the same as `labels` except that it does
+ not have the last dimension of `labels`.
"""
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
logits)
@@ -1964,8 +1965,9 @@ def softmax_cross_entropy_with_logits(
name: A name for the operation (optional).
Returns:
- A `Tensor` of the same shape as `labels` and of the same type as `logits`
- with the softmax cross entropy loss.
+ A `Tensor` that contains the softmax cross entropy loss. Its type is the
+ same as `logits` and its shape is the same as `labels` except that it does
+ not have the last dimension of `labels`.
"""
_ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
logits)
@@ -2454,7 +2456,7 @@ def conv1d(value,
returned to the caller.
Args:
- value: A 3D `Tensor`. Must be of type `float16` or `float32`.
+ value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `value`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index ce0db6b264..2fabb2e966 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -53,7 +53,7 @@ class ZeroFractionTest(test_lib.TestCase):
x_shape = [5, 17]
x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
y_np = self._ZeroFraction(x_np)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np)
x_tf.set_shape(x_shape)
y_tf = nn_impl.zero_fraction(x_tf)
@@ -62,7 +62,7 @@ class ZeroFractionTest(test_lib.TestCase):
self.assertAllClose(y_tf_np, y_np, eps)
def testZeroFractionEmpty(self):
- with self.test_session():
+ with self.cached_session():
x = np.zeros(0)
y = nn_impl.zero_fraction(x).eval()
self.assertTrue(np.isnan(y))
@@ -106,7 +106,7 @@ class SoftmaxTest(test_lib.TestCase, parameterized.TestCase):
@parameterized.parameters(((5, 10),), ((2, 3, 4),))
def testGradient(self, x_shape):
x_np = np.random.randn(*x_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np)
y_tf = nn_ops.softmax(x_tf)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -143,7 +143,7 @@ class LogPoissonLossTest(test_lib.TestCase):
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float64)
z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np)
y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False)
y_tf_stirling = nn_impl.log_poisson_loss(
@@ -191,7 +191,7 @@ class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase):
@parameterized.parameters(((5, 10),), ((2, 3, 4),))
def testGradient(self, x_shape):
x_np = np.random.randn(*x_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np)
y_tf = nn_ops.log_softmax(x_tf)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -215,7 +215,7 @@ class L2LossTest(test_lib.TestCase):
x_shape = [20, 7, 3]
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x")
output = nn_ops.l2_loss(x)
err = gradient_checker.compute_gradient_error(x, x_shape, output, [1])
@@ -263,7 +263,7 @@ class L2NormalizeTest(test_lib.TestCase):
np.random.seed(1)
x_np = np.random.random_sample(x_shape).astype(np.float64)
for dim in range(len(x_shape)):
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name="x")
y_tf = nn_impl.l2_normalize(x_tf, dim)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -282,7 +282,7 @@ class DropoutTest(test_lib.TestCase):
y_dim = 30
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = constant_op.constant(
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
dropout = nn_ops.dropout(t, keep_prob)
@@ -310,7 +310,7 @@ class DropoutTest(test_lib.TestCase):
y_dim = 3
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = constant_op.constant(
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
@@ -335,7 +335,7 @@ class DropoutTest(test_lib.TestCase):
y_dim = 30
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = constant_op.constant(
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
@@ -355,7 +355,7 @@ class DropoutTest(test_lib.TestCase):
y_dim = 30
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = constant_op.constant(
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
keep_prob_placeholder = array_ops.placeholder(dtypes.float32)
@@ -389,7 +389,7 @@ class DropoutTest(test_lib.TestCase):
y_dim = 3
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.test_session():
+ with self.cached_session():
t = constant_op.constant(
1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
# Set noise_shape=[None, 1] which means [x_dim, 1].
@@ -541,7 +541,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
"b",
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=constant_op.constant(biases))
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variables.global_variables_initializer().run()
return sess.run([list(sharded_weights), list(sharded_biases)])
@@ -549,7 +549,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for num_true in range(1, 5):
labels = np.random.randint(
low=0, high=num_classes, size=batch_size * num_true)
@@ -585,7 +585,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for num_true in range(1, 5):
labels = np.random.randint(
low=0, high=num_classes, size=batch_size * num_true)
@@ -622,7 +622,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
num_classes = 5
batch_size = 3
sampled = [1, 0, 2, 3]
- with self.test_session():
+ with self.cached_session():
for num_true in range(1, 5):
labels = np.random.randint(
low=0, high=num_classes, size=batch_size * num_true)
@@ -666,7 +666,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for num_true in range(1, 5):
labels = np.random.randint(
low=0, high=num_classes, size=batch_size * num_true)
@@ -702,7 +702,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for num_true in range(1, 5):
labels = np.random.randint(
low=0, high=num_classes, size=batch_size * num_true)
@@ -762,7 +762,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
exp_nce_loss = np.sum(
_SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1)
- with self.test_session():
+ with self.cached_session():
got_nce_loss = nn_impl.nce_loss(
weights=constant_op.constant(weights),
biases=constant_op.constant(biases),
@@ -819,7 +819,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
exp_logits, exp_labels)
- with self.test_session():
+ with self.cached_session():
got_sampled_softmax_loss = nn_impl.sampled_softmax_loss(
weights=constant_op.constant(weights),
biases=constant_op.constant(biases),
@@ -880,7 +880,7 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
exp_logits, exp_labels)
- with self.test_session():
+ with self.cached_session():
true_exp_bf16 = np.full(
[batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
sampled_exp_bf16 = np.full(
@@ -911,7 +911,7 @@ class CReluTest(test_lib.TestCase):
np.random.seed(1) # Make it reproducible.
x = np.random.randn(3, 4).astype(np.float32)
y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1)
- with self.test_session():
+ with self.cached_session():
z = nn_ops.crelu(constant_op.constant(x)).eval()
self.assertAllClose(y, z, 1e-4)
@@ -922,7 +922,7 @@ class ReluTest(test_lib.TestCase):
np.random.seed(1) # Make it reproducible.
x = np.random.randn(3, 4).astype(np.float32)
y = np.maximum(x, 0.0)
- with self.test_session():
+ with self.cached_session():
z = nn_ops.relu(constant_op.constant(x)).eval()
self.assertAllEqual(y, z)
@@ -930,7 +930,7 @@ class ReluTest(test_lib.TestCase):
# Test that relu(nan) = nan for various sizes.
for i in range(18):
x = np.zeros(i) + np.nan
- with self.test_session():
+ with self.cached_session():
z = nn_ops.relu(constant_op.constant(x)).eval()
self.assertTrue(np.isnan(z).all())
@@ -947,7 +947,7 @@ class LeakyReluTest(test_lib.TestCase):
outputs = nn_ops.leaky_relu(inputs)
self.assertEquals(inputs.shape, outputs.shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs, outputs = sess.run([inputs, outputs])
self.assertGreaterEqual(outputs.min(), 0.0)
self.assertLessEqual(outputs.max(), 1.0)
@@ -957,7 +957,7 @@ class LeakyReluTest(test_lib.TestCase):
for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype)
outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
outputs = sess.run(outputs)
tol = 2e-3 if dtype == np.float16 else 1e-6
self.assertAllClose(
@@ -984,7 +984,7 @@ class SwishTest(test_lib.TestCase):
tf_values = constant_op.constant(np_values)
actual_tf_outputs = nn_impl.swish(tf_values)
expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_outputs, expected_outputs = sess.run(
[actual_tf_outputs, expected_tf_outputs])
self.assertAllClose(actual_outputs, expected_outputs)
@@ -995,7 +995,7 @@ class SwishTest(test_lib.TestCase):
input_values = np.random.randn(*shape) * sigma
x_tf = constant_op.constant(input_values)
y_tf = nn_impl.swish(x_tf)
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape)
self.assertLess(err, 1e-4)
@@ -1016,7 +1016,7 @@ class MomentsTest(test_lib.TestCase):
expected_var = np.var(
input_values, axis=moments_axes, keepdims=keep_dims)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
inputs = constant_op.constant(
input_values, shape=input_shape, dtype=dtypes.float32)
mean, variance = nn_impl.moments(
diff --git a/tensorflow/python/ops/nn_xent_test.py b/tensorflow/python/ops/nn_xent_test.py
index 90f4b40770..54a0e26bfb 100644
--- a/tensorflow/python/ops/nn_xent_test.py
+++ b/tensorflow/python/ops/nn_xent_test.py
@@ -54,7 +54,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
return logits, targets, losses
def testConstructionNamed(self):
- with self.test_session():
+ with self.cached_session():
logits, targets, _ = self._Inputs()
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits, name="mylogistic")
@@ -84,7 +84,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
def testGradient(self):
sizes = [4, 2]
- with self.test_session():
+ with self.cached_session():
logits, targets, _ = self._Inputs(sizes=sizes)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
labels=targets, logits=logits)
@@ -93,7 +93,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
self.assertLess(err, 1e-7)
def testGradientAtZero(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
loss = nn_impl.sigmoid_cross_entropy_with_logits(
@@ -130,7 +130,7 @@ class WeightedCrossEntropyTest(test.TestCase):
return logits, targets, q, losses
def testConstructionNamed(self):
- with self.test_session():
+ with self.cached_session():
logits, targets, pos_weight, _ = self._Inputs()
loss = nn_impl.weighted_cross_entropy_with_logits(
targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
@@ -159,7 +159,7 @@ class WeightedCrossEntropyTest(test.TestCase):
def testGradient(self):
sizes = [4, 2]
- with self.test_session():
+ with self.cached_session():
logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
loss = nn_impl.weighted_cross_entropy_with_logits(
targets=targets, logits=logits, pos_weight=pos_weight)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index ccf2eb8214..ead7ae5478 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -46,6 +46,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
"""
flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
+ is_none_list = []
def while_body(i, *ta_list):
"""Body of while loop."""
@@ -56,10 +57,13 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
"actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
len(fn_output)))
outputs = []
+ del is_none_list[:]
+ is_none_list.extend([x is None for x in fn_output])
for out, ta in zip(fn_output, ta_list):
# TODO(agarwal): support returning Operation objects from loop_fn.
- assert isinstance(out, ops.Tensor)
- outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
+ if out is not None:
+ ta = ta.write(i, array_ops.expand_dims(out, 0))
+ outputs.append(ta)
return tuple([i + 1] + outputs)
ta_list = control_flow_ops.while_loop(
@@ -69,7 +73,10 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
])[1:]
# TODO(rachelim): enable this for sparse tensors
- return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list])
+
+ output = [None if is_none else ta.concat()
+ for ta, is_none in zip(ta_list, is_none_list)]
+ return nest.pack_sequence_as(loop_fn_dtypes, output)
def pfor(loop_fn, iters):
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index c0e66cb0b8..d403b0c61a 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1259,7 +1259,7 @@ class SparseTest(PForTest):
[3]) # [0, 2, 0]
pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(pfor, feed_dict={num_iters: 3})
def test_sparse_result_none_stacked(self):
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index ee3d5c9b86..460de0a97f 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -61,9 +61,10 @@ def jacobian(output, inputs, use_pfor=True):
loop_fn, [output.dtype] * len(flat_inputs), output_size)
for i, out in enumerate(pfor_outputs):
- new_shape = array_ops.concat(
- [output_shape, array_ops.shape(out)[1:]], axis=0)
- out = array_ops.reshape(out, new_shape)
+ if out is not None:
+ new_shape = array_ops.concat(
+ [output_shape, array_ops.shape(out)[1:]], axis=0)
+ out = array_ops.reshape(out, new_shape)
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)
@@ -119,6 +120,8 @@ def batch_jacobian(output, inp, use_pfor=True):
else:
pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
output_row_size)
+ if pfor_output is None:
+ return None
pfor_output = array_ops.reshape(pfor_output,
[output_row_size, batch_size, -1])
output = array_ops.transpose(pfor_output, [1, 0, 2])
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 3a6d9149ad..628c6764cd 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -333,6 +333,13 @@ class GradientsTest(test.TestCase):
for i in range(n):
self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol)
+ def test_no_path(self):
+ for grad_func in [gradients.jacobian, gradients.batch_jacobian]:
+ for use_pfor in [True, False]:
+ x = constant_op.constant([[1.0]])
+ y = constant_op.constant([[2.0]])
+ self.assertIsNone(grad_func(y, x, use_pfor=use_pfor))
+
def test_jacobian_fixed_shape(self):
x = random_ops.random_uniform([2, 2])
y = math_ops.matmul(x, x, transpose_a=True)
@@ -349,7 +356,7 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, jacobian_while)
def test_jacobian_unknown_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
y = math_ops.matmul(x, x, transpose_a=True)
jacobian_pfor = gradients.jacobian(y, x, use_pfor=True)
@@ -374,7 +381,7 @@ class GradientsTest(test.TestCase):
gradients.batch_jacobian(y, x, use_pfor=True)
def test_batch_jacobian_bad_unknown_shapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.concat([x, x], axis=0)
jacobian = gradients.batch_jacobian(y, x)
@@ -395,7 +402,7 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, batch_jacobian_while)
def test_batch_jacobian_unknown_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = x * x
batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True)
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 2e4b2fd64e..f9153b6d7d 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import collections
-from absl import flags
-
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -41,6 +39,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@@ -1070,6 +1069,8 @@ class PFor(object):
If y does not need to be converted, it returns y as is. Else it returns
the "converted value" corresponding to y.
"""
+ if y is None:
+ return None
if isinstance(y, sparse_tensor.SparseTensor):
return self._convert_sparse(y)
output = self._convert_helper(y)
@@ -2011,6 +2012,7 @@ def _convert_biasaddgrad(pfor_input):
@RegisterPForWithArgs("ReluGrad")
@RegisterPForWithArgs("TanhGrad")
@RegisterPForWithArgs("SigmoidGrad")
+@RegisterPForWithArgs("SoftplusGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index d8d9af545f..8224097ac4 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -629,76 +629,12 @@ def _parse_example_raw(serialized,
Returns:
A `dict` mapping keys to `Tensor`s and `SparseTensor`s.
- Raises:
- ValueError: If sparse and dense key sets intersect, or input lengths do not
- match up.
"""
with ops.name_scope(name, "ParseExample", [serialized, names]):
- names = [] if names is None else names
- dense_defaults = collections.OrderedDict(
- ) if dense_defaults is None else dense_defaults
- sparse_keys = [] if sparse_keys is None else sparse_keys
- sparse_types = [] if sparse_types is None else sparse_types
- dense_keys = [] if dense_keys is None else dense_keys
- dense_types = [] if dense_types is None else dense_types
- dense_shapes = (
- [[]] * len(dense_keys) if dense_shapes is None else dense_shapes)
-
- num_dense = len(dense_keys)
- num_sparse = len(sparse_keys)
-
- if len(dense_shapes) != num_dense:
- raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
- % (len(dense_shapes), num_dense))
- if len(dense_types) != num_dense:
- raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
- % (len(dense_types), num_dense))
- if len(sparse_types) != num_sparse:
- raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
- % (len(sparse_types), num_sparse))
- if num_dense + num_sparse == 0:
- raise ValueError("Must provide at least one sparse key or dense key")
- if not set(dense_keys).isdisjoint(set(sparse_keys)):
- raise ValueError(
- "Dense and sparse keys must not intersect; intersection: %s" %
- set(dense_keys).intersection(set(sparse_keys)))
-
- # Convert dense_shapes to TensorShape object.
- dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
-
- dense_defaults_vec = []
- for i, key in enumerate(dense_keys):
- default_value = dense_defaults.get(key)
- dense_shape = dense_shapes[i]
- if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
- dense_shape[0].value is None):
- # Variable stride dense shape, the default value should be a
- # scalar padding value
- if default_value is None:
- default_value = ops.convert_to_tensor(
- "" if dense_types[i] == dtypes.string else 0,
- dtype=dense_types[i])
- else:
- # Reshape to a scalar to ensure user gets an error if they
- # provide a tensor that's not intended to be a padding value
- # (0 or 2+ elements).
- key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, [])
- else:
- if default_value is None:
- default_value = constant_op.constant([], dtype=dense_types[i])
- elif not isinstance(default_value, ops.Tensor):
- key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, dense_shape)
-
- dense_defaults_vec.append(default_value)
-
- # Finally, convert dense_shapes to TensorShapeProto
- dense_shapes = [shape.as_proto() for shape in dense_shapes]
+ (names, dense_defaults_vec, sparse_keys, sparse_types,
+ dense_keys, dense_shapes, _) = _process_raw_parameters(
+ names, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
@@ -719,6 +655,112 @@ def _parse_example_raw(serialized,
return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
+ dense_keys, dense_types, dense_shapes):
+ """Process raw parameters to params used by `gen_parsing_ops`.
+
+ Args:
+ names: A vector (1-D Tensor) of strings (optional), the names of
+ the serialized protos.
+ dense_defaults: A dict mapping string keys to `Tensor`s.
+ The keys of the dict must match the dense_keys of the feature.
+ sparse_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `SparseTensor` objects.
+ sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `Tensor`s
+ dense_types: A list of DTypes of the same length as `dense_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_shapes: A list of tuples with the same length as `dense_keys`.
+ The shape of the data for each dense feature referenced by `dense_keys`.
+ Required for any input tensors identified by `dense_keys`. Must be
+ either fully defined, or may contain an unknown first dimension.
+ An unknown first dimension means the feature is treated as having
+ a variable number of blocks, and the output shape along this dimension
+ is considered unknown at graph build time. Padding is applied for
+ minibatch elements smaller than the maximum number of blocks for the
+ given feature along this dimension.
+
+ Returns:
+ Tuple of `names`, `dense_defaults_vec`, `sparse_keys`, `sparse_types`,
+ `dense_keys`, `dense_shapes`.
+
+ Raises:
+ ValueError: If sparse and dense key sets intersect, or input lengths do not
+ match up.
+ """
+ names = [] if names is None else names
+ dense_defaults = collections.OrderedDict(
+ ) if dense_defaults is None else dense_defaults
+ sparse_keys = [] if sparse_keys is None else sparse_keys
+ sparse_types = [] if sparse_types is None else sparse_types
+ dense_keys = [] if dense_keys is None else dense_keys
+ dense_types = [] if dense_types is None else dense_types
+ dense_shapes = ([[]] * len(dense_keys)
+ if dense_shapes is None else dense_shapes)
+
+ num_dense = len(dense_keys)
+ num_sparse = len(sparse_keys)
+
+ if len(dense_shapes) != num_dense:
+ raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d" %
+ (len(dense_shapes), num_dense))
+ if len(dense_types) != num_dense:
+ raise ValueError("len(dense_types) != len(num_dense): %d vs. %d" %
+ (len(dense_types), num_dense))
+ if len(sparse_types) != num_sparse:
+ raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d" %
+ (len(sparse_types), num_sparse))
+ if num_dense + num_sparse == 0:
+ raise ValueError("Must provide at least one sparse key or dense key")
+ if not set(dense_keys).isdisjoint(set(sparse_keys)):
+ raise ValueError(
+ "Dense and sparse keys must not intersect; intersection: %s" %
+ set(dense_keys).intersection(set(sparse_keys)))
+
+ # Convert dense_shapes to TensorShape object.
+ dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
+
+ dense_defaults_vec = []
+ for i, key in enumerate(dense_keys):
+ default_value = dense_defaults.get(key)
+ dense_shape = dense_shapes[i]
+ if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
+ dense_shape[0].value is None):
+ # Variable stride dense shape, the default value should be a
+ # scalar padding value
+ if default_value is None:
+ default_value = ops.convert_to_tensor(
+ "" if dense_types[i] == dtypes.string else 0, dtype=dense_types[i])
+ else:
+ # Reshape to a scalar to ensure user gets an error if they
+ # provide a tensor that's not intended to be a padding value
+ # (0 or 2+ elements).
+ key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, [])
+ else:
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, dense_shape)
+
+ dense_defaults_vec.append(default_value)
+
+ # Finally, convert dense_shapes to TensorShapeProto
+ dense_shapes_as_proto = [shape.as_proto() for shape in dense_shapes]
+
+ return (names, dense_defaults_vec, sparse_keys, sparse_types, dense_keys,
+ dense_shapes_as_proto, dense_shapes)
+
+
@tf_export("parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@@ -855,6 +897,352 @@ def _parse_single_example_raw(serialized,
return outputs
+@tf_export("io.parse_sequence_example")
+def parse_sequence_example(serialized,
+ context_features=None,
+ sequence_features=None,
+ example_names=None,
+ name=None):
+ # pylint: disable=line-too-long
+ """Parses a batch of `SequenceExample` protos.
+
+ Parses a vector of serialized
+ [`SequenceExample`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
+ protos given in `serialized`.
+
+ This op parses serialized sequence examples into a tuple of dictionaries
+ mapping keys to `Tensor` and `SparseTensor` objects respectively.
+ The first dictionary contains mappings for keys appearing in
+ `context_features`, and the second dictionary contains mappings for keys
+ appearing in `sequence_features`.
+
+ At least one of `context_features` and `sequence_features` must be provided
+ and non-empty.
+
+ The `context_features` keys are associated with a `SequenceExample` as a
+ whole, independent of time / frame. In contrast, the `sequence_features` keys
+ provide a way to access variable-length data within the `FeatureList` section
+ of the `SequenceExample` proto. While the shapes of `context_features` values
+ are fixed with respect to frame, the frame dimension (the first dimension)
+ of `sequence_features` values may vary between `SequenceExample` protos,
+ and even between `feature_list` keys within the same `SequenceExample`.
+
+ `context_features` contains `VarLenFeature` and `FixedLenFeature` objects.
+ Each `VarLenFeature` is mapped to a `SparseTensor`, and each `FixedLenFeature`
+ is mapped to a `Tensor`, of the specified type, shape, and default value.
+
+ `sequence_features` contains `VarLenFeature` and `FixedLenSequenceFeature`
+ objects. Each `VarLenFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenSequenceFeature` is mapped to a `Tensor`, each of the specified type.
+ The shape will be `(B,T,) + df.dense_shape` for `FixedLenSequenceFeature`
+ `df`, where `B` is the batch size, and `T` is the length of the associated
+ `FeatureList` in the `SequenceExample`. For instance,
+ `FixedLenSequenceFeature([])` yields a scalar 2-D `Tensor` of static shape
+ `[None, None]` and dynamic shape `[B, T]`, while
+ `FixedLenSequenceFeature([k])` (for `int k >= 1`) yields a 3-D matrix `Tensor`
+ of static shape `[None, None, k]` and dynamic shape `[B, T, k]`.
+
+ Like the input, the resulting output tensors have a batch dimension. This
+ means that the original per-example shapes of `VarLenFeature`s and
+ `FixedLenSequenceFeature`s can be lost. To handle that situation, this op also
+ provides dicts of shape tensors as part of the output. There is one dict for
+ the context features, and one for the feature_list features. Context features
+ of type `FixedLenFeature`s will not be present, since their shapes are already
+ known by the caller. In situations where the input 'FixedLenFeature`s are of
+ different lengths across examples, the shorter examples will be padded with
+ default datatype values: 0 for numeric types, and the empty string for string
+ types.
+
+ Each `SparseTensor` corresponding to `sequence_features` represents a ragged
+ vector. Its indices are `[time, index]`, where `time` is the `FeatureList`
+ entry and `index` is the value's index in the list of values associated with
+ that time.
+
+ `FixedLenFeature` entries with a `default_value` and `FixedLenSequenceFeature`
+ entries with `allow_missing=True` are optional; otherwise, we will fail if
+ that `Feature` or `FeatureList` is missing from any example in `serialized`.
+
+ `example_name` may contain a descriptive name for the corresponding serialized
+ proto. This may be useful for debugging purposes, but it has no effect on the
+ output. If not `None`, `example_name` must be a scalar.
+
+ Args:
+ serialized: A vector (1-D Tensor) of type string containing binary
+ serialized `SequenceExample` protos.
+ context_features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. These features are associated with a
+ `SequenceExample` as a whole.
+ sequence_features: A `dict` mapping feature keys to
+ `FixedLenSequenceFeature` or `VarLenFeature` values. These features are
+ associated with data within the `FeatureList` section of the
+ `SequenceExample` proto.
+ example_names: A vector (1-D Tensor) of strings (optional), the name of the
+ serialized protos.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
+ The first dict contains the context key/values.
+ The second dict contains the feature_list key/values.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
+ if not (context_features or sequence_features):
+ raise ValueError("Missing features.")
+ (context_sparse_keys, context_sparse_types, context_dense_keys,
+ context_dense_types,
+ context_dense_defaults, context_dense_shapes) = _features_to_raw_params(
+ context_features, [VarLenFeature, FixedLenFeature])
+ (feature_list_sparse_keys, feature_list_sparse_types, feature_list_dense_keys,
+ feature_list_dense_types, feature_list_dense_defaults,
+ feature_list_dense_shapes) = _features_to_raw_params(
+ sequence_features, [VarLenFeature, FixedLenSequenceFeature])
+ return _parse_sequence_example_raw(
+ serialized, example_names, context_sparse_keys, context_sparse_types,
+ context_dense_keys, context_dense_types, context_dense_defaults,
+ context_dense_shapes, feature_list_sparse_keys, feature_list_sparse_types,
+ feature_list_dense_keys, feature_list_dense_types,
+ feature_list_dense_shapes, feature_list_dense_defaults, name)
+
+
+def _parse_sequence_example_raw(serialized,
+ debug_name=None,
+ context_sparse_keys=None,
+ context_sparse_types=None,
+ context_dense_keys=None,
+ context_dense_types=None,
+ context_dense_defaults=None,
+ context_dense_shapes=None,
+ feature_list_sparse_keys=None,
+ feature_list_sparse_types=None,
+ feature_list_dense_keys=None,
+ feature_list_dense_types=None,
+ feature_list_dense_shapes=None,
+ feature_list_dense_defaults=None,
+ name=None):
+ """Parses a vector of `SequenceExample` protos.
+
+ Args:
+ serialized: A vector (1-D Tensor) of type string, containing binary
+ serialized `SequenceExample` protos.
+ debug_name: A vector (1-D Tensor) of strings (optional), the names of the
+ serialized protos.
+ context_sparse_keys: A list of string keys in the `SequenceExample`'s
+ features. The results for these keys will be returned as `SparseTensor`
+ objects.
+ context_sparse_types: A list of `DTypes`, the same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ context_dense_keys: A list of string keys in the examples' features. The
+ results for these keys will be returned as `Tensor`s
+ context_dense_types: A list of DTypes, same length as `context_dense_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ context_dense_defaults: A dict mapping string keys to `Tensor`s. The keys of
+ the dict must match the context_dense_keys of the feature.
+ context_dense_shapes: A list of tuples, same length as `context_dense_keys`.
+ The shape of the data for each context_dense feature referenced by
+ `context_dense_keys`. Required for any input tensors identified by
+ `context_dense_keys` whose shapes are anything other than `[]` or `[1]`.
+ feature_list_sparse_keys: A list of string keys in the `SequenceExample`'s
+ feature_lists. The results for these keys will be returned as
+ `SparseTensor` objects.
+ feature_list_sparse_types: A list of `DTypes`, same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
+ (`BytesList`) are supported.
+ feature_list_dense_keys: A list of string keys in the `SequenceExample`'s
+ features_lists. The results for these keys will be returned as `Tensor`s.
+ feature_list_dense_types: A list of `DTypes`, same length as
+ `feature_list_dense_keys`. Only `tf.float32` (`FloatList`), `tf.int64`
+ (`Int64List`), and `tf.string` (`BytesList`) are supported.
+ feature_list_dense_shapes: A list of tuples, same length as
+ `feature_list_dense_keys`. The shape of the data for each `FeatureList`
+ feature referenced by `feature_list_dense_keys`.
+ feature_list_dense_defaults: A dict mapping key strings to values. The only
+ currently allowed value is `None`. Any key appearing in this dict with
+ value `None` is allowed to be missing from the `SequenceExample`. If
+ missing, the key is treated as zero-length.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple of three `dict`s, each mapping keys to `Tensor`s and
+ `SparseTensor`s. The first dict contains the context key/values,
+ the second dict contains the feature_list key/values, and the final dict
+ contains the lengths of any dense feature_list features.
+
+ Raises:
+ ValueError: If context_sparse and context_dense key sets intersect,
+ if feature_list_sparse and feature_list_dense key sets intersect,
+ if input lengths do not match up, or if a value in
+ feature_list_dense_defaults is not None.
+ TypeError: if feature_list_dense_defaults is not either None or a dict.
+ """
+ with ops.name_scope(name, "ParseSequenceExample", [serialized]):
+ context_dense_defaults = ({} if context_dense_defaults is None else
+ context_dense_defaults)
+ context_sparse_keys = ([] if context_sparse_keys is None else
+ context_sparse_keys)
+ context_sparse_types = ([] if context_sparse_types is None else
+ context_sparse_types)
+ context_dense_keys = ([]
+ if context_dense_keys is None else context_dense_keys)
+ context_dense_types = ([] if context_dense_types is None else
+ context_dense_types)
+ context_dense_shapes = ([[]] * len(context_dense_keys)
+ if context_dense_shapes is None else
+ context_dense_shapes)
+ feature_list_sparse_keys = ([] if feature_list_sparse_keys is None else
+ feature_list_sparse_keys)
+ feature_list_sparse_types = ([] if feature_list_sparse_types is None else
+ feature_list_sparse_types)
+ feature_list_dense_keys = ([] if feature_list_dense_keys is None else
+ feature_list_dense_keys)
+ feature_list_dense_types = ([] if feature_list_dense_types is None else
+ feature_list_dense_types)
+ feature_list_dense_shapes = ([[]] * len(feature_list_dense_keys)
+ if feature_list_dense_shapes is None else
+ feature_list_dense_shapes)
+ feature_list_dense_defaults = (
+ dict()
+ if feature_list_dense_defaults is None else feature_list_dense_defaults)
+ debug_name = [] if debug_name is None else debug_name
+
+ # Internal
+ feature_list_dense_missing_assumed_empty = []
+
+ num_context_dense = len(context_dense_keys)
+ num_feature_list_dense = len(feature_list_dense_keys)
+ num_context_sparse = len(context_sparse_keys)
+ num_feature_list_sparse = len(feature_list_sparse_keys)
+
+ if len(context_dense_shapes) != num_context_dense:
+ raise ValueError(
+ "len(context_dense_shapes) != len(context_dense_keys): %d vs. %d" %
+ (len(context_dense_shapes), num_context_dense))
+ if len(context_dense_types) != num_context_dense:
+ raise ValueError(
+ "len(context_dense_types) != len(num_context_dense): %d vs. %d" %
+ (len(context_dense_types), num_context_dense))
+ if len(feature_list_dense_shapes) != num_feature_list_dense:
+ raise ValueError(
+ "len(feature_list_dense_shapes) != len(feature_list_dense_keys): "
+ "%d vs. %d" % (len(feature_list_dense_shapes),
+ num_feature_list_dense))
+ if len(feature_list_dense_types) != num_feature_list_dense:
+ raise ValueError(
+ "len(feature_list_dense_types) != len(num_feature_list_dense):"
+ "%d vs. %d" % (len(feature_list_dense_types), num_feature_list_dense))
+ if len(context_sparse_types) != num_context_sparse:
+ raise ValueError(
+ "len(context_sparse_types) != len(context_sparse_keys): %d vs. %d" %
+ (len(context_sparse_types), num_context_sparse))
+ if len(feature_list_sparse_types) != num_feature_list_sparse:
+ raise ValueError(
+ "len(feature_list_sparse_types) != len(feature_list_sparse_keys): "
+ "%d vs. %d" % (len(feature_list_sparse_types),
+ num_feature_list_sparse))
+ if (num_context_dense + num_context_sparse + num_feature_list_dense +
+ num_feature_list_sparse) == 0:
+ raise ValueError(
+ "Must provide at least one context_sparse key, context_dense key, "
+ ", feature_list_sparse key, or feature_list_dense key")
+ if not set(context_dense_keys).isdisjoint(set(context_sparse_keys)):
+ raise ValueError(
+ "context_dense and context_sparse keys must not intersect; "
+ "intersection: %s" % set(context_dense_keys).intersection(
+ set(context_sparse_keys)))
+ if not set(feature_list_dense_keys).isdisjoint(
+ set(feature_list_sparse_keys)):
+ raise ValueError(
+ "feature_list_dense and feature_list_sparse keys must not intersect; "
+ "intersection: %s" % set(feature_list_dense_keys).intersection(
+ set(feature_list_sparse_keys)))
+ if not isinstance(feature_list_dense_defaults, dict):
+ raise TypeError("feature_list_dense_defaults must be a dict")
+ for k, v in feature_list_dense_defaults.items():
+ if v is not None:
+ raise ValueError(
+ "Value feature_list_dense_defaults[%s] must be None" % k)
+ feature_list_dense_missing_assumed_empty.append(k)
+
+ context_dense_defaults_vec = []
+ for i, key in enumerate(context_dense_keys):
+ default_value = context_dense_defaults.get(key)
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=context_dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=context_dense_types[i], name=key_name)
+
+ context_dense_defaults_vec.append(default_value)
+
+ context_dense_shapes = [
+ tensor_shape.as_shape(shape).as_proto()
+ for shape in context_dense_shapes
+ ]
+ feature_list_dense_shapes = [
+ tensor_shape.as_shape(shape).as_proto()
+ for shape in feature_list_dense_shapes
+ ]
+
+ # pylint: disable=protected-access
+ outputs = gen_parsing_ops.parse_sequence_example(
+ serialized=serialized,
+ debug_name=debug_name,
+ Ncontext_sparse=num_context_sparse,
+ Ncontext_dense=num_context_dense,
+ Nfeature_list_sparse=num_feature_list_sparse,
+ Nfeature_list_dense=num_feature_list_dense,
+ context_dense_defaults=context_dense_defaults_vec,
+ context_sparse_keys=context_sparse_keys,
+ context_sparse_types=context_sparse_types,
+ context_dense_keys=context_dense_keys,
+ context_dense_shapes=context_dense_shapes,
+ feature_list_sparse_keys=feature_list_sparse_keys,
+ feature_list_sparse_types=feature_list_sparse_types,
+ feature_list_dense_keys=feature_list_dense_keys,
+ feature_list_dense_types=feature_list_dense_types,
+ feature_list_dense_shapes=feature_list_dense_shapes,
+ feature_list_dense_missing_assumed_empty=(
+ feature_list_dense_missing_assumed_empty),
+ name=name)
+ # pylint: enable=protected-access
+
+ (context_sparse_indices, context_sparse_values, context_sparse_shapes,
+ context_dense_values, feature_list_sparse_indices,
+ feature_list_sparse_values, feature_list_sparse_shapes,
+ feature_list_dense_values, feature_list_dense_lengths) = outputs
+
+ context_sparse_tensors = [
+ sparse_tensor.SparseTensor(ix, val, shape)
+ for (ix, val,
+ shape) in zip(context_sparse_indices, context_sparse_values,
+ context_sparse_shapes)
+ ]
+
+ feature_list_sparse_tensors = [
+ sparse_tensor.SparseTensor(ix, val, shape)
+ for (ix, val, shape
+ ) in zip(feature_list_sparse_indices, feature_list_sparse_values,
+ feature_list_sparse_shapes)
+ ]
+
+ context_output = dict(
+ zip(context_sparse_keys + context_dense_keys,
+ context_sparse_tensors + context_dense_values))
+ feature_list_output = dict(
+ zip(feature_list_sparse_keys + feature_list_dense_keys,
+ feature_list_sparse_tensors + feature_list_dense_values))
+ feature_list_lengths = dict(
+ zip(feature_list_dense_keys, feature_list_dense_lengths))
+
+ return (context_output, feature_list_output, feature_list_lengths)
+
+
+# TODO(sundberg): rewrite this method to call the batch version, which is more
+# efficient especially for large inputs.
@tf_export("parse_single_sequence_example")
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index d533731c07..55c2eb5fa4 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -94,26 +94,8 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
ops.set_shape_and_handle_data_for_outputs(h.op)
handle._handle_data = h._handle_data
# pylint: enable=protected-access
-
- # Clean up our reference cycles to avoid making the garbage collector run.
- # pylint: disable=protected-access
- # OrderedDict, constructed on Graph creation, makes a simple reference loop
- # and hides it in an __attribute in some Python versions. We don't need to
- # throw an error if we can't find it, but if we do find it we can break the
- # loop to avoid creating work for the garbage collector.
- problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
- # pylint: enable=protected-access
- if problematic_cycle:
- try:
- del problematic_cycle[0][:]
- except TypeError:
- # This is probably not one of the problematic Python versions. Continue
- # with the rest of our cleanup.
- pass
- # Now clean up our own reference cycles by clearing all of the attributes for
- # the Graph and op we created.
- h.__dict__ = {}
- graph.__dict__ = {}
+ # Clean up op->graph->op reference cycles.
+ ops.dismantle_graph(graph)
return handle
@@ -185,7 +167,8 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None):
class ResourceVariable(variables.RefVariable):
"""Variable based on resource handles.
- See the @{$variables$Variables How To} for a high level overview.
+ See the [Variables How To](https://tensorflow.org/guide/variables)
+ for a high level overview.
A `ResourceVariable` allows you to maintain state across subsequent calls to
session.run.
@@ -372,6 +355,15 @@ class ResourceVariable(variables.RefVariable):
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
+ if isinstance(initial_value, ops.Tensor) and hasattr(
+ initial_value, "graph") and initial_value.graph.building_function:
+ raise ValueError("Tensor-typed variable initializers must either be "
+ "wrapped in an init_scope or callable "
+ "(e.g., `tf.Variable(lambda : "
+ "tf.truncated_normal([10, 40]))`) when building "
+ "functions. Please file a feature request if this "
+ "restriction inconveniences you.")
+
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if not isinstance(collections, (list, tuple, set)):
@@ -603,6 +595,22 @@ class ResourceVariable(variables.RefVariable):
def __bool__(self):
return bool(self.read_value())
+ def __copy__(self):
+ return self
+
+ def __deepcopy__(self, memo):
+ if not context.executing_eagerly():
+ raise NotImplementedError(
+ "__deepcopy__() is only available when eager execution is enabled.")
+ copied_variable = ResourceVariable(
+ initial_value=self.read_value(),
+ trainable=self._trainable,
+ constraint=self._constraint,
+ dtype=self._dtype,
+ name=self._shared_name + "_copy")
+ memo[self._unique_id] = copied_variable
+ return copied_variable
+
@property
def dtype(self):
"""The dtype of this variable."""
@@ -742,7 +750,7 @@ class ResourceVariable(variables.RefVariable):
def _read_variable_op(self):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
if not context.executing_eagerly():
@@ -773,7 +781,7 @@ class ResourceVariable(variables.RefVariable):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
@@ -941,12 +949,12 @@ class ResourceVariable(variables.RefVariable):
def _lazy_read(self, op):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
return _UnreadVariable(
handle=self._handle, dtype=self.dtype, shape=self._shape,
in_graph_mode=self._in_graph_mode,
deleter=self._handle_deleter if not self._in_graph_mode else None,
- parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id)
+ parent_op=op, unique_id=self._unique_id)
def assign(self, value, use_locking=None, name=None, read_value=True):
"""Assigns a new value to this variable.
@@ -975,6 +983,231 @@ class ResourceVariable(variables.RefVariable):
return self._lazy_read(assign_op)
return assign_op
+ def __reduce__(self):
+ return (ResourceVariable, (self.numpy(),))
+
+ def scatter_sub(self, sparse_delta, use_locking=False, name=None):
+ """Subtracts `IndexedSlices` from this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be subtracted from this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub(
+ self.handle, sparse_delta.indices,
+ ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
+
+ def scatter_add(self, sparse_delta, use_locking=False, name=None):
+ """Adds `IndexedSlices` from this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be added to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return self._lazy_read(gen_resource_variable_ops.resource_scatter_add(
+ self.handle, sparse_delta.indices,
+ ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
+
+ def scatter_update(self, sparse_delta, use_locking=False, name=None):
+ """Assigns `IndexedSlices` to this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be assigned to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return self._lazy_read(gen_resource_variable_ops.resource_scatter_update(
+ self.handle, sparse_delta.indices,
+ ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
+
+ def scatter_nd_sub(self, indices, updates, name=None):
+ """Applies sparse subtraction to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_sub(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, -9, 3, -6, -6, 6, 7, -4]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return self._lazy_read(gen_state_ops.resource_scatter_nd_sub(
+ self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
+ name=name))
+
+ def scatter_nd_add(self, indices, updates, name=None):
+ """Applies sparse addition to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = ref.scatter_nd_add(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return self._lazy_read(gen_state_ops.resource_scatter_nd_add(
+ self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
+ name=name))
+
+ def scatter_nd_update(self, indices, updates, name=None):
+ """Applies sparse assignment to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_update(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return self._lazy_read(gen_state_ops.resource_scatter_nd_update(
+ self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
+ name=name))
+
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask):
@@ -1060,8 +1293,7 @@ class _UnreadVariable(ResourceVariable):
"""
def __init__(self, handle, dtype, # pylint: disable=super-init-not-called
- shape, in_graph_mode, deleter, parent_op, parent_name,
- unique_id):
+ shape, in_graph_mode, deleter, parent_op, unique_id):
# We do not call super init on purpose.
self._trainable = False
self._save_slice_info = None
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 7b6ab20975..5c00d929bf 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
@@ -144,6 +145,28 @@ def _should_cache():
return control_flow_util.GetContainingWhileContext(ctxt) is None
+def _is_keras_rnn_cell(rnn_cell):
+ """Check whether the cell is a Keras RNN cell.
+
+ The Keras RNN cell accept the state as a list even the state is a single
+ tensor, whereas the TF RNN cell does not wrap single state tensor in list.
+ This behavior difference should be unified in future version.
+
+ Args:
+ rnn_cell: An RNN cell instance that either follow the Keras interface or TF
+ RNN interface.
+ Returns:
+ Boolean, whether the cell is an Keras RNN cell.
+ """
+ # Cell type check is not strict enough since there are cells created by other
+ # library like Deepmind that didn't inherit tf.nn.rnn_cell.RNNCell.
+ # Keras cells never had zero_state method, which was from the original
+ # interface from TF RNN cell.
+ return (not isinstance(rnn_cell, rnn_cell_impl.RNNCell)
+ and isinstance(rnn_cell, base_layer.Layer)
+ and getattr(rnn_cell, "zero_state", None) is None)
+
+
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -608,7 +631,11 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
else:
if not dtype:
raise ValueError("If there is no initial_state, you must give a dtype.")
- state = cell.zero_state(batch_size, dtype)
+ if getattr(cell, "get_initial_state", None) is not None:
+ state = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
+ else:
+ state = cell.zero_state(batch_size, dtype)
def _assert_has_shape(x, shape):
x_shape = array_ops.shape(x)
@@ -788,6 +815,10 @@ def _dynamic_rnn_loop(cell,
input_t = tuple(ta[time.numpy()] for ta in input_ta)
input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
+ # Keras RNN cells only accept state as list, even if it's a single tensor.
+ is_keras_rnn_cell = _is_keras_rnn_cell(cell)
+ if is_keras_rnn_cell and not nest.is_sequence(state):
+ state = [state]
call_cell = lambda: cell(input_t, state)
if sequence_length is not None:
@@ -804,6 +835,9 @@ def _dynamic_rnn_loop(cell,
else:
(output, new_state) = call_cell()
+ # Keras cells always wrap state as list, even if it's a single tensor.
+ if is_keras_rnn_cell and len(new_state) == 1:
+ new_state = new_state[0]
# Pack state if using state tuples
output = nest.flatten(output)
@@ -1286,7 +1320,11 @@ def static_rnn(cell,
if not dtype:
raise ValueError("If no initial_state is provided, "
"dtype must be specified")
- state = cell.zero_state(batch_size, dtype)
+ if getattr(cell, "get_initial_state", None) is not None:
+ state = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
+ else:
+ state = cell.zero_state(batch_size, dtype)
if sequence_length is not None: # Prepare variables
sequence_length = ops.convert_to_tensor(
@@ -1315,6 +1353,10 @@ def static_rnn(cell,
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
+ # Keras RNN cells only accept state as list, even if it's a single tensor.
+ is_keras_rnn_cell = _is_keras_rnn_cell(cell)
+ if is_keras_rnn_cell and not nest.is_sequence(state):
+ state = [state]
for time, input_ in enumerate(inputs):
if time > 0:
varscope.reuse_variables()
@@ -1333,8 +1375,10 @@ def static_rnn(cell,
state_size=cell.state_size)
else:
(output, state) = call_cell()
-
outputs.append(output)
+ # Keras RNN cells only return state as list, even if it's a single tensor.
+ if is_keras_rnn_cell and len(state) == 1:
+ state = state[0]
return (outputs, state)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 85a6a2233c..c11c9ccaae 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -80,13 +80,13 @@ def assert_like_rnncell(cell_name, cell):
conditions = [
hasattr(cell, "output_size"),
hasattr(cell, "state_size"),
- hasattr(cell, "zero_state"),
+ hasattr(cell, "get_initial_state") or hasattr(cell, "zero_state"),
callable(cell),
]
errors = [
"'output_size' property is missing",
"'state_size' property is missing",
- "'zero_state' method is missing",
+ "either 'zero_state' or 'get_initial_state' method is required",
"is not callable"
]
@@ -266,6 +266,36 @@ class RNNCell(base_layer.Layer):
# self.add_variable() inside the call() method.
pass
+ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
+ if inputs is not None:
+ # Validate the given batch_size and dtype against inputs if provided.
+ inputs = ops.convert_to_tensor(inputs, name="inputs")
+ if batch_size is not None:
+ if tensor_util.is_tensor(batch_size):
+ static_batch_size = tensor_util.constant_value(
+ batch_size, partial=True)
+ else:
+ static_batch_size = batch_size
+ if inputs.shape[0].value != static_batch_size:
+ raise ValueError(
+ "batch size from input tensor is different from the "
+ "input param. Input tensor batch: {}, batch_size: {}".format(
+ inputs.shape[0].value, batch_size))
+
+ if dtype is not None and inputs.dtype != dtype:
+ raise ValueError(
+ "dtype from input tensor is different from the "
+ "input param. Input tensor dtype: {}, dtype: {}".format(
+ inputs.dtype, dtype))
+
+ batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
+ dtype = inputs.dtype
+ if None in [batch_size, dtype]:
+ raise ValueError(
+ "batch_size and dtype cannot be None while constructing initial "
+ "state: batch_size={}, dtype={}".format(batch_size, dtype))
+ return self.zero_state(batch_size, dtype)
+
def zero_state(self, batch_size, dtype):
"""Return zero-filled state tensor(s).
@@ -344,6 +374,9 @@ class LayerRNNCell(RNNCell):
class BasicRNNCell(LayerRNNCell):
"""The most basic RNN cell.
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
+
Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`. It could also be string
@@ -369,6 +402,10 @@ class BasicRNNCell(LayerRNNCell):
**kwargs):
super(BasicRNNCell, self).__init__(
_reuse=reuse, name=name, dtype=dtype, **kwargs)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -391,7 +428,7 @@ class BasicRNNCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
self._kernel = self.add_variable(
@@ -427,6 +464,10 @@ class BasicRNNCell(LayerRNNCell):
class GRUCell(LayerRNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
+ `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU.
+
Args:
num_units: int, The number of units in the GRU cell.
activation: Nonlinearity to use. Default: `tanh`.
@@ -457,6 +498,10 @@ class GRUCell(LayerRNNCell):
super(GRUCell, self).__init__(
_reuse=reuse, name=name, dtype=dtype, **kwargs)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -480,7 +525,7 @@ class GRUCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
self._gate_kernel = self.add_variable(
@@ -580,6 +625,11 @@ class BasicLSTMCell(LayerRNNCell):
For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell`
that follows.
+
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
+ `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
+ better performance on CPU.
"""
@deprecated(None, "This class is deprecated, please use "
@@ -626,6 +676,10 @@ class BasicLSTMCell(LayerRNNCell):
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -651,7 +705,7 @@ class BasicLSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units
@@ -729,10 +783,10 @@ class LSTMCell(LayerRNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -744,6 +798,11 @@ class LSTMCell(LayerRNNCell):
The class uses optional peep-hole connections, optional cell clipping, and
an optional projection layer.
+
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
+ `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
+ better performance on CPU.
"""
def __init__(self, num_units,
@@ -803,6 +862,10 @@ class LSTMCell(LayerRNNCell):
"%s: The num_unit_shards and proj_unit_shards parameters are "
"deprecated and will be removed in Jan 2017. "
"Use a variable scope with a partitioner instead.", self)
+ if context.executing_eagerly() and context.num_gpus() > 0:
+ logging.warn("%s: Note that this cell is not optimized for performance. "
+ "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
+ "performance on GPU.", self)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
@@ -845,7 +908,7 @@ class LSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units if self._num_proj is None else self._num_proj
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index d11e446dbf..2ec4b540fb 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Script Language Operators. See the @{$python/script_ops} guide."""
+"""Script Language Operators."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@@ -287,19 +287,19 @@ def _internal_py_func(func,
# TODO(akshayka): Implement higher-order derivatives.
@ops.RegisterGradient("EagerPyFunc")
-def _EagerPyFuncGrad(op, dy):
+def _EagerPyFuncGrad(op, *dy):
"""Computes the gradient of an EagerPyFunc."""
token = op.get_attr("token")
- def eagerly_executed_grad(dy):
+ def eagerly_executed_grad(*dy):
tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
with ops.control_dependencies(op.outputs):
return _internal_py_func(
func=eagerly_executed_grad,
- inp=[dy] if isinstance(dy, ops.Tensor) else dy,
+ inp=dy,
Tout=[tensor.dtype for tensor in op.inputs],
eager=True,
is_grad_func=True)
@@ -343,7 +343,8 @@ def eager_py_func(func, inp, Tout, name=None):
or print statements as desired, and wrap those functions in
`tf.contrib.eager.py_func`.
- For more information on eager execution, see @{$guide/eager}.
+ For more information on eager execution, see the
+ [Eager guide](https://tensorflow.org/guide/eager).
`tf.contrib.eager.py_func` is similar in spirit to `tf.py_func`, but unlike
the latter, the former lets you use TensorFlow operations in the wrapped
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index dee84bab0c..e229501c10 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -13,7 +13,11 @@
# limitations under the License.
# ==============================================================================
-"""Tensor Handle Operations. See the @{$python/session_ops} guide."""
+"""Tensor Handle Operations.
+
+See the [Session Ops](https://tensorflow.org/api_guides/python/session_ops)
+guide.
+"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 44b66a3049..400a42a3c0 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -14,7 +14,10 @@
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
-"""Sparse Tensor Representation. See the @{$python/sparse_ops} guide."""
+"""Sparse Tensor Representation.
+
+See the [Sparse Ops](https://tensorflow.org/api_guides/python/sparse_ops) guide.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -38,6 +41,7 @@ from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_sparse_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -82,6 +86,104 @@ def _convert_to_sparse_tensors(sp_inputs):
raise TypeError("Inputs must be a list or tuple.")
+def _make_int64_tensor(value, name):
+ if isinstance(value, compat.integral_types):
+ return ops.convert_to_tensor(value, name=name, dtype=dtypes.int64)
+ if not isinstance(value, ops.Tensor):
+ raise TypeError("{} must be an integer value".format(name))
+ if value.dtype == dtypes.int64:
+ return value
+ return math_ops.cast(value, dtypes.int64)
+
+
+@tf_export("sparse.expand_dims")
+def sparse_expand_dims(sp_input, axis=None, name=None):
+ """Inserts a dimension of 1 into a tensor's shape.
+
+ Given a tensor `sp_input`, this operation inserts a dimension of 1 at the
+ dimension index `axis` of `sp_input`'s shape. The dimension index `axis`
+ starts at zero; if you specify a negative number for `axis` it is counted
+ backwards from the end.
+
+ Args:
+ sp_input: A `SparseTensor`.
+ axis: 0-D (scalar). Specifies the dimension index at which to expand the
+ shape of `input`. Must be in the range `[-rank(sp_input) - 1,
+ rank(sp_input)]`.
+ name: The name of the output `SparseTensor`.
+
+ Returns:
+ A `SparseTensor` with the same data as `sp_input`, but its shape has an
+ additional dimension of size 1 added.
+ """
+ rank = sp_input.dense_shape.get_shape()[0]
+ axis = -1 if axis is None else axis
+
+ with ops.name_scope(name, default_name="expand_dims", values=[sp_input]):
+ if isinstance(axis, compat.integral_types):
+ axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32)
+ elif not isinstance(axis, ops.Tensor):
+ raise TypeError("axis must be an integer value in range [-rank(sp_input)"
+ " - 1, rank(sp_input)]")
+
+ # Convert axis to a positive value if it is negative.
+ axis = array_ops.where(axis >= 0, axis, axis + rank + 1)
+
+ # Create the new column of indices for the sparse tensor by slicing
+ # the indices and inserting a new column of indices for the new dimension.
+ column_size = array_ops.shape(sp_input.indices)[0]
+ new_index = array_ops.zeros([column_size, 1], dtype=dtypes.int64)
+ indices_before = array_ops.slice(sp_input.indices, [0, 0], [-1, axis])
+ indices_after = array_ops.slice(sp_input.indices, [0, axis], [-1, -1])
+ indices = array_ops.concat(
+ [indices_before, new_index, indices_after], axis=1)
+
+ # Create the new dense shape by splicing the tensor [1] in the correct
+ # dimension of the existing shape.
+ shape_before = array_ops.slice(sp_input.dense_shape, [0], [axis])
+ shape_after = array_ops.slice(sp_input.dense_shape, [axis], [-1])
+ new_shape = ops.convert_to_tensor([1], name="new_shape", dtype=dtypes.int64)
+ shape = array_ops.concat([shape_before, new_shape, shape_after], axis=0)
+
+ # Create the output sparse tensor.
+ return sparse_tensor.SparseTensor(
+ indices=indices, values=sp_input.values, dense_shape=shape)
+
+
+@tf_export("sparse.eye")
+def sparse_eye(num_rows,
+ num_columns=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Creates a two-dimensional sparse tensor with ones along the diagonal.
+
+ Args:
+ num_rows: Non-negative integer or `int32` scalar `tensor` giving the number
+ of rows in the resulting matrix.
+ num_columns: Optional non-negative integer or `int32` scalar `tensor` giving
+ the number of columns in the resulting matrix. Defaults to `num_rows`.
+ dtype: The type of element in the resulting `Tensor`.
+ name: A name for this `Op`. Defaults to "eye".
+
+ Returns:
+ A `SparseTensor` of shape [num_rows, num_columns] with ones along the
+ diagonal.
+ """
+ with ops.name_scope(name, default_name="eye", values=[num_rows, num_columns]):
+ num_rows = _make_int64_tensor(num_rows, "num_rows")
+ num_columns = num_rows if num_columns is None else _make_int64_tensor(
+ num_columns, "num_columns")
+
+ # Create the sparse tensor.
+ diag_size = math_ops.minimum(num_rows, num_columns)
+ diag_range = math_ops.range(diag_size, dtype=dtypes.int64)
+
+ return sparse_tensor.SparseTensor(
+ indices=array_ops.stack([diag_range, diag_range], axis=1),
+ values=array_ops.ones(diag_size, dtype=dtype),
+ dense_shape=[num_rows, num_columns])
+
+
# pylint: disable=protected-access
@tf_export("sparse_concat")
@deprecation.deprecated_args(
@@ -787,6 +889,9 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
`tf.reduce_max()`. In particular, this Op also returns a dense `Tensor`
instead of a sparse one.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
@@ -854,6 +959,9 @@ def sparse_reduce_max_sparse(sp_input,
`tf.reduce_max()`. In contrast to SparseReduceSum, this Op returns a
SparseTensor.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
@@ -955,6 +1063,9 @@ def sparse_reduce_sum_sparse(sp_input,
`tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a
SparseTensor.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
@@ -1240,7 +1351,11 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
new_shape = array_ops.concat([sp_ids[0].dense_shape[:-1], vocab_size], 0)
result = sparse_tensor.SparseTensor(new_indices, new_values, new_shape)
- return result if already_sorted else sparse_reorder(result)
+ if already_sorted:
+ return result
+ sorted_result = sparse_reorder(result)
+ return sparse_tensor.SparseTensor(
+ sorted_result.indices, sorted_result.values, new_shape)
@tf_export("sparse_retain")
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
new file mode 100644
index 0000000000..4ee1569249
--- /dev/null
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -0,0 +1,81 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sparse ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import googletest
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class SparseOpsTest(test_util.TensorFlowTestCase):
+
+ def testSparseEye(self):
+ def test_one(n, m, as_tensors):
+ expected = np.eye(n, m)
+ if as_tensors:
+ m = constant_op.constant(m)
+ n = constant_op.constant(n)
+ s = sparse_ops.sparse_eye(n, m)
+ d = sparse_ops.sparse_to_dense(s.indices, s.dense_shape, s.values)
+ self.assertAllEqual(self.evaluate(d), expected)
+
+ for n in range(2, 10, 2):
+ for m in range(2, 10, 2):
+ # Test with n and m as both constants and tensors.
+ test_one(n, m, True)
+ test_one(n, m, False)
+
+ def testSparseExpandDims(self):
+ for rank in range(1, 4):
+ # Create a dummy input. When rank=3, shape=[2, 4, 6].
+ shape = np.arange(1, rank + 1) * 2
+ before = np.arange(np.prod(shape)).reshape(shape)
+
+ # Make entries sparse.
+ before *= np.random.binomial(1, .2, before.shape)
+ dense_shape = before.shape
+ indices = np.array(np.where(before)).T
+ values = before[before != 0]
+
+ # Try every possible valid value of axis.
+ for axis in range(-rank - 1, rank):
+ expected_after = np.expand_dims(before, axis)
+
+ for axis_as_tensor in [False, True]:
+ dense_shape_t = constant_op.constant(dense_shape, dtype=dtypes.int64)
+ indices_t = constant_op.constant(indices)
+ values_t = constant_op.constant(values)
+ before_t = sparse_tensor.SparseTensor(
+ indices=indices_t, values=values_t, dense_shape=dense_shape_t)
+
+ if axis_as_tensor:
+ axis = constant_op.constant(axis)
+
+ s = sparse_ops.sparse_expand_dims(before_t, axis)
+ d = sparse_ops.sparse_to_dense(s.indices, s.dense_shape, s.values)
+ self.assertAllEqual(self.evaluate(d), expected_after)
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index d556d11a1b..920047f38b 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================
-"""Variables. See the @{$python/state_ops} guide."""
+"""Variables.
+
+See the [Variables](https://tensorflow.org/api_guides/python/state_ops) guide.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -21,13 +24,15 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_state_ops import *
-from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
@@ -126,7 +131,7 @@ def is_variable_initialized(ref, name=None):
return ref.is_initialized(name=name)
-@tf_export("assign_sub")
+@tf_export(v1=["assign_sub"])
def assign_sub(ref, value, use_locking=None, name=None):
"""Update 'ref' by subtracting 'value' from it.
@@ -155,7 +160,7 @@ def assign_sub(ref, value, use_locking=None, name=None):
return ref.assign_sub(value)
-@tf_export("assign_add")
+@tf_export(v1=["assign_add"])
def assign_add(ref, value, use_locking=None, name=None):
"""Update 'ref' by adding 'value' to it.
@@ -184,7 +189,7 @@ def assign_add(ref, value, use_locking=None, name=None):
return ref.assign_add(value)
-@tf_export("assign")
+@tf_export(v1=["assign"])
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update 'ref' by assigning 'value' to it.
@@ -217,7 +222,7 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None):
return ref.assign(value, name=name)
-@tf_export("count_up_to")
+@tf_export(v1=["count_up_to"])
def count_up_to(ref, limit, name=None):
r"""Increments 'ref' until it reaches 'limit'.
@@ -240,7 +245,7 @@ def count_up_to(ref, limit, name=None):
ref.handle, limit, T=ref.dtype, name=name)
-@tf_export("scatter_update")
+@tf_export(v1=["scatter_update"])
def scatter_update(ref, indices, updates, use_locking=True, name=None):
# pylint: disable=line-too-long
r"""Applies sparse updates to a variable reference.
@@ -294,7 +299,7 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None):
name=name))
-@tf_export("scatter_nd_update")
+@tf_export(v1=["scatter_nd_update"])
def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
r"""Applies sparse `updates` to individual values or slices in a Variable.
@@ -356,7 +361,7 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
name=name))
-@tf_export("scatter_add")
+@tf_export(v1=["scatter_add"])
def scatter_add(ref, indices, updates, use_locking=False, name=None):
# pylint: disable=line-too-long
r"""Adds sparse updates to the variable referenced by `resource`.
@@ -408,7 +413,7 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None):
name=name))
-@tf_export("scatter_nd_add")
+@tf_export(v1=["scatter_nd_add"])
def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
r"""Applies sparse addition to individual values or slices in a Variable.
@@ -472,7 +477,7 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
name=name))
-@tf_export("scatter_sub")
+@tf_export(v1=["scatter_sub"])
def scatter_sub(ref, indices, updates, use_locking=False, name=None):
r"""Subtracts sparse updates to a variable reference.
@@ -524,3 +529,164 @@ def scatter_sub(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export(v1=["scatter_nd_sub"])
+def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
+ r"""Applies sparse subtraction to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to subtract 4 scattered elements from a rank-1 tensor
+ to 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = tf.scatter_nd_sub(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, -9, 3, -6, -6, 6, 7, -4]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ ref: A mutable `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
+ `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
+ `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into ref.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to add to ref.
+ use_locking: An optional `bool`. Defaults to `False`.
+ An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ A mutable `Tensor`. Has the same type as `ref`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_nd_sub(
+ ref, indices, updates, use_locking, name)
+ return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
+
+
+@tf_export("batch_scatter_update")
+def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
+ """Generalization of `tf.scatter_update` to axis different than 0.
+
+ Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
+ have a series of leading dimensions that are the same for all of them, and the
+ updates are performed on the last dimension of indices. In other words, the
+ dimensions should be the following:
+
+ `num_prefix_dims = indices.ndims - 1`
+ `batch_dim = num_prefix_dims + 1`
+ `updates.shape = indices.shape + var.shape[batch_dim:]`
+
+ where
+
+ `updates.shape[:num_prefix_dims]`
+ `== indices.shape[:num_prefix_dims]`
+ `== var.shape[:num_prefix_dims]`
+
+ And the operation performed can be expressed as:
+
+ `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
+
+ When indices is a 1D tensor, this operation is equivalent to
+ `tf.scatter_update`.
+
+ To avoid this operation there would be 2 alternatives:
+ 1) Reshaping the variable by merging the first `ndims` dimensions. However,
+ this is not possible because `tf.reshape` returns a Tensor, which we
+ cannot use `tf.scatter_update` on.
+ 2) Looping over the first `ndims` of the variable and using
+ `tf.scatter_update` on the subtensors that result of slicing the first
+ dimension. This is a valid option for `ndims = 1`, but less efficient than
+ this implementation.
+
+ See also `tf.scatter_update` and `tf.scatter_nd_update`.
+
+ Args:
+ ref: `Variable` to scatter onto.
+ indices: Tensor containing indices as described above.
+ updates: Tensor of updates to apply to `ref`.
+ use_locking: Boolean indicating whether to lock the writing operation.
+ name: Optional scope name string.
+
+ Returns:
+ Ref to `variable` after it has been modified.
+
+ Raises:
+ ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
+ not the same.
+ """
+ with ops.name_scope(name):
+ indices = ops.convert_to_tensor(indices, name="indices")
+ indices_shape = array_ops.shape(indices)
+ indices_dimensions = indices.get_shape().ndims
+
+ if indices_dimensions is None:
+ raise ValueError("batch_gather does not allow indices with unknown "
+ "shape.")
+
+ nd_indices = array_ops.expand_dims(indices, axis=-1)
+ nd_indices_list = []
+
+ # Scatter ND requires indices to have an additional dimension, in which the
+ # coordinates of the updated things are specified. For this to be adapted to
+ # the scatter_update with several leading dimensions, we simply make use of
+ # a tf.range for all the leading dimensions followed by concat of all the
+ # coordinates we created with the original indices.
+
+ # For example if indices.shape = [2, 3, 4], we should generate the following
+ # indices for tf.scatter_nd_update:
+ # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
+ # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
+ # nd_indices[:, :, 2] = indices
+ for dimension in range(indices_dimensions - 1):
+ # In this loop we generate the following for the example (one for each
+ # iteration).
+ # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
+ # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
+ # This is done at every iteration with a tf.range over the size of the
+ # i-th dimension and using broadcasting over the desired shape.
+ dimension_size = indices_shape[dimension]
+ shape_to_broadcast = [1] * (indices_dimensions + 1)
+ shape_to_broadcast[dimension] = dimension_size
+ dimension_range = array_ops.reshape(
+ gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
+ if dimension_range.dtype.base_dtype != nd_indices.dtype:
+ dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
+ nd_indices_list.append(
+ dimension_range * array_ops.ones_like(nd_indices))
+ # Add the original indices at the end, as described above, and concat.
+ nd_indices_list.append(nd_indices)
+ final_indices = array_ops.concat(nd_indices_list, axis=-1)
+ return scatter_nd_update(
+ ref, final_indices, updates, use_locking=use_locking)
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 0280c89c10..29fefbe3a5 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -15,7 +15,7 @@
"""Operations for working with string Tensors.
-See the @{$python/string_ops} guide.
+See the [Strings](https://tensorflow.org/api_guides/python/string_ops) guide.
"""
from __future__ import absolute_import
@@ -24,6 +24,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -31,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -39,9 +41,73 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+
+# pylint: disable=redefined-builtin
+def regex_full_match(input, pattern, name=None):
+ r"""Match elements of `input` with regex `pattern`.
+
+ Args:
+ input: string `Tensor`, the source strings to process.
+ pattern: string or scalar string `Tensor`, regular expression to use,
+ see more details at https://github.com/google/re2/wiki/Syntax
+ name: Name of the op.
+
+ Returns:
+ bool `Tensor` of the same shape as `input` with match results.
+ """
+ # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+ if not compat.forward_compatible(2018, 11, 10):
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+ if isinstance(pattern, util_compat.bytes_or_text_types):
+ # When `pattern` is static through the life of the op we can
+ # use a version which performs the expensive regex compilation once at
+ # creation time.
+ return gen_string_ops.static_regex_full_match(
+ input=input, pattern=pattern, name=name)
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+
+regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
+
# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)
+
+def regex_replace(source, pattern, rewrite, replace_global=True):
+ r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
+
+ Args:
+ source: string `Tensor`, the source strings to process.
+ pattern: string or scalar string `Tensor`, regular expression to use,
+ see more details at https://github.com/google/re2/wiki/Syntax
+ rewrite: string or scalar string `Tensor`, value to use in match
+ replacement, supports backslash-escaped digits (\1 to \9) can be to insert
+ text matching corresponding parenthesized group.
+ replace_global: `bool`, if `True` replace all non-overlapping matches,
+ else replace only the first match.
+
+ Returns:
+ string `Tensor` of the same shape as `source` with specified replacements.
+ """
+ # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+ if not compat.forward_compatible(2018, 10, 10):
+ return gen_string_ops.regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+ if (isinstance(pattern, util_compat.bytes_or_text_types) and
+ isinstance(rewrite, util_compat.bytes_or_text_types)):
+ # When `pattern` and `rewrite` are static through the life of the op we can
+ # use a version which performs the expensive regex compilation once at
+ # creation time.
+ return gen_string_ops.static_regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+ return gen_string_ops.regex_replace(
+ input=source, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@@ -91,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
shape.set_shape([2])
return sparse_tensor.SparseTensor(indices, values, shape)
+
@tf_export("strings.split")
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `SparseTensor`.
@@ -133,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1):
second column corresponds to the index of the split component in this row.
"""
if sep is None:
- sep = ''
+ sep = ""
sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
source = ops.convert_to_tensor(source, dtype=dtypes.string)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index c248dd9172..a43676cd70 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -40,8 +40,10 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -204,6 +206,42 @@ it does exist, simply return it.
"""
+_DEFAULT_USE_RESOURCE = False
+
+
+@tf_export(v1=["enable_resource_variables"])
+def enable_resource_variables():
+ """Creates resource variables by default.
+
+ Resource variables are improved versions of TensorFlow variables with a
+ well-defined memory model. Accessing a resource variable reads its value, and
+ all ops which access a specific read value of the variable are guaranteed to
+ see the same value for that tensor. Writes which happen after a read (by
+ having a control or data dependency on the read) are guaranteed not to affect
+ the value of the read tensor, and similarly writes which happen before a read
+ are guaranteed to affect the value. No guarantees are made about unordered
+ read/write pairs.
+
+ Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
+ feature.
+ """
+ global _DEFAULT_USE_RESOURCE
+ _DEFAULT_USE_RESOURCE = True
+
+
+@deprecation.deprecated(
+ None, "non-resource variables are not supported in the long term")
+@tf_export(v1=["disable_resource_variables"])
+def disable_resource_variables():
+ """Opts out of resource variables.
+
+ If your code needs tf.disable_resource_variables() to be called to work
+ properly please file a bug.
+ """
+ global _DEFAULT_USE_RESOURCE
+ _DEFAULT_USE_RESOURCE = False
+
+
class _VariableStore(object):
"""Variable store that carries a number of named Variables.
@@ -837,9 +875,6 @@ class _VariableStore(object):
raise ValueError("Variable %s does not exist, or was not created with "
"tf.get_variable(). Did you mean to set "
"reuse=tf.AUTO_REUSE in VarScope?" % name)
- if not shape.is_fully_defined() and not initializing_from_value:
- raise ValueError("Shape of a new variable (%s) must be fully defined, "
- "but instead was %s." % (name, shape))
# Create the tensor to initialize the variable with default value.
if initializer is None:
@@ -854,14 +889,23 @@ class _VariableStore(object):
# Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype)
- init_val = lambda: initializer( # pylint: disable=g-long-lambda
- shape.as_list(), dtype=dtype, partition_info=partition_info)
+ if shape and shape.is_fully_defined():
+ init_val = lambda: initializer( # pylint: disable=g-long-lambda
+ shape.as_list(), dtype=dtype, partition_info=partition_info)
+ elif not tf_inspect.getargspec(initializer).args:
+ init_val = initializer
+ else:
+ raise ValueError("You can only pass an initializer function that "
+ "expects no arguments to its callable when the "
+ "shape is not fully defined. The given initializer "
+ "function expects the following args %s" %
+ tf_inspect.getargspec(initializer).args)
variable_dtype = dtype.base_dtype
# Create the variable.
if use_resource is None:
# Set the default value if unspecified.
- use_resource = False
+ use_resource = _DEFAULT_USE_RESOURCE
v = variable(
initial_value=init_val,
name=name,
@@ -1440,12 +1484,11 @@ def get_variable(name,
aggregation=aggregation)
-get_variable_or_local_docstring = (
- """%s
+get_variable_or_local_docstring = ("""%s
%sThis function prefixes the name with the current variable scope
and performs reuse checks. See the
-@{$variables$Variable Scope How To}
+[Variable Scope How To](https://tensorflow.org/guide/variables)
for an extensive description of how reusing works. Here is a basic example:
```python
@@ -1515,6 +1558,22 @@ Args:
def custom_getter(getter, name, *args, **kwargs):
return getter(name + '_suffix', *args, **kwargs)
```
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -1548,10 +1607,10 @@ def get_local_variable( # pylint: disable=missing-docstring
partitioner=None,
validate_shape=True,
use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
@@ -1895,8 +1954,8 @@ class variable_scope(object):
Variable scope allows you to create new variables and to share already created
ones while providing checks to not create or share by accident. For details,
- see the @{$variables$Variable Scope How To}, here we present only a few basic
- examples.
+ see the [Variable Scope How To](https://tensorflow.org/guide/variables), here
+ we present only a few basic examples.
Simple example of how to create a new variable:
@@ -2363,6 +2422,8 @@ def default_variable_creator(next_creator=None, **kwargs):
if use_resource is None:
use_resource = get_variable_scope().use_resource
+ if use_resource is None:
+ use_resource = _DEFAULT_USE_RESOURCE
use_resource = use_resource or context.executing_eagerly()
if use_resource:
return resource_variable_ops.ResourceVariable(
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 402ab2dd9d..7a46157739 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
@@ -54,33 +55,47 @@ def _make_getter(captured_getter, captured_previous):
@tf_export("VariableSynchronization")
class VariableSynchronization(enum.Enum):
- """Indicates when a distributed variable will be synced."""
-
- # Indicates that the synchronization will be determined by the current
- # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
- # `ON_WRITE`).
+ """Indicates when a distributed variable will be synced.
+
+ * `AUTO`: Indicates that the synchronization will be determined by the current
+ `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ `ON_WRITE`).
+ * `NONE`: Indicates that there will only be one copy of the variable, so
+ there is no need to sync.
+ * `ON_WRITE`: Indicates that the variable will be updated across devices
+ every time it is written.
+ * `ON_READ`: Indicates that the variable will be aggregated across devices
+ when it is read (eg. when checkpointing or when evaluating an op that uses
+ the variable).
+ """
AUTO = 0
-
- # Indicates that there will only be one copy of the variable, so there is no
- # need to sync.
NONE = 1
-
- # Indicates that the variable will be aggregated across devices
- # every time it is updated.
ON_WRITE = 2
-
- # Indicates that the variable will be aggregated across devices
- # when it is read (eg. when checkpointing or when evaluating an op that uses
- # the variable).
ON_READ = 3
@tf_export("VariableAggregation")
class VariableAggregation(enum.Enum):
- """Indicates how a distributed variable will be aggregated."""
+ """Indicates how a distributed variable will be aggregated.
+
+ `tf.contrib.distribute.DistributionStrategy` distributes a model by making
+ multiple copies (called "towers") acting data-parallel on different elements
+ of the input batch. When performing some variable-update operation, say
+ `var.assign_add(x)`, in a model, we need to resolve how to combine the
+ different values for `x` computed in the different towers.
+
+ * `NONE`: This is the default, giving an error if you use a
+ variable-update operation with multiple towers.
+ * `SUM`: Add the updates across towers.
+ * `MEAN`: Take the arithmetic mean ("average") of the updates across towers.
+ * `ONLY_FIRST_TOWER`: This is for when every tower is performing the same
+ update, but we only want to perform the update once. Used, e.g., for the
+ global step counter.
+ """
NONE = 0
SUM = 1
MEAN = 2
+ ONLY_FIRST_TOWER = 3
class VariableMetaclass(type):
@@ -135,7 +150,7 @@ class VariableMetaclass(type):
@tf_export("Variable")
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
- """See the @{$variables$Variables How To} for a high level overview.
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
A variable maintains state in the graph across calls to `run()`. You add a
variable to the graph by constructing an instance of the class `Variable`.
@@ -458,7 +473,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign(self, value, use_locking=False):
+ def assign(self, value, use_locking=False, name=None, read_value=True):
"""Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
@@ -466,6 +481,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -473,7 +491,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign_add(self, delta, use_locking=False):
+ def assign_add(self, delta, use_locking=False, name=None, read_value=True):
"""Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
@@ -481,6 +499,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
delta: A `Tensor`. The value to add to this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -488,7 +509,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign_sub(self, delta, use_locking=False):
+ def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
"""Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
@@ -496,6 +517,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
delta: A `Tensor`. The value to subtract from this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -503,15 +527,200 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def scatter_sub(self, sparse_delta, use_locking=False):
+ def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
- This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
- sparse_delta.values)`.
-
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def scatter_add(self, sparse_delta, use_locking=False, name=None):
+ """Adds `IndexedSlices` to this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be assigned to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def scatter_update(self, sparse_delta, use_locking=False, name=None):
+ """Assigns `IndexedSlices` to this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be assigned to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def scatter_nd_sub(self, indices, updates, name=None):
+ """Applies sparse subtraction to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_sub(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, -9, 3, -6, -6, 6, 7, -4]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def scatter_nd_add(self, indices, updates, name=None):
+ """Applies sparse addition to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = ref.scatter_nd_add(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ raise NotImplementedError
+
+ def scatter_nd_update(self, indices, updates, name=None):
+ """Applies sparse assignment to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_assign(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -1264,7 +1473,7 @@ class RefVariable(Variable):
"""
return self._constraint
- def assign(self, value, use_locking=False):
+ def assign(self, value, use_locking=False, name=None, read_value=True):
"""Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
@@ -1272,14 +1481,21 @@ class RefVariable(Variable):
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the assignment has completed.
"""
- return state_ops.assign(self._variable, value, use_locking=use_locking)
+ assign = state_ops.assign(self._variable, value, use_locking=use_locking,
+ name=name)
+ if read_value:
+ return assign
+ return assign.op
- def assign_add(self, delta, use_locking=False):
+ def assign_add(self, delta, use_locking=False, name=None, read_value=True):
"""Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
@@ -1287,14 +1503,21 @@ class RefVariable(Variable):
Args:
delta: A `Tensor`. The value to add to this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the addition has completed.
"""
- return state_ops.assign_add(self._variable, delta, use_locking=use_locking)
+ assign = state_ops.assign_add(
+ self._variable, delta, use_locking=use_locking, name=name)
+ if read_value:
+ return assign
+ return assign.op
- def assign_sub(self, delta, use_locking=False):
+ def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
"""Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
@@ -1302,22 +1525,75 @@ class RefVariable(Variable):
Args:
delta: A `Tensor`. The value to subtract from this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the subtraction has completed.
"""
- return state_ops.assign_sub(self._variable, delta, use_locking=use_locking)
+ assign = state_ops.assign_sub(
+ self._variable, delta, use_locking=use_locking, name=name)
+ if read_value:
+ return assign
+ return assign.op
- def scatter_sub(self, sparse_delta, use_locking=False):
+ def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
- This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
- sparse_delta.values)`.
-
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return gen_state_ops.scatter_sub(
+ self._variable,
+ sparse_delta.indices,
+ sparse_delta.values,
+ use_locking=use_locking,
+ name=name)
+
+ def scatter_add(self, sparse_delta, use_locking=False, name=None):
+ """Adds `IndexedSlices` from this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be added to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return gen_state_ops.scatter_add(
+ self._variable,
+ sparse_delta.indices,
+ sparse_delta.values,
+ use_locking=use_locking,
+ name=name)
+
+ def scatter_update(self, sparse_delta, use_locking=False, name=None):
+ """Assigns `IndexedSlices` to this variable.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be assigned to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -1328,11 +1604,168 @@ class RefVariable(Variable):
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
- return state_ops.scatter_sub(
+ return gen_state_ops.scatter_update(
self._variable,
sparse_delta.indices,
sparse_delta.values,
- use_locking=use_locking)
+ use_locking=use_locking,
+ name=name)
+
+ def scatter_nd_sub(self, indices, updates, name=None):
+ """Applies sparse subtraction to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_sub(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, -9, 3, -6, -6, 6, 7, -4]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return gen_state_ops.scatter_nd_sub(
+ self._variable, indices, updates, use_locking=True, name=name)
+
+ def scatter_nd_add(self, indices, updates, name=None):
+ """Applies sparse addition to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = ref.scatter_nd_add(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return gen_state_ops.scatter_nd_add(
+ self._variable, indices, updates, use_locking=True, name=name)
+
+ def scatter_nd_update(self, indices, updates, name=None):
+ """Applies sparse assignment to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ op = ref.scatter_nd_update(indices, updates)
+ with tf.Session() as sess:
+ print sess.run(op)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+ See `tf.scatter_nd` for more details about how to make updates to
+ slices.
+
+ Args:
+ indices: The indices to be used in the operation.
+ updates: The values to be used in the operation.
+ name: the name of the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ return gen_state_ops.scatter_nd_update(
+ self._variable, indices, updates, use_locking=True, name=name)
def _strided_slice_assign(self,
begin,
@@ -1917,10 +2350,15 @@ class PartitionedVariable(object):
def as_tensor(self):
"""Returns the overall concatenated value as a `Tensor`.
+ The returned tensor will not inherit the control dependencies from the scope
+ where the value is used, which is similar to getting the value of
+ `Variable`.
+
Returns:
`Tensor` containing the concatenated value.
"""
- return self._concat()
+ with ops.control_dependencies(None):
+ return self._concat()
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 9ffb48c4a5..5dc4037d62 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -15,7 +15,7 @@
"""Testing.
-See the @{$python/test} guide.
+See the [Testing](https://tensorflow.org/api_guides/python/test) guide.
Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
depending on the python version.
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 157f2341e0..be8f425481 100644..100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -50,11 +50,12 @@ limitations under the License.
%rename("%s") TFE_Py_TapeSetRestartOnThread;
%rename("%s") TFE_Py_TapeSetIsEmpty;
%rename("%s") TFE_Py_TapeSetShouldRecord;
-%rename("%s") TFE_Py_TapeSetWatch;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
-%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
+%rename("%s") TFE_Py_TapeVariableAccessed;
+%rename("%s") TFE_Py_TapeWatch;
+%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
@@ -65,6 +66,7 @@ limitations under the License.
%rename("%s") TFE_Py_TensorShapeOnDevice;
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
+%rename("%s") TFE_Py_RegisterVSpace;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
@@ -105,20 +107,29 @@ limitations under the License.
}
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* serialized_function_def {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* device_name {
if ($input == Py_None) {
$1 = nullptr;
} else {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* op_name {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
%typemap(in) (TFE_Context*) {
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 7a37eda5ea..c9bc33e218 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -225,6 +225,7 @@ py_library(
":signature_constants",
":utils",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index 9a0b276a4b..b7e217a35b 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -79,13 +79,13 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_function(self):
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader2.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
@@ -101,7 +101,7 @@ class SavedModelLoaderTest(test.TestCase):
with self.assertRaises(KeyError):
graph.get_tensor_by_name("z:0")
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
# Check that x and y are not initialized
with self.assertRaises(errors.FailedPreconditionError):
sess.run(x)
@@ -110,7 +110,7 @@ class SavedModelLoaderTest(test.TestCase):
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(
sess.graph, ["foo_graph"], import_scope="baz")
@@ -126,14 +126,14 @@ class SavedModelLoaderTest(test.TestCase):
# Test combined load function.
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"], import_scope="baa")
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
def test_restore_variables(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
x = variables.Variable(0, name="x")
y = variables.Variable(0, name="y")
z = x * y
@@ -151,7 +151,7 @@ class SavedModelLoaderTest(test.TestCase):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph()
saver, _ = loader.load_graph(graph, ["foo_graph"])
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
loader.restore_variables(sess, saver)
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
@@ -203,12 +203,12 @@ class SavedModelLoaderTest(test.TestCase):
builder.save()
loader = loader_impl.SavedModelLoader(path)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 00b669fc97..49d52d3bee 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -97,7 +97,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def({
@@ -110,7 +110,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_inputs_tensor_info_accept(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def({
@@ -121,7 +121,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_outputs_tensor_info_fail(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def(
@@ -133,7 +133,7 @@ class SavedModelTest(test.TestCase):
signature_def_map={"foo_key": foo_signature})
def _validate_outputs_tensor_info_accept(self, builder, tensor_info):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def(
@@ -153,7 +153,7 @@ class SavedModelTest(test.TestCase):
def testBadSavedModelFileFormat(self):
export_dir = self._get_export_dir("test_bad_saved_model_file_format")
# Attempt to load a SavedModel from an export directory that does not exist.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError,
"SavedModel file does not exist at: %s" %
export_dir):
@@ -164,7 +164,7 @@ class SavedModelTest(test.TestCase):
path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
with open(path_to_pb, "w") as f:
f.write("invalid content")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
constants.SAVED_MODEL_FILENAME_PB):
loader.load(sess, ["foo"], export_dir)
@@ -178,7 +178,7 @@ class SavedModelTest(test.TestCase):
constants.SAVED_MODEL_FILENAME_PBTXT)
with open(path_to_pbtxt, "w") as f:
f.write("invalid content")
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
constants.SAVED_MODEL_FILENAME_PBTXT):
loader.load(sess, ["foo"], export_dir)
@@ -187,7 +187,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_verify_session_graph_usage")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
@@ -209,12 +209,12 @@ class SavedModelTest(test.TestCase):
# Expect an assertion error since add_meta_graph_and_variables() should be
# invoked before any add_meta_graph() calls.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"])
# Expect an assertion error for multiple calls of
# add_meta_graph_and_variables() since weights should be saved exactly once.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["bar"])
self.assertRaises(AssertionError, builder.add_meta_graph_and_variables,
@@ -227,35 +227,35 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - a single tag (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph([tag_constants.SERVING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - multiple tags (from predefined constants).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - multiple tags (from predefined constants for serving on TPU).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 45)
builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
builder.add_meta_graph(["foo", "bar"])
@@ -263,49 +263,49 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with a single predefined tag whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with a single predefined tag whose variables were not
# saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple predefined tags whose variables were not
# saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple predefined tags (for serving on TPU)
# whose variables were not saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple tags. Provide duplicate tags to test set
# semantics.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo", "bar", "foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Try restoring a graph with a non-existent tag. This should yield a runtime
# error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
export_dir)
# Try restoring a graph where a subset of the tags match. Since tag matching
# for meta graph defs follows "all" semantics, this should yield a runtime
# error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
export_dir)
@@ -315,7 +315,7 @@ class SavedModelTest(test.TestCase):
# Graph with two variables. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v1", 1)
self._init_and_validate_variable(sess, "v2", 2)
builder.add_meta_graph_and_variables(sess, ["foo"])
@@ -323,14 +323,14 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable (subset of the variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v2", 3)
builder.add_meta_graph(["bar"])
# Graph with a single variable (disjoint set of variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v3", 4)
builder.add_meta_graph(["baz"])
@@ -338,7 +338,7 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 2)
@@ -348,7 +348,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "bar", whose variables were not saved. Only the
# subset of the variables added to the graph will be restored with the
# checkpointed value.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 1)
@@ -357,7 +357,7 @@ class SavedModelTest(test.TestCase):
# Try restoring the graph with tag "baz", whose variables were not saved.
# Since this graph has a disjoint set of variables from the set that was
# saved, this should raise an error.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
@@ -366,12 +366,12 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
constant_5_name = constant_op.constant(5.0).name
builder.add_meta_graph_and_variables(sess, ["foo"])
# Second graph with no variables
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
constant_6_name = constant_op.constant(6.0).name
builder.add_meta_graph(["bar"])
@@ -379,7 +379,7 @@ class SavedModelTest(test.TestCase):
builder.save()
# Restore the graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
@@ -388,7 +388,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(30.0, sess.run(c))
# Restore the graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
@@ -402,7 +402,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
@@ -410,7 +410,7 @@ class SavedModelTest(test.TestCase):
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -426,13 +426,13 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph(["bar"])
@@ -440,13 +440,13 @@ class SavedModelTest(test.TestCase):
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -457,7 +457,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v = variables.Variable(42, name="v")
ops.add_to_collection("foo_vars", v)
sess.run(variables.global_variables_initializer())
@@ -467,7 +467,7 @@ class SavedModelTest(test.TestCase):
# Graph with the same single variable added to a different collection.
# SavedModel invoked to:
# - simply add the model (weights are not updated).
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v = variables.Variable(43, name="v")
ops.add_to_collection("bar_vars", v)
sess.run(variables.global_variables_initializer())
@@ -480,7 +480,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "foo", whose variables were saved. The
# collection 'foo_vars' should contain a single element. The collection
# 'bar_vars' should not be found.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_foo_vars = ops.get_collection("foo_vars")
self.assertEqual(len(collection_foo_vars), 1)
@@ -493,7 +493,7 @@ class SavedModelTest(test.TestCase):
# reflect the new collection. The value of the variable in the
# collection-def corresponds to the saved value (from the previous graph
# with tag "foo").
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_bar_vars = ops.get_collection("bar_vars")
self.assertEqual(len(collection_bar_vars), 1)
@@ -507,7 +507,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable and a single entry in the signature def map.
# SavedModel is invoked to add with weights.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
foo_signature = signature_def_utils.build_signature_def(dict(),
@@ -517,7 +517,7 @@ class SavedModelTest(test.TestCase):
# Graph with the same single variable and multiple entries in the signature
# def map. No weights are saved by SavedModel.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
bar_signature = signature_def_utils.build_signature_def(dict(),
@@ -539,7 +539,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "foo". The single entry in the SignatureDef map
# corresponding to "foo_key" should exist.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -551,7 +551,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with tag "bar". The SignatureDef map should have two
# entries. One corresponding to "bar_key" and another corresponding to the
# new value of "foo_key".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
@@ -610,7 +610,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection.
@@ -628,7 +628,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -643,7 +643,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -660,7 +660,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar bak",
@@ -674,7 +674,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_same_path")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -689,7 +689,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -709,7 +709,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_same_file")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection(
@@ -726,7 +726,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
@@ -746,7 +746,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_assets_name_collision_many_files")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
for i in range(5):
@@ -761,7 +761,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
for i in range(1, 5):
idx = str(i)
@@ -778,7 +778,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -801,7 +801,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -813,7 +813,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -835,7 +835,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -858,7 +858,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
g = ops.Graph()
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -887,7 +887,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -905,7 +905,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(3, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -916,7 +916,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op_group")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -934,7 +934,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
@@ -945,7 +945,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_train_op_after_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -964,12 +964,12 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir)
self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
@@ -977,7 +977,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection specific to `foo` graph.
@@ -988,7 +988,7 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection specific to `bar` graph.
@@ -1002,14 +1002,14 @@ class SavedModelTest(test.TestCase):
builder.save()
# Check assets restored for graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"foo.txt", "content_foo",
"asset_file_tensor:0")
# Check assets restored for graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self._validate_asset_collection(export_dir, bar_graph.collection_def,
"bar.txt", "content_bar",
@@ -1019,7 +1019,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_duplicate_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection with `foo.txt` that has `foo` specific
@@ -1031,7 +1031,7 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection with `foo.txt` that has `bar` specific
@@ -1046,14 +1046,14 @@ class SavedModelTest(test.TestCase):
builder.save()
# Check assets restored for graph with tag "foo".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"foo.txt", "content_foo",
"asset_file_tensor:0")
# Check assets restored for graph with tag "bar".
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
# Validate the assets for `bar` graph. `foo.txt` should contain the
@@ -1139,7 +1139,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_custom_saver")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
custom_saver = training.Saver(name="my_saver")
@@ -1149,7 +1149,7 @@ class SavedModelTest(test.TestCase):
builder.save()
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, ["tag"], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue("my_saver/restore_all" in graph_ops)
@@ -1161,7 +1161,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_no_custom_saver")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
training.Saver(name="my_saver")
@@ -1171,7 +1171,7 @@ class SavedModelTest(test.TestCase):
builder.save()
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, ["tag"], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue("my_saver/restore_all" in graph_ops)
@@ -1183,7 +1183,7 @@ class SavedModelTest(test.TestCase):
export_dir = self._get_export_dir("test_multiple_custom_savers")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
variables.Variable(1, name="v1")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["tag_0"])
@@ -1199,7 +1199,7 @@ class SavedModelTest(test.TestCase):
def _validate_custom_saver(tag_name, saver_name):
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saved_graph = loader.load(sess, [tag_name], export_dir)
self.assertEqual(
saved_graph.saver_def.restore_op_name,
@@ -1214,7 +1214,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Build a SavedModel with a variable, an asset, and a constant tensor.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
"asset_file_tensor")
@@ -1228,7 +1228,7 @@ class SavedModelTest(test.TestCase):
# Save the SavedModel to disk.
builder.save()
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Restore the SavedModel under an import_scope in a new graph/session.
graph_proto = loader.load(
sess, ["tag_name"], export_dir, import_scope="scope_name")
@@ -1281,7 +1281,7 @@ class SavedModelTest(test.TestCase):
# Restore the graph with a single predefined tag whose variables were saved
# without any device information.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index f8ad788f77..37f927f381 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -21,9 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util.tf_export import tf_export
@@ -316,80 +314,3 @@ def _is_valid_classification_signature(signature_def):
return True
-
-def _get_shapes_from_tensor_info_dict(tensor_info_dict):
- """Returns a map of keys to TensorShape objects.
-
- Args:
- tensor_info_dict: map with TensorInfo proto as values.
-
- Returns:
- Map with corresponding TensorShape objects as values.
- """
- return {
- key: tensor_shape.TensorShape(tensor_info.tensor_shape)
- for key, tensor_info in tensor_info_dict.items()
- }
-
-
-def _get_types_from_tensor_info_dict(tensor_info_dict):
- """Returns a map of keys to DType objects.
-
- Args:
- tensor_info_dict: map with TensorInfo proto as values.
-
- Returns:
- Map with corresponding DType objects as values.
- """
- return {
- key: dtypes.DType(tensor_info.dtype)
- for key, tensor_info in tensor_info_dict.items()
- }
-
-
-def get_signature_def_input_shapes(signature):
- """Returns map of parameter names to their shapes.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to TensorShape objects.
- """
- return _get_shapes_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_input_types(signature):
- """Returns map of output names to their types.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to DType objects.
- """
- return _get_types_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_output_shapes(signature):
- """Returns map of output names to their shapes.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to TensorShape objects.
- """
- return _get_shapes_from_tensor_info_dict(signature.outputs)
-
-
-def get_signature_def_output_types(signature):
- """Returns map of output names to their types.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to DType objects.
- """
- return _get_types_from_tensor_info_dict(signature.outputs)
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index ebc5450633..18c55d8d33 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -275,44 +275,6 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(method_name, signature_def.method_name)
self.assertEqual(3, len(signature_def.outputs))
- def testGetShapeAndTypes(self):
- inputs = {
- "input-1": constant_op.constant(["a", "b"]),
- "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
- }
- outputs = {
- "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
- "output-2": constant_op.constant([["b"]]),
- }
- signature_def = _make_signature(inputs, outputs)
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
- {"input-1": [2], "input-2": [10, 11]})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
- {"output-1": [10, 32], "output-2": [1, 1]})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_input_types(signature_def),
- {"input-1": dtypes.string, "input-2": dtypes.float32})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_output_types(signature_def),
- {"output-1": dtypes.float32, "output-2": dtypes.string})
-
- def testGetNonFullySpecifiedShapes(self):
- outputs = {
- "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
- "output-2": array_ops.sparse_placeholder(dtypes.float32),
- }
- signature_def = _make_signature({}, outputs)
- shapes = signature_def_utils_impl.get_signature_def_output_shapes(
- signature_def)
- self.assertEqual(len(shapes), 2)
- # Must compare shapes with as_list() since 2 equivalent non-fully defined
- # shapes are not equal to each other.
- self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
- # Must compare `dims` since its an unknown shape.
- self.assertEqual(shapes["output-2"].dims, None)
-
def _assertValidSignature(self, inputs, outputs, method_name):
signature_def = signature_def_utils_impl.build_signature_def(
inputs, outputs, method_name)
diff --git a/tensorflow/python/saved_model/simple_save_test.py b/tensorflow/python/saved_model/simple_save_test.py
index b2fa40d4f1..18f82daada 100644
--- a/tensorflow/python/saved_model/simple_save_test.py
+++ b/tensorflow/python/saved_model/simple_save_test.py
@@ -60,7 +60,7 @@ class SimpleSaveTest(test.TestCase):
# Initialize input and output variables and save a prediction graph using
# the default parameters.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var_x = self._init_and_validate_variable(sess, "var_x", 1)
var_y = self._init_and_validate_variable(sess, "var_y", 2)
inputs = {"x": var_x}
@@ -69,7 +69,7 @@ class SimpleSaveTest(test.TestCase):
# Restore the graph with a valid tag and check the global variables and
# signature def map.
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
graph = loader.load(sess, [tag_constants.SERVING], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 20ff34fd8e..06d09325c8 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -75,7 +75,7 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
ValueError: If `tensor_info` is malformed.
"""
- graph = graph if graph is not None else ops.get_default_graph()
+ graph = graph or ops.get_default_graph()
def _get_tensor(name):
return graph.get_tensor_by_name(
ops.prepend_name_scope(name, import_scope=import_scope))
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 980320cc66..fbae2b77fa 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -15,7 +15,7 @@
"""Tensor summaries for exporting information about a model.
-See the @{$python/summary} guide.
+See the [Summary](https://tensorflow.org/api_guides/python/summary) guide.
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/summary/summary_test.py b/tensorflow/python/summary/summary_test.py
index eb9dbf9645..ac5eb4dbbe 100644
--- a/tensorflow/python/summary/summary_test.py
+++ b/tensorflow/python/summary/summary_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.summary import summary as summary_lib
class ScalarSummaryTest(test.TestCase):
def testScalarSummary(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = constant_op.constant(3)
with ops.name_scope('outer'):
im = summary_lib.scalar('inner', i)
@@ -45,7 +45,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(values[0].simple_value, 3.0)
def testScalarSummaryWithFamily(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = constant_op.constant(7)
with ops.name_scope('outer'):
im1 = summary_lib.scalar('inner', i, family='family')
@@ -68,7 +68,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(values[0].simple_value, 7.0)
def testSummarizingVariable(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
c = constant_op.constant(42.0)
v = variables.Variable(c)
ss = summary_lib.scalar('summary', v)
@@ -83,7 +83,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(value.simple_value, 42.0)
def testImageSummary(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 4, 4, 3))
with ops.name_scope('outer'):
im = summary_lib.image('inner', i, max_outputs=3)
@@ -97,7 +97,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(tags, expected)
def testImageSummaryWithFamily(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 2, 3, 1))
with ops.name_scope('outer'):
im = summary_lib.image('inner', i, max_outputs=3, family='family')
@@ -113,7 +113,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(tags, expected)
def testHistogramSummary(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 4, 4, 3))
with ops.name_scope('outer'):
summ_op = summary_lib.histogram('inner', i)
@@ -124,7 +124,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(summary.value[0].tag, 'outer/inner')
def testHistogramSummaryWithFamily(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 4, 4, 3))
with ops.name_scope('outer'):
summ_op = summary_lib.histogram('inner', i, family='family')
@@ -136,7 +136,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(summary.value[0].tag, 'family/outer/family/inner')
def testAudioSummary(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 3, 4))
with ops.name_scope('outer'):
aud = summary_lib.audio('inner', i, 0.2, max_outputs=3)
@@ -150,7 +150,7 @@ class ScalarSummaryTest(test.TestCase):
self.assertEqual(tags, expected)
def testAudioSummaryWithFamily(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
i = array_ops.ones((5, 3, 4))
with ops.name_scope('outer'):
aud = summary_lib.audio('inner', i, 0.2, max_outputs=3, family='family')
@@ -194,7 +194,7 @@ class ScalarSummaryTest(test.TestCase):
new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')
# However, the tags are unaffected.
- with self.test_session() as s:
+ with self.cached_session() as s:
new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
new_summ_pb = summary_pb2.Summary()
new_summ_pb.ParseFromString(new_summ_str)
diff --git a/tensorflow/python/summary/text_summary_test.py b/tensorflow/python/summary/text_summary_test.py
index 4d357918f6..5b0db43cc1 100644
--- a/tensorflow/python/summary/text_summary_test.py
+++ b/tensorflow/python/summary/text_summary_test.py
@@ -33,7 +33,7 @@ class TextPluginTest(test_util.TensorFlowTestCase):
"""
def testTextSummaryAPI(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
num = array_ops.constant(1)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 26e8acd897..39174fa589 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -54,4 +54,5 @@ limitations under the License.
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
+%include "tensorflow/python/grappler/graph_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 222f856511..1c1a1a54cd 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -114,6 +114,12 @@ py_library(
],
)
+py_library(
+ name = "component_api_helper",
+ srcs = ["component_api_helper.py"],
+ srcs_version = "PY2AND3",
+)
+
py_binary(
name = "strip_unused",
srcs = ["strip_unused.py"],
@@ -131,6 +137,7 @@ py_test(
size = "small",
srcs = ["strip_unused_test.py"],
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":strip_unused_lib",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
index f87fdb2d88..90be2cc4f7 100644
--- a/tensorflow/python/tools/api/generator/BUILD
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1")
exports_files(
[
@@ -14,14 +15,13 @@ exports_files(
],
)
-py_binary(
+py_library(
name = "create_python_api",
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
- main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/python:no_contrib",
+ "//tensorflow/python:util",
"//tensorflow/python/tools/api/generator:doc_srcs",
],
)
@@ -56,7 +56,7 @@ py_test(
args = [
"--package=tensorflow.python",
"--api_name=tensorflow",
- ] + TENSORFLOW_API_INIT_FILES,
+ ] + TENSORFLOW_API_INIT_FILES + TENSORFLOW_API_INIT_FILES_V1,
main = "doc_srcs_test.py",
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index 2810d83bd2..271cf2afaf 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
+def get_compat_files(
+ file_paths,
+ compat_api_version):
+ """Prepends compat/v<compat_api_version> to file_paths."""
+ return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
+
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
- compat_output_files = {},
root_init_template = None,
srcs = [],
api_name = "tensorflow",
@@ -23,7 +28,8 @@ def gen_api_init_files(
compat_api_versions = [],
package = "tensorflow.python",
package_dep = "//tensorflow/python:no_contrib",
- output_package = "tensorflow"):
+ output_package = "tensorflow",
+ output_dir = ""):
"""Creates API directory structure and __init__.py files.
Creates a genrule that generates a directory structure with __init__.py
@@ -37,8 +43,6 @@ def gen_api_init_files(
tf_export. For e.g. if an op is decorated with
@tf_export('module1.module2', 'module3'). Then, output_files should
include module1/module2/__init__.py and module3/__init__.py.
- compat_output_files: Dictionary mapping each compat_api_version to the
- set of __init__.py file paths that should be generated for that version.
root_init_template: Python init file that should be used as template for
root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
template will be replaced with root imports collected by this genrule.
@@ -53,14 +57,16 @@ def gen_api_init_files(
process
package_dep: Python library target containing your package.
output_package: Package where generated API will be added to.
+ output_dir: Subdirectory to output API to.
+ If non-empty, must end with '/'.
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- api_gen_binary_target = "create_" + package + "_api"
+ api_gen_binary_target = ("create_" + package + "_api_%d") % api_version
native.py_binary(
- name = "create_" + package + "_api",
+ name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
@@ -72,14 +78,9 @@ def gen_api_init_files(
],
)
- all_output_files = list(output_files)
+ all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
compat_api_version_flags = ""
for compat_api_version in compat_api_versions:
- compat_files = compat_output_files.get(compat_api_version, [])
- all_output_files.extend([
- "compat/v%d/%s" % (compat_api_version, f)
- for f in compat_files
- ])
compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
native.genrule(
@@ -87,12 +88,15 @@ def gen_api_init_files(
outs = all_output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" +
- api_name + " --apiversion=" + str(api_version) +
+ root_init_template_flag + " --apidir=$(@D)" + output_dir +
+ " --apiname=" + api_name + " --apiversion=" + str(api_version) +
compat_api_version_flags + " --package=" + package +
" --output_package=" + output_package + " $(OUTS)"
),
srcs = srcs,
tools = [":" + api_gen_binary_target],
- visibility = ["//tensorflow:__pkg__"],
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
+ ],
)
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 7001e566ce..92446e2f8f 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
@@ -86,7 +87,6 @@ TENSORFLOW_API_INIT_FILES = [
"sysconfig/__init__.py",
"test/__init__.py",
"train/__init__.py",
- "train/queue_runner/__init__.py",
"user_ops/__init__.py",
# END GENERATED FILES
]
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 73d11199d9..bc2f3516d1 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
index ad1988494d..fbec9c6635 100644
--- a/tensorflow/python/tools/api/generator/doc_srcs.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
@@ -62,8 +62,6 @@ _TENSORFLOW_DOC_SOURCES = {
'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
'test': DocSource(docstring_module_name='platform.test'),
'train': DocSource(docstring_module_name='training.training'),
- 'train.queue_runner': DocSource(
- docstring_module_name='training.queue_runner'),
}
_ESTIMATOR_DOC_SOURCES = {
diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py
new file mode 100644
index 0000000000..97f46719e5
--- /dev/null
+++ b/tensorflow/python/tools/component_api_helper.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helper functions to help integrate TensorFlow components into TF API.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+import os
+
+
+def package_hook(parent_package_str, child_package_str, error_msg=None):
+ """Used to hook in an external package into the TensorFlow namespace.
+
+ Example usage:
+ ### tensorflow/__init__.py
+ from tensorflow.python.tools import component_api_helper
+ component_api_helper.package_hook(
+ 'tensorflow', 'tensorflow_estimator.python')
+ component_api_helper(
+ 'tensorflow.contrib', 'tensorflow_estimator.contrib.python')
+ del component_api_helper
+
+ TODO(mikecase): This function has a minor issue, where if the child package
+ does not exist alone in its directory, sibling packages to it will also be
+ accessible from the parent. This is because we just add
+ `child_pkg.__file__/..` to the subpackage search path. This should not be
+ a big issue because of how our API generation scripts work (the child package
+ we are hooking up should always be alone). But there might be a better way
+ of doing this.
+
+ Args:
+ parent_package_str: Parent package name as a string such as 'tensorflow' or
+ 'tensorflow.contrib'. This will become the parent package for the
+ component package being hooked in.
+ child_package_str: Child package name as a string such as
+ 'tensorflow_estimator.python'. This package will be added as a subpackage
+ of the parent.
+ error_msg: Message to print if child package cannot be found.
+ """
+ parent_pkg = importlib.import_module(parent_package_str)
+ try:
+ child_pkg = importlib.import_module(child_package_str)
+ except ImportError:
+ if error_msg:
+ print(error_msg)
+ return
+
+ def set_child_as_subpackage():
+ """Sets child package as a subpackage of parent package.
+
+ Will allow the following import statement to work.
+ >>> import parent.child
+ """
+ child_pkg_path = [os.path.abspath(
+ os.path.join(os.path.dirname(child_pkg.__file__), ".."))]
+ try:
+ parent_pkg.__path__ = child_pkg_path + parent_pkg.__path__
+ except AttributeError:
+ parent_pkg.__path__ = child_pkg_path
+
+ def set_child_as_attr():
+ """Sets child package as a attr of the parent package.
+
+ Will allow for the following.
+ >>> import parent
+ >>> parent.child
+ """
+ child_pkg_attr_name = child_pkg.__name__.split(".")[-1]
+ setattr(parent_pkg, child_pkg_attr_name, child_pkg)
+
+ set_child_as_subpackage()
+ set_child_as_attr()
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index acf070075e..893309f35a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -59,7 +59,7 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
-def _has_variables(sess):
+def _has_no_variables(sess):
"""Determines if the graph has any variables.
Args:
@@ -89,7 +89,37 @@ def freeze_graph_with_def_protos(input_graph_def,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph_def: A `GraphDef`.
+ input_saver_def: A `SaverDef` (optional).
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated string of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted).
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph_def: A `MetaGraphDef` (optional),
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
+ and variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format (optional).
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2)
+
+ Returns:
+ Location of the output_graph_def.
+ """
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
@@ -168,7 +198,7 @@ def freeze_graph_with_def_protos(input_graph_def,
"the flag --input_saved_model_dir.")
return -1
# Models that have been frozen previously do not contain Variables.
- elif _has_variables(sess):
+ elif _has_no_variables(sess):
print("No variables were found in this model. It is likely the model "
"was frozen previously. You cannot freeze a graph twice.")
return 0
@@ -271,7 +301,37 @@ def freeze_graph(input_graph,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph: A `GraphDef` file to load.
+ input_saver: A TensorFlow Saver file.
+ input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated list of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted),
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph: A `MetaGraphDef` file to load (optional).
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
+ variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format.
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2).
+ Returns:
+ String that is the location of frozen GraphDef.
+ """
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index bb90d1cd6e..108f2b593c 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -133,14 +133,14 @@ def ensure_graph_is_valid(graph_def):
"""
node_map = {}
for node in graph_def.node:
- if node.name not in node_map.keys():
+ if node.name not in node_map:
node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
for node in graph_def.node:
for input_name in node.input:
input_node_name = node_name_from_input(input_name)
- if input_node_name not in node_map.keys():
+ if input_node_name not in node_map:
raise ValueError("Input for ", node.name, " not found: ", input_name)
@@ -225,7 +225,7 @@ def fold_batch_norms(input_graph_def):
"""
input_node_map = {}
for node in input_graph_def.node:
- if node.name not in input_node_map.keys():
+ if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
@@ -390,7 +390,7 @@ def fuse_resize_and_conv(input_graph_def, output_node_names):
input_node_map = {}
for node in input_graph_def.node:
- if node.name not in input_node_map.keys():
+ if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py
index 4b3d98242c..cce8060fb9 100644
--- a/tensorflow/python/tools/print_selective_registration_header_test.py
+++ b/tensorflow/python/tools/print_selective_registration_header_test.py
@@ -59,6 +59,9 @@ GRAPH_DEF_TXT = """
}
"""
+# AccumulateNV2 is included because it should be included in the header despite
+# lacking a kernel (it's rewritten by AccumulateNV2RemovePass; see
+# core/common_runtime/accumulate_n_optimizer.cc.
GRAPH_DEF_TXT_2 = """
node: {
name: "node_4"
@@ -67,6 +70,12 @@ GRAPH_DEF_TXT_2 = """
device: "/cpu:0"
attr: { key: "T" value: { type: DT_FLOAT } }
}
+ node: {
+ name: "node_5"
+ op: "AccumulateNV2"
+ attr: { key: "T" value: { type: DT_INT32 } }
+ attr: { key : "N" value: { i: 3 } }
+ }
"""
@@ -100,6 +109,7 @@ class PrintOpFilegroupTest(test.TestCase):
self.assertListEqual(
[
+ ('AccumulateNV2', None), #
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
('MatMul',
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
@@ -117,6 +127,7 @@ class PrintOpFilegroupTest(test.TestCase):
'rawproto', self.WriteGraphFiles(graphs), default_ops)
self.assertListEqual(
[
+ ('AccumulateNV2', None), #
('BiasAdd', 'BiasOp<CPUDevice, float>'), #
('MatMul',
matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), #
@@ -196,6 +207,7 @@ class PrintOpFilegroupTest(test.TestCase):
constexpr inline bool ShouldRegisterOp(const char op[]) {
return false
+ || isequal(op, "AccumulateNV2")
|| isequal(op, "BiasAdd")
;
}
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 38fed5335e..c5289564fe 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -40,8 +40,8 @@ from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
@@ -140,7 +140,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -166,7 +166,7 @@ def _print_tensor_info(tensor_info, indent=0):
tensor_info: TensorInfo object to be printed.
indent: How far (in increments of 2 spaces) to indent each line output
"""
- indent_str = " " * indent
+ indent_str = ' ' * indent
def in_print(s):
print(indent_str + s)
@@ -270,7 +270,7 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, tf_debug=False):
+ overwrite_flag, worker=None, tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -288,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
it will be created.
overwrite_flag: A boolean flag to allow overwrite output file if file with
the same name exists.
+ worker: If provided, the session will be run on the worker. Valid worker
+ specification is a bns or gRPC path.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -308,7 +310,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
# Check if input tensor keys are valid.
for input_key_name in input_tensor_key_feed_dict.keys():
- if input_key_name not in inputs_tensor_info.keys():
+ if input_key_name not in inputs_tensor_info:
raise ValueError(
'"%s" is not a valid input key. Please choose from %s, or use '
'--show option.' %
@@ -328,7 +330,7 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
for tensor_key in output_tensor_keys_sorted
]
- with session.Session(graph=ops_lib.Graph()) as sess:
+ with session.Session(worker, graph=ops_lib.Graph()) as sess:
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -544,7 +546,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
input_examples = preprocess_input_examples_arg_string(input_examples_str)
for input_tensor_key, (filename, variable_name) in inputs.items():
- data = np.load(file_io.FileIO(filename, mode='r'))
+ data = np.load(file_io.FileIO(filename, mode='rb'))
# When a variable_name key is specified for the input file
if variable_name:
@@ -632,7 +634,8 @@ def run(args):
args.inputs, args.input_exprs, args.input_examples)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
- args.overwrite, tf_debug=args.tf_debug)
+ args.overwrite, worker=args.worker,
+ tf_debug=args.tf_debug)
def scan(args):
@@ -769,6 +772,12 @@ def create_parser():
help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
'intermediate Tensors and runtime GraphDefs while running the '
'SavedModel.')
+ parser_run.add_argument(
+ '--worker',
+ type=str,
+ default=None,
+ help='if specified, a Session will be run on the worker. '
+ 'Valid worker specification is a bns or gRPC path.')
parser_run.set_defaults(func=run)
# scan command
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index dc0612bb3f..b99c632c3e 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -32,6 +32,16 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
+# Usually, we use each graph node to induce registration of an op and
+# corresponding kernel; nodes without a corresponding kernel (perhaps due to
+# attr types) generate a warning but are otherwise ignored. Ops in this set are
+# registered even if there's no corresponding kernel.
+OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
+ # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
+ # core/common_runtime/accumulate_n_optimizer.cc.
+ 'AccumulateNV2'
+])
+
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
"""Gets the ops and kernels needed from the model files."""
@@ -53,8 +63,10 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
node_def.device = '/cpu:0'
kernel_class = pywrap_tensorflow.TryFindKernelClass(
node_def.SerializeToString())
- if kernel_class:
- op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8')))
+ op = str(node_def.op)
+ if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
+ op_and_kernel = (op, str(kernel_class.decode('utf-8'))
+ if kernel_class else None)
if op_and_kernel not in ops:
ops.add(op_and_kernel)
else:
@@ -129,6 +141,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
'''
line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
for _, kernel_class in ops_and_kernels:
+ if kernel_class is None: continue
line += '"%s",\n' % kernel_class
line += '};'
append(line)
diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py
index 2678016d24..a14ac895ac 100644
--- a/tensorflow/python/training/adadelta_test.py
+++ b/tensorflow/python/training/adadelta_test.py
@@ -155,7 +155,7 @@ class AdadeltaOptimizerTest(test.TestCase):
rtol=1e-5)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -168,7 +168,7 @@ class AdadeltaOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 6778f3c735..3508b98475 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -70,20 +70,24 @@ class AdagradOptimizer(optimizer.Optimizer):
def _create_slots(self, var_list):
for v in var_list:
- with ops.colocate_with(v):
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ init = self._init_constant_op(v, dtype)
self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator", self._name)
+ def _init_constant_op(self, v, dtype):
+ def init():
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ return init
+
def _prepare(self):
learning_rate = self._call_if_callable(self._learning_rate)
self._learning_rate_tensor = ops.convert_to_tensor(
diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py
index c3a242a75e..00801be3b4 100644
--- a/tensorflow/python/training/adagrad_da_test.py
+++ b/tensorflow/python/training/adagrad_da_test.py
@@ -34,7 +34,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -81,7 +81,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
global_step = resource_variable_ops.ResourceVariable(
0, dtype=dtypes.int64)
@@ -101,7 +101,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -133,7 +133,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAWithL1(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -165,7 +165,7 @@ class AdagradDAOptimizerTest(test.TestCase):
def testAdagradDAWithL1_L2(self):
for dtype in [dtypes.float64, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(0, dtype=dtypes.int64)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index c9aec33d09..7caf01f64d 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -98,7 +98,7 @@ class AdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -117,7 +117,7 @@ class AdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -141,7 +141,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -172,7 +172,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -202,7 +202,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -226,7 +226,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -262,7 +262,7 @@ class AdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -295,13 +295,46 @@ class AdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
# Creating optimizer should cause no exception.
adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ def testDynamicShapeVariableWithCallableInit(self):
+ var0 = variable_scope.get_variable("var0",
+ initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(var0.shape.is_fully_defined())
+
+ grads0 = constant_op.constant(0.1, dtype=dtypes.float32)
+ learning_rate = lambda: 3.0
+
+ ada_opt = adagrad.AdagradOptimizer(
+ learning_rate, initial_accumulator_value=0.1, use_locking=True)
+
+ if not context.executing_eagerly():
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0], [var0]))
+ self.evaluate(variables.global_variables_initializer())
+
+ # Fetch params to validate initial values
+ v0_val = self.evaluate([var0])
+ self.assertAllClose([1.0], v0_val)
+
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ if not context.executing_eagerly():
+ self.evaluate(ada_update)
+ else:
+ ada_opt.apply_gradients(zip([grads0], [var0]))
+
+ # Validate updated params
+ v0_val = self.evaluate([var0])
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932]), v0_val)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index bcbe5907d6..704ad6d3fe 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -43,15 +43,15 @@ class AdamOptimizer(optimizer.Optimizer):
Initialization:
- $$m_0 := 0 (Initialize initial 1st moment vector)$$
- $$v_0 := 0 (Initialize initial 2nd moment vector)$$
- $$t := 0 (Initialize timestep)$$
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
- $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 8f84427654..48db6e3733 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -56,7 +56,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -122,7 +122,7 @@ class AdamOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -152,7 +152,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -224,7 +224,7 @@ class AdamOptimizerTest(test.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -237,7 +237,7 @@ class AdamOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -274,7 +274,7 @@ class AdamOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 76625624e4..3bd4bd75bd 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -1025,7 +1025,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
self._request_summary = (
- self._next_step is None or
+ self._next_step is not None and
self._timer.should_trigger_for_step(self._next_step))
requests = {"global_step": self._global_step_tensor}
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1035,6 +1035,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
stale_global_step = run_values.results["global_step"]
+ if self._next_step is None:
+ # Update the timer so that it does not activate until N steps or seconds
+ # have passed.
+ self._timer.update_last_triggered_step(stale_global_step)
global_step = stale_global_step + 1
if self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..2d469634e0 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1145,7 +1145,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1177,7 +1177,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=[self.summary_op, self.summary_op2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1205,7 +1205,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1240,7 +1240,7 @@ class SummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1388,7 +1388,7 @@ class ResourceSummarySaverHookTest(test.TestCase):
summary_writer=self.summary_writer,
summary_op=self.summary_op)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase):
with self.assertRaises(ValueError):
basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
- def test_save_secs_saves_in_first_step(self):
+ def test_save_secs_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op)
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
# Pick a fixed start time.
- current_time = 1484863632.320497
+ current_time = 1484863632.
with self.graph.as_default():
mock_time.return_value = current_time
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
- self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
# Simulate 2.5 seconds of sleep.
mock_time.return_value = current_time + 2.5
sess.run(self.train_op) # Saved.
+ self.assertEqual(1, self._count_timeline_files())
# Pretend some small amount of time has passed.
- mock_time.return_value = current_time + 0.1
+ mock_time.return_value = current_time + 2.6
sess.run(self.train_op) # Not saved.
# Edge test just before we should save the timeline.
- mock_time.return_value = current_time + 1.9
+ mock_time.return_value = current_time + 4.4
sess.run(self.train_op) # Not saved.
- self.assertEqual(2, self._count_timeline_files())
+ self.assertEqual(1, self._count_timeline_files())
mock_time.return_value = current_time + 4.5
sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
+ self.assertEqual(2, self._count_timeline_files())
- def test_save_steps_saves_in_first_step(self):
+ def test_save_steps_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
@@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase):
save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
self.assertEqual(0, self._count_timeline_files())
+ sess.run(self.train_op) # Not saved.
+ self.assertEqual(0, self._count_timeline_files())
sess.run(self.train_op) # Saved.
self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
@@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase):
self.assertEqual(2, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
self.assertEqual(2, self._count_timeline_files())
- sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
+ def test_run_metadata_saves(self):
writer_cache.FileWriterCache.clear()
fake_summary_writer.FakeSummaryWriter.install()
fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+ sess.run(self.train_op) # Not saved.
sess.run(self.train_op) # Saved.
self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
+ list(fake_writer._added_run_metadata.keys()), ['step_2'])
fake_summary_writer.FakeSummaryWriter.uninstall()
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 85f2904318..38910fb246 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -510,7 +510,10 @@ class CheckpointManager(object):
max_to_keep: An integer, the number of checkpoints to keep. Unless
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
deleted from the active set, oldest first, until only `max_to_keep`
- checkpoints remain.
+ checkpoints remain. If `None`, no checkpoints are deleted and everything
+ stays in the active set. Note that `max_to_keep=None` will keep all
+ checkpoint paths in memory and in the checkpoint state protocol buffer
+ on disk.
keep_checkpoint_every_n_hours: Upon removal from the active set, a
checkpoint will be preserved if it has been at least
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
@@ -521,9 +524,10 @@ class CheckpointManager(object):
"""
self._checkpoint = checkpoint
self._save_counter_assign = None
- if not max_to_keep or max_to_keep < 0:
+ if max_to_keep is not None and max_to_keep <= 0:
raise ValueError(
- "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ ("Expected a positive integer or `None` for `max_to_max_to_keep`, "
+ "got %d.")
% (max_to_keep,))
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
@@ -534,7 +538,9 @@ class CheckpointManager(object):
self._maybe_delete = collections.OrderedDict()
if recovered_state is None:
self._latest_checkpoint = None
- self._last_preserved_timestamp = current_clock
+ # Set the clock back slightly to avoid race conditions when quckly
+ # re-creating a CheckpointManager.
+ self._last_preserved_timestamp = current_clock - 1.
else:
self._latest_checkpoint = recovered_state.model_checkpoint_path
self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
@@ -586,6 +592,10 @@ class CheckpointManager(object):
def _sweep(self):
"""Deletes or preserves managed checkpoints."""
+ if not self._max_to_keep:
+ # Does not update self._last_preserved_timestamp, since everything is kept
+ # in the active set.
+ return
while len(self._maybe_delete) > self._max_to_keep:
filename, timestamp = self._maybe_delete.popitem(last=False)
# Even if we're keeping this checkpoint due to
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 1e2827d0a4..3a061bcb35 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,6 +26,7 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
@@ -72,7 +73,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
# Collides with the default name of the checkpoint state file.
filepath = os.path.join(traindir, "checkpoint")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unused_a = variables.Variable(0.0) # So that Saver saves something.
variables.global_variables_initializer().run()
@@ -112,7 +113,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
filename = "snapshot"
filepath = os.path.join(traindir, filename)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a simple graph.
v0 = variables.Variable(0.0)
inc = v0.assign_add(1.0)
@@ -127,7 +128,7 @@ class LatestCheckpointWithRelativePaths(test.TestCase):
inc.eval()
save.save(sess, filepath, global_step=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a new graph with different initialization.
v0 = variables.Variable(-1.0)
@@ -272,7 +273,7 @@ class SaverUtilsTest(test.TestCase):
def testCheckpointExists(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
@@ -290,7 +291,7 @@ class SaverUtilsTest(test.TestCase):
def testGetCheckpointMtimes(self):
prefixes = []
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(write_version=version)
@@ -304,7 +305,7 @@ class SaverUtilsTest(test.TestCase):
def testRemoveCheckpoint(self):
for sharded in (False, True):
for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
unused_v = variables.Variable(1.0, name="v")
variables.global_variables_initializer().run()
saver = saver_module.Saver(sharded=sharded, write_version=version)
@@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
@test_util.run_in_graph_and_eager_modes
+ def testKeepAll(self):
+ checkpoint = util.Checkpoint()
+ directory = os.path.join(
+ self.get_temp_dir(),
+ # Avoid sharing directories between eager and graph
+ # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
+ str(context.executing_eagerly()))
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertEqual(third_path, manager.latest_checkpoint)
+ self.assertEqual([first_path, second_path, third_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ fourth_path = manager.save()
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=3)
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ fifth_path = manager.save()
+ self.assertEqual([third_path, fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testSaveRestoreState(self, mock_time):
directory = self.get_temp_dir()
@@ -345,8 +389,6 @@ class CheckpointManagerTest(test.TestCase):
mock_time.time.return_value = first_time
first_manager.save()
state = checkpoint_management.get_checkpoint_state(directory)
- self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
- self.assertEqual(3., state.last_preserved_timestamp)
second_time = first_time + 3610.
second_name = os.path.join(directory, "ckpt-2")
mock_time.time.return_value = second_time
@@ -354,7 +396,6 @@ class CheckpointManagerTest(test.TestCase):
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual([first_time, second_time],
state.all_model_checkpoint_timestamps)
- self.assertEqual(3., state.last_preserved_timestamp)
self.assertEqual([first_name, second_name], first_manager.checkpoints)
self.assertEqual(second_name, first_manager.latest_checkpoint)
del first_manager
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
index a6e9662b73..cfd9b39ddc 100644
--- a/tensorflow/python/training/checkpoint_ops.py
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path,
vocab files are the same, and no column remapping is done.
The returned initializer only supports div-partitioning along the row axis. It
- does not support partitioning along the column axis or mod-partitioning.
+ does not support partitioning along the column axis (as this is not common in
+ practice) or mod-partitioning.
NOTE: When this is used to warm-start variables, client code should use
`tf.lookup.index_table_from_tensor()` like
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
index 00611de862..dde8431497 100644
--- a/tensorflow/python/training/checkpoint_ops_test.py
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -43,7 +43,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
# 0., 1., ..., 79. reshaped into [5, 16].
initializer = init_ops.constant_initializer(
np.reshape(np.linspace(0.0, 79, 5 * 16), (5, 16)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope('some_scope'):
variable_scope.get_variable(name='embeddings', shape=[5, 16],
initializer=initializer)
@@ -114,7 +114,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
],
axis=1)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_output_layer_weight_initializer_linear(self):
@@ -150,7 +150,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -184,7 +184,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -222,7 +222,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -258,7 +258,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_matrix,
remapped_matrix.as_tensor().eval())
@@ -292,7 +292,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
@@ -338,7 +338,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
@@ -376,7 +376,7 @@ class LoadAndRemapWrappersTest(test.TestCase):
initializer=embedding_loading_initializer,
partitioner=partitioned_variables.fixed_size_partitioner(2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAllClose(expected_remapped_embeddings,
remapped_embeddings.as_tensor().eval())
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 1c1f126ce9..61dcbdb2b8 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -84,7 +84,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -92,7 +92,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -105,7 +105,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -114,12 +114,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("my1", [1, 10])
with variable_scope.variable_scope("some_other_scope"):
@@ -148,12 +148,12 @@ class CheckpointsTest(test.TestCase):
def testInitialValueComesFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope(
"some_scope", initializer=init_ops.zeros_initializer()):
my1 = variable_scope.get_variable("my1", [1, 10])
@@ -178,7 +178,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -190,14 +190,14 @@ class CheckpointsTest(test.TestCase):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope/": "useful_scope/"})
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(my4.eval(session), v4)
self.assertAllEqual(my5.eval(session), my5_init)
def testRestoreRunsOnSameDevice(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default():
@@ -213,12 +213,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
@@ -237,12 +237,12 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
my3 = variable_scope.get_variable("var3", [100, 100])
@@ -260,12 +260,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
@@ -303,7 +303,7 @@ class CheckpointsTest(test.TestCase):
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
@@ -322,12 +322,12 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
_ = variable_scope.get_variable("my1", [10, 10])
_ = variable_scope.get_variable(
@@ -367,12 +367,12 @@ class CheckpointsTest(test.TestCase):
def testNoAdditionalReadOpsForResourceVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")
with ops.name_scope("init_from_checkpoint"):
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 6e9b8ff905..d26932c1aa 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -101,15 +101,26 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base",
+ ":data_structures",
":tracking",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:io_ops_gen",
- "//tensorflow/python:ops",
+ "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:saveable_object",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 390434c0a2..095a90ddd4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
import json
import weakref
+import six
+
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -90,32 +94,85 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+ """Embeds a tensor in a checkpoint with no restore ops."""
+
+ def __init__(self, tensor, name, dtype=None):
+ spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+ super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+ """An interface for saving/restoring volatile Python state."""
+
+ @abc.abstractmethod
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed.
+
+ Returns:
+ A dictionary mapping `Tensor`s to current Python state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def freeze(self):
+ """Create a new `SaveableObject` which freezes current state as a constant.
+
+ Used when executing eagerly to embed the current state as a constant, or
+ when creating a static tf.train.Saver with the frozen current Python state.
+
+ Returns:
+ A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+ no Python state associated with it).
+ """
+ pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
- def __init__(self, name, state_callback):
+ def __init__(self, name, state_callback, restore_callback=None):
"""Configure saving.
Args:
name: The checkpoint key to write to.
state_callback: A function taking no arguments which returns a
string. This function is run every time a checkpoint is written.
+ restore_callback: A function taking a Python string, used to restore
+ state. Optional; defaults to doing nothing.
"""
- if context.executing_eagerly():
- self._save_string = (
- lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
- else:
+ self._state_callback = state_callback
+ self._restore_callback = restore_callback
+ with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
- self.feed_dict_additions = (
- lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed."""
+ return {self._save_string: self._state_callback()}
+
+ def freeze(self):
+ """Create a frozen `SaveableObject` which saves the current state."""
+ return NoRestoreSaveable(
+ tensor=self._state_callback,
+ dtype=dtypes.string,
+ name=self.name)
+
+ def python_restore(self, restored_strings):
+ """Called to restore Python state."""
+ if self._restore_callback:
+ restored, = restored_strings
+ self._restore_callback(restored)
+
def restore(self, restored_tensors, restored_shapes):
- # TODO(allenl): Add a Python hook for state coming out of a checkpoint
- # (currently PythonStringStateSaveable is write-only).
+ """Called to restore TensorFlow state (nothing to do)."""
return control_flow_ops.no_op()
@@ -227,7 +284,7 @@ class _CheckpointPosition(object):
with ops.device("/cpu:0"):
# Run the restore itself on the CPU.
value, = io_ops.restore_v2(
- prefix=self._checkpoint.save_path,
+ prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],
shape_and_slices=[""],
dtypes=[base_type],
@@ -236,42 +293,99 @@ class _CheckpointPosition(object):
value_tensors[serialized_tensor.name] = array_ops.identity(value)
return value_tensors
- def restore_ops(self):
- """Create or fetch restore ops for this object's attributes.
-
- Requires that the `Checkpointable` Python object has been bound to an object
- ID in the checkpoint.
-
- Returns:
- A list of operations when graph building, or an empty list when executing
- eagerly.
- """
+ def _gather_ops_or_named_saveables(self):
+ """Looks up or creates SaveableObjects which don't have cached ops."""
saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
- restore_ops = []
- building_graph = not context.executing_eagerly()
+ python_saveables = []
+ existing_restore_ops = []
for serialized_tensor in self.object_proto.attributes:
- saveable_factory = saveables.get(serialized_tensor.name, None)
- if saveable_factory is None:
- # Purposefully does not throw an exception if attributes have been added
- # or deleted. Stores unused attributes so an exception can be raised if
- # the user decides to check that everything in the checkpoint was
- # loaded.
- self._checkpoint.unused_attributes.setdefault(
- self.checkpointable, []).append(serialized_tensor.name)
+ if context.executing_eagerly():
+ existing_op = None
+ else:
+ existing_op = self._checkpoint.restore_ops_by_name.get(
+ serialized_tensor.checkpoint_key, None)
+ if existing_op is not None:
+ existing_restore_ops.append(existing_op)
continue
- if building_graph:
- existing_ops = self._checkpoint.restore_ops_by_name.get(
- serialized_tensor.name, None)
+
+ # Only if we don't have cached ops for this SaveableObject, we'll see if
+ # the SaveableObject itself has been cached. If not, we'll make it, and
+ # either way we'll extract new ops from it (or if it has Python state to
+ # restore, we'll run that).
+ if self._checkpoint.saveable_object_cache is None:
+ # No SaveableObject caching when executing eagerly.
+ saveable = None
else:
- existing_ops = None
- if existing_ops is None:
+ # If we've already created and cached a SaveableObject for this
+ # attribute, we can re-use it to avoid re-creating some ops when graph
+ # building.
+ saveable_list = self._checkpoint.saveable_object_cache.get(
+ self.checkpointable, {}).get(serialized_tensor.name, (None,))
+ if len(saveable_list) == 1:
+ # Almost every attribute will have exactly one SaveableObject.
+ saveable, = saveable_list
+ else:
+ # Don't use cached SaveableObjects for partitioned variables, which is
+ # the only case where we'd have a list of SaveableObjects. Op caching
+ # will catch them.
+ saveable = None
+ if saveable is not None:
+ # The name of this attribute has changed, so we need to re-generate
+ # the SaveableObject.
+ if serialized_tensor.checkpoint_key not in saveable.name:
+ saveable = None
+ del self._checkpoint.saveable_object_cache[self.checkpointable]
+ break
+ if saveable is None:
+ # If there was no cached SaveableObject, we should check if the Python
+ # object has the attribute.
+ saveable_factory = saveables.get(serialized_tensor.name, None)
+ if saveable_factory is None:
+ # Purposefully does not throw an exception if attributes have been
+ # added or deleted. Stores unused attributes so an exception can be
+ # raised if the user decides to check that everything in the
+ # checkpoint was loaded.
+ self._checkpoint.unused_attributes.setdefault(
+ self.checkpointable, []).append(serialized_tensor.name)
+ continue
if callable(saveable_factory):
saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
else:
saveable = saveable_factory
+ if self._checkpoint.saveable_object_cache is not None:
+ self._checkpoint.saveable_object_cache.setdefault(
+ self.checkpointable, {})[serialized_tensor.name] = [saveable]
+ if isinstance(saveable, PythonStateSaveable):
+ python_saveables.append(saveable)
+ else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
+ return existing_restore_ops, named_saveables, python_saveables
+
+ def restore_ops(self):
+ """Create or fetch restore ops for this object's attributes.
+
+ Requires that the `Checkpointable` Python object has been bound to an object
+ ID in the checkpoint.
+
+ Returns:
+ A list of operations when graph building, or an empty list when executing
+ eagerly.
+ """
+ (restore_ops,
+ named_saveables,
+ python_saveables) = self._gather_ops_or_named_saveables()
+
+ # Eagerly run restorations for Python state.
+ reader = pywrap_tensorflow.NewCheckpointReader(
+ self._checkpoint.save_path_string)
+ for saveable in python_saveables:
+ spec_names = [spec.name for spec in saveable.specs]
+ saveable.python_restore(
+ [reader.get_tensor(name) for name in spec_names])
+
+ # If we have new SaveableObjects, extract and cache restore ops.
if named_saveables:
validated_saveables = (
self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access
@@ -281,7 +395,7 @@ class _CheckpointPosition(object):
("Saveable keys changed when validating. Got back %s, was "
"expecting %s") % (named_saveables.keys(), validated_names))
all_tensors = self._checkpoint.builder.bulk_restore(
- filename_tensor=self._checkpoint.save_path,
+ filename_tensor=self._checkpoint.save_path_tensor,
saveables=validated_saveables, preferred_shard=-1,
restore_sequentially=False)
saveable_index = 0
@@ -291,7 +405,7 @@ class _CheckpointPosition(object):
saveable_index:saveable_index + num_specs]
saveable_index += num_specs
restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
- if building_graph:
+ if not context.executing_eagerly():
assert saveable.name not in self._checkpoint.restore_ops_by_name
self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
restore_ops.append(restore_op)
@@ -753,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
- return json.dumps(self,
+ return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else:
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 507cda8734..c29e5db075 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import copy
import six
@@ -128,7 +129,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
if (isinstance(value, CheckpointableDataStructure)
- or layer_utils.is_layer(value)):
+ or layer_utils.is_layer(value)
+ or layer_utils.has_weights(value)):
# Check for object-identity rather than with __eq__ to avoid
# de-duplicating empty container types. Automatically generated list
# wrappers keep things like "[] == []" true, which means "[] in [[]]" is
@@ -149,14 +151,14 @@ class CheckpointableDataStructure(base.CheckpointableBase):
def trainable_weights(self):
return layer_utils.gather_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
return layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
- sub_layers=self.layers,
+ sub_layers=self._layers,
extra_variables=self._extra_variables)
@property
@@ -183,7 +185,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
# have any inputs.
aggregated = []
for layer in self.layers:
- aggregated += layer.updates
+ if hasattr(layer, "updates"):
+ aggregated += layer.updates
return aggregated
@property
@@ -191,7 +194,8 @@ class CheckpointableDataStructure(base.CheckpointableBase):
"""Aggregate losses from any `Layer` instances."""
aggregated = []
for layer in self.layers:
- aggregated += layer.losses
+ if hasattr(layer, "losses"):
+ aggregated += layer.losses
return aggregated
def __hash__(self):
@@ -248,6 +252,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
self._storage[index] = self._track_value(
element, name=self._name_element(index))
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
"""Determines the backing storage (overridden in subclasses)."""
return list(*args, **kwargs)
@@ -322,6 +332,20 @@ class _ListWrapper(List, collections.MutableSequence,
super(_ListWrapper, self).__init__(wrapped_list)
self._last_wrapped_list_snapshot = list(self._storage)
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_ListWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_ListWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_list):
"""Use the user's original list for storage."""
return wrapped_list
@@ -446,6 +470,12 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
value, name=self._name_element(key))
for key, value in self._storage.items()})
+ def __copy__(self):
+ return type(self)(copy.copy(self._storage))
+
+ def __deepcopy__(self, memo):
+ return type(self)(copy.deepcopy(self._storage, memo))
+
def _make_storage(self, *args, **kwargs):
return dict(*args, **kwargs)
@@ -522,6 +552,22 @@ class _DictWrapper(Mapping, collections.MutableMapping):
super(_DictWrapper, self).__init__(wrapped_dict)
self._update_snapshot()
+ # pylint: disable=protected-access
+ def __copy__(self):
+ copied = super(_DictWrapper, self).__copy__()
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+
+ def __deepcopy__(self, memo):
+ copied = super(_DictWrapper, self).__deepcopy__(memo)
+ copied._non_append_mutation = self._non_append_mutation
+ copied._external_modification = self._external_modification
+ copied._non_string_key = self._non_string_key
+ return copied
+ # pylint: enable=protected-access
+
def _make_storage(self, wrapped_dict):
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
return wrapped_dict
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 472b7c32b4..5597c7c772 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import os
import numpy
@@ -31,6 +32,7 @@ from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -96,6 +98,11 @@ class ListTests(test.TestCase):
model.load_weights(save_path)
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
self.evaluate(model.variables[0]))
+ v = variables.Variable(1.)
+ model.var_list = [v]
+ self.assertIn(v, model.variables)
+ self.assertIn(v, model.trainable_variables)
+ self.assertNotIn(v, model.non_trainable_variables)
def testUpdatesForwarded(self):
with context.graph_mode():
@@ -418,6 +425,104 @@ class MappingTests(test.TestCase):
new_dict.update(model.d)
self.assertEqual({1: 3}, new_dict)
+ def testListShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.copy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testListDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_list = [[1.]]
+ root.a = orig_list
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([[1.]], copied)
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a[0], copied[0])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_list.append(1.)
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testDictShallowCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.copy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIs(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.copy(root.a))
+
+ def testDictDeepCopy(self):
+ root = tracking.Checkpointable()
+ orig_dict = {"a": [1.]}
+ root.a = orig_dict
+ copied = copy.deepcopy(root.a)
+ self.assertAllEqual([1.], copied["a"])
+ self.assertIsNot(root.a, copied)
+ self.assertIsNot(root.a["a"], copied["a"])
+
+ # Dirtiness should be inherited
+ util.list_objects(root.a)
+ orig_dict["b"] = []
+ with self.assertRaises(ValueError):
+ util.list_objects(root.a)
+ with self.assertRaises(ValueError):
+ util.list_objects(copy.deepcopy(root.a))
+
+ def testShallowCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ shallow_copied = copy.copy(original)
+ self.assertIs(original_sub, shallow_copied.b["a"])
+ self.assertIsNot(original, shallow_copied)
+ self.assertEqual([[1.]], shallow_copied.a)
+ shallow_deps = util.list_objects(shallow_copied)
+ self.assertIn(shallow_copied.a, shallow_deps)
+ self.assertIn(shallow_copied.b, shallow_deps)
+ self.assertIn(shallow_copied.b["a"], shallow_deps)
+
+ def testDeepCopyCheckpointable(self):
+ original = tracking.Checkpointable()
+ original_sub = tracking.Checkpointable()
+ original.a = [[1.]]
+ original.b = {"a": original_sub}
+ deep_copied = copy.deepcopy(original)
+ self.assertIsNot(original, deep_copied)
+ self.assertIsNot(original_sub, deep_copied.b["a"])
+ self.assertEqual([[1.]], deep_copied.a)
+ self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
+ deps = util.list_objects(deep_copied)
+ self.assertIn(deep_copied.a, deps)
+ self.assertIn(deep_copied.b, deps)
+ self.assertIn(deep_copied.b["a"], deps)
+ self.assertNotIn(original_sub, deps)
+
def testConstructableFromSequence(self):
result = data_structures._DictWrapper([(1, 2), (3, 4)])
self.assertIsInstance(result, dict)
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
index d65b631fe9..ec764bca89 100644
--- a/tensorflow/python/training/checkpointable/layer_utils.py
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -30,13 +30,20 @@ def is_layer(obj):
and hasattr(obj, "variables"))
+def has_weights(obj):
+ """Implicit check for Layer-like objects."""
+ # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
+ return (hasattr(obj, "trainable_weights")
+ and hasattr(obj, "non_trainable_weights"))
+
+
def filter_empty_layer_containers(layer_list):
"""Filter out empty Layer-like containers."""
filtered = []
for obj in layer_list:
if is_layer(obj):
filtered.append(obj)
- else:
+ elif hasattr(obj, "layers"):
# Checkpointable data structures will not show up in ".layers" lists, but
# the layers they contain will.
filtered.extend(obj.layers)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index e85f812ce2..a44c570fb9 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -165,7 +165,7 @@ class InterfaceTests(test.TestCase):
self.assertEqual([c], a.attribute["c"].layers)
checkpoint = util.Checkpoint(a=a)
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
- with self.test_session():
+ with self.cached_session():
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
@test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index e42f989469..56c4043d9d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@@ -68,16 +67,25 @@ _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
class _CheckpointRestoreCoordinator(object):
"""Holds the status of an object-based checkpoint load."""
- def __init__(self, object_graph_proto, save_path, dtype_map=None):
+ def __init__(self, object_graph_proto, save_path, save_path_tensor,
+ restore_op_cache, saveable_object_cache):
"""Specify the checkpoint being loaded.
Args:
object_graph_proto: The CheckpointableObjectGraph protocol buffer
associated with this checkpoint.
- save_path: A string `Tensor`. The path to the checkpoint, as returned by
+ save_path: A string, the path to the checkpoint, as returned by
`tf.train.latest_checkpoint`.
- dtype_map: When executing eagerly, specifies dtypes for creating slot
- variables. None when graph building.
+ save_path_tensor: A string `Tensor` which contains or will be fed the save
+ path.
+ restore_op_cache: A dictionary shared between
+ `_CheckpointRestoreCoordinator`s for the same Python objects, used to
+ look up restore ops by name to avoid re-creating them across multiple
+ `restore()` calls.
+ saveable_object_cache: A mapping of checkpointable objects -> attribute
+ names -> list(`SaveableObject`s), used when `SaveableObjects` must be
+ referenced every restore (e.g. for Python state); otherwise they would
+ create their own ops every restore.
"""
self.builder = saver_lib.BulkSaverBuilder()
self.object_graph_proto = object_graph_proto
@@ -97,12 +105,17 @@ class _CheckpointRestoreCoordinator(object):
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
self.all_python_objects = _ObjectIdentityWeakSet()
- self.save_path = save_path
- self.dtype_map = dtype_map
+ self.save_path_tensor = save_path_tensor
+ self.save_path_string = save_path
+ self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
+ save_path).get_variable_to_dtype_map()
+ # A NewCheckpointReader for the most recent checkpoint, for streaming Python
+ # state restoration.
# When graph building, contains a list of ops to run to restore objects from
# this checkpoint.
self.restore_ops = []
- self.restore_ops_by_name = {}
+ self.restore_ops_by_name = restore_op_cache
+ self.saveable_object_cache = saveable_object_cache
self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
@@ -185,6 +198,7 @@ class _NameBasedRestoreCoordinator(object):
for saveable in self.globally_named_object_attributes(
checkpointable):
restored_tensors = []
+ tensor_missing = False
for spec in saveable.specs:
if spec.name in self.dtype_map:
with ops.device("cpu:0"):
@@ -195,9 +209,15 @@ class _NameBasedRestoreCoordinator(object):
dtypes=[self.dtype_map[spec.name]],
name="%s_checkpoint_read" % (spec.name,))
restored_tensors.append(array_ops.identity(restored))
+ else:
+ tensor_missing = True
- saveable.restore(restored_tensors=restored_tensors,
- restored_shapes=None)
+ if not tensor_missing:
+ # Ignores values missing from the checkpoint, as with object-based
+ # restore. Status assertions can be used to check exact matches,
+ # although it's unlikely to ever happen for name-based checkpoints.
+ saveable.restore(restored_tensors=restored_tensors,
+ restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
@@ -536,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
- feed_additions = {}
+ if saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the checkpoint.
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@@ -595,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
- saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
- if saveable_feed_dict_fn is not None:
- saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveables.extend(saveables)
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@@ -664,6 +698,11 @@ def _serialize_object_graph(root_checkpointable, saveables_cache):
saveables_cache=saveables_cache)
+def named_saveables(root_checkpointable):
+ """Gather list of all SaveableObjects in the Checkpointable object."""
+ return _serialize_object_graph(root_checkpointable, None)[0]
+
+
def list_objects(root_checkpointable):
"""Traverse the object graph and list all accessible objects.
@@ -801,16 +840,6 @@ def capture_dependencies(template):
yield
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, tensor, name):
- spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
- super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
- def restore(self, restored_tensors, restored_shapes):
- return control_flow_ops.no_op()
-
-
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@@ -820,6 +849,11 @@ class _LoadStatus(object):
pass
@abc.abstractmethod
+ def assert_existing_objects_matched(self):
+ """Raises an exception unless existing Python objects have been matched."""
+ pass
+
+ @abc.abstractmethod
def run_restore_ops(self, session=None):
"""Runs restore ops from the checkpoint. Requires a valid checkpoint."""
pass
@@ -889,13 +923,11 @@ class CheckpointLoadStatus(_LoadStatus):
or if there are any checkpointed values which have not been matched to
Python objects.
"""
+ self.assert_existing_objects_matched()
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None)
if checkpointable is None:
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
- if checkpointable._update_uid < self._checkpoint.restore_uid: # pylint: disable=protected-access
- raise AssertionError(
- "Object not assigned a value from checkpoint: %s" % (node,))
if self._checkpoint.slot_restorations:
# Sanity check; this collection should be clear if everything has been
# restored.
@@ -906,6 +938,31 @@ class CheckpointLoadStatus(_LoadStatus):
("Unused attributes in these objects (the attributes exist in the "
"checkpoint but not in the objects): %s") % (
self._checkpoint.unused_attributes.items(),))
+ return self
+
+ def assert_existing_objects_matched(self):
+ """Asserts that checkpointable Python objects have been matched.
+
+ Note that this is a weaker assertion than `assert_consumed`. It will only
+ fail for existing Python objects which are (transitive) dependencies of the
+ root object and which do not have an entry in the checkpoint.
+
+ It will not fail, for example, if a `tf.keras.Layer` object has not yet been
+ built and so has not created any `tf.Variable` objects.
+
+ Returns:
+ `self` for chaining.
+
+ Raises:
+ AssertionError: If a Python object exists in the transitive dependencies
+ of the root object but does not have a value in the checkpoint.
+ """
+ for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
+ checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None)
+ if (checkpointable is not None
+ and checkpointable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
+ raise AssertionError(
+ "Object not assigned a value from checkpoint: %s" % (node,))
for checkpointable_object in list_objects(self._root_checkpointable):
self._checkpoint.all_python_objects.add(checkpointable_object)
unused_python_objects = (
@@ -915,7 +972,7 @@ class CheckpointLoadStatus(_LoadStatus):
raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely "
"due to changes in the Python program: %s")
- % (unused_python_objects,))
+ % (list(unused_python_objects),))
return self
def run_restore_ops(self, session=None):
@@ -977,6 +1034,11 @@ class InitializationOnlyStatus(_LoadStatus):
raise AssertionError(
"No checkpoint specified (save_path=None); nothing is being restored.")
+ def assert_existing_objects_matched(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "No checkpoint specified (save_path=None); nothing is being restored.")
+
def run_restore_ops(self, session=None):
"""For consistency with `CheckpointLoadStatus`.
@@ -1050,6 +1112,15 @@ class NameBasedSaverStatus(_LoadStatus):
if checkpointable._update_uid < self._checkpoint.restore_uid:
raise AssertionError("Object not restored: %s" % (checkpointable,))
# pylint: enable=protected-access
+ return self
+
+ def assert_existing_objects_matched(self):
+ """Raises an exception if currently created objects are unmatched."""
+ # For name-based checkpoints there's no object information in the
+ # checkpoint, so there's no distinction between
+ # assert_existing_objects_matched and assert_consumed (and both are less
+ # useful since we don't touch Python objects or Python state).
+ return self.assert_consumed()
def _gather_saveable_objects(self):
"""Walk the object graph, using global names for SaveableObjects."""
@@ -1153,16 +1224,15 @@ class CheckpointableSaver(object):
self._last_save_object_graph = None
self._last_save_saver = None
- # Op caching for restore
- self._last_restore_object_graph = None
- self._last_restore_checkpoint = None
+ # Op caching for restore, shared between _CheckpointRestoreCoordinators
+ self._restore_op_cache = {}
if context.executing_eagerly():
# SaveableObjects are always recreated when executing eagerly.
self._saveable_object_cache = None
else:
- # Maps Checkpointable objects -> attribute names -> SaveableObjects, to
- # avoid re-creating SaveableObjects when graph building.
+ # Maps Checkpointable objects -> attribute names -> list(SaveableObjects),
+ # to avoid re-creating SaveableObjects when graph building.
self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
@property
@@ -1174,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
+ def _gather_saveables(
+ self, object_graph_tensor=None, saveable_object_cache=None):
+ """Wraps _serialize_object_graph to include the object graph proto."""
+ assert ((object_graph_tensor is None and saveable_object_cache is None)
+ or (object_graph_tensor is not None
+ and saveable_object_cache is not None))
+ (named_saveable_objects, graph_proto,
+ feed_additions) = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=saveable_object_cache)
+ if object_graph_tensor is None:
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ else:
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects, graph_proto, feed_additions
+
+ def freeze(self):
+ """Creates a `tf.train.Saver` with the current object graph frozen."""
+ named_saveable_objects, _, _ = self._gather_saveables(
+ object_graph_tensor=None, saveable_object_cache=None)
+ return saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+
+ def _prepare_save(self,
+ object_graph_tensor=None,
+ saveable_object_cache=None):
+ """Create or retrieve save ops.
+
+ When graph building, `saveable_object_cache` will typically be non-`None`,
+ meaning that existing `SaveableObject`s are re-used across calls to
+ `_prepare_save` even if the object graph has grown. This avoids
+ unnecessarily re-creating save ops.
+
+ Args:
+ object_graph_tensor: A `Tensor` to which the current object graph will be
+ fed.
+ saveable_object_cache: A dictionary; if specified, used to cache
+ `SaveableObject`s.
+
+ Returns:
+ A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+ to feed when running save ops. The feed dict contains the current object
+ graph and any Python state to be saved in the checkpoint.
+ """
+ (named_saveable_objects, graph_proto,
+ feed_additions) = self._gather_saveables(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=saveable_object_cache)
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
+ if self._last_save_object_graph is not None:
+ self._last_save_saver = _copy_saver_with_new_var_list(
+ old_saver=self._last_save_saver,
+ new_var_list=named_saveable_objects)
+ else:
+ self._last_save_saver = saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+ self._last_save_object_graph = graph_proto
+ return self._last_save_saver, feed_additions
+
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@@ -1196,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto, feed_additions = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=self._saveable_object_cache)
- if not context.executing_eagerly():
- if session is None:
- session = ops.get_default_session()
+ feed_additions = {}
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions.update(
- {object_graph_tensor: graph_proto.SerializeToString()})
else:
+ object_graph_tensor = None
+
+ saver, new_feed_additions = self._prepare_save(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=self._saveable_object_cache)
+ if new_feed_additions:
+ feed_additions.update(new_feed_additions)
+ if not graph_building:
session = None
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables.append(
- _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- if (self._last_save_object_graph != graph_proto
- # When executing eagerly, we need to re-create SaveableObjects each time
- # save() is called so they pick up new Tensors passed to their
- # constructors. That means the Saver needs to be copied with a new
- # var_list.
- or context.executing_eagerly()):
- if self._last_save_object_graph is not None:
- self._last_save_saver = _copy_saver_with_new_var_list(
- old_saver=self._last_save_saver, new_var_list=named_variables)
- else:
- self._last_save_saver = saver_lib.Saver(
- var_list=named_variables, max_to_keep=None)
- self._last_save_object_graph = graph_proto
+ elif session is None:
+ session = ops.get_default_session()
+
with ops.device("/cpu:0"):
- save_path = self._last_save_saver.save(
+ save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@@ -1340,22 +1467,12 @@ class CheckpointableSaver(object):
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
- if graph_building and object_graph_proto == self._last_restore_object_graph:
- checkpoint = self._last_restore_checkpoint
- else:
- checkpoint = _CheckpointRestoreCoordinator(
- object_graph_proto=object_graph_proto,
- save_path=file_prefix_tensor,
- dtype_map=dtype_map)
- if graph_building:
- if self._last_restore_object_graph is not None:
- raise NotImplementedError(
- "Using a single Saver to restore different object graphs is not "
- "currently supported when graph building. Use a different Saver "
- "for each object graph (restore ops will be duplicated), or "
- "file a feature request if this limitation bothers you.")
- self._last_restore_checkpoint = checkpoint
- self._last_restore_object_graph = object_graph_proto
+ checkpoint = _CheckpointRestoreCoordinator(
+ object_graph_proto=object_graph_proto,
+ save_path=save_path,
+ save_path_tensor=file_prefix_tensor,
+ restore_op_cache=self._restore_op_cache,
+ saveable_object_cache=self._saveable_object_cache)
base._CheckpointPosition( # pylint: disable=protected-access
checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
load_status = CheckpointLoadStatus(
@@ -1365,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
+def frozen_saver(root_checkpointable):
+ """Creates a static `tf.train.Saver` from a checkpointable object.
+
+ The returned `Saver` saves object-based checkpoints, but these checkpoints
+ will no longer reflect structural changes to the object graph, only changes to
+ the values of `Variable`s added as dependencies of the root object before
+ `freeze` was called.
+
+ `restore` works on the returned `Saver`, but requires that the object graph of
+ the checkpoint being loaded exactly matches the object graph when `freeze` was
+ called. This is in contrast the object-based restore performed by
+ `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+ object graph and the current Python object graph.
+
+ Args:
+ root_checkpointable: A checkpointable object to save.
+
+ Returns:
+ A `tf.train.Saver` which saves object-based checkpoints for the object graph
+ frozen at the time `frozen_saver` was called.
+ """
+ return CheckpointableSaver(root_checkpointable).freeze()
+
+
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.
@@ -1644,6 +1785,17 @@ class Checkpoint(tracking.Checkpointable):
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
chained with `initialize_or_restore` or `run_restore_ops`.
+ - `assert_existing_objects_matched()`:
+ Raises an exception if any existing Python objects in the dependency
+ graph are unmatched. Unlike `assert_consumed`, this assertion will
+ pass if values in the checkpoint have no corresponding Python
+ objects. For example a `tf.keras.Layer` object which has not yet been
+ built, and so has not created any variables, will pass this assertion
+ but fail `assert_consumed`. Useful when loading part of a larger
+ checkpoint into a new Python program, e.g. a training checkpoint with
+ a `tf.train.Optimizer` was saved but only the state required for
+ inference is being loaded. This method returns the status object, and
+ so may be chained with `initialize_or_restore` or `run_restore_ops`.
- `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index a0a87b6b79..f8b5bd8501 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -384,8 +384,8 @@ class CheckpointingTests(test.TestCase):
saver = saver_lib.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
@@ -437,6 +437,9 @@ class CheckpointingTests(test.TestCase):
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
on_create_model(constant_op.constant([[3.]])) # create variables
self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
self.assertAllEqual([42.],
@@ -444,6 +447,9 @@ class CheckpointingTests(test.TestCase):
on_create_model._named_dense.variables[1]))
on_create_m_bias_slot = on_create_optimizer.get_slot(
on_create_model._named_dense.variables[1], "m")
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
# Optimizer slot variables are created when the original variable is
# restored.
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
@@ -451,6 +457,7 @@ class CheckpointingTests(test.TestCase):
self.evaluate(on_create_optimizer.variables()))
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_existing_objects_matched()
status.assert_consumed()
beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
@@ -499,15 +506,18 @@ class CheckpointingTests(test.TestCase):
global_step=root.global_step)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
- with self.test_session(graph=ops.get_default_graph()) as session:
+ with self.session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
self.assertEqual(0, training_continuation)
with self.assertRaises(AssertionError):
status.assert_consumed()
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
else:
status.assert_consumed()
+ status.assert_existing_objects_matched()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix, session=session)
@@ -550,6 +560,46 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
@test_util.run_in_graph_and_eager_modes
+ def testFreezing(self):
+ with self.cached_session(use_gpu=True) as session:
+ # Save an object-based checkpoint using a frozen saver
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ self.evaluate(v.assign(3))
+ # Create the save counter so assert_consumed doesn't complain about it not
+ # existing in the checkpoint on restore.
+ self.evaluate(checkpoint.save_counter.assign(12))
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ save_path = saver.save(session, prefix)
+ self.evaluate(v.assign(10))
+ # Use the frozen saver to restore the same object graph
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore using another frozen saver on an identical object graph
+ del v, checkpoint, saver
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore as an object-based checkpoint
+ del v, checkpoint, saver
+ checkpoint = checkpointable_utils.Checkpoint()
+ status = checkpoint.restore(save_path)
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ if context.executing_eagerly():
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+ self.assertEqual(0, self.evaluate(v))
+ checkpoint.v = v
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(3, self.evaluate(v))
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
def testCustomNumbering(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
@@ -704,11 +754,12 @@ class CheckpointingTests(test.TestCase):
load_into = LateDependencies()
status = checkpointable_utils.CheckpointableSaver(
load_into).restore(save_path)
+ status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
load_into.add_dep()
status.assert_consumed()
- status.run_restore_ops()
+ status.assert_existing_objects_matched().run_restore_ops()
self.assertEqual(123., self.evaluate(load_into.dep.var))
@test_util.run_in_graph_and_eager_modes
@@ -785,6 +836,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
+ slot_status.assert_existing_objects_matched()
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
@@ -884,6 +936,8 @@ class CheckpointingTests(test.TestCase):
load_root.dep_one.dep_three, name="var", initializer=0.)
with self.assertRaises(AssertionError):
status.assert_consumed()
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
@test_util.run_in_graph_and_eager_modes
def testObjectsCombined(self):
@@ -907,7 +961,7 @@ class CheckpointingTests(test.TestCase):
v2 = checkpointable_utils.add_variable(
load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
status = checkpointable_utils.CheckpointableSaver(load_root).restore(
- save_path).assert_consumed()
+ save_path).assert_consumed().assert_existing_objects_matched()
status.run_restore_ops()
self.assertEqual(32., self.evaluate(v1))
self.assertEqual(64., self.evaluate(v2))
@@ -994,7 +1048,7 @@ class CheckpointingTests(test.TestCase):
"""Saves after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()
@@ -1073,22 +1127,17 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(5, self.evaluate(checkpoint.var_5))
self.assertEqual(1, self.evaluate(checkpoint.var_1))
self.assertEqual(0, self.evaluate(checkpoint.var_0))
- if context.executing_eagerly():
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
- self.assertEqual(9, self.evaluate(checkpoint.var_9))
- self.assertEqual(8, self.evaluate(checkpoint.var_8))
- self.assertEqual(1, self.evaluate(checkpoint.var_1))
- self.assertEqual(0, self.evaluate(checkpoint.var_0))
- else:
- # Restoring into modified graphs is an error while graph building.
- with self.assertRaises(NotImplementedError):
- checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops()
+ self.assertEqual(9, self.evaluate(checkpoint.var_9))
+ self.assertEqual(8, self.evaluate(checkpoint.var_8))
+ self.assertEqual(1, self.evaluate(checkpoint.var_1))
+ self.assertEqual(0, self.evaluate(checkpoint.var_0))
def testManyRestoresGraph(self):
"""Restores after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()
@@ -1244,6 +1293,8 @@ class CheckpointingTests(test.TestCase):
status.initialize_or_restore()
train_fn()
with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
status.assert_consumed()
# Make sure initialization doesn't clobber later restores
@@ -1456,17 +1507,27 @@ class CheckpointCompatibilityTests(test.TestCase):
if context.executing_eagerly():
with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
status.assert_consumed()
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_existing_objects_matched()
else:
# When graph building, we haven't read any keys, so we don't know
# whether the restore will be complete.
with self.assertRaisesRegexp(AssertionError, "not restored"):
status.assert_consumed()
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_existing_objects_matched()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)
+ # Check that there is no error when keys are missing from the name-based
+ # checkpoint.
+ root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.])
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
def testSaveGraphLoadEager(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 20e031569b..21ca1735e0 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -248,6 +249,7 @@ class DistributionStrategy(object):
devices.
We have then a few approaches we want to support:
+
* Code written (as if) with no knowledge of class `DistributionStrategy`.
This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done
@@ -370,7 +372,7 @@ class DistributionStrategy(object):
use its API, including `merge_call()` to get back to cross-tower
context), once for each tower. May use values with locality T or
M, and any variable.
- * `d.reduce(m, t)`: in cross-tower context, accepts t with locality T
+ * `d.reduce(m, t, t)`: in cross-tower context, accepts t with locality T
and produces a value with locality M.
* `d.reduce(m, t, v)`: in cross-tower context, accepts t with
locality T and produces a value with locality V(`v`).
@@ -403,10 +405,11 @@ class DistributionStrategy(object):
Another thing you might want to do in the middle of your tower function
is an all-reduce of some intermediate value, using `d.reduce()` or
- `d.batch_reduce()` without supplying a variable as the destination.
+ `d.batch_reduce()`. You simply provide the same tensor as the input and
+ destination.
Layers should expect to be called in a tower context, and can use
- the `get_tower_context()` function to get a `TowerContext` object. The
+ the `get_tower_context()` function to get a `TowerContext` object. The
`TowerContext` object has a `merge_call()` method for entering
cross-tower context where you can use `reduce()` (or
`batch_reduce()`) and then optionally `update()` to update state.
@@ -717,18 +720,18 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, aggregation, value, destinations=None):
+ def reduce(self, aggregation, value, destinations):
"""Combine (via e.g. sum or mean) values across towers.
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value: A per-device value with one value per tower.
- destinations: An optional mirrored variable, a device string,
- list of device strings. The return value will be copied to all
- destination devices (or all the devices where the mirrored
- variable resides). If `None` or unspecified, the destinations
- will match the devices `value` resides on.
+ destinations: A mirrored variable, a per-device tensor, a device string,
+ or list of device strings. The return value will be copied to all
+ destination devices (or all the devices where the `destinations` value
+ resides). To perform an all-reduction, pass `value` to `destinations`.
Returns:
A value mirrored to `destinations`.
@@ -739,7 +742,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._reduce(aggregation, value, destinations)
@@ -751,7 +755,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -762,7 +767,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -1071,10 +1077,15 @@ class TowerContext(object):
require_tower_context(self)
return device_util.current()
- # TODO(josh11b): Implement `start_all_reduce(method, t)` that returns
- # a function returning the result of reducing `t` across all
- # towers. Most likely can be implemented in terms of `merge_call()`
- # and `batch_reduce()`.
+ # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
+ # all-reduce. It would return a function returning the result of reducing `t`
+ # across all towers. The caller would wait to call this function until they
+ # needed the reduce result, allowing an efficient implementation:
+ # * With eager execution, the reduction could be performed asynchronously
+ # in the background, not blocking until the result was needed.
+ # * When constructing a graph, it could batch up all reduction requests up
+ # to that point that the first result is needed. Most likely this can be
+ # implemented in terms of `merge_call()` and `batch_reduce()`.
# ------------------------------------------------------------------------------
@@ -1167,9 +1178,14 @@ class _DefaultDistributionStrategy(DistributionStrategy):
# ------------------------------------------------------------------------------
-# Common operations
+# Deprecated, use v.assign_add(amount) instead. Internal API, so expect
+# it to be deleted soon.
+@deprecation.deprecated(None,
+ "Use v.assign_add(amount) instead. You may need to set "
+ "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER "
+ "when creating the variable.")
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 775bdb3f60..09d6fe36d3 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -37,7 +37,7 @@ class FtrlOptimizerTest(test.TestCase):
def doTestFtrlwithoutRegularization(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -76,7 +76,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlwithoutRegularization2(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -105,7 +105,7 @@ class FtrlOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -117,12 +117,11 @@ class FtrlOptimizerTest(test.TestCase):
# Run 1 step of sgd
sgd_op.run()
# Validate updated params
- self.assertAllCloseAccordingToType(
- [[0, 1]], var0.eval(), atol=0.01)
+ self.assertAllCloseAccordingToType([[0, 1]], var0.eval(), atol=0.01)
def testFtrlWithL1(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -151,7 +150,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1_L2(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -187,7 +186,7 @@ class FtrlOptimizerTest(test.TestCase):
weights will tend to have smaller magnitudes with this parameter set.
"""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -212,24 +211,96 @@ class FtrlOptimizerTest(test.TestCase):
v0_val, v1_val = sess.run([var0, var1])
self.assertAllCloseAccordingToType(
- np.array([-0.22078767, -0.41378114]), v0_val)
+ np.array([-0.22578995, -0.44345796]), v0_val)
self.assertAllCloseAccordingToType(
- np.array([-0.02919818, -0.07343706]), v1_val)
+ np.array([-0.14378493, -0.13229476]), v1_val)
+
+ def testFtrlWithL1_L2_L2ShrinkageSparse(self):
+ """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+
+ opt = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
+ self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
+ self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.test_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.FtrlOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((v0_val**2 < v1_val**2).all())
+ accum0 = list(sess.run(opt0._slots)["accum"].values())[0]
+ accum1 = list(sess.run(opt1._slots)["accum"].values())[0]
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
def applyOptimizer(self, opt, dtype, steps=5, is_sparse=False):
if is_sparse:
var0 = variables.Variable([[0.0], [0.0]], dtype=dtype)
var1 = variables.Variable([[0.0], [0.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
- constant_op.constant(
- [0.1], shape=[1, 1], dtype=dtype),
- constant_op.constant([0]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
grads1 = ops.IndexedSlices(
- constant_op.constant(
- [0.02], shape=[1, 1], dtype=dtype),
- constant_op.constant([1]),
- constant_op.constant([2, 1]))
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
else:
var0 = variables.Variable([0.0, 0.0], dtype=dtype)
var1 = variables.Variable([0.0, 0.0], dtype=dtype)
@@ -264,7 +335,7 @@ class FtrlOptimizerTest(test.TestCase):
# FTRL-Proximal performs same updates as Adagrad or GradientDescent.
def testEquivAdagradwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -275,17 +346,16 @@ class FtrlOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
dtype)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1), dtype)
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
self.assertAllCloseAccordingToType(val0, val2)
self.assertAllCloseAccordingToType(val1, val3)
def testEquivSparseAdagradwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -297,10 +367,9 @@ class FtrlOptimizerTest(test.TestCase):
dtype,
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
- adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1),
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
dtype,
is_sparse=True)
@@ -309,7 +378,7 @@ class FtrlOptimizerTest(test.TestCase):
def testEquivSparseGradientDescentwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -321,7 +390,7 @@ class FtrlOptimizerTest(test.TestCase):
dtype,
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0),
dtype,
@@ -332,7 +401,7 @@ class FtrlOptimizerTest(test.TestCase):
def testEquivGradientDescentwithoutRegularization(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
ftrl.FtrlOptimizer(
3.0,
@@ -343,7 +412,7 @@ class FtrlOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
dtype)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), dtype)
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index b304e92421..56d82a5b88 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -37,7 +37,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -60,7 +60,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasicResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -85,7 +85,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testBasicCallableParams(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -111,7 +111,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -137,7 +137,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -164,7 +164,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -186,7 +186,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testGradWrtRef(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
opt = gradient_descent.GradientDescentOptimizer(3.0)
values = [1.0, 3.0]
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
@@ -197,7 +197,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testWithGlobalStep(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
global_step = variables.Variable(0, trainable=False)
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
@@ -220,7 +220,7 @@ class GradientDescentOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index caa26581e8..9d9db70890 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -15,7 +15,8 @@
"""Input pipeline.
-Please see the @{$reading_data$reading data how-to}
+Please see the [reading data
+how-to](https://tensorflow.org/api_guides/python/reading_data)
for context.
"""
@@ -44,6 +45,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.summary import summary
from tensorflow.python.training import queue_runner
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -74,7 +76,10 @@ def match_filenames_once(pattern, name=None):
collections=[ops.GraphKeys.LOCAL_VARIABLES])
-@tf_export("train.limit_epochs")
+@tf_export(v1=["train.limit_epochs"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.")
def limit_epochs(tensor, num_epochs=None, name=None):
"""Returns tensor `num_epochs` times and then raises an `OutOfRange` error.
@@ -107,7 +112,12 @@ def limit_epochs(tensor, num_epochs=None, name=None):
return array_ops.identity(tensor, name=name)
-@tf_export("train.input_producer")
+@tf_export(v1=["train.input_producer"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.from_tensor_slices(input_tensor).shuffle"
+ "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+ "`shuffle=False`, omit the `.shuffle(...)`.")
def input_producer(input_tensor,
element_shape=None,
num_epochs=None,
@@ -190,7 +200,12 @@ def input_producer(input_tensor,
return q
-@tf_export("train.string_input_producer")
+@tf_export(v1=["train.string_input_producer"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.from_tensor_slices(string_tensor).shuffle"
+ "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+ "`shuffle=False`, omit the `.shuffle(...)`.")
def string_input_producer(string_tensor,
num_epochs=None,
shuffle=True,
@@ -260,7 +275,11 @@ def string_input_producer(string_tensor,
cancel_op=cancel_op)
-@tf_export("train.range_input_producer")
+@tf_export(v1=["train.range_input_producer"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.range(limit).shuffle(limit).repeat(num_epochs)`. If "
+ "`shuffle=False`, omit the `.shuffle(...)`.")
def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None):
"""Produces the integers from 0 to limit-1 in a queue.
@@ -298,7 +317,12 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
shared_name, "fraction_of_%d_full" % capacity, name)
-@tf_export("train.slice_input_producer")
+@tf_export(v1=["train.slice_input_producer"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.from_tensor_slices(tuple(tensor_list)).shuffle"
+ "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+ "`shuffle=False`, omit the `.shuffle(...)`.")
def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None):
"""Produces a slice of each `Tensor` in `tensor_list`.
@@ -893,7 +917,11 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity,
# Batching functions ----------------------------------------------------------
-@tf_export("train.batch")
+@tf_export(v1=["train.batch"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.batch(batch_size)` (or `padded_batch(...)` if "
+ "`dynamic_pad=True`).")
def batch(tensors, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -988,7 +1016,11 @@ def batch(tensors, batch_size, num_threads=1, capacity=32,
name=name)
-@tf_export("train.maybe_batch")
+@tf_export(v1=["train.maybe_batch"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.filter(...).batch(batch_size)` (or `padded_batch(...)`"
+ " if `dynamic_pad=True`).")
def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1041,7 +1073,11 @@ def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
name=name)
-@tf_export("train.batch_join")
+@tf_export(v1=["train.batch_join"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.interleave(...).batch(batch_size)` (or "
+ "`padded_batch(...)` if `dynamic_pad=True`).")
def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
shapes=None, dynamic_pad=False, allow_smaller_final_batch=False,
shared_name=None, name=None):
@@ -1147,7 +1183,11 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
name=name)
-@tf_export("train.maybe_batch_join")
+@tf_export(v1=["train.maybe_batch_join"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.interleave(...).filter(...).batch(batch_size)` (or "
+ "`padded_batch(...)` if `dynamic_pad=True`).")
def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None,
@@ -1200,7 +1240,10 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
name=name)
-@tf_export("train.shuffle_batch")
+@tf_export(v1=["train.shuffle_batch"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.shuffle(min_after_dequeue).batch(batch_size)`.")
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
num_threads=1, seed=None, enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1300,7 +1343,11 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
name=name)
-@tf_export("train.maybe_shuffle_batch")
+@tf_export(v1=["train.maybe_shuffle_batch"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.filter(...).shuffle(min_after_dequeue).batch(batch_size)`"
+ ".")
def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
keep_input, num_threads=1, seed=None,
enqueue_many=False, shapes=None,
@@ -1360,7 +1407,11 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
name=name)
-@tf_export("train.shuffle_batch_join")
+@tf_export(v1=["train.shuffle_batch_join"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.interleave(...).shuffle(min_after_dequeue).batch"
+ "(batch_size)`.")
def shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, seed=None, enqueue_many=False,
shapes=None, allow_smaller_final_batch=False,
@@ -1454,7 +1505,11 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
name=name)
-@tf_export("train.maybe_shuffle_batch_join")
+@tf_export(v1=["train.maybe_shuffle_batch_join"])
+@deprecation.deprecated(
+ None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+ "`tf.data.Dataset.interleave(...).filter(...).shuffle(min_after_dequeue)"
+ ".batch(batch_size)`.")
def maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, keep_input, seed=None,
enqueue_many=False, shapes=None,
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 1b1e89cb26..a9b05dcc73 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -51,7 +51,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
for name in additional:
open(name, "w").write("Some contents")
filenames = list(set(filenames + additional))
- with self.test_session():
+ with self.cached_session():
star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
question = inp.match_filenames_once(
os.path.join(self.get_temp_dir(), "match_filenames.?"))
@@ -66,7 +66,7 @@ class MatchFilenamesOnceTest(test_lib.TestCase):
class LimitEpochsTest(test_lib.TestCase):
def testNoLimit(self):
- with self.test_session():
+ with self.cached_session():
seven = constant_op.constant(7)
seven_forever = inp.limit_epochs(seven)
variables.local_variables_initializer().run()
@@ -74,7 +74,7 @@ class LimitEpochsTest(test_lib.TestCase):
self.assertEqual(7, seven_forever.eval())
def testLimit(self):
- with self.test_session():
+ with self.cached_session():
love_me = constant_op.constant("Love Me")
love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
variables.global_variables_initializer().run()
@@ -88,7 +88,7 @@ class LimitEpochsTest(test_lib.TestCase):
class InputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
input_tensor = [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]
@@ -111,7 +111,7 @@ class InputProducerTest(test_lib.TestCase):
thread.join()
def testNoShapeInference(self):
- with self.test_session():
+ with self.cached_session():
# Disable shape inference for the input.
input_value = [[1, 2, 3, 4],
[5, 6, 7, 8],
@@ -144,7 +144,7 @@ class InputProducerTest(test_lib.TestCase):
class StringInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
num_epochs = 3
queue = inp.string_input_producer(
@@ -166,7 +166,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"a", b"b", b"c"]
num_epochs = 600
queue = inp.string_input_producer(
@@ -206,7 +206,7 @@ class StringInputProducerTest(test_lib.TestCase):
def testNullStringPython(self):
# Graph-construction time check for empty string list:
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
_ = inp.string_input_producer([])
@@ -214,7 +214,7 @@ class StringInputProducerTest(test_lib.TestCase):
# Runtime check for empty string list. This is slightly oblique:
# The queue runner should die with an assertion error on the null
# input tensor, causing the dequeue to fail with an OutOfRangeError.
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
queue = inp.string_input_producer(
constant_op.constant(
@@ -230,7 +230,7 @@ class StringInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(
strings, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -238,7 +238,7 @@ class StringInputProducerTest(test_lib.TestCase):
queue.queue_ref.op.node_def.attr["shared_name"])
def testConstructionRace(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
queue = inp.string_input_producer(strings, shuffle=False)
coord = coordinator.Coordinator()
@@ -260,7 +260,7 @@ class StringInputProducerTest(test_lib.TestCase):
class RangeInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 3
range_size = 5
queue = inp.range_input_producer(
@@ -282,7 +282,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session():
+ with self.cached_session():
num_epochs = 200
range_size = 2
queue = inp.range_input_producer(
@@ -321,7 +321,7 @@ class RangeInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
range_size = 5
queue = inp.range_input_producer(
range_size, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -332,7 +332,7 @@ class RangeInputProducerTest(test_lib.TestCase):
class SliceInputProducerTest(test_lib.TestCase):
def testNoShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 3
source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
source_ints = [2, 3, 5, 7]
@@ -356,7 +356,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testShuffle(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_epochs = 1200
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
@@ -400,7 +400,7 @@ class SliceInputProducerTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
source_strings = ["A", "B", "D", "G"]
source_ints = [7, 3, 5, 2]
slices = inp.slice_input_producer(
@@ -440,7 +440,7 @@ class DictHelperTest(test_lib.TestCase):
class BatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -500,7 +500,7 @@ class BatchTest(test_lib.TestCase):
def testUint32DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -511,7 +511,7 @@ class BatchTest(test_lib.TestCase):
def testUint64DataTypes(self):
values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
batched = inp.batch([values], batch_size=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
sess.run(batched)
@@ -520,7 +520,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -550,7 +550,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -585,7 +585,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -625,7 +625,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -682,7 +682,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -737,7 +737,7 @@ class BatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -754,7 +754,7 @@ class BatchTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch([x], batch_size=2)
@@ -797,7 +797,7 @@ class BatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -934,7 +934,7 @@ class BatchTest(test_lib.TestCase):
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -952,7 +952,7 @@ class BatchTest(test_lib.TestCase):
class BatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, "a").
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1069,7 +1069,7 @@ class BatchJoinTest(test_lib.TestCase):
batch_size=8)
def DISABLED_testTwoThreadsDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1144,7 +1144,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, "a").
num_a = 70 + extra_elements
@@ -1243,7 +1243,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
extra_elements = 2
# Two threads, the first generates (0..69, ["a"] * 1..70).
num_a = 70 + extra_elements
@@ -1338,7 +1338,7 @@ class BatchJoinTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1360,7 +1360,7 @@ class BatchJoinTest(test_lib.TestCase):
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
def testCannotInferRankError(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.int64)
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
inp.batch_join([[x]], batch_size=2)
@@ -1371,7 +1371,7 @@ class BatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1511,7 +1511,7 @@ class BatchJoinTest(test_lib.TestCase):
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
- with self.test_session():
+ with self.cached_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -1529,7 +1529,7 @@ class BatchJoinTest(test_lib.TestCase):
class ShuffleBatchTest(test_lib.TestCase):
def _testOneThreadHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1594,7 +1594,7 @@ class ShuffleBatchTest(test_lib.TestCase):
self._testOneThreadHelper(use_dict=True)
def testOneThreadSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1650,7 +1650,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1697,7 +1697,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testManyThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
num_batches = 3
extra_elements = 5
@@ -1755,7 +1755,7 @@ class ShuffleBatchTest(test_lib.TestCase):
thread.join()
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1775,7 +1775,7 @@ class ShuffleBatchTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
@@ -1906,7 +1906,7 @@ class ShuffleBatchTest(test_lib.TestCase):
class ShuffleBatchJoinTest(test_lib.TestCase):
def _testTwoThreadsHelper(self, use_dict):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..24, "a").
num_a = 25
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2017,7 +2017,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
self._testTwoThreadsHelper(use_dict=True)
def testTwoThreadsSmallerBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Two threads, the first generates (0..26, "a").
extra_elements = 2
num_a = 25 + extra_elements
@@ -2137,7 +2137,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
seed=223607)
def testSharedName(self):
- with self.test_session():
+ with self.cached_session():
batch_size = 10
num_batches = 3
zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2162,7 +2162,7 @@ class ShuffleBatchJoinTest(test_lib.TestCase):
def _testKeepInputHelper(self, num_threads, enqueue_many,
keep_input_vector=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 5
num_batches = 4
examples = variables.Variable(0)
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index fd195a7965..29b5465321 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -17,19 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.training import learning_rate_decay_v2
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.exponential_decay")
+@tf_export(v1=["train.exponential_decay"])
def exponential_decay(learning_rate,
global_step,
decay_steps,
@@ -95,32 +88,19 @@ def exponential_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for exponential_decay.")
- with ops.name_scope(
- name, "ExponentialDecay",
- [learning_rate, global_step, decay_steps, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- return math_ops.multiply(
- learning_rate, math_ops.pow(decay_rate, p), name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.piecewise_constant")
+ decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=staircase,
+ name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.piecewise_constant"])
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
@@ -163,58 +143,15 @@ def piecewise_constant(x, boundaries, values, name=None):
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if len(boundaries) != len(values) - 1:
- raise ValueError(
- "The length of boundaries should be 1 less than the length of values")
- with ops.name_scope(name, "PiecewiseConstant",
- [x, boundaries, values, name]) as name:
- boundaries = ops.convert_n_to_tensor(boundaries)
- values = ops.convert_n_to_tensor(values)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- x_recomp = ops.convert_to_tensor(x)
- # Avoid explicit conversion to x's dtype. This could result in faulty
- # comparisons, for example if floats are converted to integers.
- for i, b in enumerate(boundaries):
- if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
- # We can promote int32 boundaries to int64 without loss of precision.
- # This covers the most common case where the user passes in boundaries
- # as an array of Python integers.
- if (b.dtype.base_dtype == dtypes.int32 and
- x_recomp.dtype.base_dtype == dtypes.int64):
- b = math_ops.cast(b, x_recomp.dtype.base_dtype)
- boundaries[i] = b
- else:
- raise ValueError(
- "Boundaries (%s) must have the same dtype as x (%s)." %
- (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
- # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
- for v in values[1:]:
- if v.dtype.base_dtype != values[0].dtype.base_dtype:
- raise ValueError(
- "Values must have elements all with the same dtype (%s vs %s)." %
- (values[0].dtype.base_dtype, v.dtype.base_dtype))
- pred_fn_pairs = []
- pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
- pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
- for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
- # Need to bind v here; can do this with lambda v=v: ...
- pred = (x_recomp > low) & (x_recomp <= high)
- pred_fn_pairs.append((pred, lambda v=v: v))
-
- # The default isn't needed here because our conditions are mutually
- # exclusive and exhaustive, but tf.case requires it.
- default = lambda: values[0]
- return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.polynomial_decay")
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values,
+ name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.polynomial_decay"])
def polynomial_decay(learning_rate,
global_step,
decay_steps,
@@ -299,46 +236,22 @@ def polynomial_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for polynomial_decay.")
- with ops.name_scope(
- name, "PolynomialDecay",
- [learning_rate, global_step, decay_steps, end_learning_rate, power
- ]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- end_learning_rate = math_ops.cast(end_learning_rate, dtype)
- power = math_ops.cast(power, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- decay_steps_recomp = math_ops.cast(decay_steps, dtype)
- if cycle:
- # Find the first multiple of decay_steps that is bigger than
- # global_step. If global_step is zero set the multiplier to 1
- multiplier = control_flow_ops.cond(
- math_ops.equal(global_step_recomp, 0), lambda: 1.0,
- lambda: math_ops.ceil(global_step_recomp / decay_steps))
- decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
- else:
- # Make sure that the global_step used is not bigger than decay_steps.
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
-
- p = math_ops.div(global_step_recomp, decay_steps_recomp)
- return math_ops.add(
- math_ops.multiply(learning_rate - end_learning_rate,
- math_ops.pow(1 - p, power)),
- end_learning_rate,
- name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.natural_exp_decay")
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ end_learning_rate=end_learning_rate,
+ power=power,
+ cycle=cycle,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.natural_exp_decay"])
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
@@ -410,32 +323,17 @@ def natural_exp_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for natural_exp_decay.")
- with ops.name_scope(name, "NaturalExpDecay",
- [learning_rate, global_step, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- exponent = math_ops.exp(
- math_ops.multiply(math_ops.negative(decay_rate), p))
- return math_ops.multiply(learning_rate, exponent, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.inverse_time_decay")
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+ learning_rate, global_step, decay_steps, decay_rate, staircase=staircase,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.inverse_time_decay"])
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
@@ -507,32 +405,21 @@ def inverse_time_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("global_step is required for inverse_time_decay.")
- with ops.name_scope(name, "InverseTimeDecay",
- [learning_rate, global_step, decay_rate]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- decay_rate = math_ops.cast(decay_rate, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- p = global_step_recomp / decay_steps
- if staircase:
- p = math_ops.floor(p)
- const = math_ops.cast(constant_op.constant(1), dtype)
- denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
- return math_ops.div(learning_rate, denom, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.cosine_decay")
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=staircase,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.cosine_decay"])
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
"""Applies cosine decay to the learning rate.
@@ -581,32 +468,16 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("cosine decay requires global_step")
- with ops.name_scope(name, "CosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- completed_fraction = global_step_recomp / decay_steps
- cosine_decayed = 0.5 * (1.0 + math_ops.cos(
- constant_op.constant(math.pi) * completed_fraction))
-
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed)
+ decayed_lr = learning_rate_decay_v2.cosine_decay(
+ learning_rate, global_step, decay_steps, alpha=alpha, name=name)
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
- return decayed_lr
+ return decayed_lr
-@tf_export("train.cosine_decay_restarts")
+@tf_export(v1=["train.cosine_decay_restarts"])
def cosine_decay_restarts(learning_rate,
global_step,
first_decay_steps,
@@ -664,57 +535,22 @@ def cosine_decay_restarts(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("cosine decay restarts requires global_step")
- with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(
- learning_rate, name="initial_learning_rate")
- dtype = learning_rate.dtype
- first_decay_steps = math_ops.cast(first_decay_steps, dtype)
- alpha = math_ops.cast(alpha, dtype)
- t_mul = math_ops.cast(t_mul, dtype)
- m_mul = math_ops.cast(m_mul, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- completed_fraction = global_step_recomp / first_decay_steps
-
- def compute_step(completed_fraction, geometric=False):
- """Helper for `cond` operation."""
- if geometric:
- i_restart = math_ops.floor(
- math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
- math_ops.log(t_mul))
-
- sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
- completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
-
- else:
- i_restart = math_ops.floor(completed_fraction)
- completed_fraction -= i_restart
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ learning_rate,
+ global_step,
+ first_decay_steps,
+ t_mul=t_mul,
+ m_mul=m_mul,
+ alpha=alpha,
+ name=name)
- return i_restart, completed_fraction
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
- i_restart, completed_fraction = control_flow_ops.cond(
- math_ops.equal(t_mul, 1.0),
- lambda: compute_step(completed_fraction, geometric=False),
- lambda: compute_step(completed_fraction, geometric=True))
+ return decayed_lr
- m_fac = m_mul**i_restart
- cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
- constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.linear_cosine_decay")
+@tf_export(v1=["train.linear_cosine_decay"])
def linear_cosine_decay(learning_rate,
global_step,
decay_steps,
@@ -781,37 +617,22 @@ def linear_cosine_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("linear cosine decay requires global_step")
- with ops.name_scope(name, "LinearCosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- num_periods = math_ops.cast(num_periods, dtype)
- alpha = math_ops.cast(alpha, dtype)
- beta = math_ops.cast(beta, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- linear_decayed = (decay_steps - global_step_recomp) / decay_steps
- completed_fraction = global_step_recomp / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
-
- linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
- return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
-
-
-@tf_export("train.noisy_linear_cosine_decay")
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ learning_rate,
+ global_step,
+ decay_steps,
+ num_periods=num_periods,
+ alpha=alpha,
+ beta=beta,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
+
+
+@tf_export(v1=["train.noisy_linear_cosine_decay"])
def noisy_linear_cosine_decay(learning_rate,
global_step,
decay_steps,
@@ -886,42 +707,17 @@ def noisy_linear_cosine_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
- if global_step is None:
- raise ValueError("noisy linear cosine decay requires global_step")
- with ops.name_scope(name, "NoisyLinearCosineDecay",
- [learning_rate, global_step]) as name:
- learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
- dtype = learning_rate.dtype
- decay_steps = math_ops.cast(decay_steps, dtype)
- initial_variance = math_ops.cast(initial_variance, dtype)
- variance_decay = math_ops.cast(variance_decay, dtype)
- num_periods = math_ops.cast(num_periods, dtype)
- alpha = math_ops.cast(alpha, dtype)
- beta = math_ops.cast(beta, dtype)
-
- def decayed_lr():
- """Helper to recompute learning rate; most helpful in eager-mode."""
- global_step_recomp = math_ops.cast(global_step, dtype)
- global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
- linear_decayed = (decay_steps - global_step_recomp) / decay_steps
- variance = initial_variance / (
- math_ops.pow(1.0 + global_step_recomp, variance_decay))
- std = math_ops.sqrt(variance)
- noisy_linear_decayed = (
- linear_decayed + random_ops.random_normal(
- linear_decayed.shape, stddev=std))
-
- completed_fraction = global_step_recomp / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
- noisy_linear_cosine_decayed = (
- (alpha + noisy_linear_decayed) * cosine_decayed + beta)
-
- return math_ops.multiply(
- learning_rate, noisy_linear_cosine_decayed, name=name)
-
- if not context.executing_eagerly():
- decayed_lr = decayed_lr()
-
- return decayed_lr
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ learning_rate, global_step,
+ decay_steps,
+ initial_variance=initial_variance,
+ variance_decay=variance_decay,
+ num_periods=num_periods,
+ alpha=alpha,
+ beta=beta,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 4f3cf01822..5a9215730e 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -62,7 +62,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testVariables(self):
- with self.test_session():
+ with self.cached_session():
step = variables.Variable(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py
new file mode 100644
index 0000000000..9c5e144be6
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2.py
@@ -0,0 +1,898 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various learning rate decay functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("train.exponential_decay", v1=[])
+def exponential_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies exponential decay to the learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg function that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions.
+ It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate *
+ decay_rate ^ (global_step / decay_steps)
+ ```
+
+ If the argument `staircase` is `True`, then `global_step / decay_steps` is an
+ integer division and the decayed learning rate follows a staircase function.
+
+ Example: decay every 100000 steps with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate_fn = tf.train.exponential_decay(starter_learning_rate,
+ global_step, 100000, 0.96,
+ staircase=True)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ decay_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The decay rate.
+ staircase: Boolean. If `True` decay the learning rate at discrete intervals
+ name: String. Optional name of the operation. Defaults to
+ 'ExponentialDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for exponential_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate,
+ staircase, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(
+ name, "ExponentialDecay",
+ [learning_rate, global_step, decay_steps, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.multiply(
+ learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.piecewise_constant", v1=[])
+def piecewise_constant(x, boundaries, values, name=None):
+ """Piecewise constant from boundaries and interval values.
+
+ This function returns a no-arg callable to compute the piecewise constant.
+ This can be useful for changing the learning rate value across
+ different invocations of optimizer functions.
+
+ Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
+ for the next 10000 steps, and 0.1 for any additional steps.
+
+ ```python
+ global_step = tf.Variable(0, trainable=False)
+ boundaries = [100000, 110000]
+ values = [1.0, 0.5, 0.1]
+ learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries,
+ values)
+ learning_rate = learning_rate_fn()
+
+ # Later, whenever we perform an optimization step, we increment global_step.
+ ```
+
+ Args:
+ x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
+ boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
+ increasing entries, and with all elements having the same type as `x`.
+ values: A list of `Tensor`s or `float`s or `int`s that specifies the values
+ for the intervals defined by `boundaries`. It should have one more element
+ than `boundaries`, and all elements should have the same type.
+ name: A string. Optional name of the operation. Defaults to
+ 'PiecewiseConstant'.
+
+ Returns:
+ A no-arg function that outputs a 0-D Tensor. The output of the no-arg
+ function is `values[0]` when `x <= boundaries[0]`,
+ `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
+ and values[-1] when `x > boundaries[-1]`.
+
+ Raises:
+ ValueError: if types of `x` and `boundaries` do not match, or types of all
+ `values` do not match or
+ the number of elements in the lists does not match.
+ """
+ if len(boundaries) != len(values) - 1:
+ raise ValueError(
+ "The length of boundaries should be 1 less than the length of values")
+ def decayed_lr(x, boundaries, values, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "PiecewiseConstant",
+ [x, boundaries, values, name]) as name:
+ boundaries = ops.convert_n_to_tensor(boundaries)
+ values = ops.convert_n_to_tensor(values)
+ x_recomp = ops.convert_to_tensor(x)
+ # Avoid explicit conversion to x's dtype. This could result in faulty
+ # comparisons, for example if floats are converted to integers.
+ for i, b in enumerate(boundaries):
+ if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
+ # We can promote int32 boundaries to int64 without loss of precision.
+ # This covers the most common case where the user passes in boundaries
+ # as an array of Python integers.
+ if (b.dtype.base_dtype == dtypes.int32 and
+ x_recomp.dtype.base_dtype == dtypes.int64):
+ b = math_ops.cast(b, x_recomp.dtype.base_dtype)
+ boundaries[i] = b
+ else:
+ raise ValueError(
+ "Boundaries (%s) must have the same dtype as x (%s)." %
+ (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
+ # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
+ for v in values[1:]:
+ if v.dtype.base_dtype != values[0].dtype.base_dtype:
+ raise ValueError(
+ "Values must have elements all with the same dtype (%s vs %s)." %
+ (values[0].dtype.base_dtype, v.dtype.base_dtype))
+ pred_fn_pairs = []
+ pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
+ pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
+ for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+ # Need to bind v here; can do this with lambda v=v: ...
+ pred = (x_recomp > low) & (x_recomp <= high)
+ pred_fn_pairs.append((pred, lambda v=v: v))
+
+ # The default isn't needed here because our conditions are mutually
+ # exclusive and exhaustive, but tf.case requires it.
+ default = lambda: values[0]
+ return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ return functools.partial(decayed_lr, x, boundaries, values, name)
+
+
+@tf_export("train.polynomial_decay", v1=[])
+def polynomial_decay(learning_rate,
+ global_step,
+ decay_steps,
+ end_learning_rate=0.0001,
+ power=1.0,
+ cycle=False,
+ name=None):
+ """Applies a polynomial decay to the learning rate.
+
+ It is commonly observed that a monotonically decreasing learning rate, whose
+ degree of change is carefully chosen, results in a better performing model.
+ This function applies a polynomial decay function to a provided initial
+ `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`.
+
+ It requires a `global_step` value to compute the decayed learning rate. You
+ can just pass a TensorFlow variable that you increment at each training step.
+
+ The function returns a no-arg callable that outputs the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ decayed_learning_rate = (learning_rate - end_learning_rate) *
+ (1 - global_step / decay_steps) ^ (power) +
+ end_learning_rate
+
+ ```
+
+ If `cycle` is True then a multiple of `decay_steps` is used, the first one
+ that is bigger than `global_steps`.
+
+ ```python
+ decay_steps = decay_steps * ceil(global_step / decay_steps)
+ decayed_learning_rate_fn = (learning_rate - end_learning_rate) *
+ (1 - global_step / decay_steps) ^ (power) +
+ end_learning_rate
+ decayed_learning_rate = decayed_learning_rate_fn()
+
+ ```
+
+ Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5):
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ end_learning_rate = 0.01
+ decay_steps = 10000
+ learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate,
+ global_step, decay_steps,
+ end_learning_rate,
+ power=0.5)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The minimal end learning rate.
+ power: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The power of the polynomial. Defaults to linear, 1.0.
+ cycle: A boolean, whether or not it should cycle beyond decay_steps.
+ name: String. Optional name of the operation. Defaults to
+ 'PolynomialDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for polynomial_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate,
+ power, cycle, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(
+ name, "PolynomialDecay",
+ [learning_rate, global_step, decay_steps, end_learning_rate, power]
+ ) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ end_learning_rate = math_ops.cast(end_learning_rate, dtype)
+ power = math_ops.cast(power, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ decay_steps_recomp = math_ops.cast(decay_steps, dtype)
+ if cycle:
+ # Find the first multiple of decay_steps that is bigger than
+ # global_step. If global_step is zero set the multiplier to 1
+ multiplier = control_flow_ops.cond(
+ math_ops.equal(global_step_recomp, 0), lambda: 1.0,
+ lambda: math_ops.ceil(global_step_recomp / decay_steps))
+ decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
+ else:
+ # Make sure that the global_step used is not bigger than decay_steps.
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+
+ p = math_ops.div(global_step_recomp, decay_steps_recomp)
+ return math_ops.add(
+ math_ops.multiply(learning_rate - end_learning_rate,
+ math_ops.pow(1 - p, power)),
+ end_learning_rate,
+ name=name)
+
+ return functools.partial(
+ decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate,
+ power, cycle, name)
+
+
+@tf_export("train.natural_exp_decay", v1=[])
+def natural_exp_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies natural exponential decay to the initial learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires an `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
+ decay_step)
+ ```
+
+ or, if `staircase` is `True`, as:
+
+ ```python
+ decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
+ decay_step))
+ ```
+
+ Example: decay exponentially with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ learning_rate = 0.1
+ decay_steps = 5
+ k = 0.5
+ learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step,
+ decay_steps, k)
+
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: How often to apply decay.
+ decay_rate: A Python number. The decay rate.
+ staircase: Whether to apply decay in a discrete staircase, as opposed to
+ continuous, fashion.
+ name: String. Optional name of the operation. Defaults to
+ 'ExponentialTimeDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for natural_exp_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+ name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "NaturalExpDecay",
+ [learning_rate, global_step, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ exponent = math_ops.exp(
+ math_ops.multiply(math_ops.negative(decay_rate), p))
+ return math_ops.multiply(learning_rate, exponent, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.inverse_time_decay", v1=[])
+def inverse_time_decay(learning_rate,
+ global_step,
+ decay_steps,
+ decay_rate,
+ staircase=False,
+ name=None):
+ """Applies inverse time decay to the initial learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an inverse decay function
+ to a provided initial learning rate. It requires an `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /
+ decay_step)
+ ```
+
+ or, if `staircase` is `True`, as:
+
+ ```python
+ decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step /
+ decay_step))
+ ```
+
+ Example: decay 1/t with a rate of 0.5:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ learning_rate = 0.1
+ decay_steps = 1.0
+ decay_rate = 0.5
+ learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step,
+ decay_steps, decay_rate)
+
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate_fn)
+ .minimize(...my loss..., global_step=global_step)
+ )
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: How often to apply decay.
+ decay_rate: A Python number. The decay rate.
+ staircase: Whether to apply decay in a discrete staircase, as opposed to
+ continuous, fashion.
+ name: String. Optional name of the operation. Defaults to
+ 'InverseTimeDecay'.
+
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("global_step is required for inverse_time_decay.")
+ def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+ name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "InverseTimeDecay",
+ [learning_rate, global_step, decay_rate]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ const = math_ops.cast(constant_op.constant(1), dtype)
+ denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
+ return math_ops.div(learning_rate, denom, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ decay_rate, staircase, name)
+
+
+@tf_export("train.cosine_decay", v1=[])
+def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0,
+ name=None):
+ """Applies cosine decay to the learning rate.
+
+ See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a cosine decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
+ decayed = (1 - alpha) * cosine_decay + alpha
+ decayed_learning_rate = learning_rate * decayed
+ ```
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ alpha: A scalar `float32` or `float64` Tensor or a Python number.
+ Minimum learning rate value as a fraction of learning_rate.
+ name: String. Optional name of the operation. Defaults to 'CosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, alpha, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "CosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ completed_fraction = global_step_recomp / decay_steps
+ cosine_decayed = 0.5 * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+
+ decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ alpha, name)
+
+
+@tf_export("train.cosine_decay_restarts", v1=[])
+def cosine_decay_restarts(learning_rate,
+ global_step,
+ first_decay_steps,
+ t_mul=2.0,
+ m_mul=1.0,
+ alpha=0.0,
+ name=None):
+ """Applies cosine decay with restarts to the learning rate.
+
+ See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a cosine decay function with
+ restarts to a provided initial learning rate. It requires a `global_step`
+ value to compute the decayed learning rate. You can just pass a TensorFlow
+ variable that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate while taking into account possible warm restarts. This can be useful for
+ changing the learning rate value across different invocations of optimizer
+ functions.
+
+ The learning rate multiplier first decays
+ from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
+ restart is performed. Each new warm restart runs for `t_mul` times more steps
+ and with `m_mul` times smaller initial learning rate.
+
+ Example usage:
+ ```python
+ first_decay_steps = 1000
+ lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step,
+ first_decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Used to derive the number of iterations in the i-th period
+ m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Used to derive the initial learning rate of the i-th period:
+ alpha: A scalar `float32` or `float64` Tensor or a Python number.
+ Minimum learning rate value as a fraction of the learning_rate.
+ name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("cosine decay restarts requires global_step")
+ def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul,
+ alpha, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]
+ ) as name:
+ learning_rate = ops.convert_to_tensor(
+ learning_rate, name="initial_learning_rate")
+ dtype = learning_rate.dtype
+ first_decay_steps = math_ops.cast(first_decay_steps, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ t_mul = math_ops.cast(t_mul, dtype)
+ m_mul = math_ops.cast(m_mul, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ completed_fraction = global_step_recomp / first_decay_steps
+
+ def compute_step(completed_fraction, geometric=False):
+ """Helper for `cond` operation."""
+ if geometric:
+ i_restart = math_ops.floor(
+ math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
+ math_ops.log(t_mul))
+
+ sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
+ completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
+
+ else:
+ i_restart = math_ops.floor(completed_fraction)
+ completed_fraction -= i_restart
+
+ return i_restart, completed_fraction
+
+ i_restart, completed_fraction = control_flow_ops.cond(
+ math_ops.equal(t_mul, 1.0),
+ lambda: compute_step(completed_fraction, geometric=False),
+ lambda: compute_step(completed_fraction, geometric=True))
+
+ m_fac = m_mul**i_restart
+ cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+ decayed = (1 - alpha) * cosine_decayed + alpha
+
+ return math_ops.multiply(learning_rate, decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step,
+ first_decay_steps, t_mul, m_mul, alpha, name)
+
+
+@tf_export("train.linear_cosine_decay", v1=[])
+def linear_cosine_decay(learning_rate,
+ global_step,
+ decay_steps,
+ num_periods=0.5,
+ alpha=0.0,
+ beta=0.001,
+ name=None):
+ """Applies linear cosine decay to the learning rate.
+
+ See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+ https://arxiv.org/abs/1709.07417
+
+ For the idea of warm starts here controlled by `num_periods`,
+ see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ Note that linear cosine decay is more aggressive than cosine decay and
+ larger initial learning rates can typically be used.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a linear cosine decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ linear_decay = (decay_steps - global_step) / decay_steps)
+ cosine_decay = 0.5 * (
+ 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+ decayed = (alpha + linear_decay) * cosine_decay + beta
+ decayed_learning_rate = learning_rate * decayed
+ ```
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step,
+ decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ num_periods: Number of periods in the cosine part of the decay.
+ See computation above.
+ alpha: See computation above.
+ beta: See computation above.
+ name: String. Optional name of the operation. Defaults to
+ 'LinearCosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("linear cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha,
+ beta, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "LinearCosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ num_periods = math_ops.cast(num_periods, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ beta = math_ops.cast(beta, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+
+ linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
+ return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ num_periods, alpha, beta, name)
+
+
+@tf_export("train.noisy_linear_cosine_decay", v1=[])
+def noisy_linear_cosine_decay(learning_rate,
+ global_step,
+ decay_steps,
+ initial_variance=1.0,
+ variance_decay=0.55,
+ num_periods=0.5,
+ alpha=0.0,
+ beta=0.001,
+ name=None):
+ """Applies noisy linear cosine decay to the learning rate.
+
+ See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+ https://arxiv.org/abs/1709.07417
+
+ For the idea of warm starts here controlled by `num_periods`,
+ see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+ with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+ Note that linear cosine decay is more aggressive than cosine decay and
+ larger initial learning rates can typically be used.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies a noisy linear
+ cosine decay function to a provided initial learning rate.
+ It requires a `global_step` value to compute the decayed learning rate.
+ You can just pass a TensorFlow variable that you increment at each
+ training step.
+
+ The function returns a no-arg callable that produces the decayed learning
+ rate. This can be useful for changing the learning rate value across
+ different invocations of optimizer functions. It is computed as:
+
+ ```python
+ global_step = min(global_step, decay_steps)
+ linear_decay = (decay_steps - global_step) / decay_steps)
+ cosine_decay = 0.5 * (
+ 1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+ decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
+ decayed_learning_rate = learning_rate * decayed
+ ```
+ where eps_t is 0-centered gaussian noise with variance
+ initial_variance / (1 + global_step) ** variance_decay
+
+ Example usage:
+ ```python
+ decay_steps = 1000
+ lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step,
+ decay_steps)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+ The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Number of steps to decay over.
+ initial_variance: initial variance for the noise. See computation above.
+ variance_decay: decay for the noise's variance. See computation above.
+ num_periods: Number of periods in the cosine part of the decay.
+ See computation above.
+ alpha: See computation above.
+ beta: See computation above.
+ name: String. Optional name of the operation. Defaults to
+ 'NoisyLinearCosineDecay'.
+ Returns:
+ A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+ of the same type as `learning_rate`.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+ if global_step is None:
+ raise ValueError("noisy linear cosine decay requires global_step")
+ def decayed_lr(learning_rate, global_step, decay_steps, initial_variance,
+ variance_decay, num_periods, alpha, beta, name):
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ with ops.name_scope(name, "NoisyLinearCosineDecay",
+ [learning_rate, global_step]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ initial_variance = math_ops.cast(initial_variance, dtype)
+ variance_decay = math_ops.cast(variance_decay, dtype)
+ num_periods = math_ops.cast(num_periods, dtype)
+ alpha = math_ops.cast(alpha, dtype)
+ beta = math_ops.cast(beta, dtype)
+
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ variance = initial_variance / (
+ math_ops.pow(1.0 + global_step_recomp, variance_decay))
+ std = math_ops.sqrt(variance)
+ noisy_linear_decayed = (
+ linear_decayed + random_ops.random_normal(
+ linear_decayed.shape, stddev=std))
+
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ noisy_linear_cosine_decayed = (
+ (alpha + noisy_linear_decayed) * cosine_decayed + beta)
+
+ return math_ops.multiply(
+ learning_rate, noisy_linear_cosine_decayed, name=name)
+
+ return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+ initial_variance, variance_decay, num_periods, alpha,
+ beta, name)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
new file mode 100644
index 0000000000..0f2d60dafc
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -0,0 +1,497 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+# Import resource_variable_ops for the variables-to-tensor implicit conversion.
+from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import learning_rate_decay_v2
+
+
+class LRDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testContinuous(self):
+ self.evaluate(variables.global_variables_initializer())
+ step = 5
+ decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96)
+ expected = .05 * 0.96**(5.0 / 10.0)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ if context.executing_eagerly():
+ step = resource_variable_ops.ResourceVariable(0)
+ self.evaluate(variables.global_variables_initializer())
+ decayed_lr = learning_rate_decay_v2.exponential_decay(
+ .1, step, 3, 0.96, staircase=True)
+
+ # No change to learning rate due to staircase
+ expected = .1
+ self.evaluate(step.assign(1))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ expected = .1
+ self.evaluate(step.assign(2))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ # Decayed learning rate
+ expected = .1 * 0.96 ** (100 // 3)
+ self.evaluate(step.assign(100))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ def testVariables(self):
+ with self.test_session():
+ step = variables.Variable(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay_v2.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ variables.global_variables_initializer().run()
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 // 3)
+ self.assertAllClose(decayed_lr().eval(), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testPiecewiseConstant(self):
+ x = resource_variable_ops.ResourceVariable(-999)
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])
+
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+ self.evaluate(x.assign(100))
+ self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+ self.evaluate(x.assign(105))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+ self.evaluate(x.assign(110))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+ self.evaluate(x.assign(120))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6)
+ self.evaluate(x.assign(999))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testPiecewiseConstantEdgeCases(self):
+ x_int = resource_variable_ops.ResourceVariable(
+ 0, dtype=variables.dtypes.int32)
+ boundaries, values = [-1.0, 1.0], [1, 2, 3]
+ with self.assertRaises(ValueError):
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x_int, boundaries, values)
+ decayed_lr()
+
+ x = resource_variable_ops.ResourceVariable(0.0)
+ boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
+ with self.assertRaises(ValueError):
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x, boundaries, values)()
+ decayed_lr()
+
+ # Test that ref types are valid.
+ if not context.executing_eagerly():
+ x = variables.Variable(0.0)
+ x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
+ boundaries, values = [1.0, 2.0], [1, 2, 3]
+ learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values)
+
+ # Test casting boundaries from int32 to int64.
+ x_int64 = resource_variable_ops.ResourceVariable(
+ 0, dtype=variables.dtypes.int64)
+ boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
+ decayed_lr = learning_rate_decay_v2.piecewise_constant(
+ x_int64, boundaries, values)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+ self.evaluate(x_int64.assign(1))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+ self.evaluate(x_int64.assign(2))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6)
+ self.evaluate(x_int64.assign(3))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6)
+ self.evaluate(x_int64.assign(4))
+ self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6)
+
+
+class LinearDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWay(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = lr * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEnd(self):
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWayWithEnd(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = (lr + end_lr) * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEnd(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEndWithCycle(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, cycle=True)
+ expected = (lr - end_lr) * 0.25 + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class SqrtDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWay(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = lr * 0.5**power
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEnd(self):
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testHalfWayWithEnd(self):
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = (lr - end_lr) * 0.5**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEnd(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeyondEndWithCycle(self):
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, 10, end_lr, power=power, cycle=True)
+ expected = (lr - end_lr) * 0.25**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class PolynomialDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBeginWithCycle(self):
+ lr = 0.001
+ decay_steps = 10
+ step = 0
+ decayed_lr = learning_rate_decay_v2.polynomial_decay(
+ lr, step, decay_steps, cycle=True)
+ expected = lr
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class ExponentialDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k,
+ decay_rate)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+
+class InverseDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k,
+ decay_rate)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testStaircase(self):
+ initial_lr = 0.1
+ k = 10
+ decay_rate = 0.96
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+
+class CosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ def np_cosine_decay(self, step, decay_steps, alpha=0.0):
+ step = min(step, decay_steps)
+ completed_fraction = step / decay_steps
+ decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+ return (1.0 - alpha) * decay + alpha
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+ num_training_steps)
+ expected = self.np_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAlpha(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ alpha = 0.1
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+ num_training_steps,
+ alpha)
+ expected = self.np_cosine_decay(step, num_training_steps, alpha)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase):
+
+ def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
+ alpha=0.0):
+ fac = 1.0
+ while step >= decay_steps:
+ step -= decay_steps
+ decay_steps *= t_mul
+ fac *= m_mul
+
+ completed_fraction = step / decay_steps
+ decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+ return (1.0 - alpha) * decay + alpha
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps)
+ expected = self.np_cosine_decay_restarts(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAlpha(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ alpha = 0.1
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, alpha=alpha)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, alpha=alpha)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testMMul(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ m_mul = 0.9
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, m_mul=m_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, m_mul=m_mul)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testTMul(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ t_mul = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, t_mul=t_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, t_mul=t_mul)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class LinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ def np_linear_cosine_decay(self,
+ step,
+ decay_steps,
+ alpha=0.0,
+ beta=0.001,
+ num_periods=0.5):
+ step = min(step, decay_steps)
+ linear_decayed = float(decay_steps - step) / decay_steps
+ fraction = 2.0 * num_periods * step / float(decay_steps)
+ cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
+ return (alpha + linear_decayed) * cosine_decayed + beta
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDefaultDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ expected = self.np_linear_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNonDefaultDecay(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ expected = self.np_linear_cosine_decay(
+ step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDefaultNoisyLinearCosine(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNonDefaultNoisyLinearCosine(self):
+ num_training_steps = 1000
+ initial_lr = 1.0
+ for step in range(0, 1500, 250):
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ initial_variance=0.5,
+ variance_decay=0.1,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr())
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index f7e78071d8..8a21c39d32 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -123,7 +123,7 @@ class MomentumOptimizerTest(test.TestCase):
]), self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -162,7 +162,7 @@ class MomentumOptimizerTest(test.TestCase):
def testNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -188,7 +188,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparseNesterovMomentum(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -282,7 +282,7 @@ class MomentumOptimizerTest(test.TestCase):
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -435,7 +435,7 @@ class MomentumOptimizerTest(test.TestCase):
return db_grad, db_out
def testLikeDistBeliefMom01(self):
- with self.test_session():
+ with self.cached_session():
db_grad, db_out = self._dbParamsMom01()
num_samples = len(db_grad)
var0 = variables.Variable([0.0] * num_samples)
@@ -449,7 +449,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSparse(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
grads0 = ops.IndexedSlices(
@@ -518,7 +518,7 @@ class MomentumOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7b06bffa4b..0e0125a956 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -25,6 +25,7 @@ import sys
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -284,6 +285,63 @@ class Scaffold(object):
resources.initialize_resources(resources.local_resources()))
+def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring
+ scaffold,
+ checkpoint_dir=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=None,
+ save_summaries_steps=None,
+ save_summaries_secs=None,
+ config=None,
+ stop_grace_period_secs=120,
+ log_step_count_steps=100,
+ max_wait_secs=7200,
+ save_checkpoint_steps=None,
+ summary_dir=None):
+ all_hooks = []
+ if hooks:
+ all_hooks.extend(hooks)
+ if chief_only_hooks and worker_context.is_chief:
+ all_hooks.extend(chief_only_hooks)
+
+ summary_dir = summary_dir or checkpoint_dir
+ if summary_dir and worker_context.should_save_summary:
+ if log_step_count_steps and log_step_count_steps > 0:
+ all_hooks.append(
+ basic_session_run_hooks.StepCounterHook(
+ output_dir=summary_dir, every_n_steps=log_step_count_steps))
+
+ if (save_summaries_steps and save_summaries_steps > 0) or (
+ save_summaries_secs and save_summaries_secs > 0):
+ all_hooks.append(
+ basic_session_run_hooks.SummarySaverHook(
+ scaffold=scaffold,
+ save_steps=save_summaries_steps,
+ save_secs=save_summaries_secs,
+ output_dir=summary_dir))
+
+ if checkpoint_dir and worker_context.should_checkpoint:
+ if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
+ save_checkpoint_steps and save_checkpoint_steps > 0):
+ all_hooks.append(
+ basic_session_run_hooks.CheckpointSaverHook(
+ checkpoint_dir,
+ save_steps=save_checkpoint_steps,
+ save_secs=save_checkpoint_secs,
+ scaffold=scaffold))
+
+ session_creator = worker_context.session_creator(
+ scaffold,
+ config=config,
+ checkpoint_dir=checkpoint_dir,
+ max_wait_secs=max_wait_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=all_hooks,
+ stop_grace_period_secs=stop_grace_period_secs)
+
+
@tf_export('train.MonitoredTrainingSession')
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True,
@@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_checkpoint_steps = None
scaffold = scaffold or Scaffold()
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+
+ if worker_context:
+ return _create_monitored_session_with_worker_context(
+ worker_context,
+ scaffold,
+ checkpoint_dir=checkpoint_dir,
+ hooks=hooks,
+ chief_only_hooks=chief_only_hooks,
+ save_checkpoint_secs=save_checkpoint_secs,
+ save_summaries_steps=save_summaries_steps,
+ save_summaries_secs=save_summaries_secs,
+ config=config,
+ stop_grace_period_secs=stop_grace_period_secs,
+ log_step_count_steps=log_step_count_steps,
+ max_wait_secs=max_wait_secs,
+ save_checkpoint_steps=save_checkpoint_steps,
+ summary_dir=summary_dir)
+
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
- return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
- stop_grace_period_secs=stop_grace_period_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=hooks or [],
+ stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
if chief_only_hooks:
@@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
- all_hooks.append(basic_session_run_hooks.SummarySaverHook(
- scaffold=scaffold,
- save_steps=save_summaries_steps,
- save_secs=save_summaries_secs,
- output_dir=summary_dir))
+ all_hooks.append(
+ basic_session_run_hooks.SummarySaverHook(
+ scaffold=scaffold,
+ save_steps=save_summaries_steps,
+ save_secs=save_summaries_secs,
+ output_dir=summary_dir))
if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
- all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
- checkpoint_dir,
- save_steps=save_checkpoint_steps,
- save_secs=save_checkpoint_secs,
- scaffold=scaffold))
+ all_hooks.append(
+ basic_session_run_hooks.CheckpointSaverHook(
+ checkpoint_dir,
+ save_steps=save_checkpoint_steps,
+ save_secs=save_checkpoint_secs,
+ scaffold=scaffold))
if hooks:
all_hooks.extend(hooks)
- return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
- stop_grace_period_secs=stop_grace_period_secs)
+ return MonitoredSession(
+ session_creator=session_creator,
+ hooks=all_hooks,
+ stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.SessionCreator')
@@ -546,6 +629,11 @@ class _MonitoredSession(object):
self._hooks = hooks or []
for h in self._hooks:
h.begin()
+
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+ if not session_creator and worker_context:
+ session_creator = worker_context.session_creator()
+
# Create the session.
self._coordinated_creator = self._CoordinatedSessionCreator(
session_creator=session_creator or ChiefSessionCreator(),
@@ -712,7 +800,8 @@ class _MonitoredSession(object):
self.tf_sess = self._session_creator.create_session()
# We don't want coordinator to suppress any exception.
self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
- queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
+ if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
+ queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
# Inform the hooks that a new session has been created.
for hook in self._hooks:
hook.after_create_session(self.tf_sess, self.coord)
@@ -1275,3 +1364,6 @@ class _HookedSession(_WrappedSession):
options.debug_options.debug_tensor_watch_opts.extend(
incoming_options.debug_options.debug_tensor_watch_opts)
+ options.debug_options.reset_disk_byte_usage = (
+ options.debug_options.reset_disk_byte_usage or
+ incoming_options.debug_options.reset_disk_byte_usage)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 92533ca4f3..2d7799d66a 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib
+from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -79,7 +80,7 @@ class ScaffoldTest(test.TestCase):
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'],
sess.run(scaffold.ready_op))
self.assertItemsEqual([b'my_var'],
@@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
self.assertEqual(0, session.run(gstep))
+class MockStrategy(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=True,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self._between_graph = between_graph
+ self._should_init = should_init
+ self._should_checkpoint = should_checkpoint
+ self._should_save_summary = should_save_summary
+
+ @property
+ def between_graph(self):
+ return self._between_graph
+
+ @property
+ def should_init(self):
+ return self._should_init
+
+ @property
+ def should_checkpoint(self):
+ return self._should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ return self._should_save_summary
+
+
+class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
+ """Test distribute coordinator controls summary saving and checkpointing."""
+
+ def test_summary_hook_enabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_save_summary=True), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ summary.scalar('my_summary_tag', new_gstep * 2)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_summaries_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(101):
+ session.run(new_gstep)
+
+ summaries = util_test.latest_summaries(logdir)
+ tags = [s.summary.value[0].tag for s in summaries]
+ self.assertIn('my_summary_tag', tags)
+ self.assertIn('global_step/sec', tags)
+
+ def test_summary_hook_disabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_save_summary=False), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ summary.scalar('my_summary_tag', new_gstep * 2)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_summaries_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(101):
+ session.run(new_gstep)
+
+ # No summary is saved.
+ summaries = util_test.latest_summaries(logdir)
+ self.assertEqual(len(summaries), 0)
+
+ def test_checkpoint_hook_enabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_checkpoint=True), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(100, session.run(gstep))
+
+ def test_checkpoint_hook_disabled(self):
+ context = distribute_coordinator._WorkerContext(
+ MockStrategy(should_checkpoint=False), None, None, None)
+
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with context, monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+
+ # No checkpoint is saved.
+ checkpoint = checkpoint_management.latest_checkpoint(logdir)
+ self.assertIsNone(checkpoint)
+
+
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
@@ -399,21 +513,21 @@ class WrappedSessionTest(test.TestCase):
"""_WrappedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertEquals(sess.graph, wrapped_sess.graph)
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertFalse(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_uses_check_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = StopAtNSession(sess, 3)
self.assertFalse(wrapped_sess.should_stop())
self.assertFalse(wrapped_sess.should_stop())
@@ -421,7 +535,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_should_stop_delegates_to_wrapped_session(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess0 = StopAtNSession(sess, 4)
wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
self.assertFalse(wrapped_sess1.should_stop())
@@ -431,7 +545,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess1.should_stop())
def test_close_twice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@@ -439,7 +553,7 @@ class WrappedSessionTest(test.TestCase):
self.assertTrue(wrapped_sess.should_stop())
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
@@ -456,7 +570,7 @@ class CoordinatedSessionTest(test.TestCase):
"""_CoordinatedSession tests."""
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
@@ -464,7 +578,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, coord_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -472,7 +586,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
def test_should_stop_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -480,7 +594,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_should_stop_on_coord_stop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
@@ -488,7 +602,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_dont_request_stop_on_exception_in_main_thread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -502,7 +616,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertFalse(coord_sess.should_stop())
def test_stop_threads_on_close_after_exception(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
@@ -532,7 +646,7 @@ class CoordinatedSessionTest(test.TestCase):
self.assertTrue(coord_sess.should_stop())
def test_stop_threads_on_close(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = [
threading.Thread(
@@ -550,7 +664,7 @@ class CoordinatedSessionTest(test.TestCase):
def test_propagates_exception_trace(self):
assertion = control_flow_ops.Assert(False, ['This should fail.'])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator(clean_stop_exception_types=())
coord_sess = monitored_session._CoordinatedSession(sess, coord)
try:
@@ -696,7 +810,7 @@ class RecoverableSessionTest(test.TestCase):
return self._sess
def test_properties(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant_op.constant(0.0)
recoverable_sess = monitored_session._RecoverableSession(
self._SessionReturner(sess))
@@ -704,7 +818,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
def test_run(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
recoverable_sess = monitored_session._RecoverableSession(
@@ -712,7 +826,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_recovery(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
class StackSessionCreator(object):
@@ -758,7 +872,7 @@ class RecoverableSessionTest(test.TestCase):
recoverable_sess.run(v, feed_dict={c: -12})
def test_recovery_from_coordinator_exception(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -783,7 +897,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -812,7 +926,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -836,7 +950,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -866,7 +980,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
@@ -900,7 +1014,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
@@ -944,7 +1058,7 @@ class RecoverableSessionTest(test.TestCase):
return session
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -976,7 +1090,7 @@ class RecoverableSessionTest(test.TestCase):
self.assertEqual(2, session_creator.number_of_sessions_created)
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1013,7 +1127,7 @@ class RecoverableSessionTest(test.TestCase):
session.close()
def test_recovery_from_session_getting_stuck_with_raw_session(self):
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
@@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
- checkpoint_filename_with_path=
- checkpoint_management.latest_checkpoint(logdir))) as session:
+ checkpoint_filename_with_path=checkpoint_management.
+ latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):
@@ -1933,7 +2047,7 @@ class MonitoredSessionTest(test.TestCase):
return value
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session)) as session:
session.run(variables.global_variables_initializer())
@@ -1996,7 +2110,7 @@ class MonitoredSessionTest(test.TestCase):
step_context.session.run(graph_side_effect)
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
- with self.test_session() as test_session:
+ with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session),
hooks=[Hook(self)]) as session:
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 4b91d1e963..177a7ddfa5 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -363,10 +363,12 @@ class ExponentialMovingAverage(object):
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`.
- Returns an op that updates all shadow variables as described above.
+ Returns an op that updates all shadow variables from the current value of
+ their associated variables.
- Note that `apply()` can be called multiple times with different lists of
- variables.
+ Note that `apply()` can be called multiple times. When eager execution is
+ enabled each call to apply will update the variables once, so this needs to
+ be called in a loop.
Args:
var_list: A list of Variable or Tensor objects. The variables
@@ -389,31 +391,30 @@ class ExponentialMovingAverage(object):
dtypes.float64]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
- if var in self._averages:
- raise ValueError("Moving average already computed for: %s" % var.name)
- # For variables: to lower communication bandwidth across devices we keep
- # the moving averages on the same device as the variables. For other
- # tensors, we rely on the existing device allocation mechanism.
- with ops.init_scope():
- if isinstance(var, variables.Variable):
- avg = slot_creator.create_slot(var,
- var.initialized_value(),
- self.name,
- colocate_with_primary=True)
- # NOTE(mrry): We only add `tf.Variable` objects to the
- # `MOVING_AVERAGE_VARIABLES` collection.
- ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
- else:
- avg = slot_creator.create_zeros_slot(
- var,
- self.name,
- colocate_with_primary=(var.op.type in ["Variable",
- "VariableV2",
- "VarHandleOp"]))
- if self._zero_debias:
- zero_debias_true.add(avg)
- self._averages[var] = avg
+ if var not in self._averages:
+ # For variables: to lower communication bandwidth across devices we keep
+ # the moving averages on the same device as the variables. For other
+ # tensors, we rely on the existing device allocation mechanism.
+ with ops.init_scope():
+ if isinstance(var, variables.Variable):
+ avg = slot_creator.create_slot(var,
+ var.initialized_value(),
+ self.name,
+ colocate_with_primary=True)
+ # NOTE(mrry): We only add `tf.Variable` objects to the
+ # `MOVING_AVERAGE_VARIABLES` collection.
+ ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
+ else:
+ avg = slot_creator.create_zeros_slot(
+ var,
+ self.name,
+ colocate_with_primary=(var.op.type in ["Variable",
+ "VariableV2",
+ "VarHandleOp"]))
+ if self._zero_debias:
+ zero_debias_true.add(avg)
+ self._averages[var] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 3e85e6bfa7..93991d0e14 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import variable_scope
@@ -33,7 +35,7 @@ from tensorflow.python.training import saver as saver_lib
class MovingAveragesTest(test.TestCase):
def testAssignMovingAverageWithoutZeroDebias(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([10.0, 11.0])
val = constant_op.constant([1.0, 2.0], dtypes.float32)
decay = 0.25
@@ -47,7 +49,7 @@ class MovingAveragesTest(test.TestCase):
var.eval())
def testAssignMovingAverage(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([0.0, 0.0])
val = constant_op.constant([1.0, 2.0], dtypes.float32)
decay = 0.25
@@ -84,7 +86,7 @@ class MovingAveragesTest(test.TestCase):
moving_averages.assign_moving_average(var, 0.0, 0.99)
def testWeightedMovingAverage(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
decay = 0.5
weight = array_ops.placeholder(dtypes.float32, [])
val = array_ops.placeholder(dtypes.float32, [])
@@ -185,53 +187,53 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertAllClose(expected, avg2.eval())
def testAverageVariablesNoNumUpdates_Scalar(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
def testAverageVariablesNoNumUpdates_Scalar_Debias(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
def testAverageVariablesNoNumUpdates_Vector(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNoNumUpdates_Vector_Debias(self):
- with self.test_session():
+ with self.cached_session():
ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNumUpdates_Scalar(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
def testAverageVariablesNumUpdates_Scalar_Debias(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(
0.25, num_updates=1, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
def testAverageVariablesNumUpdates_Vector(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
def testAverageVariablesNumUpdates_Vector_Debias(self):
- with self.test_session():
+ with self.cached_session():
# With num_updates 1, the decay applied is 0.1818
ema = moving_averages.ExponentialMovingAverage(
0.25, num_updates=1, zero_debias=True)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
def testAverageVariablesWithControlDeps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(0, name="v0")
add_to_v0 = v0.assign_add(1)
v1 = variables.Variable([10.0], name="v1")
@@ -254,8 +256,27 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual(1, sess.run(v0))
self.assertEqual([17.5], sess.run(v1_avg))
+ @test_util.run_in_graph_and_eager_modes
+ def testBasicEager(self):
+ v0 = variables.Variable(1.0)
+ v1 = variables.Variable(2.0)
+
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ op = ema.apply([v0, v1])
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(op)
+
+ self.evaluate(v0.assign(2.0))
+ self.evaluate(v1.assign(4.0))
+
+ self.evaluate(ema.apply([v0, v1]))
+
+ self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75)
+ self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
+
def averageVariablesNamesHelper(self, zero_debias):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
# Add a non-trainable variable.
@@ -299,7 +320,7 @@ class ExponentialMovingAverageTest(test.TestCase):
def averageVariablesNamesRespectScopeHelper(self, zero_debias):
# See discussion on #2740.
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("scope1"):
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
@@ -346,7 +367,7 @@ class ExponentialMovingAverageTest(test.TestCase):
self.averageVariablesNamesRespectScopeHelper(zero_debias=False)
def testSubsetAverageVariablesNames(self):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
# Add a non-trainable variable.
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 1b6bce2865..2304a461c1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -772,16 +772,15 @@ class Optimizer(
Returns:
A list of variables.
"""
- executing_eagerly = context.executing_eagerly()
current_graph = ops.get_default_graph()
def _from_current_graph(variable):
- if executing_eagerly:
+ if variable._in_graph_mode: # pylint: disable=protected-access
+ return variable.op.graph is current_graph
+ else:
# No variable.op in eager mode. We don't expect lots of eager graphs,
# but behavior should be consistent with graph mode.
return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
- else:
- return variable.op.graph is current_graph
optimizer_variables = [v for v in self._non_slot_variables()
if _from_current_graph(v)]
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index dfe9176bea..7a7d01d50e 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -64,7 +64,7 @@ class OptimizerTest(test.TestCase):
def testAggregationMethod(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -89,7 +89,7 @@ class OptimizerTest(test.TestCase):
def testPrecomputedGradient(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
@@ -231,7 +231,7 @@ class OptimizerTest(test.TestCase):
sgd_op.apply_gradients(grads_and_vars)
def testTrainOp(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([3.0, 4.0])
cost = 5 * var0 + 3 * var1
@@ -244,7 +244,7 @@ class OptimizerTest(test.TestCase):
def testConstraint(self):
constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0],
constraint=constraint_01)
var1 = variables.Variable([3.0, 4.0],
diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py
index 430c16b351..74e06a5e2e 100644
--- a/tensorflow/python/training/proximal_adagrad_test.py
+++ b/tensorflow/python/training/proximal_adagrad_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import proximal_adagrad
class ProximalAdagradOptimizerTest(test.TestCase):
def doTestProximalAdagradwithoutRegularization(self, use_resource=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([0.0, 0.0])
var1 = variables.Variable([0.0, 0.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -71,7 +71,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.doTestProximalAdagradwithoutRegularization(use_resource=True)
def testProximalAdagradwithoutRegularization2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -98,7 +98,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -114,7 +114,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
[[0, 1]], var0.eval(), atol=0.01)
def testProximalAdagradWithL1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -140,7 +140,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.assertAllClose(np.array([2.959304, 1.029232]), v1_val)
def testProximalAdagradWithL1_L2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -206,7 +206,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
return v0_val, v1_val
def testEquivAdagradwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -214,7 +214,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1))
@@ -223,7 +223,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
self.assertAllClose(val1, val3)
def testEquivSparseAdagradwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_adagrad.ProximalAdagradOptimizer(
3.0,
@@ -232,7 +232,7 @@ class ProximalAdagradOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
adagrad.AdagradOptimizer(
3.0, initial_accumulator_value=0.1),
diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py
index 4e4812fe60..f77f68b234 100644
--- a/tensorflow/python/training/proximal_gradient_descent_test.py
+++ b/tensorflow/python/training/proximal_gradient_descent_test.py
@@ -36,7 +36,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
def doTestProximalGradientDescentwithoutRegularization(
self, use_resource=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
@@ -69,7 +69,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.doTestProximalGradientDescentwithoutRegularization(use_resource=True)
def testProximalGradientDescentwithoutRegularization2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -94,7 +94,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -111,7 +111,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
[[-111, -138]], var0.eval(), atol=0.01)
def testProximalGradientDescentWithL1_L2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@@ -174,7 +174,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
return v0_val, v1_val
def testEquivSparseGradientDescentwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
@@ -182,7 +182,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
is_sparse=True)
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True)
@@ -190,14 +190,14 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.assertAllClose(val1, val3)
def testEquivGradientDescentwithoutRegularization(self):
- with self.test_session():
+ with self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
- with self.test_session():
+ with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0))
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py
index d38c5499c7..ac9d4c850d 100644
--- a/tensorflow/python/training/queue_runner_impl.py
+++ b/tensorflow/python/training/queue_runner_impl.py
@@ -27,10 +27,14 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+_DEPRECATION_INSTRUCTION = (
+ "To construct input pipelines, use the `tf.data` module.")
-@tf_export("train.queue_runner.QueueRunner", "train.QueueRunner")
+
+@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
class QueueRunner(object):
"""Holds a list of enqueue operations for a queue, each to be run in a thread.
@@ -53,6 +57,7 @@ class QueueRunner(object):
@end_compatibility
"""
+ @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def __init__(self, queue=None, enqueue_ops=None, close_op=None,
cancel_op=None, queue_closed_exception_types=None,
queue_runner_def=None, import_scope=None):
@@ -386,7 +391,8 @@ class QueueRunner(object):
import_scope=import_scope)
-@tf_export("train.queue_runner.add_queue_runner", "train.add_queue_runner")
+@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
+@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Adds a `QueueRunner` to a collection in the graph.
@@ -405,8 +411,9 @@ def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
ops.add_to_collection(collection, qr)
-@tf_export("train.queue_runner.start_queue_runners",
- "train.start_queue_runners")
+@tf_export(v1=["train.queue_runner.start_queue_runners",
+ "train.start_queue_runners"])
+@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection=ops.GraphKeys.QUEUE_RUNNERS):
"""Starts all queue runners collected in the graph.
@@ -458,6 +465,13 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
raise TypeError("sess must be a `tf.Session` object. "
"Given class: {}".format(sess.__class__))
+ queue_runners = ops.get_collection(collection)
+ if not queue_runners:
+ logging.warning(
+ "`tf.train.start_queue_runners()` was called when no queue runners "
+ "were defined. You can safely remove the call to this deprecated "
+ "function.")
+
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index ac26e75bb9..9b9e28af2b 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -41,7 +41,7 @@ _MockOp = collections.namedtuple("MockOp", ["name"])
class QueueRunnerTest(test.TestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -61,7 +61,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(3, var.eval())
def testTwoOps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var0 = variables.Variable(zero64)
@@ -84,7 +84,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(30, var1.eval())
def testExceptionsCaptured(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"),
_MockOp("so fail")])
@@ -100,7 +100,7 @@ class QueueRunnerTest(test.TestCase):
self.assertTrue("Operation not in the graph" in str(exceptions[1]))
def testRealDequeueEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
enqueue0 = q0.enqueue((10.0,))
close0 = q0.close()
@@ -128,7 +128,7 @@ class QueueRunnerTest(test.TestCase):
dequeue1.eval()
def testRespectCoordShouldStop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -152,7 +152,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(0, var.eval())
def testRequestStopOnException(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
coord = coordinator.Coordinator()
@@ -164,7 +164,7 @@ class QueueRunnerTest(test.TestCase):
coord.join()
def testGracePeriod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The enqueue will quickly block.
queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
enqueue = queue.enqueue((10.0,))
@@ -181,7 +181,7 @@ class QueueRunnerTest(test.TestCase):
coord.join(stop_grace_period_secs=1.0)
def testMultipleSessions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -196,7 +196,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual(len(threads), len(other_threads))
def testIgnoreMultiStarts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -212,7 +212,7 @@ class QueueRunnerTest(test.TestCase):
self.assertEqual([], new_threads)
def testThreads(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.Variable(zero64)
@@ -256,7 +256,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init_op.run()
threads = queue_runner_impl.start_queue_runners(sess)
for t in threads:
@@ -273,7 +273,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session():
+ with self.cached_session():
init_op.run()
with self.assertRaisesRegexp(TypeError, "tf.Session"):
queue_runner_impl.start_queue_runners("NotASession")
@@ -286,7 +286,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session():
+ with self.cached_session():
init_op.run()
threads = queue_runner_impl.start_queue_runners(
monitored_session.MonitoredSession())
@@ -303,7 +303,7 @@ class QueueRunnerTest(test.TestCase):
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
init_op.run()
threads = queue_runner_impl.start_queue_runners(sess)
for t in threads:
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 6043327384..4f5f96e2b4 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -165,7 +165,7 @@ class RMSPropOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -187,7 +187,7 @@ class RMSPropOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariableCentered(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 04fce496bd..274c856686 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -809,6 +809,22 @@ class BaseSaverBuilder(object):
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
else:
+ graph = ops.get_default_graph()
+ # Do some sanity checking on collections containing
+ # PartitionedVariables. If a saved collection has a PartitionedVariable,
+ # the GraphDef needs to include concat ops to get the value (or there'll
+ # be a lookup error on load).
+ check_collection_list = graph.get_all_collection_keys()
+ for collection_type in check_collection_list:
+ for element in graph.get_collection(collection_type):
+ if isinstance(element, variables.PartitionedVariable):
+ try:
+ graph.get_operation_by_name(element.name)
+ except KeyError:
+ # Create a concat op for this PartitionedVariable. The user may
+ # not need it, but we'll try looking it up on MetaGraph restore
+ # since it's in a collection.
+ element.as_tensor()
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
@@ -869,7 +885,7 @@ def _get_saver_or_default():
class Saver(object):
"""Saves and restores variables.
- See @{$variables$Variables}
+ See [Variables](https://tensorflow.org/guide/variables)
for an overview of variables, saving and restoring.
The `Saver` class adds ops to save and restore variables to and from
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index b55e64122a..0ac84813c8 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -84,7 +84,7 @@ class SaverTest(test.TestCase):
def basicSaveRestore(self, variable_op):
save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
@@ -115,7 +115,7 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -137,7 +137,7 @@ class SaverTest(test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0_2 = variable_op(1000.0, name="v0")
v1_2 = variable_op(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -222,7 +222,7 @@ class SaverTest(test.TestCase):
# Save from graph mode and restore from eager mode.
graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
with context.graph_mode():
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Create a graph model and save the checkpoint.
w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
@@ -256,7 +256,7 @@ class SaverTest(test.TestCase):
graph_saver.save(None, eager_ckpt_prefix)
with context.graph_mode():
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
graph_saver = saver_module.Saver([w3, w4])
@@ -268,7 +268,7 @@ class SaverTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testResourceSaveRestoreCachingDevice(self):
save_path = os.path.join(self.get_temp_dir(), "resource_cache")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
name="v")
if context.executing_eagerly():
@@ -324,7 +324,7 @@ class SaverTest(test.TestCase):
save_relative_paths=True)
init_all_op = [variables.global_variables_initializer(), v2_init]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -349,7 +349,7 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -373,7 +373,7 @@ class SaverTest(test.TestCase):
v0 = variables.Variable(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = sess.graph.get_tensor_by_name(
save.saver_def.filename_tensor_name)
self.assertEqual(sess.run(tensor), filename)
@@ -381,7 +381,7 @@ class SaverTest(test.TestCase):
def testInvalidPath(self):
v0 = variables.Variable(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
with self.assertRaisesRegexp(
ValueError, "The passed save_path is not a valid checkpoint:"):
@@ -390,7 +390,7 @@ class SaverTest(test.TestCase):
def testInt64(self):
save_path = os.path.join(self.get_temp_dir(), "int64")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
v = variables.Variable(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
@@ -401,7 +401,7 @@ class SaverTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
@@ -465,7 +465,7 @@ class SaverTest(test.TestCase):
def testBasicsWithListOfVariables(self):
save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variables.Variable(10.0, name="v0")
@@ -489,7 +489,7 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the variables
# have not been initialized either.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -514,7 +514,7 @@ class SaverTest(test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0_2 = variables.Variable(1000.0, name="v0")
v1_2 = variables.Variable(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -536,14 +536,14 @@ class SaverTest(test.TestCase):
self.assertEqual(30.0, v2_2.values().eval())
def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
save = saver_module.Saver({var_name: var})
if not context.executing_eagerly():
self.evaluate(var.initializer)
val = save.save(sess, save_path)
self.assertEqual(save_path, val)
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
save = saver_module.Saver({var_name: var})
save.restore(sess, save_path)
@@ -559,12 +559,12 @@ class SaverTest(test.TestCase):
def testAllowEmpty(self):
save_path = os.path.join(self.get_temp_dir(), "allow_empty")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = constant_op.constant(1)
save = saver_module.Saver(allow_empty=True)
val = save.save(sess, save_path)
self.assertIsNone(val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save = saver_module.Saver(allow_empty=True)
save.restore(sess, save_path)
@@ -693,7 +693,7 @@ class SaverTest(test.TestCase):
# Save and reload one Variable named "var0".
self._SaveAndLoad("var0", 0.0, 1.0, save_path)
for use_tensor in [True, False]:
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
var = resource_variable_ops.ResourceVariable(1.0, name="var0")
save = saver_module.Saver(
{
@@ -740,7 +740,7 @@ class SaverTest(test.TestCase):
# save succeeds or fails is implementation dependent. Therefore we allow
# both cases.
try:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -751,7 +751,7 @@ class SaverTest(test.TestCase):
# Save the graph.
save.save(sess, save_path)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
@@ -775,7 +775,7 @@ class SaverTest(test.TestCase):
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize all variables
sess.run(init_all_op)
@@ -791,7 +791,7 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
# Build the first session.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(10.0, name="v0", dtype=dtypes.float32)
if not context.executing_eagerly():
@@ -801,7 +801,7 @@ class SaverTest(test.TestCase):
save.save(sess, save_path)
# Start a second session.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32)
# Restore the saved value with different dtype
# in the parameter nodes.
@@ -822,7 +822,7 @@ class SaverTest(test.TestCase):
return small_v + large_v
save_graph = ops_lib.Graph()
- with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+ with save_graph.as_default(), self.session(graph=save_graph) as sess:
orig_vars = _model()
sess.run(variables.global_variables_initializer())
save = saver_module.Saver(max_to_keep=1)
@@ -983,7 +983,7 @@ class SaveRestoreShardedTest(test.TestCase):
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
- with self.test_session():
+ with self.cached_session():
v0 = variables.Variable(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
@@ -999,7 +999,7 @@ class SaveRestoreShardedTest(test.TestCase):
call_saver_with_dict = False # updated by test loop below
def _save(slices=None, partitioner=None):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Calls .eval() to return the ndarray that makes up the full variable.
rnd = random_ops.random_uniform(var_full_shape).eval()
@@ -1036,7 +1036,7 @@ class SaveRestoreShardedTest(test.TestCase):
return rnd
def _restore(slices=None, partitioner=None):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
if slices:
assert not partitioner
new_vs = partitioned_variables.create_partitioned_variables(
@@ -1209,7 +1209,7 @@ class MaxToKeepTest(test.TestCase):
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
variables.global_variables_initializer().run()
@@ -1447,7 +1447,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("no_max_to_keep")
save_dir2 = self._get_test_dir("max_to_keep_0")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
variables.global_variables_initializer().run()
@@ -1474,7 +1474,7 @@ class MaxToKeepTest(test.TestCase):
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(10.0, name="v")
save = saver_module.Saver({"v": v})
variables.global_variables_initializer().run()
@@ -1497,7 +1497,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
def testNonSharded(self, mock_time):
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variable_scope.variable([10.0], name="v")
# Run the initializer NOW to avoid the 0.5s overhead of the first Run()
# call, which throws the test timing off in fastbuild mode.
@@ -1549,7 +1549,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
def _testNonReshape(self, variable_op):
save_path = os.path.join(self.get_temp_dir(), "non_reshape")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
@@ -1574,7 +1574,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
# Verify that the mapped names are present in the Saved file and can be
# Restored using remapped names.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
@@ -1594,7 +1594,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
# Add a prefix to the node names in the current graph and Restore using
# remapped names.
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="restore_prefix/v0")
v1 = variable_op(-1.0, name="restore_prefix/v1")
@@ -1630,7 +1630,7 @@ class MetaGraphTest(test.TestCase):
def testAddCollectionDef(self):
test_dir = self._get_test_dir("good_collection")
filename = os.path.join(test_dir, "metafile")
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(1.0, name="v0")
control_flow_ops.cond(
@@ -1685,7 +1685,7 @@ class MetaGraphTest(test.TestCase):
self, meta_graph_def, new_meta_graph_def)
def testAddCollectionDefFails(self):
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(10.0, name="v0")
# Creates a saver.
@@ -1709,7 +1709,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
v1 = variables.Variable(11.0, name="v1")
@@ -1753,7 +1753,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Imports from meta_graph.
saver_module.import_meta_graph(filename)
# Retrieves SAVERS collection. Verifies there are 2 entries.
@@ -1786,7 +1786,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
v1 = variables.Variable(11.0, name="v1")
@@ -1838,25 +1838,25 @@ class MetaGraphTest(test.TestCase):
def testBinaryAndTextFormat(self):
test_dir = self._get_test_dir("binary_and_text")
filename = os.path.join(test_dir, "metafile")
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
# Creates a graph.
variables.Variable(10.0, name="v0")
# Exports the graph as binary format.
saver_module.export_meta_graph(filename, as_text=False)
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
# Imports the binary format graph.
saver = saver_module.import_meta_graph(filename)
self.assertIsNotNone(saver)
# Exports the graph as text format.
saver.export_meta_graph(filename, as_text=True)
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
# Imports the text format graph.
saver_module.import_meta_graph(filename)
# Writes wrong contents to the file.
graph_io.write_graph(saver.as_saver_def(),
os.path.dirname(filename),
os.path.basename(filename))
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
# Import should fail.
with self.assertRaisesWithPredicateMatch(IOError,
lambda e: "Cannot parse file"):
@@ -1870,7 +1870,7 @@ class MetaGraphTest(test.TestCase):
def testSliceVariable(self):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([20.0], name="v1")
v2 = variables.Variable([20.0], name="v2")
v2._set_save_slice_info(
@@ -1946,7 +1946,7 @@ class MetaGraphTest(test.TestCase):
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initializes all the variables.
sess.run(init_all_op)
# Runs to logit.
@@ -1961,7 +1961,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
train_filename = os.path.join(test_dir, "train_metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(filename)
# Generates a new MetaGraphDef.
@@ -1998,7 +1998,7 @@ class MetaGraphTest(test.TestCase):
def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
train_filename = os.path.join(test_dir, "train_metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(train_filename)
# Restores from checkpoint.
@@ -2120,7 +2120,7 @@ class MetaGraphTest(test.TestCase):
# pylint: enable=g-long-lambda
def testStrippedOpListDef(self):
- with self.test_session():
+ with self.cached_session():
# Creates a graph.
v0 = variables.Variable(0.0)
var = variables.Variable(10.0)
@@ -2160,7 +2160,7 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
- with self.test_session():
+ with self.cached_session():
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -2177,7 +2177,7 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must *not* be removed, even if they map
# to their defaults.
- with self.test_session(graph=ops_lib.Graph()):
+ with self.session(graph=ops_lib.Graph()):
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -2397,7 +2397,7 @@ class CheckpointReaderTest(test.TestCase):
}, write_version=self._WRITE_VERSION)
save_path = os.path.join(self.get_temp_dir(),
"ckpt_for_debug_string" + str(self._WRITE_VERSION))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_all_op)
# Saves a checkpoint.
save.save(sess, save_path)
@@ -2541,7 +2541,7 @@ class ScopedGraphTest(test.TestCase):
export_scope="hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
sess.run(variables.global_variables_initializer())
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)
@@ -2601,7 +2601,7 @@ class ScopedGraphTest(test.TestCase):
set(variables.global_variables()) - set(var_list.keys()))
init_rest_op = variables.variables_initializer(rest_variables)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
saver.restore(sess, os.path.join(test_dir, ckpt_filename))
# Verify that we have restored weights1 and biases1.
@@ -2635,7 +2635,7 @@ class ScopedGraphTest(test.TestCase):
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
- with self.test_session(graph=graph1) as sess:
+ with self.session(graph=graph1) as sess:
sess.run(variables.global_variables_initializer())
_, var_list_1 = meta_graph.export_scoped_meta_graph(
export_scope="hidden1")
@@ -2656,7 +2656,7 @@ class ScopedGraphTest(test.TestCase):
var_list_2 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden1", to_scope="hidden2")
- with self.test_session(graph=graph1) as sess:
+ with self.session(graph=graph1) as sess:
saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
saver1.restore(sess, saver0_ckpt)
saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
@@ -2672,7 +2672,7 @@ class ScopedGraphTest(test.TestCase):
from_graph=graph1,
to_graph=graph2)
- with self.test_session(graph=graph2) as sess:
+ with self.session(graph=graph2) as sess:
saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
saver3.restore(sess, saver0_ckpt)
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
@@ -2691,7 +2691,7 @@ class ScopedGraphTest(test.TestCase):
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
- with self.test_session(graph=graph1) as sess:
+ with self.session(graph=graph1) as sess:
sess.run(variables.global_variables_initializer())
_, var_list_1 = meta_graph.export_scoped_meta_graph(
graph_def=graph1.as_graph_def(), export_scope="hidden1")
@@ -2708,7 +2708,7 @@ class ScopedGraphTest(test.TestCase):
from_graph=graph1,
to_graph=graph2)
- with self.test_session(graph=graph2) as sess:
+ with self.session(graph=graph2) as sess:
saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
saver3.restore(sess, saver0_ckpt)
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
@@ -2729,7 +2729,7 @@ class ScopedGraphTest(test.TestCase):
saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
variables.global_variables_initializer().run()
saver1.save(sess, saver1_ckpt, write_state=False)
saver2.save(sess, saver2_ckpt, write_state=False)
@@ -2745,7 +2745,7 @@ class ScopedGraphTest(test.TestCase):
saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
self.assertEqual(1, len(saver_list1))
- with self.test_session(graph=graph1) as sess:
+ with self.session(graph=graph1) as sess:
saver_list1[0].restore(sess, saver1_ckpt)
self.assertEqual(1.0, var_dict1["variable1:0"].eval())
@@ -2760,7 +2760,7 @@ class ScopedGraphTest(test.TestCase):
saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
self.assertEqual(1, len(saver_list2))
- with self.test_session(graph=graph2) as sess:
+ with self.session(graph=graph2) as sess:
saver_list2[0].restore(sess, saver2_ckpt)
self.assertEqual(2.0, var_dict2["variable2:0"].eval())
@@ -2853,8 +2853,8 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
saver.restore(sess, save_path)
@@ -2867,7 +2867,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
@@ -2900,7 +2900,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.assertEqual(1, v.eval_count)
saver.restore(sess, save_path)
@@ -2957,7 +2957,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
b = resource_variable_ops.ResourceVariable(1., name="b")
a_saver = saver_module.Saver([a])
b_saver = saver_module.Saver([b])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with self.assertRaisesRegexp(
@@ -2979,14 +2979,14 @@ class CheckpointableCompatibilityTests(test.TestCase):
a = variables.Variable(1., name="a")
a_saver = saver_module.Saver([a])
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sess.run(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
a = variables.Variable([1.], name="a")
a_saver = saver_module.Saver([a])
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"a mismatch between the current graph and the graph"):
@@ -2997,7 +2997,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_graph = ops_lib.Graph()
- with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+ with save_graph.as_default(), self.session(graph=save_graph) as sess:
root = self._initialized_model()
object_saver = checkpointable_utils.CheckpointableSaver(root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
@@ -3031,7 +3031,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_graph = ops_lib.Graph()
- with save_graph.as_default(), self.test_session(graph=save_graph):
+ with save_graph.as_default(), self.session(graph=save_graph):
root = self._initialized_model()
object_saver = checkpointable_utils.CheckpointableSaver(root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index d7e6dac95b..f1d18f7704 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -98,7 +98,7 @@ class SessionManagerTest(test.TestCase):
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
v = variables.Variable([6.0, 7.0, 8.0], name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
@@ -236,7 +236,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -294,7 +294,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -326,7 +326,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables(),
@@ -362,7 +362,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -467,7 +467,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -519,7 +519,7 @@ class SessionManagerTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x_res")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -566,7 +566,7 @@ class SessionManagerTest(test.TestCase):
with ops.Graph().as_default():
i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
v = variables.Variable(array_ops.identity(i), name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
@@ -585,7 +585,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -602,7 +602,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -619,7 +619,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -640,7 +640,7 @@ class SessionManagerTest(test.TestCase):
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
self.assertEqual(False, variables.is_variable_initialized(w).eval())
sm2 = session_manager.SessionManager(
@@ -714,7 +714,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
v = variables.Variable([6.0, 7.0, 8.0], name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
@@ -769,7 +769,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
v = variables.Variable(2, name="v")
- with self.test_session():
+ with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py
index 08a3c8dc53..6d6364169f 100644
--- a/tensorflow/python/training/slot_creator_test.py
+++ b/tensorflow/python/training/slot_creator_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import slot_creator
class SlotCreatorTest(test.TestCase):
def testCreateSlotFromVariable(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
@@ -44,7 +44,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([1.0, 2.5], slot.eval())
def testCreateSlotFromTensor(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant([1.0, 2.5], name="const")
slot = slot_creator.create_slot(v, v * 2, name="slot")
@@ -56,7 +56,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([2.0, 5.0], slot.eval())
def testCreateZerosSlotFromVariable(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([1.0, 2.5], name="var")
with ops.control_dependencies(None):
slot = slot_creator.create_zeros_slot(
@@ -70,7 +70,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromDynamicShapedVariable(self):
- with self.test_session():
+ with self.cached_session():
dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
dyn_shape = array_ops.placeholder_with_default(dyn_shape,
shape=[None])
@@ -91,7 +91,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromTensor(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant([1.0, 2.5], name="const")
with ops.control_dependencies(None):
slot = slot_creator.create_zeros_slot(v, name="slot")
@@ -104,7 +104,7 @@ class SlotCreatorTest(test.TestCase):
self.assertAllEqual([0.0, 0.0], slot.eval())
def testCreateZerosSlotFromDynamicShapedTensor(self):
- with self.test_session():
+ with self.cached_session():
v = random_ops.random_uniform([2], dtype=dtypes.float64)
v = array_ops.placeholder_with_default(v, shape=[None], name="const")
with ops.control_dependencies(None):
@@ -120,7 +120,7 @@ class SlotCreatorTest(test.TestCase):
def testCreateSlotFromVariableRespectsScope(self):
# See discussion on #2740.
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("scope"):
v = variables.Variable([1.0, 2.5], name="var")
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 71ed88093a..caf6eba3e0 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -795,7 +795,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
@@ -859,14 +859,14 @@ class SupervisorTest(test.TestCase):
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
def testNoQueueRunners(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
self.assertEqual(0, len(sv.start_queue_runners(sess)))
sv.stop()
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index 0c6cf910d1..7afaa92699 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -53,7 +53,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
which replicas can fetch the new variables and continue.
The following accumulators/queue are created:
- <empty line>
+
* N `gradient accumulators`, one per variable to train. Gradients are pushed
to them and the chief worker will wait until enough gradients are collected
and then average them before applying to variables. The accumulator will
@@ -68,7 +68,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
The optimizer adds nodes to the graph to collect gradients and pause the
trainers until variables are updated.
For the Parameter Server job:
- <empty line>
+
1. An accumulator is created for each variable, and each replica pushes the
gradients into the accumulators instead of directly applying them to the
variables.
@@ -81,7 +81,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
update its local_step variable and start the next batch.
For the replicas:
- <empty line>
+
1. Start a step: fetch variables and compute gradients.
2. Once the gradients have been computed, push them into gradient
accumulators. Each accumulator will check the staleness and drop the stale.
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 6f6305a505..686c4be31a 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -15,7 +15,7 @@
"""Support for training models.
-See the @{$python/train} guide.
+See the [Training](https://tensorflow.org/api_guides/python/train) guide.
"""
# Optimizers.
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 2ff3eeb153..d998d6af81 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -129,6 +129,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
# Create in proper graph and base name_scope.
@@ -139,6 +140,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index 0ba7ba983d..bea9bb6dff 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -32,7 +32,7 @@ from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.VocabInfo", allow_multiple_exports=True)
+@tf_export("train.VocabInfo")
class VocabInfo(
collections.namedtuple("VocabInfo", [
"new_vocab",
@@ -41,6 +41,7 @@ class VocabInfo(
"old_vocab",
"old_vocab_size",
"backup_initializer",
+ "axis",
])):
"""Vocabulary information for warm-starting.
@@ -62,6 +63,42 @@ class VocabInfo(
backup_initializer: [Optional] A variable initializer used for variables
corresponding to new vocabulary entries and OOV. If not provided, these
entries will be zero-initialized.
+ axis: [Optional] Denotes what axis the vocabulary corresponds to. The
+ default, 0, corresponds to the most common use case (embeddings or
+ linear weights for binary classification / regression). An axis of 1
+ could be used for warm-starting output layers with class vocabularies.
+
+ For example:
+
+ embeddings_vocab_info = tf.VocabInfo(
+ new_vocab='embeddings_vocab',
+ new_vocab_size=100,
+ num_oov_buckets=1,
+ old_vocab='pretrained_embeddings_vocab',
+ old_vocab_size=10000,
+ backup_initializer=tf.truncated_normal_initializer(
+ mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
+ axis=0)
+
+ softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.glorot_uniform_initializer(),
+ axis=1)
+
+ softmax_output_layer_bias_vocab_info = tf.VocabInfo(
+ new_vocab='class_vocab',
+ new_vocab_size=5,
+ num_oov_buckets=0, # No OOV for classes.
+ old_vocab='old_class_vocab',
+ old_vocab_size=8,
+ backup_initializer=tf.zeros_initializer(),
+ axis=0)
+
+ Currently, only axis=0 and axis=1 are supported.
"""
def __new__(cls,
@@ -70,7 +107,12 @@ class VocabInfo(
num_oov_buckets,
old_vocab,
old_vocab_size=-1,
- backup_initializer=None):
+ backup_initializer=None,
+ axis=0):
+ if axis != 0 and axis != 1:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
return super(VocabInfo, cls).__new__(
cls,
new_vocab,
@@ -79,6 +121,7 @@ class VocabInfo(
old_vocab,
old_vocab_size,
backup_initializer,
+ axis,
)
@@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var,
previous_vocab_size=-1,
current_oov_buckets=0,
prev_tensor_name=None,
- initializer=None):
+ initializer=None,
+ axis=0):
"""Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
Use this method when the `var` is backed by vocabulary. This method stitches
@@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var,
None, we lookup tensor with same name as given `var`.
initializer: Variable initializer to be used for missing entries. If None,
missing entries will be zero-initialized.
+ axis: Axis of the variable that the provided vocabulary corresponds to.
Raises:
ValueError: If required args are not provided.
@@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var,
# Assume tensor name remains the same.
prev_tensor_name = _infer_var_name(var)
+ # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
+ total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
@@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var,
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
- # TODO(eddz): Support cases where class vocabularies need remapping too.
+ if axis == 0:
+ new_row_vocab_size = current_vocab_size
+ new_col_vocab_size = v_shape[1]
+ old_row_vocab_size = previous_vocab_size
+ old_row_vocab_file = prev_vocab_path
+ new_row_vocab_file = current_vocab_path
+ old_col_vocab_file = None
+ new_col_vocab_file = None
+ num_row_oov_buckets = current_oov_buckets
+ num_col_oov_buckets = 0
+ elif axis == 1:
+ # Note that we must compute this value across all partitions, whereas
+ # in the axis = 0 case, we can simply use v_shape[1] because we don't
+ # allow partitioning across axis = 1.
+ new_row_vocab_size = total_v_first_axis
+ new_col_vocab_size = current_vocab_size
+ old_row_vocab_size = -1
+ old_row_vocab_file = None
+ new_row_vocab_file = None
+ old_col_vocab_file = prev_vocab_path
+ new_col_vocab_file = current_vocab_path
+ num_row_oov_buckets = 0
+ num_col_oov_buckets = current_oov_buckets
+ else:
+ raise ValueError("The only supported values for the axis argument are 0 "
+ "and 1. Provided axis: {}".format(axis))
+
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
- new_row_vocab_size=current_vocab_size,
- new_col_vocab_size=v_shape[1],
- old_row_vocab_size=previous_vocab_size,
- old_row_vocab_file=prev_vocab_path,
- new_row_vocab_file=current_vocab_path,
- old_col_vocab_file=None,
- new_col_vocab_file=None,
- num_row_oov_buckets=current_oov_buckets,
- num_col_oov_buckets=0,
+ new_row_vocab_size=new_row_vocab_size,
+ new_col_vocab_size=new_col_vocab_size,
+ old_row_vocab_size=old_row_vocab_size,
+ old_row_vocab_file=old_row_vocab_file,
+ new_row_vocab_file=new_row_vocab_file,
+ old_col_vocab_file=old_col_vocab_file,
+ new_col_vocab_file=new_col_vocab_file,
+ num_row_oov_buckets=num_row_oov_buckets,
+ num_col_oov_buckets=num_col_oov_buckets,
initializer=initializer)
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
@@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from,
previous_vocab_size=vocab_info.old_vocab_size,
current_oov_buckets=vocab_info.num_oov_buckets,
prev_tensor_name=prev_var_name,
- initializer=vocab_info.backup_initializer)
+ initializer=vocab_info.backup_initializer,
+ axis=vocab_info.axis)
else:
# For the special value of vars_to_warm_start = None,
# we only warm-start variables with explicitly specified vocabularies.
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 6a4c207d79..6c860cd452 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -59,7 +59,7 @@ class WarmStartingUtilTest(test.TestCase):
initializer=None,
partitioner=None):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
var = variable_scope.get_variable(
var_name,
shape=shape,
@@ -102,12 +102,12 @@ class WarmStartingUtilTest(test.TestCase):
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarPrevVarPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -118,19 +118,19 @@ class WarmStartingUtilTest(test.TestCase):
prev_val = np.concatenate([weights[0], weights[1]], axis=0)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
sess.run(variables.global_variables_initializer())
- self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+ self.assertAllClose(prev_val, fruit_weights.eval(sess))
def testWarmStartVarCurrentVarPartitioned(self):
_, prev_val = self._create_prev_run_var(
"fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[4, 1],
@@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarBothVarsPartitioned(self):
_, weights = self._create_prev_run_var(
@@ -154,7 +154,7 @@ class WarmStartingUtilTest(test.TestCase):
prev_val = np.concatenate([weights[0], weights[1]], axis=0)
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"new_scope/fruit_weights",
shape=[4, 1],
@@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase):
fruit_weights = fruit_weights._get_variable_list()
new_val = np.concatenate(
[fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
- self.assertAllEqual(prev_val, new_val)
+ self.assertAllClose(prev_val, new_val)
def testWarmStartVarWithVocab(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@@ -183,15 +183,40 @@ class WarmStartingUtilTest(test.TestCase):
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocab(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -203,7 +228,7 @@ class WarmStartingUtilTest(test.TestCase):
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(
@@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase):
previous_vocab_size=2)
sess.run(variables.global_variables_initializer())
# Old vocabulary limited to ['apple', 'banana'].
- self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+ self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
def testWarmStartVarWithVocabPrevVarPartitioned(self):
@@ -232,15 +257,42 @@ class WarmStartingUtilTest(test.TestCase):
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
self.get_temp_dir(), prev_vocab_path)
sess.run(variables.global_variables_initializer())
- self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+ self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
fruit_weights.eval(sess))
+ def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+ [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
def testWarmStartVarWithVocabCurrentVarPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -252,7 +304,7 @@ class WarmStartingUtilTest(test.TestCase):
["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[6, 1],
@@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStartVarWithVocabBothVarsPartitioned(self):
prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"old_vocab")
@@ -289,7 +373,7 @@ class WarmStartingUtilTest(test.TestCase):
"blueberry"], "new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_weights = variable_scope.get_variable(
"fruit_weights",
shape=[6, 1],
@@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase):
self.assertTrue(
isinstance(fruit_weights, variables.PartitionedVariable))
fruit_weights_vars = fruit_weights._get_variable_list()
- self.assertAllEqual([[2.], [1.5], [1.]],
+ self.assertAllClose([[2.], [1.5], [1.]],
fruit_weights_vars[0].eval(sess))
- self.assertAllEqual([[0.5], [0.], [0.]],
+ self.assertAllClose([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
+ prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+ self._create_prev_run_var(
+ "fruit_output_layer",
+ shape=[4, 2],
+ initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+ partitioner=lambda shape, dtype: [2, 1])
+
+ # New vocab with elements in reverse order and one new element.
+ new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+ "new_vocab")
+ # New session and new graph.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ fruit_output_layer = variable_scope.get_variable(
+ "fruit_output_layer",
+ shape=[4, 3],
+ initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+ [0., 0., 0.]],
+ partitioner=lambda shape, dtype: [2, 1])
+ ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+ current_vocab_size=3,
+ prev_ckpt=self.get_temp_dir(),
+ prev_vocab_path=prev_vocab_path,
+ axis=1)
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(
+ isinstance(fruit_output_layer, variables.PartitionedVariable))
+ fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+ self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+ fruit_output_layer_vars[0].eval(sess))
+ self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+ fruit_output_layer_vars[1].eval(sess))
+
def testWarmStart_ListOfVariables(self):
# Save checkpoint from which to warm-start.
_, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
@@ -315,7 +433,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
"v1",
@@ -335,7 +453,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
"v1",
@@ -359,7 +477,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_int], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -369,7 +487,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_int], partitioner)
ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
sess.run(variables.global_variables_initializer())
@@ -388,7 +506,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_hash], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -398,7 +516,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_hash], partitioner)
ws_util.warm_start(
self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
@@ -422,7 +540,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -432,7 +550,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
@@ -458,7 +576,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -468,7 +586,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
@@ -503,7 +621,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -513,7 +631,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
@@ -546,7 +664,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([real_bucket], partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, the weights should be initialized using default
@@ -556,7 +674,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model([real_bucket], partitioner)
ws_util.warm_start(
self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
@@ -586,7 +704,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start. Also create a bias variable,
# so we can check that it's also warm-started.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
sc_int_weights = variable_scope.get_variable(
"linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
sc_hash_weights = variable_scope.get_variable(
@@ -617,7 +735,7 @@ class WarmStartingUtilTest(test.TestCase):
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
sess.run(variables.global_variables_initializer())
# Without warm-starting, all weights should be initialized using default
@@ -633,7 +751,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
@@ -675,7 +793,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
sc_keys_weights = variable_scope.get_variable(
@@ -694,7 +812,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
@@ -743,7 +861,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
sc_keys_weights = variable_scope.get_variable(
@@ -756,7 +874,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols,
partitioner=None)
vocab_info = ws_util.VocabInfo(
@@ -802,7 +920,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
variable_scope.get_variable(
@@ -820,7 +938,7 @@ class WarmStartingUtilTest(test.TestCase):
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
vocab_info = ws_util.VocabInfo(
new_vocab=sc_vocab.vocabulary_file,
@@ -866,7 +984,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variable_scope.get_variable(
"input_layer/sc_vocab_embedding/embedding_weights",
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
@@ -887,7 +1005,7 @@ class WarmStartingUtilTest(test.TestCase):
all_deep_cols = [emb_vocab_column]
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=_partitioner):
# Create the variables.
@@ -933,7 +1051,7 @@ class WarmStartingUtilTest(test.TestCase):
# Save checkpoint from which to warm-start.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
variable_scope.get_variable(
"linear_model/sc_vocab_embedding/embedding_weights",
initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
@@ -957,7 +1075,7 @@ class WarmStartingUtilTest(test.TestCase):
all_deep_cols = [emb_vocab]
# New graph, new session with warm-starting.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
cols_to_vars = {}
with variable_scope.variable_scope("", partitioner=_partitioner):
# Create the variables.
@@ -1015,7 +1133,7 @@ class WarmStartingUtilTest(test.TestCase):
# Unused variable names raises ValueError.
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variable_scope.get_variable(
"x",
shape=[4, 1],
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index faae0d89c3..2968ca9c07 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False):
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
+# See the swig file (util.i) for documentation.
+_is_mapping = _pywrap_tensorflow.IsMapping
+
+
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@@ -73,7 +77,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, (dict, _collections.Mapping)):
+ if _is_mapping(instance):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +93,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, (dict, _collections.Mapping)):
+ if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -102,53 +106,16 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if its input is a collections.Sequence (except strings).
+# See the swig file (util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequence
- Args:
- seq: an input sequence.
- Returns:
- True if the sequence is a not a string and is a collections.Sequence or a
- dict.
- """
- return _pywrap_tensorflow.IsSequence(seq)
+# See the swig file (util.i) for documentation.
+flatten = _pywrap_tensorflow.Flatten
-def flatten(nest):
- """Returns a flat list from a given nested structure.
-
- If `nest` is not a sequence, tuple, or dict, then returns a single-element
- list: `[nest]`.
-
- In the case of dict instances, the sequence consists of the values, sorted by
- key to ensure deterministic behavior. This is true also for `OrderedDict`
- instances: their sequence order is ignored, the sorting order of keys is
- used instead. The same convention is followed in `pack_sequence_as`. This
- correctly repacks dicts and `OrderedDict`s after they have been flattened,
- and also allows flattening an `OrderedDict` and then repacking it back using
- a corresponding plain dict, or vice-versa.
- Dictionaries with non-sortable keys cannot be flattened.
-
- Users must not modify any collections used in `nest` while this function is
- running.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object. Note, numpy
- arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
-
- Raises:
- TypeError: The nest is or contains a dict with non-sortable keys.
- """
- return _pywrap_tensorflow.Flatten(nest)
-
-
-def _same_namedtuples(nest1, nest2):
- """Returns True if the two namedtuples have the same name and fields."""
- return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
+# See the swig file (util.i) for documentation.
+_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
def assert_same_structure(nest1, nest2, check_types=True):
@@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence):
% len(flat_sequence))
return flat_sequence[0]
- flat_structure = flatten(structure)
- if len(flat_structure) != len(flat_sequence):
- raise ValueError(
- "Could not pack sequence. Structure had %d elements, but flat_sequence "
- "had %d elements. Structure: %s, flat_sequence: %s."
- % (len(flat_structure), len(flat_sequence), structure, flat_sequence))
-
- _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ try:
+ final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ if final_index < len(flat_sequence):
+ raise IndexError
+ except IndexError:
+ flat_structure = flatten(structure)
+ if len(flat_structure) != len(flat_sequence):
+ raise ValueError(
+ "Could not pack sequence. Structure had %d elements, but "
+ "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." %
+ (len(flat_structure), len(flat_sequence), structure, flat_sequence))
return _sequence_like(structure, packed)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2369eb610e..ef503137d1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -461,7 +461,7 @@ class NestTest(parameterized.TestCase, test.TestCase):
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_np = sess.run(output, feed_dict=feed_dict)
self.assertAllClose(output_np[0],
feed_dict[inp_a][0] + feed_dict[inp_b][0])
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index 274f32c21f..a5ac430ce7 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name
has no effect on exporting a constant.
api_name: Name of the API you want to generate (e.g. `tensorflow` or
`estimator`). Default is `tensorflow`.
+ allow_multiple_exports: Allow symbol to be exported multiple time under
+ different names.
"""
self._names = args
self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
+ self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if api_names_attr in func.__dict__:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
+ if not self._allow_multiple_exports:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
setattr(func, api_names_attr, names)
def export_constant(self, module_name, name):
@@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
-estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
+estimator_export = functools.partial(
+ api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True)
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 16fa1f547d..fedbe1dff6 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -106,7 +106,7 @@ class TfShouldUseTest(test.TestCase):
def return_const(value):
return constant_op.constant(value, name='blah3')
with reroute_error() as (error, _):
- with self.test_session():
+ with self.cached_session():
return_const(0.0)
# Creating another op and executing it does not mark the
# unused op as being "used".
@@ -124,7 +124,8 @@ class TfShouldUseTest(test.TestCase):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah3')
- with self.test_session():
+
+ with self.cached_session():
return_const(0.0).mark_used()
if __name__ == '__main__':
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ebb72079ef..562bbdcfeb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -470,12 +470,14 @@ void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
// with appropriate error and sets `is_type_error` to true iff
// the error to be raised should be TypeError.
-bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
- string* error_msg, bool* is_type_error) {
+bool AssertSameStructureHelper(
+ PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
+ bool* is_type_error,
+ const std::function<int(PyObject*)>& is_sequence_helper) {
DCHECK(error_msg);
DCHECK(is_type_error);
- const bool is_seq1 = IsSequence(o1);
- const bool is_seq2 = IsSequence(o2);
+ const bool is_seq1 = is_sequence_helper(o1);
+ const bool is_seq2 = is_sequence_helper(o2);
if (PyErr_Occurred()) return false;
if (is_seq1 != is_seq2) {
string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
@@ -487,7 +489,9 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
- // Got to scalars, so finished checking. Structures are the same.
+ // Got to objects that are considered non-sequences. Note that in tf.data
+ // use case lists and sparse_tensors are not considered sequences. So finished
+ // checking, structures are the same.
if (!is_seq1) return true;
if (check_types) {
@@ -586,7 +590,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return false;
}
bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error);
+ v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
@@ -647,6 +651,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
}
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
+bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
@@ -758,7 +763,32 @@ PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
string error_msg;
bool is_type_error = false;
- AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
+ IsSequenceHelper);
+ if (PyErr_Occurred()) {
+ // Don't hide Python exceptions while checking (e.g. errors fetching keys
+ // from custom mappings).
+ return nullptr;
+ }
+ if (!error_msg.empty()) {
+ PyErr_SetString(
+ is_type_error ? PyExc_TypeError : PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "The two structures don't have the same nested structure.\n\n",
+ "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
+ PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
+ .c_str());
+ return nullptr;
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
+ bool check_types) {
+ string error_msg;
+ bool is_type_error = false;
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
+ IsSequenceForDataHelper);
if (PyErr_Occurred()) {
// Don't hide Python exceptions while checking (e.g. errors fetching keys
// from custom mappings).
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 41dcc969f8..343605285e 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -47,6 +47,15 @@ bool IsSequence(PyObject* o);
// True if `instance` is a `namedtuple`.
PyObject* IsNamedtuple(PyObject* o, bool strict);
+// Returns a true if its input is a collections.Mapping.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the sequence subclasses mapping.
+bool IsMapping(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
@@ -135,16 +144,20 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
// 1. It removes support for lists as a level of nesting in nested structures.
// 2. It adds support for `SparseTensorValue` as an atomic element.
-// IsSequence specialized for the data package. Additional comments about
-// difference in functionality can be found in nest.py in tensorflow.data.util
-// and in the comments for Flatten above.
+// IsSequence specialized for `tf.data`. Additional comments about
+// difference in functionality can be found in nest.py in
+// `tensorflow.python.data.util` and in the comments for Flatten above.
bool IsSequenceForData(PyObject* o);
-// IsSequence specialized for the data package. Additional comments about
-// difference in functionality can be found in nest.py in tensorflow.data.util
-// and in the comments for Flatten above.
+// Flatten specialized for `tf.data`. Additional comments about
+// difference in functionality can be found in nest.py in
+// `tensorflow.python.data.util` and in the comments for Flatten above.
PyObject* FlattenForData(PyObject* nested);
+// AssertSameStructure specialized for `tf.data`.
+PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
+ bool check_types);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6ad1484295..6d336ac39d 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -37,18 +37,70 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%feature("docstring") tensorflow::swig::IsSequence
+"""Returns a true if its input is a collections.Sequence (except strings).
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string and is a collections.Sequence or a
+ dict.
+"""
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
%unignore tensorflow::swig::IsNamedtuple;
%noexception tensorflow::swig::IsNamedtuple;
+%feature("docstring") tensorflow::swig::IsMapping
+"""Returns True iff `instance` is a `collections.Mapping`.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is a `collections.Mapping`.
+"""
+%unignore tensorflow::swig::IsMapping;
+%noexception tensorflow::swig::IsMapping;
+
+%feature("docstring") tensorflow::swig::SameNamedtuples
+"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
%noexception tensorflow::swig::SameNamedtuples;
%unignore tensorflow::swig::AssertSameStructure;
%noexception tensorflow::swig::AssertSameStructure;
+%feature("docstring") tensorflow::swig::Flatten
+"""Returns a flat list from a given nested structure.
+
+If `nest` is not a sequence, tuple, or dict, then returns a single-element
+list: `[nest]`.
+
+In the case of dict instances, the sequence consists of the values, sorted by
+key to ensure deterministic behavior. This is true also for `OrderedDict`
+instances: their sequence order is ignored, the sorting order of keys is
+used instead. The same convention is followed in `pack_sequence_as`. This
+correctly repacks dicts and `OrderedDict`s after they have been flattened,
+and also allows flattening an `OrderedDict` and then repacking it back using
+a corresponding plain dict, or vice-versa.
+Dictionaries with non-sortable keys cannot be flattened.
+
+Users must not modify any collections used in `nest` while this function is
+running.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object. Note, numpy
+ arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+
+Raises:
+ TypeError: The nest is or contains a dict with non-sortable keys.
+"""
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;
@@ -58,6 +110,9 @@ limitations under the License.
%unignore tensorflow::swig::FlattenForData;
%noexception tensorflow::swig::FlattenForData;
+%unignore tensorflow::swig::AssertSameStructureForData;
+%noexception tensorflow::swig::AssertSameStructureForData;
+
%include "tensorflow/python/util/util.h"
%unignoreall
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 7f851e3646..f25ed700d6 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -41,6 +41,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
#include <complex>
+#include <vector>
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 55408ab9ab..3c533c7f99 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3275,6 +3275,26 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ // Zero out the result buffer for strided conv backward filter for NHWC
+ // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
+ // zeroed.
+ //
+ // This wrong result caused by the bug is very flaky. It needs to be run for
+ // up to 20 times to produce a mismatch.
+ //
+ // TODO(timshen): add a nvbugs link.
+ if (CUDNN_VERSION >= 7100 &&
+ algorithm_config.algorithm().algo_id() ==
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
+ cudnn_type == CUDNN_DATA_HALF &&
+ input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
+ output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ (convolution_descriptor.vertical_filter_stride() > 1 ||
+ convolution_descriptor.horizontal_filter_stride() > 1)) {
+ stream->ThenMemZero(backward_filter_data, backward_filter_data->size());
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
@@ -3894,7 +3914,7 @@ bool CudnnSupport::DoDepthConcatenate(
for (size_t i = 0; i < input_data.size(); ++i) {
const auto& dimensions = input_dimensions[i];
tmp.resize(dimensions.ElementCount());
- stream->ThenMemcpyD2H<float>(*input_data[i], &tmp);
+ stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
port::Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 73f05b94db..e30f50ea2a 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -164,8 +164,8 @@ bool CUDAExecutor::FindOnDiskForComputeCapability(
VLOG(2) << "could not find compute-capability specific file at: "
<< cc_specific;
- if (port::FileExists(filename.ToString()).ok()) {
- *found_filename = filename.ToString();
+ if (port::FileExists(string(filename)).ok()) {
+ *found_filename = string(filename);
return true;
}
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index 114143b3ab..ea5dffd15e 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -121,7 +121,7 @@ static mutex& GetRpathMutex() {
/* static */ void DsoLoader::RegisterRpath(port::StringPiece path) {
mutex_lock lock{GetRpathMutex()};
- GetRpaths()->push_back(path.ToString());
+ GetRpaths()->emplace_back(path);
}
/* static */ port::Status DsoLoader::GetDsoHandle(port::StringPiece path,
@@ -131,7 +131,7 @@ static mutex& GetRpathMutex() {
return port::Status(port::error::INVALID_ARGUMENT,
"Only LoadKind::kLocal is currently supported");
}
- string path_string = path.ToString();
+ string path_string(path);
port::Status s =
port::Env::Default()->LoadLibrary(path_string.c_str(), dso_handle);
if (!s.ok()) {
@@ -154,7 +154,7 @@ static mutex& GetRpathMutex() {
/* static */ string DsoLoader::GetBinaryDirectory(bool strip_executable_name) {
string exe_path = port::Env::Default()->GetExecutablePath();
- return strip_executable_name ? port::Dirname(exe_path).ToString() : exe_path;
+ return strip_executable_name ? string(port::Dirname(exe_path)) : exe_path;
}
// Creates a heap-allocated vector for initial rpaths.
@@ -212,7 +212,7 @@ static std::vector<string>* CreatePrimordialRpaths() {
}
attempted.push_back(candidate);
- return library_name.ToString();
+ return string(library_name);
}
/* static */ string DsoLoader::GetCudaLibraryDirPath() {
diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc
index 7c1923da51..e84b7e6cc2 100644
--- a/tensorflow/stream_executor/kernel.cc
+++ b/tensorflow/stream_executor/kernel.cc
@@ -94,7 +94,7 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const {
static const char *kStubPrefix = "__device_stub_";
void KernelBase::set_name(port::StringPiece name) {
- name_ = std::string(name);
+ name_ = string(name);
port::StringPiece stubless_name = name;
if (tensorflow::str_util::StartsWith(name, kStubPrefix)) {
stubless_name.remove_prefix(strlen(kStubPrefix));
diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc
index 902892af3f..1eaa080699 100644
--- a/tensorflow/stream_executor/kernel_spec.cc
+++ b/tensorflow/stream_executor/kernel_spec.cc
@@ -18,11 +18,11 @@ limitations under the License.
namespace stream_executor {
KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname)
- : kernelname_(std::string(kernelname)) {}
+ : kernelname_(string(kernelname)) {}
OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(port::StringPiece filename,
port::StringPiece kernelname)
- : KernelLoaderSpec(kernelname), filename_(std::string(filename)) {}
+ : KernelLoaderSpec(kernelname), filename_(string(filename)) {}
CudaPtxOnDisk::CudaPtxOnDisk(port::StringPiece filename,
port::StringPiece kernelname)
@@ -161,7 +161,7 @@ OpenCLTextOnDisk::OpenCLTextOnDisk(port::StringPiece filename,
OpenCLTextInMemory::OpenCLTextInMemory(port::StringPiece text,
port::StringPiece kernelname)
- : KernelLoaderSpec(kernelname), text_(std::string(text)) {}
+ : KernelLoaderSpec(kernelname), text_(text) {}
OpenCLBinaryOnDisk::OpenCLBinaryOnDisk(port::StringPiece filename,
port::StringPiece kernelname)
diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h
index 3ef8deb72e..d78bbfd425 100644
--- a/tensorflow/stream_executor/lib/env.h
+++ b/tensorflow/stream_executor/lib/env.h
@@ -32,7 +32,7 @@ inline Status FileExists(const string& filename) {
}
inline Status FileExists(const port::StringPiece& filename) {
- return Env::Default()->FileExists(std::string(filename));
+ return Env::Default()->FileExists(string(filename));
}
} // namespace port
diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc
index 58a862206c..3d3da103e1 100644
--- a/tensorflow/stream_executor/lib/path.cc
+++ b/tensorflow/stream_executor/lib/path.cc
@@ -33,7 +33,7 @@ string JoinPathImpl(std::initializer_list<port::StringPiece> paths) {
if (path.empty()) continue;
if (result.empty()) {
- result = std::string(path);
+ result = string(path);
continue;
}
diff --git a/tensorflow/stream_executor/lib/statusor_internals.h b/tensorflow/stream_executor/lib/statusor_internals.h
index 09f88f5825..a159da57a2 100644
--- a/tensorflow/stream_executor/lib/statusor_internals.h
+++ b/tensorflow/stream_executor/lib/statusor_internals.h
@@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_
-
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h
index b02fe4f56f..e77dfcef76 100644
--- a/tensorflow/stream_executor/lib/str_util.h
+++ b/tensorflow/stream_executor/lib/str_util.h
@@ -31,7 +31,7 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix)
if (tensorflow::str_util::EndsWith(str, suffix)) {
str.remove_suffix(suffix.size());
}
- return std::string(str);
+ return string(str);
}
using tensorflow::str_util::Lowercase;
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9efd34de24..19d3b2389a 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -1959,7 +1959,9 @@ Stream *Stream::GetOrCreateSubStream() {
false);
Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init();
- CHECK(ok_) << "sub-stream failed to be initialized";
+ if (!sub_stream->ok_) {
+ LOG(ERROR) << "sub-stream failed to be initialized";
+ }
VLOG(1) << DebugStreamPointers() << " created new sub_stream "
<< sub_stream->DebugStreamPointers();
@@ -5285,12 +5287,11 @@ Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
VLOG_CALL(PARAM(callback));
- if (ok()) {
- CheckError(parent_->HostCallback(this, callback));
- } else {
+ if (!ok()) {
LOG(INFO) << DebugStreamPointers()
<< " was in error state before adding host callback";
}
+ CheckError(parent_->HostCallback(this, std::move(callback)));
return *this;
}
@@ -5298,12 +5299,11 @@ Stream &Stream::ThenDoHostCallbackWithStatus(
std::function<port::Status()> callback) {
VLOG_CALL(PARAM(callback));
- if (ok()) {
- CheckError(parent_->HostCallback(this, std::move(callback)));
- } else {
- LOG(WARNING) << "stream " << DebugStreamPointers()
- << " was in error state before adding host callback";
+ if (!ok()) {
+ LOG(INFO) << DebugStreamPointers()
+ << " was in error state before adding host callback";
}
+ CheckError(parent_->HostCallback(this, std::move(callback)));
return *this;
}
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 14e678d1ca..adac895a17 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -31,6 +31,10 @@ load(
"//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
)
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
def register_extension_info(**kwargs):
pass
@@ -233,6 +237,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_ngraph(["-DINTEL_NGRAPH=1"]) +
if_mkl_lnx_x64(["-fopenmp"]) +
if_android_arm(["-mfpu=neon"]) +
if_linux_x86_64(["-msse3"]) +
@@ -324,11 +329,16 @@ def tf_binary_additional_srcs():
],
)
+def _linux_kernel_dso_name(kernel_build_target):
+ """Given a build target, construct the dso name for linux."""
+ parts = kernel_build_target.split(":")
+ return "%s:libtfkernel_%s.so" % (parts[0], parts[1])
+
# Helper functions to add kernel dependencies to tf binaries when using dynamic
# kernel linking.
def tf_binary_dynamic_kernel_dsos(kernels):
return if_dynamic_kernels(
- extra_deps = ["libtfkernel_%s.so" % clean_dep(k) for k in kernels],
+ extra_deps = [_linux_kernel_dso_name(k) for k in kernels],
otherwise = [],
)
@@ -391,7 +401,7 @@ def tf_cc_binary(
srcs = srcs + tf_binary_additional_srcs(),
deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
[
- "//third_party/mkl:intel_binary_blob",
+ clean_dep("//third_party/mkl:intel_binary_blob"),
],
),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
@@ -729,7 +739,7 @@ def tf_cc_test(
}) + linkopts + _rpath_linkopts(name),
deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml(
[
- "//third_party/mkl:intel_binary_blob",
+ clean_dep("//third_party/mkl:intel_binary_blob"),
],
),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
@@ -791,6 +801,7 @@ def tf_cuda_cc_test(
extra_copts = [],
linkstatic = 0,
args = [],
+ kernels = [],
linkopts = []):
tf_cc_test(
name = name,
@@ -803,6 +814,7 @@ def tf_cuda_cc_test(
linkstatic = linkstatic,
linkopts = linkopts,
args = args,
+ kernels = kernels,
)
tf_cc_test(
name = name,
@@ -824,6 +836,7 @@ def tf_cuda_cc_test(
extra_copts = extra_copts,
linkopts = linkopts,
args = args,
+ kernels = kernels,
)
register_extension_info(
@@ -879,6 +892,7 @@ def tf_cc_tests(
size = "medium",
args = None,
linkopts = [],
+ kernels = [],
nocopts = None):
for src in srcs:
tf_cc_test(
@@ -891,6 +905,7 @@ def tf_cc_tests(
args = args,
linkopts = linkopts,
nocopts = nocopts,
+ kernels = kernels,
)
def tf_cc_test_mkl(
@@ -938,8 +953,9 @@ def tf_cc_tests_gpu(
linkstatic = 0,
tags = [],
size = "medium",
+ kernels = [],
args = None):
- tf_cc_tests(srcs, deps, linkstatic, tags = tags, size = size, args = args)
+ tf_cc_tests(srcs, deps, linkstatic, tags = tags, size = size, kernels = kernels, args = args)
def tf_cuda_cc_tests(
srcs,
@@ -949,6 +965,7 @@ def tf_cuda_cc_tests(
size = "medium",
linkstatic = 0,
args = None,
+ kernels = [],
linkopts = []):
for src in srcs:
tf_cuda_cc_test(
@@ -959,6 +976,7 @@ def tf_cuda_cc_tests(
size = size,
linkstatic = linkstatic,
args = args,
+ kernels = kernels,
linkopts = linkopts,
)
@@ -1347,12 +1365,13 @@ def transitive_hdrs(name, deps = [], **kwargs):
# Create a header only library that includes all the headers exported by
# the libraries in deps.
-def cc_header_only_library(name, deps = [], includes = [], **kwargs):
+def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kwargs):
_transitive_hdrs(name = name + "_gather", deps = deps)
native.cc_library(
name = name,
hdrs = [":" + name + "_gather"],
includes = includes,
+ deps = extra_deps,
**kwargs
)
@@ -1649,17 +1668,17 @@ def tf_py_wrap_cc(
# Note that this only works on Windows. See the definition of
# //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons.
# 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test.
-def py_test(deps = [], data = [], **kwargs):
+def py_test(deps = [], data = [], kernels = [], **kwargs):
native.py_test(
# TODO(jlebar): Ideally we'd use tcmalloc here.,
deps = select({
"//conditions:default": deps,
clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }),
+ }) + tf_binary_dynamic_kernel_deps(kernels),
data = data + select({
"//conditions:default": [],
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
- }),
+ }) + tf_binary_dynamic_kernel_dsos(kernels),
**kwargs
)
@@ -1678,6 +1697,7 @@ def tf_py_test(
tags = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
flaky = 0,
xla_enabled = False,
grpc_enabled = False):
@@ -1694,6 +1714,7 @@ def tf_py_test(
tags = tags,
visibility = [clean_dep("//tensorflow:internal")],
shard_count = shard_count,
+ kernels = kernels,
data = data,
deps = [
clean_dep("//tensorflow/python:extra_py_tests_deps"),
@@ -1717,6 +1738,7 @@ def cuda_py_test(
args = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
tags = [],
flaky = 0,
xla_enabled = False,
@@ -1732,6 +1754,7 @@ def cuda_py_test(
tags = test_tags,
shard_count = shard_count,
additional_deps = additional_deps,
+ kernels = kernels,
flaky = flaky,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
@@ -1751,6 +1774,7 @@ def sycl_py_test(
args = [],
shard_count = 1,
additional_deps = [],
+ kernels = [],
tags = [],
flaky = 0,
xla_enabled = False,
@@ -1766,6 +1790,7 @@ def sycl_py_test(
tags = test_tags,
shard_count = shard_count,
additional_deps = additional_deps,
+ kernels = kernels,
flaky = flaky,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
@@ -1781,6 +1806,7 @@ def py_tests(
srcs,
size = "medium",
additional_deps = [],
+ kernels = [],
data = [],
tags = [],
shard_count = 1,
@@ -1800,6 +1826,7 @@ def py_tests(
shard_count = shard_count,
data = data,
additional_deps = additional_deps,
+ kernels = kernels,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
)
@@ -1809,6 +1836,7 @@ def cuda_py_tests(
srcs,
size = "medium",
additional_deps = [],
+ kernels = [],
data = [],
shard_count = 1,
tags = [],
@@ -1825,6 +1853,7 @@ def cuda_py_tests(
tags = test_tags,
shard_count = shard_count,
prefix = prefix,
+ kernels = kernels,
xla_enabled = xla_enabled,
grpc_enabled = grpc_enabled,
)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index eb41deee13..9f6dcd8fdb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index e565b903d2..f3a515163d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt
index eac236d498..3add49e90d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "values"
mtype: "<type \'property\'>"
}
@@ -32,6 +36,10 @@ tf_class {
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "consumers"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "eval"
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
index 36b534af36..66a20547eb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt
@@ -10,6 +10,10 @@ tf_class {
mtype: "<enum \'VariableAggregation\'>"
}
member {
+ name: "ONLY_FIRST_TOWER"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
name: "SUM"
mtype: "<enum \'VariableAggregation\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
index e841c4ad89..05698b03ee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
@@ -53,15 +53,15 @@ tf_class {
}
member_method {
name: "assign"
- argspec: "args=[\'self\', \'value\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'value\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "assign_add"
- argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "assign_sub"
- argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "count_up_to"
@@ -92,8 +92,28 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "scatter_add"
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_add"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_sub"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_update"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "scatter_sub"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_update"
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "set_shape"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
index cf22e39d4c..082e26b99b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
index a363bceae3..7cc4191eb3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index c23b04b4ef..7027e78df4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 6878d28fff..d8167ea7cb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 0c6b7e4a82..718f415a77 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 9c1c072124..b23c019d6c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index 7391d4b07a..caa9e3f1de 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
index f50e375f7c..1f5e650940 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
index d72b576977..ebd3869c9b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
@@ -31,6 +31,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
index 154f171e89..53ec5a0c78 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
index 4d46d1e6b6..3791162619 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
index bf1f94b6ae..269e18a0a7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt
@@ -96,7 +96,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb3..b6942cb7ed 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt
new file mode 100644
index 0000000000..483d1f8ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.glorot_normal_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_normal_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt
new file mode 100644
index 0000000000..bb8540d0fd
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.glorot_uniform_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_uniform_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..4a81e52df9
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..815dc81dff
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
index bc0426f2f1..d499c67d89 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
@@ -5,6 +5,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -45,14 +53,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
index 3a36c168aa..8938cf217b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
@@ -25,6 +25,10 @@ tf_module {
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index e579fe6a1a..d843194ef0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 97688fcb0f..b8e9baca71 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
index 23cd02c0b0..26784ce55d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
index d98628f422..4110bda5f6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomUniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
index 86d48257c1..0451d0d73a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.TruncatedNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..ef0815972d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..439b5ada9b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
index 7485772784..8d0b5c242b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
index 8645e54302..1540c2915b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt
@@ -45,6 +45,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -89,14 +97,6 @@ tf_module {
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
index a6df1e87a3..bac8211a10 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
index 37a0fa0d55..ab0d74d071 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
index f97e93f0b7..358cca2b9c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.truncated_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
index 58186b1383..e6c731361a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
index 86e328888e..5510465d7b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
index b0ed545781..38ec8a0aff 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
index 42f98ed03d..41cb8e30bf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 000898a4be..9a7aaa8e96 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index 380b49f99c..c3dd2ad046 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 82db5e6137..cc303bf7b9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index b6ff688ec3..628447ce35 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
index b41290f8b0..f03c986c22 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index 88a033e61f..c440604aae 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index c1b9b96044..a01eaf8a12 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index f59f7727a3..0d6698f2ef 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
index 7d3744ed92..f1b23be48f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
index 3fd4ccdab2..0672cd5b7b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -107,7 +107,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
index ba21b50be4..b25ae1e82e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index 46f9fa2bbb..bb1918eba6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -188,7 +188,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
index c3ad326589..16e0fd5a31 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index fd9eb43066..065bb4d35b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
index 40d61688f2..543bae6fa9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index b8c227d725..c7ba6056f9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
index 095d35e574..072943dc2c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
index 8f99961198..222a1ef4fc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 96d522a016..8f4f7918ab 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
index de2824dab4..f939067178 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 1d563241d8..93c442bd55 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
index c87e52c537..471b18ef85 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
index dccf5523e3..0f250a09b7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
index 7ac4116d92..f52128483c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
index 024f72705d..98daf3bab1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
index 4e0233331b..64e7a9046b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
index 32d46ce8f3..6fdffef776 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
index 858486c725..3ac3825759 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index f65d750926..280ec8c25f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
index 2e71ef503d..560f66f9c7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
index 42533bcd21..c0543529c3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
index b5df169417..04eb2824b9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
index 0ea17919a9..f400432915 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
index a33248bc00..ab176b441a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 4ba21a25cd..c3895a0ac1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
index a7a570418e..a0fe598ab9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index 763bc23113..55e0d7ef02 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 3c50a3d7f2..38fbff5e4a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index ac78bdafad..5ea61d118d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index 275282d9d2..929f48df23 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index 0e31e6058b..2e6d59337f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index aacd0b1791..11dca17c6d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index c236548663..4e3e258430 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index 6b9c0290aa..fb9166316f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 0d7b2211e6..278429af6f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index d080ad6aed..87b7f6797a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index fcb0a109da..98bf96fa0c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 1d0e22abd0..935a69ab2f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index 653c9f547b..c9d4158d1c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index cdbaf82cf6..9953102ff9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
index 230c5e9034..2617f5a95f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 511456e740..e9f6ef45aa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 4a3492ebd6..ecdbf48157 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
index 2dff7a6de4..2e0b6bac24 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
index 7efa29be77..1e93d1118a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
@@ -97,7 +97,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 0ca8e0b52c..bfd36012a7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index ff19dcc3a3..5ad5990d7e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index 3c278fead6..40d03369a5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
index 850ecff974..86666b51bb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 7c69e31f9a..238d96cca6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
index fba42642d7..85f23df671 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
index 9c277411ea..235806b965 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 7c2f6ccc8a..4a45bf7997 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index 802178dba6..fda2562fc8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index e870dfe9ad..71d2d09a8d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
index c1337ce0cb..12949b39a6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
index ed27a62765..ab16d0021e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
index b9f05cb3e5..61ccbf5962 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 336d9f76fb..ce2320d703 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
index 46282217e0..69848af8cf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
index 42cd7e87ee..2b6e8af11d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
index 4d3de58bd1..413f45f018 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
index 9f094a877a..9c61ff6027 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
index 2f519a2438..baa91804c4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 6b93116ba0..15a5d6ac9e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index fd17115e27..be43bd5b3c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index 4b37a94478..6105992c7a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 5bdadca74a..1b6cf1e9ec 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 9dfda96fc8..29488a37f8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index 7b7684ccd2..182efb83b8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
index 3b15407fca..d29731ecf9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index 6d04415267..a6d7494ca7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 04950654d5..c36e802693 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index c424e6dcc8..9c46cfe40f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 6718e36dc6..8982f78794 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -106,7 +106,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -141,6 +141,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
index 740a03367b..ec2cc50298 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index a08c583adb..d7bc1980f3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
index c1294fed0f..fec2de6b49 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -103,7 +103,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index dc401d3ed0..3d285e7f17 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index 4b5165ae97..40a56a0c94 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 789af15fea..728eca415a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
index 0536a7cee7..da64e77c39 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index 8915353ec3..2f505f9293 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index 6efb5ef15a..f82c77072e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 4c33c5d0bf..54e01a9917 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 56914e1746..472b9818df 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index acfb3521c0..937516eff1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
index 8ba0e7480b..7ad4a32d43 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt
@@ -9,6 +9,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
+ name: "clone_model"
+ argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000000..e7e7d2839b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+ }
+ member_method {
+ name: "get"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_running"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "start"
+ argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+ }
+ member_method {
+ name: "stop"
+ argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index 4d7a1519ce..81b91d2780 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "OrderedEnqueuer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Progbar"
mtype: "<type \'type\'>"
}
@@ -45,6 +49,10 @@ tf_module {
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_source_inputs"
+ argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "multi_gpu_model"
argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index e606eab919..88b8f37c4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index 5deb02d569..a4483fefa2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 8a63b49180..381c4975d7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index db1aae2757..912365a28b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -155,6 +155,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 32fa151a8e..a4bb3219c7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 30c6c2ce3b..715bfd5fc7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 72b40cc9f7..b66c0f89cc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index a5c2b4aefd..faeb4f3513 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -150,6 +150,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index 61d5f04b22..caa2e60080 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 4de662fe33..dd9f7c49e0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -365,6 +365,14 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "glorot_normal_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "graph_util"
mtype: "<type \'module\'>"
}
@@ -789,6 +797,10 @@ tf_module {
argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "batch_scatter_update"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "batch_to_space"
argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1001,10 +1013,18 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "disable_resource_variables"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "div"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1029,10 +1049,18 @@ tf_module {
argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
+ name: "enable_resource_variables"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "encode_base64"
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "ensure_shape"
+ argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1162,7 +1190,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_seed"
@@ -1197,14 +1225,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
@@ -1273,6 +1293,10 @@ tf_module {
argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "init_scope"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "initialize_all_tables"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716023..614ba42d3e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
index bbfe395031..ba9e651b34 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
@@ -8,4 +8,12 @@ tf_module {
name: "cross_hashed"
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
+ member_method {
+ name: "expand_dims"
+ argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "eye"
+ argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1111..39b946b82f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
index eb41deee13..9f6dcd8fdb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
index e565b903d2..f3a515163d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@ tf_proto {
type: TYPE_STRING
}
field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
name: "executor_type"
number: 3
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ reserved_range {
+ start: 2
+ end: 3
+ }
}
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt
deleted file mode 100644
index 260c796fd6..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.FixedLengthRecordReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.FixedLengthRecordReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt
deleted file mode 100644
index 2eda320d63..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.IdentityReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.IdentityReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt
deleted file mode 100644
index f9b7e9bbca..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.LMDBReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.LMDBReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt
deleted file mode 100644
index f6a3ce76a1..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt
+++ /dev/null
@@ -1,45 +0,0 @@
-path: "tensorflow.ReaderBase"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'reader_ref\', \'supports_serialize\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt
index eac236d498..3add49e90d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "values"
mtype: "<type \'property\'>"
}
@@ -32,6 +36,10 @@ tf_class {
argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "consumers"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "eval"
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt
deleted file mode 100644
index cdf7937391..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.TFRecordReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.TFRecordReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt
deleted file mode 100644
index e9779f0762..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.TextLineReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.TextLineReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'skip_header_lines\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
index 36b534af36..66a20547eb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt
@@ -10,6 +10,10 @@ tf_class {
mtype: "<enum \'VariableAggregation\'>"
}
member {
+ name: "ONLY_FIRST_TOWER"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
name: "SUM"
mtype: "<enum \'VariableAggregation\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
index e841c4ad89..05698b03ee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
@@ -53,15 +53,15 @@ tf_class {
}
member_method {
name: "assign"
- argspec: "args=[\'self\', \'value\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'value\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "assign_add"
- argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "assign_sub"
- argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
}
member_method {
name: "count_up_to"
@@ -92,8 +92,28 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "scatter_add"
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_add"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_sub"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "scatter_nd_update"
+ argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "scatter_sub"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "scatter_update"
+ argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "set_shape"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt
deleted file mode 100644
index 4ac759891c..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.WholeFileReader"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.io_ops.WholeFileReader\'>"
- is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "reader_ref"
- mtype: "<type \'property\'>"
- }
- member {
- name: "supports_serialize"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_records_produced"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "num_work_units_completed"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read"
- argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_up_to"
- argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "reset"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "restore_state"
- argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "serialize_state"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
index cf22e39d4c..082e26b99b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
index a363bceae3..7cc4191eb3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index c23b04b4ef..7027e78df4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 6878d28fff..d8167ea7cb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 0c6b7e4a82..718f415a77 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 9c1c072124..b23c019d6c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index 7391d4b07a..caa9e3f1de 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
index f50e375f7c..1f5e650940 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
index d72b576977..ebd3869c9b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
@@ -31,6 +31,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
index 154f171e89..53ec5a0c78 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
index 4d46d1e6b6..3791162619 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
@@ -32,6 +32,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "export_saved_model"
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_savedmodel"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
index bf1f94b6ae..269e18a0a7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt
@@ -96,7 +96,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "replace"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94eb3..b6942cb7ed 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt
new file mode 100644
index 0000000000..483d1f8ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.glorot_normal_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_normal_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt
new file mode 100644
index 0000000000..bb8540d0fd
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.glorot_uniform_initializer.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.glorot_uniform_initializer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..4a81e52df9
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..815dc81dff
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
index bc0426f2f1..d499c67d89 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
@@ -5,6 +5,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -45,14 +53,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
index 3a36c168aa..8938cf217b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
@@ -25,6 +25,10 @@ tf_module {
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index e579fe6a1a..d843194ef0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 97688fcb0f..b8e9baca71 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
index 23cd02c0b0..26784ce55d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
index d98628f422..4110bda5f6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.RandomUniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
index 86d48257c1..0451d0d73a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.TruncatedNormal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000000..ef0815972d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000000..439b5ada9b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
index 7485772784..8d0b5c242b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
index 8645e54302..1540c2915b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
@@ -45,6 +45,14 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "identity"
mtype: "<type \'type\'>"
}
@@ -89,14 +97,6 @@ tf_module {
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
index a6df1e87a3..bac8211a10 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
index 37a0fa0d55..ab0d74d071 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.random_uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
index f97e93f0b7..358cca2b9c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.truncated_normal"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.TruncatedNormal\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
index 58186b1383..e6c731361a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt
@@ -1,11 +1,12 @@
path: "tensorflow.keras.initializers.uniform"
tf_class {
+ is_instance: "<class \'tensorflow.python.keras.initializers.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.RandomUniform\'>"
is_instance: "<class \'tensorflow.python.ops.init_ops.Initializer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'-0.05\', \'0.05\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
index 86e328888e..5510465d7b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
index b0ed545781..38ec8a0aff 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
index 42f98ed03d..41cb8e30bf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 000898a4be..9a7aaa8e96 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index 380b49f99c..c3dd2ad046 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 82db5e6137..cc303bf7b9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index b6ff688ec3..628447ce35 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
index b41290f8b0..f03c986c22 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index 88a033e61f..c440604aae 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index c1b9b96044..a01eaf8a12 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index f59f7727a3..0d6698f2ef 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
index 7d3744ed92..f1b23be48f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
index 3fd4ccdab2..0672cd5b7b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -107,7 +107,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
index ba21b50be4..b25ae1e82e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index 46f9fa2bbb..bb1918eba6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -188,7 +188,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
index c3ad326589..16e0fd5a31 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index fd9eb43066..065bb4d35b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
index 40d61688f2..543bae6fa9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index b8c227d725..c7ba6056f9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
index 095d35e574..072943dc2c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
index 8f99961198..222a1ef4fc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 96d522a016..8f4f7918ab 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
index de2824dab4..f939067178 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 1d563241d8..93c442bd55 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
index c87e52c537..471b18ef85 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
index dccf5523e3..0f250a09b7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
index 7ac4116d92..f52128483c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
index 024f72705d..98daf3bab1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
index 4e0233331b..64e7a9046b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
index 32d46ce8f3..6fdffef776 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
index 858486c725..3ac3825759 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index f65d750926..280ec8c25f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
index 2e71ef503d..560f66f9c7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
index 42533bcd21..c0543529c3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
index b5df169417..04eb2824b9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
index 0ea17919a9..f400432915 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
index a33248bc00..ab176b441a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 4ba21a25cd..c3895a0ac1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
index a7a570418e..a0fe598ab9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index 763bc23113..55e0d7ef02 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 3c50a3d7f2..38fbff5e4a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index ac78bdafad..5ea61d118d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index 275282d9d2..929f48df23 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index 0e31e6058b..2e6d59337f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index aacd0b1791..11dca17c6d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index c236548663..4e3e258430 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index 6b9c0290aa..fb9166316f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 0d7b2211e6..278429af6f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index d080ad6aed..87b7f6797a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index fcb0a109da..98bf96fa0c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 1d0e22abd0..935a69ab2f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index 653c9f547b..c9d4158d1c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index cdbaf82cf6..9953102ff9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
index 230c5e9034..2617f5a95f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 511456e740..e9f6ef45aa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 4a3492ebd6..ecdbf48157 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
index 2dff7a6de4..2e0b6bac24 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
index 7efa29be77..1e93d1118a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
@@ -97,7 +97,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 0ca8e0b52c..bfd36012a7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index ff19dcc3a3..5ad5990d7e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index 3c278fead6..40d03369a5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
index 850ecff974..86666b51bb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 7c69e31f9a..238d96cca6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
index fba42642d7..85f23df671 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
index 9c277411ea..235806b965 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 7c2f6ccc8a..4a45bf7997 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index 802178dba6..fda2562fc8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index e870dfe9ad..71d2d09a8d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
index c1337ce0cb..12949b39a6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
index ed27a62765..ab16d0021e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
index b9f05cb3e5..61ccbf5962 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 336d9f76fb..ce2320d703 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
index 46282217e0..69848af8cf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
index 42cd7e87ee..2b6e8af11d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
index 4d3de58bd1..413f45f018 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
index 9f094a877a..9c61ff6027 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
index 2f519a2438..baa91804c4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 6b93116ba0..15a5d6ac9e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index fd17115e27..be43bd5b3c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index 4b37a94478..6105992c7a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 5bdadca74a..1b6cf1e9ec 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 9dfda96fc8..29488a37f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -133,6 +133,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index 7b7684ccd2..182efb83b8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
index 3b15407fca..d29731ecf9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index 6d04415267..a6d7494ca7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 04950654d5..c36e802693 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index c424e6dcc8..9c46cfe40f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index 6718e36dc6..8982f78794 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -106,7 +106,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -141,6 +141,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
index 740a03367b..ec2cc50298 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index a08c583adb..d7bc1980f3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
index c1294fed0f..fec2de6b49 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -103,7 +103,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index dc401d3ed0..3d285e7f17 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index 4b5165ae97..40a56a0c94 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 789af15fea..728eca415a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
index 0536a7cee7..da64e77c39 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index 8915353ec3..2f505f9293 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index 6efb5ef15a..f82c77072e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 4c33c5d0bf..54e01a9917 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 56914e1746..472b9818df 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index acfb3521c0..937516eff1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
index 8ba0e7480b..7ad4a32d43 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt
@@ -9,6 +9,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
+ name: "clone_model"
+ argspec: "args=[\'model\', \'input_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "load_model"
argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000000..e7e7d2839b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+ is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+ }
+ member_method {
+ name: "get"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_running"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "start"
+ argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+ }
+ member_method {
+ name: "stop"
+ argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index 4d7a1519ce..81b91d2780 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "OrderedEnqueuer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Progbar"
mtype: "<type \'type\'>"
}
@@ -45,6 +49,10 @@ tf_module {
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_source_inputs"
+ argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "multi_gpu_model"
argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index e606eab919..88b8f37c4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index 5deb02d569..a4483fefa2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 8a63b49180..381c4975d7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index db1aae2757..912365a28b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -155,6 +155,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 32fa151a8e..a4bb3219c7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 30c6c2ce3b..715bfd5fc7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -152,6 +152,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 72b40cc9f7..b66c0f89cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index a5c2b4aefd..faeb4f3513 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -150,6 +150,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index 61d5f04b22..caa2e60080 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -151,6 +151,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "get_initial_state"
+ argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 4de662fe33..9332e16bf6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -61,10 +61,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "FixedLengthRecordReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "GIT_VERSION"
mtype: "<type \'str\'>"
}
@@ -109,10 +105,6 @@ tf_module {
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
- name: "IdentityReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "IndexedSlices"
mtype: "<type \'type\'>"
}
@@ -121,10 +113,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "LMDBReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "LogMessage"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
@@ -177,10 +165,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "ReaderBase"
- mtype: "<type \'type\'>"
- }
- member {
name: "RegisterGradient"
mtype: "<type \'type\'>"
}
@@ -225,10 +209,6 @@ tf_module {
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
- name: "TFRecordReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "Tensor"
mtype: "<type \'type\'>"
}
@@ -245,10 +225,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "TextLineReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "VERSION"
mtype: "<type \'str\'>"
}
@@ -273,10 +249,6 @@ tf_module {
mtype: "<class \'enum.EnumMeta\'>"
}
member {
- name: "WholeFileReader"
- mtype: "<type \'type\'>"
- }
- member {
name: "app"
mtype: "<type \'module\'>"
}
@@ -365,6 +337,14 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "glorot_normal_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform_initializer"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "graph_util"
mtype: "<type \'module\'>"
}
@@ -761,18 +741,6 @@ tf_module {
argspec: "args=[\'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "assign"
- argspec: "args=[\'ref\', \'value\', \'validate_shape\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "assign_add"
- argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "assign_sub"
- argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
name: "atan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -789,6 +757,10 @@ tf_module {
argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "batch_scatter_update"
+ argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "batch_to_space"
argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -925,10 +897,6 @@ tf_module {
argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'dtype\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'int64\'>\", \'None\', \'None\', \'None\'], "
}
member_method {
- name: "count_up_to"
- argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "create_partitioned_variables"
argspec: "args=[\'shape\', \'slicing\', \'initializer\', \'dtype\', \'trainable\', \'collections\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'True\', \'None\', \'None\', \'None\'], "
}
@@ -1005,6 +973,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1033,6 +1005,10 @@ tf_module {
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "ensure_shape"
+ argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1162,7 +1138,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_seed"
@@ -1197,14 +1173,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "glorot_normal_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
- name: "glorot_uniform_initializer"
- argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
@@ -1273,6 +1241,10 @@ tf_module {
argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "init_scope"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "initialize_all_tables"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
}
@@ -1729,10 +1701,6 @@ tf_module {
argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'False\', \'None\'], "
}
member_method {
- name: "scatter_add"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
name: "scatter_div"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
@@ -1753,26 +1721,6 @@ tf_module {
argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "scatter_nd_add"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_nd_sub"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_nd_update"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "scatter_sub"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_update"
- argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716023..614ba42d3e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
index bbfe395031..ba9e651b34 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
@@ -8,4 +8,12 @@ tf_module {
name: "cross_hashed"
argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
}
+ member_method {
+ name: "expand_dims"
+ argspec: "args=[\'sp_input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "eye"
+ argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt
deleted file mode 100644
index d84d0058ee..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt
+++ /dev/null
@@ -1,49 +0,0 @@
-path: "tensorflow.train.QueueRunner"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.queue_runner_impl.QueueRunner\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cancel_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "close_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "enqueue_ops"
- mtype: "<type \'property\'>"
- }
- member {
- name: "exceptions_raised"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue_closed_exception_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "create_threads"
- argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1111..39b946b82f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
is_instance: "<type \'tuple\'>"
member {
+ name: "axis"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "backup_initializer"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index 9f35395284..b21dabbde7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -145,10 +145,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "QueueRunner"
- mtype: "<type \'type\'>"
- }
- member {
name: "RMSPropOptimizer"
mtype: "<type \'type\'>"
}
@@ -236,10 +232,6 @@ tf_module {
name: "WorkerSessionCreator"
mtype: "<type \'type\'>"
}
- member {
- name: "queue_runner"
- mtype: "<type \'module\'>"
- }
member_method {
name: "MonitoredTrainingSession"
argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
@@ -249,10 +241,6 @@ tf_module {
argspec: "args=[\'filepattern\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "add_queue_runner"
- argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], "
- }
- member_method {
name: "assert_global_step"
argspec: "args=[\'global_step_tensor\'], varargs=None, keywords=None, defaults=None"
}
@@ -261,14 +249,6 @@ tf_module {
argspec: "args=[\'supervisor\', \'train_step_fn\', \'args\', \'kwargs\', \'master\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\'], "
}
member_method {
- name: "batch"
- argspec: "args=[\'tensors\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "batch_join"
- argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
- }
- member_method {
name: "checkpoint_exists"
argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None"
}
@@ -329,10 +309,6 @@ tf_module {
argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "input_producer"
- argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "inverse_time_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
@@ -341,10 +317,6 @@ tf_module {
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "limit_epochs"
- argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
name: "linear_cosine_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], "
}
@@ -365,22 +337,6 @@ tf_module {
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "maybe_batch"
- argspec: "args=[\'tensors\', \'keep_input\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "maybe_batch_join"
- argspec: "args=[\'tensors_list\', \'keep_input\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "maybe_shuffle_batch"
- argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "maybe_shuffle_batch_join"
- argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
name: "natural_exp_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
@@ -397,10 +353,6 @@ tf_module {
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
}
member_method {
- name: "range_input_producer"
- argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
- }
- member_method {
name: "remove_checkpoint"
argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], "
}
@@ -421,26 +373,6 @@ tf_module {
argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "shuffle_batch"
- argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "shuffle_batch_join"
- argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "slice_input_producer"
- argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
- }
- member_method {
- name: "start_queue_runners"
- argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], "
- }
- member_method {
- name: "string_input_producer"
- argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "summary_iterator"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt
deleted file mode 100644
index 23d402de30..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt
+++ /dev/null
@@ -1,49 +0,0 @@
-path: "tensorflow.train.queue_runner.QueueRunner"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.queue_runner_impl.QueueRunner\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cancel_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "close_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "enqueue_ops"
- mtype: "<type \'property\'>"
- }
- member {
- name: "exceptions_raised"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue"
- mtype: "<type \'property\'>"
- }
- member {
- name: "queue_closed_exception_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "create_threads"
- argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt
deleted file mode 100644
index 6e2d043049..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt
+++ /dev/null
@@ -1,15 +0,0 @@
-path: "tensorflow.train.queue_runner"
-tf_module {
- member {
- name: "QueueRunner"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "add_queue_runner"
- argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], "
- }
- member_method {
- name: "start_queue_runners"
- argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], "
- }
-}
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8764409e4d..4efa4a9651 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
py_test(
name = "api_compatibility_test",
- srcs = ["api_compatibility_test.py"],
+ srcs = [
+ "api_compatibility_test.py",
+ "//tensorflow:tf_python_api_gen_v2",
+ ],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc99c..99bed5714f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
from google.protobuf import message
from google.protobuf import text_format
@@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v1, visitor)
+ traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v2, visitor)
+ traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase):
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV1(self):
- if not hasattr(tf.compat, 'v1'):
- return
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v1, golden_file_pattern, api_version)
+ tf_v2.compat.v1, golden_file_pattern, api_version)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV2(self):
- if not hasattr(tf.compat, 'v2'):
- return
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v2, golden_file_pattern, api_version)
+ tf_v2, golden_file_pattern, api_version,
+ additional_private_map={'tf.compat': ['v1']})
if __name__ == '__main__':
diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake
index 4587bcf891..b7450c83de 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cmake
+++ b/tensorflow/tools/ci_build/Dockerfile.cmake
@@ -28,8 +28,8 @@ RUN pip install --upgrade astor
RUN pip install --upgrade gast
RUN pip install --upgrade numpy
RUN pip install --upgrade termcolor
-RUN pip install keras_applications==1.0.4
-RUN pip install keras_preprocessing==1.0.2
+RUN pip install keras_applications==1.0.5
+RUN pip install keras_preprocessing==1.0.3
# Install golang
RUN apt-get install -t xenial-backports -y golang-1.9
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu
index 383f9545c9..a4cad4b6c6 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu
@@ -30,4 +30,4 @@ RUN mkdir /usr/local/cuda-9.0/lib && \
# Configure the build for our CUDA configuration.
ENV TF_NEED_CUDA 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
+ENV TF_NEED_TENSORRT 1
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
new file mode 100644
index 0000000000..a30858db82
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -0,0 +1,83 @@
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 \
+# --tag "gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04" .
+# $ docker push gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04
+#
+# TODO(klimek): Include clang in this image so we can also target clang
+# builds.
+
+FROM ubuntu:14.04
+LABEL maintainer="Manuel Klimek <klimek@google.com>"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates apt-transport-https gnupg-curl && \
+ rm -rf /var/lib/apt/lists/* && \
+ NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \
+ NVIDIA_GPGKEY_FPR=ae09fe4bbd223a84b2ccfce3f60f4b3d7fa2af80 && \
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/7fa2af80.pub && \
+ apt-key adv --export --no-emit-version -a $NVIDIA_GPGKEY_FPR | tail -n +2 > cudasign.pub && \
+ echo "$NVIDIA_GPGKEY_SUM cudasign.pub" | sha256sum -c --strict - && rm cudasign.pub && \
+ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
+ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list
+
+ENV CUDA_VERSION 9.0.176
+ENV CUDA_PKG_VERSION 9-0=$CUDA_VERSION-1
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
+ENV NCCL_VERSION 2.2.13
+ENV CUDNN_VERSION 7.2.1.38
+
+# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
+# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
+# correct way to pass the path to bfd-ld is to pass
+# -Wl,-rpath-link=/usr/local/cuda/lib64/stubs to all binaries transitively
+# depending on libcuda. Optimally, builds targeting cuda would do that
+# internally.
+ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64/stubs
+
+LABEL com.nvidia.volumes.needed="nvidia_driver"
+LABEL com.nvidia.cuda.version="${CUDA_VERSION}"
+LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ cuda-cudart-$CUDA_PKG_VERSION \
+ cuda-libraries-$CUDA_PKG_VERSION \
+ cuda-cublas-9-0=9.0.176.4-1 \
+ libnccl2=$NCCL_VERSION-1+cuda9.0 \
+ cuda-libraries-dev-$CUDA_PKG_VERSION \
+ cuda-nvml-dev-$CUDA_PKG_VERSION \
+ cuda-minimal-build-$CUDA_PKG_VERSION \
+ cuda-command-line-tools-$CUDA_PKG_VERSION \
+ cuda-core-9-0=9.0.176.3-1 \
+ cuda-cublas-dev-9-0=9.0.176.4-1 \
+ libnccl-dev=$NCCL_VERSION-1+cuda9.0 \
+ libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 \
+ libcudnn7=$CUDNN_VERSION-1+cuda9.0 && \
+ ln -s cuda-9.0 /usr/local/cuda && \
+ apt-mark hold libnccl2 && \
+ apt-mark hold libcudnn7 libcudnn7-dev && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
+ echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
+
+# TODO(b/110903506): Provide a link to the SONAME of libcuda.so.
+# https://github.com/NVIDIA/nvidia-docker/issues/775
+RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
+# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
+# libnccl is resolved, delete this block.
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
+ && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_golang.sh
+
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
index 24ff4765a6..b656205836 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.gpu
@@ -19,8 +19,8 @@ RUN /install/install_golang.sh
# Install clang from pre-built package
RUN cd /tmp && \
- wget https://storage.googleapis.com/clang-builds-stable/clang-ubuntu16_04/clang_r323528.tar.gz && \
- echo "26752d9f5785df07193fac8316ba5d5ba3bec36d970c29a1577360848818ac74 clang_r323528.tar.gz" | sha256sum -c && \
+ wget https://storage.googleapis.com/clang-builds-stable/clang-ubuntu16_04/clang_r337145.tar.gz && \
+ echo "ab98c63eb09c04112cc992bc95ebc0dcea8c5e9d0760438789be2896cdc69ff8 clang_r337145.tar.gz" | sha256sum -c && \
tar -C /usr/local -xf clang_r323528.tar.gz && \
- rm clang_r323528.tar.gz
+ rm clang_r337145.tar.gz
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 993894d658..cc09784c1d 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -86,7 +86,7 @@
# When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX
# options, as this will replace the two.
# TF_SKIP_CONTRIB_TESTS:
-# If set to any non-empty or non-0 value, will skipp running
+# If set to any non-empty or non-0 value, will skip running
# contrib tests.
# TF_NIGHTLY:
# If this run is being used to build the tf_nightly pip
@@ -127,11 +127,19 @@ NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone"
DO_DOCKER=1
-BAZEL_CMD="bazel test"
-BAZEL_BUILD_ONLY_CMD="bazel build"
-BAZEL_CLEAN_CMD="bazel clean"
-DEFAULT_BAZEL_CONFIGS=""
+# Helpful flags:
+# --test_summary=detailed: Tell us more about which targets are being built
+# --keep_going: Don't stop at the first failure; tell us all the failures
+# --build_tests_only: Don't build targets depended on by tests if the test is
+# disabled. Also saves some compilation time. Otherwise,
+# tries to build everything.
+BAZEL_TEST_FLAGS="--test_summary=detailed --build_tests_only --keep_going"
+BAZEL_BUILD_FLAGS="--keep_going"
+
+BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS}"
+BAZEL_BUILD_ONLY_CMD="bazel build ${BAZEL_BUILD_FLAGS}"
+BAZEL_CLEAN_CMD="bazel clean"
PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh"
PIP_TEST_TUTORIALS_FLAG="--test_tutorials"
@@ -148,9 +156,7 @@ EXTRA_PARAMS=""
BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..."
if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then
- BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..."
-else
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..."
+ BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/..."
fi
TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data"
@@ -377,6 +383,10 @@ else
if [[ ${IS_MAC} == "1" ]]; then
EXTRA_ARGS="${EXTRA_ARGS},-nomac"
fi
+ EXTRA_ARGS="${EXTRA_ARGS} --build_tag_filters=-no_oss,-oss_serial,-benchmark-test"
+ if [[ ${IS_MAC} == "1" ]]; then
+ EXTRA_ARGS="${EXTRA_ARGS},-nomac"
+ fi
fi
# For any "tool" dependencies in genrules, Bazel will build them for host
@@ -385,7 +395,7 @@ fi
EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false"
if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] &&
- [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
+ [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD}
fi
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 866fe95d2b..a98c15d961 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -99,6 +99,7 @@ do_pylint() {
"^tensorflow/contrib/layers/python/layers/feature_column\.py.*\[E0110.*abstract-class-instantiated "\
"^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
+"^tensorflow/contrib/rate/rate\.py.*\[E0202.*method-hidden "\
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
"^tensorflow/python/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 9640810533..179fc42d60 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -67,6 +67,12 @@ apt-get install -y --no-install-recommends \
zip \
zlib1g-dev
+apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# populate the database
updatedb
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index bb316ecfc9..a9ae715c6a 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -115,10 +115,12 @@ pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
# Keras
-pip2 install keras_applications==1.0.4 --no-deps
-pip3 install keras_applications==1.0.4 --no-deps
-pip2 install keras_preprocessing==1.0.2 --no-deps
-pip3 install keras_preprocessing==1.0.2 --no-deps
+pip2 install keras_applications==1.0.5 --no-deps
+pip3 install keras_applications==1.0.5 --no-deps
+pip2 install keras_preprocessing==1.0.3 --no-deps
+pip3 install keras_preprocessing==1.0.3 --no-deps
+pip2 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==2.8.0
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 15e4396ce3..37e6b51f66 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -85,8 +85,9 @@ pip3.5 install --upgrade termcolor
pip3.5 install --upgrade setuptools==39.1.0
# Keras
-pip3.5 install keras_applications==1.0.4
-pip3.5 install keras_preprocessing==1.0.2
+pip3.5 install keras_applications==1.0.5
+pip3.5 install keras_preprocessing==1.0.3
+pip3.5 install --upgrade h5py==2.8.0
# Install last working version of setuptools.
pip3.5 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 0fc3eee71c..7520ff74cb 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -99,9 +99,10 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip3 install --upgrade setuptools==39.1.0
+pip3 install --upgrade h5py==2.8.0
# Keras
-pip3 install keras_applications==1.0.4
-pip3 install keras_preprocessing==1.0.2
+pip3 install keras_applications==1.0.5
+pip3 install keras_preprocessing==1.0.3
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index f958b3c9b7..60c974c36b 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -52,6 +52,7 @@ ${DOCKER_BINARY} run \
-e "PYTHON_BIN_PATH=/usr/bin/python" \
-e "TF_NEED_HDFS=0" \
-e "TF_NEED_CUDA=${TF_NEED_CUDA}" \
+ -e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \
-e "TF_NEED_OPENCL_SYCL=0" \
"${DOCKER_IMAGE}" \
"/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh"
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index 30c318a58f..4373d464b6 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -211,44 +211,6 @@ def update_readme(old_version, new_version):
"%s-" % pep_440_str, README_MD)
-def update_md_files(old_version, new_version):
- """Update the md doc files.
-
- Args:
- old_version: Version object of current version
- new_version: Version object of new version
- """
-
- old_pep_version = old_version.pep_440_str
- new_pep_version = new_version.pep_440_str
- for filename in ["linux", "mac", "windows", "sources"]:
- filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR,
- filename)
-
- if filename == "sources" and "rc0" in new_pep_version:
- replace_string_in_line("(?<!<td>)tensorflow-%s" % old_pep_version,
- "tensorflow-%s" % new_pep_version, filepath)
- replace_string_in_line("(?<!<td>)tensorflow_gpu-%s" % old_pep_version,
- "tensorflow_gpu-%s" % new_pep_version, filepath)
- else:
- replace_string_in_line("tensorflow-%s" % old_pep_version,
- "tensorflow-%s" % new_pep_version, filepath)
- replace_string_in_line("tensorflow_gpu-%s" % old_pep_version,
- "tensorflow_gpu-%s" % new_pep_version, filepath)
- replace_string_in_line("TensorFlow %s" % old_pep_version,
- "TensorFlow %s" % new_pep_version, filepath)
-
- for filename in ["java", "go", "c"]:
- filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR,
- filename)
- replace_string_in_line(r"x86_64-%s" % old_version,
- "x86_64-%s" % new_version, filepath)
- replace_string_in_line(r"libtensorflow-%s.jar" % old_version,
- "libtensorflow-%s.jar" % new_version, filepath)
- replace_string_in_line(r"<version>%s<\/version>" % old_version,
- "<version>%s</version>" % new_version, filepath)
-
-
def major_minor_change(old_version, new_version):
"""Check if a major or minor change occurred."""
major_mismatch = old_version.major != new_version.major
@@ -350,7 +312,6 @@ def main():
update_version_h(old_version, new_version)
update_setup_dot_py(old_version, new_version)
update_readme(old_version, new_version)
- update_md_files(old_version, new_version)
update_dockerfiles(old_version, new_version)
# Print transition details.
@@ -359,12 +320,6 @@ def main():
print("Patch: %s -> %s\n" % (old_version.patch, new_version.patch))
check_for_old_version(old_version, new_version)
- if "rc0" in str(new_version):
- print("\n\n\033[93mNOTE: Please update the tensorflow/docs_src/install/"
- "install_sources.md and add a line for tensorflow-%s and "
- "tensorflow_gpu-%s in the tested source configurations "
- "table.\033[0m\n" % (new_version.pep_440_str,
- new_version.pep_440_str))
if __name__ == "__main__":
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 0482cf619a..27b350e13e 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -27,7 +27,7 @@ function run_configure_for_gpu_build {
}
function set_remote_cache_options {
- echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}"
+ echo "build --remote_instance_name=projects/tensorflow-testing/instances/default_instance" >> "${TMP_BAZELRC}"
echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}"
echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}"
echo "build --tls_enabled=true" >> "${TMP_BAZELRC}"
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 333a89d3f5..c18f0d6e69 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -53,7 +53,7 @@ export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
# Setting default values to CUDA related environment variables
export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0}
-export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0}
+export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7}
export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7}
export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 47e0e5dd59..177ef390db 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -57,8 +57,7 @@ PY_TEST_DIR="py_test_dir"
SKIP_TEST=0
RELEASE_BUILD=0
-TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
- //${PY_TEST_DIR}/tensorflow/contrib/... "
+TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/..."
# --skip_test Skip running tests
# --enable_remote_cache Add options to enable remote cache for build and test
@@ -68,6 +67,7 @@ TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
# --test_contrib_only Use tensorflow/contrib/... as test target
for ARG in "$@"; do
case "$ARG" in
+ --tf_nightly) TF_NIGHTLY=1 ;;
--skip_test) SKIP_TEST=1 ;;
--enable_remote_cache) set_remote_cache_options ;;
--release_build) RELEASE_BUILD=1 ;;
@@ -86,6 +86,11 @@ else
export TF_OVERRIDE_EIGEN_STRONG_INLINE=1
fi
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ python tensorflow/tools/ci_build/update_version.py --nightly
+ EXTRA_PIP_FLAG="--nightly_flag"
+fi
+
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
@@ -104,7 +109,11 @@ fi
# Create a python test directory to avoid package name conflict
create_python_test_dir "${PY_TEST_DIR}"
-./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}"
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" "${EXTRA_PIP_FLAG}"
+
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ exit 0
+fi
# Running python tests on Windows needs pip package installed
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index e3eee11080..28d5565b98 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -57,8 +57,7 @@ PY_TEST_DIR="py_test_dir"
SKIP_TEST=0
RELEASE_BUILD=0
-TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
- //${PY_TEST_DIR}/tensorflow/contrib/... "
+TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/..."
# --skip_test Skip running tests
# --enable_remote_cache Add options to enable remote cache for build and test
@@ -68,6 +67,7 @@ TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \
# --test_contrib_only Use tensorflow/contrib/... as test target
for ARG in "$@"; do
case "$ARG" in
+ --tf_nightly) TF_NIGHTLY=1 ;;
--skip_test) SKIP_TEST=1 ;;
--enable_remote_cache) set_remote_cache_options ;;
--release_build) RELEASE_BUILD=1 ;;
@@ -86,6 +86,11 @@ else
export TF_OVERRIDE_EIGEN_STRONG_INLINE=1
fi
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ python tensorflow/tools/ci_build/update_version.py --nightly
+ EXTRA_PIP_FLAG="--nightly_flag"
+fi
+
# Enable short object file path to avoid long path issue on Windows.
echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
@@ -107,10 +112,14 @@ fi
# Create a python test directory to avoid package name conflict
create_python_test_dir "${PY_TEST_DIR}"
-./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}"
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" --gpu "${EXTRA_PIP_FLAG}"
+
+if [[ "$TF_NIGHTLY" == 1 ]]; then
+ exit 0
+fi
# Running python tests on Windows needs pip package installed
-PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
+PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow_gpu-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
TF_GPU_COUNT=${TF_GPU_COUNT:-8}
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
index 09933d266b..82bb0713c4 100644
--- a/tensorflow/tools/common/public_api.py
+++ b/tensorflow/tools/common/public_api.py
@@ -102,9 +102,10 @@ class PublicAPIVisitor(object):
"""Override the default root name of 'tf'."""
self._root_name = root_name
- def _is_private(self, path, name):
+ def _is_private(self, path, name, obj=None):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
+ del obj # Unused.
return ((path in self._private_map and
name in self._private_map[path]) or
(name.startswith('_') and not re.match('__.*__$', name) or
@@ -129,7 +130,7 @@ class PublicAPIVisitor(object):
# Remove things that are not visible.
for name, child in list(children):
- if self._is_private(full_path, name):
+ if self._is_private(full_path, name, child):
children.remove((name, child))
self._visitor(path, parent, children)
diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py
index 216aa41b60..7e66ad816a 100644
--- a/tensorflow/tools/compatibility/renames_v2.py
+++ b/tensorflow/tools/compatibility/renames_v2.py
@@ -65,6 +65,7 @@ renames = {
'tf.fft': 'tf.spectral.fft',
'tf.floor': 'tf.math.floor',
'tf.gather_nd': 'tf.manip.gather_nd',
+ 'tf.GraphKeys.VARIABLES': 'tf.GraphKeys.GLOBAL_VARIABLES',
'tf.greater': 'tf.math.greater',
'tf.greater_equal': 'tf.math.greater_equal',
'tf.ifft': 'tf.spectral.ifft',
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 01f37d8768..35a74c9664 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -35,7 +35,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
"""
def testArgRenames(self):
- with self.test_session():
+ with self.cached_session():
a = [[1., 2., 3.], [4., 5., 6.]]
b = [[True, False, False], [False, True, True]]
@@ -98,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
[[[1, 2]], [[3, 4]]])
def testArgMinMax(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(),
[0, 2])
@@ -113,7 +113,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
[1, 0, 0])
def testExpandAndSqueeze(self):
- with self.test_session():
+ with self.cached_session():
# TODO(aselle): sparse_split, sparse_reduce_sum,
# sparse_reduce_sum_sparse, reduce_join
@@ -140,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
a)
def testArithmeticRenames(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]])
vals = s.run(stuff)
self.assertAllEqual(vals,
@@ -164,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
# ]
def testBatchAndSvd(self):
- with self.test_session():
+ with self.cached_session():
mat = [[1., 2.], [2., 3.]]
batched_mat = tf.expand_dims(mat, [0])
result = tf.matmul(mat, mat).eval()
@@ -176,7 +176,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def testCrossEntropy(self):
# TODO(aselle): Test sparse_softmax_...
- with self.test_session():
+ with self.cached_session():
labels = [.8, .5, .2, .1]
logits = [.9, .1, .3, .1]
self.assertAllEqual(
@@ -191,7 +191,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
labels=labels, logits=logits).eval())
def testVariables(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
# make some variables
_ = [tf.Variable([1, 2, 3], dtype=tf.float32),
@@ -201,7 +201,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_ = [v.name for v in tf.local_variables()]
def testSummaries(self):
- with self.test_session() as s:
+ with self.cached_session() as s:
var = tf.Variable([1, 2, 3], dtype=tf.float32)
s.run(tf.initialize_all_variables())
x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256))
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
index a49035a1a0..e5ca8d3e2e 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
@@ -26,7 +26,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
"""Test various APIs that have been changed in 2.0."""
def testRenames(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1.04719755, tf.acos(0.5).eval())
self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 9702430a12..38216ce9b1 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import functools
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2
@@ -45,6 +46,29 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Specially handled functions.
self.function_handle = {}
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+ self.function_handle[decay] = functools.partial(
+ self._learning_rate_decay_handler, decay_name=decay)
+
+ @staticmethod
+ def _learning_rate_decay_handler(file_edit_recorder, node, decay_name):
+ comment = ("ERROR: %s has been changed to return a callable instead of a "
+ "tensor when graph building, but its functionality remains "
+ "unchanged during eager execution (returns a callable like "
+ "before). The converter cannot detect and fix this reliably, so "
+ "you need to inspect this usage manually.\n") % decay_name
+ file_edit_recorder.add(
+ comment,
+ node.lineno,
+ node.col_offset,
+ decay_name,
+ decay_name,
+ error="%s requires manual check." % decay_name)
if __name__ == "__main__":
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 57ac04de06..3886c1e8b9 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -63,6 +63,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n")
+ def testLearningRateDecay(self):
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+
+ text = "%s(a, b)\n" % decay
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
+
class TestUpgradeFiles(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index 2c31d784e5..b5a6c05193 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -29,10 +29,10 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
- numpy==1.14.5 \
+ numpy \
pandas \
scipy \
sklearn \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index bacdea72ce..39e7bc8b66 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -33,11 +33,11 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
- numpy==1.14.5 \
+ numpy \
scipy \
sklearn \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 4f89e3f701..e487779e7a 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -13,8 +13,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusparse-dev-9-0 \
curl \
git \
- libcudnn7=7.1.4.18-1+cuda9.0 \
- libcudnn7-dev=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libnccl-dev=2.2.13-1+cuda9.0 \
libcurl3-dev \
@@ -35,6 +35,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
# Link NCCL libray and header where the build script expects them.
RUN mkdir /usr/local/cuda-9.0/lib && \
ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
@@ -49,11 +55,11 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
- numpy==1.14.5 \
+ numpy \
scipy \
sklearn \
pandas \
@@ -100,6 +106,7 @@ RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.
ENV CI_BUILD_PYTHON python
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=9.0
ENV TF_CUDNN_VERSION=7
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
deleted file mode 100644
index 056b4755f4..0000000000
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
+++ /dev/null
@@ -1,117 +0,0 @@
-FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
-
-LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
-
-# It is possible to override these for releases.
-ARG TF_BRANCH=master
-ARG BAZEL_VERSION=0.15.0
-ARG TF_AVAILABLE_CPUS=32
-
-RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential \
- curl \
- git \
- golang \
- libcurl3-dev \
- libfreetype6-dev \
- libpng12-dev \
- libzmq3-dev \
- pkg-config \
- python-dev \
- python-pip \
- rsync \
- software-properties-common \
- unzip \
- zip \
- zlib1g-dev \
- openjdk-8-jdk \
- openjdk-8-jre-headless \
- wget \
- && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists/*
-
-RUN pip --no-cache-dir install --upgrade \
- pip setuptools
-
-RUN pip --no-cache-dir install \
- ipykernel \
- jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
- matplotlib \
- numpy \
- scipy \
- sklearn \
- pandas \
- wheel \
- && \
- python -m ipykernel.kernelspec
-
-# Set up our notebook config.
-COPY jupyter_notebook_config.py /root/.jupyter/
-
-# Jupyter has issues with being run directly:
-# https://github.com/ipython/ipython/issues/7062
-# We just add a little wrapper script.
-COPY run_jupyter.sh /
-
-# Set up Bazel.
-
-# Running bazel inside a `docker build` command causes trouble, cf:
-# https://github.com/bazelbuild/bazel/issues/134
-# The easiest solution is to set up a bazelrc file forcing --batch.
-RUN echo "startup --batch" >>/etc/bazel.bazelrc
-# Similarly, we need to workaround sandboxing issues:
-# https://github.com/bazelbuild/bazel/issues/418
-RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
- >>/etc/bazel.bazelrc
-WORKDIR /
-RUN mkdir /bazel && \
- cd /bazel && \
- wget --quiet https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- wget --quiet https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
- chmod +x bazel-*.sh && \
- ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
- rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
-
-# Download and build TensorFlow.
-WORKDIR /
-RUN git clone https://github.com/tensorflow/tensorflow.git && \
- cd tensorflow && \
- git checkout ${TF_BRANCH}
-WORKDIR /tensorflow
-
-# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON=python \
- LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH} \
- CUDNN_INSTALL_PATH=/usr/lib/x86_64-linux-gnu \
- PYTHON_BIN_PATH=/usr/bin/python \
- PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
- TF_NEED_CUDA=1 \
- TF_CUDA_VERSION=9.0 \
- TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1,7.0 \
- TF_CUDNN_VERSION=7
-RUN ./configure
-
-# Build and Install TensorFlow.
-RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
- LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
- bazel build -c opt \
- --config=cuda \
- --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
- --jobs=${TF_AVAILABLE_CPUS} \
- tensorflow/tools/pip_package:build_pip_package && \
- mkdir /pip_pkg && \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package /pip_pkg && \
- pip --no-cache-dir install --upgrade /pip_pkg/tensorflow-*.whl && \
- rm -rf /pip_pkg && \
- rm -rf /root/.cache
-# Clean up pip wheel and Bazel cache when done.
-
-WORKDIR /root
-
-# TensorBoard
-EXPOSE 6006
-# IPython
-EXPOSE 8888
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 2df770e525..371451d2aa 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -52,8 +52,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
index ab2eec1728..987b582d10 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -45,8 +45,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index aa0e0face1..781bf9e851 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cusolver-9-0 \
cuda-cusparse-9-0 \
curl \
- libcudnn7=7.1.4.18-1+cuda9.0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
libnccl2=2.2.13-1+cuda9.0 \
libfreetype6-dev \
libhdf5-serial-dev \
@@ -28,6 +28,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py
@@ -37,10 +42,10 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
- numpy==1.14.5 \
+ numpy \
pandas \
scipy \
sklearn \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index 69553302d8..641c9e3b16 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
index 756716ee0e..2b11679f54 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.4 \
- keras_preprocessing==1.0.2 \
+ keras_applications==1.0.5 \
+ keras_preprocessing==1.0.3 \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md
index a286e8a212..263f25bc48 100644
--- a/tensorflow/tools/docker/README.md
+++ b/tensorflow/tools/docker/README.md
@@ -1,3 +1,10 @@
+# WARNING: THESE IMAGES ARE DEPRECATED.
+
+TensorFlow's Dockerfiles are now located in
+[`tensorflow/tools/dockerfiles/`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles).
+
+This directory will eventually be removed.
+
# Using TensorFlow via Docker
This directory contains `Dockerfile`s to make it easy to get up and running with
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
new file mode 100644
index 0000000000..d64db35afb
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -0,0 +1,67 @@
+# TensorFlow Dockerfiles
+
+This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES
+MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from
+the files in `partials/` and the rules in `spec.yml`. See [the Contributing
+section](#contributing) for more information.
+
+## Building
+
+The Dockerfiles in the `dockerfiles` directory must have their build context set
+to **the directory with this README.md** to copy in helper files. For example:
+
+```bash
+$ docker build -f ./dockerfiles/cpu.Dockerfile -t tf .
+```
+
+Each Dockerfile has its own set of available `--build-arg`s which are documented
+in the Dockerfile itself.
+
+## Running
+
+After building the image with the tag `tf` (for example), use `docker run` to
+run the images. Examples are below.
+
+Note for new Docker users: the `-v` and `-u` flags share directories between
+the Docker container and your machine, and very important. Without
+`-v`, your work will be wiped once the container quits, and without `-u`, files
+created by the container will have the wrong file permissions on your host
+machine. If you are confused, check out the [Docker run
+documentation](https://docs.docker.com/engine/reference/run/).
+
+```bash
+# Volume mount (-v) is optional but highly recommended, especially for Jupyter.
+# User permissions (-u) are required if you use (-v).
+
+# CPU-based images
+$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+
+# GPU-based images (set up nvidia-docker2 first)
+$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+
+# Images with Jupyter run on port 8888, and needs a volume for notebooks
+$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf
+```
+
+These images do not come with the TensorFlow source code -- but the development
+images have git included, so you can `git clone` it yourself.
+
+## Contributing
+
+To make changes to TensorFlow's Dockerfiles, you'll update `spec.yml` and the
+`*.partial.Dockerfile` files in the `partials` directory, then run
+`assembler.py` to re-generate the full Dockerfiles before creating a pull
+request.
+
+You can use the `Dockerfile` in this directory to build an editing environment
+that has all of the Python dependencies you'll need:
+
+```bash
+$ docker build -t tf-assembler -f assembler.Dockerfile .
+
+# Set --user to set correct permissions on generated files
+$ docker run --user $(id -u):$(id -g) -it -v $(pwd):/tf tf-assembler bash
+
+# In the container...
+/tf $ python3 ./assembler.py -o dockerfiles -s spec.yml
+```
diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/tools/dockerfiles/assembler.Dockerfile
index 9c9fef471f..7a8e07fced 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py
+++ b/tensorflow/tools/dockerfiles/assembler.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,20 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+#
+# TensorFlow Dockerfile Development Container
+#
+# You can use this image to quickly develop changes to the Dockerfile assembler
+# or set of TF Docker partials. See README.md for usage instructions.
+FROM debian:stretch
+LABEL maintainer="Austin Anderson <angerson@google.com>"
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.estimator import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
+RUN apt-get update && apt-get install -y python3 python3-pip bash
+RUN pip3 install --upgrade pip setuptools pyyaml absl-py cerberus
-_allowed_symbols = [
- 'FisherEstimator',
- 'make_fisher_estimator',
-]
+WORKDIR /tf
+VOLUME ["/tf"]
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/assembler.py b/tensorflow/tools/dockerfiles/assembler.py
new file mode 100644
index 0000000000..9cdd9bb0cb
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/assembler.py
@@ -0,0 +1,554 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Assemble common TF Dockerfiles from many parts.
+
+This script constructs TF's Dockerfiles by aggregating partial
+Dockerfiles. See README.md for usage examples.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import errno
+import os
+import os.path
+import re
+import shutil
+import textwrap
+
+from absl import app
+from absl import flags
+import cerberus
+import yaml
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_boolean(
+ 'dry_run', False, 'Do not actually generate Dockerfiles', short_name='n')
+
+flags.DEFINE_string(
+ 'spec_file',
+ './spec.yml',
+ 'Path to a YAML specification file',
+ short_name='s')
+
+flags.DEFINE_string(
+ 'output_dir',
+ './dockerfiles', ('Path to an output directory for Dockerfiles. '
+ 'Will be created if it doesn\'t exist.'),
+ short_name='o')
+
+flags.DEFINE_string(
+ 'partial_dir',
+ './partials',
+ 'Path to a directory containing foo.partial.Dockerfile partial files.',
+ short_name='p')
+
+flags.DEFINE_boolean(
+ 'quiet_dry_run',
+ True,
+ 'Do not print contents of dry run Dockerfiles.',
+ short_name='q')
+
+flags.DEFINE_boolean(
+ 'validate', True, 'Validate generated Dockerfiles', short_name='c')
+
+# Schema to verify the contents of spec.yml with Cerberus.
+# Must be converted to a dict from yaml to work.
+# Note: can add python references with e.g.
+# !!python/name:builtins.str
+# !!python/name:__main__.funcname
+SCHEMA_TEXT = """
+header:
+ type: string
+
+partials:
+ type: dict
+ keyschema:
+ type: string
+ valueschema:
+ type: dict
+ schema:
+ desc:
+ type: string
+ args:
+ type: dict
+ keyschema:
+ type: string
+ valueschema:
+ anyof:
+ - type: [ boolean, number, string ]
+ - type: dict
+ schema:
+ default:
+ type: [ boolean, number, string ]
+ desc:
+ type: string
+ options:
+ type: list
+ schema:
+ type: string
+
+images:
+ keyschema:
+ type: string
+ valueschema:
+ type: dict
+ schema:
+ desc:
+ type: string
+ arg-defaults:
+ type: list
+ schema:
+ anyof:
+ - type: dict
+ keyschema:
+ type: string
+ arg_in_use: true
+ valueschema:
+ type: string
+ - type: string
+ isimage: true
+ create-dockerfile:
+ type: boolean
+ partials:
+ type: list
+ schema:
+ anyof:
+ - type: dict
+ keyschema:
+ type: string
+ regex: image
+ valueschema:
+ type: string
+ isimage: true
+ - type: string
+ ispartial: true
+"""
+
+
+class TfDockerValidator(cerberus.Validator):
+ """Custom Cerberus validator for TF dockerfile spec.
+
+ Note: Each _validate_foo function's docstring must end with a segment
+ describing its own validation schema, e.g. "The rule's arguments are...". If
+ you add a new validator, you can copy/paste that section.
+ """
+
+ def _validate_ispartial(self, ispartial, field, value):
+ """Validate that a partial references an existing partial spec.
+
+ Args:
+ ispartial: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if ispartial and value not in self.root_document.get('partials', dict()):
+ self._error(field, '{} is not an existing partial.'.format(value))
+
+ def _validate_isimage(self, isimage, field, value):
+ """Validate that an image references an existing partial spec.
+
+ Args:
+ isimage: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if isimage and value not in self.root_document.get('images', dict()):
+ self._error(field, '{} is not an existing image.'.format(value))
+
+ def _validate_arg_in_use(self, arg_in_use, field, value):
+ """Validate that an arg references an existing partial spec's args.
+
+ Args:
+ arg_in_use: Value of the rule, a bool
+ field: The field being validated
+ value: The field's value
+
+ The rule's arguments are validated against this schema:
+ {'type': 'boolean'}
+ """
+ if arg_in_use:
+ for partial in self.root_document.get('partials', dict()).values():
+ if value in partial.get('args', tuple()):
+ return
+
+ self._error(field, '{} is not an arg used in any partial.'.format(value))
+
+
+def build_partial_description(partial_spec):
+ """Create the documentation lines for a specific partial.
+
+ Generates something like this:
+
+ # This is the partial's description, from spec.yml.
+ # --build-arg ARG_NAME=argdefault
+ # this is one of the args.
+ # --build-arg ANOTHER_ARG=(some|choices)
+ # another arg.
+
+ Args:
+ partial_spec: A dict representing one of the partials from spec.yml. Doesn't
+ include the name of the partial; is a dict like { desc: ..., args: ... }.
+
+ Returns:
+ A commented string describing this partial.
+ """
+
+ # Start from linewrapped desc field
+ lines = []
+ wrapper = textwrap.TextWrapper(
+ initial_indent='# ', subsequent_indent='# ', width=80)
+ description = wrapper.fill(partial_spec.get('desc', '( no comments )'))
+ lines.extend(['#', description])
+
+ # Document each arg
+ for arg, arg_data in partial_spec.get('args', dict()).items():
+ # Wrap arg description with comment lines
+ desc = arg_data.get('desc', '( no description )')
+ desc = textwrap.fill(
+ desc,
+ initial_indent='# ',
+ subsequent_indent='# ',
+ width=80,
+ drop_whitespace=False)
+
+ # Document (each|option|like|this)
+ if 'options' in arg_data:
+ arg_options = ' ({})'.format('|'.join(arg_data['options']))
+ else:
+ arg_options = ''
+
+ # Add usage sample
+ arg_use = '# --build-arg {}={}{}'.format(arg,
+ arg_data.get('default', '(unset)'),
+ arg_options)
+ lines.extend([arg_use, desc])
+
+ return '\n'.join(lines)
+
+
+def construct_contents(partial_specs, image_spec):
+ """Assemble the dockerfile contents for an image spec.
+
+ It assembles a concrete list of partial references into a single, large
+ string.
+ Also expands argument defaults, so that the resulting Dockerfile doesn't have
+ to be configured with --build-arg=... every time. That is, any ARG directive
+ will be updated with a new default value.
+
+ Args:
+ partial_specs: The dict from spec.yml["partials"].
+ image_spec: One of the dict values from spec.yml["images"].
+
+ Returns:
+ A string containing a valid Dockerfile based on the partials listed in
+ image_spec.
+ """
+ processed_partial_strings = []
+ for partial_name in image_spec['partials']:
+ # Apply image arg-defaults to existing arg defaults
+ partial_spec = copy.deepcopy(partial_specs[partial_name])
+ args = partial_spec.get('args', dict())
+ for k_v in image_spec.get('arg-defaults', []):
+ arg, value = list(k_v.items())[0]
+ if arg in args:
+ args[arg]['default'] = value
+
+ # Read partial file contents
+ filename = partial_spec.get('file', partial_name)
+ partial_path = os.path.join(FLAGS.partial_dir,
+ '{}.partial.Dockerfile'.format(filename))
+ with open(partial_path, 'r') as f_partial:
+ partial_contents = f_partial.read()
+
+ # Replace ARG FOO=BAR with ARG FOO=[new-default]
+ for arg, arg_data in args.items():
+ if 'default' in arg_data and arg_data['default']:
+ default = '={}'.format(arg_data['default'])
+ else:
+ default = ''
+ partial_contents = re.sub(r'ARG {}.*'.format(arg), 'ARG {}{}'.format(
+ arg, default), partial_contents)
+
+ # Store updated partial contents
+ processed_partial_strings.append(partial_contents)
+
+ # Join everything together
+ return '\n'.join(processed_partial_strings)
+
+
+def mkdir_p(path):
+ """Create a directory and its parents, even if it already exists."""
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def construct_documentation(header, partial_specs, image_spec):
+ """Assemble all of the documentation for a single dockerfile.
+
+ Builds explanations of included partials and available build args.
+
+ Args:
+ header: The string from spec.yml["header"]; will be commented and wrapped.
+ partial_specs: The dict from spec.yml["partials"].
+ image_spec: The spec for the dockerfile being built.
+
+ Returns:
+ A string containing a commented header that documents the contents of the
+ dockerfile.
+
+ """
+ # Comment and wrap header and image description
+ commented_header = '\n'.join(
+ [('# ' + l).rstrip() for l in header.splitlines()])
+ commented_desc = '\n'.join(
+ ['# ' + l for l in image_spec.get('desc', '').splitlines()])
+ partial_descriptions = []
+
+ # Build documentation for each partial in the image
+ for partial in image_spec['partials']:
+ # Copy partial data for default args unique to this image
+ partial_spec = copy.deepcopy(partial_specs[partial])
+ args = partial_spec.get('args', dict())
+
+ # Overwrite any existing arg defaults
+ for k_v in image_spec.get('arg-defaults', []):
+ arg, value = list(k_v.items())[0]
+ if arg in args:
+ args[arg]['default'] = value
+
+ # Build the description from new args
+ partial_description = build_partial_description(partial_spec)
+ partial_descriptions.append(partial_description)
+
+ contents = [commented_header, '#', commented_desc] + partial_descriptions
+ return '\n'.join(contents) + '\n'
+
+
+def normalize_partial_args(partial_specs):
+ """Normalize the shorthand form of a partial's args specification.
+
+ Turns this:
+
+ partial:
+ args:
+ SOME_ARG: arg_value
+
+ Into this:
+
+ partial:
+ args:
+ SOME_ARG:
+ default: arg_value
+
+ Args:
+ partial_specs: The dict from spec.yml["partials"]. This dict is modified in
+ place.
+
+ Returns:
+ The modified contents of partial_specs.
+
+ """
+ for _, partial in partial_specs.items():
+ args = partial.get('args', dict())
+ for arg, value in args.items():
+ if not isinstance(value, dict):
+ new_value = {'default': value}
+ args[arg] = new_value
+
+ return partial_specs
+
+
+def flatten_args_references(image_specs):
+ """Resolve all default-args in each image spec to a concrete dict.
+
+ Turns this:
+
+ example-image:
+ arg-defaults:
+ - MY_ARG: ARG_VALUE
+
+ another-example:
+ arg-defaults:
+ - ANOTHER_ARG: ANOTHER_VALUE
+ - example_image
+
+ Into this:
+
+ example-image:
+ arg-defaults:
+ - MY_ARG: ARG_VALUE
+
+ another-example:
+ arg-defaults:
+ - ANOTHER_ARG: ANOTHER_VALUE
+ - MY_ARG: ARG_VALUE
+
+ Args:
+ image_specs: A dict of image_spec dicts; should be the contents of the
+ "images" key in the global spec.yaml. This dict is modified in place and
+ then returned.
+
+ Returns:
+ The modified contents of image_specs.
+ """
+ for _, image_spec in image_specs.items():
+ too_deep = 0
+ while str in map(type, image_spec.get('arg-defaults', [])) and too_deep < 5:
+ new_args = []
+ for arg in image_spec['arg-defaults']:
+ if isinstance(arg, str):
+ new_args.extend(image_specs[arg]['arg-defaults'])
+ else:
+ new_args.append(arg)
+
+ image_spec['arg-defaults'] = new_args
+ too_deep += 1
+
+ return image_specs
+
+
+def flatten_partial_references(image_specs):
+ """Resolve all partial references in each image spec to a concrete list.
+
+ Turns this:
+
+ example-image:
+ partials:
+ - foo
+
+ another-example:
+ partials:
+ - bar
+ - image: example-image
+ - bat
+
+ Into this:
+
+ example-image:
+ partials:
+ - foo
+
+ another-example:
+ partials:
+ - bar
+ - foo
+ - bat
+ Args:
+ image_specs: A dict of image_spec dicts; should be the contents of the
+ "images" key in the global spec.yaml. This dict is modified in place and
+ then returned.
+
+ Returns:
+ The modified contents of image_specs.
+ """
+ for _, image_spec in image_specs.items():
+ too_deep = 0
+ while dict in map(type, image_spec['partials']) and too_deep < 5:
+ new_partials = []
+ for partial in image_spec['partials']:
+ if isinstance(partial, str):
+ new_partials.append(partial)
+ else:
+ new_partials.extend(image_specs[partial['image']]['partials'])
+
+ image_spec['partials'] = new_partials
+ too_deep += 1
+
+ return image_specs
+
+
+def construct_dockerfiles(tf_spec):
+ """Generate a mapping of {"cpu": <cpu dockerfile contents>, ...}.
+
+ Args:
+ tf_spec: The full spec.yml loaded as a python object.
+
+ Returns:
+ A string:string dict of short names ("cpu-devel") to Dockerfile contents.
+ """
+ names_to_contents = dict()
+ image_specs = tf_spec['images']
+ image_specs = flatten_partial_references(image_specs)
+ image_specs = flatten_args_references(image_specs)
+ partial_specs = tf_spec['partials']
+ partial_specs = normalize_partial_args(partial_specs)
+
+ for name, image_spec in image_specs.items():
+ if not image_spec.get('create-dockerfile', True):
+ continue
+ documentation = construct_documentation(tf_spec['header'], partial_specs,
+ image_spec)
+ contents = construct_contents(partial_specs, image_spec)
+ names_to_contents[name] = '\n'.join([documentation, contents])
+
+ return names_to_contents
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Unexpected command line args found: {}'.format(argv))
+
+ with open(FLAGS.spec_file, 'r') as spec_file:
+ tf_spec = yaml.load(spec_file)
+
+ # Abort if spec.yaml is invalid
+ if FLAGS.validate:
+ schema = yaml.load(SCHEMA_TEXT)
+ v = TfDockerValidator(schema)
+ if not v.validate(tf_spec):
+ print('>> ERROR: {} is an invalid spec! The errors are:'.format(
+ FLAGS.spec_file))
+ print(yaml.dump(v.errors, indent=2))
+ exit(1)
+ else:
+ print('>> WARNING: Not validating {}'.format(FLAGS.spec_file))
+
+ # Generate mapping of { "cpu-devel": "<cpu-devel dockerfile contents>", ... }
+ names_to_contents = construct_dockerfiles(tf_spec)
+
+ # Write each completed Dockerfile
+ if not FLAGS.dry_run:
+ print('>> Emptying destination dir "{}"'.format(FLAGS.output_dir))
+ shutil.rmtree(FLAGS.output_dir, ignore_errors=True)
+ mkdir_p(FLAGS.output_dir)
+ else:
+ print('>> Skipping creation of {} (dry run)'.format(FLAGS.output_dir))
+ for name, contents in names_to_contents.items():
+ path = os.path.join(FLAGS.output_dir, name + '.Dockerfile')
+ if FLAGS.dry_run:
+ print('>> Skipping writing contents of {} (dry run)'.format(path))
+ print(contents)
+ else:
+ mkdir_p(FLAGS.output_dir)
+ print('>> Writing {}'.format(path))
+ with open(path, 'w') as f:
+ f.write(contents)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tensorflow/tools/dockerfiles/bashrc b/tensorflow/tools/dockerfiles/bashrc
new file mode 100644
index 0000000000..48cacf20f6
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/bashrc
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+export PS1="\[\e[31m\]tf-docker\[\e[m\] \[\e[33m\]\w\[\e[m\] > "
+export TERM=xterm-256color
+alias grep="grep --color=auto"
+alias ls="ls --color=auto"
+
+echo -e "\e[1;31m"
+cat<<TF
+________ _______________
+___ __/__________________________________ ____/__ /________ __
+__ / _ _ \_ __ \_ ___/ __ \_ ___/_ /_ __ /_ __ \_ | /| / /
+_ / / __/ / / /(__ )/ /_/ / / _ __/ _ / / /_/ /_ |/ |/ /
+/_/ \___//_/ /_//____/ \____//_/ /_/ /_/ \____/____/|__/
+
+TF
+echo -e "\e[0;33m"
+
+if [[ $EUID -eq 0 ]]; then
+ cat <<WARN
+WARNING: You are running this container as root, which can cause new files in
+mounted volumes to be created as the root user on your host machine.
+
+To avoid this, run the container by specifying your user's userid:
+
+$ docker run -u \$(id -u):\$(id -g) args...
+WARN
+else
+ cat <<EXPL
+You are running this container as user with ID $(id -u) and group $(id -g),
+which should map to the ID and group for your user on the Docker host. Great!
+EXPL
+fi
+
+# Turn off colors
+echo -e "\e[m"
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile
new file mode 100644
index 0000000000..dbbad7d03a
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel-jupyter.Dockerfile
@@ -0,0 +1,100 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for developing changes for TensorFlow, with Jupyter included.
+#
+# Start from Ubuntu, with TF development packages (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile
new file mode 100644
index 0000000000..160d7c02e2
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-devel.Dockerfile
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for developing changes for TensorFlow.
+#
+# Start from Ubuntu, with TF development packages (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
new file mode 100644
index 0000000000..8d5d653ab7
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
@@ -0,0 +1,69 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter included.
+#
+# Start from Ubuntu (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
new file mode 100644
index 0000000000..35c41b49fd
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, CPU-only environment for using TensorFlow
+#
+# Start from Ubuntu (no GPU support)
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
new file mode 100644
index 0000000000..68c0e2f2bd
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel-jupyter.Dockerfile
@@ -0,0 +1,126 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow, with Jupyter included.
+#
+# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development
+# packages.
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
new file mode 100644
index 0000000000..77be0dd287
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-devel.Dockerfile
@@ -0,0 +1,115 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for TensorFlow.
+#
+# Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF development
+# packages.
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the latest version of Bazel and Python development tools.
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
new file mode 100644
index 0000000000..5ff1fa917a
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia-jupyter.Dockerfile
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with Jupyter included.
+#
+# NVIDIA with CUDA and CuDNN, no dev stuff
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+#
+# Launch Jupyter on execution instead of a bash prompt.
+
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow-gpu
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
new file mode 100644
index 0000000000..3df810b5fe
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/nvidia.Dockerfile
@@ -0,0 +1,84 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# below. Please refer to the the TensorFlow dockerfiles documentation for
+# more information. Build args are documented as their default value.
+#
+# Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow.
+#
+# NVIDIA with CUDA and CuDNN, no dev stuff
+# --build-arg UBUNTU_VERSION=16.04
+# ( no description )
+#
+# Python is required for TensorFlow and other libraries.
+# --build-arg USE_PYTHON_3_NOT_2=True
+# Install python 3 over Python 2
+#
+# Install the TensorFlow Python package.
+# --build-arg TF_PACKAGE=tensorflow-gpu (tensorflow|tensorflow-gpu|tf-nightly|tf-nightly-gpu)
+# The specific TensorFlow Python package to install
+#
+# Configure TensorFlow's shell prompt and login tools.
+
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
+
+ARG USE_PYTHON_3_NOT_2=True
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
+
+ARG TF_PACKAGE=tensorflow-gpu
+RUN ${PIP} install ${TF_PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile
new file mode 100644
index 0000000000..b08d8bdd14
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/bazel.partial.Dockerfile
@@ -0,0 +1,13 @@
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+# Install bazel
+RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
+ curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \
+ apt-get update && \
+ apt-get install -y bazel
diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
new file mode 100644
index 0000000000..2c9b9f3f9a
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
@@ -0,0 +1,8 @@
+RUN ${PIP} install jupyter
+
+RUN mkdir /notebooks && chmod a+rwx /notebooks
+RUN mkdir /.local && chmod a+rwx /.local
+WORKDIR /notebooks
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
new file mode 100644
index 0000000000..45159f711f
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/nvidia-devel.partial.Dockerfile
@@ -0,0 +1,49 @@
+ARG UBUNTU_VERSION=16.04
+FROM nvidia/cuda:9.0-base-ubuntu${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-dev-9-0 \
+ cuda-cudart-dev-9-0 \
+ cuda-cufft-dev-9-0 \
+ cuda-curand-dev-9-0 \
+ cuda-cusolver-dev-9-0 \
+ cuda-cusparse-dev-9-0 \
+ curl \
+ git \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libcudnn7-dev=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libnccl-dev=2.2.13-1+cuda9.0 \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ && \
+ rm -rf /var/lib/apt/lists/* && \
+ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+ apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
+# Link NCCL libray and header where the build script expects them.
+RUN mkdir /usr/local/cuda-9.0/lib && \
+ ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \
+ ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h
+
+# TODO(tobyboyd): Remove after license is excluded from BUILD file.
+RUN gunzip /usr/share/doc/libnccl2/NCCL-SLA.txt.gz && \
+ cp /usr/share/doc/libnccl2/NCCL-SLA.txt /usr/local/cuda/
diff --git a/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
new file mode 100644
index 0000000000..1064390af3
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/nvidia.partial.Dockerfile
@@ -0,0 +1,28 @@
+FROM nvidia/cuda:9.0-base-ubuntu16.04
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-9-0 \
+ cuda-cublas-9-0 \
+ cuda-cufft-9-0 \
+ cuda-curand-9-0 \
+ cuda-cusolver-9-0 \
+ cuda-cusparse-9-0 \
+ libcudnn7=7.2.1.38-1+cuda9.0 \
+ libnccl2=2.2.13-1+cuda9.0 \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+ apt-get update && \
+ apt-get install libnvinfer4=4.1.2-1+cuda9.0
diff --git a/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile
new file mode 100644
index 0000000000..6f346236a5
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/python.partial.Dockerfile
@@ -0,0 +1,12 @@
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} install --upgrade \
+ pip \
+ setuptools
diff --git a/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile
new file mode 100644
index 0000000000..d641a11b06
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/shell.partial.Dockerfile
@@ -0,0 +1,2 @@
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile
new file mode 100644
index 0000000000..96e79547f0
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/tensorflow.partial.Dockerfile
@@ -0,0 +1,2 @@
+ARG TF_PACKAGE
+RUN ${PIP} install ${TF_PACKAGE}
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile
new file mode 100644
index 0000000000..bc79272276
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu-devel.partial.Dockerfile
@@ -0,0 +1,24 @@
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile
new file mode 100644
index 0000000000..0a50735bf8
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu.partial.Dockerfile
@@ -0,0 +1,2 @@
+ARG UBUNTU_VERSION=16.04
+FROM ubuntu:${UBUNTU_VERSION}
diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml
new file mode 100644
index 0000000000..28bf9a55da
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/spec.yml
@@ -0,0 +1,195 @@
+# ======
+# HEADER
+# ======
+#
+# This is commented-out and prepended to each generated Dockerfile.
+header: |
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ============================================================================
+
+ THIS IS A GENERATED DOCKERFILE.
+
+ This file was assembled from multiple pieces, whose use is documented
+ below. Please refer to the the TensorFlow dockerfiles documentation for
+ more information. Build args are documented as their default value.
+
+# ========
+# PARTIALS
+# ========
+#
+# Represent and document pieces of a Dockerfile. Spec:
+#
+# name: the name of the partial, is referenced from the images section
+# desc: A description, inserted later into the Dockerfile
+# file: Alternative file prefix, e.g. file.partial.Dockerfile. The default is
+# the name of the partial.
+# args: A dict of ARGs in the Dockerfile; each entry has the format
+# ARG_NAME: VALUE where VALUE is one of:
+# - a dict:
+# desc: Documentation for the arg
+# default: Default value for the arg; is written to the Dockerfile
+# options: List of strings, part of documentation
+# - a concrete value: the same as a dictionary with default: [value].
+
+partials:
+ ubuntu:
+ desc: Start from Ubuntu (no GPU support)
+ args:
+ UBUNTU_VERSION: 16.04
+
+ ubuntu-devel:
+ desc: Start from Ubuntu, with TF development packages (no GPU support)
+ args:
+ UBUNTU_VERSION: 16.04
+
+ bazel:
+ desc: Install the latest version of Bazel and Python development tools.
+
+ nvidia:
+ desc: NVIDIA with CUDA and CuDNN, no dev stuff
+ args:
+ UBUNTU_VERSION: 16.04
+
+ nvidia-devel:
+ desc: >
+ Start from Nvidia's Ubuntu base image with CUDA and CuDNN, with TF
+ development packages.
+ args:
+ UBUNTU_VERSION: 16.04
+
+ python:
+ desc: Python is required for TensorFlow and other libraries.
+ args:
+ USE_PYTHON_3_NOT_2:
+ default: true
+ desc: Install python 3 over Python 2
+
+ tensorflow:
+ desc: Install the TensorFlow Python package.
+ args:
+ TF_PACKAGE:
+ default: tensorflow
+ options:
+ - tensorflow
+ - tensorflow-gpu
+ - tf-nightly
+ - tf-nightly-gpu
+ desc: The specific TensorFlow Python package to install
+ shell:
+ desc: Configure TensorFlow's shell prompt and login tools.
+ jupyter:
+ desc: Launch Jupyter on execution instead of a bash prompt.
+
+# ======
+# IMAGES
+# ======
+#
+# Represent Dockerfiles. Spec:
+#
+# name: the name of the image, possibly referenced by other images
+# desc: A description, inserted later into the Dockerfile
+# create-dockerfile: Create a dockerfile based on this. Useful for creating
+# extensible base images that don't need a file. Default is true.
+# partials: List of VALUEs, where a VALUE is either:
+# - the name of a partial, which inserts that partial into this image
+# - image: [name of another image], which inserts the partials from that
+# image into this image
+# arg-defaults: List of VALUEs, where a VALUE is either:
+# - ARG_NAME: VALUE, which sets the ARG_NAME to VALUE wherever it appears
+# in this image's partials
+# - [name of another image], which loads the default args from that image
+images:
+
+ nodev:
+ create-dockerfile: false
+ partials:
+ - python
+ - tensorflow
+ - shell
+
+ dev:
+ create-dockerfile: false
+ partials:
+ - python
+ - bazel
+ - shell
+
+ cpu:
+ desc: Ubuntu-based, CPU-only environment for using TensorFlow
+ partials:
+ - ubuntu
+ - image: nodev
+
+ cpu-devel:
+ desc: >
+ Ubuntu-based, CPU-only environment for developing changes for
+ TensorFlow.
+ partials:
+ - ubuntu-devel
+ - image: dev
+
+ nvidia:
+ desc: Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow.
+ arg-defaults:
+ - TF_PACKAGE: tensorflow-gpu
+ partials:
+ - nvidia
+ - image: nodev
+
+ nvidia-devel:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for developing changes
+ for TensorFlow.
+ arg-defaults:
+ - TF_PACKAGE: tensorflow-gpu
+ partials:
+ - nvidia-devel
+ - image: dev
+
+ cpu-jupyter:
+ desc: >
+ Ubuntu-based, CPU-only environment for using TensorFlow, with Jupyter
+ included.
+ partials:
+ - image: cpu
+ - jupyter
+
+ cpu-devel-jupyter:
+ desc: >
+ Ubuntu-based, CPU-only environment for developing changes for
+ TensorFlow, with Jupyter included.
+ partials:
+ - image: cpu-devel
+ - jupyter
+
+ nvidia-jupyter:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for using TensorFlow, with
+ Jupyter included.
+ arg-defaults:
+ - nvidia
+ partials:
+ - image: nvidia
+ - jupyter
+
+ nvidia-devel-jupyter:
+ desc: >
+ Ubuntu-based, Nvidia-GPU-enabled environment for developing changes for
+ TensorFlow, with Jupyter included.
+ arg-defaults:
+ - nvidia-devel
+ partials:
+ - image: nvidia-devel
+ - jupyter
diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py
index 410342fb69..d5eb4ffc00 100644
--- a/tensorflow/tools/docs/doc_controls_test.py
+++ b/tensorflow/tools/docs/doc_controls_test.py
@@ -145,7 +145,7 @@ class DocControlsTest(googletest.TestCase):
self.assertTrue(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
- def testfor_subclass_implementers(self):
+ def test_for_subclass_implementers(self):
class GrandParent(object):
@@ -178,6 +178,43 @@ class DocControlsTest(googletest.TestCase):
self.assertTrue(
doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
+ def test_for_subclass_implementers_short_circuit(self):
+
+ class GrandParent(object):
+
+ @doc_controls.for_subclass_implementers
+ def my_method(self):
+ pass
+
+ class Parent(GrandParent):
+
+ def my_method(self):
+ pass
+
+ class Child(Parent):
+
+ @doc_controls.do_not_doc_inheritable
+ def my_method(self):
+ pass
+
+ class GrandChild(Child):
+
+ @doc_controls.for_subclass_implementers
+ def my_method(self):
+ pass
+
+ class Grand2Child(Child):
+ pass
+
+ self.assertFalse(
+ doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
+ self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
+ self.assertFalse(
+ doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
+ self.assertTrue(
+ doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py
index f96887e4c7..fc93085e3e 100644
--- a/tensorflow/tools/docs/generate.py
+++ b/tensorflow/tools/docs/generate.py
@@ -31,11 +31,6 @@ if __name__ == '__main__':
doc_generator = generate_lib.DocGenerator()
doc_generator.add_output_dir_argument()
doc_generator.add_src_dir_argument()
- doc_generator.argument_parser.add_argument(
- '--site_api_path',
- type=str, default='api_docs/python',
- help='The path from the site-root to api_docs'
- 'directory for this project')
# This doc generator works on the TensorFlow codebase. Since this script lives
# at tensorflow/tools/docs, and all code is defined somewhere inside
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 9387042224..1cd9cb7ca9 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -22,6 +22,7 @@ import argparse
import fnmatch
import os
import shutil
+import tempfile
import six
@@ -35,29 +36,12 @@ from tensorflow.tools.docs import pretty_docs
from tensorflow.tools.docs import py_guide_parser
-def _is_free_function(py_object, full_name, index):
- """Check if input is a free function (and not a class- or static method)."""
- if not tf_inspect.isfunction(py_object):
- return False
-
- # Static methods are functions to tf_inspect (in 2.7), so check if the parent
- # is a class. If there is no parent, it's not a function.
- if '.' not in full_name:
- return False
-
- parent_name = full_name.rsplit('.', 1)[0]
- if tf_inspect.isclass(index[parent_name]):
- return False
-
- return True
-
-
def write_docs(output_dir,
parser_config,
yaml_toc,
root_title='TensorFlow',
search_hints=True,
- site_api_path=None):
+ site_api_path=''):
"""Write previously extracted docs to disk.
Write a docs page for each symbol included in the indices of parser_config to
@@ -75,8 +59,8 @@ def write_docs(output_dir,
root_title: The title name for the root level index.md.
search_hints: (bool) include meta-data search hints at the top of each
output file.
- site_api_path: Used to write the api-duplicates _redirects.yaml file. if
- None (the default) the file is not generated.
+ site_api_path: The output path relative to the site root. Used in the
+ `_toc.yaml` and `_redirects.yaml` files.
Raises:
ValueError: if `output_dir` is not an absolute path
@@ -108,10 +92,7 @@ def write_docs(output_dir,
# Methods and some routines are documented only as part of their class.
if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or
- _is_free_function(py_object, full_name, parser_config.index)):
- continue
-
- if doc_controls.should_skip(py_object):
+ parser.is_free_function(py_object, full_name, parser_config.index)):
continue
sitepath = os.path.join('api_docs/python',
@@ -160,22 +141,23 @@ def write_docs(output_dir,
raise OSError(
'Cannot write documentation for %s to %s' % (full_name, directory))
- if site_api_path:
- duplicates = parser_config.duplicates.get(full_name, [])
- if not duplicates:
- continue
+ duplicates = parser_config.duplicates.get(full_name, [])
+ if not duplicates:
+ continue
- duplicates = [item for item in duplicates if item != full_name]
+ duplicates = [item for item in duplicates if item != full_name]
- for dup in duplicates:
- from_path = os.path.join(site_api_path, dup.replace('.', '/'))
- to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
- redirects.append((from_path, to_path))
+ for dup in duplicates:
+ from_path = os.path.join(site_api_path, dup.replace('.', '/'))
+ to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
+ redirects.append((
+ os.path.join('/', from_path),
+ os.path.join('/', to_path)))
- if site_api_path and redirects:
+ if redirects:
redirects = sorted(redirects)
- template = ('- from: /{}\n'
- ' to: /{}\n')
+ template = ('- from: {}\n'
+ ' to: {}\n')
redirects = [template.format(f, t) for f, t in redirects]
api_redirects_path = os.path.join(output_dir, '_redirects.yaml')
with open(api_redirects_path, 'w') as redirect_file:
@@ -210,7 +192,8 @@ def write_docs(output_dir,
'- title: ' + title,
' section:',
' - title: Overview',
- ' path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[module]]
+ ' path: ' + os.path.join('/', site_api_path,
+ symbol_to_file[module])]
header = ''.join([indent+line+'\n' for line in header])
f.write(header)
@@ -221,7 +204,8 @@ def write_docs(output_dir,
for full_name in symbols_in_module:
item = [
' - title: ' + full_name[len(module) + 1:],
- ' path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[full_name]]
+ ' path: ' + os.path.join('/', site_api_path,
+ symbol_to_file[full_name])]
item = ''.join([indent+line+'\n' for line in item])
f.write(item)
@@ -295,6 +279,15 @@ def _get_default_do_not_descend_map():
}
+class DocControlsAwareCrawler(public_api.PublicAPIVisitor):
+ """A `docs_controls` aware API-crawler."""
+
+ def _is_private(self, path, name, obj):
+ if doc_controls.should_skip(obj):
+ return True
+ return super(DocControlsAwareCrawler, self)._is_private(path, name, obj)
+
+
def extract(py_modules,
private_map,
do_not_descend_map,
@@ -302,7 +295,7 @@ def extract(py_modules,
"""Extract docs from tf namespace and write them to disk."""
# Traverse the first module.
visitor = visitor_cls(py_modules[0][0])
- api_visitor = public_api.PublicAPIVisitor(visitor)
+ api_visitor = DocControlsAwareCrawler(visitor)
api_visitor.set_root_name(py_modules[0][0])
add_dict_to_dict(private_map, api_visitor.private_map)
add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)
@@ -408,8 +401,8 @@ class _GenerateGuideIndex(py_guide_parser.PyGuideParser):
self.section_tag = tag
def process_line(self, _, line):
- """Index @{symbol} references as in the current file & section."""
- for match in parser.SYMBOL_REFERENCE_RE.finditer(line):
+ """Index the file and section of each `symbol` reference."""
+ for match in parser.AUTO_REFERENCE_RE.finditer(line):
val = self.index.get(match.group(1), [])
val.append(
_GuideRef(self.base_name, self.title, self.section_title,
@@ -532,6 +525,19 @@ class DocGenerator(object):
action='store_false',
default=True)
+ self.argument_parser.add_argument(
+ '--site_api_path',
+ type=str, default='',
+ help='The path from the site-root to api_docs'
+ 'directory for this project')
+
+ self.argument_parser.add_argument(
+ '--api_cache_out_path',
+ type=str,
+ default=None,
+ help='Path to store a json-serialized api-index, so links can be '
+ 'inserted into docs without rebuilding the api_docs')
+
def add_output_dir_argument(self):
self.argument_parser.add_argument(
'--output_dir',
@@ -544,9 +550,9 @@ class DocGenerator(object):
self.argument_parser.add_argument(
'--src_dir',
type=str,
- default=None,
- required=True,
- help='Directory with the source docs.')
+ default=tempfile.mkdtemp(),
+ required=False,
+ help='Optional directory of source docs to add api_docs links to')
def add_base_dir_argument(self, default_base_dir):
self.argument_parser.add_argument(
@@ -632,6 +638,9 @@ class DocGenerator(object):
visitor = self.run_extraction()
reference_resolver = self.make_reference_resolver(visitor, doc_index)
+ if getattr(flags, 'api_cache_out_path', None):
+ reference_resolver.to_json_file(flags.api_cache_out_path)
+
# Build the guide_index for the api_docs back links.
root_title = getattr(flags, 'root_title', 'TensorFlow')
guide_index = _build_guide_index(
@@ -648,7 +657,7 @@ class DocGenerator(object):
yaml_toc=self.yaml_toc,
root_title=root_title,
search_hints=getattr(flags, 'search_hints', True),
- site_api_path=getattr(flags, 'site_api_path', None))
+ site_api_path=getattr(flags, 'site_api_path', ''))
# Replace all the @{} references in files under `FLAGS.src_dir`
replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md')
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 801c8bcb4a..a6159fa692 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -35,6 +35,28 @@ from tensorflow.python.util import tf_inspect
from tensorflow.tools.docs import doc_controls
+def is_free_function(py_object, full_name, index):
+ """Check if input is a free function (and not a class- or static method).
+
+ Args:
+ py_object: The the object in question.
+ full_name: The full name of the object, like `tf.module.symbol`.
+ index: The {full_name:py_object} dictionary for the public API.
+
+ Returns:
+ True if the obeject is a stand-alone function, and not part of a class
+ definition.
+ """
+ if not tf_inspect.isfunction(py_object):
+ return False
+
+ parent_name = full_name.rsplit('.', 1)[0]
+ if tf_inspect.isclass(index[parent_name]):
+ return False
+
+ return True
+
+
# A regular expression capturing a python identifier.
IDENTIFIER_RE = r'[a-zA-Z_]\w*'
@@ -74,7 +96,7 @@ class _Errors(object):
return self._errors == other._errors # pylint: disable=protected-access
-def documentation_path(full_name):
+def documentation_path(full_name, is_fragment=False):
"""Returns the file path for the documentation for the given API symbol.
Given the fully qualified name of a library symbol, compute the path to which
@@ -84,12 +106,22 @@ def documentation_path(full_name):
Args:
full_name: Fully qualified name of a library symbol.
-
+ is_fragment: If `False` produce a direct markdown link (`tf.a.b.c` -->
+ `tf/a/b/c.md`). If `True` produce fragment link, `tf.a.b.c` -->
+ `tf/a/b.md#c`
Returns:
The file path to which to write the documentation for `full_name`.
"""
- dirs = full_name.split('.')
- return os.path.join(*dirs) + '.md'
+ parts = full_name.split('.')
+ if is_fragment:
+ parts, fragment = parts[:-1], parts[-1]
+
+ result = os.path.join(*parts) + '.md'
+
+ if is_fragment:
+ result = result + '#' + fragment
+
+ return result
def _get_raw_docstring(py_object):
@@ -136,8 +168,7 @@ class ReferenceResolver(object):
doc.
"""
- def __init__(self, duplicate_of, doc_index, is_class, is_module,
- py_module_names):
+ def __init__(self, duplicate_of, doc_index, is_fragment, py_module_names):
"""Initializes a Reference Resolver.
Args:
@@ -145,15 +176,15 @@ class ReferenceResolver(object):
symbols.
doc_index: A `dict` mapping symbol name strings to objects with `url`
and `title` fields. Used to resolve @{$doc} references in docstrings.
- is_class: A map from full names to bool for each symbol.
- is_module: A map from full names to bool for each symbol.
+ is_fragment: A map from full names to bool for each symbol. If True the
+ object lives at a page fragment `tf.a.b.c` --> `tf/a/b#c`. If False
+ object has a page to itself: `tf.a.b.c` --> `tf/a/b/c`.
py_module_names: A list of string names of Python modules.
"""
self._duplicate_of = duplicate_of
self._doc_index = doc_index
- self._is_class = is_class
- self._is_module = is_module
- self._all_names = set(is_class.keys())
+ self._is_fragment = is_fragment
+ self._all_names = set(is_fragment.keys())
self._py_module_names = py_module_names
self.current_doc_full_name = None
@@ -180,21 +211,18 @@ class ReferenceResolver(object):
Returns:
an instance of `ReferenceResolver` ()
"""
- is_class = {
- name: tf_inspect.isclass(visitor.index[name])
- for name, obj in visitor.index.items()
- }
+ is_fragment = {}
+ for name, obj in visitor.index.items():
+ has_page = (
+ tf_inspect.isclass(obj) or tf_inspect.ismodule(obj) or
+ is_free_function(obj, name, visitor.index))
- is_module = {
- name: tf_inspect.ismodule(visitor.index[name])
- for name, obj in visitor.index.items()
- }
+ is_fragment[name] = not has_page
return cls(
duplicate_of=visitor.duplicate_of,
doc_index=doc_index,
- is_class=is_class,
- is_module=is_module,
+ is_fragment=is_fragment,
**kwargs)
@classmethod
@@ -210,6 +238,10 @@ class ReferenceResolver(object):
Args:
filepath: The file path to write the json to.
"""
+ try:
+ os.makedirs(os.path.dirname(filepath))
+ except OSError:
+ pass
json_dict = {}
for key, value in self.__dict__.items():
# Drop these two fields. `_doc_index` is not serializable. `_all_names` is
@@ -223,7 +255,7 @@ class ReferenceResolver(object):
json_dict[key.lstrip('_')] = value
with open(filepath, 'w') as f:
- json.dump(json_dict, f)
+ json.dump(json_dict, f, indent=2, sort_keys=True)
def replace_references(self, string, relative_path_to_root):
"""Replace "@{symbol}" references with links to symbol's documentation page.
@@ -339,19 +371,7 @@ class ReferenceResolver(object):
raise TFDocsError(
'Cannot make link to "%s": Not in index.' % master_name)
- # If this is a member of a class, link to the class page with an anchor.
- ref_path = None
- if not (self._is_class[master_name] or self._is_module[master_name]):
- idents = master_name.split('.')
- if len(idents) > 1:
- class_name = '.'.join(idents[:-1])
- assert class_name in self._all_names
- if self._is_class[class_name]:
- ref_path = documentation_path(class_name) + '#%s' % idents[-1]
-
- if not ref_path:
- ref_path = documentation_path(master_name)
-
+ ref_path = documentation_path(master_name, self._is_fragment[master_name])
return os.path.join(relative_path_to_root, ref_path)
def _one_ref(self, match, relative_path_to_root):
@@ -947,6 +967,7 @@ class _ClassPageInfo(object):
self._aliases = None
self._doc = None
self._guides = None
+ self._namedtuplefields = None
self._bases = None
self._properties = []
@@ -1030,6 +1051,17 @@ class _ClassPageInfo(object):
self._guides = guides
@property
+ def namedtuplefields(self):
+ return self._namedtuplefields
+
+ def set_namedtuplefields(self, py_class):
+ if issubclass(py_class, tuple):
+ if all(
+ hasattr(py_class, attr)
+ for attr in ('_asdict', '_fields', '_make', '_replace')):
+ self._namedtuplefields = py_class._fields
+
+ @property
def bases(self):
"""Returns a list of `_LinkInfo` objects pointing to the class' parents."""
return self._bases
@@ -1066,7 +1098,15 @@ class _ClassPageInfo(object):
@property
def properties(self):
"""Returns a list of `_PropertyInfo` describing the class' properties."""
- return self._properties
+ props_dict = {prop.short_name: prop for prop in self._properties}
+ props = []
+ if self.namedtuplefields:
+ for field in self.namedtuplefields:
+ props.append(props_dict.pop(field))
+
+ props.extend(sorted(props_dict.values()))
+
+ return props
def _add_property(self, short_name, full_name, obj, doc):
"""Adds a `_PropertyInfo` entry to the `properties` list.
@@ -1077,6 +1117,9 @@ class _ClassPageInfo(object):
obj: The property object itself
doc: The property's parsed docstring, a `_DocstringInfo`.
"""
+ # Hide useless namedtuple docs-trings
+ if re.match('Alias for field number [0-9]+', doc.docstring):
+ doc = doc._replace(docstring='', brief='')
property_info = _PropertyInfo(short_name, full_name, obj, doc)
self._properties.append(property_info)
@@ -1156,6 +1199,7 @@ class _ClassPageInfo(object):
py_class: The class object being documented
parser_config: An instance of ParserConfig.
"""
+ self.set_namedtuplefields(py_class)
doc_path = documentation_path(self.full_name)
relative_path = os.path.relpath(
path='.', start=os.path.dirname(doc_path) or '.')
@@ -1695,15 +1739,18 @@ class _Metadata(object):
Attributes:
name: The name of the page being described by the Metadata block.
+ version: The source version.
"""
- def __init__(self, name):
+ def __init__(self, name, version='Stable'):
"""Creates a Metadata builder.
Args:
name: The name of the page being described by the Metadata block.
+ version: The source version.
"""
self.name = name
+ self.version = version
self._content = []
def append(self, item):
@@ -1720,6 +1767,7 @@ class _Metadata(object):
parts = ['<div itemscope itemtype="%s">' % schema]
parts.append('<meta itemprop="name" content="%s" />' % self.name)
+ parts.append('<meta itemprop="path" content="%s" />' % self.version)
for item in self._content:
parts.append('<meta itemprop="property" content="%s"/>' % item)
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 9f6b185e81..8a41796fb9 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
import os
import sys
@@ -27,6 +28,12 @@ from tensorflow.python.util import tf_inspect
from tensorflow.tools.docs import doc_controls
from tensorflow.tools.docs import parser
+# The test needs a real module. `types.ModuleType()` doesn't work, as the result
+# is a `builtin` module. Using "parser" here is arbitraty. The tests don't
+# depend on the module contents. At this point in the process the public api
+# has already been extracted.
+test_module = parser
+
def test_function(unused_arg, unused_kwarg='default'):
"""Docstring for test function."""
@@ -190,6 +197,50 @@ class ParserTest(googletest.TestCase):
# Make sure this file is contained as the definition location.
self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+ def test_namedtuple_field_order(self):
+ namedtupleclass = collections.namedtuple('namedtupleclass',
+ {'z', 'y', 'x', 'w', 'v', 'u'})
+
+ index = {
+ 'namedtupleclass': namedtupleclass,
+ 'namedtupleclass.u': namedtupleclass.u,
+ 'namedtupleclass.v': namedtupleclass.v,
+ 'namedtupleclass.w': namedtupleclass.w,
+ 'namedtupleclass.x': namedtupleclass.x,
+ 'namedtupleclass.y': namedtupleclass.y,
+ 'namedtupleclass.z': namedtupleclass.z,
+ }
+
+ visitor = DummyVisitor(index=index, duplicate_of={})
+
+ reference_resolver = parser.ReferenceResolver.from_visitor(
+ visitor=visitor, doc_index={}, py_module_names=['tf'])
+
+ tree = {'namedtupleclass': {'u', 'v', 'w', 'x', 'y', 'z'}}
+ parser_config = parser.ParserConfig(
+ reference_resolver=reference_resolver,
+ duplicates={},
+ duplicate_of={},
+ tree=tree,
+ index=index,
+ reverse_index={},
+ guide_index={},
+ base_dir='/')
+
+ page_info = parser.docs_for_object(
+ full_name='namedtupleclass',
+ py_object=namedtupleclass,
+ parser_config=parser_config)
+
+ # Each namedtiple field has a docstring of the form:
+ # 'Alias for field number ##'. These props are returned sorted.
+
+ def sort_key(prop_info):
+ return int(prop_info.obj.__doc__.split(' ')[-1])
+
+ self.assertSequenceEqual(page_info.properties,
+ sorted(page_info.properties, key=sort_key))
+
def test_docs_for_class_should_skip(self):
class Parent(object):
@@ -289,15 +340,16 @@ class ParserTest(googletest.TestCase):
self.assertEqual('my_method', page_info.methods[0].short_name)
def test_docs_for_module(self):
- # Get the current module.
- module = sys.modules[__name__]
index = {
- 'TestModule': module,
- 'TestModule.test_function': test_function,
+ 'TestModule':
+ test_module,
+ 'TestModule.test_function':
+ test_function,
'TestModule.test_function_with_args_kwargs':
- test_function_with_args_kwargs,
- 'TestModule.TestClass': TestClass,
+ test_function_with_args_kwargs,
+ 'TestModule.TestClass':
+ TestClass,
}
visitor = DummyVisitor(index=index, duplicate_of={})
@@ -320,11 +372,13 @@ class ParserTest(googletest.TestCase):
base_dir='/')
page_info = parser.docs_for_object(
- full_name='TestModule', py_object=module, parser_config=parser_config)
+ full_name='TestModule',
+ py_object=test_module,
+ parser_config=parser_config)
# Make sure the brief docstring is present
- self.assertEqual(tf_inspect.getdoc(module).split('\n')[0],
- page_info.doc.brief)
+ self.assertEqual(
+ tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)
# Make sure that the members are there
funcs = {f_info.obj for f_info in page_info.functions}
@@ -333,8 +387,9 @@ class ParserTest(googletest.TestCase):
classes = {cls_info.obj for cls_info in page_info.classes}
self.assertEqual({TestClass}, classes)
- # Make sure this file is contained as the definition location.
- self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+ # Make sure the module's file is contained as the definition location.
+ self.assertEqual(
+ os.path.relpath(test_module.__file__, '/'), page_info.defined_in.path)
def test_docs_for_function(self):
index = {
@@ -450,6 +505,7 @@ class ParserTest(googletest.TestCase):
duplicate_of = {'tf.third': 'tf.fourth'}
index = {
+ 'tf': test_module,
'tf.fancy': test_function_with_fancy_docstring,
'tf.reference': HasOneMember,
'tf.reference.foo': HasOneMember.foo,
@@ -476,20 +532,18 @@ class ParserTest(googletest.TestCase):
'NumPy has nothing as awesome as this function.\n')
def test_generate_index(self):
- module = sys.modules[__name__]
index = {
- 'TestModule': module,
- 'test_function': test_function,
- 'TestModule.test_function': test_function,
- 'TestModule.TestClass': TestClass,
- 'TestModule.TestClass.a_method': TestClass.a_method,
- 'TestModule.TestClass.a_property': TestClass.a_property,
- 'TestModule.TestClass.ChildClass': TestClass.ChildClass,
- }
- duplicate_of = {
- 'TestModule.test_function': 'test_function'
+ 'tf': test_module,
+ 'tf.TestModule': test_module,
+ 'tf.test_function': test_function,
+ 'tf.TestModule.test_function': test_function,
+ 'tf.TestModule.TestClass': TestClass,
+ 'tf.TestModule.TestClass.a_method': TestClass.a_method,
+ 'tf.TestModule.TestClass.a_property': TestClass.a_property,
+ 'tf.TestModule.TestClass.ChildClass': TestClass.ChildClass,
}
+ duplicate_of = {'tf.TestModule.test_function': 'tf.test_function'}
visitor = DummyVisitor(index=index, duplicate_of=duplicate_of)
@@ -508,7 +562,7 @@ class ParserTest(googletest.TestCase):
self.assertIn('TestModule.test_function', docs)
# Leading backtick to make sure it's included top-level.
# This depends on formatting, but should be stable.
- self.assertIn('<code>test_function', docs)
+ self.assertIn('<code>tf.test_function', docs)
def test_argspec_for_functools_partial(self):
# pylint: disable=unused-argument
@@ -620,22 +674,18 @@ class ParserTest(googletest.TestCase):
duplicate_of = {'AClass': ['AClass2']}
doc_index = {'doc': you_cant_serialize_this}
- is_class = {
+ is_fragment = {
'tf': False,
- 'tf.AClass': True,
- 'tf.AClass2': True,
- 'tf.function': False
- }
- is_module = {
- 'tf': True,
+ 'tf.VERSION': True,
'tf.AClass': False,
+ 'tf.AClass.method': True,
'tf.AClass2': False,
'tf.function': False
}
py_module_names = ['tf', 'tfdbg']
- resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class,
- is_module, py_module_names)
+ resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_fragment,
+ py_module_names)
outdir = googletest.GetTempDir()
@@ -647,6 +697,23 @@ class ParserTest(googletest.TestCase):
# There are no __slots__, so all fields are visible in __dict__.
self.assertEqual(resolver.__dict__, resolver2.__dict__)
+ def testIsFreeFunction(self):
+
+ result = parser.is_free_function(test_function, 'test_module.test_function',
+ {'test_module': test_module})
+ self.assertTrue(result)
+
+ result = parser.is_free_function(test_function, 'TestClass.test_function',
+ {'TestClass': TestClass})
+ self.assertFalse(result)
+
+ result = parser.is_free_function(TestClass, 'TestClass', {})
+ self.assertFalse(result)
+
+ result = parser.is_free_function(test_module, 'test_module', {})
+ self.assertFalse(result)
+
+
RELU_DOC = """Computes rectified linear: `max(features, 0)`
Args:
@@ -736,6 +803,5 @@ class TestGenerateSignature(googletest.TestCase):
sig = parser._generate_signature(example_fun, reverse_index={})
self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"])
-
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index 63d4fef91c..1a3e79621f 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -93,6 +93,15 @@ def _build_class_page(page_info):
parts.append('\n\n')
+ # Sort the methods list, but make sure constructors come first.
+ constructor_names = ['__init__', '__new__']
+ constructors = sorted(
+ method for method in page_info.methods
+ if method.short_name in constructor_names)
+ other_methods = sorted(
+ method for method in page_info.methods
+ if method.short_name not in constructor_names)
+
if len(page_info.aliases) > 1:
parts.append('### Aliases:\n\n')
parts.extend('* Class `%s`\n' % name for name in page_info.aliases)
@@ -109,6 +118,11 @@ def _build_class_page(page_info):
parts.append('\n\n')
+ if constructors:
+ for method_info in constructors:
+ parts.append(_build_method_section(method_info, heading_level=2))
+ parts.append('\n\n')
+
if page_info.classes:
parts.append('## Child Classes\n')
@@ -122,7 +136,7 @@ def _build_class_page(page_info):
if page_info.properties:
parts.append('## Properties\n\n')
- for prop_info in sorted(page_info.properties):
+ for prop_info in page_info.properties:
h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
parts.append(h3.format(short_name=prop_info.short_name))
@@ -134,28 +148,11 @@ def _build_class_page(page_info):
parts.append('\n\n')
- if page_info.methods:
+ if other_methods:
parts.append('## Methods\n\n')
- # Sort the methods list, but make sure constructors come first.
- constructors = ['__init__', '__new__']
- inits = [method for method in page_info.methods
- if method.short_name in constructors]
- others = [method for method in page_info.methods
- if method.short_name not in constructors]
-
- for method_info in sorted(inits) + sorted(others):
- h3 = ('<h3 id="{short_name}">'
- '<code>{short_name}</code>'
- '</h3>\n\n')
- parts.append(h3.format(**method_info._asdict()))
-
- if method_info.signature is not None:
- parts.append(_build_signature(method_info, use_full_name=False))
-
- parts.append(method_info.doc.docstring)
- parts.append(_build_function_details(method_info.doc.function_details))
- parts.append(_build_compatibility(method_info.doc.compatibility))
- parts.append('\n\n')
+
+ for method_info in other_methods:
+ parts.append(_build_method_section(method_info))
parts.append('\n\n')
if page_info.other_members:
@@ -172,6 +169,33 @@ def _build_class_page(page_info):
return ''.join(parts)
+def _build_method_section(method_info, heading_level=3):
+ """Generates a markdown section for a method.
+
+ Args:
+ method_info: A `MethodInfo` object.
+ heading_level: An Int, which HTML heading level to use.
+
+ Returns:
+ A markdown string.
+ """
+ parts = []
+ heading = ('<h{heading_level} id="{short_name}">'
+ '<code>{short_name}</code>'
+ '</h{heading_level}>\n\n')
+ parts.append(heading.format(heading_level=heading_level,
+ **method_info._asdict()))
+
+ if method_info.signature is not None:
+ parts.append(_build_signature(method_info, use_full_name=False))
+
+ parts.append(method_info.doc.docstring)
+ parts.append(_build_function_details(method_info.doc.function_details))
+ parts.append(_build_compatibility(method_info.doc.compatibility))
+ parts.append('\n\n')
+ return ''.join(parts)
+
+
def _build_module_page(page_info):
"""Given a ClassPageInfo object Return the page as an md string."""
parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)]
@@ -231,8 +255,9 @@ def _build_module_page(page_info):
# at least for basic types.
parts.append('## Other Members\n\n')
+ h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
for item in page_info.other_members:
- parts.append('`{short_name}`\n\n'.format(**item._asdict()))
+ parts.append(h3.format(**item._asdict()))
return ''.join(parts)
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index f858411876..6df2718e61 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -121,7 +121,7 @@ Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
GraphDef* graph_def) {
std::unordered_set<string> input_names;
for (const string& input_name : context.input_names) {
- input_names.insert(ParseTensorName(input_name).first.ToString());
+ input_names.emplace(ParseTensorName(input_name).first);
}
for (NodeDef& node : *graph_def->mutable_node()) {
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.h b/tensorflow/tools/graph_transforms/fold_constants_lib.h
index 8aefa6ae0f..0802ebb815 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.h
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_H_
-#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_H_
+#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_LIB_H_
+#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_LIB_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -40,4 +40,4 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
} // namespace graph_transforms
} // namespace tensorflow
-#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_H_
+#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FOLD_CONSTANTS_LIB_H_
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
index c8dc2a7c4d..d97496cbeb 100644
--- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
@@ -92,7 +92,7 @@ Status ExtractMinMaxRecords(const string& log_file_name,
if (!str_util::EndsWith(name_string, print_suffix)) {
continue;
}
- string name = std::string(
+ string name(
name_string.substr(0, name_string.size() - print_suffix.size()));
records->push_back({name, min, max});
}
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index dd95779a1f..b8d6ba00de 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -42,8 +42,8 @@ class SparsifyGatherTest : public ::testing::Test {
const std::vector<NodeDef*>& inputs, GraphDef* graph_def,
bool control_dep = false) {
NodeDef* node_def = graph_def->add_node();
- node_def->set_name(std::string(name));
- node_def->set_op(std::string(op));
+ node_def->set_name(string(name));
+ node_def->set_op(string(op));
if (!control_dep) {
std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
node_def->add_input(input->name());
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 5cae8f8d8f..7efe450710 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -65,19 +65,19 @@ Status ParseTransformParameters(const string& transforms_string,
.GetResult(&remaining, &transform_name);
if (!found_transform_name) {
return errors::InvalidArgument("Looking for transform name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_NAME;
} else {
// Add a transform with no parameters.
- params_list->push_back({std::string(transform_name), func_parameters});
+ params_list->push_back({string(transform_name), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
}
} else if (state == TRANSFORM_PARAM_NAME) {
if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
- params_list->push_back({std::string(transform_name), func_parameters});
+ params_list->push_back({string(transform_name), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
} else {
@@ -92,13 +92,13 @@ Status ParseTransformParameters(const string& transforms_string,
if (!found_parameter_name) {
return errors::InvalidArgument(
"Looking for parameter name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_VALUE;
} else {
return errors::InvalidArgument("Looking for =, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
}
} else if (state == TRANSFORM_PARAM_VALUE) {
@@ -120,10 +120,9 @@ Status ParseTransformParameters(const string& transforms_string,
}
if (!found_parameter_value) {
return errors::InvalidArgument("Looking for parameter name, but found ",
- std::string(remaining).c_str());
+ string(remaining).c_str());
}
- func_parameters[std::string(parameter_name)].push_back(
- std::string(parameter_value));
+ func_parameters[string(parameter_name)].emplace_back(parameter_value);
// Eat up any trailing quotes.
Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index cb084e49b7..c715380aae 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -93,7 +93,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix,
} else {
*prefix = "";
}
- *node_name = std::string(node_name_piece);
+ *node_name = string(node_name_piece);
}
string NodeNameFromInput(const string& input_name) {
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 00c1337b19..91c5cd094c 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -13,6 +13,10 @@ load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
# This returns a list of headers of all public header libraries (e.g.,
# framework, lib), and all of the transitive dependencies of those
@@ -71,6 +75,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
@@ -82,6 +87,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/predictor:predictor_pip",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/receptive_field:receptive_field_pip",
+ "//tensorflow/contrib/rate:rate",
"//tensorflow/contrib/rpc:rpc_pip",
"//tensorflow/contrib/session_bundle:session_bundle_pip",
"//tensorflow/contrib/signal:signal_py",
@@ -200,14 +206,23 @@ filegroup(
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
],
- ) + tf_additional_license_deps(),
+ ) + if_ngraph([
+ "@ngraph//:LICENSE",
+ "@ngraph_tf//:LICENSE",
+ "@nlohmann_json_lib//:LICENSE.MIT",
+ ]) + tf_additional_license_deps(),
)
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = select({
- "//tensorflow:windows": [":simple_console_for_windows"],
+ "//tensorflow:windows": [
+ ":simple_console_for_windows",
+ "//tensorflow/contrib/lite/python:interpreter_test_data",
+ "//tensorflow/contrib/lite/python:tflite_convert",
+ "//tensorflow/contrib/lite/toco/python:toco_from_protos",
+ ],
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
"//tensorflow/contrib/lite/python:interpreter_test_data",
diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in
index 86c5e4776d..c4b4af93b8 100644
--- a/tensorflow/tools/pip_package/MANIFEST.in
+++ b/tensorflow/tools/pip_package/MANIFEST.in
@@ -1,5 +1,6 @@
include README
recursive-include * *.py
+recursive-include * *.pd
recursive-include * *.so
recursive-include * *.dll
recursive-include * *.lib
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 5e179079c5..3102239a19 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -51,9 +51,9 @@ REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
- 'keras_applications == 1.0.4',
- 'keras_preprocessing == 1.0.2',
- 'numpy >= 1.13.3, <= 1.14.5',
+ 'keras_applications >= 1.0.5',
+ 'keras_preprocessing >= 1.0.3',
+ 'numpy >= 1.13.3',
'six >= 1.10.0',
'protobuf >= 3.6.0',
'setuptools <= 39.1.0',
@@ -167,17 +167,21 @@ class InstallHeaders(Command):
# directories for -I
install_dir = re.sub('/google/protobuf_archive/src', '', install_dir)
- # Copy eigen code into tensorflow/include.
+ # Copy external code headers into tensorflow/include.
# A symlink would do, but the wheel file that gets created ignores
# symlink within the directory hierarchy.
# NOTE(keveman): Figure out how to customize bdist_wheel package so
# we can do the symlink.
- if 'tensorflow/include/external/eigen_archive/' in install_dir:
- extra_dir = install_dir.replace(
- 'tensorflow/include/external/eigen_archive', '')
- if not os.path.exists(extra_dir):
- self.mkpath(extra_dir)
- self.copy_file(header, extra_dir)
+ external_header_locations = [
+ 'tensorflow/include/external/eigen_archive/',
+ 'tensorflow/include/external/com_google_absl/',
+ ]
+ for location in external_header_locations:
+ if location in install_dir:
+ extra_dir = install_dir.replace(location, '')
+ if not os.path.exists(extra_dir):
+ self.mkpath(extra_dir)
+ self.copy_file(header, extra_dir)
if not os.path.exists(install_dir):
self.mkpath(install_dir)
@@ -227,6 +231,8 @@ headers = (list(find_files('*.h', 'tensorflow/core')) +
list(find_files('*.h', 'tensorflow/stream_executor')) +
list(find_files('*.h', 'google/protobuf_archive/src')) +
list(find_files('*', 'third_party/eigen3')) +
+ list(find_files('*.h',
+ 'tensorflow/include/external/com_google_absl')) +
list(find_files('*', 'tensorflow/include/external/eigen_archive')))
setup(
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
index e18d749cff..20aa605480 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
-#define TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#ifndef TENSORFLOW_TOOLS_PROTO_TEXT_GEN_PROTO_TEXT_FUNCTIONS_LIB_H_
+#define TENSORFLOW_TOOLS_PROTO_TEXT_GEN_PROTO_TEXT_FUNCTIONS_LIB_H_
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ ProtoTextFunctionCode GetProtoTextFunctionCode(
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#endif // TENSORFLOW_TOOLS_PROTO_TEXT_GEN_PROTO_TEXT_FUNCTIONS_LIB_H_
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9d02cc2885..8e6f4143a9 100644..100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -19,10 +19,10 @@ load(
"//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl",
"def_file_filter_configure",
)
+load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
def initialize_third_party():
- # Fill in later
- pass
+ flatbuffers()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
@@ -60,31 +60,31 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
mkl_repository(
name = "mkl_linux",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
],
- sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725",
- strip_prefix = "mklml_lnx_2018.0.3.20180406",
+ sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
+ strip_prefix = "mklml_lnx_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_windows",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
],
- sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694",
- strip_prefix = "mklml_win_2018.0.3.20180406",
+ sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
+ strip_prefix = "mklml_win_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_darwin",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
],
- sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b",
- strip_prefix = "mklml_mac_2018.0.3.20180406",
+ sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
+ strip_prefix = "mklml_mac_2019.0.20180710",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
@@ -95,22 +95,22 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
- "https://github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
+ "https://github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
],
- sha256 = "da1f27f92453a65331197dd8e4992e810fb7b1c4e0b902a1da5611592df2b633",
- strip_prefix = "mkl-dnn-0c1cf54b63732e5a723c5670f66f6dfb19b64d20",
+ sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
+ strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
)
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
],
- sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9",
- strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d",
+ sha256 = "f4f34f90083d5259f9a1a4067749d842599748d8ca03c1d9fe723124a7045c63",
+ strip_prefix = "abseil-cpp-fb462224c058487763f263b7995d70efd0242c17",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
@@ -240,11 +240,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "jpeg",
urls = [
- "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
- "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
+ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+ "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
],
- sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde",
- strip_prefix = "libjpeg-turbo-1.5.3",
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
)
@@ -365,14 +365,18 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
},
)
+ PROTOBUF_URLS = [
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ ]
+ PROTOBUF_SHA256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4"
+ PROTOBUF_STRIP_PREFIX = "protobuf-3.6.0"
+
tf_http_archive(
name = "protobuf_archive",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
+ urls = PROTOBUF_URLS,
+ sha256 = PROTOBUF_SHA256,
+ strip_prefix = PROTOBUF_STRIP_PREFIX,
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -380,32 +384,27 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Unfortunately there is no way to alias http_archives at the moment.
tf_http_archive(
name = "com_google_protobuf",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
+ urls = PROTOBUF_URLS,
+ sha256 = PROTOBUF_SHA256,
+ strip_prefix = PROTOBUF_STRIP_PREFIX,
)
tf_http_archive(
name = "com_google_protobuf_cc",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
+ urls = PROTOBUF_URLS,
+ sha256 = PROTOBUF_SHA256,
+ strip_prefix = PROTOBUF_STRIP_PREFIX,
)
tf_http_archive(
name = "nsync",
urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz",
- "https://github.com/google/nsync/archive/1.20.0.tar.gz",
+ "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.1.tar.gz",
+ "https://github.com/google/nsync/archive/1.20.1.tar.gz",
],
- sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd",
- strip_prefix = "nsync-1.20.0",
+ sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
+ strip_prefix = "nsync-1.20.1",
+ system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
)
tf_http_archive(
@@ -492,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/17454e67ca55357e103cec104c3dc973bbb11ff0.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
],
- sha256 = "7543322052e27e70f882801ef70a45afc268e09aaf6a07b840450bfcac366eb6",
- strip_prefix = "llvm-17454e67ca55357e103cec104c3dc973bbb11ff0",
+ sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2",
+ strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
@@ -527,11 +526,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "boringssl",
urls = [
- "https://mirror.bazel.build/github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz",
- "https://github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz",
+ "https://mirror.bazel.build/github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
+ "https://github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
],
- sha256 = "972e8d8a9d1daf9892fff7155312b1af46b4754446575a7b285e62f917424c78",
- strip_prefix = "boringssl-45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8",
+ sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
+ strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
)
tf_http_archive(
@@ -582,11 +581,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "kafka",
urls = [
- "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
- "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
+ "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
+ "https://github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
],
- sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c",
- strip_prefix = "librdkafka-0.11.4",
+ sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
+ strip_prefix = "librdkafka-0.11.5",
build_file = clean_dep("//third_party:kafka/BUILD"),
patch_file = clean_dep("//third_party/kafka:config.patch"),
)
@@ -739,18 +738,6 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
- tf_http_archive(
- name = "flatbuffers",
- strip_prefix = "flatbuffers-1.9.0",
- sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
- urls = [
- "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- ],
- build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:flatbuffers.BUILD"),
- )
-
native.new_http_archive(
name = "double_conversion",
urls = [
@@ -780,6 +767,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
+
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
@@ -791,6 +779,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant_protobuf",
+ sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ ],
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+
+ tf_http_archive(
name = "tflite_conv_actions_frozen",
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [
@@ -831,6 +830,39 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
strip_prefix = "rules_android-0.1.1",
)
+ tf_http_archive(
+ name = "ngraph",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ ],
+ sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
+ strip_prefix = "ngraph-0.5.0",
+ build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "nlohmann_json_lib",
+ urls = [
+ "https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ "https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ ],
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "ngraph_tf",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ ],
+ sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
+ strip_prefix = "ngraph-tf-0.3.0-rc1",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ )
+
##############################################################################
# BIND DEFINITIONS
#